some fiddling with tests
This commit is contained in:
@ -33,7 +33,7 @@ type User struct {
|
||||
// id of this user in the local database; the end-user will never need to know this, it's strictly internal
|
||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"`
|
||||
// confirmed email address for this user, this should be unique -- only one email address registered per instance, multiple users per email are not supported
|
||||
Email string `pg:",notnull,unique"`
|
||||
Email string `pg:"default:'',notnull,unique"`
|
||||
// The id of the local gtsmodel.Account entry for this user, if it exists (unconfirmed users don't have an account yet)
|
||||
AccountID string `pg:"default:'',notnull,unique"`
|
||||
// The encrypted password of this user, generated using https://pkg.go.dev/golang.org/x/crypto/bcrypt#GenerateFromPassword. A salt is included so we're safe against 🌈 tables
|
||||
|
||||
@ -379,7 +379,7 @@ func (ps *postgresService) IsUsernameAvailable(username string) error {
|
||||
// if no error we fail because it means we found something
|
||||
// if error but it's not pg.ErrNoRows then we fail
|
||||
// if err is pg.ErrNoRows we're good, we found nothing so continue
|
||||
if err := ps.conn.Model(&model.Account{}).Where("username = ?").Where("domain = ?", nil).Select(); err == nil {
|
||||
if err := ps.conn.Model(&model.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
|
||||
return fmt.Errorf("username %s already in use", username)
|
||||
} else if err != pg.ErrNoRows {
|
||||
return fmt.Errorf("db error: %s", err)
|
||||
@ -404,8 +404,8 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
|
||||
return fmt.Errorf("db error: %s", err)
|
||||
}
|
||||
|
||||
// check if this email is associated with an account already
|
||||
if err := ps.conn.Model(&model.Account{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
|
||||
// check if this email is associated with a user already
|
||||
if err := ps.conn.Model(&model.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
|
||||
// fail because we found something
|
||||
return fmt.Errorf("email %s already in use", email)
|
||||
} else if err != pg.ErrNoRows {
|
||||
|
||||
@ -69,7 +69,7 @@ func (m *accountModule) Route(r router.Router) error {
|
||||
// It should be served as a POST at /api/v1/accounts
|
||||
func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "accountCreatePOSTHandler")
|
||||
authed, err := oauth.GetAuthed(c)
|
||||
authed, err := oauth.MustAuthed(c, true, true, false, false)
|
||||
if err != nil {
|
||||
l.Debugf("couldn't auth: %s", err)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
@ -167,6 +167,7 @@ func (m *accountModule) accountCreate(form *mastotypes.AccountCreateRequest, sig
|
||||
}
|
||||
|
||||
l.Tracef("generating a token for user %s with account %s and application %s", user.ID, user.AccountID, app.ID)
|
||||
fmt.Printf("ACCOUNT CREATE\n\n%+v\n\n%+v\n\n%+v\n", token, app, user)
|
||||
ti, err := m.oauthServer.GenerateUserAccessToken(token, app.ClientSecret, user.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new access token for user %s: %s", user.ID, err)
|
||||
|
||||
@ -20,13 +20,19 @@ package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||
"github.com/gotosocial/gotosocial/internal/oauth"
|
||||
"github.com/gotosocial/oauth2/v4"
|
||||
oauthmodels "github.com/gotosocial/oauth2/v4/models"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
@ -37,6 +43,8 @@ type AccountTestSuite struct {
|
||||
testAccountLocal *model.Account
|
||||
testAccountRemote *model.Account
|
||||
testUser *model.User
|
||||
testApplication *model.Application
|
||||
testToken oauth2.TokenInfo
|
||||
db db.DB
|
||||
accountModule *accountModule
|
||||
}
|
||||
@ -57,6 +65,11 @@ func (suite *AccountTestSuite) SetupSuite() {
|
||||
Database: "postgres",
|
||||
ApplicationName: "gotosocial",
|
||||
}
|
||||
c.AccountsConfig = &config.AccountsConfig{
|
||||
OpenRegistration: true,
|
||||
RequireApproval: true,
|
||||
ReasonRequired: true,
|
||||
}
|
||||
|
||||
database, err := db.New(context.Background(), c, log)
|
||||
if err != nil {
|
||||
@ -70,6 +83,26 @@ func (suite *AccountTestSuite) SetupSuite() {
|
||||
log: log,
|
||||
}
|
||||
|
||||
suite.testApplication = &model.Application{
|
||||
ID: "weeweeeeeeeeeeeeee",
|
||||
Name: "a test application",
|
||||
Website: "https://some-application-website.com",
|
||||
RedirectURI: "http://localhost:8080",
|
||||
ClientID: "a-known-client-id",
|
||||
ClientSecret: "some-secret",
|
||||
Scopes: "read",
|
||||
VapidKey: "aaaaaa-aaaaaaaa-aaaaaaaaaaa",
|
||||
}
|
||||
|
||||
suite.testToken = &oauthmodels.Token{
|
||||
ClientID: "a-known-client-id",
|
||||
RedirectURI: "http://localhost:8080",
|
||||
Scope: "read",
|
||||
Code: "123456789",
|
||||
CodeCreateAt: time.Now(),
|
||||
CodeExpiresIn: time.Duration(10 * time.Minute),
|
||||
}
|
||||
|
||||
// encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
||||
// if err != nil {
|
||||
// logrus.Panicf("error encrypting user pass: %s", err)
|
||||
@ -177,6 +210,7 @@ func (suite *AccountTestSuite) SetupTest() {
|
||||
&model.Follow{},
|
||||
&model.Status{},
|
||||
&model.Application{},
|
||||
&model.EmailDomainBlock{},
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
@ -194,6 +228,7 @@ func (suite *AccountTestSuite) TearDownTest() {
|
||||
&model.Follow{},
|
||||
&model.Status{},
|
||||
&model.Application{},
|
||||
&model.EmailDomainBlock{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.DropTable(m); err != nil {
|
||||
@ -206,8 +241,19 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandler() {
|
||||
// TODO: figure out how to test this properly
|
||||
recorder := httptest.NewRecorder()
|
||||
recorder.Header().Set("X-Forwarded-For", "127.0.0.1")
|
||||
recorder.Header().Set("Content-Type", "application/json")
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
// ctx.Set()
|
||||
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplication)
|
||||
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "http://localhost:8080/api/v1/accounts", nil)
|
||||
ctx.Request.Form = url.Values{
|
||||
"reason": []string{"a very good reason that's at least 40 characters i swear"},
|
||||
"username": []string{"test_user"},
|
||||
"email": []string{"user@example.org"},
|
||||
"password": []string{"very-strong-password"},
|
||||
"agreement": []string{"true"},
|
||||
"locale": []string{"en"},
|
||||
}
|
||||
suite.accountModule.accountCreatePOSTHandler(ctx)
|
||||
}
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -127,6 +128,27 @@ func GetAuthed(c *gin.Context) (*Authed, error) {
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// MustAuthed is like GetAuthed, but will fail if one of the requirements is not met.
|
||||
func MustAuthed(c *gin.Context, requireToken bool, requireApp bool, requireUser bool, requireAccount bool) (*Authed, error) {
|
||||
a, err := GetAuthed(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if requireToken && a.Token == nil {
|
||||
return nil, errors.New("token not supplied")
|
||||
}
|
||||
if requireApp && a.Application == nil {
|
||||
return nil, errors.New("application not supplied")
|
||||
}
|
||||
if requireUser && a.User == nil {
|
||||
return nil, errors.New("user not supplied")
|
||||
}
|
||||
if requireAccount && a.Account == nil {
|
||||
return nil, errors.New("account not supplied")
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
|
||||
func (s *s) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
return s.server.HandleTokenRequest(w, r)
|
||||
@ -149,16 +171,16 @@ func (s *s) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
|
||||
// The ti parameter refers to an existing Application token that was used to make the upstream
|
||||
// request. This token needs to be validated and exist in database in order to create a new token.
|
||||
func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) {
|
||||
|
||||
fmt.Printf("GENERATE USER ACCESS TOKEN %+v\n", ti)
|
||||
tgr := &oauth2.TokenGenerateRequest{
|
||||
ClientID: ti.GetClientID(),
|
||||
ClientSecret: clientSecret,
|
||||
UserID: userID,
|
||||
RedirectURI: ti.GetRedirectURI(),
|
||||
Scope: ti.GetScope(),
|
||||
Code: ti.GetCode(),
|
||||
CodeChallenge: ti.GetCodeChallenge(),
|
||||
CodeChallengeMethod: ti.GetCodeChallengeMethod(),
|
||||
ClientID: ti.GetClientID(),
|
||||
ClientSecret: clientSecret,
|
||||
UserID: userID,
|
||||
RedirectURI: ti.GetRedirectURI(),
|
||||
Scope: ti.GetScope(),
|
||||
Code: ti.GetCode(),
|
||||
// CodeChallenge: ti.GetCodeChallenge(),
|
||||
// CodeChallengeMethod: ti.GetCodeChallengeMethod(),
|
||||
}
|
||||
|
||||
return s.server.Manager.GenerateAccessToken(context.Background(), oauth2.AuthorizationCode, tgr)
|
||||
|
||||
Reference in New Issue
Block a user