tests for status create

This commit is contained in:
tsmethurst
2021-04-06 16:31:57 +02:00
parent 21ffcd98ec
commit 1025ac31aa
10 changed files with 571 additions and 57 deletions

View File

@ -4,6 +4,8 @@ package apimodule
import ( import (
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
db "github.com/superseriousbusiness/gotosocial/internal/db"
router "github.com/superseriousbusiness/gotosocial/internal/router" router "github.com/superseriousbusiness/gotosocial/internal/router"
) )
@ -12,6 +14,20 @@ type MockClientAPIModule struct {
mock.Mock mock.Mock
} }
// CreateTables provides a mock function with given fields: _a0
func (_m *MockClientAPIModule) CreateTables(_a0 db.DB) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(db.DB) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}
// Route provides a mock function with given fields: s // Route provides a mock function with given fields: s
func (_m *MockClientAPIModule) Route(s router.Router) error { func (_m *MockClientAPIModule) Route(s router.Router) error {
ret := _m.Called(s) ret := _m.Called(s)

View File

@ -99,12 +99,15 @@ func (m *statusModule) statusCreatePOSTHandler(c *gin.Context) {
ID: thisStatusID, ID: thisStatusID,
URI: thisStatusURI, URI: thisStatusURI,
URL: thisStatusURL, URL: thisStatusURL,
Content: util.HTMLFormat(form.Status),
CreatedAt: time.Now(), CreatedAt: time.Now(),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
Local: true, Local: true,
AccountID: authed.Account.ID, AccountID: authed.Account.ID,
ContentWarning: form.SpoilerText, ContentWarning: form.SpoilerText,
ActivityStreamsType: model.ActivityStreamsNote, ActivityStreamsType: model.ActivityStreamsNote,
Sensitive: form.Sensitive,
Language: form.Language,
} }
// check if replyToID is ok // check if replyToID is ok
@ -166,12 +169,28 @@ func (m *statusModule) statusCreatePOSTHandler(c *gin.Context) {
} }
// return populated status to submitter // return populated status to submitter
// mastoStatus := &mastotypes.Status{ mastoAccount, err := m.db.AccountToMastoPublic(authed.Account)
// ID: newStatus.ID, if err != nil {
// CreatedAt: time.Now().Format(time.RFC3339), c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
// InReplyToID: newStatus.InReplyToID, return
// // InReplyToAccountID: newStatus., }
// } mastoStatus := &mastotypes.Status{
ID: newStatus.ID,
CreatedAt: newStatus.CreatedAt.Format(time.RFC3339),
InReplyToID: newStatus.InReplyToID,
// InReplyToAccountID: newStatus.ReplyToAccount.ID,
Sensitive: newStatus.Sensitive,
SpoilerText: newStatus.ContentWarning,
Visibility: util.ParseMastoVisFromGTSVis(newStatus.Visibility),
Language: newStatus.Language,
URI: newStatus.URI,
URL: newStatus.URL,
Content: newStatus.Content,
Application: authed.Application.ToMasto(),
Account: mastoAccount,
Text: form.Status,
}
c.JSON(http.StatusOK, mastoStatus)
} }
func validateCreateStatus(form *advancedStatusCreateForm, config *config.StatusesConfig) error { func validateCreateStatus(form *advancedStatusCreateForm, config *config.StatusesConfig) error {

View File

@ -0,0 +1,214 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package status
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/model"
"github.com/superseriousbusiness/gotosocial/internal/distributor"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type StatusCreateTestSuite struct {
suite.Suite
config *config.Config
mockOauthServer *oauth.MockServer
mockStorage *storage.MockStorage
mediaHandler media.MediaHandler
distributor *distributor.MockDistributor
testTokens map[string]*oauth.Token
testClients map[string]*oauth.Client
testApplications map[string]*model.Application
testUsers map[string]*model.User
testAccounts map[string]*model.Account
log *logrus.Logger
db db.DB
statusModule *statusModule
}
/*
TEST INFRASTRUCTURE
*/
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *StatusCreateTestSuite) SetupSuite() {
// some of our subsequent entities need a log so create this here
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
suite.log = log
// Direct config to local postgres instance
c := config.Empty()
c.Protocol = "http"
c.Host = "localhost"
c.DBConfig = &config.DBConfig{
Type: "postgres",
Address: "localhost",
Port: 5432,
User: "postgres",
Password: "postgres",
Database: "postgres",
ApplicationName: "gotosocial",
}
c.MediaConfig = &config.MediaConfig{
MaxImageSize: 2 << 20,
}
c.StorageConfig = &config.StorageConfig{
Backend: "local",
BasePath: "/tmp",
ServeProtocol: "http",
ServeHost: "localhost",
ServeBasePath: "/fileserver/media",
}
c.StatusesConfig = &config.StatusesConfig{
MaxChars: 500,
CWMaxChars: 50,
PollMaxOptions: 4,
PollOptionMaxChars: 50,
MaxMediaFiles: 4,
}
suite.config = c
// use an actual database for this, because it's just easier than mocking one out
database, err := db.New(context.Background(), c, log)
if err != nil {
suite.FailNow(err.Error())
}
suite.db = database
suite.mockOauthServer = &oauth.MockServer{}
suite.mockStorage = &storage.MockStorage{}
suite.mediaHandler = media.New(suite.config, suite.db, suite.mockStorage, log)
suite.distributor = &distributor.MockDistributor{}
suite.distributor.On("FromClientAPI").Return(make(chan distributor.FromClientAPI, 100))
suite.statusModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.distributor, suite.log).(*statusModule)
}
func (suite *StatusCreateTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil {
logrus.Panicf("error closing db connection: %s", err)
}
}
func (suite *StatusCreateTestSuite) SetupTest() {
if err := testrig.StandardDBSetup(suite.db); err != nil {
panic(err)
}
suite.testTokens = testrig.TestTokens()
suite.testClients = testrig.TestClients()
suite.testApplications = testrig.TestApplications()
suite.testUsers = testrig.TestUsers()
suite.testAccounts = testrig.TestAccounts()
}
// TearDownTest drops tables to make sure there's no data in the db
func (suite *StatusCreateTestSuite) TearDownTest() {
if err := testrig.StandardDBTeardown(suite.db); err != nil {
panic(err)
}
}
/*
ACTUAL TESTS
*/
/*
TESTING: StatusCreatePOSTHandler
*/
func (suite *StatusCreateTestSuite) TestStatusCreatePOSTHandlerSuccessful() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.PGTokenToOauthToken(t)
// setup
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedToken, oauthToken)
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", basePath), nil) // the endpoint we're hitting
ctx.Request.Form = url.Values{
"status": {"this is a brand new status!"},
"spoiler_text": {"hello hello"},
"sensitive": {"true"},
"visibility": {"public"},
// Status string `form:"status"`
// // Array of Attachment ids to be attached as media. If provided, status becomes optional, and poll cannot be used.
// MediaIDs []string `form:"media_ids"`
// // Poll to include with this status.
// Poll *PollRequest `form:"poll"`
// // ID of the status being replied to, if status is a reply
// InReplyToID string `form:"in_reply_to_id"`
// // Mark status and attached media as sensitive?
// Sensitive bool `form:"sensitive"`
// // Text to be shown as a warning or subject before the actual content. Statuses are generally collapsed behind this field.
// SpoilerText string `form:"spoiler_text"`
// // Visibility of the posted status. Enumerable oneOf public, unlisted, private, direct.
// Visibility Visibility `form:"visibility"`
// // ISO 8601 Datetime at which to schedule a status. Providing this paramter will cause ScheduledStatus to be returned instead of Status. Must be at least 5 minutes in the future.
// ScheduledAt string `form:"scheduled_at"`
// // ISO 639 language code for this status.
// Language string `form:"language"`
}
suite.statusModule.statusCreatePOSTHandler(ctx)
// check response
// 1. we should have OK from our call to the function
suite.EqualValues(http.StatusOK, recorder.Code)
result := recorder.Result()
defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body)
assert.NoError(suite.T(), err)
statusReply := &mastotypes.Status{}
err = json.Unmarshal(b, statusReply)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), "hello hello", statusReply.SpoilerText)
assert.Equal(suite.T(), "this is a brand new status!", statusReply.Content)
assert.True(suite.T(), statusReply.Sensitive)
assert.Equal(suite.T(), mastotypes.VisibilityPublic, statusReply.Visibility)
}
func TestStatusCreateTestSuite(t *testing.T) {
suite.Run(t, new(StatusCreateTestSuite))
}

View File

@ -20,6 +20,29 @@ type MockDB struct {
mock.Mock mock.Mock
} }
// AccountToMastoPublic provides a mock function with given fields: account
func (_m *MockDB) AccountToMastoPublic(account *model.Account) (*mastotypes.Account, error) {
ret := _m.Called(account)
var r0 *mastotypes.Account
if rf, ok := ret.Get(0).(func(*model.Account) *mastotypes.Account); ok {
r0 = rf(account)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*mastotypes.Account)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(*model.Account) error); ok {
r1 = rf(account)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// AccountToMastoSensitive provides a mock function with given fields: account // AccountToMastoSensitive provides a mock function with given fields: account
func (_m *MockDB) AccountToMastoSensitive(account *model.Account) (*mastotypes.Account, error) { func (_m *MockDB) AccountToMastoSensitive(account *model.Account) (*mastotypes.Account, error) {
ret := _m.Called(account) ret := _m.Called(account)
@ -43,6 +66,27 @@ func (_m *MockDB) AccountToMastoSensitive(account *model.Account) (*mastotypes.A
return r0, r1 return r0, r1
} }
// Blocked provides a mock function with given fields: account1, account2
func (_m *MockDB) Blocked(account1 string, account2 string) (bool, error) {
ret := _m.Called(account1, account2)
var r0 bool
if rf, ok := ret.Get(0).(func(string, string) bool); ok {
r0 = rf(account1, account2)
} else {
r0 = ret.Get(0).(bool)
}
var r1 error
if rf, ok := ret.Get(1).(func(string, string) error); ok {
r1 = rf(account1, account2)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CreateTable provides a mock function with given fields: i // CreateTable provides a mock function with given fields: i
func (_m *MockDB) CreateTable(i interface{}) error { func (_m *MockDB) CreateTable(i interface{}) error {
ret := _m.Called(i) ret := _m.Called(i)
@ -99,6 +143,29 @@ func (_m *MockDB) DropTable(i interface{}) error {
return r0 return r0
} }
// EmojiStringsToEmojis provides a mock function with given fields: emojis, originAccountID, statusID
func (_m *MockDB) EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*model.Emoji, error) {
ret := _m.Called(emojis, originAccountID, statusID)
var r0 []*model.Emoji
if rf, ok := ret.Get(0).(func([]string, string, string) []*model.Emoji); ok {
r0 = rf(emojis, originAccountID, statusID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*model.Emoji)
}
}
var r1 error
if rf, ok := ret.Get(1).(func([]string, string, string) error); ok {
r1 = rf(emojis, originAccountID, statusID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Federation provides a mock function with given fields: // Federation provides a mock function with given fields:
func (_m *MockDB) Federation() pub.Database { func (_m *MockDB) Federation() pub.Database {
ret := _m.Called() ret := _m.Called()
@ -143,6 +210,20 @@ func (_m *MockDB) GetAll(i interface{}) error {
return r0 return r0
} }
// GetAvatarForAccountID provides a mock function with given fields: avatar, accountID
func (_m *MockDB) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error {
ret := _m.Called(avatar, accountID)
var r0 error
if rf, ok := ret.Get(0).(func(*model.MediaAttachment, string) error); ok {
r0 = rf(avatar, accountID)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetByID provides a mock function with given fields: id, i // GetByID provides a mock function with given fields: id, i
func (_m *MockDB) GetByID(id string, i interface{}) error { func (_m *MockDB) GetByID(id string, i interface{}) error {
ret := _m.Called(id, i) ret := _m.Called(id, i)
@ -199,6 +280,20 @@ func (_m *MockDB) GetFollowingByAccountID(accountID string, following *[]model.F
return r0 return r0
} }
// GetHeaderForAccountID provides a mock function with given fields: header, accountID
func (_m *MockDB) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error {
ret := _m.Called(header, accountID)
var r0 error
if rf, ok := ret.Get(0).(func(*model.MediaAttachment, string) error); ok {
r0 = rf(header, accountID)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetLastStatusForAccountID provides a mock function with given fields: accountID, status // GetLastStatusForAccountID provides a mock function with given fields: accountID, status
func (_m *MockDB) GetLastStatusForAccountID(accountID string, status *model.Status) error { func (_m *MockDB) GetLastStatusForAccountID(accountID string, status *model.Status) error {
ret := _m.Called(accountID, status) ret := _m.Called(accountID, status)
@ -297,6 +392,29 @@ func (_m *MockDB) IsUsernameAvailable(username string) error {
return r0 return r0
} }
// MentionStringsToMentions provides a mock function with given fields: targetAccounts, originAccountID, statusID
func (_m *MockDB) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*model.Mention, error) {
ret := _m.Called(targetAccounts, originAccountID, statusID)
var r0 []*model.Mention
if rf, ok := ret.Get(0).(func([]string, string, string) []*model.Mention); ok {
r0 = rf(targetAccounts, originAccountID, statusID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*model.Mention)
}
}
var r1 error
if rf, ok := ret.Get(1).(func([]string, string, string) error); ok {
r1 = rf(targetAccounts, originAccountID, statusID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewSignup provides a mock function with given fields: username, reason, requireApproval, email, password, signUpIP, locale, appID // NewSignup provides a mock function with given fields: username, reason, requireApproval, email, password, signUpIP, locale, appID
func (_m *MockDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) { func (_m *MockDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) {
ret := _m.Called(username, reason, requireApproval, email, password, signUpIP, locale, appID) ret := _m.Called(username, reason, requireApproval, email, password, signUpIP, locale, appID)
@ -334,6 +452,20 @@ func (_m *MockDB) Put(i interface{}) error {
return r0 return r0
} }
// SetHeaderOrAvatarForAccountID provides a mock function with given fields: mediaAttachment, accountID
func (_m *MockDB) SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error {
ret := _m.Called(mediaAttachment, accountID)
var r0 error
if rf, ok := ret.Get(0).(func(*model.MediaAttachment, string) error); ok {
r0 = rf(mediaAttachment, accountID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Stop provides a mock function with given fields: ctx // Stop provides a mock function with given fields: ctx
func (_m *MockDB) Stop(ctx context.Context) error { func (_m *MockDB) Stop(ctx context.Context) error {
ret := _m.Called(ctx) ret := _m.Called(ctx)
@ -348,6 +480,29 @@ func (_m *MockDB) Stop(ctx context.Context) error {
return r0 return r0
} }
// TagStringsToTags provides a mock function with given fields: tags, originAccountID, statusID
func (_m *MockDB) TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*model.Tag, error) {
ret := _m.Called(tags, originAccountID, statusID)
var r0 []*model.Tag
if rf, ok := ret.Get(0).(func([]string, string, string) []*model.Tag); ok {
r0 = rf(tags, originAccountID, statusID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*model.Tag)
}
}
var r1 error
if rf, ok := ret.Get(1).(func([]string, string, string) error); ok {
r1 = rf(tags, originAccountID, statusID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateByID provides a mock function with given fields: id, i // UpdateByID provides a mock function with given fields: id, i
func (_m *MockDB) UpdateByID(id string, i interface{}) error { func (_m *MockDB) UpdateByID(id string, i interface{}) error {
ret := _m.Called(id, i) ret := _m.Called(id, i)
@ -361,3 +516,17 @@ func (_m *MockDB) UpdateByID(id string, i interface{}) error {
return r0 return r0
} }
// UpdateOneByID provides a mock function with given fields: id, key, value, i
func (_m *MockDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
ret := _m.Called(id, key, value, i)
var r0 error
if rf, ok := ret.Get(0).(func(string, string, interface{}, interface{}) error); ok {
r0 = rf(id, key, value, i)
} else {
r0 = ret.Error(0)
}
return r0
}

View File

@ -46,6 +46,10 @@ type Status struct {
ContentWarning string ContentWarning string
// visibility entry for this status // visibility entry for this status
Visibility Visibility Visibility Visibility
// mark the status as sensitive?
Sensitive bool
// what language is this status written in?
Language string
// advanced visibility for this status // advanced visibility for this status
VisibilityAdvanced *VisibilityAdvanced VisibilityAdvanced *VisibilityAdvanced
// What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types // What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types

View File

@ -9,32 +9,16 @@ type MockDistributor struct {
mock.Mock mock.Mock
} }
// ClientAPIIn provides a mock function with given fields: // FromClientAPI provides a mock function with given fields:
func (_m *MockDistributor) ClientAPIIn() chan interface{} { func (_m *MockDistributor) FromClientAPI() chan FromClientAPI {
ret := _m.Called() ret := _m.Called()
var r0 chan interface{} var r0 chan FromClientAPI
if rf, ok := ret.Get(0).(func() chan interface{}); ok { if rf, ok := ret.Get(0).(func() chan FromClientAPI); ok {
r0 = rf() r0 = rf()
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(chan interface{}) r0 = ret.Get(0).(chan FromClientAPI)
}
}
return r0
}
// ClientAPIOut provides a mock function with given fields:
func (_m *MockDistributor) ClientAPIOut() chan interface{} {
ret := _m.Called()
var r0 chan interface{}
if rf, ok := ret.Get(0).(func() chan interface{}); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(chan interface{})
} }
} }
@ -68,3 +52,19 @@ func (_m *MockDistributor) Stop() error {
return r0 return r0
} }
// ToClientAPI provides a mock function with given fields:
func (_m *MockDistributor) ToClientAPI() chan ToClientAPI {
ret := _m.Called()
var r0 chan ToClientAPI
if rf, ok := ret.Get(0).(func() chan ToClientAPI); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(chan ToClientAPI)
}
}
return r0
}

View File

@ -26,3 +26,17 @@ func (_m *MockGotosocial) Start(_a0 context.Context) error {
return r0 return r0
} }
// Stop provides a mock function with given fields: _a0
func (_m *MockGotosocial) Stop(_a0 context.Context) error {
ret := _m.Called(_a0)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}

View File

@ -98,7 +98,7 @@ func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error
if !ok { if !ok {
return errors.New("info param was not a models.Token") return errors.New("info param was not a models.Token")
} }
if err := pts.db.Put(oauthTokenToPGToken(t)); err != nil { if err := pts.db.Put(OAuthTokenToPGToken(t)); err != nil {
return fmt.Errorf("error in tokenstore create: %s", err) return fmt.Errorf("error in tokenstore create: %s", err)
} }
return nil return nil
@ -130,7 +130,7 @@ func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.Token
if err := pts.db.GetWhere("code", code, pgt); err != nil { if err := pts.db.GetWhere("code", code, pgt); err != nil {
return nil, err return nil, err
} }
return pgTokenToOauthToken(pgt), nil return PGTokenToOauthToken(pgt), nil
} }
// GetByAccess selects a token from the DB based on the Access field // GetByAccess selects a token from the DB based on the Access field
@ -144,7 +144,7 @@ func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.T
if err := pts.db.GetWhere("access", access, pgt); err != nil { if err := pts.db.GetWhere("access", access, pgt); err != nil {
return nil, err return nil, err
} }
return pgTokenToOauthToken(pgt), nil return PGTokenToOauthToken(pgt), nil
} }
// GetByRefresh selects a token from the DB based on the Refresh field // GetByRefresh selects a token from the DB based on the Refresh field
@ -158,7 +158,7 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil { if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil {
return nil, err return nil, err
} }
return pgTokenToOauthToken(pgt), nil return PGTokenToOauthToken(pgt), nil
} }
/* /*
@ -194,8 +194,8 @@ type Token struct {
RefreshExpiresAt time.Time `pg:"type:timestamp"` RefreshExpiresAt time.Time `pg:"type:timestamp"`
} }
// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres // OAuthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
func oauthTokenToPGToken(tkn *models.Token) *Token { func OAuthTokenToPGToken(tkn *models.Token) *Token {
now := time.Now() now := time.Now()
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's // For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
@ -236,8 +236,8 @@ func oauthTokenToPGToken(tkn *models.Token) *Token {
} }
} }
// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token // PGTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
func pgTokenToOauthToken(pgt *Token) *models.Token { func PGTokenToOauthToken(pgt *Token) *models.Token {
now := time.Now() now := time.Now()
return &models.Token{ return &models.Token{

View File

@ -30,6 +30,37 @@ func StandardDBSetup(db db.DB) error {
return err return err
} }
} }
for _, v := range TestTokens() {
if err := db.Put(v); err != nil {
return err
}
}
for _, v := range TestClients() {
if err := db.Put(v); err != nil {
return err
}
}
for _, v := range TestApplications() {
if err := db.Put(v); err != nil {
return err
}
}
for _, v := range TestUsers() {
if err := db.Put(v); err != nil {
return err
}
}
for _, v := range TestAccounts() {
if err := db.Put(v); err != nil {
return err
}
}
return nil return nil
} }

View File

@ -13,7 +13,14 @@ import (
func TestTokens() map[string]*oauth.Token { func TestTokens() map[string]*oauth.Token {
tokens := map[string]*oauth.Token{ tokens := map[string]*oauth.Token{
"local_account_1": { "local_account_1": {
ID: "64cf4214-33ab-4220-b5ca-4a6a12263b20",
ClientID: "73b48d42-029d-4487-80fc-329a5cf67869",
UserID: "44e36b79-44a4-4bd8-91e9-097f477fe97b",
RedirectURI: "http://localhost:8080",
Scope: "read write follow push",
Access: "NZAZOTC0OWITMDU0NC0ZODG4LWE4NJITMWUXM2M4MTRHZDEX",
AccessCreateAt: time.Now(),
AccessExpiresAt: time.Now().Add(72 * time.Hour),
}, },
} }
return tokens return tokens
@ -126,7 +133,7 @@ func TestUsers() map[string]*model.User {
ChosenLanguages: []string{"en"}, ChosenLanguages: []string{"en"},
FilteredLanguages: []string{}, FilteredLanguages: []string{},
Locale: "en", Locale: "en",
CreatedByApplicationID: "", CreatedByApplicationID: "f88697b8-ee3d-46c2-ac3f-dbb85566c3cc",
LastEmailedAt: time.Now().Add(-55 * time.Minute), LastEmailedAt: time.Now().Add(-55 * time.Minute),
ConfirmationToken: "", ConfirmationToken: "",
ConfirmedAt: time.Now().Add(-34 * time.Hour), ConfirmedAt: time.Now().Add(-34 * time.Hour),
@ -203,14 +210,14 @@ func TestAccounts() map[string]*model.Account {
Privacy: model.VisibilityPublic, Privacy: model.VisibilityPublic,
Sensitive: false, Sensitive: false,
Language: "en", Language: "en",
URI: "http://localhost:8080/users/admin", URI: "http://localhost:8080/users/weed_lord420",
URL: "http://localhost:8080/@admin", URL: "http://localhost:8080/@weed_lord420",
LastWebfingeredAt: time.Time{}, LastWebfingeredAt: time.Time{},
InboxURL: "http://localhost:8080/users/admin/inbox", InboxURL: "http://localhost:8080/users/weed_lord420/inbox",
OutboxURL: "http://localhost:8080/users/admin/outbox", OutboxURL: "http://localhost:8080/users/weed_lord420/outbox",
SharedInboxURL: "", SharedInboxURL: "",
FollowersURL: "http://localhost:8080/users/admin/followers", FollowersURL: "http://localhost:8080/users/weed_lord420/followers",
FeaturedCollectionURL: "http://localhost:8080/users/admin/collections/featured", FeaturedCollectionURL: "http://localhost:8080/users/weed_lord420/collections/featured",
ActorType: model.ActivityStreamsPerson, ActorType: model.ActivityStreamsPerson,
AlsoKnownAs: "", AlsoKnownAs: "",
PrivateKey: &rsa.PrivateKey{}, PrivateKey: &rsa.PrivateKey{},
@ -360,21 +367,61 @@ func TestAccounts() map[string]*model.Account {
ID: "c2c6e647-e2a9-4286-883b-e4a188186664", ID: "c2c6e647-e2a9-4286-883b-e4a188186664",
Username: "foss_satan", Username: "foss_satan",
Domain: "fossbros-anonymous.io", Domain: "fossbros-anonymous.io",
// AvatarFileName: "http://localhost:8080/fileserver/media/eecaad73-5703-426d-9312-276641daa31e/avatar/original/d5e7c265-91a6-4d84-8c27-7e1efe5720da.jpeg",
// AvatarContentType: "image/jpeg",
// AvatarFileSize: 0,
// AvatarUpdatedAt: time.Time{},
// AvatarRemoteURL: "",
// HeaderFileName: "http://localhost:8080/fileserver/media/eecaad73-5703-426d-9312-276641daa31e/header/original/e75d4117-21b6-4315-a428-eb3944235996.jpeg",
// HeaderContentType: "image/jpeg",
// HeaderFileSize: 0,
// HeaderUpdatedAt: time.Time{},
// HeaderRemoteURL: "",
DisplayName: "big gerald",
Fields: []model.Field{},
Note: "",
Memorial: false,
MovedToAccountID: "",
CreatedAt: time.Now().Add(-190 * time.Hour),
UpdatedAt: time.Now().Add(-36 * time.Hour),
Bot: false,
Reason: "",
Locked: false,
Discoverable: true,
Sensitive: false,
Language: "en",
URI: "https://fossbros-anonymous.io/users/foss_satan",
URL: "https://fossbros-anonymous.io/@foss_satan",
LastWebfingeredAt: time.Time{},
InboxURL: "https://fossbros-anonymous.io/users/foss_satan/inbox",
OutboxURL: "https://fossbros-anonymous.io/users/foss_satan/outbox",
SharedInboxURL: "",
FollowersURL: "https://fossbros-anonymous.io/users/foss_satan/followers",
FeaturedCollectionURL: "https://fossbros-anonymous.io/users/foss_satan/collections/featured",
ActorType: model.ActivityStreamsPerson,
AlsoKnownAs: "",
PrivateKey: &rsa.PrivateKey{},
PublicKey: nil,
SensitizedAt: time.Time{},
SilencedAt: time.Time{},
SuspendedAt: time.Time{},
HideCollections: false,
SuspensionOrigin: "",
}, },
"remote_account_2": { // "remote_account_2": {
ID: "93287988-76c4-460f-9e68-a45b578bb6b2", // ID: "93287988-76c4-460f-9e68-a45b578bb6b2",
Username: "dailycatpics", // Username: "dailycatpics",
Domain: "uwu.social", // Domain: "uwu.social",
}, // },
"suspended_local_account": { // "suspended_local_account": {
ID: "e8a5cf4e-4b10-45a4-ad82-b6e37a09100d", // ID: "e8a5cf4e-4b10-45a4-ad82-b6e37a09100d",
Username: "jeffbadman", // Username: "jeffbadman",
}, // },
"suspended_remote_account": { // "suspended_remote_account": {
ID: "17e6e09e-855d-4bf8-a1c3-7e780269f215", // ID: "17e6e09e-855d-4bf8-a1c3-7e780269f215",
Username: "ipfreely", // Username: "ipfreely",
Domain: "a-very-bad-website.com", // Domain: "a-very-bad-website.com",
}, // },
} }
// generate keys for each account // generate keys for each account