diff --git a/internal/db/db.go b/internal/db/db.go index 9074c23..54a369e 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -25,59 +25,122 @@ import ( "github.com/go-fed/activity/pub" "github.com/gotosocial/gotosocial/internal/config" - "github.com/gotosocial/gotosocial/internal/gtsmodel" + "github.com/gotosocial/gotosocial/internal/db/model" + "github.com/gotosocial/gotosocial/pkg/mastotypes" "github.com/sirupsen/logrus" ) const dbTypePostgres string = "POSTGRES" +type ErrNoEntries struct{} + +func (e ErrNoEntries) Error() string { + return "no entries" +} + // DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres). +// Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated +// by whatever is returned from the database. type DB interface { // Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions. // See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database Federation() pub.Database - // CreateTable creates a table for the given interface + /* + BASIC DB FUNCTIONALITY + */ + + // CreateTable creates a table for the given interface. + // For implementations that don't use tables, this can just return nil. CreateTable(i interface{}) error - // DropTable drops the table for the given interface + // DropTable drops the table for the given interface. + // For implementations that don't use tables, this can just return nil. DropTable(i interface{}) error - // Stop should stop and close the database connection cleanly, returning an error if this is not possible + // Stop should stop and close the database connection cleanly, returning an error if this is not possible. + // If the database implementation doesn't need to be stopped, this can just return nil. Stop(ctx context.Context) error - // IsHealthy should return nil if the database connection is healthy, or an error if not + // IsHealthy should return nil if the database connection is healthy, or an error if not. IsHealthy(ctx context.Context) error - // GetByID gets one entry by its id. + // GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry, + // for other implementations (for example, in-memory) it might just be the key of a map. + // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. + // In case of no entries, a 'no entries' error will be returned GetByID(id string, i interface{}) error - // GetWhere gets one entry where key = value + // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the + // name of the key to select from. + // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. + // In case of no entries, a 'no entries' error will be returned GetWhere(key string, value interface{}, i interface{}) error - // GetAll gets all entries of interface type i + // GetAll will try to get all entries of type i. + // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. + // In case of no entries, a 'no entries' error will be returned GetAll(i interface{}) error - // Put stores i + // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key. + // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. Put(i interface{}) error - // Update by id updates i with id id + // UpdateByID updates i with id id. + // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. UpdateByID(id string, i interface{}) error - // Delete by id removes i with id id + // DeleteByID removes i with id id. + // If i didn't exist anyway, then no error should be returned. DeleteByID(id string, i interface{}) error - // Delete where deletes i where key = value + // DeleteWhere deletes i where key = value + // If i didn't exist anyway, then no error should be returned. DeleteWhere(key string, value interface{}, i interface{}) error - // GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID - GetAccountByUserID(userID string, account *gtsmodel.Account) error + /* + HANDY SHORTCUTS + */ - // GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following - GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error + // GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID. + // The given account pointer will be set to the result of the query, whatever it is. + // In case of no entries, a 'no entries' error will be returned + GetAccountByUserID(userID string, account *model.Account) error - // GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by - GetFollowersByAccountID(accountID string, following *[]gtsmodel.Follow) error + // GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following. + // The given slice 'following' will be set to the result of the query, whatever it is. + // In case of no entries, a 'no entries' error will be returned + GetFollowingByAccountID(accountID string, following *[]model.Follow) error + + // GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by. + // The given slice 'followers' will be set to the result of the query, whatever it is. + // In case of no entries, a 'no entries' error will be returned + GetFollowersByAccountID(accountID string, followers *[]model.Follow) error + + // GetStatusesByAccountID is a shortcut for the common action of fetching a list of statuses produced by accountID. + // The given slice 'statuses' will be set to the result of the query, whatever it is. + // In case of no entries, a 'no entries' error will be returned + GetStatusesByAccountID(accountID string, statuses *[]model.Status) error + + // GetStatusesByTimeDescending is a shortcut for getting the most recent statuses. accountID is optional, if not provided + // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can + // be very memory intensive so you probably shouldn't do this! + // In case of no entries, a 'no entries' error will be returned + GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error + + // GetLastStatusForAccountID simply gets the most recent status by the given account. + // The given slice 'status' pointer will be set to the result of the query, whatever it is. + // In case of no entries, a 'no entries' error will be returned + GetLastStatusForAccountID(accountID string, status *model.Status) error + + /* + USEFUL CONVERSION FUNCTIONS + */ + + // AccountToMastoSensitive takes a db model account as a param, and returns a populated mastotype account, or an error + // if something goes wrong. The returned account should be ready to serialize on an API level, and may have sensitive fields, + // so serve it only to an authorized user who should have permission to see it. + AccountToMastoSensitive(account *model.Account) (*mastotypes.Account, error) } // New returns a new database service that satisfies the DB interface and, by extension, diff --git a/internal/gtsmodel/README.md b/internal/db/model/README.md similarity index 100% rename from internal/gtsmodel/README.md rename to internal/db/model/README.md diff --git a/internal/gtsmodel/account.go b/internal/db/model/account.go similarity index 73% rename from internal/gtsmodel/account.go rename to internal/db/model/account.go index d11f676..130347c 100644 --- a/internal/gtsmodel/account.go +++ b/internal/db/model/account.go @@ -16,16 +16,15 @@ along with this program. If not, see . */ -// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database. +// Package model contains types used *internally* by GoToSocial and added/removed/selected from the database. // These types should never be serialized and/or sent out via public APIs, as they contain sensitive information. -// The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/ -package gtsmodel +// The annotation used on these structs is for handling them via the go-pg ORM (hence why they're in this db subdir). +// See here for more info on go-pg model annotations: https://pg.uptrace.dev/models/ +package model import ( "net/url" "time" - - "github.com/gotosocial/gotosocial/pkg/mastotypes" ) // Account represents either a local or a remote fediverse account, gotosocial or otherwise (mastodon, pleroma, etc) @@ -39,20 +38,36 @@ type Account struct { // Username of the account, should just be a string of [a-z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org`` Username string `pg:",notnull,unique:userdomain"` // username and domain should be unique *with* each other // Domain of the account, will be empty if this is a local account, otherwise something like ``example.org`` or ``mastodon.social``. Should be unique with username. - Domain string `pg:",unique:userdomain"` // username and domain + Domain string `pg:",unique:userdomain"` // username and domain should be unique *with* each other /* ACCOUNT METADATA */ - // Avatar image for this account - Avatar - // Header image for this account - Header + // File name of the avatar on local storage + AvatarFileName string + // Gif? png? jpeg? + AvatarContentType string + // Size of the avatar in bytes + AvatarFileSize int + // When was the avatar last updated? + AvatarUpdatedAt time.Time `pg:"type:timestamp"` + // Where can the avatar be retrieved? + AvatarRemoteURL *url.URL `pg:"type:text"` + // File name of the header on local storage + HeaderFileName string + // Gif? png? jpeg? + HeaderContentType string + // Size of the header in bytes + HeaderFileSize int + // When was the header last updated? + HeaderUpdatedAt time.Time `pg:"type:timestamp"` + // Where can the header be retrieved? + HeaderRemoteURL *url.URL `pg:"type:text"` // DisplayName for this account. Can be empty, then just the Username will be used for display purposes. DisplayName string // a key/value map of fields that this account has added to their profile - Fields map[string]string + Fields []Field // A note that this account has on their profile (ie., the account's bio/description of themselves) Note string // Is this a memorial account, ie., has the user passed away? @@ -85,8 +100,6 @@ type Account struct { URI string `pg:",unique"` // At which URL can we see the user account in a web browser? URL string `pg:",unique"` - // RemoteURL where this account is located. Will be empty if this is a local account. - RemoteURL string `pg:",unique"` // Last time this account was located using the webfinger API. LastWebfingeredAt time.Time `pg:"type:timestamp"` // Address of this account's activitypub inbox, for sending activity to @@ -132,47 +145,8 @@ type Account struct { SuspensionOrigin int } -// Avatar represents the avatar for the account for display purposes -type Avatar struct { - // File name of the avatar on local storage - AvatarFileName string - // Gif? png? jpeg? - AvatarContentType string - AvatarFileSize int - AvatarUpdatedAt *time.Time `pg:"type:timestamp"` - // Where can we retrieve the avatar? - AvatarRemoteURL *url.URL `pg:"type:text"` - AvatarStorageSchemaVersion int -} - -// Header represents the header of the account for display purposes -type Header struct { - // File name of the header on local storage - HeaderFileName string - // Gif? png? jpeg? - HeaderContentType string - HeaderFileSize int - HeaderUpdatedAt *time.Time `pg:"type:timestamp"` - // Where can we retrieve the header? - HeaderRemoteURL *url.URL `pg:"type:text"` - HeaderStorageSchemaVersion int -} - -// ToMastoSensitive returns this account as a mastodon api type, ready for serialization -func (a *Account) ToMastoSensitive() *mastotypes.Account { - return &mastotypes.Account{ - ID: a.ID, - Username: a.Username, - Acct: a.Username, // equivalent to username for local users only, which sensitive always is - DisplayName: a.DisplayName, - Locked: a.Locked, - Bot: a.Bot, - CreatedAt: a.CreatedAt.Format(time.RFC3339), - Note: a.Note, - URL: a.URL, - Avatar: a.Avatar.AvatarRemoteURL.String(), - AvatarStatic: a.AvatarRemoteURL.String(), - Header: a.Header.HeaderRemoteURL.String(), - HeaderStatic: a.Header.HeaderRemoteURL.String(), - } +type Field struct { + Name string + Value string + VerifiedAt time.Time `pg:"type:timestamp"` } diff --git a/internal/gtsmodel/application.go b/internal/db/model/application.go similarity index 99% rename from internal/gtsmodel/application.go rename to internal/db/model/application.go index 1478a24..41a7deb 100644 --- a/internal/gtsmodel/application.go +++ b/internal/db/model/application.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package gtsmodel +package model import "github.com/gotosocial/gotosocial/pkg/mastotypes" diff --git a/internal/gtsmodel/follow.go b/internal/db/model/follow.go similarity index 98% rename from internal/gtsmodel/follow.go rename to internal/db/model/follow.go index e0c6616..4c67b60 100644 --- a/internal/gtsmodel/follow.go +++ b/internal/db/model/follow.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package gtsmodel +package model import "time" diff --git a/internal/gtsmodel/status.go b/internal/db/model/status.go similarity index 99% rename from internal/gtsmodel/status.go rename to internal/db/model/status.go index 1c0e920..d152587 100644 --- a/internal/gtsmodel/status.go +++ b/internal/db/model/status.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package gtsmodel +package model import "time" diff --git a/internal/gtsmodel/user.go b/internal/db/model/user.go similarity index 99% rename from internal/gtsmodel/user.go rename to internal/db/model/user.go index 551cbe2..24de1cc 100644 --- a/internal/gtsmodel/user.go +++ b/internal/db/model/user.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package gtsmodel +package model import ( "net" diff --git a/internal/db/pg.go b/internal/db/pg.go index c07f00e..a033801 100644 --- a/internal/db/pg.go +++ b/internal/db/pg.go @@ -31,7 +31,8 @@ import ( "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" "github.com/gotosocial/gotosocial/internal/config" - "github.com/gotosocial/gotosocial/internal/gtsmodel" + "github.com/gotosocial/gotosocial/internal/db/model" + "github.com/gotosocial/gotosocial/pkg/mastotypes" "github.com/sirupsen/logrus" ) @@ -46,7 +47,7 @@ type postgresService struct { // newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. // Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection. -func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (*postgresService, error) { +func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (DB, error) { opts, err := derivePGOptions(c) if err != nil { return nil, fmt.Errorf("could not create postgres service: %s", err) @@ -108,10 +109,6 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry }, nil } -func (ps *postgresService) Federation() pub.Database { - return ps.federationDB -} - /* HANDY STUFF */ @@ -168,9 +165,29 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { } /* - EXTRA FUNCTIONS + FEDERATION FUNCTIONALITY */ +func (ps *postgresService) Federation() pub.Database { + return ps.federationDB +} + +/* + BASIC DB FUNCTIONALITY +*/ + +func (ps *postgresService) CreateTable(i interface{}) error { + return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{ + IfNotExists: true, + }) +} + +func (ps *postgresService) DropTable(i interface{}) error { + return ps.conn.Model(i).DropTable(&orm.DropTableOptions{ + IfExists: true, + }) +} + func (ps *postgresService) Stop(ctx context.Context) error { ps.log.Info("closing db connection") if err := ps.conn.Close(); err != nil { @@ -181,11 +198,15 @@ func (ps *postgresService) Stop(ctx context.Context) error { return nil } +func (ps *postgresService) IsHealthy(ctx context.Context) error { + return ps.conn.Ping(ctx) +} + func (ps *postgresService) CreateSchema(ctx context.Context) error { models := []interface{}{ - (*gtsmodel.Account)(nil), - (*gtsmodel.Status)(nil), - (*gtsmodel.User)(nil), + (*model.Account)(nil), + (*model.Status)(nil), + (*model.User)(nil), } ps.log.Info("creating db schema") @@ -202,32 +223,35 @@ func (ps *postgresService) CreateSchema(ctx context.Context) error { return nil } -func (ps *postgresService) IsHealthy(ctx context.Context) error { - return ps.conn.Ping(ctx) -} - -func (ps *postgresService) CreateTable(i interface{}) error { - return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }) -} - -func (ps *postgresService) DropTable(i interface{}) error { - return ps.conn.Model(i).DropTable(&orm.DropTableOptions{ - IfExists: true, - }) -} - func (ps *postgresService) GetByID(id string, i interface{}) error { - return ps.conn.Model(i).Where("id = ?", id).Select() + if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + + } + return nil } func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error { - return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select() + if err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil } func (ps *postgresService) GetAll(i interface{}) error { - return ps.conn.Model(i).Select() + if err := ps.conn.Model(i).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil } func (ps *postgresService) Put(i interface{}) error { @@ -236,34 +260,207 @@ func (ps *postgresService) Put(i interface{}) error { } func (ps *postgresService) UpdateByID(id string, i interface{}) error { - _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert() - return err + if _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil } func (ps *postgresService) DeleteByID(id string, i interface{}) error { - _, err := ps.conn.Model(i).Where("id = ?", id).Delete() - return err + if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil } func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error { - _, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete() - return err + if _, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil } -func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error { - user := >smodel.User{ +/* + HANDY SHORTCUTS +*/ + +func (ps *postgresService) GetAccountByUserID(userID string, account *model.Account) error { + user := &model.User{ ID: userID, } if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } return err } - return ps.conn.Model(account).Where("id = ?", user.AccountID).Select() + if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil } -func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error { - return ps.conn.Model(following).Where("account_id = ?", accountID).Select() +func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]model.Follow) error { + if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil } -func (ps *postgresService) GetFollowersByAccountID(accountID string, following *[]gtsmodel.Follow) error { - return ps.conn.Model(following).Where("target_account_id = ?", accountID).Select() +func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error { + if err := ps.conn.Model(followers).Where("target_account_id = ?", accountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) GetStatusesByAccountID(accountID string, statuses *[]model.Status) error { + if err := ps.conn.Model(statuses).Where("account_id = ?", accountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error { + q := ps.conn.Model(statuses).Order("created_at DESC") + if limit != 0 { + q = q.Limit(limit) + } + if accountID != "" { + q = q.Where("account_id = ?", accountID) + } + if err := q.Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *model.Status) error { + if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil + +} + +/* + CONVERSION FUNCTIONS +*/ + +// AccountToMastoSensitive takes an internal account model and transforms it into an account ready to be served through the API. +// The resulting account fits the specifications for the path /api/v1/accounts/verify_credentials, as described here: +// https://docs.joinmastodon.org/methods/accounts/. Note that it's *sensitive* because it's only meant to be exposed to the user +// that the account actually belongs to. +func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) { + + fields := []mastotypes.Field{} + for _, f := range a.Fields { + mField := mastotypes.Field{ + Name: f.Name, + Value: f.Value, + } + if !f.VerifiedAt.IsZero() { + mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339) + } + fields = append(fields, mField) + } + fmt.Printf("fields: %+v", fields) + + // count followers + var followers []model.Follow + if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting followers: %s", err) + } + } + var followersCount int + if followers != nil { + followersCount = len(followers) + } + + // count following + var following []model.Follow + if err := ps.GetFollowingByAccountID(a.ID, &following); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting following: %s", err) + } + } + var followingCount int + if following != nil { + followingCount = len(following) + } + + // count statuses + var statuses []model.Status + if err := ps.GetStatusesByAccountID(a.ID, &statuses); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting last statuses: %s", err) + } + } + var statusesCount int + if statuses != nil { + statusesCount = len(statuses) + } + + // check when the last status was + var lastStatus *model.Status + if err := ps.GetLastStatusForAccountID(a.ID, lastStatus); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting last status: %s", err) + } + } + var lastStatusAt string + if lastStatus != nil { + lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339) + } + + return &mastotypes.Account{ + ID: a.ID, + Username: a.Username, + Acct: a.Username, // equivalent to username for local users only, which sensitive always is + DisplayName: a.DisplayName, + Locked: a.Locked, + Bot: a.Bot, + CreatedAt: a.CreatedAt.Format(time.RFC3339), + Note: a.Note, + URL: a.URL, + Avatar: a.AvatarRemoteURL.String(), + AvatarStatic: a.AvatarRemoteURL.String(), + Header: a.HeaderRemoteURL.String(), + HeaderStatic: a.HeaderRemoteURL.String(), + FollowersCount: followersCount, + FollowingCount: followingCount, + StatusesCount: statusesCount, + LastStatusAt: lastStatusAt, + Source: nil, + Emojis: nil, + Fields: fields, + }, nil } diff --git a/internal/module/account/account.go b/internal/module/account/account.go index 06d40dd..36af90c 100644 --- a/internal/module/account/account.go +++ b/internal/module/account/account.go @@ -19,13 +19,13 @@ package account import ( + "fmt" "net/http" - "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/db" - "github.com/gotosocial/gotosocial/internal/gtsmodel" + "github.com/gotosocial/gotosocial/internal/db/model" "github.com/gotosocial/gotosocial/internal/module" "github.com/gotosocial/gotosocial/internal/module/oauth" "github.com/gotosocial/gotosocial/internal/router" @@ -56,19 +56,33 @@ func (m *accountModule) Route(r router.Router) error { return nil } +// AccountVerifyGETHandler serves a user's account details to them IF they reached this +// handler while in possession of a valid token, according to the oauth middleware. func (m *accountModule) AccountVerifyGETHandler(c *gin.Context) { - s := sessions.Default(c) - userID, ok := s.Get(oauth.SessionAuthorizedUser).(string) + i, ok := c.Get(oauth.SessionAuthorizedUser) + fmt.Println(i) + if !ok { + c.JSON(http.StatusUnauthorized, gin.H{"error": "The access token is invalid"}) + return + } + + userID, ok := (i).(string) if !ok || userID == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "The access token is invalid"}) return } - acct := >smodel.Account{} + acct := &model.Account{} if err := m.db.GetAccountByUserID(userID, acct); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err}) - return + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } - c.JSON(http.StatusOK, acct.ToMastoSensitive()) + acctSensitive, err := m.db.AccountToMastoSensitive(acct) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, acctSensitive) } diff --git a/internal/module/account/account_test.go b/internal/module/account/account_test.go new file mode 100644 index 0000000..37ba727 --- /dev/null +++ b/internal/module/account/account_test.go @@ -0,0 +1,233 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package account + +import ( + "context" + "fmt" + "net/url" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/gotosocial/gotosocial/internal/config" + "github.com/gotosocial/gotosocial/internal/db" + "github.com/gotosocial/gotosocial/internal/db/model" + "github.com/gotosocial/gotosocial/internal/module/oauth" + "github.com/gotosocial/gotosocial/internal/router" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/suite" + "golang.org/x/crypto/bcrypt" +) + +type AccountTestSuite struct { + suite.Suite + db db.DB + testAccountLocal *model.Account + testAccountRemote *model.Account + testUser *model.User + config *config.Config +} + +// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout +func (suite *AccountTestSuite) SetupSuite() { + c := config.Empty() + c.DBConfig = &config.DBConfig{ + Type: "postgres", + Address: "localhost", + Port: 5432, + User: "postgres", + Password: "postgres", + Database: "postgres", + ApplicationName: "gotosocial", + } + suite.config = c + + encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) + if err != nil { + logrus.Panicf("error encrypting user pass: %s", err) + } + + localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png") + if err != nil { + logrus.Panicf("error parsing localavatar url: %s", err) + } + localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png") + if err != nil { + logrus.Panicf("error parsing localheader url: %s", err) + } + + acctID := uuid.NewString() + suite.testAccountLocal = &model.Account{ + ID: acctID, + Username: "local_account_of_some_kind", + AvatarRemoteURL: localAvatar, + HeaderRemoteURL: localHeader, + DisplayName: "michael caine", + Fields: []model.Field{ + { + Name: "come and ave a go", + Value: "if you think you're hard enough", + }, + { + Name: "website", + Value: "https://imdb.com", + VerifiedAt: time.Now(), + }, + }, + Note: "My name is Michael Caine and i'm a local user.", + Discoverable: true, + } + + avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png") + if err != nil { + logrus.Panicf("error parsing avatarURL: %s", err) + } + + headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png") + if err != nil { + logrus.Panicf("error parsing avatarURL: %s", err) + } + suite.testAccountRemote = &model.Account{ + ID: uuid.NewString(), + Username: "neato_bombeato", + Domain: "example.org", + + AvatarFileName: "avatar.png", + AvatarContentType: "image/png", + AvatarFileSize: 1024, + AvatarUpdatedAt: time.Now(), + AvatarRemoteURL: avatarURL, + + HeaderFileName: "avatar.png", + HeaderContentType: "image/png", + HeaderFileSize: 1024, + HeaderUpdatedAt: time.Now(), + HeaderRemoteURL: headerURL, + + DisplayName: "one cool dude 420", + Fields: []model.Field{ + { + Name: "pronouns", + Value: "he/they", + }, + { + Name: "website", + Value: "https://imcool.edu", + VerifiedAt: time.Now(), + }, + }, + Note: "

I'm cool as heck!

", + Discoverable: true, + URI: "https://example.org/users/neato_bombeato", + URL: "https://example.org/@neato_bombeato", + LastWebfingeredAt: time.Now(), + InboxURL: "https://example.org/users/neato_bombeato/inbox", + OutboxURL: "https://example.org/users/neato_bombeato/outbox", + SharedInboxURL: "https://example.org/inbox", + FollowersURL: "https://example.org/users/neato_bombeato/followers", + FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured", + } + suite.testUser = &model.User{ + ID: uuid.NewString(), + EncryptedPassword: string(encryptedPassword), + Email: "user@example.org", + AccountID: acctID, + } +} + +// SetupTest creates a postgres connection and creates the oauth_clients table before each test +func (suite *AccountTestSuite) SetupTest() { + + log := logrus.New() + log.SetLevel(logrus.TraceLevel) + db, err := db.New(context.Background(), suite.config, log) + if err != nil { + logrus.Panicf("error creating database connection: %s", err) + } + + suite.db = db + + models := []interface{}{ + &model.User{}, + &model.Account{}, + &model.Follow{}, + &model.Status{}, + } + + for _, m := range models { + if err := suite.db.CreateTable(m); err != nil { + logrus.Panicf("db connection error: %s", err) + } + } + + if err := suite.db.Put(suite.testAccountLocal); err != nil { + logrus.Panicf("could not insert test account into db: %s", err) + } + if err := suite.db.Put(suite.testUser); err != nil { + logrus.Panicf("could not insert test user into db: %s", err) + } + +} + +// TearDownTest drops the oauth_clients table and closes the pg connection after each test +func (suite *AccountTestSuite) TearDownTest() { + models := []interface{}{ + &model.User{}, + &model.Account{}, + &model.Follow{}, + &model.Status{}, + } + for _, m := range models { + if err := suite.db.DropTable(m); err != nil { + logrus.Panicf("error dropping table: %s", err) + } + } + if err := suite.db.Stop(context.Background()); err != nil { + logrus.Panicf("error closing db connection: %s", err) + } + suite.db = nil +} + +func (suite *AccountTestSuite) TestAPIInitialize() { + log := logrus.New() + log.SetLevel(logrus.TraceLevel) + + r, err := router.New(suite.config, log) + if err != nil { + suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) + } + + r.AttachMiddleware(func(c *gin.Context) { + c.Set(oauth.SessionAuthorizedUser, suite.testUser.ID) + }) + + acct := New(suite.config, suite.db) + acct.Route(r) + + r.Start() + defer r.Stop(context.Background()) + time.Sleep(10 * time.Second) + +} + +func TestAccountTestSuite(t *testing.T) { + suite.Run(t, new(AccountTestSuite)) +} diff --git a/internal/module/oauth/clientstore.go b/internal/module/oauth/clientstore.go index f99c160..45b518c 100644 --- a/internal/module/oauth/clientstore.go +++ b/internal/module/oauth/clientstore.go @@ -20,7 +20,6 @@ package oauth import ( "context" - "fmt" "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/oauth2/v4" @@ -43,7 +42,7 @@ func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.Cli ID: clientID, } if err := cs.db.GetByID(clientID, poc); err != nil { - return nil, fmt.Errorf("database error: %s", err) + return nil, err } return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil } diff --git a/internal/module/oauth/clientstore_test.go b/internal/module/oauth/clientstore_test.go index bca0024..8401142 100644 --- a/internal/module/oauth/clientstore_test.go +++ b/internal/module/oauth/clientstore_test.go @@ -136,7 +136,7 @@ func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() { // try to get the deleted client; we should get an error deletedClient, err := cs.GetByID(context.Background(), suite.testClientID) suite.Assert().Nil(deletedClient) - suite.Assert().NotNil(err) + suite.Assert().EqualValues(db.ErrNoEntries{}, err) } func TestPgClientStoreTestSuite(t *testing.T) { diff --git a/internal/module/oauth/oauth.go b/internal/module/oauth/oauth.go index 9536dcd..2f77561 100644 --- a/internal/module/oauth/oauth.go +++ b/internal/module/oauth/oauth.go @@ -34,7 +34,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/gotosocial/gotosocial/internal/db" - "github.com/gotosocial/gotosocial/internal/gtsmodel" + "github.com/gotosocial/gotosocial/internal/db/model" "github.com/gotosocial/gotosocial/internal/module" "github.com/gotosocial/gotosocial/internal/router" "github.com/gotosocial/gotosocial/pkg/mastotypes" @@ -47,10 +47,10 @@ import ( ) const ( - appsPath = "/api/v1/apps" - authSignInPath = "/auth/sign_in" - oauthTokenPath = "/oauth/token" - oauthAuthorizePath = "/oauth/authorize" + appsPath = "/api/v1/apps" + authSignInPath = "/auth/sign_in" + oauthTokenPath = "/oauth/token" + oauthAuthorizePath = "/oauth/authorize" SessionAuthorizedUser = "authorized_user" ) @@ -179,7 +179,7 @@ func (m *oauthModule) appsPOSTHandler(c *gin.Context) { vapidKey := uuid.NewString() // generate the application to put in the database - app := >smodel.Application{ + app := &model.Application{ Name: form.ClientName, Website: form.Website, RedirectURI: form.RedirectURIs, @@ -287,7 +287,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"}) return } - app := >smodel.Application{ + app := &model.Application{ ClientID: clientID, } if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil { @@ -296,7 +296,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) { } // we can also use the userid of the user to fetch their username from the db to greet them nicely <3 - user := >smodel.User{ + user := &model.User{ ID: userID, } if err := m.db.GetByID(user.ID, user); err != nil { @@ -304,7 +304,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) { return } - acct := >smodel.Account{ + acct := &model.Account{ ID: user.AccountID, } @@ -413,7 +413,6 @@ func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) { if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil { l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope()) c.Set(SessionAuthorizedUser, ti.GetUserID()) - } else { l.Trace("continuing with unauthenticated request") } @@ -437,7 +436,7 @@ func (m *oauthModule) validatePassword(email string, password string) (userid st } // first we select the user from the database based on email address, bail if no user found for that email - gtsUser := >smodel.User{} + gtsUser := &model.User{} if err := m.db.GetWhere("email", email, gtsUser); err != nil { l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) diff --git a/internal/module/oauth/oauth_test.go b/internal/module/oauth/oauth_test.go index adfb40a..7dcff0d 100644 --- a/internal/module/oauth/oauth_test.go +++ b/internal/module/oauth/oauth_test.go @@ -22,12 +22,11 @@ import ( "context" "fmt" "testing" - "time" "github.com/google/uuid" "github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/db" - "github.com/gotosocial/gotosocial/internal/gtsmodel" + "github.com/gotosocial/gotosocial/internal/db/model" "github.com/gotosocial/gotosocial/internal/router" "github.com/gotosocial/oauth2/v4" "github.com/sirupsen/logrus" @@ -40,9 +39,9 @@ type OauthTestSuite struct { tokenStore oauth2.TokenStore clientStore oauth2.ClientStore db db.DB - testAccount *gtsmodel.Account - testApplication *gtsmodel.Application - testUser *gtsmodel.User + testAccount *model.Account + testApplication *model.Application + testUser *model.User testClient *oauthClient config *config.Config } @@ -76,11 +75,11 @@ func (suite *OauthTestSuite) SetupSuite() { acctID := uuid.NewString() - suite.testAccount = >smodel.Account{ + suite.testAccount = &model.Account{ ID: acctID, Username: "test_user", } - suite.testUser = >smodel.User{ + suite.testUser = &model.User{ EncryptedPassword: string(encryptedPassword), Email: "user@example.org", AccountID: acctID, @@ -90,7 +89,7 @@ func (suite *OauthTestSuite) SetupSuite() { Secret: "some-secret", Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host), } - suite.testApplication = >smodel.Application{ + suite.testApplication = &model.Application{ Name: "a test application", Website: "https://some-application-website.com", RedirectURI: "http://localhost:8080", @@ -116,9 +115,9 @@ func (suite *OauthTestSuite) SetupTest() { models := []interface{}{ &oauthClient{}, &oauthToken{}, - >smodel.User{}, - >smodel.Account{}, - >smodel.Application{}, + &model.User{}, + &model.Account{}, + &model.Application{}, } for _, m := range models { @@ -150,9 +149,9 @@ func (suite *OauthTestSuite) TearDownTest() { models := []interface{}{ &oauthClient{}, &oauthToken{}, - >smodel.User{}, - >smodel.Account{}, - >smodel.Application{}, + &model.User{}, + &model.Account{}, + &model.Application{}, } for _, m := range models { if err := suite.db.DropTable(m); err != nil { @@ -179,11 +178,10 @@ func (suite *OauthTestSuite) TestAPIInitialize() { suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) } - go r.Start() - time.Sleep(60 * time.Second) - // http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=http://localhost:8080&scope=read - // curl -v -F client_id=a-known-client-id -F client_secret=some-secret -F redirect_uri=http://localhost:8080 -F code=[ INSERT CODE HERE ] -F grant_type=authorization_code localhost:8080/oauth/token - // curl -v -H "Authorization: Bearer [INSERT TOKEN HERE]" http://localhost:8080 + r.Start() + if err := r.Stop(context.Background()); err != nil { + suite.FailNow(fmt.Sprintf("error stopping router: %s", err)) + } } func TestOauthTestSuite(t *testing.T) { diff --git a/internal/router/router.go b/internal/router/router.go index 3893503..779a91d 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -19,8 +19,10 @@ package router import ( + "context" "crypto/rand" "fmt" + "net/http" "os" "path/filepath" @@ -40,26 +42,28 @@ type Router interface { // Start the router Start() // Stop the router - Stop() + Stop(ctx context.Context) error } // router fulfils the Router interface using gin and logrus type router struct { logger *logrus.Logger engine *gin.Engine + srv *http.Server } // Start starts the router nicely -func (s *router) Start() { - // todo: start gracefully - if err := s.engine.Run(); err != nil { - s.logger.Panicf("server error: %s", err) - } +func (r *router) Start() { + go func() { + if err := r.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + r.logger.Fatalf("listen: %s", err) + } + }() } // Stop shuts down the router nicely -func (s *router) Stop() { - // todo: shut down gracefully +func (s *router) Stop(ctx context.Context) error { + return s.srv.Shutdown(ctx) } // AttachHandler attaches the given gin.HandlerFunc to the router with the specified method and path. @@ -100,6 +104,10 @@ func New(config *config.Config, logger *logrus.Logger) (Router, error) { return &router{ logger: logger, engine: engine, + srv: &http.Server{ + Addr: ":8080", + Handler: engine, + }, }, nil } diff --git a/pkg/mastotypes/account.go b/pkg/mastotypes/account.go index 031fa7c..ca50c64 100644 --- a/pkg/mastotypes/account.go +++ b/pkg/mastotypes/account.go @@ -31,7 +31,7 @@ type Account struct { // Whether the account manually approves follow requests. Locked bool `json:"locked"` // Whether the account has opted into discovery features such as the profile directory. - Discoverable bool `json:"discoverable"` + Discoverable bool `json:"discoverable,omitempty"` // A presentational flag. Indicates that the account may perform automated actions, may not be monitored, or identifies as a robot. Bot bool `json:"bot"` // When the account was created. (ISO 8601 Datetime) @@ -61,9 +61,9 @@ type Account struct { // Additional metadata attached to a profile as name-value pairs. Fields []Field `json:"fields"` // An extra entity returned when an account is suspended. - Suspended bool `json:"suspended"` + Suspended bool `json:"suspended,omitempty"` // When a timed mute will expire, if applicable. (ISO 8601 Datetime) - MuteExpiresAt string `json:"mute_expires_at"` + MuteExpiresAt string `json:"mute_expires_at,omitempty"` // An extra entity to be used with API methods to verify credentials and update credentials. Source *Source `json:"source"` } diff --git a/pkg/mastotypes/field.go b/pkg/mastotypes/field.go index dbfe08c..29b5a18 100644 --- a/pkg/mastotypes/field.go +++ b/pkg/mastotypes/field.go @@ -28,7 +28,6 @@ type Field struct { Value string `json:"value"` // OPTIONAL - // Timestamp of when the server verified a URL value for a rel="me” link. String (ISO 8601 Datetime) if value is a verified URL VerifiedAt string `json:"verified_at,omitempty"` }