plodding away on the accounts endpoint
This commit is contained in:
parent
7139116e5d
commit
0ea69345b9
|
@ -25,59 +25,122 @@ import (
|
||||||
|
|
||||||
"github.com/go-fed/activity/pub"
|
"github.com/go-fed/activity/pub"
|
||||||
"github.com/gotosocial/gotosocial/internal/config"
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
|
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const dbTypePostgres string = "POSTGRES"
|
const dbTypePostgres string = "POSTGRES"
|
||||||
|
|
||||||
|
type ErrNoEntries struct{}
|
||||||
|
|
||||||
|
func (e ErrNoEntries) Error() string {
|
||||||
|
return "no entries"
|
||||||
|
}
|
||||||
|
|
||||||
// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres).
|
// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres).
|
||||||
|
// Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated
|
||||||
|
// by whatever is returned from the database.
|
||||||
type DB interface {
|
type DB interface {
|
||||||
// Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions.
|
// Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions.
|
||||||
// See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database
|
// See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database
|
||||||
Federation() pub.Database
|
Federation() pub.Database
|
||||||
|
|
||||||
// CreateTable creates a table for the given interface
|
/*
|
||||||
|
BASIC DB FUNCTIONALITY
|
||||||
|
*/
|
||||||
|
|
||||||
|
// CreateTable creates a table for the given interface.
|
||||||
|
// For implementations that don't use tables, this can just return nil.
|
||||||
CreateTable(i interface{}) error
|
CreateTable(i interface{}) error
|
||||||
|
|
||||||
// DropTable drops the table for the given interface
|
// DropTable drops the table for the given interface.
|
||||||
|
// For implementations that don't use tables, this can just return nil.
|
||||||
DropTable(i interface{}) error
|
DropTable(i interface{}) error
|
||||||
|
|
||||||
// Stop should stop and close the database connection cleanly, returning an error if this is not possible
|
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
|
||||||
|
// If the database implementation doesn't need to be stopped, this can just return nil.
|
||||||
Stop(ctx context.Context) error
|
Stop(ctx context.Context) error
|
||||||
|
|
||||||
// IsHealthy should return nil if the database connection is healthy, or an error if not
|
// IsHealthy should return nil if the database connection is healthy, or an error if not.
|
||||||
IsHealthy(ctx context.Context) error
|
IsHealthy(ctx context.Context) error
|
||||||
|
|
||||||
// GetByID gets one entry by its id.
|
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
|
||||||
|
// for other implementations (for example, in-memory) it might just be the key of a map.
|
||||||
|
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||||
|
// In case of no entries, a 'no entries' error will be returned
|
||||||
GetByID(id string, i interface{}) error
|
GetByID(id string, i interface{}) error
|
||||||
|
|
||||||
// GetWhere gets one entry where key = value
|
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
|
||||||
|
// name of the key to select from.
|
||||||
|
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||||
|
// In case of no entries, a 'no entries' error will be returned
|
||||||
GetWhere(key string, value interface{}, i interface{}) error
|
GetWhere(key string, value interface{}, i interface{}) error
|
||||||
|
|
||||||
// GetAll gets all entries of interface type i
|
// GetAll will try to get all entries of type i.
|
||||||
|
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||||
|
// In case of no entries, a 'no entries' error will be returned
|
||||||
GetAll(i interface{}) error
|
GetAll(i interface{}) error
|
||||||
|
|
||||||
// Put stores i
|
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
|
||||||
|
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||||
Put(i interface{}) error
|
Put(i interface{}) error
|
||||||
|
|
||||||
// Update by id updates i with id id
|
// UpdateByID updates i with id id.
|
||||||
|
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||||
UpdateByID(id string, i interface{}) error
|
UpdateByID(id string, i interface{}) error
|
||||||
|
|
||||||
// Delete by id removes i with id id
|
// DeleteByID removes i with id id.
|
||||||
|
// If i didn't exist anyway, then no error should be returned.
|
||||||
DeleteByID(id string, i interface{}) error
|
DeleteByID(id string, i interface{}) error
|
||||||
|
|
||||||
// Delete where deletes i where key = value
|
// DeleteWhere deletes i where key = value
|
||||||
|
// If i didn't exist anyway, then no error should be returned.
|
||||||
DeleteWhere(key string, value interface{}, i interface{}) error
|
DeleteWhere(key string, value interface{}, i interface{}) error
|
||||||
|
|
||||||
// GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID
|
/*
|
||||||
GetAccountByUserID(userID string, account *gtsmodel.Account) error
|
HANDY SHORTCUTS
|
||||||
|
*/
|
||||||
|
|
||||||
// GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following
|
// GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID.
|
||||||
GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error
|
// The given account pointer will be set to the result of the query, whatever it is.
|
||||||
|
// In case of no entries, a 'no entries' error will be returned
|
||||||
|
GetAccountByUserID(userID string, account *model.Account) error
|
||||||
|
|
||||||
// GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by
|
// GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following.
|
||||||
GetFollowersByAccountID(accountID string, following *[]gtsmodel.Follow) error
|
// The given slice 'following' will be set to the result of the query, whatever it is.
|
||||||
|
// In case of no entries, a 'no entries' error will be returned
|
||||||
|
GetFollowingByAccountID(accountID string, following *[]model.Follow) error
|
||||||
|
|
||||||
|
// GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by.
|
||||||
|
// The given slice 'followers' will be set to the result of the query, whatever it is.
|
||||||
|
// In case of no entries, a 'no entries' error will be returned
|
||||||
|
GetFollowersByAccountID(accountID string, followers *[]model.Follow) error
|
||||||
|
|
||||||
|
// GetStatusesByAccountID is a shortcut for the common action of fetching a list of statuses produced by accountID.
|
||||||
|
// The given slice 'statuses' will be set to the result of the query, whatever it is.
|
||||||
|
// In case of no entries, a 'no entries' error will be returned
|
||||||
|
GetStatusesByAccountID(accountID string, statuses *[]model.Status) error
|
||||||
|
|
||||||
|
// GetStatusesByTimeDescending is a shortcut for getting the most recent statuses. accountID is optional, if not provided
|
||||||
|
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
|
||||||
|
// be very memory intensive so you probably shouldn't do this!
|
||||||
|
// In case of no entries, a 'no entries' error will be returned
|
||||||
|
GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error
|
||||||
|
|
||||||
|
// GetLastStatusForAccountID simply gets the most recent status by the given account.
|
||||||
|
// The given slice 'status' pointer will be set to the result of the query, whatever it is.
|
||||||
|
// In case of no entries, a 'no entries' error will be returned
|
||||||
|
GetLastStatusForAccountID(accountID string, status *model.Status) error
|
||||||
|
|
||||||
|
/*
|
||||||
|
USEFUL CONVERSION FUNCTIONS
|
||||||
|
*/
|
||||||
|
|
||||||
|
// AccountToMastoSensitive takes a db model account as a param, and returns a populated mastotype account, or an error
|
||||||
|
// if something goes wrong. The returned account should be ready to serialize on an API level, and may have sensitive fields,
|
||||||
|
// so serve it only to an authorized user who should have permission to see it.
|
||||||
|
AccountToMastoSensitive(account *model.Account) (*mastotypes.Account, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new database service that satisfies the DB interface and, by extension,
|
// New returns a new database service that satisfies the DB interface and, by extension,
|
||||||
|
|
|
@ -16,16 +16,15 @@
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database.
|
// Package model contains types used *internally* by GoToSocial and added/removed/selected from the database.
|
||||||
// These types should never be serialized and/or sent out via public APIs, as they contain sensitive information.
|
// These types should never be serialized and/or sent out via public APIs, as they contain sensitive information.
|
||||||
// The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/
|
// The annotation used on these structs is for handling them via the go-pg ORM (hence why they're in this db subdir).
|
||||||
package gtsmodel
|
// See here for more info on go-pg model annotations: https://pg.uptrace.dev/models/
|
||||||
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Account represents either a local or a remote fediverse account, gotosocial or otherwise (mastodon, pleroma, etc)
|
// Account represents either a local or a remote fediverse account, gotosocial or otherwise (mastodon, pleroma, etc)
|
||||||
|
@ -39,20 +38,36 @@ type Account struct {
|
||||||
// Username of the account, should just be a string of [a-z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``
|
// Username of the account, should just be a string of [a-z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``
|
||||||
Username string `pg:",notnull,unique:userdomain"` // username and domain should be unique *with* each other
|
Username string `pg:",notnull,unique:userdomain"` // username and domain should be unique *with* each other
|
||||||
// Domain of the account, will be empty if this is a local account, otherwise something like ``example.org`` or ``mastodon.social``. Should be unique with username.
|
// Domain of the account, will be empty if this is a local account, otherwise something like ``example.org`` or ``mastodon.social``. Should be unique with username.
|
||||||
Domain string `pg:",unique:userdomain"` // username and domain
|
Domain string `pg:",unique:userdomain"` // username and domain should be unique *with* each other
|
||||||
|
|
||||||
/*
|
/*
|
||||||
ACCOUNT METADATA
|
ACCOUNT METADATA
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Avatar image for this account
|
// File name of the avatar on local storage
|
||||||
Avatar
|
AvatarFileName string
|
||||||
// Header image for this account
|
// Gif? png? jpeg?
|
||||||
Header
|
AvatarContentType string
|
||||||
|
// Size of the avatar in bytes
|
||||||
|
AvatarFileSize int
|
||||||
|
// When was the avatar last updated?
|
||||||
|
AvatarUpdatedAt time.Time `pg:"type:timestamp"`
|
||||||
|
// Where can the avatar be retrieved?
|
||||||
|
AvatarRemoteURL *url.URL `pg:"type:text"`
|
||||||
|
// File name of the header on local storage
|
||||||
|
HeaderFileName string
|
||||||
|
// Gif? png? jpeg?
|
||||||
|
HeaderContentType string
|
||||||
|
// Size of the header in bytes
|
||||||
|
HeaderFileSize int
|
||||||
|
// When was the header last updated?
|
||||||
|
HeaderUpdatedAt time.Time `pg:"type:timestamp"`
|
||||||
|
// Where can the header be retrieved?
|
||||||
|
HeaderRemoteURL *url.URL `pg:"type:text"`
|
||||||
// DisplayName for this account. Can be empty, then just the Username will be used for display purposes.
|
// DisplayName for this account. Can be empty, then just the Username will be used for display purposes.
|
||||||
DisplayName string
|
DisplayName string
|
||||||
// a key/value map of fields that this account has added to their profile
|
// a key/value map of fields that this account has added to their profile
|
||||||
Fields map[string]string
|
Fields []Field
|
||||||
// A note that this account has on their profile (ie., the account's bio/description of themselves)
|
// A note that this account has on their profile (ie., the account's bio/description of themselves)
|
||||||
Note string
|
Note string
|
||||||
// Is this a memorial account, ie., has the user passed away?
|
// Is this a memorial account, ie., has the user passed away?
|
||||||
|
@ -85,8 +100,6 @@ type Account struct {
|
||||||
URI string `pg:",unique"`
|
URI string `pg:",unique"`
|
||||||
// At which URL can we see the user account in a web browser?
|
// At which URL can we see the user account in a web browser?
|
||||||
URL string `pg:",unique"`
|
URL string `pg:",unique"`
|
||||||
// RemoteURL where this account is located. Will be empty if this is a local account.
|
|
||||||
RemoteURL string `pg:",unique"`
|
|
||||||
// Last time this account was located using the webfinger API.
|
// Last time this account was located using the webfinger API.
|
||||||
LastWebfingeredAt time.Time `pg:"type:timestamp"`
|
LastWebfingeredAt time.Time `pg:"type:timestamp"`
|
||||||
// Address of this account's activitypub inbox, for sending activity to
|
// Address of this account's activitypub inbox, for sending activity to
|
||||||
|
@ -132,47 +145,8 @@ type Account struct {
|
||||||
SuspensionOrigin int
|
SuspensionOrigin int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Avatar represents the avatar for the account for display purposes
|
type Field struct {
|
||||||
type Avatar struct {
|
Name string
|
||||||
// File name of the avatar on local storage
|
Value string
|
||||||
AvatarFileName string
|
VerifiedAt time.Time `pg:"type:timestamp"`
|
||||||
// Gif? png? jpeg?
|
|
||||||
AvatarContentType string
|
|
||||||
AvatarFileSize int
|
|
||||||
AvatarUpdatedAt *time.Time `pg:"type:timestamp"`
|
|
||||||
// Where can we retrieve the avatar?
|
|
||||||
AvatarRemoteURL *url.URL `pg:"type:text"`
|
|
||||||
AvatarStorageSchemaVersion int
|
|
||||||
}
|
|
||||||
|
|
||||||
// Header represents the header of the account for display purposes
|
|
||||||
type Header struct {
|
|
||||||
// File name of the header on local storage
|
|
||||||
HeaderFileName string
|
|
||||||
// Gif? png? jpeg?
|
|
||||||
HeaderContentType string
|
|
||||||
HeaderFileSize int
|
|
||||||
HeaderUpdatedAt *time.Time `pg:"type:timestamp"`
|
|
||||||
// Where can we retrieve the header?
|
|
||||||
HeaderRemoteURL *url.URL `pg:"type:text"`
|
|
||||||
HeaderStorageSchemaVersion int
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToMastoSensitive returns this account as a mastodon api type, ready for serialization
|
|
||||||
func (a *Account) ToMastoSensitive() *mastotypes.Account {
|
|
||||||
return &mastotypes.Account{
|
|
||||||
ID: a.ID,
|
|
||||||
Username: a.Username,
|
|
||||||
Acct: a.Username, // equivalent to username for local users only, which sensitive always is
|
|
||||||
DisplayName: a.DisplayName,
|
|
||||||
Locked: a.Locked,
|
|
||||||
Bot: a.Bot,
|
|
||||||
CreatedAt: a.CreatedAt.Format(time.RFC3339),
|
|
||||||
Note: a.Note,
|
|
||||||
URL: a.URL,
|
|
||||||
Avatar: a.Avatar.AvatarRemoteURL.String(),
|
|
||||||
AvatarStatic: a.AvatarRemoteURL.String(),
|
|
||||||
Header: a.Header.HeaderRemoteURL.String(),
|
|
||||||
HeaderStatic: a.Header.HeaderRemoteURL.String(),
|
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -16,7 +16,7 @@
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package gtsmodel
|
package model
|
||||||
|
|
||||||
import "github.com/gotosocial/gotosocial/pkg/mastotypes"
|
import "github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package gtsmodel
|
package model
|
||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package gtsmodel
|
package model
|
||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package gtsmodel
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
|
@ -31,7 +31,8 @@ import (
|
||||||
"github.com/go-pg/pg/v10"
|
"github.com/go-pg/pg/v10"
|
||||||
"github.com/go-pg/pg/v10/orm"
|
"github.com/go-pg/pg/v10/orm"
|
||||||
"github.com/gotosocial/gotosocial/internal/config"
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
|
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,7 +47,7 @@ type postgresService struct {
|
||||||
|
|
||||||
// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
|
// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
|
||||||
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
|
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
|
||||||
func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (*postgresService, error) {
|
func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (DB, error) {
|
||||||
opts, err := derivePGOptions(c)
|
opts, err := derivePGOptions(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not create postgres service: %s", err)
|
return nil, fmt.Errorf("could not create postgres service: %s", err)
|
||||||
|
@ -108,10 +109,6 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *postgresService) Federation() pub.Database {
|
|
||||||
return ps.federationDB
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
HANDY STUFF
|
HANDY STUFF
|
||||||
*/
|
*/
|
||||||
|
@ -168,9 +165,29 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
EXTRA FUNCTIONS
|
FEDERATION FUNCTIONALITY
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
func (ps *postgresService) Federation() pub.Database {
|
||||||
|
return ps.federationDB
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
BASIC DB FUNCTIONALITY
|
||||||
|
*/
|
||||||
|
|
||||||
|
func (ps *postgresService) CreateTable(i interface{}) error {
|
||||||
|
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
|
||||||
|
IfNotExists: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) DropTable(i interface{}) error {
|
||||||
|
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
|
||||||
|
IfExists: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (ps *postgresService) Stop(ctx context.Context) error {
|
func (ps *postgresService) Stop(ctx context.Context) error {
|
||||||
ps.log.Info("closing db connection")
|
ps.log.Info("closing db connection")
|
||||||
if err := ps.conn.Close(); err != nil {
|
if err := ps.conn.Close(); err != nil {
|
||||||
|
@ -181,11 +198,15 @@ func (ps *postgresService) Stop(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) IsHealthy(ctx context.Context) error {
|
||||||
|
return ps.conn.Ping(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func (ps *postgresService) CreateSchema(ctx context.Context) error {
|
func (ps *postgresService) CreateSchema(ctx context.Context) error {
|
||||||
models := []interface{}{
|
models := []interface{}{
|
||||||
(*gtsmodel.Account)(nil),
|
(*model.Account)(nil),
|
||||||
(*gtsmodel.Status)(nil),
|
(*model.Status)(nil),
|
||||||
(*gtsmodel.User)(nil),
|
(*model.User)(nil),
|
||||||
}
|
}
|
||||||
ps.log.Info("creating db schema")
|
ps.log.Info("creating db schema")
|
||||||
|
|
||||||
|
@ -202,32 +223,35 @@ func (ps *postgresService) CreateSchema(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *postgresService) IsHealthy(ctx context.Context) error {
|
|
||||||
return ps.conn.Ping(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) CreateTable(i interface{}) error {
|
|
||||||
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
|
|
||||||
IfNotExists: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) DropTable(i interface{}) error {
|
|
||||||
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
|
|
||||||
IfExists: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) GetByID(id string, i interface{}) error {
|
func (ps *postgresService) GetByID(id string, i interface{}) error {
|
||||||
return ps.conn.Model(i).Where("id = ?", id).Select()
|
if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error {
|
func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error {
|
||||||
return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select()
|
if err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *postgresService) GetAll(i interface{}) error {
|
func (ps *postgresService) GetAll(i interface{}) error {
|
||||||
return ps.conn.Model(i).Select()
|
if err := ps.conn.Model(i).Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *postgresService) Put(i interface{}) error {
|
func (ps *postgresService) Put(i interface{}) error {
|
||||||
|
@ -236,34 +260,207 @@ func (ps *postgresService) Put(i interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
|
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
|
||||||
_, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert()
|
if _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
|
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
|
||||||
_, err := ps.conn.Model(i).Where("id = ?", id).Delete()
|
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error {
|
func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error {
|
||||||
_, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete()
|
if _, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error {
|
/*
|
||||||
user := >smodel.User{
|
HANDY SHORTCUTS
|
||||||
|
*/
|
||||||
|
|
||||||
|
func (ps *postgresService) GetAccountByUserID(userID string, account *model.Account) error {
|
||||||
|
user := &model.User{
|
||||||
ID: userID,
|
ID: userID,
|
||||||
}
|
}
|
||||||
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
|
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return ps.conn.Model(account).Where("id = ?", user.AccountID).Select()
|
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error {
|
func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]model.Follow) error {
|
||||||
return ps.conn.Model(following).Where("account_id = ?", accountID).Select()
|
if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *postgresService) GetFollowersByAccountID(accountID string, following *[]gtsmodel.Follow) error {
|
func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error {
|
||||||
return ps.conn.Model(following).Where("target_account_id = ?", accountID).Select()
|
if err := ps.conn.Model(followers).Where("target_account_id = ?", accountID).Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) GetStatusesByAccountID(accountID string, statuses *[]model.Status) error {
|
||||||
|
if err := ps.conn.Model(statuses).Where("account_id = ?", accountID).Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error {
|
||||||
|
q := ps.conn.Model(statuses).Order("created_at DESC")
|
||||||
|
if limit != 0 {
|
||||||
|
q = q.Limit(limit)
|
||||||
|
}
|
||||||
|
if accountID != "" {
|
||||||
|
q = q.Where("account_id = ?", accountID)
|
||||||
|
}
|
||||||
|
if err := q.Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *model.Status) error {
|
||||||
|
if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
CONVERSION FUNCTIONS
|
||||||
|
*/
|
||||||
|
|
||||||
|
// AccountToMastoSensitive takes an internal account model and transforms it into an account ready to be served through the API.
|
||||||
|
// The resulting account fits the specifications for the path /api/v1/accounts/verify_credentials, as described here:
|
||||||
|
// https://docs.joinmastodon.org/methods/accounts/. Note that it's *sensitive* because it's only meant to be exposed to the user
|
||||||
|
// that the account actually belongs to.
|
||||||
|
func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) {
|
||||||
|
|
||||||
|
fields := []mastotypes.Field{}
|
||||||
|
for _, f := range a.Fields {
|
||||||
|
mField := mastotypes.Field{
|
||||||
|
Name: f.Name,
|
||||||
|
Value: f.Value,
|
||||||
|
}
|
||||||
|
if !f.VerifiedAt.IsZero() {
|
||||||
|
mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
fields = append(fields, mField)
|
||||||
|
}
|
||||||
|
fmt.Printf("fields: %+v", fields)
|
||||||
|
|
||||||
|
// count followers
|
||||||
|
var followers []model.Follow
|
||||||
|
if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil {
|
||||||
|
if _, ok := err.(ErrNoEntries); !ok {
|
||||||
|
return nil, fmt.Errorf("error getting followers: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var followersCount int
|
||||||
|
if followers != nil {
|
||||||
|
followersCount = len(followers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// count following
|
||||||
|
var following []model.Follow
|
||||||
|
if err := ps.GetFollowingByAccountID(a.ID, &following); err != nil {
|
||||||
|
if _, ok := err.(ErrNoEntries); !ok {
|
||||||
|
return nil, fmt.Errorf("error getting following: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var followingCount int
|
||||||
|
if following != nil {
|
||||||
|
followingCount = len(following)
|
||||||
|
}
|
||||||
|
|
||||||
|
// count statuses
|
||||||
|
var statuses []model.Status
|
||||||
|
if err := ps.GetStatusesByAccountID(a.ID, &statuses); err != nil {
|
||||||
|
if _, ok := err.(ErrNoEntries); !ok {
|
||||||
|
return nil, fmt.Errorf("error getting last statuses: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var statusesCount int
|
||||||
|
if statuses != nil {
|
||||||
|
statusesCount = len(statuses)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check when the last status was
|
||||||
|
var lastStatus *model.Status
|
||||||
|
if err := ps.GetLastStatusForAccountID(a.ID, lastStatus); err != nil {
|
||||||
|
if _, ok := err.(ErrNoEntries); !ok {
|
||||||
|
return nil, fmt.Errorf("error getting last status: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var lastStatusAt string
|
||||||
|
if lastStatus != nil {
|
||||||
|
lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &mastotypes.Account{
|
||||||
|
ID: a.ID,
|
||||||
|
Username: a.Username,
|
||||||
|
Acct: a.Username, // equivalent to username for local users only, which sensitive always is
|
||||||
|
DisplayName: a.DisplayName,
|
||||||
|
Locked: a.Locked,
|
||||||
|
Bot: a.Bot,
|
||||||
|
CreatedAt: a.CreatedAt.Format(time.RFC3339),
|
||||||
|
Note: a.Note,
|
||||||
|
URL: a.URL,
|
||||||
|
Avatar: a.AvatarRemoteURL.String(),
|
||||||
|
AvatarStatic: a.AvatarRemoteURL.String(),
|
||||||
|
Header: a.HeaderRemoteURL.String(),
|
||||||
|
HeaderStatic: a.HeaderRemoteURL.String(),
|
||||||
|
FollowersCount: followersCount,
|
||||||
|
FollowingCount: followingCount,
|
||||||
|
StatusesCount: statusesCount,
|
||||||
|
LastStatusAt: lastStatusAt,
|
||||||
|
Source: nil,
|
||||||
|
Emojis: nil,
|
||||||
|
Fields: fields,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,13 +19,13 @@
|
||||||
package account
|
package account
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gotosocial/gotosocial/internal/config"
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
"github.com/gotosocial/gotosocial/internal/db"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
"github.com/gotosocial/gotosocial/internal/module"
|
"github.com/gotosocial/gotosocial/internal/module"
|
||||||
"github.com/gotosocial/gotosocial/internal/module/oauth"
|
"github.com/gotosocial/gotosocial/internal/module/oauth"
|
||||||
"github.com/gotosocial/gotosocial/internal/router"
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
|
@ -56,19 +56,33 @@ func (m *accountModule) Route(r router.Router) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountVerifyGETHandler serves a user's account details to them IF they reached this
|
||||||
|
// handler while in possession of a valid token, according to the oauth middleware.
|
||||||
func (m *accountModule) AccountVerifyGETHandler(c *gin.Context) {
|
func (m *accountModule) AccountVerifyGETHandler(c *gin.Context) {
|
||||||
s := sessions.Default(c)
|
i, ok := c.Get(oauth.SessionAuthorizedUser)
|
||||||
userID, ok := s.Get(oauth.SessionAuthorizedUser).(string)
|
fmt.Println(i)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "The access token is invalid"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, ok := (i).(string)
|
||||||
if !ok || userID == "" {
|
if !ok || userID == "" {
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "The access token is invalid"})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "The access token is invalid"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
acct := >smodel.Account{}
|
acct := &model.Account{}
|
||||||
if err := m.db.GetAccountByUserID(userID, acct); err != nil {
|
if err := m.db.GetAccountByUserID(userID, acct); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, acct.ToMastoSensitive())
|
acctSensitive, err := m.db.AccountToMastoSensitive(acct)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, acctSensitive)
|
||||||
}
|
}
|
||||||
|
|
233
internal/module/account/account_test.go
Normal file
233
internal/module/account/account_test.go
Normal file
|
@ -0,0 +1,233 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/module/oauth"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AccountTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
db db.DB
|
||||||
|
testAccountLocal *model.Account
|
||||||
|
testAccountRemote *model.Account
|
||||||
|
testUser *model.User
|
||||||
|
config *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||||
|
func (suite *AccountTestSuite) SetupSuite() {
|
||||||
|
c := config.Empty()
|
||||||
|
c.DBConfig = &config.DBConfig{
|
||||||
|
Type: "postgres",
|
||||||
|
Address: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "postgres",
|
||||||
|
Password: "postgres",
|
||||||
|
Database: "postgres",
|
||||||
|
ApplicationName: "gotosocial",
|
||||||
|
}
|
||||||
|
suite.config = c
|
||||||
|
|
||||||
|
encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Panicf("error encrypting user pass: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png")
|
||||||
|
if err != nil {
|
||||||
|
logrus.Panicf("error parsing localavatar url: %s", err)
|
||||||
|
}
|
||||||
|
localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png")
|
||||||
|
if err != nil {
|
||||||
|
logrus.Panicf("error parsing localheader url: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
acctID := uuid.NewString()
|
||||||
|
suite.testAccountLocal = &model.Account{
|
||||||
|
ID: acctID,
|
||||||
|
Username: "local_account_of_some_kind",
|
||||||
|
AvatarRemoteURL: localAvatar,
|
||||||
|
HeaderRemoteURL: localHeader,
|
||||||
|
DisplayName: "michael caine",
|
||||||
|
Fields: []model.Field{
|
||||||
|
{
|
||||||
|
Name: "come and ave a go",
|
||||||
|
Value: "if you think you're hard enough",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "website",
|
||||||
|
Value: "https://imdb.com",
|
||||||
|
VerifiedAt: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Note: "My name is Michael Caine and i'm a local user.",
|
||||||
|
Discoverable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png")
|
||||||
|
if err != nil {
|
||||||
|
logrus.Panicf("error parsing avatarURL: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png")
|
||||||
|
if err != nil {
|
||||||
|
logrus.Panicf("error parsing avatarURL: %s", err)
|
||||||
|
}
|
||||||
|
suite.testAccountRemote = &model.Account{
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
Username: "neato_bombeato",
|
||||||
|
Domain: "example.org",
|
||||||
|
|
||||||
|
AvatarFileName: "avatar.png",
|
||||||
|
AvatarContentType: "image/png",
|
||||||
|
AvatarFileSize: 1024,
|
||||||
|
AvatarUpdatedAt: time.Now(),
|
||||||
|
AvatarRemoteURL: avatarURL,
|
||||||
|
|
||||||
|
HeaderFileName: "avatar.png",
|
||||||
|
HeaderContentType: "image/png",
|
||||||
|
HeaderFileSize: 1024,
|
||||||
|
HeaderUpdatedAt: time.Now(),
|
||||||
|
HeaderRemoteURL: headerURL,
|
||||||
|
|
||||||
|
DisplayName: "one cool dude 420",
|
||||||
|
Fields: []model.Field{
|
||||||
|
{
|
||||||
|
Name: "pronouns",
|
||||||
|
Value: "he/they",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "website",
|
||||||
|
Value: "https://imcool.edu",
|
||||||
|
VerifiedAt: time.Now(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Note: "<p>I'm cool as heck!</p>",
|
||||||
|
Discoverable: true,
|
||||||
|
URI: "https://example.org/users/neato_bombeato",
|
||||||
|
URL: "https://example.org/@neato_bombeato",
|
||||||
|
LastWebfingeredAt: time.Now(),
|
||||||
|
InboxURL: "https://example.org/users/neato_bombeato/inbox",
|
||||||
|
OutboxURL: "https://example.org/users/neato_bombeato/outbox",
|
||||||
|
SharedInboxURL: "https://example.org/inbox",
|
||||||
|
FollowersURL: "https://example.org/users/neato_bombeato/followers",
|
||||||
|
FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured",
|
||||||
|
}
|
||||||
|
suite.testUser = &model.User{
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
EncryptedPassword: string(encryptedPassword),
|
||||||
|
Email: "user@example.org",
|
||||||
|
AccountID: acctID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
||||||
|
func (suite *AccountTestSuite) SetupTest() {
|
||||||
|
|
||||||
|
log := logrus.New()
|
||||||
|
log.SetLevel(logrus.TraceLevel)
|
||||||
|
db, err := db.New(context.Background(), suite.config, log)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Panicf("error creating database connection: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
suite.db = db
|
||||||
|
|
||||||
|
models := []interface{}{
|
||||||
|
&model.User{},
|
||||||
|
&model.Account{},
|
||||||
|
&model.Follow{},
|
||||||
|
&model.Status{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range models {
|
||||||
|
if err := suite.db.CreateTable(m); err != nil {
|
||||||
|
logrus.Panicf("db connection error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := suite.db.Put(suite.testAccountLocal); err != nil {
|
||||||
|
logrus.Panicf("could not insert test account into db: %s", err)
|
||||||
|
}
|
||||||
|
if err := suite.db.Put(suite.testUser); err != nil {
|
||||||
|
logrus.Panicf("could not insert test user into db: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
||||||
|
func (suite *AccountTestSuite) TearDownTest() {
|
||||||
|
models := []interface{}{
|
||||||
|
&model.User{},
|
||||||
|
&model.Account{},
|
||||||
|
&model.Follow{},
|
||||||
|
&model.Status{},
|
||||||
|
}
|
||||||
|
for _, m := range models {
|
||||||
|
if err := suite.db.DropTable(m); err != nil {
|
||||||
|
logrus.Panicf("error dropping table: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := suite.db.Stop(context.Background()); err != nil {
|
||||||
|
logrus.Panicf("error closing db connection: %s", err)
|
||||||
|
}
|
||||||
|
suite.db = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AccountTestSuite) TestAPIInitialize() {
|
||||||
|
log := logrus.New()
|
||||||
|
log.SetLevel(logrus.TraceLevel)
|
||||||
|
|
||||||
|
r, err := router.New(suite.config, log)
|
||||||
|
if err != nil {
|
||||||
|
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.AttachMiddleware(func(c *gin.Context) {
|
||||||
|
c.Set(oauth.SessionAuthorizedUser, suite.testUser.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
acct := New(suite.config, suite.db)
|
||||||
|
acct.Route(r)
|
||||||
|
|
||||||
|
r.Start()
|
||||||
|
defer r.Stop(context.Background())
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountTestSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(AccountTestSuite))
|
||||||
|
}
|
|
@ -20,7 +20,6 @@ package oauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/gotosocial/gotosocial/internal/db"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/oauth2/v4"
|
"github.com/gotosocial/oauth2/v4"
|
||||||
|
@ -43,7 +42,7 @@ func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.Cli
|
||||||
ID: clientID,
|
ID: clientID,
|
||||||
}
|
}
|
||||||
if err := cs.db.GetByID(clientID, poc); err != nil {
|
if err := cs.db.GetByID(clientID, poc); err != nil {
|
||||||
return nil, fmt.Errorf("database error: %s", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil
|
return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,7 +136,7 @@ func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
|
||||||
// try to get the deleted client; we should get an error
|
// try to get the deleted client; we should get an error
|
||||||
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
|
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
|
||||||
suite.Assert().Nil(deletedClient)
|
suite.Assert().Nil(deletedClient)
|
||||||
suite.Assert().NotNil(err)
|
suite.Assert().EqualValues(db.ErrNoEntries{}, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPgClientStoreTestSuite(t *testing.T) {
|
func TestPgClientStoreTestSuite(t *testing.T) {
|
||||||
|
|
|
@ -34,7 +34,7 @@ import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/gotosocial/gotosocial/internal/db"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
"github.com/gotosocial/gotosocial/internal/module"
|
"github.com/gotosocial/gotosocial/internal/module"
|
||||||
"github.com/gotosocial/gotosocial/internal/router"
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||||
|
@ -179,7 +179,7 @@ func (m *oauthModule) appsPOSTHandler(c *gin.Context) {
|
||||||
vapidKey := uuid.NewString()
|
vapidKey := uuid.NewString()
|
||||||
|
|
||||||
// generate the application to put in the database
|
// generate the application to put in the database
|
||||||
app := >smodel.Application{
|
app := &model.Application{
|
||||||
Name: form.ClientName,
|
Name: form.ClientName,
|
||||||
Website: form.Website,
|
Website: form.Website,
|
||||||
RedirectURI: form.RedirectURIs,
|
RedirectURI: form.RedirectURIs,
|
||||||
|
@ -287,7 +287,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
app := >smodel.Application{
|
app := &model.Application{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
}
|
}
|
||||||
if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil {
|
if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil {
|
||||||
|
@ -296,7 +296,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
|
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
|
||||||
user := >smodel.User{
|
user := &model.User{
|
||||||
ID: userID,
|
ID: userID,
|
||||||
}
|
}
|
||||||
if err := m.db.GetByID(user.ID, user); err != nil {
|
if err := m.db.GetByID(user.ID, user); err != nil {
|
||||||
|
@ -304,7 +304,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
acct := >smodel.Account{
|
acct := &model.Account{
|
||||||
ID: user.AccountID,
|
ID: user.AccountID,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -413,7 +413,6 @@ func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
|
||||||
if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil {
|
if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil {
|
||||||
l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope())
|
l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope())
|
||||||
c.Set(SessionAuthorizedUser, ti.GetUserID())
|
c.Set(SessionAuthorizedUser, ti.GetUserID())
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
l.Trace("continuing with unauthenticated request")
|
l.Trace("continuing with unauthenticated request")
|
||||||
}
|
}
|
||||||
|
@ -437,7 +436,7 @@ func (m *oauthModule) validatePassword(email string, password string) (userid st
|
||||||
}
|
}
|
||||||
|
|
||||||
// first we select the user from the database based on email address, bail if no user found for that email
|
// first we select the user from the database based on email address, bail if no user found for that email
|
||||||
gtsUser := >smodel.User{}
|
gtsUser := &model.User{}
|
||||||
|
|
||||||
if err := m.db.GetWhere("email", email, gtsUser); err != nil {
|
if err := m.db.GetWhere("email", email, gtsUser); err != nil {
|
||||||
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
|
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
|
||||||
|
|
|
@ -22,12 +22,11 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/gotosocial/gotosocial/internal/config"
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
"github.com/gotosocial/gotosocial/internal/db"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
"github.com/gotosocial/gotosocial/internal/router"
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
"github.com/gotosocial/oauth2/v4"
|
"github.com/gotosocial/oauth2/v4"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
@ -40,9 +39,9 @@ type OauthTestSuite struct {
|
||||||
tokenStore oauth2.TokenStore
|
tokenStore oauth2.TokenStore
|
||||||
clientStore oauth2.ClientStore
|
clientStore oauth2.ClientStore
|
||||||
db db.DB
|
db db.DB
|
||||||
testAccount *gtsmodel.Account
|
testAccount *model.Account
|
||||||
testApplication *gtsmodel.Application
|
testApplication *model.Application
|
||||||
testUser *gtsmodel.User
|
testUser *model.User
|
||||||
testClient *oauthClient
|
testClient *oauthClient
|
||||||
config *config.Config
|
config *config.Config
|
||||||
}
|
}
|
||||||
|
@ -76,11 +75,11 @@ func (suite *OauthTestSuite) SetupSuite() {
|
||||||
|
|
||||||
acctID := uuid.NewString()
|
acctID := uuid.NewString()
|
||||||
|
|
||||||
suite.testAccount = >smodel.Account{
|
suite.testAccount = &model.Account{
|
||||||
ID: acctID,
|
ID: acctID,
|
||||||
Username: "test_user",
|
Username: "test_user",
|
||||||
}
|
}
|
||||||
suite.testUser = >smodel.User{
|
suite.testUser = &model.User{
|
||||||
EncryptedPassword: string(encryptedPassword),
|
EncryptedPassword: string(encryptedPassword),
|
||||||
Email: "user@example.org",
|
Email: "user@example.org",
|
||||||
AccountID: acctID,
|
AccountID: acctID,
|
||||||
|
@ -90,7 +89,7 @@ func (suite *OauthTestSuite) SetupSuite() {
|
||||||
Secret: "some-secret",
|
Secret: "some-secret",
|
||||||
Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
|
Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
|
||||||
}
|
}
|
||||||
suite.testApplication = >smodel.Application{
|
suite.testApplication = &model.Application{
|
||||||
Name: "a test application",
|
Name: "a test application",
|
||||||
Website: "https://some-application-website.com",
|
Website: "https://some-application-website.com",
|
||||||
RedirectURI: "http://localhost:8080",
|
RedirectURI: "http://localhost:8080",
|
||||||
|
@ -116,9 +115,9 @@ func (suite *OauthTestSuite) SetupTest() {
|
||||||
models := []interface{}{
|
models := []interface{}{
|
||||||
&oauthClient{},
|
&oauthClient{},
|
||||||
&oauthToken{},
|
&oauthToken{},
|
||||||
>smodel.User{},
|
&model.User{},
|
||||||
>smodel.Account{},
|
&model.Account{},
|
||||||
>smodel.Application{},
|
&model.Application{},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
|
@ -150,9 +149,9 @@ func (suite *OauthTestSuite) TearDownTest() {
|
||||||
models := []interface{}{
|
models := []interface{}{
|
||||||
&oauthClient{},
|
&oauthClient{},
|
||||||
&oauthToken{},
|
&oauthToken{},
|
||||||
>smodel.User{},
|
&model.User{},
|
||||||
>smodel.Account{},
|
&model.Account{},
|
||||||
>smodel.Application{},
|
&model.Application{},
|
||||||
}
|
}
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if err := suite.db.DropTable(m); err != nil {
|
if err := suite.db.DropTable(m); err != nil {
|
||||||
|
@ -179,11 +178,10 @@ func (suite *OauthTestSuite) TestAPIInitialize() {
|
||||||
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
go r.Start()
|
r.Start()
|
||||||
time.Sleep(60 * time.Second)
|
if err := r.Stop(context.Background()); err != nil {
|
||||||
// http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=http://localhost:8080&scope=read
|
suite.FailNow(fmt.Sprintf("error stopping router: %s", err))
|
||||||
// curl -v -F client_id=a-known-client-id -F client_secret=some-secret -F redirect_uri=http://localhost:8080 -F code=[ INSERT CODE HERE ] -F grant_type=authorization_code localhost:8080/oauth/token
|
}
|
||||||
// curl -v -H "Authorization: Bearer [INSERT TOKEN HERE]" http://localhost:8080
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOauthTestSuite(t *testing.T) {
|
func TestOauthTestSuite(t *testing.T) {
|
||||||
|
|
|
@ -19,8 +19,10 @@
|
||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
@ -40,26 +42,28 @@ type Router interface {
|
||||||
// Start the router
|
// Start the router
|
||||||
Start()
|
Start()
|
||||||
// Stop the router
|
// Stop the router
|
||||||
Stop()
|
Stop(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// router fulfils the Router interface using gin and logrus
|
// router fulfils the Router interface using gin and logrus
|
||||||
type router struct {
|
type router struct {
|
||||||
logger *logrus.Logger
|
logger *logrus.Logger
|
||||||
engine *gin.Engine
|
engine *gin.Engine
|
||||||
|
srv *http.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the router nicely
|
// Start starts the router nicely
|
||||||
func (s *router) Start() {
|
func (r *router) Start() {
|
||||||
// todo: start gracefully
|
go func() {
|
||||||
if err := s.engine.Run(); err != nil {
|
if err := r.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
s.logger.Panicf("server error: %s", err)
|
r.logger.Fatalf("listen: %s", err)
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop shuts down the router nicely
|
// Stop shuts down the router nicely
|
||||||
func (s *router) Stop() {
|
func (s *router) Stop(ctx context.Context) error {
|
||||||
// todo: shut down gracefully
|
return s.srv.Shutdown(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AttachHandler attaches the given gin.HandlerFunc to the router with the specified method and path.
|
// AttachHandler attaches the given gin.HandlerFunc to the router with the specified method and path.
|
||||||
|
@ -100,6 +104,10 @@ func New(config *config.Config, logger *logrus.Logger) (Router, error) {
|
||||||
return &router{
|
return &router{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
engine: engine,
|
engine: engine,
|
||||||
|
srv: &http.Server{
|
||||||
|
Addr: ":8080",
|
||||||
|
Handler: engine,
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ type Account struct {
|
||||||
// Whether the account manually approves follow requests.
|
// Whether the account manually approves follow requests.
|
||||||
Locked bool `json:"locked"`
|
Locked bool `json:"locked"`
|
||||||
// Whether the account has opted into discovery features such as the profile directory.
|
// Whether the account has opted into discovery features such as the profile directory.
|
||||||
Discoverable bool `json:"discoverable"`
|
Discoverable bool `json:"discoverable,omitempty"`
|
||||||
// A presentational flag. Indicates that the account may perform automated actions, may not be monitored, or identifies as a robot.
|
// A presentational flag. Indicates that the account may perform automated actions, may not be monitored, or identifies as a robot.
|
||||||
Bot bool `json:"bot"`
|
Bot bool `json:"bot"`
|
||||||
// When the account was created. (ISO 8601 Datetime)
|
// When the account was created. (ISO 8601 Datetime)
|
||||||
|
@ -61,9 +61,9 @@ type Account struct {
|
||||||
// Additional metadata attached to a profile as name-value pairs.
|
// Additional metadata attached to a profile as name-value pairs.
|
||||||
Fields []Field `json:"fields"`
|
Fields []Field `json:"fields"`
|
||||||
// An extra entity returned when an account is suspended.
|
// An extra entity returned when an account is suspended.
|
||||||
Suspended bool `json:"suspended"`
|
Suspended bool `json:"suspended,omitempty"`
|
||||||
// When a timed mute will expire, if applicable. (ISO 8601 Datetime)
|
// When a timed mute will expire, if applicable. (ISO 8601 Datetime)
|
||||||
MuteExpiresAt string `json:"mute_expires_at"`
|
MuteExpiresAt string `json:"mute_expires_at,omitempty"`
|
||||||
// An extra entity to be used with API methods to verify credentials and update credentials.
|
// An extra entity to be used with API methods to verify credentials and update credentials.
|
||||||
Source *Source `json:"source"`
|
Source *Source `json:"source"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,6 @@ type Field struct {
|
||||||
Value string `json:"value"`
|
Value string `json:"value"`
|
||||||
|
|
||||||
// OPTIONAL
|
// OPTIONAL
|
||||||
|
|
||||||
// Timestamp of when the server verified a URL value for a rel="me” link. String (ISO 8601 Datetime) if value is a verified URL
|
// Timestamp of when the server verified a URL value for a rel="me” link. String (ISO 8601 Datetime) if value is a verified URL
|
||||||
VerifiedAt string `json:"verified_at,omitempty"`
|
VerifiedAt string `json:"verified_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user