Api/v1/accounts (#8)

* start work on accounts module

* plodding away on the accounts endpoint

* groundwork for other account routes

* add password validator

* validation utils

* require account approval flags

* comments

* comments

* go fmt

* comments

* add distributor stub

* rename api to federator

* tidy a bit

* validate new account requests

* rename r router

* comments

* add domain blocks

* add some more shortcuts

* add some more shortcuts

* check email + username availability

* email block checking for signups

* chunking away at it

* tick off a few more things

* some fiddling with tests

* add mock package

* relocate repo

* move mocks around

* set app id on new signups

* initialize oauth server properly

* rename oauth server

* proper mocking tests

* go fmt ./...

* add required fields

* change name of func

* move validation to account.go

* more tests!

* add some file utility tools

* add mediaconfig

* new shortcut

* add some more fields

* add followrequest model

* add notify

* update mastotypes

* mock out storage interface

* start building media interface

* start on update credentials

* mess about with media a bit more

* test image manipulation

* media more or less working

* account update nearly working

* rearranging my package ;) ;) ;)

* phew big stuff!!!!

* fix type checking

* *fiddles*

* Add CreateTables func

* account registration flow working

* tidy

* script to step through auth flow

* add a lil helper for generating user uris

* fiddling with federation a bit

* update progress

* Tidying and linting
This commit is contained in:
Tobi Smethurst
2021-04-01 20:46:45 +02:00
committed by GitHub
parent aa9ce272dc
commit 71a49e2b43
94 changed files with 6585 additions and 955 deletions

View File

@ -20,8 +20,12 @@ package db
import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"net"
"net/mail"
"regexp"
"strings"
"time"
@ -30,14 +34,17 @@ import (
"github.com/go-pg/pg/extra/pgdebug"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/gtsmodel"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
"golang.org/x/crypto/bcrypt"
)
// postgresService satisfies the DB interface
type postgresService struct {
config *config.DBConfig
config *config.Config
conn *pg.DB
log *logrus.Entry
cancel context.CancelFunc
@ -46,7 +53,7 @@ type postgresService struct {
// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (*postgresService, error) {
func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry) (DB, error) {
opts, err := derivePGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create postgres service: %s", err)
@ -98,18 +105,18 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry
return nil, errors.New("db connection timeout")
}
// we can confidently return this useable postgres service now
return &postgresService{
config: c.DBConfig,
conn: conn,
log: log,
cancel: cancel,
federationDB: newPostgresFederation(conn),
}, nil
}
ps := &postgresService{
config: c,
conn: conn,
log: log,
cancel: cancel,
}
func (ps *postgresService) Federation() pub.Database {
return ps.federationDB
federatingDB := newFederatingDB(ps, c)
ps.federationDB = federatingDB
// we can confidently return this useable postgres service now
return ps, nil
}
/*
@ -168,9 +175,29 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
}
/*
EXTRA FUNCTIONS
FEDERATION FUNCTIONALITY
*/
func (ps *postgresService) Federation() pub.Database {
return ps.federationDB
}
/*
BASIC DB FUNCTIONALITY
*/
func (ps *postgresService) CreateTable(i interface{}) error {
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (ps *postgresService) DropTable(i interface{}) error {
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (ps *postgresService) Stop(ctx context.Context) error {
ps.log.Info("closing db connection")
if err := ps.conn.Close(); err != nil {
@ -181,11 +208,15 @@ func (ps *postgresService) Stop(ctx context.Context) error {
return nil
}
func (ps *postgresService) IsHealthy(ctx context.Context) error {
return ps.conn.Ping(ctx)
}
func (ps *postgresService) CreateSchema(ctx context.Context) error {
models := []interface{}{
(*gtsmodel.Account)(nil),
(*gtsmodel.Status)(nil),
(*gtsmodel.User)(nil),
(*model.Account)(nil),
(*model.Status)(nil),
(*model.User)(nil),
}
ps.log.Info("creating db schema")
@ -202,32 +233,35 @@ func (ps *postgresService) CreateSchema(ctx context.Context) error {
return nil
}
func (ps *postgresService) IsHealthy(ctx context.Context) error {
return ps.conn.Ping(ctx)
}
func (ps *postgresService) CreateTable(i interface{}) error {
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (ps *postgresService) DropTable(i interface{}) error {
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (ps *postgresService) GetByID(id string, i interface{}) error {
return ps.conn.Model(i).Where("id = ?", id).Select()
if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error {
return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select()
if err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetAll(i interface{}) error {
return ps.conn.Model(i).Select()
if err := ps.conn.Model(i).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) Put(i interface{}) error {
@ -236,16 +270,393 @@ func (ps *postgresService) Put(i interface{}) error {
}
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
_, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert()
if _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
_, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
_, err := ps.conn.Model(i).Where("id = ?", id).Delete()
return err
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error {
_, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete()
if _, err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Delete(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
/*
HANDY SHORTCUTS
*/
func (ps *postgresService) GetAccountByUserID(userID string, account *model.Account) error {
user := &model.User{
ID: userID,
}
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error {
if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]model.Follow) error {
if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error {
if err := ps.conn.Model(followers).Where("target_account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetStatusesByAccountID(accountID string, statuses *[]model.Status) error {
if err := ps.conn.Model(statuses).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetStatusesByTimeDescending(accountID string, statuses *[]model.Status, limit int) error {
q := ps.conn.Model(statuses).Order("created_at DESC")
if limit != 0 {
q = q.Limit(limit)
}
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *model.Status) error {
if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) IsUsernameAvailable(username string) error {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
if err := ps.conn.Model(&model.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
return fmt.Errorf("username %s already in use", username)
} else if err != pg.ErrNoRows {
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (ps *postgresService) IsEmailAvailable(email string) error {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
return fmt.Errorf("error parsing email address %s: %s", email, err)
}
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
if err := ps.conn.Model(&model.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email domain %s is blocked", domain)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
// check if this email is associated with a user already
if err := ps.conn.Model(&model.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email %s already in use", email)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
uris := util.GenerateURIs(username, ps.config.Protocol, ps.config.Host)
a := &model.Account{
Username: username,
DisplayName: username,
Reason: reason,
URL: uris.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
ActorType: "Person",
URI: uris.UserURI,
InboxURL: uris.InboxURL,
OutboxURL: uris.OutboxURL,
FollowersURL: uris.FollowersURL,
FeaturedCollectionURL: uris.CollectionURL,
}
if _, err = ps.conn.Model(a).Insert(); err != nil {
return nil, err
}
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err)
}
u := &model.User{
AccountID: a.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP,
Locale: locale,
UnconfirmedEmail: email,
CreatedByApplicationID: appID,
Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
}
if _, err = ps.conn.Model(u).Insert(); err != nil {
return nil, err
}
return u, nil
}
func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error {
_, err := ps.conn.Model(mediaAttachment).Insert()
return err
}
func (ps *postgresService) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error {
if err := ps.conn.Model(header).Where("account_id = ?", accountID).Where("header = ?", true).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error {
if err := ps.conn.Model(avatar).Where("account_id = ?", accountID).Where("avatar = ?", true).Select(); err != nil {
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
return err
}
return nil
}
/*
CONVERSION FUNCTIONS
*/
// AccountToMastoSensitive takes an internal account model and transforms it into an account ready to be served through the API.
// The resulting account fits the specifications for the path /api/v1/accounts/verify_credentials, as described here:
// https://docs.joinmastodon.org/methods/accounts/. Note that it's *sensitive* because it's only meant to be exposed to the user
// that the account actually belongs to.
func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) {
// we can build this sensitive account easily by first getting the public account....
mastoAccount, err := ps.AccountToMastoPublic(a)
if err != nil {
return nil, err
}
// then adding the Source object to it...
// check pending follow requests aimed at this account
fr := []model.FollowRequest{}
if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting follow requests: %s", err)
}
}
var frc int
if fr != nil {
frc = len(fr)
}
mastoAccount.Source = &mastotypes.Source{
Privacy: a.Privacy,
Sensitive: a.Sensitive,
Language: a.Language,
Note: a.Note,
Fields: mastoAccount.Fields,
FollowRequestsCount: frc,
}
return mastoAccount, nil
}
func (ps *postgresService) AccountToMastoPublic(a *model.Account) (*mastotypes.Account, error) {
// count followers
followers := []model.Follow{}
if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting followers: %s", err)
}
}
var followersCount int
if followers != nil {
followersCount = len(followers)
}
// count following
following := []model.Follow{}
if err := ps.GetFollowingByAccountID(a.ID, &following); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting following: %s", err)
}
}
var followingCount int
if following != nil {
followingCount = len(following)
}
// count statuses
statuses := []model.Status{}
if err := ps.GetStatusesByAccountID(a.ID, &statuses); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting last statuses: %s", err)
}
}
var statusesCount int
if statuses != nil {
statusesCount = len(statuses)
}
// check when the last status was
lastStatus := &model.Status{}
if err := ps.GetLastStatusForAccountID(a.ID, lastStatus); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting last status: %s", err)
}
}
var lastStatusAt string
if lastStatus != nil {
lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339)
}
// build the avatar and header URLs
avi := &model.MediaAttachment{}
if err := ps.GetAvatarForAccountID(avi, a.ID); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting avatar: %s", err)
}
}
aviURL := avi.File.Path
aviURLStatic := avi.Thumbnail.Path
header := &model.MediaAttachment{}
if err := ps.GetHeaderForAccountID(avi, a.ID); err != nil {
if _, ok := err.(ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting header: %s", err)
}
}
headerURL := header.File.Path
headerURLStatic := header.Thumbnail.Path
// get the fields set on this account
fields := []mastotypes.Field{}
for _, f := range a.Fields {
mField := mastotypes.Field{
Name: f.Name,
Value: f.Value,
}
if !f.VerifiedAt.IsZero() {
mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339)
}
fields = append(fields, mField)
}
var acct string
if a.Domain != "" {
// this is a remote user
acct = fmt.Sprintf("%s@%s", a.Username, a.Domain)
} else {
// this is a local user
acct = a.Username
}
return &mastotypes.Account{
ID: a.ID,
Username: a.Username,
Acct: acct,
DisplayName: a.DisplayName,
Locked: a.Locked,
Bot: a.Bot,
CreatedAt: a.CreatedAt.Format(time.RFC3339),
Note: a.Note,
URL: a.URL,
Avatar: aviURL,
AvatarStatic: aviURLStatic,
Header: headerURL,
HeaderStatic: headerURLStatic,
FollowersCount: followersCount,
FollowingCount: followingCount,
StatusesCount: statusesCount,
LastStatusAt: lastStatusAt,
Emojis: nil, // TODO: implement this
Fields: fields,
}, nil
}