From aa9ce272dcfa1380b2f05bc3a90ef8ca1b0a7f62 Mon Sep 17 00:00:00 2001 From: Tobi Smethurst <31960611+tsmethurst@users.noreply.github.com> Date: Mon, 22 Mar 2021 22:26:54 +0100 Subject: [PATCH] Oauth/token (#7) * add host and protocol options * some fiddling * tidying up and comments * tick off /oauth/token * tidying a bit * tidying * go mod tidy * allow attaching middleware to server * add middleware * more user friendly * add comments * comments * store account + app * tidying * lots of restructuring * lint + tidy --- PROGRESS.md | 2 +- cmd/gotosocial/main.go | 12 + example/config.yaml | 11 + go.mod | 2 +- go.sum | 4 +- internal/api/server.go | 87 --- internal/config/config.go | 16 + internal/db/actions.go | 11 +- internal/db/db.go | 47 +- internal/db/pg-fed.go | 137 +++++ internal/db/{postgres.go => pg.go} | 208 ++----- internal/email/email.go | 2 +- internal/federation/federation.go | 4 +- internal/gotosocial/actions.go | 6 +- internal/gotosocial/gotosocial.go | 8 +- internal/gtsmodel/account.go | 2 +- internal/gtsmodel/application.go | 41 +- internal/gtsmodel/status.go | 49 +- .../account/account.go} | 20 +- internal/module/module.go | 29 + internal/{ => module}/oauth/README.md | 4 +- .../oauth/clientstore.go} | 34 +- .../oauth/clientstore_test.go} | 73 ++- internal/module/oauth/oauth.go | 510 ++++++++++++++++++ internal/module/oauth/oauth_test.go | 191 +++++++ .../oauth/tokenstore.go} | 96 ++-- internal/oauth/oauth.go | 446 --------------- internal/oauth/oauth_test.go | 133 ----- internal/router/router.go | 120 +++++ web/template/authorize.tmpl | 18 +- 30 files changed, 1346 insertions(+), 977 deletions(-) delete mode 100644 internal/api/server.go create mode 100644 internal/db/pg-fed.go rename internal/db/{postgres.go => pg.go} (56%) rename internal/{api/route_statuses.go => module/account/account.go} (65%) create mode 100644 internal/module/module.go rename internal/{ => module}/oauth/README.md (55%) rename internal/{oauth/pgclientstore.go => module/oauth/clientstore.go} (56%) rename internal/{oauth/pgclientstore_test.go => module/oauth/clientstore_test.go} (59%) create mode 100644 internal/module/oauth/oauth.go create mode 100644 internal/module/oauth/oauth_test.go rename internal/{oauth/pgtokenstore.go => module/oauth/tokenstore.go} (72%) delete mode 100644 internal/oauth/oauth.go delete mode 100644 internal/oauth/oauth_test.go create mode 100644 internal/router/router.go diff --git a/PROGRESS.md b/PROGRESS.md index 9c6c784..079f47f 100644 --- a/PROGRESS.md +++ b/PROGRESS.md @@ -6,7 +6,7 @@ * [ ] /api/v1/apps/verify_credentials GET (Verify an application works) * [x] /oauth/authorize GET (Show authorize page to user) * [x] /oauth/authorize POST (Get an oauth access code for an app/user) - * [ ] /oauth/token POST (Obtain a user-level access token) + * [x] /oauth/token POST (Obtain a user-level access token) * [ ] /oauth/revoke POST (Revoke a user-level access token) * [x] /auth/sign_in GET (Show form for user signin) * [x] /auth/sign_in POST (Validate username and password and sign user in) diff --git a/cmd/gotosocial/main.go b/cmd/gotosocial/main.go index 9679803..0919d5f 100644 --- a/cmd/gotosocial/main.go +++ b/cmd/gotosocial/main.go @@ -58,6 +58,18 @@ func main() { Value: "", EnvVars: []string{envNames.ConfigPath}, }, + &cli.StringFlag{ + Name: flagNames.Host, + Usage: "Hostname to use for the server (eg., example.org, gotosocial.whatever.com)", + Value: "localhost", + EnvVars: []string{envNames.Host}, + }, + &cli.StringFlag{ + Name: flagNames.Protocol, + Usage: "Protocol to use for the REST api of the server (only use http for debugging and tests!)", + Value: "https", + EnvVars: []string{envNames.Protocol}, + }, // DATABASE FLAGS &cli.StringFlag{ diff --git a/example/config.yaml b/example/config.yaml index b65149d..58766a2 100644 --- a/example/config.yaml +++ b/example/config.yaml @@ -28,6 +28,17 @@ logLevel: "info" # Default: "gotosocial" applicationName: "gotosocial" +# String. Hostname/domain to use for the server. Defaults to localhost for local testing, +# but you should *definitely* change this when running for real, or your server won't work at all. +# Examples: ["example.org","some.server.com"] +# Default: "localhost" +host: "localhost" + +# String. Protocol to use for the server. Only change to http for local testing! +# Options: ["http","https"] +# Default: "https" +protocol: "https" + # Config pertaining to the Gotosocial database connection db: # String. Database type. diff --git a/go.mod b/go.mod index 4d13b11..e913e34 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/go-pg/pg/v10 v10.8.0 github.com/golang/mock v1.4.4 // indirect github.com/google/uuid v1.2.0 - github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3 + github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88 github.com/onsi/ginkgo v1.15.0 // indirect github.com/onsi/gomega v1.10.5 // indirect github.com/sirupsen/logrus v1.8.0 diff --git a/go.sum b/go.sum index c7f3098..0338e7d 100644 --- a/go.sum +++ b/go.sum @@ -103,8 +103,8 @@ github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9R github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3 h1:CKRz5d7mRum+UMR88Ue33tCYcej14WjUsB59C02DDqY= -github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3/go.mod h1:zl5kwHf/atRUrY5yOyDnk49Us1Ygs0BzdW4jKAgoiP8= +github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88 h1:YJ//HmHOYJ4srm/LA6VPNjNisneMbY6TTM1xttV/ZQU= +github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88/go.mod h1:zl5kwHf/atRUrY5yOyDnk49Us1Ygs0BzdW4jKAgoiP8= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk= github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= diff --git a/internal/api/server.go b/internal/api/server.go deleted file mode 100644 index 8e22742..0000000 --- a/internal/api/server.go +++ /dev/null @@ -1,87 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . -*/ - -package api - -import ( - "fmt" - "os" - "path/filepath" - - "github.com/gin-contrib/sessions" - "github.com/gin-contrib/sessions/memstore" - "github.com/gin-gonic/gin" - "github.com/gotosocial/gotosocial/internal/config" - "github.com/sirupsen/logrus" -) - -type Server interface { - AttachHandler(method string, path string, handler gin.HandlerFunc) - // AttachMiddleware(handler gin.HandlerFunc) - GetAPIGroup() *gin.RouterGroup - Start() - Stop() -} - -type AddsRoutes interface { - AddRoutes(s Server) error -} - -type server struct { - APIGroup *gin.RouterGroup - logger *logrus.Logger - engine *gin.Engine -} - -func (s *server) GetAPIGroup() *gin.RouterGroup { - return s.APIGroup -} - -func (s *server) Start() { - // todo: start gracefully - if err := s.engine.Run(); err != nil { - s.logger.Panicf("server error: %s", err) - } -} - -func (s *server) Stop() { - // todo: shut down gracefully -} - -func (s *server) AttachHandler(method string, path string, handler gin.HandlerFunc) { - if method == "ANY" { - s.engine.Any(path, handler) - } else { - s.engine.Handle(method, path, handler) - } -} - -func New(config *config.Config, logger *logrus.Logger) Server { - engine := gin.New() - store := memstore.NewStore([]byte("authentication-key"), []byte("encryption-keyencryption-key----")) - engine.Use(sessions.Sessions("gotosocial-session", store)) - cwd, _ := os.Getwd() - tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", config.TemplateConfig.BaseDir)) - logger.Debugf("loading templates from %s", tmPath) - engine.LoadHTMLGlob(tmPath) - return &server{ - APIGroup: engine.Group("/api").Group("/v1"), - logger: logger, - engine: engine, - } -} diff --git a/internal/config/config.go b/internal/config/config.go index ce194cd..dca325c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -29,6 +29,8 @@ import ( type Config struct { LogLevel string `yaml:"logLevel"` ApplicationName string `yaml:"applicationName"` + Host string `yaml:"host"` + Protocol string `yaml:"protocol"` DBConfig *DBConfig `yaml:"db"` TemplateConfig *TemplateConfig `yaml:"template"` } @@ -97,6 +99,14 @@ func (c *Config) ParseCLIFlags(f KeyedFlags) { c.ApplicationName = f.String(fn.ApplicationName) } + if c.Host == "" || f.IsSet(fn.Host) { + c.Host = f.String(fn.Host) + } + + if c.Protocol == "" || f.IsSet(fn.Protocol) { + c.Protocol = f.String(fn.Protocol) + } + // db flags if c.DBConfig.Type == "" || f.IsSet(fn.DbType) { c.DBConfig.Type = f.String(fn.DbType) @@ -142,6 +152,8 @@ type Flags struct { LogLevel string ApplicationName string ConfigPath string + Host string + Protocol string DbType string DbAddress string DbPort string @@ -158,6 +170,8 @@ func GetFlagNames() Flags { LogLevel: "log-level", ApplicationName: "application-name", ConfigPath: "config-path", + Host: "host", + Protocol: "protocol", DbType: "db-type", DbAddress: "db-address", DbPort: "db-port", @@ -175,6 +189,8 @@ func GetEnvNames() Flags { LogLevel: "GTS_LOG_LEVEL", ApplicationName: "GTS_APPLICATION_NAME", ConfigPath: "GTS_CONFIG_PATH", + Host: "GTS_HOST", + Protocol: "GTS_PROTOCOL", DbType: "GTS_DB_TYPE", DbAddress: "GTS_DB_ADDRESS", DbPort: "GTS_DB_PORT", diff --git a/internal/db/actions.go b/internal/db/actions.go index 6fa7d23..01fb44b 100644 --- a/internal/db/actions.go +++ b/internal/db/actions.go @@ -28,9 +28,10 @@ import ( // Initialize will initialize the database given in the config for use with GoToSocial var Initialize action.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - db, err := New(ctx, c, log) - if err != nil { - return err - } - return db.CreateSchema(ctx) + // db, err := New(ctx, c, log) + // if err != nil { + // return err + // } + return nil + // return db.CreateSchema(ctx) } diff --git a/internal/db/db.go b/internal/db/db.go index 4ea4e1a..9952e5e 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -30,30 +30,47 @@ import ( const dbTypePostgres string = "POSTGRES" -// DB provides methods for interacting with an underlying database (for now, just postgres). -// The function mapping lines up with the DB interface described in go-fed. -// See here: https://github.com/go-fed/activity/blob/master/pub/database.go +// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres). type DB interface { - /* - GO-FED DATABASE FUNCTIONS - */ - pub.Database + // Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions. + // See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database + Federation() pub.Database - /* - ANY ADDITIONAL DESIRED FUNCTIONS - */ + // CreateTable creates a table for the given interface + CreateTable(i interface{}) error - // CreateSchema should populate the database with the required tables - CreateSchema(context.Context) error + // DropTable drops the table for the given interface + DropTable(i interface{}) error // Stop should stop and close the database connection cleanly, returning an error if this is not possible - Stop(context.Context) error + Stop(ctx context.Context) error // IsHealthy should return nil if the database connection is healthy, or an error if not - IsHealthy(context.Context) error + IsHealthy(ctx context.Context) error + + // GetByID gets one entry by its id. + GetByID(id string, i interface{}) error + + // GetWhere gets one entry where key = value + GetWhere(key string, value interface{}, i interface{}) error + + // GetAll gets all entries of interface type i + GetAll(i interface{}) error + + // Put stores i + Put(i interface{}) error + + // Update by id updates i with id id + UpdateByID(id string, i interface{}) error + + // Delete by id removes i with id id + DeleteByID(id string, i interface{}) error + + // Delete where deletes i where key = value + DeleteWhere(key string, value interface{}, i interface{}) error } -// New returns a new database service that satisfies the Service interface and, by extension, +// New returns a new database service that satisfies the DB interface and, by extension, // the go-fed database interface described here: https://github.com/go-fed/activity/blob/master/pub/database.go func New(ctx context.Context, c *config.Config, log *logrus.Logger) (DB, error) { switch strings.ToUpper(c.DBConfig.Type) { diff --git a/internal/db/pg-fed.go b/internal/db/pg-fed.go new file mode 100644 index 0000000..ec1957a --- /dev/null +++ b/internal/db/pg-fed.go @@ -0,0 +1,137 @@ +package db + +import ( + "context" + "errors" + "net/url" + "sync" + + "github.com/go-fed/activity/pub" + "github.com/go-fed/activity/streams" + "github.com/go-fed/activity/streams/vocab" + "github.com/go-pg/pg/v10" +) + +type postgresFederation struct { + locks *sync.Map + conn *pg.DB +} + +func newPostgresFederation(conn *pg.DB) pub.Database { + return &postgresFederation{ + locks: new(sync.Map), + conn: conn, + } +} + +/* + GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS +*/ +func (pf *postgresFederation) Lock(ctx context.Context, id *url.URL) error { + // Before any other Database methods are called, the relevant `id` + // entries are locked to allow for fine-grained concurrency. + + // Strategy: create a new lock, if stored, continue. Otherwise, lock the + // existing mutex. + mu := &sync.Mutex{} + mu.Lock() // Optimistically lock if we do store it. + i, loaded := pf.locks.LoadOrStore(id.String(), mu) + if loaded { + mu = i.(*sync.Mutex) + mu.Lock() + } + return nil +} + +func (pf *postgresFederation) Unlock(ctx context.Context, id *url.URL) error { + // Once Go-Fed is done calling Database methods, the relevant `id` + // entries are unlocked. + + i, ok := pf.locks.Load(id.String()) + if !ok { + return errors.New("missing an id in unlock") + } + mu := i.(*sync.Mutex) + mu.Unlock() + return nil +} + +func (pf *postgresFederation) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) { + return false, nil +} + +func (pf *postgresFederation) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { + return nil, nil +} + +func (pf *postgresFederation) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error { + return nil +} + +func (pf *postgresFederation) Owns(ctx context.Context, id *url.URL) (owns bool, err error) { + return false, nil +} + +func (pf *postgresFederation) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) { + return nil, nil +} + +func (pf *postgresFederation) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) { + return nil, nil +} + +func (pf *postgresFederation) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) { + return nil, nil +} + +func (pf *postgresFederation) Exists(ctx context.Context, id *url.URL) (exists bool, err error) { + return false, nil +} + +func (pf *postgresFederation) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) { + return nil, nil +} + +func (pf *postgresFederation) Create(ctx context.Context, asType vocab.Type) error { + t, err := streams.NewTypeResolver() + if err != nil { + return err + } + if err := t.Resolve(ctx, asType); err != nil { + return err + } + asType.GetTypeName() + return nil +} + +func (pf *postgresFederation) Update(ctx context.Context, asType vocab.Type) error { + return nil +} + +func (pf *postgresFederation) Delete(ctx context.Context, id *url.URL) error { + return nil +} + +func (pf *postgresFederation) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { + return nil, nil +} + +func (pf *postgresFederation) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error { + return nil +} + +func (pf *postgresFederation) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) { + return nil, nil +} + +func (pf *postgresFederation) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { + return nil, nil +} + +func (pf *postgresFederation) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { + return nil, nil +} + +func (pf *postgresFederation) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { + return nil, nil +} diff --git a/internal/db/postgres.go b/internal/db/pg.go similarity index 56% rename from internal/db/postgres.go rename to internal/db/pg.go index dae6b11..487af18 100644 --- a/internal/db/postgres.go +++ b/internal/db/pg.go @@ -22,30 +22,26 @@ import ( "context" "errors" "fmt" - "net/url" "regexp" "strings" - "sync" "time" - "github.com/go-fed/activity/streams" - "github.com/go-fed/activity/streams/vocab" + "github.com/go-fed/activity/pub" "github.com/go-pg/pg/extra/pgdebug" "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" "github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/gtsmodel" - "github.com/gotosocial/oauth2/v4" "github.com/sirupsen/logrus" ) +// postgresService satisfies the DB interface type postgresService struct { - config *config.DBConfig - conn *pg.DB - log *logrus.Entry - cancel context.CancelFunc - locks *sync.Map - tokenStore oauth2.TokenStore + config *config.DBConfig + conn *pg.DB + log *logrus.Entry + cancel context.CancelFunc + federationDB pub.Database } // newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. @@ -102,36 +98,20 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry return nil, errors.New("db connection timeout") } - // acc := model.StubAccount() - // if _, err := conn.Model(acc).Returning("id").Insert(); err != nil { - // cancel() - // return nil, fmt.Errorf("db insert error: %s", err) - // } - // log.Infof("created account with id %s", acc.ID) - - // note := &model.Note{ - // Visibility: &model.Visibility{ - // Local: true, - // }, - // CreatedAt: time.Now(), - // UpdatedAt: time.Now(), - // } - // if _, err := conn.WithContext(ctx).Model(note).Returning("id").Insert(); err != nil { - // cancel() - // return nil, fmt.Errorf("db insert error: %s", err) - // } - // log.Infof("created note with id %s", note.ID) - // we can confidently return this useable postgres service now return &postgresService{ - config: c.DBConfig, - conn: conn, - log: log, - cancel: cancel, - locks: &sync.Map{}, + config: c.DBConfig, + conn: conn, + log: log, + cancel: cancel, + federationDB: newPostgresFederation(conn), }, nil } +func (ps *postgresService) Federation() pub.Database { + return ps.federationDB +} + /* HANDY STUFF */ @@ -187,118 +167,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { return options, nil } -/* - GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS -*/ -func (ps *postgresService) Lock(ctx context.Context, id *url.URL) error { - // Before any other Database methods are called, the relevant `id` - // entries are locked to allow for fine-grained concurrency. - - // Strategy: create a new lock, if stored, continue. Otherwise, lock the - // existing mutex. - mu := &sync.Mutex{} - mu.Lock() // Optimistically lock if we do store it. - i, loaded := ps.locks.LoadOrStore(id.String(), mu) - if loaded { - mu = i.(*sync.Mutex) - mu.Lock() - } - return nil -} - -func (ps *postgresService) Unlock(ctx context.Context, id *url.URL) error { - // Once Go-Fed is done calling Database methods, the relevant `id` - // entries are unlocked. - - i, ok := ps.locks.Load(id.String()) - if !ok { - return errors.New("missing an id in unlock") - } - mu := i.(*sync.Mutex) - mu.Unlock() - return nil -} - -func (ps *postgresService) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) { - return false, nil -} - -func (ps *postgresService) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { - return nil, nil -} - -func (ps *postgresService) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error { - return nil -} - -func (ps *postgresService) Owns(ctx context.Context, id *url.URL) (owns bool, err error) { - return false, nil -} - -func (ps *postgresService) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) { - return nil, nil -} - -func (ps *postgresService) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) { - return nil, nil -} - -func (ps *postgresService) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) { - return nil, nil -} - -func (ps *postgresService) Exists(ctx context.Context, id *url.URL) (exists bool, err error) { - return false, nil -} - -func (ps *postgresService) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) { - return nil, nil -} - -func (ps *postgresService) Create(ctx context.Context, asType vocab.Type) error { - t, err := streams.NewTypeResolver() - if err != nil { - return err - } - if err := t.Resolve(ctx, asType); err != nil { - return err - } - asType.GetTypeName() - return nil -} - -func (ps *postgresService) Update(ctx context.Context, asType vocab.Type) error { - return nil -} - -func (ps *postgresService) Delete(ctx context.Context, id *url.URL) error { - return nil -} - -func (ps *postgresService) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { - return nil, nil -} - -func (ps *postgresService) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error { - return nil -} - -func (ps *postgresService) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) { - return nil, nil -} - -func (ps *postgresService) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { - return nil, nil -} - -func (ps *postgresService) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { - return nil, nil -} - -func (ps *postgresService) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { - return nil, nil -} - /* EXTRA FUNCTIONS */ @@ -338,6 +206,46 @@ func (ps *postgresService) IsHealthy(ctx context.Context) error { return ps.conn.Ping(ctx) } -func (ps *postgresService) TokenStore() oauth2.TokenStore { - return ps.tokenStore +func (ps *postgresService) CreateTable(i interface{}) error { + return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{ + IfNotExists: true, + }) +} + +func (ps *postgresService) DropTable(i interface{}) error { + return ps.conn.Model(i).DropTable(&orm.DropTableOptions{ + IfExists: true, + }) +} + +func (ps *postgresService) GetByID(id string, i interface{}) error { + return ps.conn.Model(i).Where("id = ?", id).Select() +} + +func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error { + return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select() +} + +func (ps *postgresService) GetAll(i interface{}) error { + return ps.conn.Model(i).Select() +} + +func (ps *postgresService) Put(i interface{}) error { + _, err := ps.conn.Model(i).Insert(i) + return err +} + +func (ps *postgresService) UpdateByID(id string, i interface{}) error { + _, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert() + return err +} + +func (ps *postgresService) DeleteByID(id string, i interface{}) error { + _, err := ps.conn.Model(i).Where("id = ?", id).Delete() + return err +} + +func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error { + _, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete() + return err } diff --git a/internal/email/email.go b/internal/email/email.go index d70f6c5..3d6a9dd 100644 --- a/internal/email/email.go +++ b/internal/email/email.go @@ -16,5 +16,5 @@ along with this program. If not, see . */ -// package email provides a service for interacting with an SMTP server +// Package email provides a service for interacting with an SMTP server package email diff --git a/internal/federation/federation.go b/internal/federation/federation.go index ebc9102..cbd4eda 100644 --- a/internal/federation/federation.go +++ b/internal/federation/federation.go @@ -30,11 +30,13 @@ import ( "github.com/gotosocial/gotosocial/internal/db" ) +// New returns a go-fed compatible federating actor func New(db db.DB) pub.FederatingActor { fa := &API{} - return pub.NewFederatingActor(fa, fa, db, fa) + return pub.NewFederatingActor(fa, fa, db.Federation(), fa) } +// API implements several go-fed interfaces in one convenient location type API struct { } diff --git a/internal/gotosocial/actions.go b/internal/gotosocial/actions.go index 3d3fdc3..398c0b4 100644 --- a/internal/gotosocial/actions.go +++ b/internal/gotosocial/actions.go @@ -38,9 +38,9 @@ var Run action.GTSAction = func(ctx context.Context, c *config.Config, log *logr return fmt.Errorf("error creating dbservice: %s", err) } - if err := dbService.CreateSchema(ctx); err != nil { - return fmt.Errorf("error creating dbschema: %s", err) - } + // if err := dbService.CreateSchema(ctx); err != nil { + // return fmt.Errorf("error creating dbschema: %s", err) + // } // catch shutdown signals from the operating system sigs := make(chan os.Signal, 1) diff --git a/internal/gotosocial/gotosocial.go b/internal/gotosocial/gotosocial.go index 4409e85..d9fb295 100644 --- a/internal/gotosocial/gotosocial.go +++ b/internal/gotosocial/gotosocial.go @@ -22,10 +22,10 @@ import ( "context" "github.com/go-fed/activity/pub" - "github.com/gotosocial/gotosocial/internal/api" "github.com/gotosocial/gotosocial/internal/cache" "github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/db" + "github.com/gotosocial/gotosocial/internal/router" ) type Gotosocial interface { @@ -33,11 +33,11 @@ type Gotosocial interface { Stop(context.Context) error } -func New(db db.DB, cache cache.Cache, clientAPI api.Server, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) { +func New(db db.DB, cache cache.Cache, apiRouter router.Router, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) { return &gotosocial{ db: db, cache: cache, - clientAPI: clientAPI, + apiRouter: apiRouter, federationAPI: federationAPI, config: config, }, nil @@ -46,7 +46,7 @@ func New(db db.DB, cache cache.Cache, clientAPI api.Server, federationAPI pub.Fe type gotosocial struct { db db.DB cache cache.Cache - clientAPI api.Server + apiRouter router.Router federationAPI pub.FederatingActor config *config.Config } diff --git a/internal/gtsmodel/account.go b/internal/gtsmodel/account.go index 6786014..6c17b90 100644 --- a/internal/gtsmodel/account.go +++ b/internal/gtsmodel/account.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -// package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database. +// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database. // These types should never be serialized and/or sent out via public APIs, as they contain sensitive information. // The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/ package gtsmodel diff --git a/internal/gtsmodel/application.go b/internal/gtsmodel/application.go index c0d6ddd..fd0fa6a 100644 --- a/internal/gtsmodel/application.go +++ b/internal/gtsmodel/application.go @@ -18,13 +18,38 @@ package gtsmodel +import "github.com/gotosocial/gotosocial/pkg/mastotypes" + +// Application represents an application that can perform actions on behalf of a user. +// It is used to authorize tokens etc, and is associated with an oauth client id in the database. type Application struct { - ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` - Name string - Website string - RedirectURI string `json:"redirect_uri"` - ClientID string `json:"client_id"` - ClientSecret string `json:"client_secret"` - Scopes string `json:"scopes"` - VapidKey string `json:"vapid_key"` + // id of this application in the db + ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` + // name of the application given when it was created (eg., 'tusky') + Name string + // website for the application given when it was created (eg., 'https://tusky.app') + Website string + // redirect uri requested by the application for oauth2 flow + RedirectURI string + // id of the associated oauth client entity in the db + ClientID string + // secret of the associated oauth client entity in the db + ClientSecret string + // scopes requested when this app was created + Scopes string + // a vapid key generated for this app when it was created + VapidKey string +} + +// ToMastotype returns this application as a mastodon api type, ready for serialization +func (a *Application) ToMastotype() *mastotypes.Application { + return &mastotypes.Application{ + ID: a.ID, + Name: a.Name, + Website: a.Website, + RedirectURI: a.RedirectURI, + ClientID: a.ClientID, + ClientSecret: a.ClientSecret, + VapidKey: a.VapidKey, + } } diff --git a/internal/gtsmodel/status.go b/internal/gtsmodel/status.go index 22e88c0..1c0e920 100644 --- a/internal/gtsmodel/status.go +++ b/internal/gtsmodel/status.go @@ -20,25 +20,44 @@ package gtsmodel import "time" +// Status represents a user-created 'post' or 'status' in the database, either remote or local type Status struct { - ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` - URI string `pg:",unique"` - URL string `pg:",unique"` - Content string - CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` - UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` - Local bool - AccountID string - InReplyToID string - BoostOfID string + // id of the status in the database + ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` + // uri at which this status is reachable + URI string `pg:",unique"` + // web url for viewing this status + URL string `pg:",unique"` + // the html-formatted content of this status + Content string + // when was this status created? + CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + // when was this status updated? + UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` + // is this status from a local account? + Local bool + // which account posted this status? + AccountID string + // id of the status this status is a reply to + InReplyToID string + // id of the status this status is a boost of + BoostOfID string + // cw string for this status ContentWarning string - Visibility *Visibility + // visibility entry for this status + Visibility *Visibility } +// Visibility represents the visibility granularity of a status. It is a combination of flags. type Visibility struct { - Direct bool + // Is this status viewable as a direct message? + Direct bool + // Is this status viewable to followers? Followers bool - Local bool - Unlisted bool - Public bool + // Is this status viewable on the local timeline? + Local bool + // Is this status boostable but not shown on public timelines? + Unlisted bool + // Is this status shown on public and federated timelines? + Public bool } diff --git a/internal/api/route_statuses.go b/internal/module/account/account.go similarity index 65% rename from internal/api/route_statuses.go rename to internal/module/account/account.go index fe8aa96..d82d96e 100644 --- a/internal/api/route_statuses.go +++ b/internal/module/account/account.go @@ -16,4 +16,22 @@ along with this program. If not, see . */ -package api +package account + +import ( + "github.com/gotosocial/gotosocial/internal/module" + "github.com/gotosocial/gotosocial/internal/router" +) + +type accountModule struct { +} + +// New returns a new account module +func New() module.ClientAPIModule { + return &accountModule{} +} + +// Route attaches all routes from this module to the given router +func (m *accountModule) Route(r router.Router) error { + return nil +} diff --git a/internal/module/module.go b/internal/module/module.go new file mode 100644 index 0000000..8618d28 --- /dev/null +++ b/internal/module/module.go @@ -0,0 +1,29 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +// Package module is basically a wrapper for a lot of modules (in subdirectories) that satisfy the ClientAPIModule interface. +package module + +import "github.com/gotosocial/gotosocial/internal/router" + +// ClientAPIModule represents a chunk of code (usually contained in a single package) that adds a set +// of functionalities and side effects to a router, by mapping routes and handlers onto it--in other words, a REST API ;) +// A ClientAPIMpdule corresponds roughly to one main path of the gotosocial REST api, for example /api/v1/accounts/ or /oauth/ +type ClientAPIModule interface { + Route(s router.Router) error +} diff --git a/internal/oauth/README.md b/internal/module/oauth/README.md similarity index 55% rename from internal/oauth/README.md rename to internal/module/oauth/README.md index 50a9e12..3d84273 100644 --- a/internal/oauth/README.md +++ b/internal/module/oauth/README.md @@ -1,3 +1,5 @@ # oauth -This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) server functionality to the GoToSocial APIs. +This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API. + +It also provides a handler/middleware for attaching to the Gin engine for validating authenticated users. diff --git a/internal/oauth/pgclientstore.go b/internal/module/oauth/clientstore.go similarity index 56% rename from internal/oauth/pgclientstore.go rename to internal/module/oauth/clientstore.go index 1df46fe..f99c160 100644 --- a/internal/oauth/pgclientstore.go +++ b/internal/module/oauth/clientstore.go @@ -22,55 +22,47 @@ import ( "context" "fmt" - "github.com/go-pg/pg/v10" + "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/oauth2/v4" "github.com/gotosocial/oauth2/v4/models" ) -type pgClientStore struct { - conn *pg.DB +type clientStore struct { + db db.DB } -func NewPGClientStore(conn *pg.DB) oauth2.ClientStore { - pts := &pgClientStore{ - conn: conn, +func newClientStore(db db.DB) oauth2.ClientStore { + pts := &clientStore{ + db: db, } return pts } -func (pcs *pgClientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { +func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { poc := &oauthClient{ ID: clientID, } - if err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Select(); err != nil { - return nil, fmt.Errorf("error in clientstore getbyid searching for client %s: %s", clientID, err) + if err := cs.db.GetByID(clientID, poc); err != nil { + return nil, fmt.Errorf("database error: %s", err) } return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil } -func (pcs *pgClientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { +func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { poc := &oauthClient{ ID: cli.GetID(), Secret: cli.GetSecret(), Domain: cli.GetDomain(), UserID: cli.GetUserID(), } - _, err := pcs.conn.WithContext(ctx).Model(poc).OnConflict("(id) DO UPDATE").Insert() - if err != nil { - return fmt.Errorf("error in clientstore set: %s", err) - } - return nil + return cs.db.UpdateByID(id, poc) } -func (pcs *pgClientStore) Delete(ctx context.Context, id string) error { +func (cs *clientStore) Delete(ctx context.Context, id string) error { poc := &oauthClient{ ID: id, } - _, err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Delete() - if err != nil { - return fmt.Errorf("error in clientstore delete: %s", err) - } - return nil + return cs.db.DeleteByID(id, poc) } type oauthClient struct { diff --git a/internal/oauth/pgclientstore_test.go b/internal/module/oauth/clientstore_test.go similarity index 59% rename from internal/oauth/pgclientstore_test.go rename to internal/module/oauth/clientstore_test.go index eb011fe..bca0024 100644 --- a/internal/oauth/pgclientstore_test.go +++ b/internal/module/oauth/clientstore_test.go @@ -1,11 +1,28 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ package oauth import ( "context" "testing" - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" + "github.com/gotosocial/gotosocial/internal/config" + "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/oauth2/v4/models" "github.com/sirupsen/logrus" "github.com/stretchr/testify/suite" @@ -13,7 +30,7 @@ import ( type PgClientStoreTestSuite struct { suite.Suite - conn *pg.DB + db db.DB testClientID string testClientSecret string testClientDomain string @@ -32,31 +49,55 @@ func (suite *PgClientStoreTestSuite) SetupSuite() { // SetupTest creates a postgres connection and creates the oauth_clients table before each test func (suite *PgClientStoreTestSuite) SetupTest() { - suite.conn = pg.Connect(&pg.Options{}) - if err := suite.conn.Ping(context.Background()); err != nil { - logrus.Panicf("db connection error: %s", err) + log := logrus.New() + log.SetLevel(logrus.TraceLevel) + c := config.Empty() + c.DBConfig = &config.DBConfig{ + Type: "postgres", + Address: "localhost", + Port: 5432, + User: "postgres", + Password: "postgres", + Database: "postgres", + ApplicationName: "gotosocial", } - if err := suite.conn.Model(&oauthClient{}).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }); err != nil { - logrus.Panicf("db connection error: %s", err) + db, err := db.New(context.Background(), c, log) + if err != nil { + logrus.Panicf("error creating database connection: %s", err) + } + + suite.db = db + + models := []interface{}{ + &oauthClient{}, + } + + for _, m := range models { + if err := suite.db.CreateTable(m); err != nil { + logrus.Panicf("db connection error: %s", err) + } } } // TearDownTest drops the oauth_clients table and closes the pg connection after each test func (suite *PgClientStoreTestSuite) TearDownTest() { - if err := suite.conn.Model(&oauthClient{}).DropTable(&orm.DropTableOptions{}); err != nil { - logrus.Panicf("drop table error: %s", err) + models := []interface{}{ + &oauthClient{}, } - if err := suite.conn.Close(); err != nil { + for _, m := range models { + if err := suite.db.DropTable(m); err != nil { + logrus.Panicf("error dropping table: %s", err) + } + } + if err := suite.db.Stop(context.Background()); err != nil { logrus.Panicf("error closing db connection: %s", err) } - suite.conn = nil + suite.db = nil } func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() { // set a new client in the store - cs := NewPGClientStore(suite.conn) + cs := newClientStore(suite.db) if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { suite.FailNow(err.Error()) } @@ -74,7 +115,7 @@ func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() { func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() { // set a new client in the store - cs := NewPGClientStore(suite.conn) + cs := newClientStore(suite.db) if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { suite.FailNow(err.Error()) } diff --git a/internal/module/oauth/oauth.go b/internal/module/oauth/oauth.go new file mode 100644 index 0000000..4436f7a --- /dev/null +++ b/internal/module/oauth/oauth.go @@ -0,0 +1,510 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +// Package oauth is a module that provides oauth functionality to a router. +// It adds the following paths: +// /api/v1/apps +// /auth/sign_in +// /oauth/token +// /oauth/authorize +// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token. +package oauth + +import ( + "fmt" + "net/http" + "net/url" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/gotosocial/gotosocial/internal/db" + "github.com/gotosocial/gotosocial/internal/gtsmodel" + "github.com/gotosocial/gotosocial/internal/module" + "github.com/gotosocial/gotosocial/internal/router" + "github.com/gotosocial/gotosocial/pkg/mastotypes" + "github.com/gotosocial/oauth2/v4" + "github.com/gotosocial/oauth2/v4/errors" + "github.com/gotosocial/oauth2/v4/manage" + "github.com/gotosocial/oauth2/v4/server" + "github.com/sirupsen/logrus" + "golang.org/x/crypto/bcrypt" +) + +const ( + appsPath = "/api/v1/apps" + authSignInPath = "/auth/sign_in" + oauthTokenPath = "/oauth/token" + oauthAuthorizePath = "/oauth/authorize" +) + +// oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface +type oauthModule struct { + oauthManager *manage.Manager + oauthServer *server.Server + db db.DB + log *logrus.Logger +} + +type login struct { + Email string `form:"username"` + Password string `form:"password"` +} + +// New returns a new oauth module +func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule { + manager := manage.NewDefaultManager() + manager.MapTokenStorage(ts) + manager.MapClientStorage(cs) + manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) + sc := &server.Config{ + TokenType: "Bearer", + // Must follow the spec. + AllowGetAccessRequest: false, + // Support only the non-implicit flow. + AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, + // Allow: + // - Authorization Code (for first & third parties) + AllowedGrantTypes: []oauth2.GrantType{ + oauth2.AuthorizationCode, + }, + AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain}, + } + + srv := server.NewServer(sc, manager) + srv.SetInternalErrorHandler(func(err error) *errors.Response { + log.Errorf("internal oauth error: %s", err) + return nil + }) + + srv.SetResponseErrorHandler(func(re *errors.Response) { + log.Errorf("internal response error: %s", re.Error) + }) + + m := &oauthModule{ + oauthManager: manager, + oauthServer: srv, + db: db, + log: log, + } + + m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler) + m.oauthServer.SetClientInfoHandler(server.ClientFormHandler) + return m +} + +// Route satisfies the RESTAPIModule interface +func (m *oauthModule) Route(s router.Router) error { + s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler) + + s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler) + s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler) + + s.AttachHandler(http.MethodPost, oauthTokenPath, m.tokenPOSTHandler) + + s.AttachHandler(http.MethodGet, oauthAuthorizePath, m.authorizeGETHandler) + s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler) + + s.AttachMiddleware(m.oauthTokenMiddleware) + + return nil +} + +/* + MAIN HANDLERS -- serve these through a server/router +*/ + +// appsPOSTHandler should be served at https://example.org/api/v1/apps +// It is equivalent to: https://docs.joinmastodon.org/methods/apps/ +func (m *oauthModule) appsPOSTHandler(c *gin.Context) { + l := m.log.WithField("func", "AppsPOSTHandler") + l.Trace("entering AppsPOSTHandler") + + form := &mastotypes.ApplicationPOSTRequest{} + if err := c.ShouldBind(form); err != nil { + c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()}) + return + } + + // permitted length for most fields + permittedLength := 64 + // redirect can be a bit bigger because we probably need to encode data in the redirect uri + permittedRedirect := 256 + + // check lengths of fields before proceeding so the user can't spam huge entries into the database + if len(form.ClientName) > permittedLength { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)}) + return + } + if len(form.Website) > permittedLength { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)}) + return + } + if len(form.RedirectURIs) > permittedRedirect { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)}) + return + } + if len(form.Scopes) > permittedLength { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)}) + return + } + + // set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/ + var scopes string + if form.Scopes == "" { + scopes = "read" + } else { + scopes = form.Scopes + } + + // generate new IDs for this application and its associated client + clientID := uuid.NewString() + clientSecret := uuid.NewString() + vapidKey := uuid.NewString() + + // generate the application to put in the database + app := >smodel.Application{ + Name: form.ClientName, + Website: form.Website, + RedirectURI: form.RedirectURIs, + ClientID: clientID, + ClientSecret: clientSecret, + Scopes: scopes, + VapidKey: vapidKey, + } + + // chuck it in the db + if err := m.db.Put(app); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // now we need to model an oauth client from the application that the oauth library can use + oc := &oauthClient{ + ID: clientID, + Secret: clientSecret, + Domain: form.RedirectURIs, + UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now + } + + // chuck it in the db + if err := m.db.Put(oc); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/ + c.JSON(http.StatusOK, app.ToMastotype()) +} + +// signInGETHandler should be served at https://example.org/auth/sign_in. +// The idea is to present a sign in page to the user, where they can enter their username and password. +// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler +func (m *oauthModule) signInGETHandler(c *gin.Context) { + m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html") + c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{}) +} + +// signInPOSTHandler should be served at https://example.org/auth/sign_in. +// The idea is to present a sign in page to the user, where they can enter their username and password. +// The handler will then redirect to the auth handler served at /auth +func (m *oauthModule) signInPOSTHandler(c *gin.Context) { + l := m.log.WithField("func", "SignInPOSTHandler") + s := sessions.Default(c) + form := &login{} + if err := c.ShouldBind(form); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + l.Tracef("parsed form: %+v", form) + + userid, err := m.validatePassword(form.Email, form.Password) + if err != nil { + c.String(http.StatusForbidden, err.Error()) + return + } + + s.Set("userid", userid) + if err := s.Save(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + l.Trace("redirecting to auth page") + c.Redirect(http.StatusFound, oauthAuthorizePath) +} + +// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token +// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs. +// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token +func (m *oauthModule) tokenPOSTHandler(c *gin.Context) { + l := m.log.WithField("func", "TokenPOSTHandler") + l.Trace("entered TokenPOSTHandler") + if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } +} + +// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize +// The idea here is to present an oauth authorize page to the user, with a button +// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user +func (m *oauthModule) authorizeGETHandler(c *gin.Context) { + l := m.log.WithField("func", "AuthorizeGETHandler") + s := sessions.Default(c) + + // UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow + // If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page. + userID, ok := s.Get("userid").(string) + if !ok || userID == "" { + l.Trace("userid was empty, parsing form then redirecting to sign in page") + if err := parseAuthForm(c, l); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + } else { + c.Redirect(http.StatusFound, authSignInPath) + } + return + } + + // We can use the client_id on the session to retrieve info about the app associated with the client_id + clientID, ok := s.Get("client_id").(string) + if !ok || clientID == "" { + c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"}) + return + } + app := >smodel.Application{ + ClientID: clientID, + } + if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) + return + } + + // we can also use the userid of the user to fetch their username from the db to greet them nicely <3 + user := >smodel.User{ + ID: userID, + } + if err := m.db.GetByID(user.ID, user); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + acct := >smodel.Account{ + ID: user.AccountID, + } + + if err := m.db.GetByID(acct.ID, acct); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Finally we should also get the redirect and scope of this particular request, as stored in the session. + redirect, ok := s.Get("redirect_uri").(string) + if !ok || redirect == "" { + c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"}) + return + } + scope, ok := s.Get("scope").(string) + if !ok || scope == "" { + c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"}) + return + } + + // the authorize template will display a form to the user where they can get some information + // about the app that's trying to authorize, and the scope of the request. + // They can then approve it if it looks OK to them, which will POST to the AuthorizePOSTHandler + l.Trace("serving authorize html") + c.HTML(http.StatusOK, "authorize.tmpl", gin.H{ + "appname": app.Name, + "appwebsite": app.Website, + "redirect": redirect, + "scope": scope, + "user": acct.Username, + }) +} + +// authorizePOSTHandler should be served as POST at https://example.org/oauth/authorize +// At this point we assume that the user has A) logged in and B) accepted that the app should act for them, +// so we should proceed with the authentication flow and generate an oauth token for them if we can. +// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user +func (m *oauthModule) authorizePOSTHandler(c *gin.Context) { + l := m.log.WithField("func", "AuthorizePOSTHandler") + s := sessions.Default(c) + + // At this point we know the user has said 'yes' to allowing the application and oauth client + // work for them, so we can set the + + // We need to retrieve the original form submitted to the authorizeGEThandler, and + // recreate it on the request so that it can be used further by the oauth2 library. + // So first fetch all the values from the session. + forceLogin, ok := s.Get("force_login").(string) + if !ok { + c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"}) + return + } + responseType, ok := s.Get("response_type").(string) + if !ok || responseType == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"}) + return + } + clientID, ok := s.Get("client_id").(string) + if !ok || clientID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"}) + return + } + redirectURI, ok := s.Get("redirect_uri").(string) + if !ok || redirectURI == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"}) + return + } + scope, ok := s.Get("scope").(string) + if !ok { + c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"}) + return + } + userID, ok := s.Get("userid").(string) + if !ok { + c.JSON(http.StatusBadRequest, gin.H{"error": "session missing userid"}) + return + } + // we're done with the session so we can clear it now + s.Clear() + + // now set the values on the request + values := url.Values{} + values.Set("force_login", forceLogin) + values.Set("response_type", responseType) + values.Set("client_id", clientID) + values.Set("redirect_uri", redirectURI) + values.Set("scope", scope) + values.Set("userid", userID) + c.Request.Form = values + l.Tracef("values on request set to %+v", c.Request.Form) + + // and proceed with authorization using the oauth2 library + if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + } +} + +/* + MIDDLEWARE +*/ + +// oauthTokenMiddleware +func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) { + l := m.log.WithField("func", "ValidatePassword") + l.Trace("entering OauthTokenMiddleware") + if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil { + l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope()) + c.Set("authenticated_user", ti.GetUserID()) + + } else { + l.Trace("continuing with unauthenticated request") + } +} + +/* + SUB-HANDLERS -- don't serve these directly, they should be attached to the oauth2 server or used inside handler funcs +*/ + +// validatePassword takes an email address and a password. +// The goal is to authenticate the password against the one for that email +// address stored in the database. If OK, we return the userid (a uuid) for that user, +// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. +func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) { + l := m.log.WithField("func", "ValidatePassword") + + // make sure an email/password was provided and bail if not + if email == "" || password == "" { + l.Debug("email or password was not provided") + return incorrectPassword() + } + + // first we select the user from the database based on email address, bail if no user found for that email + gtsUser := >smodel.User{} + + if err := m.db.GetWhere("email", email, gtsUser); err != nil { + l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) + return incorrectPassword() + } + + // make sure a password is actually set and bail if not + if gtsUser.EncryptedPassword == "" { + l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email) + return incorrectPassword() + } + + // compare the provided password with the encrypted one from the db, bail if they don't match + if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil { + l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err) + return incorrectPassword() + } + + // If we've made it this far the email/password is correct, so we can just return the id of the user. + userid = gtsUser.ID + l.Tracef("returning (%s, %s)", userid, err) + return +} + +// incorrectPassword is just a little helper function to use in the ValidatePassword function +func incorrectPassword() (string, error) { + return "", errors.New("password/email combination was incorrect") +} + +// userAuthorizationHandler gets the user's ID from the 'userid' field of the request form, +// or redirects to the /auth/sign_in page, if this key is not present. +func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) { + l := m.log.WithField("func", "UserAuthorizationHandler") + userID = r.FormValue("userid") + if userID == "" { + return "", errors.New("userid was empty, redirecting to sign in page") + } + l.Tracef("returning userID %s", userID) + return userID, err +} + +// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores +// the values in the form into the session. +func parseAuthForm(c *gin.Context, l *logrus.Entry) error { + s := sessions.Default(c) + + // first make sure they've filled out the authorize form with the required values + form := &mastotypes.OAuthAuthorize{} + if err := c.ShouldBind(form); err != nil { + return err + } + l.Tracef("parsed form: %+v", form) + + // these fields are *required* so check 'em + if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" { + return errors.New("missing one of: response_type, client_id or redirect_uri") + } + + // set default scope to read + if form.Scope == "" { + form.Scope = "read" + } + + // save these values from the form so we can use them elsewhere in the session + s.Set("force_login", form.ForceLogin) + s.Set("response_type", form.ResponseType) + s.Set("client_id", form.ClientID) + s.Set("redirect_uri", form.RedirectURI) + s.Set("scope", form.Scope) + return s.Save() +} diff --git a/internal/module/oauth/oauth_test.go b/internal/module/oauth/oauth_test.go new file mode 100644 index 0000000..adfb40a --- /dev/null +++ b/internal/module/oauth/oauth_test.go @@ -0,0 +1,191 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package oauth + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/gotosocial/gotosocial/internal/config" + "github.com/gotosocial/gotosocial/internal/db" + "github.com/gotosocial/gotosocial/internal/gtsmodel" + "github.com/gotosocial/gotosocial/internal/router" + "github.com/gotosocial/oauth2/v4" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/suite" + "golang.org/x/crypto/bcrypt" +) + +type OauthTestSuite struct { + suite.Suite + tokenStore oauth2.TokenStore + clientStore oauth2.ClientStore + db db.DB + testAccount *gtsmodel.Account + testApplication *gtsmodel.Application + testUser *gtsmodel.User + testClient *oauthClient + config *config.Config +} + +// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout +func (suite *OauthTestSuite) SetupSuite() { + c := config.Empty() + // we're running on localhost without https so set the protocol to http + c.Protocol = "http" + // just for testing + c.Host = "localhost:8080" + // because go tests are run within the test package directory, we need to fiddle with the templateconfig + // basedir in a way that we wouldn't normally have to do when running the binary, in order to make + // the templates actually load + c.TemplateConfig.BaseDir = "../../../web/template/" + c.DBConfig = &config.DBConfig{ + Type: "postgres", + Address: "localhost", + Port: 5432, + User: "postgres", + Password: "postgres", + Database: "postgres", + ApplicationName: "gotosocial", + } + suite.config = c + + encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) + if err != nil { + logrus.Panicf("error encrypting user pass: %s", err) + } + + acctID := uuid.NewString() + + suite.testAccount = >smodel.Account{ + ID: acctID, + Username: "test_user", + } + suite.testUser = >smodel.User{ + EncryptedPassword: string(encryptedPassword), + Email: "user@example.org", + AccountID: acctID, + } + suite.testClient = &oauthClient{ + ID: "a-known-client-id", + Secret: "some-secret", + Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host), + } + suite.testApplication = >smodel.Application{ + Name: "a test application", + Website: "https://some-application-website.com", + RedirectURI: "http://localhost:8080", + ClientID: "a-known-client-id", + ClientSecret: "some-secret", + Scopes: "read", + VapidKey: uuid.NewString(), + } +} + +// SetupTest creates a postgres connection and creates the oauth_clients table before each test +func (suite *OauthTestSuite) SetupTest() { + + log := logrus.New() + log.SetLevel(logrus.TraceLevel) + db, err := db.New(context.Background(), suite.config, log) + if err != nil { + logrus.Panicf("error creating database connection: %s", err) + } + + suite.db = db + + models := []interface{}{ + &oauthClient{}, + &oauthToken{}, + >smodel.User{}, + >smodel.Account{}, + >smodel.Application{}, + } + + for _, m := range models { + if err := suite.db.CreateTable(m); err != nil { + logrus.Panicf("db connection error: %s", err) + } + } + + suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New()) + suite.clientStore = newClientStore(suite.db) + + if err := suite.db.Put(suite.testAccount); err != nil { + logrus.Panicf("could not insert test account into db: %s", err) + } + if err := suite.db.Put(suite.testUser); err != nil { + logrus.Panicf("could not insert test user into db: %s", err) + } + if err := suite.db.Put(suite.testClient); err != nil { + logrus.Panicf("could not insert test client into db: %s", err) + } + if err := suite.db.Put(suite.testApplication); err != nil { + logrus.Panicf("could not insert test application into db: %s", err) + } + +} + +// TearDownTest drops the oauth_clients table and closes the pg connection after each test +func (suite *OauthTestSuite) TearDownTest() { + models := []interface{}{ + &oauthClient{}, + &oauthToken{}, + >smodel.User{}, + >smodel.Account{}, + >smodel.Application{}, + } + for _, m := range models { + if err := suite.db.DropTable(m); err != nil { + logrus.Panicf("error dropping table: %s", err) + } + } + if err := suite.db.Stop(context.Background()); err != nil { + logrus.Panicf("error closing db connection: %s", err) + } + suite.db = nil +} + +func (suite *OauthTestSuite) TestAPIInitialize() { + log := logrus.New() + log.SetLevel(logrus.TraceLevel) + + r, err := router.New(suite.config, log) + if err != nil { + suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) + } + + api := New(suite.tokenStore, suite.clientStore, suite.db, log) + if err := api.Route(r); err != nil { + suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) + } + + go r.Start() + time.Sleep(60 * time.Second) + // http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=http://localhost:8080&scope=read + // curl -v -F client_id=a-known-client-id -F client_secret=some-secret -F redirect_uri=http://localhost:8080 -F code=[ INSERT CODE HERE ] -F grant_type=authorization_code localhost:8080/oauth/token + // curl -v -H "Authorization: Bearer [INSERT TOKEN HERE]" http://localhost:8080 +} + +func TestOauthTestSuite(t *testing.T) { + suite.Run(t, new(OauthTestSuite)) +} diff --git a/internal/oauth/pgtokenstore.go b/internal/module/oauth/tokenstore.go similarity index 72% rename from internal/oauth/pgtokenstore.go rename to internal/module/oauth/tokenstore.go index 0271afa..d8a6d58 100644 --- a/internal/oauth/pgtokenstore.go +++ b/internal/module/oauth/tokenstore.go @@ -24,31 +24,31 @@ import ( "fmt" "time" - "github.com/go-pg/pg/v10" + "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/oauth2/v4" "github.com/gotosocial/oauth2/v4/models" "github.com/sirupsen/logrus" ) -// pgTokenStore is an implementation of oauth2.TokenStore, which uses Postgres as a storage backend. -type pgTokenStore struct { +// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. +type tokenStore struct { oauth2.TokenStore - conn *pg.DB - log *logrus.Logger + db db.DB + log *logrus.Logger } -// NewPGTokenStore returns a token store, using postgres, that satisfies the oauth2.TokenStore interface. +// newTokenStore returns a token store that satisfies the oauth2.TokenStore interface. // -// In order to allow tokens to 'expire' (not really a thing in Postgres world), it will also set off a -// goroutine that iterates through the tokens in the DB once per minute and deletes any that have expired. -func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth2.TokenStore { - pts := &pgTokenStore{ - conn: conn, - log: log, +// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through +// the tokens in the DB once per minute and deletes any that have expired. +func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.TokenStore { + pts := &tokenStore{ + db: db, + log: log, } // set the token store to clean out expired tokens once per minute, or return if we're done - go func(ctx context.Context, pts *pgTokenStore, log *logrus.Logger) { + go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) { cleanloop: for { select { @@ -67,22 +67,22 @@ func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth } // sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. -func (pts *pgTokenStore) sweep() error { +func (pts *tokenStore) sweep() error { // select *all* tokens from the db // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. - var tokens []oauthToken - if err := pts.conn.Model(&tokens).Select(); err != nil { + tokens := new([]*oauthToken) + if err := pts.db.GetAll(tokens); err != nil { return err } // iterate through and remove expired tokens now := time.Now() - for _, pgt := range tokens { + for _, pgt := range *tokens { // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: // we only want to check if a token expired before now if the expiry time is *not zero*; // ie., if it's been explicity set. if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) { - if _, err := pts.conn.Model(&pgt).Delete(); err != nil { + if err := pts.db.DeleteByID(pgt.ID, &pgt); err != nil { return err } } @@ -93,68 +93,61 @@ func (pts *pgTokenStore) sweep() error { // Create creates and store the new token information. // For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34 -func (pts *pgTokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { +func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { t, ok := info.(*models.Token) if !ok { return errors.New("info param was not a models.Token") } - _, err := pts.conn.WithContext(ctx).Model(oauthTokenToPGToken(t)).Insert() - if err != nil { + if err := pts.db.Put(oauthTokenToPGToken(t)); err != nil { return fmt.Errorf("error in tokenstore create: %s", err) } return nil } // RemoveByCode deletes a token from the DB based on the Code field -func (pts *pgTokenStore) RemoveByCode(ctx context.Context, code string) error { - _, err := pts.conn.Model(&oauthToken{}).Where("code = ?", code).Delete() - if err != nil { - return fmt.Errorf("error in tokenstore removebycode: %s", err) - } - return nil +func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error { + return pts.db.DeleteWhere("code", code, &oauthToken{}) } // RemoveByAccess deletes a token from the DB based on the Access field -func (pts *pgTokenStore) RemoveByAccess(ctx context.Context, access string) error { - _, err := pts.conn.Model(&oauthToken{}).Where("access = ?", access).Delete() - if err != nil { - return fmt.Errorf("error in tokenstore removebyaccess: %s", err) - } - return nil +func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { + return pts.db.DeleteWhere("access", access, &oauthToken{}) } // RemoveByRefresh deletes a token from the DB based on the Refresh field -func (pts *pgTokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { - _, err := pts.conn.Model(&oauthToken{}).Where("refresh = ?", refresh).Delete() - if err != nil { - return fmt.Errorf("error in tokenstore removebyrefresh: %s", err) - } - return nil +func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { + return pts.db.DeleteWhere("refresh", refresh, &oauthToken{}) } // GetByCode selects a token from the DB based on the Code field -func (pts *pgTokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - pgt := &oauthToken{} - if err := pts.conn.Model(pgt).Where("code = ?", code).Select(); err != nil { - return nil, fmt.Errorf("error in tokenstore getbycode: %s", err) +func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{ + Code: code, + } + if err := pts.db.GetWhere("code", code, pgt); err != nil { + return nil, err } return pgTokenToOauthToken(pgt), nil } // GetByAccess selects a token from the DB based on the Access field -func (pts *pgTokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { - pgt := &oauthToken{} - if err := pts.conn.Model(pgt).Where("access = ?", access).Select(); err != nil { - return nil, fmt.Errorf("error in tokenstore getbyaccess: %s", err) +func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{ + Access: access, + } + if err := pts.db.GetWhere("access", access, pgt); err != nil { + return nil, err } return pgTokenToOauthToken(pgt), nil } // GetByRefresh selects a token from the DB based on the Refresh field -func (pts *pgTokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - pgt := &oauthToken{} - if err := pts.conn.Model(pgt).Where("refresh = ?", refresh).Select(); err != nil { - return nil, fmt.Errorf("error in tokenstore getbyrefresh: %s", err) +func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { + pgt := &oauthToken{ + Refresh: refresh, + } + if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil { + return nil, err } return pgTokenToOauthToken(pgt), nil } @@ -174,6 +167,7 @@ func (pts *pgTokenStore) GetByRefresh(ctx context.Context, refresh string) (oaut // As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken // and pgTokenToOauthToken can be used for that. type oauthToken struct { + ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` ClientID string UserID string RedirectURI string diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go deleted file mode 100644 index 49e04a9..0000000 --- a/internal/oauth/oauth.go +++ /dev/null @@ -1,446 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . -*/ - -package oauth - -import ( - "fmt" - "net/http" - "net/url" - - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" - "github.com/go-pg/pg/v10" - "github.com/google/uuid" - "github.com/gotosocial/gotosocial/internal/api" - "github.com/gotosocial/gotosocial/internal/gtsmodel" - "github.com/gotosocial/gotosocial/pkg/mastotypes" - "github.com/gotosocial/oauth2/v4" - "github.com/gotosocial/oauth2/v4/errors" - "github.com/gotosocial/oauth2/v4/manage" - "github.com/gotosocial/oauth2/v4/server" - "github.com/sirupsen/logrus" - "golang.org/x/crypto/bcrypt" -) - -type API struct { - manager *manage.Manager - server *server.Server - conn *pg.DB - log *logrus.Logger -} - -type login struct { - Email string `form:"username"` - Password string `form:"password"` -} - -type code struct { - Code string `form:"code"` -} - -func New(ts oauth2.TokenStore, cs oauth2.ClientStore, conn *pg.DB, log *logrus.Logger) *API { - manager := manage.NewDefaultManager() - manager.MapTokenStorage(ts) - manager.MapClientStorage(cs) - manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) - sc := &server.Config{ - TokenType: "Bearer", - // Must follow the spec. - AllowGetAccessRequest: false, - // Support only the non-implicit flow. - AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, - // Allow: - // - Authorization Code (for first & third parties) - // - Refreshing Tokens - // - // Deny: - // - Resource owner secrets (password grant) - // - Client secrets - AllowedGrantTypes: []oauth2.GrantType{ - oauth2.AuthorizationCode, - oauth2.Refreshing, - }, - AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{ - oauth2.CodeChallengePlain, - }, - } - - srv := server.NewServer(sc, manager) - srv.SetInternalErrorHandler(func(err error) *errors.Response { - log.Errorf("internal oauth error: %s", err) - return nil - }) - - srv.SetResponseErrorHandler(func(re *errors.Response) { - log.Errorf("internal response error: %s", re.Error) - }) - - api := &API{ - manager: manager, - server: srv, - conn: conn, - log: log, - } - - api.server.SetUserAuthorizationHandler(api.UserAuthorizationHandler) - api.server.SetClientInfoHandler(server.ClientFormHandler) - return api -} - -func (a *API) AddRoutes(s api.Server) error { - s.AttachHandler(http.MethodPost, "/api/v1/apps", a.AppsPOSTHandler) - - s.AttachHandler(http.MethodGet, "/auth/sign_in", a.SignInGETHandler) - s.AttachHandler(http.MethodPost, "/auth/sign_in", a.SignInPOSTHandler) - - s.AttachHandler(http.MethodPost, "/oauth/token", a.TokenPOSTHandler) - - s.AttachHandler(http.MethodGet, "/oauth/authorize", a.AuthorizeGETHandler) - s.AttachHandler(http.MethodPost, "/oauth/authorize", a.AuthorizePOSTHandler) - - return nil -} - -func incorrectPassword() (string, error) { - return "", errors.New("password/email combination was incorrect") -} - -/* - MAIN HANDLERS -- serve these through a server/router -*/ - -// AppsPOSTHandler should be served at https://example.org/api/v1/apps -// It is equivalent to: https://docs.joinmastodon.org/methods/apps/ -func (a *API) AppsPOSTHandler(c *gin.Context) { - l := a.log.WithField("func", "AppsPOSTHandler") - l.Trace("entering AppsPOSTHandler") - - form := &mastotypes.ApplicationPOSTRequest{} - if err := c.ShouldBind(form); err != nil { - c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()}) - return - } - - // permitted length for most fields - permittedLength := 64 - // redirect can be a bit bigger because we probably need to encode data in the redirect uri - permittedRedirect := 256 - - // check lengths of fields before proceeding so the user can't spam huge entries into the database - if len(form.ClientName) > permittedLength { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)}) - return - } - if len(form.Website) > permittedLength { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)}) - return - } - if len(form.RedirectURIs) > permittedRedirect { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)}) - return - } - if len(form.Scopes) > permittedLength { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)}) - return - } - - // set default 'read' for scopes if it's not set - var scopes string - if form.Scopes == "" { - scopes = "read" - } else { - scopes = form.Scopes - } - - // generate new IDs for this application and its associated client - clientID := uuid.NewString() - clientSecret := uuid.NewString() - vapidKey := uuid.NewString() - - // generate the application to put in the database - app := >smodel.Application{ - Name: form.ClientName, - Website: form.Website, - RedirectURI: form.RedirectURIs, - ClientID: clientID, - ClientSecret: clientSecret, - Scopes: scopes, - VapidKey: vapidKey, - } - - // chuck it in the db - if _, err := a.conn.Model(app).Insert(); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // now we need to model an oauth client from the application that the oauth library can use - oc := &oauthClient{ - ID: clientID, - Secret: clientSecret, - Domain: form.RedirectURIs, - UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now - } - - // chuck it in the db - if _, err := a.conn.Model(oc).Insert(); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/ - c.JSON(http.StatusOK, app) -} - -// SignInGETHandler should be served at https://example.org/auth/sign_in. -// The idea is to present a sign in page to the user, where they can enter their username and password. -// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler -func (a *API) SignInGETHandler(c *gin.Context) { - a.log.WithField("func", "SignInGETHandler").Trace("serving sign in html") - c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{}) -} - -// SignInPOSTHandler should be served at https://example.org/auth/sign_in. -// The idea is to present a sign in page to the user, where they can enter their username and password. -// The handler will then redirect to the auth handler served at /auth -func (a *API) SignInPOSTHandler(c *gin.Context) { - l := a.log.WithField("func", "SignInPOSTHandler") - s := sessions.Default(c) - form := &login{} - if err := c.ShouldBind(form); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - l.Tracef("parsed form: %+v", form) - - userid, err := a.ValidatePassword(form.Email, form.Password) - if err != nil { - c.String(http.StatusForbidden, err.Error()) - return - } - - s.Set("username", userid) - if err := s.Save(); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - l.Trace("redirecting to auth page") - c.Redirect(http.StatusFound, "/oauth/authorize") -} - -// TokenPOSTHandler should be served as a POST at https://example.org/oauth/token -// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs. -// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token -func (a *API) TokenPOSTHandler(c *gin.Context) { - l := a.log.WithField("func", "TokenHandler") - l.Trace("entered token handler, will now go to server.HandleTokenRequest") - if err := a.server.HandleTokenRequest(c.Writer, c.Request); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } -} - -// AuthorizeGETHandler should be served as GET at https://example.org/oauth/authorize -// The idea here is to present an oauth authorize page to the user, with a button -// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user -func (a *API) AuthorizeGETHandler(c *gin.Context) { - l := a.log.WithField("func", "AuthorizeGETHandler") - s := sessions.Default(c) - - // Username will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow - // If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page. - v := s.Get("username") - if username, ok := v.(string); !ok || username == "" { - l.Trace("username was empty, parsing form then redirecting to sign in page") - - // first make sure they've filled out the authorize form with the required values - form := &mastotypes.OAuthAuthorize{} - if err := c.ShouldBind(form); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - l.Tracef("parsed form: %+v", form) - - // these fields are *required* so check 'em - if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "missing one of: response_type, client_id or redirect_uri"}) - return - } - - // save these values from the form so we can use them elsewhere in the session - s.Set("force_login", form.ForceLogin) - s.Set("response_type", form.ResponseType) - s.Set("client_id", form.ClientID) - s.Set("redirect_uri", form.RedirectURI) - s.Set("scope", form.Scope) - if err := s.Save(); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // send them to the sign in page so we can tell who they are - c.Redirect(http.StatusFound, "/auth/sign_in") - return - } - - // Check if we have a code already. If we do, it means the user used urn:ietf:wg:oauth:2.0:oob as their redirect URI - // and were sent here, which means they just want the code displayed so they can use it out of band. - code := &code{} - if err := c.Bind(code); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // the authorize template will either: - // 1. Display the code to the user if they're already authorized and were redirected here because they selected urn:ietf:wg:oauth:2.0:oob. - // 2. Display a form where they can get some information about the app that's trying to authorize, and approve it, which will then go to AuthorizePOSTHandler - l.Trace("serving authorize html") - c.HTML(http.StatusOK, "authorize.tmpl", gin.H{ - "code": code.Code, - }) -} - -// AuthorizePOSTHandler should be served as POST at https://example.org/oauth/authorize -// The idea here is to present an oauth authorize page to the user, with a button -// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user -func (a *API) AuthorizePOSTHandler(c *gin.Context) { - l := a.log.WithField("func", "AuthorizePOSTHandler") - s := sessions.Default(c) - - v := s.Get("username") - if username, ok := v.(string); !ok || username == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not signed in"}) - } - - values := url.Values{} - - if v, ok := s.Get("force_login").(string); !ok { - c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"}) - return - } else { - values.Add("force_login", v) - } - - if v, ok := s.Get("response_type").(string); !ok { - c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"}) - return - } else { - values.Add("response_type", v) - } - - if v, ok := s.Get("client_id").(string); !ok { - c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"}) - return - } else { - values.Add("client_id", v) - } - - if v, ok := s.Get("redirect_uri").(string); !ok { - c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"}) - return - } else { - // todo: explain this little hack - if v == "urn:ietf:wg:oauth:2.0:oob" { - v = "http://localhost:8080/oauth/authorize" - } - values.Add("redirect_uri", v) - } - - if v, ok := s.Get("scope").(string); !ok { - c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"}) - return - } else { - values.Add("scope", v) - } - - if v, ok := s.Get("username").(string); !ok { - c.JSON(http.StatusBadRequest, gin.H{"error": "session missing username"}) - return - } else { - values.Add("username", v) - } - - c.Request.Form = values - l.Tracef("values on request set to %+v", c.Request.Form) - - if err := s.Save(); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - if err := a.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - } -} - -/* - SUB-HANDLERS -- don't serve these directly, they should be attached to the oauth2 server -*/ - -// PasswordAuthorizationHandler takes a username (in this case, we use an email address) -// and a password. The goal is to authenticate the password against the one for that email -// address stored in the database. If OK, we return the userid (a uuid) for that user, -// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. -func (a *API) ValidatePassword(email string, password string) (userid string, err error) { - l := a.log.WithField("func", "PasswordAuthorizationHandler") - - // make sure an email/password was provided and bail if not - if email == "" || password == "" { - l.Debug("email or password was not provided") - return incorrectPassword() - } - - // first we select the user from the database based on email address, bail if no user found for that email - gtsUser := >smodel.User{} - if err := a.conn.Model(gtsUser).Where("email = ?", email).Select(); err != nil { - l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) - return incorrectPassword() - } - - // make sure a password is actually set and bail if not - if gtsUser.EncryptedPassword == "" { - l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email) - return incorrectPassword() - } - - // compare the provided password with the encrypted one from the db, bail if they don't match - if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil { - l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err) - return incorrectPassword() - } - - // If we've made it this far the email/password is correct, so we can just return the id of the user. - userid = gtsUser.ID - l.Tracef("returning (%s, %s)", userid, err) - return -} - -// UserAuthorizationHandler gets the user's ID from the 'username' field of the request form, -// or redirects to the /auth/sign_in page, if this key is not present. -func (a *API) UserAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) { - l := a.log.WithField("func", "UserAuthorizationHandler") - userID = r.FormValue("username") - if userID == "" { - l.Trace("username was empty, redirecting to sign in page") - http.Redirect(w, r, "/auth/sign_in", http.StatusFound) - return "", nil - } - l.Tracef("returning (%s, %s)", userID, err) - return userID, err -} diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go deleted file mode 100644 index 6c3a17c..0000000 --- a/internal/oauth/oauth_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package oauth - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" - "github.com/gotosocial/gotosocial/internal/api" - "github.com/gotosocial/gotosocial/internal/config" - "github.com/gotosocial/gotosocial/internal/gtsmodel" - "github.com/gotosocial/oauth2/v4" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/suite" - "golang.org/x/crypto/bcrypt" -) - -type OauthTestSuite struct { - suite.Suite - tokenStore oauth2.TokenStore - clientStore oauth2.ClientStore - conn *pg.DB - testAccount *gtsmodel.Account - testUser *gtsmodel.User - testClient *oauthClient - config *config.Config -} - -const () - -// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout -func (suite *OauthTestSuite) SetupSuite() { - encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("test-password"), bcrypt.DefaultCost) - if err != nil { - logrus.Panicf("error encrypting user pass: %s", err) - } - - suite.testAccount = >smodel.Account{} - suite.testUser = >smodel.User{ - EncryptedPassword: string(encryptedPassword), - Email: "user@localhost", - AccountID: "some-account-id-it-doesn't-matter-really-since-this-user-doesn't-actually-have-an-account!", - } - suite.testClient = &oauthClient{ - ID: "a-known-client-id", - Secret: "some-secret", - Domain: "http://localhost:8080", - } - - // because go tests are run within the test package directory, we need to fiddle with the templateconfig - // basedir in a way that we wouldn't normally have to do when running the binary, in order to make - // the templates actually load - c := config.Empty() - c.TemplateConfig.BaseDir = "../../web/template/" - suite.config = c -} - -// SetupTest creates a postgres connection and creates the oauth_clients table before each test -func (suite *OauthTestSuite) SetupTest() { - suite.conn = pg.Connect(&pg.Options{}) - if err := suite.conn.Ping(context.Background()); err != nil { - logrus.Panicf("db connection error: %s", err) - } - - models := []interface{}{ - &oauthClient{}, - &oauthToken{}, - >smodel.User{}, - >smodel.Account{}, - >smodel.Application{}, - } - - for _, m := range models { - if err := suite.conn.Model(m).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }); err != nil { - logrus.Panicf("db connection error: %s", err) - } - } - - suite.tokenStore = NewPGTokenStore(context.Background(), suite.conn, logrus.New()) - suite.clientStore = NewPGClientStore(suite.conn) - - if _, err := suite.conn.Model(suite.testUser).Insert(); err != nil { - logrus.Panicf("could not insert test user into db: %s", err) - } - - if _, err := suite.conn.Model(suite.testClient).Insert(); err != nil { - logrus.Panicf("could not insert test client into db: %s", err) - } - -} - -// TearDownTest drops the oauth_clients table and closes the pg connection after each test -func (suite *OauthTestSuite) TearDownTest() { - models := []interface{}{ - &oauthClient{}, - &oauthToken{}, - >smodel.User{}, - >smodel.Account{}, - >smodel.Application{}, - } - for _, m := range models { - if err := suite.conn.Model(m).DropTable(&orm.DropTableOptions{}); err != nil { - logrus.Panicf("drop table error: %s", err) - } - } - if err := suite.conn.Close(); err != nil { - logrus.Panicf("error closing db connection: %s", err) - } - suite.conn = nil -} - -func (suite *OauthTestSuite) TestAPIInitialize() { - log := logrus.New() - log.SetLevel(logrus.TraceLevel) - - r := api.New(suite.config, log) - api := New(suite.tokenStore, suite.clientStore, suite.conn, log) - if err := api.AddRoutes(r); err != nil { - suite.FailNow(fmt.Sprintf("error initializing api: %s", err)) - } - go r.Start() - time.Sleep(30 * time.Second) - // http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=https://example.org - // http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=urn:ietf:wg:oauth:2.0:oob -} - -func TestOauthTestSuite(t *testing.T) { - suite.Run(t, new(OauthTestSuite)) -} diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 0000000..3893503 --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,120 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package router + +import ( + "crypto/rand" + "fmt" + "os" + "path/filepath" + + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/memstore" + "github.com/gin-gonic/gin" + "github.com/gotosocial/gotosocial/internal/config" + "github.com/sirupsen/logrus" +) + +// Router provides the REST interface for gotosocial, using gin. +type Router interface { + // Attach a gin handler to the router with the given method and path + AttachHandler(method string, path string, handler gin.HandlerFunc) + // Attach a gin middleware to the router that will be used globally + AttachMiddleware(handler gin.HandlerFunc) + // Start the router + Start() + // Stop the router + Stop() +} + +// router fulfils the Router interface using gin and logrus +type router struct { + logger *logrus.Logger + engine *gin.Engine +} + +// Start starts the router nicely +func (s *router) Start() { + // todo: start gracefully + if err := s.engine.Run(); err != nil { + s.logger.Panicf("server error: %s", err) + } +} + +// Stop shuts down the router nicely +func (s *router) Stop() { + // todo: shut down gracefully +} + +// AttachHandler attaches the given gin.HandlerFunc to the router with the specified method and path. +// If the path is set to ANY, then the handlerfunc will be used for ALL methods at its given path. +func (s *router) AttachHandler(method string, path string, handler gin.HandlerFunc) { + if method == "ANY" { + s.engine.Any(path, handler) + } else { + s.engine.Handle(method, path, handler) + } +} + +// AttachMiddleware attaches a gin middleware to the router that will be used globally +func (s *router) AttachMiddleware(middleware gin.HandlerFunc) { + s.engine.Use(middleware) +} + +// New returns a new Router with the specified configuration, using the given logrus logger. +func New(config *config.Config, logger *logrus.Logger) (Router, error) { + engine := gin.New() + + // create a new session store middleware + store, err := sessionStore() + if err != nil { + return nil, fmt.Errorf("error creating session store: %s", err) + } + engine.Use(sessions.Sessions("gotosocial-session", store)) + + // load html templates for use by the router + cwd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("error getting current working directory: %s", err) + } + tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", config.TemplateConfig.BaseDir)) + logger.Debugf("loading templates from %s", tmPath) + engine.LoadHTMLGlob(tmPath) + + return &router{ + logger: logger, + engine: engine, + }, nil +} + +// sessionStore returns a new session store with a random auth and encryption key. +// This means that cookies using the store will be reset if gotosocial is restarted! +func sessionStore() (memstore.Store, error) { + auth := make([]byte, 32) + crypt := make([]byte, 32) + + if _, err := rand.Read(auth); err != nil { + return nil, err + } + if _, err := rand.Read(crypt); err != nil { + return nil, err + } + + return memstore.NewStore(auth, crypt), nil +} diff --git a/web/template/authorize.tmpl b/web/template/authorize.tmpl index 2b3b3ab..fa6338b 100644 --- a/web/template/authorize.tmpl +++ b/web/template/authorize.tmpl @@ -2,7 +2,7 @@ - Auth + GoToSocial Authorization -{{if len .code | eq 0 }}
-

Authorize

-

The client would like to perform actions on your behalf.

+

Hi {{.user}}!

+

Application {{.appname}} {{if len .appwebsite | eq 0 | not}}({{.appwebsite}}) {{end}}would like to perform actions on your behalf, with scope {{.scope}}.

+

The application will redirect to {{.redirect}} to continue.

-{{else}} - -
-
- {{.code}} -
-
- -{{end}} -