diff --git a/internal/module/account/account.go b/internal/module/account/account.go index f5649a3..5cd21d6 100644 --- a/internal/module/account/account.go +++ b/internal/module/account/account.go @@ -29,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db/model" + "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/module" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/router" @@ -38,25 +39,28 @@ import ( ) const ( - basePath = "/api/v1/accounts" - basePathWithID = basePath + "/:id" - verifyPath = basePath + "/verify_credentials" + basePath = "/api/v1/accounts" + basePathWithID = basePath + "/:id" + verifyPath = basePath + "/verify_credentials" + updateCredentialsPath = basePath + "/update_credentials" ) type accountModule struct { - config *config.Config - db db.DB - oauthServer oauth.Server - log *logrus.Logger + config *config.Config + db db.DB + oauthServer oauth.Server + mediaHandler media.MediaHandler + log *logrus.Logger } // New returns a new account module -func New(config *config.Config, db db.DB, oauthServer oauth.Server, log *logrus.Logger) module.ClientAPIModule { +func New(config *config.Config, db db.DB, oauthServer oauth.Server, mediaHandler media.MediaHandler, log *logrus.Logger) module.ClientAPIModule { return &accountModule{ - config: config, - db: db, - oauthServer: oauthServer, - log: log, + config: config, + db: db, + oauthServer: oauthServer, + mediaHandler: mediaHandler, + log: log, } } @@ -64,6 +68,7 @@ func New(config *config.Config, db db.DB, oauthServer oauth.Server, log *logrus. func (m *accountModule) Route(r router.Router) error { r.AttachHandler(http.MethodPost, basePath, m.accountCreatePOSTHandler) r.AttachHandler(http.MethodGet, verifyPath, m.accountVerifyGETHandler) + r.AttachHandler(http.MethodPatch, updateCredentialsPath, m.accountUpdateCredentialsPATCHHandler) return nil } @@ -117,10 +122,15 @@ func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) { // handler while in possession of a valid token, according to the oauth middleware. // It should be served as a GET at /api/v1/accounts/verify_credentials func (m *accountModule) accountVerifyGETHandler(c *gin.Context) { - l := m.log.WithField("func", "AccountVerifyGETHandler") + l := m.log.WithField("func", "accountVerifyGETHandler") authed, err := oauth.MustAuth(c, true, false, false, true) + if err != nil { + l.Debugf("couldn't auth: %s", err) + c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) + return + } - l.Tracef("retrieved account %+v, converting to mastosensitive...", authed.Account) + l.Tracef("retrieved account %+v, converting to mastosensitive...", authed.Account.ID) acctSensitive, err := m.db.AccountToMastoSensitive(authed.Account) if err != nil { l.Tracef("could not convert account into mastosensitive account: %s", err) @@ -132,6 +142,53 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) { c.JSON(http.StatusOK, acctSensitive) } +// accountUpdateCredentialsPATCHHandler allows a user to modify their account/profile settings. +// It should be served as a PATCH at /api/v1/accounts/update_credentials +func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { + l := m.log.WithField("func", "accountUpdateCredentialsPATCHHandler") + authed, err := oauth.MustAuth(c, true, false, false, true) + if err != nil { + l.Debugf("couldn't auth: %s", err) + c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) + return + } + + l.Trace("parsing request form") + form := &mastotypes.UpdateCredentialsRequest{} + if err := c.ShouldBind(form); err != nil || form == nil { + l.Debugf("could not parse form from request: %s", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "missing one or more required form values"}) + return + } + + // TODO: form validation + + // TODO: tidy this code into subfunctions + if form.Header != nil && form.Header.Size != 0 { + if form.Header.Size > m.config.MediaConfig.MaxImageSize { + err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", form.Header.Size, m.config.MediaConfig.MaxImageSize) + l.Debugf("error processing header: %s", err) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + f, err := form.Header.Open() + if err != nil { + l.Debugf("error processing header: %s", err) + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)}) + return + } + headerInfo, err := m.mediaHandler.SetHeaderForAccountID(f, authed.Account.ID) + if err != nil { + l.Debugf("error processing header: %s", err) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + l.Tracef("new header info for account %s is %+v", headerInfo) + } + + l.Tracef("retrieved account %+v", authed.Account.ID) +} + /* HELPER FUNCTIONS */ diff --git a/internal/module/account/account_test.go b/internal/module/account/account_test.go index f712797..c515bcc 100644 --- a/internal/module/account/account_test.go +++ b/internal/module/account/account_test.go @@ -38,6 +38,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db/model" + "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/pkg/mastotypes" "github.com/superseriousbusiness/oauth2/v4" @@ -56,6 +57,7 @@ type AccountTestSuite struct { testApplication *model.Application testToken oauth2.TokenInfo mockOauthServer *oauth.MockServer + mockMediaHandler *media.MockMediaHandler db db.DB accountModule *accountModule newUserFormHappyPath url.Values @@ -128,8 +130,11 @@ func (suite *AccountTestSuite) SetupSuite() { Code: "we're authorized now!", }, nil) + // mock the media handler because some handlers (eg update credentials) need to upload media (new header/avatar) + suite.mockMediaHandler = &media.MockMediaHandler{} + // and finally here's the thing we're actually testing! - suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.log).(*accountModule) + suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mockMediaHandler, suite.log).(*accountModule) } func (suite *AccountTestSuite) TearDownSuite() {