diff --git a/internal/db/db.go b/internal/db/db.go
index 9074c23..54a369e 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -25,59 +25,122 @@ import (
"github.com/go-fed/activity/pub"
"github.com/gotosocial/gotosocial/internal/config"
- "github.com/gotosocial/gotosocial/internal/gtsmodel"
+ "github.com/gotosocial/gotosocial/internal/db/model"
+ "github.com/gotosocial/gotosocial/pkg/mastotypes"
"github.com/sirupsen/logrus"
)
const dbTypePostgres string = "POSTGRES"
+type ErrNoEntries struct{}
+
+func (e ErrNoEntries) Error() string {
+ return "no entries"
+}
+
// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres).
+// Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated
+// by whatever is returned from the database.
type DB interface {
// Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions.
// See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database
Federation() pub.Database
- // CreateTable creates a table for the given interface
+ /*
+ BASIC DB FUNCTIONALITY
+ */
+
+ // CreateTable creates a table for the given interface.
+ // For implementations that don't use tables, this can just return nil.
CreateTable(i interface{}) error
- // DropTable drops the table for the given interface
+ // DropTable drops the table for the given interface.
+ // For implementations that don't use tables, this can just return nil.
DropTable(i interface{}) error
- // Stop should stop and close the database connection cleanly, returning an error if this is not possible
+ // Stop should stop and close the database connection cleanly, returning an error if this is not possible.
+ // If the database implementation doesn't need to be stopped, this can just return nil.
Stop(ctx context.Context) error
- // IsHealthy should return nil if the database connection is healthy, or an error if not
+ // IsHealthy should return nil if the database connection is healthy, or an error if not.
IsHealthy(ctx context.Context) error
- // GetByID gets one entry by its id.
+ // GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
+ // for other implementations (for example, in-memory) it might just be the key of a map.
+ // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
+ // In case of no entries, a 'no entries' error will be returned
GetByID(id string, i interface{}) error
- // GetWhere gets one entry where key = value
+ // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
+ // name of the key to select from.
+ // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
+ // In case of no entries, a 'no entries' error will be returned
GetWhere(key string, value interface{}, i interface{}) error
- // GetAll gets all entries of interface type i
+ // GetAll will try to get all entries of type i.
+ // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
+ // In case of no entries, a 'no entries' error will be returned
GetAll(i interface{}) error
- // Put stores i
+ // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
+ // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(i interface{}) error
- // Update by id updates i with id id
+ // UpdateByID updates i with id id.
+ // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(id string, i interface{}) error
- // Delete by id removes i with id id
+ // DeleteByID removes i with id id.
+ // If i didn't exist anyway, then no error should be returned.
DeleteByID(id string, i interface{}) error
- // Delete where deletes i where key = value
+ // DeleteWhere deletes i where key = value
+ // If i didn't exist anyway, then no error should be returned.
DeleteWhere(key string, value interface{}, i interface{}) error
- // GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID
- GetAccountByUserID(userID string, account *gtsmodel.Account) error
+ /*
+ HANDY SHORTCUTS
+ */
- // GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following
- GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error
+ // GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID.
+ // The given account pointer will be set to the result of the query, whatever it is.
+ // In case of no entries, a 'no entries' error will be returned
+ GetAccountByUserID(userID string, account *model.Account) error
- // GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by
- GetFollowersByAccountID(accountID string, following *[]gtsmodel.Follow) error
+ // GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following.
+ // The given slice 'following' will be set to the result of the query, whatever it is.
+ // In case of no entries, a 'no entries' error will be returned
+ GetFollowingByAccountID(accountID string, following *[]model.Follow) error
+
+ // GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by.
+ // The given slice 'followers' will be set to the result of the query, whatever it is.
+ // In case of no entries, a 'no entries' error will be returned
+ GetFollowersByAccountID(accountID string, followers *[]model.Follow) error
+
+ // GetStatusesByAccountID is a shortcut for the common action of fetching a list of statuses produced by accountID.
+ // The given slice 'statuses' will be set to the result of the query, whatever it is.
+ // In case of no entries, a 'no entries' error will be returned
+ GetStatusesByAccountID(accountID string, statuses *[]model.Status) error
+
+ // GetStatusesByTimeDescending is a shortcut for getting the most recent statuses. accountID is optional, if not provided
+ // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
+ // be very memory intensive so you probably shouldn't do this!
+ // In case of no entries, a 'no entries' error will be returned
+ GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error
+
+ // GetLastStatusForAccountID simply gets the most recent status by the given account.
+ // The given slice 'status' pointer will be set to the result of the query, whatever it is.
+ // In case of no entries, a 'no entries' error will be returned
+ GetLastStatusForAccountID(accountID string, status *model.Status) error
+
+ /*
+ USEFUL CONVERSION FUNCTIONS
+ */
+
+ // AccountToMastoSensitive takes a db model account as a param, and returns a populated mastotype account, or an error
+ // if something goes wrong. The returned account should be ready to serialize on an API level, and may have sensitive fields,
+ // so serve it only to an authorized user who should have permission to see it.
+ AccountToMastoSensitive(account *model.Account) (*mastotypes.Account, error)
}
// New returns a new database service that satisfies the DB interface and, by extension,
diff --git a/internal/gtsmodel/README.md b/internal/db/model/README.md
similarity index 100%
rename from internal/gtsmodel/README.md
rename to internal/db/model/README.md
diff --git a/internal/gtsmodel/account.go b/internal/db/model/account.go
similarity index 73%
rename from internal/gtsmodel/account.go
rename to internal/db/model/account.go
index d11f676..130347c 100644
--- a/internal/gtsmodel/account.go
+++ b/internal/db/model/account.go
@@ -16,16 +16,15 @@
along with this program. If not, see .
*/
-// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database.
+// Package model contains types used *internally* by GoToSocial and added/removed/selected from the database.
// These types should never be serialized and/or sent out via public APIs, as they contain sensitive information.
-// The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/
-package gtsmodel
+// The annotation used on these structs is for handling them via the go-pg ORM (hence why they're in this db subdir).
+// See here for more info on go-pg model annotations: https://pg.uptrace.dev/models/
+package model
import (
"net/url"
"time"
-
- "github.com/gotosocial/gotosocial/pkg/mastotypes"
)
// Account represents either a local or a remote fediverse account, gotosocial or otherwise (mastodon, pleroma, etc)
@@ -39,20 +38,36 @@ type Account struct {
// Username of the account, should just be a string of [a-z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``
Username string `pg:",notnull,unique:userdomain"` // username and domain should be unique *with* each other
// Domain of the account, will be empty if this is a local account, otherwise something like ``example.org`` or ``mastodon.social``. Should be unique with username.
- Domain string `pg:",unique:userdomain"` // username and domain
+ Domain string `pg:",unique:userdomain"` // username and domain should be unique *with* each other
/*
ACCOUNT METADATA
*/
- // Avatar image for this account
- Avatar
- // Header image for this account
- Header
+ // File name of the avatar on local storage
+ AvatarFileName string
+ // Gif? png? jpeg?
+ AvatarContentType string
+ // Size of the avatar in bytes
+ AvatarFileSize int
+ // When was the avatar last updated?
+ AvatarUpdatedAt time.Time `pg:"type:timestamp"`
+ // Where can the avatar be retrieved?
+ AvatarRemoteURL *url.URL `pg:"type:text"`
+ // File name of the header on local storage
+ HeaderFileName string
+ // Gif? png? jpeg?
+ HeaderContentType string
+ // Size of the header in bytes
+ HeaderFileSize int
+ // When was the header last updated?
+ HeaderUpdatedAt time.Time `pg:"type:timestamp"`
+ // Where can the header be retrieved?
+ HeaderRemoteURL *url.URL `pg:"type:text"`
// DisplayName for this account. Can be empty, then just the Username will be used for display purposes.
DisplayName string
// a key/value map of fields that this account has added to their profile
- Fields map[string]string
+ Fields []Field
// A note that this account has on their profile (ie., the account's bio/description of themselves)
Note string
// Is this a memorial account, ie., has the user passed away?
@@ -85,8 +100,6 @@ type Account struct {
URI string `pg:",unique"`
// At which URL can we see the user account in a web browser?
URL string `pg:",unique"`
- // RemoteURL where this account is located. Will be empty if this is a local account.
- RemoteURL string `pg:",unique"`
// Last time this account was located using the webfinger API.
LastWebfingeredAt time.Time `pg:"type:timestamp"`
// Address of this account's activitypub inbox, for sending activity to
@@ -132,47 +145,8 @@ type Account struct {
SuspensionOrigin int
}
-// Avatar represents the avatar for the account for display purposes
-type Avatar struct {
- // File name of the avatar on local storage
- AvatarFileName string
- // Gif? png? jpeg?
- AvatarContentType string
- AvatarFileSize int
- AvatarUpdatedAt *time.Time `pg:"type:timestamp"`
- // Where can we retrieve the avatar?
- AvatarRemoteURL *url.URL `pg:"type:text"`
- AvatarStorageSchemaVersion int
-}
-
-// Header represents the header of the account for display purposes
-type Header struct {
- // File name of the header on local storage
- HeaderFileName string
- // Gif? png? jpeg?
- HeaderContentType string
- HeaderFileSize int
- HeaderUpdatedAt *time.Time `pg:"type:timestamp"`
- // Where can we retrieve the header?
- HeaderRemoteURL *url.URL `pg:"type:text"`
- HeaderStorageSchemaVersion int
-}
-
-// ToMastoSensitive returns this account as a mastodon api type, ready for serialization
-func (a *Account) ToMastoSensitive() *mastotypes.Account {
- return &mastotypes.Account{
- ID: a.ID,
- Username: a.Username,
- Acct: a.Username, // equivalent to username for local users only, which sensitive always is
- DisplayName: a.DisplayName,
- Locked: a.Locked,
- Bot: a.Bot,
- CreatedAt: a.CreatedAt.Format(time.RFC3339),
- Note: a.Note,
- URL: a.URL,
- Avatar: a.Avatar.AvatarRemoteURL.String(),
- AvatarStatic: a.AvatarRemoteURL.String(),
- Header: a.Header.HeaderRemoteURL.String(),
- HeaderStatic: a.Header.HeaderRemoteURL.String(),
- }
+type Field struct {
+ Name string
+ Value string
+ VerifiedAt time.Time `pg:"type:timestamp"`
}
diff --git a/internal/gtsmodel/application.go b/internal/db/model/application.go
similarity index 99%
rename from internal/gtsmodel/application.go
rename to internal/db/model/application.go
index 1478a24..41a7deb 100644
--- a/internal/gtsmodel/application.go
+++ b/internal/db/model/application.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package gtsmodel
+package model
import "github.com/gotosocial/gotosocial/pkg/mastotypes"
diff --git a/internal/gtsmodel/follow.go b/internal/db/model/follow.go
similarity index 98%
rename from internal/gtsmodel/follow.go
rename to internal/db/model/follow.go
index e0c6616..4c67b60 100644
--- a/internal/gtsmodel/follow.go
+++ b/internal/db/model/follow.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package gtsmodel
+package model
import "time"
diff --git a/internal/gtsmodel/status.go b/internal/db/model/status.go
similarity index 99%
rename from internal/gtsmodel/status.go
rename to internal/db/model/status.go
index 1c0e920..d152587 100644
--- a/internal/gtsmodel/status.go
+++ b/internal/db/model/status.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package gtsmodel
+package model
import "time"
diff --git a/internal/gtsmodel/user.go b/internal/db/model/user.go
similarity index 99%
rename from internal/gtsmodel/user.go
rename to internal/db/model/user.go
index 551cbe2..24de1cc 100644
--- a/internal/gtsmodel/user.go
+++ b/internal/db/model/user.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package gtsmodel
+package model
import (
"net"
diff --git a/internal/db/pg.go b/internal/db/pg.go
index c07f00e..a033801 100644
--- a/internal/db/pg.go
+++ b/internal/db/pg.go
@@ -31,7 +31,8 @@ import (
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/gotosocial/gotosocial/internal/config"
- "github.com/gotosocial/gotosocial/internal/gtsmodel"
+ "github.com/gotosocial/gotosocial/internal/db/model"
+ "github.com/gotosocial/gotosocial/pkg/mastotypes"
"github.com/sirupsen/logrus"
)
@@ -46,7 +47,7 @@ type postgresService struct {
// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
-func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (*postgresService, error) {
+func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (DB, error) {
opts, err := derivePGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create postgres service: %s", err)
@@ -108,10 +109,6 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry
}, nil
}
-func (ps *postgresService) Federation() pub.Database {
- return ps.federationDB
-}
-
/*
HANDY STUFF
*/
@@ -168,9 +165,29 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
}
/*
- EXTRA FUNCTIONS
+ FEDERATION FUNCTIONALITY
*/
+func (ps *postgresService) Federation() pub.Database {
+ return ps.federationDB
+}
+
+/*
+ BASIC DB FUNCTIONALITY
+*/
+
+func (ps *postgresService) CreateTable(i interface{}) error {
+ return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
+ IfNotExists: true,
+ })
+}
+
+func (ps *postgresService) DropTable(i interface{}) error {
+ return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
+ IfExists: true,
+ })
+}
+
func (ps *postgresService) Stop(ctx context.Context) error {
ps.log.Info("closing db connection")
if err := ps.conn.Close(); err != nil {
@@ -181,11 +198,15 @@ func (ps *postgresService) Stop(ctx context.Context) error {
return nil
}
+func (ps *postgresService) IsHealthy(ctx context.Context) error {
+ return ps.conn.Ping(ctx)
+}
+
func (ps *postgresService) CreateSchema(ctx context.Context) error {
models := []interface{}{
- (*gtsmodel.Account)(nil),
- (*gtsmodel.Status)(nil),
- (*gtsmodel.User)(nil),
+ (*model.Account)(nil),
+ (*model.Status)(nil),
+ (*model.User)(nil),
}
ps.log.Info("creating db schema")
@@ -202,32 +223,35 @@ func (ps *postgresService) CreateSchema(ctx context.Context) error {
return nil
}
-func (ps *postgresService) IsHealthy(ctx context.Context) error {
- return ps.conn.Ping(ctx)
-}
-
-func (ps *postgresService) CreateTable(i interface{}) error {
- return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- })
-}
-
-func (ps *postgresService) DropTable(i interface{}) error {
- return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
- IfExists: true,
- })
-}
-
func (ps *postgresService) GetByID(id string, i interface{}) error {
- return ps.conn.Model(i).Where("id = ?", id).Select()
+ if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+
+ }
+ return nil
}
func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error {
- return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select()
+ if err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
}
func (ps *postgresService) GetAll(i interface{}) error {
- return ps.conn.Model(i).Select()
+ if err := ps.conn.Model(i).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
}
func (ps *postgresService) Put(i interface{}) error {
@@ -236,34 +260,207 @@ func (ps *postgresService) Put(i interface{}) error {
}
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
- _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert()
- return err
+ if _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
}
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
- _, err := ps.conn.Model(i).Where("id = ?", id).Delete()
- return err
+ if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
}
func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error {
- _, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete()
- return err
+ if _, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
}
-func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error {
- user := >smodel.User{
+/*
+ HANDY SHORTCUTS
+*/
+
+func (ps *postgresService) GetAccountByUserID(userID string, account *model.Account) error {
+ user := &model.User{
ID: userID,
}
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
return err
}
- return ps.conn.Model(account).Where("id = ?", user.AccountID).Select()
+ if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
}
-func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error {
- return ps.conn.Model(following).Where("account_id = ?", accountID).Select()
+func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]model.Follow) error {
+ if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
}
-func (ps *postgresService) GetFollowersByAccountID(accountID string, following *[]gtsmodel.Follow) error {
- return ps.conn.Model(following).Where("target_account_id = ?", accountID).Select()
+func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error {
+ if err := ps.conn.Model(followers).Where("target_account_id = ?", accountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) GetStatusesByAccountID(accountID string, statuses *[]model.Status) error {
+ if err := ps.conn.Model(statuses).Where("account_id = ?", accountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error {
+ q := ps.conn.Model(statuses).Order("created_at DESC")
+ if limit != 0 {
+ q = q.Limit(limit)
+ }
+ if accountID != "" {
+ q = q.Where("account_id = ?", accountID)
+ }
+ if err := q.Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+}
+
+func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *model.Status) error {
+ if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
+ if err == pg.ErrNoRows {
+ return ErrNoEntries{}
+ }
+ return err
+ }
+ return nil
+
+}
+
+/*
+ CONVERSION FUNCTIONS
+*/
+
+// AccountToMastoSensitive takes an internal account model and transforms it into an account ready to be served through the API.
+// The resulting account fits the specifications for the path /api/v1/accounts/verify_credentials, as described here:
+// https://docs.joinmastodon.org/methods/accounts/. Note that it's *sensitive* because it's only meant to be exposed to the user
+// that the account actually belongs to.
+func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) {
+
+ fields := []mastotypes.Field{}
+ for _, f := range a.Fields {
+ mField := mastotypes.Field{
+ Name: f.Name,
+ Value: f.Value,
+ }
+ if !f.VerifiedAt.IsZero() {
+ mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339)
+ }
+ fields = append(fields, mField)
+ }
+ fmt.Printf("fields: %+v", fields)
+
+ // count followers
+ var followers []model.Follow
+ if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil {
+ if _, ok := err.(ErrNoEntries); !ok {
+ return nil, fmt.Errorf("error getting followers: %s", err)
+ }
+ }
+ var followersCount int
+ if followers != nil {
+ followersCount = len(followers)
+ }
+
+ // count following
+ var following []model.Follow
+ if err := ps.GetFollowingByAccountID(a.ID, &following); err != nil {
+ if _, ok := err.(ErrNoEntries); !ok {
+ return nil, fmt.Errorf("error getting following: %s", err)
+ }
+ }
+ var followingCount int
+ if following != nil {
+ followingCount = len(following)
+ }
+
+ // count statuses
+ var statuses []model.Status
+ if err := ps.GetStatusesByAccountID(a.ID, &statuses); err != nil {
+ if _, ok := err.(ErrNoEntries); !ok {
+ return nil, fmt.Errorf("error getting last statuses: %s", err)
+ }
+ }
+ var statusesCount int
+ if statuses != nil {
+ statusesCount = len(statuses)
+ }
+
+ // check when the last status was
+ var lastStatus *model.Status
+ if err := ps.GetLastStatusForAccountID(a.ID, lastStatus); err != nil {
+ if _, ok := err.(ErrNoEntries); !ok {
+ return nil, fmt.Errorf("error getting last status: %s", err)
+ }
+ }
+ var lastStatusAt string
+ if lastStatus != nil {
+ lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339)
+ }
+
+ return &mastotypes.Account{
+ ID: a.ID,
+ Username: a.Username,
+ Acct: a.Username, // equivalent to username for local users only, which sensitive always is
+ DisplayName: a.DisplayName,
+ Locked: a.Locked,
+ Bot: a.Bot,
+ CreatedAt: a.CreatedAt.Format(time.RFC3339),
+ Note: a.Note,
+ URL: a.URL,
+ Avatar: a.AvatarRemoteURL.String(),
+ AvatarStatic: a.AvatarRemoteURL.String(),
+ Header: a.HeaderRemoteURL.String(),
+ HeaderStatic: a.HeaderRemoteURL.String(),
+ FollowersCount: followersCount,
+ FollowingCount: followingCount,
+ StatusesCount: statusesCount,
+ LastStatusAt: lastStatusAt,
+ Source: nil,
+ Emojis: nil,
+ Fields: fields,
+ }, nil
}
diff --git a/internal/module/account/account.go b/internal/module/account/account.go
index 06d40dd..36af90c 100644
--- a/internal/module/account/account.go
+++ b/internal/module/account/account.go
@@ -19,13 +19,13 @@
package account
import (
+ "fmt"
"net/http"
- "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db"
- "github.com/gotosocial/gotosocial/internal/gtsmodel"
+ "github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/module"
"github.com/gotosocial/gotosocial/internal/module/oauth"
"github.com/gotosocial/gotosocial/internal/router"
@@ -56,19 +56,33 @@ func (m *accountModule) Route(r router.Router) error {
return nil
}
+// AccountVerifyGETHandler serves a user's account details to them IF they reached this
+// handler while in possession of a valid token, according to the oauth middleware.
func (m *accountModule) AccountVerifyGETHandler(c *gin.Context) {
- s := sessions.Default(c)
- userID, ok := s.Get(oauth.SessionAuthorizedUser).(string)
+ i, ok := c.Get(oauth.SessionAuthorizedUser)
+ fmt.Println(i)
+ if !ok {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "The access token is invalid"})
+ return
+ }
+
+ userID, ok := (i).(string)
if !ok || userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "The access token is invalid"})
return
}
- acct := >smodel.Account{}
+ acct := &model.Account{}
if err := m.db.GetAccountByUserID(userID, acct); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": err})
- return
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
}
- c.JSON(http.StatusOK, acct.ToMastoSensitive())
+ acctSensitive, err := m.db.AccountToMastoSensitive(acct)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ c.JSON(http.StatusOK, acctSensitive)
}
diff --git a/internal/module/account/account_test.go b/internal/module/account/account_test.go
new file mode 100644
index 0000000..37ba727
--- /dev/null
+++ b/internal/module/account/account_test.go
@@ -0,0 +1,233 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package account
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "testing"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+ "github.com/gotosocial/gotosocial/internal/config"
+ "github.com/gotosocial/gotosocial/internal/db"
+ "github.com/gotosocial/gotosocial/internal/db/model"
+ "github.com/gotosocial/gotosocial/internal/module/oauth"
+ "github.com/gotosocial/gotosocial/internal/router"
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/suite"
+ "golang.org/x/crypto/bcrypt"
+)
+
+type AccountTestSuite struct {
+ suite.Suite
+ db db.DB
+ testAccountLocal *model.Account
+ testAccountRemote *model.Account
+ testUser *model.User
+ config *config.Config
+}
+
+// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
+func (suite *AccountTestSuite) SetupSuite() {
+ c := config.Empty()
+ c.DBConfig = &config.DBConfig{
+ Type: "postgres",
+ Address: "localhost",
+ Port: 5432,
+ User: "postgres",
+ Password: "postgres",
+ Database: "postgres",
+ ApplicationName: "gotosocial",
+ }
+ suite.config = c
+
+ encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
+ if err != nil {
+ logrus.Panicf("error encrypting user pass: %s", err)
+ }
+
+ localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png")
+ if err != nil {
+ logrus.Panicf("error parsing localavatar url: %s", err)
+ }
+ localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png")
+ if err != nil {
+ logrus.Panicf("error parsing localheader url: %s", err)
+ }
+
+ acctID := uuid.NewString()
+ suite.testAccountLocal = &model.Account{
+ ID: acctID,
+ Username: "local_account_of_some_kind",
+ AvatarRemoteURL: localAvatar,
+ HeaderRemoteURL: localHeader,
+ DisplayName: "michael caine",
+ Fields: []model.Field{
+ {
+ Name: "come and ave a go",
+ Value: "if you think you're hard enough",
+ },
+ {
+ Name: "website",
+ Value: "https://imdb.com",
+ VerifiedAt: time.Now(),
+ },
+ },
+ Note: "My name is Michael Caine and i'm a local user.",
+ Discoverable: true,
+ }
+
+ avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png")
+ if err != nil {
+ logrus.Panicf("error parsing avatarURL: %s", err)
+ }
+
+ headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png")
+ if err != nil {
+ logrus.Panicf("error parsing avatarURL: %s", err)
+ }
+ suite.testAccountRemote = &model.Account{
+ ID: uuid.NewString(),
+ Username: "neato_bombeato",
+ Domain: "example.org",
+
+ AvatarFileName: "avatar.png",
+ AvatarContentType: "image/png",
+ AvatarFileSize: 1024,
+ AvatarUpdatedAt: time.Now(),
+ AvatarRemoteURL: avatarURL,
+
+ HeaderFileName: "avatar.png",
+ HeaderContentType: "image/png",
+ HeaderFileSize: 1024,
+ HeaderUpdatedAt: time.Now(),
+ HeaderRemoteURL: headerURL,
+
+ DisplayName: "one cool dude 420",
+ Fields: []model.Field{
+ {
+ Name: "pronouns",
+ Value: "he/they",
+ },
+ {
+ Name: "website",
+ Value: "https://imcool.edu",
+ VerifiedAt: time.Now(),
+ },
+ },
+ Note: "
I'm cool as heck!
",
+ Discoverable: true,
+ URI: "https://example.org/users/neato_bombeato",
+ URL: "https://example.org/@neato_bombeato",
+ LastWebfingeredAt: time.Now(),
+ InboxURL: "https://example.org/users/neato_bombeato/inbox",
+ OutboxURL: "https://example.org/users/neato_bombeato/outbox",
+ SharedInboxURL: "https://example.org/inbox",
+ FollowersURL: "https://example.org/users/neato_bombeato/followers",
+ FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured",
+ }
+ suite.testUser = &model.User{
+ ID: uuid.NewString(),
+ EncryptedPassword: string(encryptedPassword),
+ Email: "user@example.org",
+ AccountID: acctID,
+ }
+}
+
+// SetupTest creates a postgres connection and creates the oauth_clients table before each test
+func (suite *AccountTestSuite) SetupTest() {
+
+ log := logrus.New()
+ log.SetLevel(logrus.TraceLevel)
+ db, err := db.New(context.Background(), suite.config, log)
+ if err != nil {
+ logrus.Panicf("error creating database connection: %s", err)
+ }
+
+ suite.db = db
+
+ models := []interface{}{
+ &model.User{},
+ &model.Account{},
+ &model.Follow{},
+ &model.Status{},
+ }
+
+ for _, m := range models {
+ if err := suite.db.CreateTable(m); err != nil {
+ logrus.Panicf("db connection error: %s", err)
+ }
+ }
+
+ if err := suite.db.Put(suite.testAccountLocal); err != nil {
+ logrus.Panicf("could not insert test account into db: %s", err)
+ }
+ if err := suite.db.Put(suite.testUser); err != nil {
+ logrus.Panicf("could not insert test user into db: %s", err)
+ }
+
+}
+
+// TearDownTest drops the oauth_clients table and closes the pg connection after each test
+func (suite *AccountTestSuite) TearDownTest() {
+ models := []interface{}{
+ &model.User{},
+ &model.Account{},
+ &model.Follow{},
+ &model.Status{},
+ }
+ for _, m := range models {
+ if err := suite.db.DropTable(m); err != nil {
+ logrus.Panicf("error dropping table: %s", err)
+ }
+ }
+ if err := suite.db.Stop(context.Background()); err != nil {
+ logrus.Panicf("error closing db connection: %s", err)
+ }
+ suite.db = nil
+}
+
+func (suite *AccountTestSuite) TestAPIInitialize() {
+ log := logrus.New()
+ log.SetLevel(logrus.TraceLevel)
+
+ r, err := router.New(suite.config, log)
+ if err != nil {
+ suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
+ }
+
+ r.AttachMiddleware(func(c *gin.Context) {
+ c.Set(oauth.SessionAuthorizedUser, suite.testUser.ID)
+ })
+
+ acct := New(suite.config, suite.db)
+ acct.Route(r)
+
+ r.Start()
+ defer r.Stop(context.Background())
+ time.Sleep(10 * time.Second)
+
+}
+
+func TestAccountTestSuite(t *testing.T) {
+ suite.Run(t, new(AccountTestSuite))
+}
diff --git a/internal/module/oauth/clientstore.go b/internal/module/oauth/clientstore.go
index f99c160..45b518c 100644
--- a/internal/module/oauth/clientstore.go
+++ b/internal/module/oauth/clientstore.go
@@ -20,7 +20,6 @@ package oauth
import (
"context"
- "fmt"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/oauth2/v4"
@@ -43,7 +42,7 @@ func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.Cli
ID: clientID,
}
if err := cs.db.GetByID(clientID, poc); err != nil {
- return nil, fmt.Errorf("database error: %s", err)
+ return nil, err
}
return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil
}
diff --git a/internal/module/oauth/clientstore_test.go b/internal/module/oauth/clientstore_test.go
index bca0024..8401142 100644
--- a/internal/module/oauth/clientstore_test.go
+++ b/internal/module/oauth/clientstore_test.go
@@ -136,7 +136,7 @@ func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
// try to get the deleted client; we should get an error
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
suite.Assert().Nil(deletedClient)
- suite.Assert().NotNil(err)
+ suite.Assert().EqualValues(db.ErrNoEntries{}, err)
}
func TestPgClientStoreTestSuite(t *testing.T) {
diff --git a/internal/module/oauth/oauth.go b/internal/module/oauth/oauth.go
index 9536dcd..2f77561 100644
--- a/internal/module/oauth/oauth.go
+++ b/internal/module/oauth/oauth.go
@@ -34,7 +34,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/db"
- "github.com/gotosocial/gotosocial/internal/gtsmodel"
+ "github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/module"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/gotosocial/pkg/mastotypes"
@@ -47,10 +47,10 @@ import (
)
const (
- appsPath = "/api/v1/apps"
- authSignInPath = "/auth/sign_in"
- oauthTokenPath = "/oauth/token"
- oauthAuthorizePath = "/oauth/authorize"
+ appsPath = "/api/v1/apps"
+ authSignInPath = "/auth/sign_in"
+ oauthTokenPath = "/oauth/token"
+ oauthAuthorizePath = "/oauth/authorize"
SessionAuthorizedUser = "authorized_user"
)
@@ -179,7 +179,7 @@ func (m *oauthModule) appsPOSTHandler(c *gin.Context) {
vapidKey := uuid.NewString()
// generate the application to put in the database
- app := >smodel.Application{
+ app := &model.Application{
Name: form.ClientName,
Website: form.Website,
RedirectURI: form.RedirectURIs,
@@ -287,7 +287,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
return
}
- app := >smodel.Application{
+ app := &model.Application{
ClientID: clientID,
}
if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil {
@@ -296,7 +296,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
}
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
- user := >smodel.User{
+ user := &model.User{
ID: userID,
}
if err := m.db.GetByID(user.ID, user); err != nil {
@@ -304,7 +304,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
return
}
- acct := >smodel.Account{
+ acct := &model.Account{
ID: user.AccountID,
}
@@ -413,7 +413,6 @@ func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil {
l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope())
c.Set(SessionAuthorizedUser, ti.GetUserID())
-
} else {
l.Trace("continuing with unauthenticated request")
}
@@ -437,7 +436,7 @@ func (m *oauthModule) validatePassword(email string, password string) (userid st
}
// first we select the user from the database based on email address, bail if no user found for that email
- gtsUser := >smodel.User{}
+ gtsUser := &model.User{}
if err := m.db.GetWhere("email", email, gtsUser); err != nil {
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
diff --git a/internal/module/oauth/oauth_test.go b/internal/module/oauth/oauth_test.go
index adfb40a..7dcff0d 100644
--- a/internal/module/oauth/oauth_test.go
+++ b/internal/module/oauth/oauth_test.go
@@ -22,12 +22,11 @@ import (
"context"
"fmt"
"testing"
- "time"
"github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db"
- "github.com/gotosocial/gotosocial/internal/gtsmodel"
+ "github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/oauth2/v4"
"github.com/sirupsen/logrus"
@@ -40,9 +39,9 @@ type OauthTestSuite struct {
tokenStore oauth2.TokenStore
clientStore oauth2.ClientStore
db db.DB
- testAccount *gtsmodel.Account
- testApplication *gtsmodel.Application
- testUser *gtsmodel.User
+ testAccount *model.Account
+ testApplication *model.Application
+ testUser *model.User
testClient *oauthClient
config *config.Config
}
@@ -76,11 +75,11 @@ func (suite *OauthTestSuite) SetupSuite() {
acctID := uuid.NewString()
- suite.testAccount = >smodel.Account{
+ suite.testAccount = &model.Account{
ID: acctID,
Username: "test_user",
}
- suite.testUser = >smodel.User{
+ suite.testUser = &model.User{
EncryptedPassword: string(encryptedPassword),
Email: "user@example.org",
AccountID: acctID,
@@ -90,7 +89,7 @@ func (suite *OauthTestSuite) SetupSuite() {
Secret: "some-secret",
Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
}
- suite.testApplication = >smodel.Application{
+ suite.testApplication = &model.Application{
Name: "a test application",
Website: "https://some-application-website.com",
RedirectURI: "http://localhost:8080",
@@ -116,9 +115,9 @@ func (suite *OauthTestSuite) SetupTest() {
models := []interface{}{
&oauthClient{},
&oauthToken{},
- >smodel.User{},
- >smodel.Account{},
- >smodel.Application{},
+ &model.User{},
+ &model.Account{},
+ &model.Application{},
}
for _, m := range models {
@@ -150,9 +149,9 @@ func (suite *OauthTestSuite) TearDownTest() {
models := []interface{}{
&oauthClient{},
&oauthToken{},
- >smodel.User{},
- >smodel.Account{},
- >smodel.Application{},
+ &model.User{},
+ &model.Account{},
+ &model.Application{},
}
for _, m := range models {
if err := suite.db.DropTable(m); err != nil {
@@ -179,11 +178,10 @@ func (suite *OauthTestSuite) TestAPIInitialize() {
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
}
- go r.Start()
- time.Sleep(60 * time.Second)
- // http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=http://localhost:8080&scope=read
- // curl -v -F client_id=a-known-client-id -F client_secret=some-secret -F redirect_uri=http://localhost:8080 -F code=[ INSERT CODE HERE ] -F grant_type=authorization_code localhost:8080/oauth/token
- // curl -v -H "Authorization: Bearer [INSERT TOKEN HERE]" http://localhost:8080
+ r.Start()
+ if err := r.Stop(context.Background()); err != nil {
+ suite.FailNow(fmt.Sprintf("error stopping router: %s", err))
+ }
}
func TestOauthTestSuite(t *testing.T) {
diff --git a/internal/router/router.go b/internal/router/router.go
index 3893503..779a91d 100644
--- a/internal/router/router.go
+++ b/internal/router/router.go
@@ -19,8 +19,10 @@
package router
import (
+ "context"
"crypto/rand"
"fmt"
+ "net/http"
"os"
"path/filepath"
@@ -40,26 +42,28 @@ type Router interface {
// Start the router
Start()
// Stop the router
- Stop()
+ Stop(ctx context.Context) error
}
// router fulfils the Router interface using gin and logrus
type router struct {
logger *logrus.Logger
engine *gin.Engine
+ srv *http.Server
}
// Start starts the router nicely
-func (s *router) Start() {
- // todo: start gracefully
- if err := s.engine.Run(); err != nil {
- s.logger.Panicf("server error: %s", err)
- }
+func (r *router) Start() {
+ go func() {
+ if err := r.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
+ r.logger.Fatalf("listen: %s", err)
+ }
+ }()
}
// Stop shuts down the router nicely
-func (s *router) Stop() {
- // todo: shut down gracefully
+func (s *router) Stop(ctx context.Context) error {
+ return s.srv.Shutdown(ctx)
}
// AttachHandler attaches the given gin.HandlerFunc to the router with the specified method and path.
@@ -100,6 +104,10 @@ func New(config *config.Config, logger *logrus.Logger) (Router, error) {
return &router{
logger: logger,
engine: engine,
+ srv: &http.Server{
+ Addr: ":8080",
+ Handler: engine,
+ },
}, nil
}
diff --git a/pkg/mastotypes/account.go b/pkg/mastotypes/account.go
index 031fa7c..ca50c64 100644
--- a/pkg/mastotypes/account.go
+++ b/pkg/mastotypes/account.go
@@ -31,7 +31,7 @@ type Account struct {
// Whether the account manually approves follow requests.
Locked bool `json:"locked"`
// Whether the account has opted into discovery features such as the profile directory.
- Discoverable bool `json:"discoverable"`
+ Discoverable bool `json:"discoverable,omitempty"`
// A presentational flag. Indicates that the account may perform automated actions, may not be monitored, or identifies as a robot.
Bot bool `json:"bot"`
// When the account was created. (ISO 8601 Datetime)
@@ -61,9 +61,9 @@ type Account struct {
// Additional metadata attached to a profile as name-value pairs.
Fields []Field `json:"fields"`
// An extra entity returned when an account is suspended.
- Suspended bool `json:"suspended"`
+ Suspended bool `json:"suspended,omitempty"`
// When a timed mute will expire, if applicable. (ISO 8601 Datetime)
- MuteExpiresAt string `json:"mute_expires_at"`
+ MuteExpiresAt string `json:"mute_expires_at,omitempty"`
// An extra entity to be used with API methods to verify credentials and update credentials.
Source *Source `json:"source"`
}
diff --git a/pkg/mastotypes/field.go b/pkg/mastotypes/field.go
index dbfe08c..29b5a18 100644
--- a/pkg/mastotypes/field.go
+++ b/pkg/mastotypes/field.go
@@ -28,7 +28,6 @@ type Field struct {
Value string `json:"value"`
// OPTIONAL
-
// Timestamp of when the server verified a URL value for a rel="me” link. String (ISO 8601 Datetime) if value is a verified URL
VerifiedAt string `json:"verified_at,omitempty"`
}