From 6d81df50cef20cb5f321614f45bacb9f84ce24ce Mon Sep 17 00:00:00 2001 From: tsmethurst Date: Mon, 22 Mar 2021 22:23:42 +0100 Subject: [PATCH] lots of restructuring --- internal/api/server.go | 91 -------- internal/db/actions.go | 11 +- internal/db/db.go | 47 ++-- internal/db/pg-fed.go | 137 ++++++++++++ internal/db/{postgres.go => pg.go} | 208 +++++------------- internal/email/email.go | 2 +- internal/federation/federation.go | 4 +- internal/gotosocial/actions.go | 6 +- internal/gotosocial/gotosocial.go | 8 +- internal/gtsmodel/account.go | 2 +- internal/gtsmodel/application.go | 8 +- internal/gtsmodel/status.go | 30 +-- .../account/account.go} | 20 +- internal/module/module.go | 29 +++ internal/{ => module}/oauth/README.md | 2 +- .../oauth/clientstore.go} | 34 ++- .../oauth/clientstore_test.go} | 74 +++++-- internal/{ => module}/oauth/oauth.go | 122 +++++----- internal/{ => module}/oauth/oauth_test.go | 87 +++++--- .../oauth/tokenstore.go} | 96 ++++---- internal/router/router.go | 120 ++++++++++ 21 files changed, 674 insertions(+), 464 deletions(-) delete mode 100644 internal/api/server.go create mode 100644 internal/db/pg-fed.go rename internal/db/{postgres.go => pg.go} (56%) rename internal/{api/route_statuses.go => module/account/account.go} (65%) create mode 100644 internal/module/module.go rename internal/{ => module}/oauth/README.md (84%) rename internal/{oauth/pgclientstore.go => module/oauth/clientstore.go} (56%) rename internal/{oauth/pgclientstore_test.go => module/oauth/clientstore_test.go} (59%) rename internal/{ => module}/oauth/oauth.go (81%) rename internal/{ => module}/oauth/oauth_test.go (62%) rename internal/{oauth/pgtokenstore.go => module/oauth/tokenstore.go} (72%) create mode 100644 internal/router/router.go diff --git a/internal/api/server.go b/internal/api/server.go deleted file mode 100644 index 5f98485..0000000 --- a/internal/api/server.go +++ /dev/null @@ -1,91 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . -*/ - -package api - -import ( - "fmt" - "os" - "path/filepath" - - "github.com/gin-contrib/sessions" - "github.com/gin-contrib/sessions/memstore" - "github.com/gin-gonic/gin" - "github.com/gotosocial/gotosocial/internal/config" - "github.com/sirupsen/logrus" -) - -type Server interface { - AttachHandler(method string, path string, handler gin.HandlerFunc) - AttachMiddleware(handler gin.HandlerFunc) - GetAPIGroup() *gin.RouterGroup - Start() - Stop() -} - -type AddsRoutes interface { - AddRoutes(s Server) error -} - -type server struct { - APIGroup *gin.RouterGroup - logger *logrus.Logger - engine *gin.Engine -} - -func (s *server) GetAPIGroup() *gin.RouterGroup { - return s.APIGroup -} - -func (s *server) Start() { - // todo: start gracefully - if err := s.engine.Run(); err != nil { - s.logger.Panicf("server error: %s", err) - } -} - -func (s *server) Stop() { - // todo: shut down gracefully -} - -func (s *server) AttachHandler(method string, path string, handler gin.HandlerFunc) { - if method == "ANY" { - s.engine.Any(path, handler) - } else { - s.engine.Handle(method, path, handler) - } -} - -func (s *server) AttachMiddleware(middleware gin.HandlerFunc) { - s.engine.Use(middleware) -} - -func New(config *config.Config, logger *logrus.Logger) Server { - engine := gin.New() - store := memstore.NewStore([]byte("authentication-key"), []byte("encryption-keyencryption-key----")) - engine.Use(sessions.Sessions("gotosocial-session", store)) - cwd, _ := os.Getwd() - tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", config.TemplateConfig.BaseDir)) - logger.Debugf("loading templates from %s", tmPath) - engine.LoadHTMLGlob(tmPath) - return &server{ - APIGroup: engine.Group("/api").Group("/v1"), - logger: logger, - engine: engine, - } -} diff --git a/internal/db/actions.go b/internal/db/actions.go index 6fa7d23..01fb44b 100644 --- a/internal/db/actions.go +++ b/internal/db/actions.go @@ -28,9 +28,10 @@ import ( // Initialize will initialize the database given in the config for use with GoToSocial var Initialize action.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - db, err := New(ctx, c, log) - if err != nil { - return err - } - return db.CreateSchema(ctx) + // db, err := New(ctx, c, log) + // if err != nil { + // return err + // } + return nil + // return db.CreateSchema(ctx) } diff --git a/internal/db/db.go b/internal/db/db.go index 4ea4e1a..9952e5e 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -30,30 +30,47 @@ import ( const dbTypePostgres string = "POSTGRES" -// DB provides methods for interacting with an underlying database (for now, just postgres). -// The function mapping lines up with the DB interface described in go-fed. -// See here: https://github.com/go-fed/activity/blob/master/pub/database.go +// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres). type DB interface { - /* - GO-FED DATABASE FUNCTIONS - */ - pub.Database + // Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions. + // See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database + Federation() pub.Database - /* - ANY ADDITIONAL DESIRED FUNCTIONS - */ + // CreateTable creates a table for the given interface + CreateTable(i interface{}) error - // CreateSchema should populate the database with the required tables - CreateSchema(context.Context) error + // DropTable drops the table for the given interface + DropTable(i interface{}) error // Stop should stop and close the database connection cleanly, returning an error if this is not possible - Stop(context.Context) error + Stop(ctx context.Context) error // IsHealthy should return nil if the database connection is healthy, or an error if not - IsHealthy(context.Context) error + IsHealthy(ctx context.Context) error + + // GetByID gets one entry by its id. + GetByID(id string, i interface{}) error + + // GetWhere gets one entry where key = value + GetWhere(key string, value interface{}, i interface{}) error + + // GetAll gets all entries of interface type i + GetAll(i interface{}) error + + // Put stores i + Put(i interface{}) error + + // Update by id updates i with id id + UpdateByID(id string, i interface{}) error + + // Delete by id removes i with id id + DeleteByID(id string, i interface{}) error + + // Delete where deletes i where key = value + DeleteWhere(key string, value interface{}, i interface{}) error } -// New returns a new database service that satisfies the Service interface and, by extension, +// New returns a new database service that satisfies the DB interface and, by extension, // the go-fed database interface described here: https://github.com/go-fed/activity/blob/master/pub/database.go func New(ctx context.Context, c *config.Config, log *logrus.Logger) (DB, error) { switch strings.ToUpper(c.DBConfig.Type) { diff --git a/internal/db/pg-fed.go b/internal/db/pg-fed.go new file mode 100644 index 0000000..ec1957a --- /dev/null +++ b/internal/db/pg-fed.go @@ -0,0 +1,137 @@ +package db + +import ( + "context" + "errors" + "net/url" + "sync" + + "github.com/go-fed/activity/pub" + "github.com/go-fed/activity/streams" + "github.com/go-fed/activity/streams/vocab" + "github.com/go-pg/pg/v10" +) + +type postgresFederation struct { + locks *sync.Map + conn *pg.DB +} + +func newPostgresFederation(conn *pg.DB) pub.Database { + return &postgresFederation{ + locks: new(sync.Map), + conn: conn, + } +} + +/* + GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS +*/ +func (pf *postgresFederation) Lock(ctx context.Context, id *url.URL) error { + // Before any other Database methods are called, the relevant `id` + // entries are locked to allow for fine-grained concurrency. + + // Strategy: create a new lock, if stored, continue. Otherwise, lock the + // existing mutex. + mu := &sync.Mutex{} + mu.Lock() // Optimistically lock if we do store it. + i, loaded := pf.locks.LoadOrStore(id.String(), mu) + if loaded { + mu = i.(*sync.Mutex) + mu.Lock() + } + return nil +} + +func (pf *postgresFederation) Unlock(ctx context.Context, id *url.URL) error { + // Once Go-Fed is done calling Database methods, the relevant `id` + // entries are unlocked. + + i, ok := pf.locks.Load(id.String()) + if !ok { + return errors.New("missing an id in unlock") + } + mu := i.(*sync.Mutex) + mu.Unlock() + return nil +} + +func (pf *postgresFederation) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) { + return false, nil +} + +func (pf *postgresFederation) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { + return nil, nil +} + +func (pf *postgresFederation) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error { + return nil +} + +func (pf *postgresFederation) Owns(ctx context.Context, id *url.URL) (owns bool, err error) { + return false, nil +} + +func (pf *postgresFederation) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) { + return nil, nil +} + +func (pf *postgresFederation) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) { + return nil, nil +} + +func (pf *postgresFederation) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) { + return nil, nil +} + +func (pf *postgresFederation) Exists(ctx context.Context, id *url.URL) (exists bool, err error) { + return false, nil +} + +func (pf *postgresFederation) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) { + return nil, nil +} + +func (pf *postgresFederation) Create(ctx context.Context, asType vocab.Type) error { + t, err := streams.NewTypeResolver() + if err != nil { + return err + } + if err := t.Resolve(ctx, asType); err != nil { + return err + } + asType.GetTypeName() + return nil +} + +func (pf *postgresFederation) Update(ctx context.Context, asType vocab.Type) error { + return nil +} + +func (pf *postgresFederation) Delete(ctx context.Context, id *url.URL) error { + return nil +} + +func (pf *postgresFederation) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { + return nil, nil +} + +func (pf *postgresFederation) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error { + return nil +} + +func (pf *postgresFederation) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) { + return nil, nil +} + +func (pf *postgresFederation) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { + return nil, nil +} + +func (pf *postgresFederation) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { + return nil, nil +} + +func (pf *postgresFederation) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { + return nil, nil +} diff --git a/internal/db/postgres.go b/internal/db/pg.go similarity index 56% rename from internal/db/postgres.go rename to internal/db/pg.go index dae6b11..2919190 100644 --- a/internal/db/postgres.go +++ b/internal/db/pg.go @@ -22,14 +22,11 @@ import ( "context" "errors" "fmt" - "net/url" "regexp" "strings" - "sync" "time" - "github.com/go-fed/activity/streams" - "github.com/go-fed/activity/streams/vocab" + "github.com/go-fed/activity/pub" "github.com/go-pg/pg/extra/pgdebug" "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" @@ -39,13 +36,14 @@ import ( "github.com/sirupsen/logrus" ) +// postgresService satisfies the DB interface type postgresService struct { - config *config.DBConfig - conn *pg.DB - log *logrus.Entry - cancel context.CancelFunc - locks *sync.Map - tokenStore oauth2.TokenStore + config *config.DBConfig + conn *pg.DB + log *logrus.Entry + cancel context.CancelFunc + tokenStore oauth2.TokenStore + federationDB pub.Database } // newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. @@ -102,36 +100,20 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry return nil, errors.New("db connection timeout") } - // acc := model.StubAccount() - // if _, err := conn.Model(acc).Returning("id").Insert(); err != nil { - // cancel() - // return nil, fmt.Errorf("db insert error: %s", err) - // } - // log.Infof("created account with id %s", acc.ID) - - // note := &model.Note{ - // Visibility: &model.Visibility{ - // Local: true, - // }, - // CreatedAt: time.Now(), - // UpdatedAt: time.Now(), - // } - // if _, err := conn.WithContext(ctx).Model(note).Returning("id").Insert(); err != nil { - // cancel() - // return nil, fmt.Errorf("db insert error: %s", err) - // } - // log.Infof("created note with id %s", note.ID) - // we can confidently return this useable postgres service now return &postgresService{ - config: c.DBConfig, - conn: conn, - log: log, - cancel: cancel, - locks: &sync.Map{}, + config: c.DBConfig, + conn: conn, + log: log, + cancel: cancel, + federationDB: newPostgresFederation(conn), }, nil } +func (ps *postgresService) Federation() pub.Database { + return ps.federationDB +} + /* HANDY STUFF */ @@ -187,118 +169,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { return options, nil } -/* - GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS -*/ -func (ps *postgresService) Lock(ctx context.Context, id *url.URL) error { - // Before any other Database methods are called, the relevant `id` - // entries are locked to allow for fine-grained concurrency. - - // Strategy: create a new lock, if stored, continue. Otherwise, lock the - // existing mutex. - mu := &sync.Mutex{} - mu.Lock() // Optimistically lock if we do store it. - i, loaded := ps.locks.LoadOrStore(id.String(), mu) - if loaded { - mu = i.(*sync.Mutex) - mu.Lock() - } - return nil -} - -func (ps *postgresService) Unlock(ctx context.Context, id *url.URL) error { - // Once Go-Fed is done calling Database methods, the relevant `id` - // entries are unlocked. - - i, ok := ps.locks.Load(id.String()) - if !ok { - return errors.New("missing an id in unlock") - } - mu := i.(*sync.Mutex) - mu.Unlock() - return nil -} - -func (ps *postgresService) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) { - return false, nil -} - -func (ps *postgresService) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { - return nil, nil -} - -func (ps *postgresService) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error { - return nil -} - -func (ps *postgresService) Owns(ctx context.Context, id *url.URL) (owns bool, err error) { - return false, nil -} - -func (ps *postgresService) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) { - return nil, nil -} - -func (ps *postgresService) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) { - return nil, nil -} - -func (ps *postgresService) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) { - return nil, nil -} - -func (ps *postgresService) Exists(ctx context.Context, id *url.URL) (exists bool, err error) { - return false, nil -} - -func (ps *postgresService) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) { - return nil, nil -} - -func (ps *postgresService) Create(ctx context.Context, asType vocab.Type) error { - t, err := streams.NewTypeResolver() - if err != nil { - return err - } - if err := t.Resolve(ctx, asType); err != nil { - return err - } - asType.GetTypeName() - return nil -} - -func (ps *postgresService) Update(ctx context.Context, asType vocab.Type) error { - return nil -} - -func (ps *postgresService) Delete(ctx context.Context, id *url.URL) error { - return nil -} - -func (ps *postgresService) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { - return nil, nil -} - -func (ps *postgresService) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error { - return nil -} - -func (ps *postgresService) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) { - return nil, nil -} - -func (ps *postgresService) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { - return nil, nil -} - -func (ps *postgresService) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { - return nil, nil -} - -func (ps *postgresService) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { - return nil, nil -} - /* EXTRA FUNCTIONS */ @@ -338,6 +208,46 @@ func (ps *postgresService) IsHealthy(ctx context.Context) error { return ps.conn.Ping(ctx) } -func (ps *postgresService) TokenStore() oauth2.TokenStore { - return ps.tokenStore +func (ps *postgresService) CreateTable(i interface{}) error { + return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{ + IfNotExists: true, + }) +} + +func (ps *postgresService) DropTable(i interface{}) error { + return ps.conn.Model(i).DropTable(&orm.DropTableOptions{ + IfExists: true, + }) +} + +func (ps *postgresService) GetByID(id string, i interface{}) error { + return ps.conn.Model(i).Where("id = ?", id).Select() +} + +func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error { + return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select() +} + +func (ps *postgresService) GetAll(i interface{}) error { + return ps.conn.Model(i).Select() +} + +func (ps *postgresService) Put(i interface{}) error { + _, err := ps.conn.Model(i).Insert(i) + return err +} + +func (ps *postgresService) UpdateByID(id string, i interface{}) error { + _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert() + return err +} + +func (ps *postgresService) DeleteByID(id string, i interface{}) error { + _, err := ps.conn.Model(i).Where("id = ?", id).Delete() + return err +} + +func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error { + _, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete() + return err } diff --git a/internal/email/email.go b/internal/email/email.go index d70f6c5..3d6a9dd 100644 --- a/internal/email/email.go +++ b/internal/email/email.go @@ -16,5 +16,5 @@ along with this program. If not, see . */ -// package email provides a service for interacting with an SMTP server +// Package email provides a service for interacting with an SMTP server package email diff --git a/internal/federation/federation.go b/internal/federation/federation.go index ebc9102..cbd4eda 100644 --- a/internal/federation/federation.go +++ b/internal/federation/federation.go @@ -30,11 +30,13 @@ import ( "github.com/gotosocial/gotosocial/internal/db" ) +// New returns a go-fed compatible federating actor func New(db db.DB) pub.FederatingActor { fa := &API{} - return pub.NewFederatingActor(fa, fa, db, fa) + return pub.NewFederatingActor(fa, fa, db.Federation(), fa) } +// API implements several go-fed interfaces in one convenient location type API struct { } diff --git a/internal/gotosocial/actions.go b/internal/gotosocial/actions.go index 3d3fdc3..398c0b4 100644 --- a/internal/gotosocial/actions.go +++ b/internal/gotosocial/actions.go @@ -38,9 +38,9 @@ var Run action.GTSAction = func(ctx context.Context, c *config.Config, log *logr return fmt.Errorf("error creating dbservice: %s", err) } - if err := dbService.CreateSchema(ctx); err != nil { - return fmt.Errorf("error creating dbschema: %s", err) - } + // if err := dbService.CreateSchema(ctx); err != nil { + // return fmt.Errorf("error creating dbschema: %s", err) + // } // catch shutdown signals from the operating system sigs := make(chan os.Signal, 1) diff --git a/internal/gotosocial/gotosocial.go b/internal/gotosocial/gotosocial.go index 4409e85..d9fb295 100644 --- a/internal/gotosocial/gotosocial.go +++ b/internal/gotosocial/gotosocial.go @@ -22,10 +22,10 @@ import ( "context" "github.com/go-fed/activity/pub" - "github.com/gotosocial/gotosocial/internal/api" "github.com/gotosocial/gotosocial/internal/cache" "github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/db" + "github.com/gotosocial/gotosocial/internal/router" ) type Gotosocial interface { @@ -33,11 +33,11 @@ type Gotosocial interface { Stop(context.Context) error } -func New(db db.DB, cache cache.Cache, clientAPI api.Server, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) { +func New(db db.DB, cache cache.Cache, apiRouter router.Router, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) { return &gotosocial{ db: db, cache: cache, - clientAPI: clientAPI, + apiRouter: apiRouter, federationAPI: federationAPI, config: config, }, nil @@ -46,7 +46,7 @@ func New(db db.DB, cache cache.Cache, clientAPI api.Server, federationAPI pub.Fe type gotosocial struct { db db.DB cache cache.Cache - clientAPI api.Server + apiRouter router.Router federationAPI pub.FederatingActor config *config.Config } diff --git a/internal/gtsmodel/account.go b/internal/gtsmodel/account.go index 6786014..6c17b90 100644 --- a/internal/gtsmodel/account.go +++ b/internal/gtsmodel/account.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -// package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database. +// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database. // These types should never be serialized and/or sent out via public APIs, as they contain sensitive information. // The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/ package gtsmodel diff --git a/internal/gtsmodel/application.go b/internal/gtsmodel/application.go index 0d3265d..fd0fa6a 100644 --- a/internal/gtsmodel/application.go +++ b/internal/gtsmodel/application.go @@ -35,10 +35,10 @@ type Application struct { ClientID string // secret of the associated oauth client entity in the db ClientSecret string - // scopes requested when this app was created - Scopes string - // a vapid key generated for this app when it was created - VapidKey string + // scopes requested when this app was created + Scopes string + // a vapid key generated for this app when it was created + VapidKey string } // ToMastotype returns this application as a mastodon api type, ready for serialization diff --git a/internal/gtsmodel/status.go b/internal/gtsmodel/status.go index 3a15cf5..1c0e920 100644 --- a/internal/gtsmodel/status.go +++ b/internal/gtsmodel/status.go @@ -23,41 +23,41 @@ import "time" // Status represents a user-created 'post' or 'status' in the database, either remote or local type Status struct { // id of the status in the database - ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` + ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` // uri at which this status is reachable - URI string `pg:",unique"` + URI string `pg:",unique"` // web url for viewing this status - URL string `pg:",unique"` + URL string `pg:",unique"` // the html-formatted content of this status - Content string + Content string // when was this status created? - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` // when was this status updated? - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` // is this status from a local account? - Local bool + Local bool // which account posted this status? - AccountID string + AccountID string // id of the status this status is a reply to - InReplyToID string + InReplyToID string // id of the status this status is a boost of - BoostOfID string + BoostOfID string // cw string for this status ContentWarning string // visibility entry for this status - Visibility *Visibility + Visibility *Visibility } // Visibility represents the visibility granularity of a status. It is a combination of flags. type Visibility struct { // Is this status viewable as a direct message? - Direct bool + Direct bool // Is this status viewable to followers? Followers bool // Is this status viewable on the local timeline? - Local bool + Local bool // Is this status boostable but not shown on public timelines? - Unlisted bool + Unlisted bool // Is this status shown on public and federated timelines? - Public bool + Public bool } diff --git a/internal/api/route_statuses.go b/internal/module/account/account.go similarity index 65% rename from internal/api/route_statuses.go rename to internal/module/account/account.go index fe8aa96..d82d96e 100644 --- a/internal/api/route_statuses.go +++ b/internal/module/account/account.go @@ -16,4 +16,22 @@ along with this program. If not, see . */ -package api +package account + +import ( + "github.com/gotosocial/gotosocial/internal/module" + "github.com/gotosocial/gotosocial/internal/router" +) + +type accountModule struct { +} + +// New returns a new account module +func New() module.ClientAPIModule { + return &accountModule{} +} + +// Route attaches all routes from this module to the given router +func (m *accountModule) Route(r router.Router) error { + return nil +} diff --git a/internal/module/module.go b/internal/module/module.go new file mode 100644 index 0000000..8618d28 --- /dev/null +++ b/internal/module/module.go @@ -0,0 +1,29 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +// Package module is basically a wrapper for a lot of modules (in subdirectories) that satisfy the ClientAPIModule interface. +package module + +import "github.com/gotosocial/gotosocial/internal/router" + +// ClientAPIModule represents a chunk of code (usually contained in a single package) that adds a set +// of functionalities and side effects to a router, by mapping routes and handlers onto it--in other words, a REST API ;) +// A ClientAPIMpdule corresponds roughly to one main path of the gotosocial REST api, for example /api/v1/accounts/ or /oauth/ +type ClientAPIModule interface { + Route(s router.Router) error +} diff --git a/internal/oauth/README.md b/internal/module/oauth/README.md similarity index 84% rename from internal/oauth/README.md rename to internal/module/oauth/README.md index f739aa6..3d84273 100644 --- a/internal/oauth/README.md +++ b/internal/module/oauth/README.md @@ -1,5 +1,5 @@ # oauth -This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) server functionality to the GoToSocial APIs. +This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API. It also provides a handler/middleware for attaching to the Gin engine for validating authenticated users. diff --git a/internal/oauth/pgclientstore.go b/internal/module/oauth/clientstore.go similarity index 56% rename from internal/oauth/pgclientstore.go rename to internal/module/oauth/clientstore.go index 1df46fe..f99c160 100644 --- a/internal/oauth/pgclientstore.go +++ b/internal/module/oauth/clientstore.go @@ -22,55 +22,47 @@ import ( "context" "fmt" - "github.com/go-pg/pg/v10" + "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/oauth2/v4" "github.com/gotosocial/oauth2/v4/models" ) -type pgClientStore struct { - conn *pg.DB +type clientStore struct { + db db.DB } -func NewPGClientStore(conn *pg.DB) oauth2.ClientStore { - pts := &pgClientStore{ - conn: conn, +func newClientStore(db db.DB) oauth2.ClientStore { + pts := &clientStore{ + db: db, } return pts } -func (pcs *pgClientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { +func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { poc := &oauthClient{ ID: clientID, } - if err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Select(); err != nil { - return nil, fmt.Errorf("error in clientstore getbyid searching for client %s: %s", clientID, err) + if err := cs.db.GetByID(clientID, poc); err != nil { + return nil, fmt.Errorf("database error: %s", err) } return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil } -func (pcs *pgClientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { +func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { poc := &oauthClient{ ID: cli.GetID(), Secret: cli.GetSecret(), Domain: cli.GetDomain(), UserID: cli.GetUserID(), } - _, err := pcs.conn.WithContext(ctx).Model(poc).OnConflict("(id) DO UPDATE").Insert() - if err != nil { - return fmt.Errorf("error in clientstore set: %s", err) - } - return nil + return cs.db.UpdateByID(id, poc) } -func (pcs *pgClientStore) Delete(ctx context.Context, id string) error { +func (cs *clientStore) Delete(ctx context.Context, id string) error { poc := &oauthClient{ ID: id, } - _, err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Delete() - if err != nil { - return fmt.Errorf("error in clientstore delete: %s", err) - } - return nil + return cs.db.DeleteByID(id, poc) } type oauthClient struct { diff --git a/internal/oauth/pgclientstore_test.go b/internal/module/oauth/clientstore_test.go similarity index 59% rename from internal/oauth/pgclientstore_test.go rename to internal/module/oauth/clientstore_test.go index eb011fe..032f4f3 100644 --- a/internal/oauth/pgclientstore_test.go +++ b/internal/module/oauth/clientstore_test.go @@ -1,11 +1,29 @@ + +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ package oauth import ( "context" "testing" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" + "github.com/gotosocial/gotosocial/internal/config" + "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/oauth2/v4/models" "github.com/sirupsen/logrus" "github.com/stretchr/testify/suite" @@ -13,7 +31,7 @@ import ( type PgClientStoreTestSuite struct { suite.Suite - conn *pg.DB + db db.DB testClientID string testClientSecret string testClientDomain string @@ -32,31 +50,55 @@ func (suite *PgClientStoreTestSuite) SetupSuite() { // SetupTest creates a postgres connection and creates the oauth_clients table before each test func (suite *PgClientStoreTestSuite) SetupTest() { - suite.conn = pg.Connect(&pg.Options{}) - if err := suite.conn.Ping(context.Background()); err != nil { - logrus.Panicf("db connection error: %s", err) + log := logrus.New() + log.SetLevel(logrus.TraceLevel) + c := config.Empty() + c.DBConfig = &config.DBConfig{ + Type: "postgres", + Address: "localhost", + Port: 5432, + User: "postgres", + Password: "postgres", + Database: "postgres", + ApplicationName: "gotosocial", } - if err := suite.conn.Model(&oauthClient{}).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }); err != nil { - logrus.Panicf("db connection error: %s", err) + db, err := db.New(context.Background(), c, log) + if err != nil { + logrus.Panicf("error creating database connection: %s", err) + } + + suite.db = db + + models := []interface{}{ + &oauthClient{}, + } + + for _, m := range models { + if err := suite.db.CreateTable(m); err != nil { + logrus.Panicf("db connection error: %s", err) + } } } // TearDownTest drops the oauth_clients table and closes the pg connection after each test func (suite *PgClientStoreTestSuite) TearDownTest() { - if err := suite.conn.Model(&oauthClient{}).DropTable(&orm.DropTableOptions{}); err != nil { - logrus.Panicf("drop table error: %s", err) + models := []interface{}{ + &oauthClient{}, } - if err := suite.conn.Close(); err != nil { + for _, m := range models { + if err := suite.db.DropTable(m); err != nil { + logrus.Panicf("error dropping table: %s", err) + } + } + if err := suite.db.Stop(context.Background()); err != nil { logrus.Panicf("error closing db connection: %s", err) } - suite.conn = nil + suite.db = nil } func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() { // set a new client in the store - cs := NewPGClientStore(suite.conn) + cs := newClientStore(suite.db) if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { suite.FailNow(err.Error()) } @@ -74,7 +116,7 @@ func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() { func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() { // set a new client in the store - cs := NewPGClientStore(suite.conn) + cs := newClientStore(suite.db) if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { suite.FailNow(err.Error()) } diff --git a/internal/oauth/oauth.go b/internal/module/oauth/oauth.go similarity index 81% rename from internal/oauth/oauth.go rename to internal/module/oauth/oauth.go index f4b7fbb..b0530c6 100644 --- a/internal/oauth/oauth.go +++ b/internal/module/oauth/oauth.go @@ -16,6 +16,13 @@ along with this program. If not, see . */ +// Package oauth is a module that provides oauth functionality to a router. +// It adds the following paths: +// /api/v1/apps +// /auth/sign_in +// /oauth/token +// /oauth/authorize +// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token. package oauth import ( @@ -25,10 +32,11 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/go-pg/pg/v10" "github.com/google/uuid" - "github.com/gotosocial/gotosocial/internal/api" + "github.com/gotosocial/gotosocial/internal/db" + "github.com/gotosocial/gotosocial/internal/module" "github.com/gotosocial/gotosocial/internal/gtsmodel" + "github.com/gotosocial/gotosocial/internal/router" "github.com/gotosocial/gotosocial/pkg/mastotypes" "github.com/gotosocial/oauth2/v4" "github.com/gotosocial/oauth2/v4/errors" @@ -39,18 +47,18 @@ import ( ) const ( - outOfBandRedirect = "urn:ietf:wg:oauth:2.0:oob" appsPath = "/api/v1/apps" authSignInPath = "/auth/sign_in" oauthTokenPath = "/oauth/token" oauthAuthorizePath = "/oauth/authorize" ) -type API struct { - manager *manage.Manager - server *server.Server - conn *pg.DB - log *logrus.Logger +// oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface +type oauthModule struct { + oauthManager *manage.Manager + oauthServer *server.Server + db db.DB + log *logrus.Logger } type login struct { @@ -58,11 +66,8 @@ type login struct { Password string `form:"password"` } -type code struct { - Code string `form:"code"` -} - -func New(ts oauth2.TokenStore, cs oauth2.ClientStore, conn *pg.DB, log *logrus.Logger) *API { +// New returns a new oauth module +func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule { manager := manage.NewDefaultManager() manager.MapTokenStorage(ts) manager.MapClientStorage(cs) @@ -91,30 +96,31 @@ func New(ts oauth2.TokenStore, cs oauth2.ClientStore, conn *pg.DB, log *logrus.L log.Errorf("internal response error: %s", re.Error) }) - api := &API{ - manager: manager, - server: srv, - conn: conn, - log: log, + m := &oauthModule{ + oauthManager: manager, + oauthServer: srv, + db: db, + log: log, } - api.server.SetUserAuthorizationHandler(api.userAuthorizationHandler) - api.server.SetClientInfoHandler(server.ClientFormHandler) - return api + m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler) + m.oauthServer.SetClientInfoHandler(server.ClientFormHandler) + return m } -func (a *API) Route(s api.Server) error { - s.AttachHandler(http.MethodPost, appsPath, a.appsPOSTHandler) +// Route satisfies the RESTAPIModule interface +func (m *oauthModule) Route(s router.Router) error { + s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler) - s.AttachHandler(http.MethodGet, authSignInPath, a.signInGETHandler) - s.AttachHandler(http.MethodPost, authSignInPath, a.signInPOSTHandler) + s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler) + s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler) - s.AttachHandler(http.MethodPost, oauthTokenPath, a.tokenPOSTHandler) + s.AttachHandler(http.MethodPost, oauthTokenPath, m.tokenPOSTHandler) - s.AttachHandler(http.MethodGet, oauthAuthorizePath, a.authorizeGETHandler) - s.AttachHandler(http.MethodPost, oauthAuthorizePath, a.authorizePOSTHandler) + s.AttachHandler(http.MethodGet, oauthAuthorizePath, m.authorizeGETHandler) + s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler) - s.AttachMiddleware(a.oauthTokenMiddleware) + s.AttachMiddleware(m.oauthTokenMiddleware) return nil } @@ -125,8 +131,8 @@ func (a *API) Route(s api.Server) error { // appsPOSTHandler should be served at https://example.org/api/v1/apps // It is equivalent to: https://docs.joinmastodon.org/methods/apps/ -func (a *API) appsPOSTHandler(c *gin.Context) { - l := a.log.WithField("func", "AppsPOSTHandler") +func (m *oauthModule) appsPOSTHandler(c *gin.Context) { + l := m.log.WithField("func", "AppsPOSTHandler") l.Trace("entering AppsPOSTHandler") form := &mastotypes.ApplicationPOSTRequest{} @@ -183,7 +189,7 @@ func (a *API) appsPOSTHandler(c *gin.Context) { } // chuck it in the db - if _, err := a.conn.Model(app).Insert(); err != nil { + if err := m.db.Put(app); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -197,7 +203,7 @@ func (a *API) appsPOSTHandler(c *gin.Context) { } // chuck it in the db - if _, err := a.conn.Model(oc).Insert(); err != nil { + if err := m.db.Put(oc); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -209,16 +215,16 @@ func (a *API) appsPOSTHandler(c *gin.Context) { // signInGETHandler should be served at https://example.org/auth/sign_in. // The idea is to present a sign in page to the user, where they can enter their username and password. // The form will then POST to the sign in page, which will be handled by SignInPOSTHandler -func (a *API) signInGETHandler(c *gin.Context) { - a.log.WithField("func", "SignInGETHandler").Trace("serving sign in html") +func (m *oauthModule) signInGETHandler(c *gin.Context) { + m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html") c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{}) } // signInPOSTHandler should be served at https://example.org/auth/sign_in. // The idea is to present a sign in page to the user, where they can enter their username and password. // The handler will then redirect to the auth handler served at /auth -func (a *API) signInPOSTHandler(c *gin.Context) { - l := a.log.WithField("func", "SignInPOSTHandler") +func (m *oauthModule) signInPOSTHandler(c *gin.Context) { + l := m.log.WithField("func", "SignInPOSTHandler") s := sessions.Default(c) form := &login{} if err := c.ShouldBind(form); err != nil { @@ -227,7 +233,7 @@ func (a *API) signInPOSTHandler(c *gin.Context) { } l.Tracef("parsed form: %+v", form) - userid, err := a.validatePassword(form.Email, form.Password) + userid, err := m.validatePassword(form.Email, form.Password) if err != nil { c.String(http.StatusForbidden, err.Error()) return @@ -246,10 +252,10 @@ func (a *API) signInPOSTHandler(c *gin.Context) { // tokenPOSTHandler should be served as a POST at https://example.org/oauth/token // The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs. // See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token -func (a *API) tokenPOSTHandler(c *gin.Context) { - l := a.log.WithField("func", "TokenPOSTHandler") +func (m *oauthModule) tokenPOSTHandler(c *gin.Context) { + l := m.log.WithField("func", "TokenPOSTHandler") l.Trace("entered TokenPOSTHandler") - if err := a.server.HandleTokenRequest(c.Writer, c.Request); err != nil { + if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } } @@ -257,8 +263,8 @@ func (a *API) tokenPOSTHandler(c *gin.Context) { // authorizeGETHandler should be served as GET at https://example.org/oauth/authorize // The idea here is to present an oauth authorize page to the user, with a button // that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user -func (a *API) authorizeGETHandler(c *gin.Context) { - l := a.log.WithField("func", "AuthorizeGETHandler") +func (m *oauthModule) authorizeGETHandler(c *gin.Context) { + l := m.log.WithField("func", "AuthorizeGETHandler") s := sessions.Default(c) // UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow @@ -283,7 +289,7 @@ func (a *API) authorizeGETHandler(c *gin.Context) { app := >smodel.Application{ ClientID: clientID, } - if err := a.conn.Model(app).Where("client_id = ?", app.ClientID).Select(); err != nil { + if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) return } @@ -292,7 +298,7 @@ func (a *API) authorizeGETHandler(c *gin.Context) { user := >smodel.User{ ID: userID, } - if err := a.conn.Model(user).Where("id = ?", user.ID).Select(); err != nil { + if err := m.db.GetByID(user.ID, user); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -300,7 +306,8 @@ func (a *API) authorizeGETHandler(c *gin.Context) { acct := >smodel.Account{ ID: user.AccountID, } - if err := a.conn.Model(acct).Where("id = ?", acct.ID).Select(); err != nil { + + if err := m.db.GetByID(acct.ID, acct); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -334,8 +341,8 @@ func (a *API) authorizeGETHandler(c *gin.Context) { // At this point we assume that the user has A) logged in and B) accepted that the app should act for them, // so we should proceed with the authentication flow and generate an oauth token for them if we can. // See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user -func (a *API) authorizePOSTHandler(c *gin.Context) { - l := a.log.WithField("func", "AuthorizePOSTHandler") +func (m *oauthModule) authorizePOSTHandler(c *gin.Context) { + l := m.log.WithField("func", "AuthorizePOSTHandler") s := sessions.Default(c) // At this point we know the user has said 'yes' to allowing the application and oauth client @@ -389,7 +396,7 @@ func (a *API) authorizePOSTHandler(c *gin.Context) { l.Tracef("values on request set to %+v", c.Request.Form) // and proceed with authorization using the oauth2 library - if err := a.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { + if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) } } @@ -399,10 +406,10 @@ func (a *API) authorizePOSTHandler(c *gin.Context) { */ // oauthTokenMiddleware -func (a *API) oauthTokenMiddleware(c *gin.Context) { - l := a.log.WithField("func", "ValidatePassword") +func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) { + l := m.log.WithField("func", "ValidatePassword") l.Trace("entering OauthTokenMiddleware") - if ti, err := a.server.ValidationBearerToken(c.Request); err == nil { + if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil { l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope()) c.Set("authenticated_user", ti.GetUserID()) @@ -419,8 +426,8 @@ func (a *API) oauthTokenMiddleware(c *gin.Context) { // The goal is to authenticate the password against the one for that email // address stored in the database. If OK, we return the userid (a uuid) for that user, // so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. -func (a *API) validatePassword(email string, password string) (userid string, err error) { - l := a.log.WithField("func", "ValidatePassword") +func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) { + l := m.log.WithField("func", "ValidatePassword") // make sure an email/password was provided and bail if not if email == "" || password == "" { @@ -430,7 +437,8 @@ func (a *API) validatePassword(email string, password string) (userid string, er // first we select the user from the database based on email address, bail if no user found for that email gtsUser := >smodel.User{} - if err := a.conn.Model(gtsUser).Where("email = ?", email).Select(); err != nil { + + if err := m.db.GetWhere("email", email, gtsUser); err != nil { l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) return incorrectPassword() } @@ -460,8 +468,8 @@ func incorrectPassword() (string, error) { // userAuthorizationHandler gets the user's ID from the 'userid' field of the request form, // or redirects to the /auth/sign_in page, if this key is not present. -func (a *API) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) { - l := a.log.WithField("func", "UserAuthorizationHandler") +func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) { + l := m.log.WithField("func", "UserAuthorizationHandler") userID = r.FormValue("userid") if userID == "" { return "", errors.New("userid was empty, redirecting to sign in page") diff --git a/internal/oauth/oauth_test.go b/internal/module/oauth/oauth_test.go similarity index 62% rename from internal/oauth/oauth_test.go rename to internal/module/oauth/oauth_test.go index 5ad7cde..adfb40a 100644 --- a/internal/oauth/oauth_test.go +++ b/internal/module/oauth/oauth_test.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package oauth import ( @@ -6,12 +24,11 @@ import ( "testing" "time" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/google/uuid" - "github.com/gotosocial/gotosocial/internal/api" "github.com/gotosocial/gotosocial/internal/config" + "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/gotosocial/internal/gtsmodel" + "github.com/gotosocial/gotosocial/internal/router" "github.com/gotosocial/oauth2/v4" "github.com/sirupsen/logrus" "github.com/stretchr/testify/suite" @@ -22,7 +39,7 @@ type OauthTestSuite struct { suite.Suite tokenStore oauth2.TokenStore clientStore oauth2.ClientStore - conn *pg.DB + db db.DB testAccount *gtsmodel.Account testApplication *gtsmodel.Application testUser *gtsmodel.User @@ -40,7 +57,16 @@ func (suite *OauthTestSuite) SetupSuite() { // because go tests are run within the test package directory, we need to fiddle with the templateconfig // basedir in a way that we wouldn't normally have to do when running the binary, in order to make // the templates actually load - c.TemplateConfig.BaseDir = "../../web/template/" + c.TemplateConfig.BaseDir = "../../../web/template/" + c.DBConfig = &config.DBConfig{ + Type: "postgres", + Address: "localhost", + Port: 5432, + User: "postgres", + Password: "postgres", + Database: "postgres", + ApplicationName: "gotosocial", + } suite.config = c encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) @@ -77,11 +103,16 @@ func (suite *OauthTestSuite) SetupSuite() { // SetupTest creates a postgres connection and creates the oauth_clients table before each test func (suite *OauthTestSuite) SetupTest() { - suite.conn = pg.Connect(&pg.Options{}) - if err := suite.conn.Ping(context.Background()); err != nil { - logrus.Panicf("db connection error: %s", err) + + log := logrus.New() + log.SetLevel(logrus.TraceLevel) + db, err := db.New(context.Background(), suite.config, log) + if err != nil { + logrus.Panicf("error creating database connection: %s", err) } + suite.db = db + models := []interface{}{ &oauthClient{}, &oauthToken{}, @@ -91,29 +122,24 @@ func (suite *OauthTestSuite) SetupTest() { } for _, m := range models { - if err := suite.conn.Model(m).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }); err != nil { + if err := suite.db.CreateTable(m); err != nil { logrus.Panicf("db connection error: %s", err) } } - suite.tokenStore = NewPGTokenStore(context.Background(), suite.conn, logrus.New()) - suite.clientStore = NewPGClientStore(suite.conn) + suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New()) + suite.clientStore = newClientStore(suite.db) - if _, err := suite.conn.Model(suite.testAccount).Insert(); err != nil { + if err := suite.db.Put(suite.testAccount); err != nil { logrus.Panicf("could not insert test account into db: %s", err) } - - if _, err := suite.conn.Model(suite.testUser).Insert(); err != nil { + if err := suite.db.Put(suite.testUser); err != nil { logrus.Panicf("could not insert test user into db: %s", err) } - - if _, err := suite.conn.Model(suite.testClient).Insert(); err != nil { + if err := suite.db.Put(suite.testClient); err != nil { logrus.Panicf("could not insert test client into db: %s", err) } - - if _, err := suite.conn.Model(suite.testApplication).Insert(); err != nil { + if err := suite.db.Put(suite.testApplication); err != nil { logrus.Panicf("could not insert test application into db: %s", err) } @@ -129,25 +155,30 @@ func (suite *OauthTestSuite) TearDownTest() { >smodel.Application{}, } for _, m := range models { - if err := suite.conn.Model(m).DropTable(&orm.DropTableOptions{}); err != nil { - logrus.Panicf("drop table error: %s", err) + if err := suite.db.DropTable(m); err != nil { + logrus.Panicf("error dropping table: %s", err) } } - if err := suite.conn.Close(); err != nil { + if err := suite.db.Stop(context.Background()); err != nil { logrus.Panicf("error closing db connection: %s", err) } - suite.conn = nil + suite.db = nil } func (suite *OauthTestSuite) TestAPIInitialize() { log := logrus.New() log.SetLevel(logrus.TraceLevel) - r := api.New(suite.config, log) - api := New(suite.tokenStore, suite.clientStore, suite.conn, log) - if err := api.Route(r); err != nil { - suite.FailNow(fmt.Sprintf("error initializing api: %s", err)) + r, err := router.New(suite.config, log) + if err != nil { + suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) } + + api := New(suite.tokenStore, suite.clientStore, suite.db, log) + if err := api.Route(r); err != nil { + suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) + } + go r.Start() time.Sleep(60 * time.Second) // http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=http://localhost:8080&scope=read diff --git a/internal/oauth/pgtokenstore.go b/internal/module/oauth/tokenstore.go similarity index 72% rename from internal/oauth/pgtokenstore.go rename to internal/module/oauth/tokenstore.go index 0271afa..d8a6d58 100644 --- a/internal/oauth/pgtokenstore.go +++ b/internal/module/oauth/tokenstore.go @@ -24,31 +24,31 @@ import ( "fmt" "time" - "github.com/go-pg/pg/v10" + "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/oauth2/v4" "github.com/gotosocial/oauth2/v4/models" "github.com/sirupsen/logrus" ) -// pgTokenStore is an implementation of oauth2.TokenStore, which uses Postgres as a storage backend. -type pgTokenStore struct { +// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. +type tokenStore struct { oauth2.TokenStore - conn *pg.DB - log *logrus.Logger + db db.DB + log *logrus.Logger } -// NewPGTokenStore returns a token store, using postgres, that satisfies the oauth2.TokenStore interface. +// newTokenStore returns a token store that satisfies the oauth2.TokenStore interface. // -// In order to allow tokens to 'expire' (not really a thing in Postgres world), it will also set off a -// goroutine that iterates through the tokens in the DB once per minute and deletes any that have expired. -func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth2.TokenStore { - pts := &pgTokenStore{ - conn: conn, - log: log, +// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through +// the tokens in the DB once per minute and deletes any that have expired. +func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.TokenStore { + pts := &tokenStore{ + db: db, + log: log, } // set the token store to clean out expired tokens once per minute, or return if we're done - go func(ctx context.Context, pts *pgTokenStore, log *logrus.Logger) { + go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) { cleanloop: for { select { @@ -67,22 +67,22 @@ func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth } // sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. -func (pts *pgTokenStore) sweep() error { +func (pts *tokenStore) sweep() error { // select *all* tokens from the db // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. - var tokens []oauthToken - if err := pts.conn.Model(&tokens).Select(); err != nil { + tokens := new([]*oauthToken) + if err := pts.db.GetAll(tokens); err != nil { return err } // iterate through and remove expired tokens now := time.Now() - for _, pgt := range tokens { + for _, pgt := range *tokens { // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: // we only want to check if a token expired before now if the expiry time is *not zero*; // ie., if it's been explicity set. if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) { - if _, err := pts.conn.Model(&pgt).Delete(); err != nil { + if err := pts.db.DeleteByID(pgt.ID, &pgt); err != nil { return err } } @@ -93,68 +93,61 @@ func (pts *pgTokenStore) sweep() error { // Create creates and store the new token information. // For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34 -func (pts *pgTokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { +func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { t, ok := info.(*models.Token) if !ok { return errors.New("info param was not a models.Token") } - _, err := pts.conn.WithContext(ctx).Model(oauthTokenToPGToken(t)).Insert() - if err != nil { + if err := pts.db.Put(oauthTokenToPGToken(t)); err != nil { return fmt.Errorf("error in tokenstore create: %s", err) } return nil } // RemoveByCode deletes a token from the DB based on the Code field -func (pts *pgTokenStore) RemoveByCode(ctx context.Context, code string) error { - _, err := pts.conn.Model(&oauthToken{}).Where("code = ?", code).Delete() - if err != nil { - return fmt.Errorf("error in tokenstore removebycode: %s", err) - } - return nil +func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error { + return pts.db.DeleteWhere("code", code, &oauthToken{}) } // RemoveByAccess deletes a token from the DB based on the Access field -func (pts *pgTokenStore) RemoveByAccess(ctx context.Context, access string) error { - _, err := pts.conn.Model(&oauthToken{}).Where("access = ?", access).Delete() - if err != nil { - return fmt.Errorf("error in tokenstore removebyaccess: %s", err) - } - return nil +func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { + return pts.db.DeleteWhere("access", access, &oauthToken{}) } // RemoveByRefresh deletes a token from the DB based on the Refresh field -func (pts *pgTokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { - _, err := pts.conn.Model(&oauthToken{}).Where("refresh = ?", refresh).Delete() - if err != nil { - return fmt.Errorf("error in tokenstore removebyrefresh: %s", err) - } - return nil +func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { + return pts.db.DeleteWhere("refresh", refresh, &oauthToken{}) } // GetByCode selects a token from the DB based on the Code field -func (pts *pgTokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - pgt := &oauthToken{} - if err := pts.conn.Model(pgt).Where("code = ?", code).Select(); err != nil { - return nil, fmt.Errorf("error in tokenstore getbycode: %s", err) +func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{ + Code: code, + } + if err := pts.db.GetWhere("code", code, pgt); err != nil { + return nil, err } return pgTokenToOauthToken(pgt), nil } // GetByAccess selects a token from the DB based on the Access field -func (pts *pgTokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { - pgt := &oauthToken{} - if err := pts.conn.Model(pgt).Where("access = ?", access).Select(); err != nil { - return nil, fmt.Errorf("error in tokenstore getbyaccess: %s", err) +func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{ + Access: access, + } + if err := pts.db.GetWhere("access", access, pgt); err != nil { + return nil, err } return pgTokenToOauthToken(pgt), nil } // GetByRefresh selects a token from the DB based on the Refresh field -func (pts *pgTokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - pgt := &oauthToken{} - if err := pts.conn.Model(pgt).Where("refresh = ?", refresh).Select(); err != nil { - return nil, fmt.Errorf("error in tokenstore getbyrefresh: %s", err) +func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{ + Refresh: refresh, + } + if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil { + return nil, err } return pgTokenToOauthToken(pgt), nil } @@ -174,6 +167,7 @@ func (pts *pgTokenStore) GetByRefresh(ctx context.Context, refresh string) (oaut // As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken // and pgTokenToOauthToken can be used for that. type oauthToken struct { + ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` ClientID string UserID string RedirectURI string diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 0000000..3893503 --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,120 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package router + +import ( + "crypto/rand" + "fmt" + "os" + "path/filepath" + + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/memstore" + "github.com/gin-gonic/gin" + "github.com/gotosocial/gotosocial/internal/config" + "github.com/sirupsen/logrus" +) + +// Router provides the REST interface for gotosocial, using gin. +type Router interface { + // Attach a gin handler to the router with the given method and path + AttachHandler(method string, path string, handler gin.HandlerFunc) + // Attach a gin middleware to the router that will be used globally + AttachMiddleware(handler gin.HandlerFunc) + // Start the router + Start() + // Stop the router + Stop() +} + +// router fulfils the Router interface using gin and logrus +type router struct { + logger *logrus.Logger + engine *gin.Engine +} + +// Start starts the router nicely +func (s *router) Start() { + // todo: start gracefully + if err := s.engine.Run(); err != nil { + s.logger.Panicf("server error: %s", err) + } +} + +// Stop shuts down the router nicely +func (s *router) Stop() { + // todo: shut down gracefully +} + +// AttachHandler attaches the given gin.HandlerFunc to the router with the specified method and path. +// If the path is set to ANY, then the handlerfunc will be used for ALL methods at its given path. +func (s *router) AttachHandler(method string, path string, handler gin.HandlerFunc) { + if method == "ANY" { + s.engine.Any(path, handler) + } else { + s.engine.Handle(method, path, handler) + } +} + +// AttachMiddleware attaches a gin middleware to the router that will be used globally +func (s *router) AttachMiddleware(middleware gin.HandlerFunc) { + s.engine.Use(middleware) +} + +// New returns a new Router with the specified configuration, using the given logrus logger. +func New(config *config.Config, logger *logrus.Logger) (Router, error) { + engine := gin.New() + + // create a new session store middleware + store, err := sessionStore() + if err != nil { + return nil, fmt.Errorf("error creating session store: %s", err) + } + engine.Use(sessions.Sessions("gotosocial-session", store)) + + // load html templates for use by the router + cwd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("error getting current working directory: %s", err) + } + tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", config.TemplateConfig.BaseDir)) + logger.Debugf("loading templates from %s", tmPath) + engine.LoadHTMLGlob(tmPath) + + return &router{ + logger: logger, + engine: engine, + }, nil +} + +// sessionStore returns a new session store with a random auth and encryption key. +// This means that cookies using the store will be reset if gotosocial is restarted! +func sessionStore() (memstore.Store, error) { + auth := make([]byte, 32) + crypt := make([]byte, 32) + + if _, err := rand.Read(auth); err != nil { + return nil, err + } + if _, err := rand.Read(crypt); err != nil { + return nil, err + } + + return memstore.NewStore(auth, crypt), nil +}