diff --git a/cmd/gotosocial/main.go b/cmd/gotosocial/main.go index a850938..983d49d 100644 --- a/cmd/gotosocial/main.go +++ b/cmd/gotosocial/main.go @@ -129,6 +129,52 @@ func main() { Value: true, EnvVars: []string{envNames.AccountsRequireApproval}, }, + + // MEDIA FLAGS + &cli.IntFlag{ + Name: flagNames.MediaMaxImageSize, + Usage: "Max size of accepted images in bytes", + Value: 1048576, // 1mb + EnvVars: []string{envNames.MediaMaxImageSize}, + }, + &cli.IntFlag{ + Name: flagNames.MediaMaxVideoSize, + Usage: "Max size of accepted videos in bytes", + Value: 5242880, // 5mb + EnvVars: []string{envNames.MediaMaxVideoSize}, + }, + + // STORAGE FLAGS + &cli.StringFlag{ + Name: flagNames.StorageBackend, + Usage: "Storage backend to use for media attachments", + Value: "local", + EnvVars: []string{envNames.StorageBackend}, + }, + &cli.StringFlag{ + Name: flagNames.StorageBasePath, + Usage: "Full path to an already-created directory where gts should store/retrieve media files", + Value: "/opt/gotosocial", + EnvVars: []string{envNames.StorageBasePath}, + }, + &cli.StringFlag{ + Name: flagNames.StorageServeProtocol, + Usage: "Protocol to use for serving media attachments (use https if storage is local)", + Value: "https", + EnvVars: []string{envNames.StorageServeProtocol}, + }, + &cli.StringFlag{ + Name: flagNames.StorageServeHost, + Usage: "Hostname to serve media attachments from (use the same value as host if storage is local)", + Value: "localhost", + EnvVars: []string{envNames.StorageServeHost}, + }, + &cli.StringFlag{ + Name: flagNames.StorageServeBasePath, + Usage: "Path to append to protocol and hostname to create the base path from which media files will be served (default will mostly be fine)", + Value: "/fileserver/media", + EnvVars: []string{envNames.StorageServeBasePath}, + }, }, Commands: []*cli.Command{ { diff --git a/internal/apimodule/account/account.go b/internal/apimodule/account/account.go index 27c8a77..2d9ddbb 100644 --- a/internal/apimodule/account/account.go +++ b/internal/apimodule/account/account.go @@ -21,7 +21,9 @@ package account import ( "fmt" "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/apimodule" "github.com/superseriousbusiness/gotosocial/internal/config" @@ -62,8 +64,7 @@ func New(config *config.Config, db db.DB, oauthServer oauth.Server, mediaHandler // Route attaches all routes from this module to the given router func (m *accountModule) Route(r router.Router) error { r.AttachHandler(http.MethodPost, basePath, m.accountCreatePOSTHandler) - r.AttachHandler(http.MethodGet, verifyPath, m.accountVerifyGETHandler) - r.AttachHandler(http.MethodPatch, updateCredentialsPath, m.accountUpdateCredentialsPATCHHandler) + r.AttachHandler(http.MethodGet, basePathWithID, m.muxHandler) return nil } @@ -86,3 +87,14 @@ func (m *accountModule) CreateTables(db db.DB) error { } return nil } + +func (m *accountModule) muxHandler(c *gin.Context) { + ru := c.Request.RequestURI + if strings.HasPrefix(ru, verifyPath) { + m.accountVerifyGETHandler(c) + } else if strings.HasPrefix(ru, updateCredentialsPath) { + m.accountUpdateCredentialsPATCHHandler(c) + } else { + m.accountGETHandler(c) + } +} diff --git a/internal/apimodule/account/accountcreate.go b/internal/apimodule/account/accountcreate.go index 23cb530..44b28bd 100644 --- a/internal/apimodule/account/accountcreate.go +++ b/internal/apimodule/account/accountcreate.go @@ -119,7 +119,7 @@ func validateCreateAccount(form *mastotypes.AccountCreateRequest, c *config.Acco return errors.New("registration is not open for this server") } - if err := util.ValidateSignUpUsername(form.Username); err != nil { + if err := util.ValidateUsername(form.Username); err != nil { return err } @@ -127,7 +127,7 @@ func validateCreateAccount(form *mastotypes.AccountCreateRequest, c *config.Acco return err } - if err := util.ValidateSignUpPassword(form.Password); err != nil { + if err := util.ValidateNewPassword(form.Password); err != nil { return err } diff --git a/internal/apimodule/account/account_test.go b/internal/apimodule/account/accountcreate_test.go similarity index 93% rename from internal/apimodule/account/account_test.go rename to internal/apimodule/account/accountcreate_test.go index 293f551..95f73de 100644 --- a/internal/apimodule/account/account_test.go +++ b/internal/apimodule/account/accountcreate_test.go @@ -52,7 +52,7 @@ import ( "golang.org/x/crypto/bcrypt" ) -type AccountTestSuite struct { +type AccountCreateTestSuite struct { suite.Suite config *config.Config log *logrus.Logger @@ -74,7 +74,7 @@ type AccountTestSuite struct { */ // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout -func (suite *AccountTestSuite) SetupSuite() { +func (suite *AccountCreateTestSuite) SetupSuite() { // some of our subsequent entities need a log so create this here log := logrus.New() log.SetLevel(logrus.TraceLevel) @@ -109,6 +109,8 @@ func (suite *AccountTestSuite) SetupSuite() { // Direct config to local postgres instance c := config.Empty() + c.Protocol = "http" + c.Host = "localhost" c.DBConfig = &config.DBConfig{ Type: "postgres", Address: "localhost", @@ -121,6 +123,13 @@ func (suite *AccountTestSuite) SetupSuite() { c.MediaConfig = &config.MediaConfig{ MaxImageSize: 2 << 20, } + c.StorageConfig = &config.StorageConfig{ + Backend: "local", + BasePath: "/tmp", + ServeProtocol: "http", + ServeHost: "localhost", + ServeBasePath: "/fileserver/media", + } suite.config = c // use an actual database for this, because it's just easier than mocking one out @@ -155,14 +164,14 @@ func (suite *AccountTestSuite) SetupSuite() { suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.log).(*accountModule) } -func (suite *AccountTestSuite) TearDownSuite() { +func (suite *AccountCreateTestSuite) TearDownSuite() { if err := suite.db.Stop(context.Background()); err != nil { logrus.Panicf("error closing db connection: %s", err) } } // SetupTest creates a db connection and creates necessary tables before each test -func (suite *AccountTestSuite) SetupTest() { +func (suite *AccountCreateTestSuite) SetupTest() { // create all the tables we might need in thie suite models := []interface{}{ &model.User{}, @@ -199,7 +208,7 @@ func (suite *AccountTestSuite) SetupTest() { } // TearDownTest drops tables to make sure there's no data in the db -func (suite *AccountTestSuite) TearDownTest() { +func (suite *AccountCreateTestSuite) TearDownTest() { // remove all the tables we might have used so it's clear for the next test models := []interface{}{ @@ -231,7 +240,7 @@ func (suite *AccountTestSuite) TearDownTest() { // and at the end of it a new user and account should be added into the database. // // This is the handler served at /api/v1/accounts as POST -func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerSuccessful() { +func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerSuccessful() { // setup recorder := httptest.NewRecorder() @@ -307,7 +316,7 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerSuccessful() { // TestAccountCreatePOSTHandlerNoAuth makes sure that the handler fails when no authorization is provided: // only registered applications can create accounts, and we don't provide one here. -func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerNoAuth() { +func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerNoAuth() { // setup recorder := httptest.NewRecorder() @@ -330,7 +339,7 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerNoAuth() { } // TestAccountCreatePOSTHandlerNoAuth makes sure that the handler fails when no form is provided at all. -func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerNoForm() { +func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerNoForm() { // setup recorder := httptest.NewRecorder() @@ -352,7 +361,7 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerNoForm() { } // TestAccountCreatePOSTHandlerWeakPassword makes sure that the handler fails when a weak password is provided -func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerWeakPassword() { +func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerWeakPassword() { // setup recorder := httptest.NewRecorder() @@ -377,7 +386,7 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerWeakPassword() { } // TestAccountCreatePOSTHandlerWeirdLocale makes sure that the handler fails when a weird locale is provided -func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerWeirdLocale() { +func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerWeirdLocale() { // setup recorder := httptest.NewRecorder() @@ -402,7 +411,7 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerWeirdLocale() { } // TestAccountCreatePOSTHandlerRegistrationsClosed makes sure that the handler fails when registrations are closed -func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerRegistrationsClosed() { +func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerRegistrationsClosed() { // setup recorder := httptest.NewRecorder() @@ -428,7 +437,7 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerRegistrationsClosed() } // TestAccountCreatePOSTHandlerReasonNotProvided makes sure that the handler fails when no reason is provided but one is required -func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerReasonNotProvided() { +func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerReasonNotProvided() { // setup recorder := httptest.NewRecorder() @@ -455,7 +464,7 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerReasonNotProvided() { } // TestAccountCreatePOSTHandlerReasonNotProvided makes sure that the handler fails when a crappy reason is presented but a good one is required -func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerInsufficientReason() { +func (suite *AccountCreateTestSuite) TestAccountCreatePOSTHandlerInsufficientReason() { // setup recorder := httptest.NewRecorder() @@ -485,7 +494,7 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerInsufficientReason() TESTING: AccountUpdateCredentialsPATCHHandler */ -func (suite *AccountTestSuite) TestAccountUpdateCredentialsPATCHHandler() { +func (suite *AccountCreateTestSuite) TestAccountUpdateCredentialsPATCHHandler() { // put test local account in db err := suite.db.Put(suite.testAccountLocal) @@ -533,6 +542,6 @@ func (suite *AccountTestSuite) TestAccountUpdateCredentialsPATCHHandler() { // assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b)) } -func TestAccountTestSuite(t *testing.T) { - suite.Run(t, new(AccountTestSuite)) +func TestAccountCreateTestSuite(t *testing.T) { + suite.Run(t, new(AccountCreateTestSuite)) } diff --git a/internal/apimodule/account/accountget.go b/internal/apimodule/account/accountget.go new file mode 100644 index 0000000..1458c34 --- /dev/null +++ b/internal/apimodule/account/accountget.go @@ -0,0 +1,58 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package account + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/db/model" +) + +// accountGetHandler serves the account information held by the server in response to a GET +// request. It should be served as a GET at /api/v1/accounts/:id. +// +// See: https://docs.joinmastodon.org/methods/accounts/ +func (m *accountModule) accountGETHandler(c *gin.Context) { + targetAcctID := c.Param(idKey) + if targetAcctID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error":"no account id specified"}) + return + } + + targetAccount := &model.Account{} + if err := m.db.GetByID(targetAcctID, targetAccount); err != nil { + if _, ok := err.(db.ErrNoEntries); ok { + c.JSON(http.StatusNotFound, gin.H{"error":"Record not found"}) + return + } else { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + acctInfo, err := m.db.AccountToMastoPublic(targetAccount) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, acctInfo) +} diff --git a/internal/apimodule/account/accountupdate.go b/internal/apimodule/account/accountupdate.go index 6221dac..af312a2 100644 --- a/internal/apimodule/account/accountupdate.go +++ b/internal/apimodule/account/accountupdate.go @@ -29,6 +29,7 @@ import ( "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/internal/db/model" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/superseriousbusiness/gotosocial/pkg/mastotypes" ) @@ -38,6 +39,8 @@ import ( // TODO: this can be optimized massively by building up a picture of what we want the new account // details to be, and then inserting it all in the database at once. As it is, we do queries one-by-one // which is not gonna make the database very happy when lots of requests are going through. +// This way it would also be safer because the update won't happen until *all* the fields are validated. +// Otherwise we risk doing a partial update and that's gonna cause probllleeemmmsss. func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { l := m.log.WithField("func", "accountUpdateCredentialsPATCHHandler") authed, err := oauth.MustAuth(c, true, false, false, true) @@ -80,6 +83,10 @@ func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { } if form.DisplayName != nil { + if err := util.ValidateDisplayName(*form.DisplayName); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } if err := m.db.UpdateOneByID(authed.Account.ID, "display_name", *form.DisplayName, &model.Account{}); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -87,6 +94,10 @@ func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { } if form.Note != nil { + if err := util.ValidateNote(*form.Note); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } if err := m.db.UpdateOneByID(authed.Account.ID, "note", *form.Note, &model.Account{}); err != nil { l.Debugf("error updating note: %s", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -116,17 +127,46 @@ func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { if form.Locked != nil { if err := m.db.UpdateOneByID(authed.Account.ID, "locked", *form.Locked, &model.Account{}); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"": err.Error()}) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } } if form.Source != nil { + if form.Source.Language != nil { + if err := util.ValidateLanguage(*form.Source.Language); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } else { + if err := m.db.UpdateOneByID(authed.Account.ID, "language", *form.Source.Language, &model.Account{}); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + } + if form.Source.Sensitive != nil { + if err := m.db.UpdateOneByID(authed.Account.ID, "locked", *form.Locked, &model.Account{}); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + if form.Source.Privacy != nil { + if err := util.ValidatePrivacy(*form.Source.Privacy); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } else { + if err := m.db.UpdateOneByID(authed.Account.ID, "privacy", *form.Source.Privacy, &model.Account{}); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + } } if form.FieldsAttributes != nil { - + // TODO: parse fields attributes nicely and update } // fetch the account with all updated values set @@ -159,7 +199,7 @@ func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { // the account's new avatar image. func (m *accountModule) UpdateAccountAvatar(avatar *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) { var err error - if avatar.Size > m.config.MediaConfig.MaxImageSize { + if int(avatar.Size) > m.config.MediaConfig.MaxImageSize { err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, m.config.MediaConfig.MaxImageSize) return nil, err } @@ -192,7 +232,7 @@ func (m *accountModule) UpdateAccountAvatar(avatar *multipart.FileHeader, accoun // the account's new header image. func (m *accountModule) UpdateAccountHeader(header *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) { var err error - if header.Size > m.config.MediaConfig.MaxImageSize { + if int(header.Size) > m.config.MediaConfig.MaxImageSize { err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, m.config.MediaConfig.MaxImageSize) return nil, err } diff --git a/internal/apimodule/account/accountupdate_test.go b/internal/apimodule/account/accountupdate_test.go new file mode 100644 index 0000000..5886b7b --- /dev/null +++ b/internal/apimodule/account/accountupdate_test.go @@ -0,0 +1,300 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package account + +import ( + "bytes" + "context" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/url" + "os" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/db/model" + "github.com/superseriousbusiness/gotosocial/internal/media" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/oauth2/v4" + "github.com/superseriousbusiness/oauth2/v4/models" + oauthmodels "github.com/superseriousbusiness/oauth2/v4/models" +) + +type AccountUpdateTestSuite struct { + suite.Suite + config *config.Config + log *logrus.Logger + testAccountLocal *model.Account + testAccountRemote *model.Account + testUser *model.User + testApplication *model.Application + testToken oauth2.TokenInfo + mockOauthServer *oauth.MockServer + mockStorage *storage.MockStorage + mediaHandler media.MediaHandler + db db.DB + accountModule *accountModule + newUserFormHappyPath url.Values +} + +/* + TEST INFRASTRUCTURE +*/ + +// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout +func (suite *AccountUpdateTestSuite) SetupSuite() { + // some of our subsequent entities need a log so create this here + log := logrus.New() + log.SetLevel(logrus.TraceLevel) + suite.log = log + + suite.testAccountLocal = &model.Account{ + ID: uuid.NewString(), + Username: "test_user", + } + + // can use this test application throughout + suite.testApplication = &model.Application{ + ID: "weeweeeeeeeeeeeeee", + Name: "a test application", + Website: "https://some-application-website.com", + RedirectURI: "http://localhost:8080", + ClientID: "a-known-client-id", + ClientSecret: "some-secret", + Scopes: "read", + VapidKey: "aaaaaa-aaaaaaaa-aaaaaaaaaaa", + } + + // can use this test token throughout + suite.testToken = &oauthmodels.Token{ + ClientID: "a-known-client-id", + RedirectURI: "http://localhost:8080", + Scope: "read", + Code: "123456789", + CodeCreateAt: time.Now(), + CodeExpiresIn: time.Duration(10 * time.Minute), + } + + // Direct config to local postgres instance + c := config.Empty() + c.Protocol = "http" + c.Host = "localhost" + c.DBConfig = &config.DBConfig{ + Type: "postgres", + Address: "localhost", + Port: 5432, + User: "postgres", + Password: "postgres", + Database: "postgres", + ApplicationName: "gotosocial", + } + c.MediaConfig = &config.MediaConfig{ + MaxImageSize: 2 << 20, + } + c.StorageConfig = &config.StorageConfig{ + Backend: "local", + BasePath: "/tmp", + ServeProtocol: "http", + ServeHost: "localhost", + ServeBasePath: "/fileserver/media", + } + suite.config = c + + // use an actual database for this, because it's just easier than mocking one out + database, err := db.New(context.Background(), c, log) + if err != nil { + suite.FailNow(err.Error()) + } + suite.db = database + + // we need to mock the oauth server because account creation needs it to create a new token + suite.mockOauthServer = &oauth.MockServer{} + suite.mockOauthServer.On("GenerateUserAccessToken", suite.testToken, suite.testApplication.ClientSecret, mock.AnythingOfType("string")).Run(func(args mock.Arguments) { + l := suite.log.WithField("func", "GenerateUserAccessToken") + token := args.Get(0).(oauth2.TokenInfo) + l.Infof("received token %+v", token) + clientSecret := args.Get(1).(string) + l.Infof("received clientSecret %+v", clientSecret) + userID := args.Get(2).(string) + l.Infof("received userID %+v", userID) + }).Return(&models.Token{ + Code: "we're authorized now!", + }, nil) + + suite.mockStorage = &storage.MockStorage{} + // We don't need storage to do anything for these tests, so just simulate a success and do nothing -- we won't need to return anything from storage + suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil) + + // set a media handler because some handlers (eg update credentials) need to upload media (new header/avatar) + suite.mediaHandler = media.New(suite.config, suite.db, suite.mockStorage, log) + + // and finally here's the thing we're actually testing! + suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.log).(*accountModule) +} + +func (suite *AccountUpdateTestSuite) TearDownSuite() { + if err := suite.db.Stop(context.Background()); err != nil { + logrus.Panicf("error closing db connection: %s", err) + } +} + +// SetupTest creates a db connection and creates necessary tables before each test +func (suite *AccountUpdateTestSuite) SetupTest() { + // create all the tables we might need in thie suite + models := []interface{}{ + &model.User{}, + &model.Account{}, + &model.Follow{}, + &model.FollowRequest{}, + &model.Status{}, + &model.Application{}, + &model.EmailDomainBlock{}, + &model.MediaAttachment{}, + } + for _, m := range models { + if err := suite.db.CreateTable(m); err != nil { + logrus.Panicf("db connection error: %s", err) + } + } + + // form to submit for happy path account create requests -- this will be changed inside tests so it's better to set it before each test + suite.newUserFormHappyPath = url.Values{ + "reason": []string{"a very good reason that's at least 40 characters i swear"}, + "username": []string{"test_user"}, + "email": []string{"user@example.org"}, + "password": []string{"very-strong-password"}, + "agreement": []string{"true"}, + "locale": []string{"en"}, + } + + // same with accounts config + suite.config.AccountsConfig = &config.AccountsConfig{ + OpenRegistration: true, + RequireApproval: true, + ReasonRequired: true, + } +} + +// TearDownTest drops tables to make sure there's no data in the db +func (suite *AccountUpdateTestSuite) TearDownTest() { + + // remove all the tables we might have used so it's clear for the next test + models := []interface{}{ + &model.User{}, + &model.Account{}, + &model.Follow{}, + &model.FollowRequest{}, + &model.Status{}, + &model.Application{}, + &model.EmailDomainBlock{}, + &model.MediaAttachment{}, + } + for _, m := range models { + if err := suite.db.DropTable(m); err != nil { + logrus.Panicf("error dropping table: %s", err) + } + } +} + +/* + ACTUAL TESTS +*/ + +/* + TESTING: AccountUpdateCredentialsPATCHHandler +*/ + +func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler() { + + // put test local account in db + err := suite.db.Put(suite.testAccountLocal) + assert.NoError(suite.T(), err) + + // attach avatar to request form + avatarFile, err := os.Open("../../media/test/test-jpeg.jpg") + assert.NoError(suite.T(), err) + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + avatarPart, err := writer.CreateFormFile("avatar", "test-jpeg.jpg") + assert.NoError(suite.T(), err) + + _, err = io.Copy(avatarPart, avatarFile) + assert.NoError(suite.T(), err) + + err = avatarFile.Close() + assert.NoError(suite.T(), err) + + // set display name to a new value + displayNamePart, err := writer.CreateFormField("display_name") + assert.NoError(suite.T(), err) + + _, err = io.Copy(displayNamePart, bytes.NewBufferString("test_user_wohoah")) + assert.NoError(suite.T(), err) + + // set locked to true + lockedPart, err := writer.CreateFormField("locked") + assert.NoError(suite.T(), err) + + _, err = io.Copy(lockedPart, bytes.NewBufferString("true")) + assert.NoError(suite.T(), err) + + // close the request writer, the form is now prepared + err = writer.Close() + assert.NoError(suite.T(), err) + + // setup + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccountLocal) + ctx.Set(oauth.SessionAuthorizedToken, suite.testToken) + ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", updateCredentialsPath), body) // the endpoint we're hitting + ctx.Request.Header.Set("Content-Type", writer.FormDataContentType()) + suite.accountModule.accountUpdateCredentialsPATCHHandler(ctx) + + // check response + + // 1. we should have OK because our request was valid + suite.EqualValues(http.StatusOK, recorder.Code) + + // 2. we should have an error message in the result body + result := recorder.Result() + defer result.Body.Close() + // TODO: implement proper checks here + // + // b, err := ioutil.ReadAll(result.Body) + // assert.NoError(suite.T(), err) + // assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b)) +} + +func TestAccountUpdateTestSuite(t *testing.T) { + suite.Run(t, new(AccountUpdateTestSuite)) +} diff --git a/internal/apimodule/account/accountverify_test.go b/internal/apimodule/account/accountverify_test.go new file mode 100644 index 0000000..223a0c1 --- /dev/null +++ b/internal/apimodule/account/accountverify_test.go @@ -0,0 +1,19 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package account diff --git a/internal/apimodule/fileserver/fileserver.go b/internal/apimodule/fileserver/fileserver.go new file mode 100644 index 0000000..af7bc5a --- /dev/null +++ b/internal/apimodule/fileserver/fileserver.go @@ -0,0 +1,42 @@ +package fileserver + +import ( + "fmt" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/apimodule" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/router" + "github.com/superseriousbusiness/gotosocial/internal/storage" +) + +// fileServer implements the RESTAPIModule interface. +// The goal here is to serve requested media files if the gotosocial server is configured to use local storage. +type fileServer struct { + config *config.Config + db db.DB + storage storage.Storage + log *logrus.Logger + storageBase string +} + +// New returns a new fileServer module +func New(config *config.Config, db db.DB, storage storage.Storage, log *logrus.Logger) apimodule.ClientAPIModule { + + storageBase := fmt.Sprintf("%s", config.StorageConfig.BasePath) // TODO: do this properly + + return &fileServer{ + config: config, + db: db, + storage: storage, + log: log, + storageBase: storageBase, + } +} + +// Route satisfies the RESTAPIModule interface +func (m *fileServer) Route(s router.Router) error { + // s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler) + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go index f68c6e6..c68e585 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -35,26 +35,19 @@ type Config struct { TemplateConfig *TemplateConfig `yaml:"template"` AccountsConfig *AccountsConfig `yaml:"accounts"` MediaConfig *MediaConfig `yaml:"media"` + StorageConfig *StorageConfig `yaml:"storage"` } // FromFile returns a new config from a file, or an error if something goes amiss. func FromFile(path string) (*Config, error) { - c, err := loadFromFile(path) - if err != nil { - return nil, fmt.Errorf("error creating config: %s", err) - } - return c, nil -} - -// Default returns a new config with default values. -// Not yet implemented. -func Default() *Config { - // TODO: find a way of doing this without code repetition, because having to - // repeat all values here and elsewhere is annoying and gonna be prone to mistakes. - return &Config{ - DBConfig: &DBConfig{}, - TemplateConfig: &TemplateConfig{}, + if path != "" { + c, err := loadFromFile(path) + if err != nil { + return nil, fmt.Errorf("error creating config: %s", err) + } + return c, nil } + return Empty(), nil } // Empty just returns an empty config @@ -62,6 +55,9 @@ func Empty() *Config { return &Config{ DBConfig: &DBConfig{}, TemplateConfig: &TemplateConfig{}, + AccountsConfig: &AccountsConfig{}, + MediaConfig: &MediaConfig{}, + StorageConfig: &StorageConfig{}, } } @@ -147,6 +143,36 @@ func (c *Config) ParseCLIFlags(f KeyedFlags) { if f.IsSet(fn.AccountsRequireApproval) { c.AccountsConfig.RequireApproval = f.Bool(fn.AccountsRequireApproval) } + + // media flags + if c.MediaConfig.MaxImageSize == 0 || f.IsSet(fn.MediaMaxImageSize) { + c.MediaConfig.MaxImageSize = f.Int(fn.MediaMaxImageSize) + } + + if c.MediaConfig.MaxVideoSize == 0 || f.IsSet(fn.MediaMaxVideoSize) { + c.MediaConfig.MaxVideoSize = f.Int(fn.MediaMaxVideoSize) + } + + // storage flags + if c.StorageConfig.Backend == "" || f.IsSet(fn.StorageBackend) { + c.StorageConfig.Backend = f.String(fn.StorageBackend) + } + + if c.StorageConfig.BasePath == "" || f.IsSet(fn.StorageBasePath) { + c.StorageConfig.BasePath = f.String(fn.StorageBasePath) + } + + if c.StorageConfig.ServeProtocol == "" || f.IsSet(fn.StorageServeProtocol) { + c.StorageConfig.ServeProtocol = f.String(fn.StorageServeProtocol) + } + + if c.StorageConfig.ServeHost == "" || f.IsSet(fn.StorageServeHost) { + c.StorageConfig.ServeHost = f.String(fn.StorageServeHost) + } + + if c.StorageConfig.ServeBasePath == "" || f.IsSet(fn.StorageServeBasePath) { + c.StorageConfig.ServeBasePath = f.String(fn.StorageServeBasePath) + } } // KeyedFlags is a wrapper for any type that can store keyed flags and give them back. @@ -166,15 +192,27 @@ type Flags struct { ConfigPath string Host string Protocol string + DbType string DbAddress string DbPort string DbUser string DbPassword string DbDatabase string + TemplateBaseDir string + AccountsOpenRegistration string AccountsRequireApproval string + + MediaMaxImageSize string + MediaMaxVideoSize string + + StorageBackend string + StorageBasePath string + StorageServeProtocol string + StorageServeHost string + StorageServeBasePath string } // GetFlagNames returns a struct containing the names of the various flags used for @@ -186,15 +224,27 @@ func GetFlagNames() Flags { ConfigPath: "config-path", Host: "host", Protocol: "protocol", + DbType: "db-type", DbAddress: "db-address", DbPort: "db-port", DbUser: "db-user", DbPassword: "db-password", DbDatabase: "db-database", + TemplateBaseDir: "template-basedir", + AccountsOpenRegistration: "accounts-open-registration", AccountsRequireApproval: "accounts-require-approval", + + MediaMaxImageSize: "media-max-image-size", + MediaMaxVideoSize: "media-max-video-size", + + StorageBackend: "storage-backend", + StorageBasePath: "storage-base-path", + StorageServeProtocol: "storage-serve-protocol", + StorageServeHost: "storage-serve-host", + StorageServeBasePath: "storage-serve-base-path", } } @@ -207,14 +257,26 @@ func GetEnvNames() Flags { ConfigPath: "GTS_CONFIG_PATH", Host: "GTS_HOST", Protocol: "GTS_PROTOCOL", + DbType: "GTS_DB_TYPE", DbAddress: "GTS_DB_ADDRESS", DbPort: "GTS_DB_PORT", DbUser: "GTS_DB_USER", DbPassword: "GTS_DB_PASSWORD", DbDatabase: "GTS_DB_DATABASE", + TemplateBaseDir: "GTS_TEMPLATE_BASEDIR", + AccountsOpenRegistration: "GTS_ACCOUNTS_OPEN_REGISTRATION", AccountsRequireApproval: "GTS_ACCOUNTS_REQUIRE_APPROVAL", + + MediaMaxImageSize: "GTS_MEDIA_MAX_IMAGE_SIZE", + MediaMaxVideoSize: "GTS_MEDIA_MAX_VIDEO_SIZE", + + StorageBackend: "GTS_STORAGE_BACKEND", + StorageBasePath: "GTS_STORAGE_BASE_PATH", + StorageServeProtocol: "GTS_STORAGE_SERVE_PROTOCOL", + StorageServeHost: "GTS_STORAGE_SERVE_HOST", + StorageServeBasePath: "GTS_STORAGE_SERVE_BASE_PATH", } } diff --git a/internal/config/media.go b/internal/config/media.go index ae209d7..816e236 100644 --- a/internal/config/media.go +++ b/internal/config/media.go @@ -18,7 +18,10 @@ package config -// AccountsConfig contains configuration to do with creating accounts, new registrations, and defaults. +// MediaConfig contains configuration for receiving and parsing media files and attachments type MediaConfig struct { - MaxImageSize int64 `yaml:"maxImageSize"` + // Max size of uploaded images in bytes + MaxImageSize int `yaml:"maxImageSize"` + // Max size of uploaded video in bytes + MaxVideoSize int `yaml:"maxVideoSize"` } diff --git a/internal/config/storage.go b/internal/config/storage.go new file mode 100644 index 0000000..4a8ff79 --- /dev/null +++ b/internal/config/storage.go @@ -0,0 +1,36 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package config + +// StorageConfig contains configuration for storage and serving of media files and attachments +type StorageConfig struct { + // Type of storage backend to use: currently only 'local' is supported. + // TODO: add S3 support here. + Backend string `yaml:"backend"` + + // The base path for storing things. Should be an already-existing directory. + BasePath string `yaml:"basePath"` + + // Protocol to use when *serving* media files from storage + ServeProtocol string `yaml:"serveProtocol"` + // Host to use when *serving* media files from storage + ServeHost string `yaml:"serveHost"` + // Base path to use when *serving* media files from storage + ServeBasePath string `yaml:"serveBasePath"` +} diff --git a/internal/db/db.go b/internal/db/db.go index 2cd9c15..4921270 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -177,6 +177,11 @@ type DB interface { // if something goes wrong. The returned account should be ready to serialize on an API level, and may have sensitive fields, // so serve it only to an authorized user who should have permission to see it. AccountToMastoSensitive(account *model.Account) (*mastotypes.Account, error) + + // AccountToMastoPublic takes a db model account as a param, and returns a populated mastotype account, or an error + // if something goes wrong. The returned account should be ready to serialize on an API level, and may NOT have sensitive fields. + // In other words, this is the public record that the server has of an account. + AccountToMastoPublic(account *model.Account) (*mastotypes.Account, error) } // New returns a new database service that satisfies the DB interface and, by extension, diff --git a/internal/db/model/mediaattachment.go b/internal/db/model/mediaattachment.go index 4cb90f5..3aff18d 100644 --- a/internal/db/model/mediaattachment.go +++ b/internal/db/model/mediaattachment.go @@ -116,7 +116,7 @@ const ( // FileMeta describes metadata about the actual contents of the file. type FileMeta struct { Original Original - Small Small + Small Small } // Small implements SmallMeta and can be used for a thumbnail of any media type diff --git a/internal/db/pg.go b/internal/db/pg.go index 8d6c4a7..5ad9537 100644 --- a/internal/db/pg.go +++ b/internal/db/pg.go @@ -43,7 +43,7 @@ import ( // postgresService satisfies the DB interface type postgresService struct { - config *config.DBConfig + config *config.Config conn *pg.DB log *logrus.Entry cancel context.CancelFunc @@ -106,7 +106,7 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry // we can confidently return this useable postgres service now return &postgresService{ - config: c.DBConfig, + config: c, conn: conn, log: log, cancel: cancel, @@ -240,7 +240,7 @@ func (ps *postgresService) GetByID(id string, i interface{}) error { } func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error { - if err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select(); err != nil { + if err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Select(); err != nil { if err == pg.ErrNoRows { return ErrNoEntries{} } @@ -275,7 +275,7 @@ func (ps *postgresService) UpdateByID(id string, i interface{}) error { } func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error { - _, err := ps.conn.Model(i).Set("? = ?", key, value).Where("id = ?", id).Update() + _, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update() return err } @@ -290,7 +290,7 @@ func (ps *postgresService) DeleteByID(id string, i interface{}) error { } func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error { - if _, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete(); err != nil { + if _, err := ps.conn.Model(i).Where("? = ?", pg.Safe(key), value).Delete(); err != nil { if err == pg.ErrNoRows { return ErrNoEntries{} } @@ -437,10 +437,14 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr return nil, err } + // should be something like https://example.org/@some_username + url := fmt.Sprintf("%s://%s/@%s", ps.config.Protocol, ps.config.Host, username) + a := &model.Account{ Username: username, DisplayName: username, Reason: reason, + URL: url, PrivateKey: key, PublicKey: &key.PublicKey, ActorType: "Person", @@ -460,6 +464,7 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr Locale: locale, UnconfirmedEmail: email, CreatedByApplicationID: appID, + Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user } if _, err = ps.conn.Model(u).Insert(); err != nil { return nil, err @@ -502,7 +507,39 @@ func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, // https://docs.joinmastodon.org/methods/accounts/. Note that it's *sensitive* because it's only meant to be exposed to the user // that the account actually belongs to. func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) { + // we can build this sensitive account easily by first getting the public account.... + mastoAccount, err := ps.AccountToMastoPublic(a) + if err != nil { + return nil, err + } + // then adding the Source object to it... + + // check pending follow requests aimed at this account + fr := []model.FollowRequest{} + if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting follow requests: %s", err) + } + } + var frc int + if fr != nil { + frc = len(fr) + } + + mastoAccount.Source = &mastotypes.Source{ + Privacy: a.Privacy, + Sensitive: a.Sensitive, + Language: a.Language, + Note: a.Note, + Fields: mastoAccount.Fields, + FollowRequestsCount: frc, + } + + return mastoAccount, nil +} + +func (ps *postgresService) AccountToMastoPublic(a *model.Account) (*mastotypes.Account, error) { // count followers followers := []model.Follow{} if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil { @@ -583,47 +620,33 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype fields = append(fields, mField) } - // check pending follow requests aimed at this account - fr := []model.FollowRequest{} - if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil { - if _, ok := err.(ErrNoEntries); !ok { - return nil, fmt.Errorf("error getting follow requests: %s", err) - } - } - var frc int - if fr != nil { - frc = len(fr) - } - - // derive source from fields and other info - source := &mastotypes.Source{ - Privacy: a.Privacy, - Sensitive: a.Sensitive, - Language: a.Language, - Note: a.Note, - Fields: fields, - FollowRequestsCount: frc, + var acct string + if a.Domain != "" { + // this is a remote user + acct = fmt.Sprintf("%s@%s", a.Username, a.Domain) + } else { + // this is a local user + acct = a.Username } return &mastotypes.Account{ ID: a.ID, Username: a.Username, - Acct: a.Username, // equivalent to username for local users only, which sensitive always is + Acct: acct, DisplayName: a.DisplayName, Locked: a.Locked, Bot: a.Bot, CreatedAt: a.CreatedAt.Format(time.RFC3339), Note: a.Note, - URL: a.URL, // TODO: set this during account creation - Avatar: aviURL, // TODO: build this url properly using host and protocol from config - AvatarStatic: aviURLStatic, // TODO: build this url properly using host and protocol from config - Header: headerURL, // TODO: build this url properly using host and protocol from config - HeaderStatic: headerURLStatic, // TODO: build this url properly using host and protocol from config + URL: a.URL, + Avatar: aviURL, + AvatarStatic: aviURLStatic, + Header: headerURL, + HeaderStatic: headerURLStatic, FollowersCount: followersCount, FollowingCount: followingCount, StatusesCount: statusesCount, LastStatusAt: lastStatusAt, - Source: source, Emojis: nil, // TODO: implement this Fields: fields, }, nil diff --git a/internal/gotosocial/actions.go b/internal/gotosocial/actions.go index c348af3..29a391d 100644 --- a/internal/gotosocial/actions.go +++ b/internal/gotosocial/actions.go @@ -27,8 +27,18 @@ import ( "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/action" + "github.com/superseriousbusiness/gotosocial/internal/apimodule" + "github.com/superseriousbusiness/gotosocial/internal/apimodule/account" + "github.com/superseriousbusiness/gotosocial/internal/apimodule/app" + "github.com/superseriousbusiness/gotosocial/internal/apimodule/auth" + "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/media" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/router" + "github.com/superseriousbusiness/gotosocial/internal/storage" ) // Run creates and starts a gotosocial server @@ -38,9 +48,45 @@ var Run action.GTSAction = func(ctx context.Context, c *config.Config, log *logr return fmt.Errorf("error creating dbservice: %s", err) } - // if err := dbService.CreateSchema(ctx); err != nil { - // return fmt.Errorf("error creating dbschema: %s", err) - // } + router, err := router.New(c, log) + if err != nil { + return fmt.Errorf("error creating router: %s", err) + } + + storageBackend, err := storage.NewInMem(c, log) + if err != nil { + return fmt.Errorf("error creating storage backend: %s", err) + } + + // build backend handlers + mediaHandler := media.New(c, dbService, storageBackend, log) + oauthServer := oauth.New(dbService, log) + + // build client api modules + authModule := auth.New(oauthServer, dbService, log) + accountModule := account.New(c, dbService, oauthServer, mediaHandler, log) + appsModule := app.New(oauthServer, dbService, log) + + apiModules := []apimodule.ClientAPIModule{ + authModule, // this one has to go first so the other modules use its middleware + accountModule, + appsModule, + } + + for _, m := range apiModules { + if err := m.Route(router); err != nil { + return fmt.Errorf("routing error: %s", err) + } + } + + gts, err := New(dbService, &cache.MockCache{}, router, federation.New(dbService), c) + if err != nil { + return fmt.Errorf("error creating gotosocial service: %s", err) + } + + if err := gts.Start(ctx); err != nil { + return fmt.Errorf("error starting gotosocial service: %s", err) + } // catch shutdown signals from the operating system sigs := make(chan os.Signal, 1) @@ -49,8 +95,8 @@ var Run action.GTSAction = func(ctx context.Context, c *config.Config, log *logr log.Infof("received signal %s, shutting down", sig) // close down all running services in order - if err := dbService.Stop(ctx); err != nil { - return fmt.Errorf("error closing dbservice: %s", err) + if err := gts.Stop(ctx); err != nil { + return fmt.Errorf("error closing gotosocial service: %s", err) } log.Info("done! exiting...") diff --git a/internal/gotosocial/gotosocial.go b/internal/gotosocial/gotosocial.go index 3fb1e53..d8f46f8 100644 --- a/internal/gotosocial/gotosocial.go +++ b/internal/gotosocial/gotosocial.go @@ -32,6 +32,7 @@ import ( // The logic of stopping and starting the entire server is contained here. type Gotosocial interface { Start(context.Context) error + Stop(context.Context) error } // New returns a new gotosocial server, initialized with the given configuration. @@ -56,10 +57,19 @@ type gotosocial struct { config *config.Config } -// Start starts up the gotosocial server. It is a blocking call, so only call it when -// you're absolutely sure you want to start up the server. If something goes wrong -// while starting the server, then an error will be returned. You can treat this function a -// lot like you would treat http.ListenAndServe() +// Start starts up the gotosocial server. If something goes wrong +// while starting the server, then an error will be returned. func (gts *gotosocial) Start(ctx context.Context) error { + gts.apiRouter.Start() + return nil +} + +func (gts *gotosocial) Stop(ctx context.Context) error { + if err := gts.apiRouter.Stop(ctx); err != nil { + return err + } + if err := gts.db.Stop(ctx); err != nil { + return err + } return nil } diff --git a/internal/media/media.go b/internal/media/media.go index f66a215..fd517e2 100644 --- a/internal/media/media.go +++ b/internal/media/media.go @@ -151,13 +151,16 @@ func (mh *mediaHandler) processHeaderOrAvi(imageBytes []byte, contentType string // now put it in storage, take a new uuid for the name of the file so we don't store any unnecessary info about it extension := strings.Split(contentType, "/")[1] newMediaID := uuid.NewString() + + base := fmt.Sprintf("%s://%s%s", mh.config.StorageConfig.ServeProtocol, mh.config.StorageConfig.ServeHost, mh.config.StorageConfig.ServeBasePath, ) + // we store the original... - originalPath := fmt.Sprintf("%s/media/%s/original/%s.%s", accountID, headerOrAvi, newMediaID, extension) + originalPath := fmt.Sprintf("%s/%s/%s/original/%s.%s", base, accountID, headerOrAvi, newMediaID, extension) if err := mh.storage.StoreFileAt(originalPath, original.image); err != nil { return nil, fmt.Errorf("storage error: %s", err) } // and a thumbnail... - smallPath := fmt.Sprintf("%s/media/%s/small/%s.%s", accountID, headerOrAvi, newMediaID, extension) + smallPath := fmt.Sprintf("%s/%s/%s/small/%s.%s", base, accountID, headerOrAvi, newMediaID, extension) if err := mh.storage.StoreFileAt(smallPath, small.image); err != nil { return nil, fmt.Errorf("storage error: %s", err) } diff --git a/internal/media/media_test.go b/internal/media/media_test.go index b073e7d..18855a2 100644 --- a/internal/media/media_test.go +++ b/internal/media/media_test.go @@ -55,6 +55,8 @@ func (suite *MediaTestSuite) SetupSuite() { // Direct config to local postgres instance c := config.Empty() + c.Protocol = "http" + c.Host = "localhost" c.DBConfig = &config.DBConfig{ Type: "postgres", Address: "localhost", @@ -67,6 +69,13 @@ func (suite *MediaTestSuite) SetupSuite() { c.MediaConfig = &config.MediaConfig{ MaxImageSize: 2 << 20, } + c.StorageConfig = &config.StorageConfig{ + Backend: "local", + BasePath: "/tmp", + ServeProtocol: "http", + ServeHost: "localhost", + ServeBasePath: "/fileserver/media", + } suite.config = c // use an actual database for this, because it's just easier than mocking one out database, err := db.New(context.Background(), c, log) diff --git a/internal/router/router.go b/internal/router/router.go index b60e215..ce924b2 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -59,8 +59,6 @@ func (r *router) Start() { r.logger.Fatalf("listen: %s", err) } }() - // c := &gin.Context{} - // c.Get() } // Stop shuts down the router nicely diff --git a/internal/storage/inmem.go b/internal/storage/inmem.go new file mode 100644 index 0000000..25432fb --- /dev/null +++ b/internal/storage/inmem.go @@ -0,0 +1,31 @@ +package storage + +import ( + "fmt" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" +) + +func NewInMem(c *config.Config, log *logrus.Logger) (Storage, error) { + return &inMemStorage{ + stored: make(map[string][]byte), + }, nil +} + +type inMemStorage struct { + stored map[string][]byte +} + +func (s *inMemStorage) StoreFileAt(path string, data []byte) error { + s.stored[path] = data + return nil +} + +func (s *inMemStorage) RetrieveFileFrom(path string) ([]byte, error) { + d, ok := s.stored[path] + if !ok { + return nil, fmt.Errorf("no data found at path %s", path) + } + return d, nil +} diff --git a/internal/storage/local.go b/internal/storage/local.go new file mode 100644 index 0000000..29461d5 --- /dev/null +++ b/internal/storage/local.go @@ -0,0 +1,21 @@ +package storage + +import ( + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" +) + +func NewLocal(c *config.Config, log *logrus.Logger) (Storage, error) { + return &localStorage{}, nil +} + +type localStorage struct { +} + +func (s *localStorage) StoreFileAt(path string, data []byte) error { + return nil +} + +func (s *localStorage) RetrieveFileFrom(path string) ([]byte, error) { + return nil, nil +} diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 7257dc0..fa884ed 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -18,16 +18,7 @@ package storage -import "time" - type Storage interface { StoreFileAt(path string, data []byte) error RetrieveFileFrom(path string) ([]byte, error) } - -type FileInfo struct { - Data []byte - StorePath string - CreatedAt time.Time - UpdatedAt time.Time -} diff --git a/internal/util/validation.go b/internal/util/validation.go index f032539..88a5687 100644 --- a/internal/util/validation.go +++ b/internal/util/validation.go @@ -50,8 +50,8 @@ var ( NewUsernameRegex = regexp.MustCompile(NewUsernameRegexString) ) -// ValidateSignUpPassword returns an error if the given password is not sufficiently strong, or nil if it's ok. -func ValidateSignUpPassword(password string) error { +// ValidateNewPassword returns an error if the given password is not sufficiently strong, or nil if it's ok. +func ValidateNewPassword(password string) error { if password == "" { return errors.New("no password provided") } @@ -63,9 +63,9 @@ func ValidateSignUpPassword(password string) error { return pwv.Validate(password, MinimumPasswordEntropy) } -// ValidateSignUpUsername makes sure that a given username is valid (ie., letters, numbers, underscores, check length). +// ValidateUsername makes sure that a given username is valid (ie., letters, numbers, underscores, check length). // Returns an error if not. -func ValidateSignUpUsername(username string) error { +func ValidateUsername(username string) error { if username == "" { return errors.New("no username provided") } @@ -127,3 +127,18 @@ func ValidateSignUpReason(reason string, reasonRequired bool) error { } return nil } + +func ValidateDisplayName(displayName string) error { + // TODO: add some validation logic here -- length, characters, etc + return nil +} + +func ValidateNote(note string) error { + // TODO: add some validation logic here -- length, characters, etc + return nil +} + +func ValidatePrivacy(privacy string) error { + // TODO: add some validation logic here -- length, characters, etc + return nil +} diff --git a/internal/util/validation_test.go b/internal/util/validation_test.go index 4c40c88..28d6457 100644 --- a/internal/util/validation_test.go +++ b/internal/util/validation_test.go @@ -42,42 +42,42 @@ func (suite *ValidationTestSuite) TestCheckPasswordStrength() { strongPassword := "3dX5@Zc%mV*W2MBNEy$@" var err error - err = ValidateSignUpPassword(empty) + err = ValidateNewPassword(empty) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), errors.New("no password provided"), err) } - err = ValidateSignUpPassword(terriblePassword) + err = ValidateNewPassword(terriblePassword) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), errors.New("insecure password, try including more special characters, using uppercase letters, using numbers or using a longer password"), err) } - err = ValidateSignUpPassword(weakPassword) + err = ValidateNewPassword(weakPassword) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), errors.New("insecure password, try including more special characters, using numbers or using a longer password"), err) } - err = ValidateSignUpPassword(shortPassword) + err = ValidateNewPassword(shortPassword) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), errors.New("insecure password, try including more special characters or using a longer password"), err) } - err = ValidateSignUpPassword(specialPassword) + err = ValidateNewPassword(specialPassword) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), errors.New("insecure password, try including more special characters or using a longer password"), err) } - err = ValidateSignUpPassword(longPassword) + err = ValidateNewPassword(longPassword) if assert.NoError(suite.T(), err) { assert.Equal(suite.T(), nil, err) } - err = ValidateSignUpPassword(tooLong) + err = ValidateNewPassword(tooLong) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), errors.New("password should be no more than 64 chars"), err) } - err = ValidateSignUpPassword(strongPassword) + err = ValidateNewPassword(strongPassword) if assert.NoError(suite.T(), err) { assert.Equal(suite.T(), nil, err) } @@ -94,42 +94,42 @@ func (suite *ValidationTestSuite) TestValidateUsername() { goodUsername := "this_is_a_good_username" var err error - err = ValidateSignUpUsername(empty) + err = ValidateUsername(empty) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), errors.New("no username provided"), err) } - err = ValidateSignUpUsername(tooLong) + err = ValidateUsername(tooLong) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), fmt.Errorf("username should be no more than 64 chars but '%s' was 66", tooLong), err) } - err = ValidateSignUpUsername(withSpaces) + err = ValidateUsername(withSpaces) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", withSpaces), err) } - err = ValidateSignUpUsername(weirdChars) + err = ValidateUsername(weirdChars) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", weirdChars), err) } - err = ValidateSignUpUsername(leadingSpace) + err = ValidateUsername(leadingSpace) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", leadingSpace), err) } - err = ValidateSignUpUsername(trailingSpace) + err = ValidateUsername(trailingSpace) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", trailingSpace), err) } - err = ValidateSignUpUsername(newlines) + err = ValidateUsername(newlines) if assert.Error(suite.T(), err) { assert.Equal(suite.T(), fmt.Errorf("given username %s was invalid: must contain only lowercase letters, numbers, and underscores", newlines), err) } - err = ValidateSignUpUsername(goodUsername) + err = ValidateUsername(goodUsername) if assert.NoError(suite.T(), err) { assert.Equal(suite.T(), nil, err) }