From 26c482cd86b91d050f5f708ee23cfa8cd04f72e9 Mon Sep 17 00:00:00 2001 From: tsmethurst Date: Wed, 24 Mar 2021 17:34:41 +0100 Subject: [PATCH] groundwork for other account routes --- cmd/gotosocial/main.go | 10 ++- internal/config/accounts.go | 24 ++++++++ internal/config/config.go | 82 ++++++++++++++----------- internal/db/pg.go | 9 +-- internal/module/account/account.go | 37 +++++++---- internal/module/account/account_test.go | 11 +++- internal/module/oauth/oauth.go | 37 +++++++---- internal/util/validation.go | 23 +++++++ pkg/mastotypes/account.go | 18 ++++++ 9 files changed, 186 insertions(+), 65 deletions(-) create mode 100644 internal/config/accounts.go create mode 100644 internal/util/validation.go diff --git a/cmd/gotosocial/main.go b/cmd/gotosocial/main.go index 0919d5f..173c82d 100644 --- a/cmd/gotosocial/main.go +++ b/cmd/gotosocial/main.go @@ -111,10 +111,18 @@ func main() { // TEMPLATE FLAGS &cli.StringFlag{ Name: flagNames.TemplateBaseDir, - Usage: "Basedir for html templating files for rendering pages and composing emails", + Usage: "Basedir for html templating files for rendering pages and composing emails.", Value: "./web/template/", EnvVars: []string{envNames.TemplateBaseDir}, }, + + // ACCOUNTS FLAGS + &cli.BoolFlag{ + Name: flagNames.AccountsOpenRegistration, + Usage: "Allow anyone to submit an account signup request. If false, server will be invite-only.", + Value: false, + EnvVars: []string{envNames.AccountsOpenRegistration}, + }, }, Commands: []*cli.Command{ { diff --git a/internal/config/accounts.go b/internal/config/accounts.go new file mode 100644 index 0000000..d5be1cb --- /dev/null +++ b/internal/config/accounts.go @@ -0,0 +1,24 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package config + +type AccountsConfig struct { + // Do we want people to be able to just submit sign up requests, or do we want invite only? + OpenRegistration bool +} diff --git a/internal/config/config.go b/internal/config/config.go index dca325c..6172de3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -33,6 +33,7 @@ type Config struct { Protocol string `yaml:"protocol"` DBConfig *DBConfig `yaml:"db"` TemplateConfig *TemplateConfig `yaml:"template"` + AccountsConfig *AccountsConfig `yaml:"accounts"` } // FromFile returns a new config from a file, or an error if something goes amiss. @@ -136,11 +137,17 @@ func (c *Config) ParseCLIFlags(f KeyedFlags) { if c.TemplateConfig.BaseDir == "" || f.IsSet(fn.TemplateBaseDir) { c.TemplateConfig.BaseDir = f.String(fn.TemplateBaseDir) } + + // accounts flags + if f.IsSet(fn.AccountsOpenRegistration) { + c.AccountsConfig.OpenRegistration = f.Bool(fn.AccountsOpenRegistration) + } } // KeyedFlags is a wrapper for any type that can store keyed flags and give them back. // HINT: This works with a urfave cli context struct ;) type KeyedFlags interface { + Bool(k string) bool String(k string) string Int(k string) int IsSet(k string) bool @@ -149,36 +156,38 @@ type KeyedFlags interface { // Flags is used for storing the names of the various flags used for // initializing and storing urfavecli flag variables. type Flags struct { - LogLevel string - ApplicationName string - ConfigPath string - Host string - Protocol string - DbType string - DbAddress string - DbPort string - DbUser string - DbPassword string - DbDatabase string - TemplateBaseDir string + LogLevel string + ApplicationName string + ConfigPath string + Host string + Protocol string + DbType string + DbAddress string + DbPort string + DbUser string + DbPassword string + DbDatabase string + TemplateBaseDir string + AccountsOpenRegistration string } // GetFlagNames returns a struct containing the names of the various flags used for // initializing and storing urfavecli flag variables. func GetFlagNames() Flags { return Flags{ - LogLevel: "log-level", - ApplicationName: "application-name", - ConfigPath: "config-path", - Host: "host", - Protocol: "protocol", - DbType: "db-type", - DbAddress: "db-address", - DbPort: "db-port", - DbUser: "db-user", - DbPassword: "db-password", - DbDatabase: "db-database", - TemplateBaseDir: "template-basedir", + LogLevel: "log-level", + ApplicationName: "application-name", + ConfigPath: "config-path", + Host: "host", + Protocol: "protocol", + DbType: "db-type", + DbAddress: "db-address", + DbPort: "db-port", + DbUser: "db-user", + DbPassword: "db-password", + DbDatabase: "db-database", + TemplateBaseDir: "template-basedir", + AccountsOpenRegistration: "accounts-open-registration", } } @@ -186,17 +195,18 @@ func GetFlagNames() Flags { // initializing and storing urfavecli flag variables. func GetEnvNames() Flags { return Flags{ - LogLevel: "GTS_LOG_LEVEL", - ApplicationName: "GTS_APPLICATION_NAME", - ConfigPath: "GTS_CONFIG_PATH", - Host: "GTS_HOST", - Protocol: "GTS_PROTOCOL", - DbType: "GTS_DB_TYPE", - DbAddress: "GTS_DB_ADDRESS", - DbPort: "GTS_DB_PORT", - DbUser: "GTS_DB_USER", - DbPassword: "GTS_DB_PASSWORD", - DbDatabase: "GTS_DB_DATABASE", - TemplateBaseDir: "GTS_TEMPLATE_BASEDIR", + LogLevel: "GTS_LOG_LEVEL", + ApplicationName: "GTS_APPLICATION_NAME", + ConfigPath: "GTS_CONFIG_PATH", + Host: "GTS_HOST", + Protocol: "GTS_PROTOCOL", + DbType: "GTS_DB_TYPE", + DbAddress: "GTS_DB_ADDRESS", + DbPort: "GTS_DB_PORT", + DbUser: "GTS_DB_USER", + DbPassword: "GTS_DB_PASSWORD", + DbDatabase: "GTS_DB_DATABASE", + TemplateBaseDir: "GTS_TEMPLATE_BASEDIR", + AccountsOpenRegistration: "GTS_ACCOUNTS_OPEN_REGISTRATION", } } diff --git a/internal/db/pg.go b/internal/db/pg.go index a033801..92d6396 100644 --- a/internal/db/pg.go +++ b/internal/db/pg.go @@ -304,6 +304,7 @@ func (ps *postgresService) GetAccountByUserID(userID string, account *model.Acco return err } if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil { + fmt.Println(account) if err == pg.ErrNoRows { return ErrNoEntries{} } @@ -394,7 +395,7 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype fmt.Printf("fields: %+v", fields) // count followers - var followers []model.Follow + followers := []model.Follow{} if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil { if _, ok := err.(ErrNoEntries); !ok { return nil, fmt.Errorf("error getting followers: %s", err) @@ -406,7 +407,7 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype } // count following - var following []model.Follow + following := []model.Follow{} if err := ps.GetFollowingByAccountID(a.ID, &following); err != nil { if _, ok := err.(ErrNoEntries); !ok { return nil, fmt.Errorf("error getting following: %s", err) @@ -418,7 +419,7 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype } // count statuses - var statuses []model.Status + statuses := []model.Status{} if err := ps.GetStatusesByAccountID(a.ID, &statuses); err != nil { if _, ok := err.(ErrNoEntries); !ok { return nil, fmt.Errorf("error getting last statuses: %s", err) @@ -430,7 +431,7 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype } // check when the last status was - var lastStatus *model.Status + lastStatus := &model.Status{} if err := ps.GetLastStatusForAccountID(a.ID, lastStatus); err != nil { if _, ok := err.(ErrNoEntries); !ok { return nil, fmt.Errorf("error getting last status: %s", err) diff --git a/internal/module/account/account.go b/internal/module/account/account.go index 36af90c..db04ed0 100644 --- a/internal/module/account/account.go +++ b/internal/module/account/account.go @@ -19,7 +19,6 @@ package account import ( - "fmt" "net/http" "github.com/gin-gonic/gin" @@ -29,6 +28,7 @@ import ( "github.com/gotosocial/gotosocial/internal/module" "github.com/gotosocial/gotosocial/internal/module/oauth" "github.com/gotosocial/gotosocial/internal/router" + "github.com/sirupsen/logrus" ) const ( @@ -40,49 +40,62 @@ const ( type accountModule struct { config *config.Config db db.DB + log *logrus.Logger } // New returns a new account module -func New(config *config.Config, db db.DB) module.ClientAPIModule { +func New(config *config.Config, db db.DB, log *logrus.Logger) module.ClientAPIModule { return &accountModule{ config: config, db: db, + log: log, } } // Route attaches all routes from this module to the given router func (m *accountModule) Route(r router.Router) error { + r.AttachHandler(http.MethodPost, basePath, m.AccountCreatePOSTHandler) r.AttachHandler(http.MethodGet, verifyPath, m.AccountVerifyGETHandler) return nil } +func (m *accountModule) AccountCreatePOSTHandler(c *gin.Context) { + l := m.log.WithField("func", "AccountCreatePOSTHandler") + l.Trace("checking if registration is open") + if !m.config.AccountsConfig.OpenRegistration { + l.Trace("account registration is closed, returning error to client") + } +} + // AccountVerifyGETHandler serves a user's account details to them IF they reached this // handler while in possession of a valid token, according to the oauth middleware. func (m *accountModule) AccountVerifyGETHandler(c *gin.Context) { - i, ok := c.Get(oauth.SessionAuthorizedUser) - fmt.Println(i) + l := m.log.WithField("func", "AccountVerifyGETHandler") + + l.Trace("getting account details from session") + i, ok := c.Get(oauth.SessionAuthorizedAccount) if !ok { + l.Trace("no account in session, returning error to client") c.JSON(http.StatusUnauthorized, gin.H{"error": "The access token is invalid"}) return } - userID, ok := (i).(string) - if !ok || userID == "" { + l.Trace("attempting to convert account interface into account struct...") + acct, ok := i.(*model.Account) + if !ok { + l.Tracef("could not convert %+v into account struct, returning error to client", i) c.JSON(http.StatusUnauthorized, gin.H{"error": "The access token is invalid"}) return } - acct := &model.Account{} - if err := m.db.GetAccountByUserID(userID, acct); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - + l.Tracef("retrieved account %+v, converting to mastosensitive...", acct) acctSensitive, err := m.db.AccountToMastoSensitive(acct) if err != nil { + l.Tracef("could not convert account into mastosensitive account: %s", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive) c.JSON(http.StatusOK, acctSensitive) } diff --git a/internal/module/account/account_test.go b/internal/module/account/account_test.go index 37ba727..8a7917e 100644 --- a/internal/module/account/account_test.go +++ b/internal/module/account/account_test.go @@ -216,10 +216,19 @@ func (suite *AccountTestSuite) TestAPIInitialize() { } r.AttachMiddleware(func(c *gin.Context) { + account := &model.Account{} + if err := suite.db.GetAccountByUserID(suite.testUser.ID, account); err != nil || account == nil { + suite.T().Log(err) + suite.FailNowf("no account found for user %s, continuing with unauthenticated request: %+v", "", suite.testUser.ID, account) + fmt.Println(account) + return + } + + c.Set(oauth.SessionAuthorizedAccount, account) c.Set(oauth.SessionAuthorizedUser, suite.testUser.ID) }) - acct := New(suite.config, suite.db) + acct := New(suite.config, suite.db, log) acct.Route(r) r.Start() diff --git a/internal/module/oauth/oauth.go b/internal/module/oauth/oauth.go index 2f77561..c69b626 100644 --- a/internal/module/oauth/oauth.go +++ b/internal/module/oauth/oauth.go @@ -47,11 +47,12 @@ import ( ) const ( - appsPath = "/api/v1/apps" - authSignInPath = "/auth/sign_in" - oauthTokenPath = "/oauth/token" - oauthAuthorizePath = "/oauth/authorize" - SessionAuthorizedUser = "authorized_user" + appsPath = "/api/v1/apps" + authSignInPath = "/auth/sign_in" + oauthTokenPath = "/oauth/token" + oauthAuthorizePath = "/oauth/authorize" + SessionAuthorizedUser = "authorized_user" + SessionAuthorizedAccount = "authorized_account" ) // oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface @@ -406,16 +407,30 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) { MIDDLEWARE */ -// oauthTokenMiddleware +// oauthTokenMiddleware checks if the client has presented a valid oauth Bearer token. +// If so, it will check the User that the token belongs to, and set that in the context of +// the request. Then, it will look up the account for that user, and set that in the request too. +// If user or account can't be found, then the handler won't *fail*, in case the server wants to allow +// public requests that don't have a Bearer token set (eg., for public instance information and so on). func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) { l := m.log.WithField("func", "ValidatePassword") l.Trace("entering OauthTokenMiddleware") - if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil { - l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope()) - c.Set(SessionAuthorizedUser, ti.GetUserID()) - } else { - l.Trace("continuing with unauthenticated request") + + ti, err := m.oauthServer.ValidationBearerToken(c.Request) + if err != nil { + l.Trace("no valid token presented: continuing with unauthenticated request") + return } + l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope()) + + acct := &model.Account{} + if err := m.db.GetAccountByUserID(ti.GetUserID(), acct); err != nil || acct == nil { + l.Tracef("no account found for user %s, continuing with unauthenticated request", ti.GetUserID()) + return + } + + c.Set(SessionAuthorizedAccount, acct) + c.Set(SessionAuthorizedUser, ti.GetUserID()) } /* diff --git a/internal/util/validation.go b/internal/util/validation.go new file mode 100644 index 0000000..3012553 --- /dev/null +++ b/internal/util/validation.go @@ -0,0 +1,23 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package util + +func CheckPasswordStrength(password string) bool { + return true +} diff --git a/pkg/mastotypes/account.go b/pkg/mastotypes/account.go index ca50c64..77581be 100644 --- a/pkg/mastotypes/account.go +++ b/pkg/mastotypes/account.go @@ -67,3 +67,21 @@ type Account struct { // An extra entity to be used with API methods to verify credentials and update credentials. Source *Source `json:"source"` } + +// AccountCreateRequest represents the form submitted during a POST request to /api/v1/accounts. +// See https://docs.joinmastodon.org/methods/accounts/ +type AccountCreateRequest struct { + // Text that will be reviewed by moderators if registrations require manual approval. + Reason string `form:"reason"` + // The desired username for the account + Username string `form:"username"` + // The email address to be used for login + Email string `form:"email"` + // The password to be used for login + Password string `form:"password"` + // Whether the user agrees to the local rules, terms, and policies. + // These should be presented to the user in order to allow them to consent before setting this parameter to TRUE. + Agreement bool `form:"agreement"` + // The language of the confirmation email that will be sent + Locale string `form:"locale"` +}