chunking away at it

This commit is contained in:
tsmethurst 2021-03-26 19:02:20 +01:00
parent 0a244be523
commit f58f77bf1f
23 changed files with 860 additions and 394 deletions

View File

@ -21,6 +21,7 @@ package db
import (
"context"
"fmt"
"net"
"strings"
"github.com/go-fed/activity/pub"
@ -145,6 +146,10 @@ type DB interface {
// C) something went wrong in the db
IsEmailAvailable(email string) error
// NewSignup creates a new user in the database with the given parameters, with an *unconfirmed* email address.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error)
/*
USEFUL CONVERSION FUNCTIONS
*/

View File

@ -23,6 +23,7 @@
package model
import (
"crypto/rsa"
"net/url"
"time"
)
@ -82,6 +83,8 @@ type Account struct {
SubscriptionExpiresAt time.Time `pg:"type:timestamp"`
// Does this account identify itself as a bot?
Bot bool
// What reason was given for signing up when this account was created?
Reason string
/*
PRIVACY SETTINGS
@ -123,9 +126,9 @@ type Account struct {
Secret string
// Privatekey for validating activitypub requests, will obviously only be defined for local accounts
PrivateKey string
PrivateKey *rsa.PrivateKey
// Publickey for encoding activitypub requests, will be defined for both local and remote accounts
PublicKey string
PublicKey *rsa.PublicKey
/*
ADMIN FIELDS

View File

@ -35,13 +35,13 @@ type DomainBlock struct {
// Account ID of the creator of this block
CreatedByAccountID string `pg:",notnull"`
// TODO: define this
Severity int
Severity int
// Reject media from this domain?
RejectMedia bool
RejectMedia bool
// Reject reports from this domain?
RejectReports bool
RejectReports bool
// Private comment on this block, viewable to admins
PrivateComment string
PrivateComment string
// Public comment on this block, viewable (optionally) by everyone
PublicComment string
PublicComment string
}

View File

@ -20,8 +20,11 @@ package db
import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"net"
"net/mail"
"regexp"
"strings"
@ -35,6 +38,7 @@ import (
"github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/pkg/mastotypes"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
)
// postgresService satisfies the DB interface
@ -305,7 +309,6 @@ func (ps *postgresService) GetAccountByUserID(userID string, account *model.Acco
return err
}
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
fmt.Println(account)
if err == pg.ErrNoRows {
return ErrNoEntries{}
}
@ -400,7 +403,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
// check if this email is associated with an account already
if err := ps.conn.Model(&model.Account{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
// fail because we found something
@ -412,6 +415,43 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
return nil
}
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
a := &model.Account{
Username: username,
DisplayName: username,
Reason: reason,
PrivateKey: key,
PublicKey: &key.PublicKey,
ActorType: "Person",
}
if _, err = ps.conn.Model(a).Insert(); err != nil {
return nil, err
}
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err)
}
u := &model.User{
AccountID: a.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP,
Locale: locale,
UnconfirmedEmail: email,
}
if _, err = ps.conn.Model(u).Insert(); err != nil {
return nil, err
}
return u, nil
}
/*
CONVERSION FUNCTIONS
*/
@ -433,7 +473,6 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype
}
fields = append(fields, mField)
}
fmt.Printf("fields: %+v", fields)
// count followers
followers := []model.Follow{}

21
internal/db/pg_test.go Normal file
View File

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
// TODO: write tests for postgres

View File

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import (

21
internal/db/pgfed_test.go Normal file
View File

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
// TODO: write tests for pgfed

View File

@ -19,6 +19,8 @@
package account
import (
"fmt"
"net"
"net/http"
"github.com/gin-gonic/gin"
@ -26,9 +28,10 @@ import (
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/module"
"github.com/gotosocial/gotosocial/internal/module/oauth"
"github.com/gotosocial/gotosocial/internal/oauth"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/gotosocial/pkg/mastotypes"
"github.com/gotosocial/oauth2/v4"
"github.com/sirupsen/logrus"
)
@ -39,9 +42,10 @@ const (
)
type accountModule struct {
config *config.Config
db db.DB
log *logrus.Logger
config *config.Config
db db.DB
oauthServer oauth.Server
log *logrus.Logger
}
// New returns a new account module
@ -60,15 +64,15 @@ func (m *accountModule) Route(r router.Router) error {
return nil
}
// accountCreatePOSTHandler handles create account requests, validates them,
// and puts them in the database if they're valid.
// It should be served as a POST at /api/v1/accounts
func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AccountCreatePOSTHandler")
// TODO: check whether a valid app token has been presented!!
// See: https://docs.joinmastodon.org/methods/accounts/
l.Trace("checking if registration is open")
if !m.config.AccountsConfig.OpenRegistration {
l.Debug("account registration is closed, returning error to client")
c.JSON(http.StatusUnauthorized, gin.H{"error": "account registration is closed"})
l := m.log.WithField("func", "accountCreatePOSTHandler")
authed, err := oauth.GetAuthed(c)
if err != nil {
l.Debugf("couldn't auth: %s", err)
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@ -81,15 +85,34 @@ func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
}
l.Tracef("validating form %+v", form)
if err := validateCreateAccount(form, m.config.AccountsConfig.ReasonRequired, m.db); err != nil {
if err := validateCreateAccount(form, m.config.AccountsConfig, m.db); err != nil {
l.Debugf("error validating form: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
clientIP := c.ClientIP()
l.Tracef("attempting to parse client ip address %s", clientIP)
signUpIP := net.ParseIP(clientIP)
if signUpIP == nil {
l.Debugf("error validating sign up ip address %s", clientIP)
c.JSON(http.StatusBadRequest, gin.H{"error": "ip address could not be parsed from request"})
return
}
ti, err := m.accountCreate(form, signUpIP, authed.Token, authed.Application)
if err != nil {
l.Errorf("internal server error while creating new account: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, ti)
}
// accountVerifyGETHandler serves a user's account details to them IF they reached this
// handler while in possession of a valid token, according to the oauth middleware.
// It should be served as a GET at /api/v1/accounts/verify_credentials
func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
l := m.log.WithField("func", "AccountVerifyGETHandler")
@ -120,3 +143,39 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive)
c.JSON(http.StatusOK, acctSensitive)
}
/*
HELPER FUNCTIONS
*/
// accountCreate does the dirty work of making an account and user in the database.
// It then returns a token to the caller, for use with the new account, as per the
// spec here: https://docs.joinmastodon.org/methods/accounts/
func (m *accountModule) accountCreate(form *mastotypes.AccountCreateRequest, signUpIP net.IP, token oauth2.TokenInfo, app *model.Application) (*mastotypes.Token, error) {
l := m.log.WithField("func", "accountCreate")
// don't store a reason if we don't require one
reason := form.Reason
if !m.config.AccountsConfig.ReasonRequired {
reason = ""
}
l.Trace("creating new username and account")
user, err := m.db.NewSignup(form.Username, reason, m.config.AccountsConfig.RequireApproval, form.Email, form.Password, signUpIP, form.Locale)
if err != nil {
return nil, fmt.Errorf("error creating new signup in the database: %s", err)
}
l.Tracef("generating a token for user %s with account %s and application %s", user.ID, user.AccountID, app.ID)
ti, err := m.oauthServer.GenerateUserAccessToken(token, app.ClientSecret, user.ID)
if err != nil {
return nil, fmt.Errorf("error creating new access token for user %s: %s", user.ID, err)
}
return &mastotypes.Token{
AccessToken: ti.GetCode(),
TokenType: "Bearer",
Scope: ti.GetScope(),
CreatedAt: ti.GetCodeCreateAt().Unix(),
}, nil
}

View File

@ -20,34 +20,33 @@ package account
import (
"context"
"fmt"
"net/url"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/module/oauth"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"golang.org/x/crypto/bcrypt"
)
type AccountTestSuite struct {
suite.Suite
db db.DB
log *logrus.Logger
testAccountLocal *model.Account
testAccountRemote *model.Account
testUser *model.User
config *config.Config
db db.DB
accountModule *accountModule
}
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *AccountTestSuite) SetupSuite() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
suite.log = log
c := config.Empty()
c.DBConfig = &config.DBConfig{
Type: "postgres",
@ -58,118 +57,126 @@ func (suite *AccountTestSuite) SetupSuite() {
Database: "postgres",
ApplicationName: "gotosocial",
}
suite.config = c
encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
database, err := db.New(context.Background(), c, log)
if err != nil {
logrus.Panicf("error encrypting user pass: %s", err)
suite.FailNow(err.Error())
}
suite.db = database
suite.accountModule = &accountModule{
config: c,
db: database,
log: log,
}
localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png")
if err != nil {
logrus.Panicf("error parsing localavatar url: %s", err)
}
localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png")
if err != nil {
logrus.Panicf("error parsing localheader url: %s", err)
}
// encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
// if err != nil {
// logrus.Panicf("error encrypting user pass: %s", err)
// }
acctID := uuid.NewString()
suite.testAccountLocal = &model.Account{
ID: acctID,
Username: "local_account_of_some_kind",
AvatarRemoteURL: localAvatar,
HeaderRemoteURL: localHeader,
DisplayName: "michael caine",
Fields: []model.Field{
{
Name: "come and ave a go",
Value: "if you think you're hard enough",
},
{
Name: "website",
Value: "https://imdb.com",
VerifiedAt: time.Now(),
},
},
Note: "My name is Michael Caine and i'm a local user.",
Discoverable: true,
}
// localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png")
// if err != nil {
// logrus.Panicf("error parsing localavatar url: %s", err)
// }
// localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png")
// if err != nil {
// logrus.Panicf("error parsing localheader url: %s", err)
// }
avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png")
if err != nil {
logrus.Panicf("error parsing avatarURL: %s", err)
}
// acctID := uuid.NewString()
// suite.testAccountLocal = &model.Account{
// ID: acctID,
// Username: "local_account_of_some_kind",
// AvatarRemoteURL: localAvatar,
// HeaderRemoteURL: localHeader,
// DisplayName: "michael caine",
// Fields: []model.Field{
// {
// Name: "come and ave a go",
// Value: "if you think you're hard enough",
// },
// {
// Name: "website",
// Value: "https://imdb.com",
// VerifiedAt: time.Now(),
// },
// },
// Note: "My name is Michael Caine and i'm a local user.",
// Discoverable: true,
// }
headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png")
if err != nil {
logrus.Panicf("error parsing avatarURL: %s", err)
}
suite.testAccountRemote = &model.Account{
ID: uuid.NewString(),
Username: "neato_bombeato",
Domain: "example.org",
// avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png")
// if err != nil {
// logrus.Panicf("error parsing avatarURL: %s", err)
// }
AvatarFileName: "avatar.png",
AvatarContentType: "image/png",
AvatarFileSize: 1024,
AvatarUpdatedAt: time.Now(),
AvatarRemoteURL: avatarURL,
// headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png")
// if err != nil {
// logrus.Panicf("error parsing avatarURL: %s", err)
// }
// suite.testAccountRemote = &model.Account{
// ID: uuid.NewString(),
// Username: "neato_bombeato",
// Domain: "example.org",
HeaderFileName: "avatar.png",
HeaderContentType: "image/png",
HeaderFileSize: 1024,
HeaderUpdatedAt: time.Now(),
HeaderRemoteURL: headerURL,
// AvatarFileName: "avatar.png",
// AvatarContentType: "image/png",
// AvatarFileSize: 1024,
// AvatarUpdatedAt: time.Now(),
// AvatarRemoteURL: avatarURL,
DisplayName: "one cool dude 420",
Fields: []model.Field{
{
Name: "pronouns",
Value: "he/they",
},
{
Name: "website",
Value: "https://imcool.edu",
VerifiedAt: time.Now(),
},
},
Note: "<p>I'm cool as heck!</p>",
Discoverable: true,
URI: "https://example.org/users/neato_bombeato",
URL: "https://example.org/@neato_bombeato",
LastWebfingeredAt: time.Now(),
InboxURL: "https://example.org/users/neato_bombeato/inbox",
OutboxURL: "https://example.org/users/neato_bombeato/outbox",
SharedInboxURL: "https://example.org/inbox",
FollowersURL: "https://example.org/users/neato_bombeato/followers",
FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured",
}
suite.testUser = &model.User{
ID: uuid.NewString(),
EncryptedPassword: string(encryptedPassword),
Email: "user@example.org",
AccountID: acctID,
// HeaderFileName: "avatar.png",
// HeaderContentType: "image/png",
// HeaderFileSize: 1024,
// HeaderUpdatedAt: time.Now(),
// HeaderRemoteURL: headerURL,
// DisplayName: "one cool dude 420",
// Fields: []model.Field{
// {
// Name: "pronouns",
// Value: "he/they",
// },
// {
// Name: "website",
// Value: "https://imcool.edu",
// VerifiedAt: time.Now(),
// },
// },
// Note: "<p>I'm cool as heck!</p>",
// Discoverable: true,
// URI: "https://example.org/users/neato_bombeato",
// URL: "https://example.org/@neato_bombeato",
// LastWebfingeredAt: time.Now(),
// InboxURL: "https://example.org/users/neato_bombeato/inbox",
// OutboxURL: "https://example.org/users/neato_bombeato/outbox",
// SharedInboxURL: "https://example.org/inbox",
// FollowersURL: "https://example.org/users/neato_bombeato/followers",
// FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured",
// }
// suite.testUser = &model.User{
// ID: uuid.NewString(),
// EncryptedPassword: string(encryptedPassword),
// Email: "user@example.org",
// AccountID: acctID,
// }
}
func (suite *AccountTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil {
logrus.Panicf("error closing db connection: %s", err)
}
}
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
// SetupTest creates a db connection and creates necessary tables before each test
func (suite *AccountTestSuite) SetupTest() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
db, err := db.New(context.Background(), suite.config, log)
if err != nil {
logrus.Panicf("error creating database connection: %s", err)
}
suite.db = db
models := []interface{}{
&model.User{},
&model.Account{},
&model.Follow{},
&model.Status{},
&model.Application{},
}
for _, m := range models {
@ -177,70 +184,31 @@ func (suite *AccountTestSuite) SetupTest() {
logrus.Panicf("db connection error: %s", err)
}
}
if err := suite.db.Put(suite.testAccountLocal); err != nil {
logrus.Panicf("could not insert test account into db: %s", err)
}
if err := suite.db.Put(suite.testUser); err != nil {
logrus.Panicf("could not insert test user into db: %s", err)
}
}
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
// TearDownTest drops tables to make sure there's no data in the db
func (suite *AccountTestSuite) TearDownTest() {
models := []interface{}{
&model.User{},
&model.Account{},
&model.Follow{},
&model.Status{},
&model.Application{},
}
for _, m := range models {
if err := suite.db.DropTable(m); err != nil {
logrus.Panicf("error dropping table: %s", err)
}
}
if err := suite.db.Stop(context.Background()); err != nil {
logrus.Panicf("error closing db connection: %s", err)
}
suite.db = nil
}
func (suite *AccountTestSuite) TestAPIInitialize() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
r, err := router.New(suite.config, log)
if err != nil {
suite.FailNow(fmt.Sprintf("error creating router: %s", err))
}
r.AttachMiddleware(func(c *gin.Context) {
account := &model.Account{}
if err := suite.db.GetAccountByUserID(suite.testUser.ID, account); err != nil || account == nil {
suite.T().Log(err)
suite.FailNowf("no account found for user %s, continuing with unauthenticated request: %+v", "", suite.testUser.ID, account)
fmt.Println(account)
return
}
c.Set(oauth.SessionAuthorizedAccount, account)
c.Set(oauth.SessionAuthorizedUser, suite.testUser.ID)
})
acct := New(suite.config, suite.db, log)
if err := acct.Route(r); err != nil {
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
}
r.Start()
defer func() {
if err := r.Stop(context.Background()); err != nil {
panic(fmt.Errorf("error stopping router: %s", err))
}
}()
time.Sleep(10 * time.Second)
func (suite *AccountTestSuite) TestAccountCreatePOSTHandler() {
// TODO: figure out how to test this properly
recorder := httptest.NewRecorder()
recorder.Header().Set("X-Forwarded-For", "127.0.0.1")
ctx, _ := gin.CreateTestContext(recorder)
// ctx.Set()
suite.accountModule.accountCreatePOSTHandler(ctx)
}
func TestAccountTestSuite(t *testing.T) {

View File

@ -21,12 +21,17 @@ package account
import (
"errors"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/util"
"github.com/gotosocial/gotosocial/pkg/mastotypes"
)
func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired bool, database db.DB) error {
func validateCreateAccount(form *mastotypes.AccountCreateRequest, c *config.AccountsConfig, database db.DB) error {
if !c.OpenRegistration {
return errors.New("registration is not open for this server")
}
if err := util.ValidateSignUpUsername(form.Username); err != nil {
return err
}
@ -47,7 +52,7 @@ func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired
return err
}
if err := util.ValidateSignUpReason(form.Reason, reasonRequired); err != nil {
if err := util.ValidateSignUpReason(form.Reason, c.ReasonRequired); err != nil {
return err
}

140
internal/module/app/app.go Normal file
View File

@ -0,0 +1,140 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package app
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/module"
"github.com/gotosocial/gotosocial/internal/oauth"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/gotosocial/pkg/mastotypes"
"github.com/sirupsen/logrus"
)
const appsPath = "/api/v1/apps"
type appModule struct {
server oauth.Server
db db.DB
log *logrus.Logger
}
// New returns a new auth module
func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule {
return &appModule{
server: srv,
db: db,
log: log,
}
}
// Route satisfies the RESTAPIModule interface
func (m *appModule) Route(s router.Router) error {
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
return nil
}
// appsPOSTHandler should be served at https://example.org/api/v1/apps
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
func (m *appModule) appsPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AppsPOSTHandler")
l.Trace("entering AppsPOSTHandler")
form := &mastotypes.ApplicationPOSTRequest{}
if err := c.ShouldBind(form); err != nil {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
return
}
// permitted length for most fields
permittedLength := 64
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
permittedRedirect := 256
// check lengths of fields before proceeding so the user can't spam huge entries into the database
if len(form.ClientName) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
return
}
if len(form.Website) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
return
}
if len(form.RedirectURIs) > permittedRedirect {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
return
}
if len(form.Scopes) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
return
}
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
var scopes string
if form.Scopes == "" {
scopes = "read"
} else {
scopes = form.Scopes
}
// generate new IDs for this application and its associated client
clientID := uuid.NewString()
clientSecret := uuid.NewString()
vapidKey := uuid.NewString()
// generate the application to put in the database
app := &model.Application{
Name: form.ClientName,
Website: form.Website,
RedirectURI: form.RedirectURIs,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: scopes,
VapidKey: vapidKey,
}
// chuck it in the db
if err := m.db.Put(app); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// now we need to model an oauth client from the application that the oauth library can use
oc := &oauth.Client{
ID: clientID,
Secret: clientSecret,
Domain: form.RedirectURIs,
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
}
// chuck it in the db
if err := m.db.Put(oc); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
c.JSON(http.StatusOK, app.ToMasto())
}

View File

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package app
// TODO: write tests

View File

@ -1,4 +1,4 @@
# oauth
# auth
This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API.

View File

@ -16,57 +16,42 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
// Package oauth is a module that provides oauth functionality to a router.
// Package auth is a module that provides oauth functionality to a router.
// It adds the following paths:
// /api/v1/apps
// /auth/sign_in
// /oauth/token
// /oauth/authorize
// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token.
package oauth
package auth
import (
"errors"
"fmt"
"net/http"
"net/url"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/module"
"github.com/gotosocial/gotosocial/internal/oauth"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/gotosocial/pkg/mastotypes"
"github.com/gotosocial/oauth2/v4"
"github.com/gotosocial/oauth2/v4/errors"
"github.com/gotosocial/oauth2/v4/manage"
"github.com/gotosocial/oauth2/v4/server"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
)
const (
appsPath = "/api/v1/apps"
authSignInPath = "/auth/sign_in"
oauthTokenPath = "/oauth/token"
oauthAuthorizePath = "/oauth/authorize"
// SessionAuthorizedUser is the key set in the gin context for the id of
// a User who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a string.
SessionAuthorizedUser = "authorized_user"
// SessionAuthorizedAccount is the key set in the gin context for the Account
// of a User who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account
SessionAuthorizedAccount = "authorized_account"
)
// oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface
type oauthModule struct {
oauthManager *manage.Manager
oauthServer *server.Server
db db.DB
log *logrus.Logger
type authModule struct {
server oauth.Server
db db.DB
log *logrus.Logger
}
type login struct {
@ -74,52 +59,17 @@ type login struct {
Password string `form:"password"`
}
// New returns a new oauth module
func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule {
manager := manage.NewDefaultManager()
manager.MapTokenStorage(ts)
manager.MapClientStorage(cs)
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
sc := &server.Config{
TokenType: "Bearer",
// Must follow the spec.
AllowGetAccessRequest: false,
// Support only the non-implicit flow.
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
// Allow:
// - Authorization Code (for first & third parties)
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
// New returns a new auth module
func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule {
return &authModule{
server: srv,
db: db,
log: log,
}
srv := server.NewServer(sc, manager)
srv.SetInternalErrorHandler(func(err error) *errors.Response {
log.Errorf("internal oauth error: %s", err)
return nil
})
srv.SetResponseErrorHandler(func(re *errors.Response) {
log.Errorf("internal response error: %s", re.Error)
})
m := &oauthModule{
oauthManager: manager,
oauthServer: srv,
db: db,
log: log,
}
m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler)
m.oauthServer.SetClientInfoHandler(server.ClientFormHandler)
return m
}
// Route satisfies the RESTAPIModule interface
func (m *oauthModule) Route(s router.Router) error {
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
func (m *authModule) Route(s router.Router) error {
s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler)
s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler)
@ -129,7 +79,6 @@ func (m *oauthModule) Route(s router.Router) error {
s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler)
s.AttachMiddleware(m.oauthTokenMiddleware)
return nil
}
@ -137,93 +86,10 @@ func (m *oauthModule) Route(s router.Router) error {
MAIN HANDLERS -- serve these through a server/router
*/
// appsPOSTHandler should be served at https://example.org/api/v1/apps
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
func (m *oauthModule) appsPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AppsPOSTHandler")
l.Trace("entering AppsPOSTHandler")
form := &mastotypes.ApplicationPOSTRequest{}
if err := c.ShouldBind(form); err != nil {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
return
}
// permitted length for most fields
permittedLength := 64
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
permittedRedirect := 256
// check lengths of fields before proceeding so the user can't spam huge entries into the database
if len(form.ClientName) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
return
}
if len(form.Website) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
return
}
if len(form.RedirectURIs) > permittedRedirect {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
return
}
if len(form.Scopes) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
return
}
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
var scopes string
if form.Scopes == "" {
scopes = "read"
} else {
scopes = form.Scopes
}
// generate new IDs for this application and its associated client
clientID := uuid.NewString()
clientSecret := uuid.NewString()
vapidKey := uuid.NewString()
// generate the application to put in the database
app := &model.Application{
Name: form.ClientName,
Website: form.Website,
RedirectURI: form.RedirectURIs,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: scopes,
VapidKey: vapidKey,
}
// chuck it in the db
if err := m.db.Put(app); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// now we need to model an oauth client from the application that the oauth library can use
oc := &oauthClient{
ID: clientID,
Secret: clientSecret,
Domain: form.RedirectURIs,
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
}
// chuck it in the db
if err := m.db.Put(oc); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
c.JSON(http.StatusOK, app.ToMasto())
}
// signInGETHandler should be served at https://example.org/auth/sign_in.
// The idea is to present a sign in page to the user, where they can enter their username and password.
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
func (m *oauthModule) signInGETHandler(c *gin.Context) {
func (m *authModule) signInGETHandler(c *gin.Context) {
m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
}
@ -231,7 +97,7 @@ func (m *oauthModule) signInGETHandler(c *gin.Context) {
// signInPOSTHandler should be served at https://example.org/auth/sign_in.
// The idea is to present a sign in page to the user, where they can enter their username and password.
// The handler will then redirect to the auth handler served at /auth
func (m *oauthModule) signInPOSTHandler(c *gin.Context) {
func (m *authModule) signInPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "SignInPOSTHandler")
s := sessions.Default(c)
form := &login{}
@ -260,10 +126,10 @@ func (m *oauthModule) signInPOSTHandler(c *gin.Context) {
// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
func (m *oauthModule) tokenPOSTHandler(c *gin.Context) {
func (m *authModule) tokenPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "TokenPOSTHandler")
l.Trace("entered TokenPOSTHandler")
if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil {
if err := m.server.HandleTokenRequest(c.Writer, c.Request); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
@ -271,7 +137,7 @@ func (m *oauthModule) tokenPOSTHandler(c *gin.Context) {
// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize
// The idea here is to present an oauth authorize page to the user, with a button
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
func (m *authModule) authorizeGETHandler(c *gin.Context) {
l := m.log.WithField("func", "AuthorizeGETHandler")
s := sessions.Default(c)
@ -349,7 +215,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
// At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
// so we should proceed with the authentication flow and generate an oauth token for them if we can.
// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
func (m *authModule) authorizePOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AuthorizePOSTHandler")
s := sessions.Default(c)
@ -404,7 +270,7 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
l.Tracef("values on request set to %+v", c.Request.Form)
// and proceed with authorization using the oauth2 library
if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
}
@ -418,25 +284,50 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
// the request. Then, it will look up the account for that user, and set that in the request too.
// If user or account can't be found, then the handler won't *fail*, in case the server wants to allow
// public requests that don't have a Bearer token set (eg., for public instance information and so on).
func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
func (m *authModule) oauthTokenMiddleware(c *gin.Context) {
l := m.log.WithField("func", "ValidatePassword")
l.Trace("entering OauthTokenMiddleware")
ti, err := m.oauthServer.ValidationBearerToken(c.Request)
ti, err := m.server.ValidationBearerToken(c.Request)
if err != nil {
l.Trace("no valid token presented: continuing with unauthenticated request")
return
}
l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope())
c.Set(oauth.SessionAuthorizedToken, ti)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedToken, ti)
acct := &model.Account{}
if err := m.db.GetAccountByUserID(ti.GetUserID(), acct); err != nil || acct == nil {
l.Tracef("no account found for user %s, continuing with unauthenticated request", ti.GetUserID())
return
// check for user-level token
if uid := ti.GetUserID(); uid != "" {
l.Tracef("authenticated user %s with bearer token, scope is %s", uid, ti.GetScope())
// fetch user's and account for this user id
user := &model.User{}
if err := m.db.GetByID(uid, user); err != nil || user == nil {
l.Warnf("no user found for validated uid %s", uid)
return
}
c.Set(oauth.SessionAuthorizedUser, user)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
acct := &model.Account{}
if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil {
l.Warnf("no account found for validated user %s", uid)
return
}
c.Set(oauth.SessionAuthorizedAccount, acct)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedAccount, acct)
}
c.Set(SessionAuthorizedAccount, acct)
c.Set(SessionAuthorizedUser, ti.GetUserID())
// check for application token
if cid := ti.GetClientID(); cid != "" {
l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope())
app := &model.Application{}
if err := m.db.GetWhere("client_id", cid, app); err != nil {
l.Tracef("no app found for client %s", cid)
}
c.Set(oauth.SessionAuthorizedApplication, app)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedApplication, app)
}
}
/*
@ -447,7 +338,7 @@ func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
// The goal is to authenticate the password against the one for that email
// address stored in the database. If OK, we return the userid (a uuid) for that user,
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) {
func (m *authModule) validatePassword(email string, password string) (userid string, err error) {
l := m.log.WithField("func", "ValidatePassword")
// make sure an email/password was provided and bail if not
@ -487,18 +378,6 @@ func incorrectPassword() (string, error) {
return "", errors.New("password/email combination was incorrect")
}
// userAuthorizationHandler gets the user's ID from the 'userid' field of the request form,
// or redirects to the /auth/sign_in page, if this key is not present.
func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
l := m.log.WithField("func", "UserAuthorizationHandler")
userID = r.FormValue("userid")
if userID == "" {
return "", errors.New("userid was empty, redirecting to sign in page")
}
l.Tracef("returning userID %s", userID)
return userID, err
}
// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores
// the values in the form into the session.
func parseAuthForm(c *gin.Context, l *logrus.Entry) error {

View File

@ -16,38 +16,38 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package oauth
package auth
import (
"context"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/oauth"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/oauth2/v4"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"golang.org/x/crypto/bcrypt"
)
type OauthTestSuite struct {
type AuthTestSuite struct {
suite.Suite
tokenStore oauth2.TokenStore
clientStore oauth2.ClientStore
oauthServer oauth.Server
db db.DB
testAccount *model.Account
testApplication *model.Application
testUser *model.User
testClient *oauthClient
testClient *oauth.Client
config *config.Config
}
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *OauthTestSuite) SetupSuite() {
func (suite *AuthTestSuite) SetupSuite() {
c := config.Empty()
// we're running on localhost without https so set the protocol to http
c.Protocol = "http"
@ -84,7 +84,7 @@ func (suite *OauthTestSuite) SetupSuite() {
Email: "user@example.org",
AccountID: acctID,
}
suite.testClient = &oauthClient{
suite.testClient = &oauth.Client{
ID: "a-known-client-id",
Secret: "some-secret",
Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
@ -101,7 +101,7 @@ func (suite *OauthTestSuite) SetupSuite() {
}
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
func (suite *OauthTestSuite) SetupTest() {
func (suite *AuthTestSuite) SetupTest() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
@ -113,8 +113,8 @@ func (suite *OauthTestSuite) SetupTest() {
suite.db = db
models := []interface{}{
&oauthClient{},
&oauthToken{},
&oauth.Client{},
&oauth.Token{},
&model.User{},
&model.Account{},
&model.Application{},
@ -126,8 +126,7 @@ func (suite *OauthTestSuite) SetupTest() {
}
}
suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New())
suite.clientStore = newClientStore(suite.db)
suite.oauthServer = oauth.New(suite.db, log)
if err := suite.db.Put(suite.testAccount); err != nil {
logrus.Panicf("could not insert test account into db: %s", err)
@ -145,10 +144,10 @@ func (suite *OauthTestSuite) SetupTest() {
}
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
func (suite *OauthTestSuite) TearDownTest() {
func (suite *AuthTestSuite) TearDownTest() {
models := []interface{}{
&oauthClient{},
&oauthToken{},
&oauth.Client{},
&oauth.Token{},
&model.User{},
&model.Account{},
&model.Application{},
@ -164,7 +163,7 @@ func (suite *OauthTestSuite) TearDownTest() {
suite.db = nil
}
func (suite *OauthTestSuite) TestAPIInitialize() {
func (suite *AuthTestSuite) TestAPIInitialize() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
@ -173,17 +172,18 @@ func (suite *OauthTestSuite) TestAPIInitialize() {
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
}
api := New(suite.tokenStore, suite.clientStore, suite.db, log)
api := New(suite.oauthServer, suite.db, log)
if err := api.Route(r); err != nil {
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
}
r.Start()
time.Sleep(60 * time.Second)
if err := r.Stop(context.Background()); err != nil {
suite.FailNow(fmt.Sprintf("error stopping router: %s", err))
}
}
func TestOauthTestSuite(t *testing.T) {
suite.Run(t, new(OauthTestSuite))
func TestAuthTestSuite(t *testing.T) {
suite.Run(t, new(AuthTestSuite))
}

View File

@ -38,7 +38,7 @@ func newClientStore(db db.DB) oauth2.ClientStore {
}
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
poc := &oauthClient{
poc := &Client{
ID: clientID,
}
if err := cs.db.GetByID(clientID, poc); err != nil {
@ -48,7 +48,7 @@ func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.Cli
}
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
poc := &oauthClient{
poc := &Client{
ID: cli.GetID(),
Secret: cli.GetSecret(),
Domain: cli.GetDomain(),
@ -58,13 +58,13 @@ func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo
}
func (cs *clientStore) Delete(ctx context.Context, id string) error {
poc := &oauthClient{
poc := &Client{
ID: id,
}
return cs.db.DeleteByID(id, poc)
}
type oauthClient struct {
type Client struct {
ID string
Secret string
Domain string

View File

@ -69,7 +69,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
suite.db = db
models := []interface{}{
&oauthClient{},
&Client{},
}
for _, m := range models {
@ -82,7 +82,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
func (suite *PgClientStoreTestSuite) TearDownTest() {
models := []interface{}{
&oauthClient{},
&Client{},
}
for _, m := range models {
if err := suite.db.DropTable(m); err != nil {

212
internal/oauth/oauth.go Normal file
View File

@ -0,0 +1,212 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package oauth
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/oauth2/v4"
"github.com/gotosocial/oauth2/v4/errors"
"github.com/gotosocial/oauth2/v4/manage"
"github.com/gotosocial/oauth2/v4/server"
"github.com/sirupsen/logrus"
)
const (
SessionAuthorizedToken = "authorized_token"
// SessionAuthorizedUser is the key set in the gin context for the id of
// a User who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.User
SessionAuthorizedUser = "authorized_user"
// SessionAuthorizedAccount is the key set in the gin context for the Account
// of a User who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account
SessionAuthorizedAccount = "authorized_account"
// SessionAuthorizedAccount is the key set in the gin context for the Application
// of a Client who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Application
SessionAuthorizedApplication = "authorized_app"
)
// Server wraps some oauth2 server functions in an interface, exposing only what is needed
type Server interface {
HandleTokenRequest(w http.ResponseWriter, r *http.Request) error
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error
ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error)
GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error)
}
// s fulfils the Server interface using the underlying oauth2 server
type s struct {
server *server.Server
log *logrus.Logger
}
type Authed struct {
Token oauth2.TokenInfo
Application *model.Application
User *model.User
Account *model.Account
}
// GetAuthed is a convenience function for returning an Authed struct from a gin context.
// In essence, it tries to extract a token, application, user, and account from the context,
// and then sets them on a struct for convenience.
//
// If any are not present in the context, they will be set to nil on the returned Authed struct.
//
// If *ALL* are not present, then nil and an error will be returned.
//
// If something goes wrong during parsing, then nil and an error will be returned (consider this not authed).
func GetAuthed(c *gin.Context) (*Authed, error) {
ctx := c.Copy()
a := &Authed{}
var i interface{}
var ok bool
i, ok = ctx.Get(SessionAuthorizedToken)
if ok {
parsed, ok := i.(oauth2.TokenInfo)
if !ok {
return nil, errors.New("could not parse token from session context")
}
a.Token = parsed
}
i, ok = ctx.Get(SessionAuthorizedApplication)
if ok {
parsed, ok := i.(*model.Application)
if !ok {
return nil, errors.New("could not parse application from session context")
}
a.Application = parsed
}
i, ok = ctx.Get(SessionAuthorizedUser)
if ok {
parsed, ok := i.(*model.User)
if !ok {
return nil, errors.New("could not parse user from session context")
}
a.User = parsed
}
i, ok = ctx.Get(SessionAuthorizedAccount)
if ok {
parsed, ok := i.(*model.Account)
if !ok {
return nil, errors.New("could not parse account from session context")
}
a.Account = parsed
}
if a.Token == nil && a.Application == nil && a.User == nil && a.Account == nil {
return nil, errors.New("not authorized")
}
return a, nil
}
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
func (s *s) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
return s.server.HandleTokenRequest(w, r)
}
// HandleAuthorizeRequest wraps the oauth2 library's HandleAuthorizeRequest function
func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
return s.server.HandleAuthorizeRequest(w, r)
}
// ValidationBearerToken wraps the oauth2 library's ValidationBearerToken function
func (s *s) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
return s.server.ValidationBearerToken(r)
}
// GenerateUserAccessToken shortcuts the normal oauth flow to create an user-level
// bearer token *without* requiring that user to log in. This is useful when we
// need to create a token for new users who haven't validated their email or logged in yet.
//
// The ti parameter refers to an existing Application token that was used to make the upstream
// request. This token needs to be validated and exist in database in order to create a new token.
func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) {
tgr := &oauth2.TokenGenerateRequest{
ClientID: ti.GetClientID(),
ClientSecret: clientSecret,
UserID: userID,
RedirectURI: ti.GetRedirectURI(),
Scope: ti.GetScope(),
Code: ti.GetCode(),
CodeChallenge: ti.GetCodeChallenge(),
CodeChallengeMethod: ti.GetCodeChallengeMethod(),
}
return s.server.Manager.GenerateAccessToken(context.Background(), oauth2.AuthorizationCode, tgr)
}
func New(database db.DB, log *logrus.Logger) Server {
ts := newTokenStore(context.Background(), database, log)
cs := newClientStore(database)
manager := manage.NewDefaultManager()
manager.MapTokenStorage(ts)
manager.MapClientStorage(cs)
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
sc := &server.Config{
TokenType: "Bearer",
// Must follow the spec.
AllowGetAccessRequest: false,
// Support only the non-implicit flow.
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
// Allow:
// - Authorization Code (for first & third parties)
// - Client Credentials (for applications)
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
oauth2.ClientCredentials,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
}
srv := server.NewServer(sc, manager)
srv.SetInternalErrorHandler(func(err error) *errors.Response {
log.Errorf("internal oauth error: %s", err)
return nil
})
srv.SetResponseErrorHandler(func(re *errors.Response) {
log.Errorf("internal response error: %s", re.Error)
})
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
userID := r.FormValue("userid")
if userID == "" {
return "", errors.New("userid was empty")
}
return userID, nil
})
srv.SetClientInfoHandler(server.ClientFormHandler)
return &s{
server: srv,
}
}

View File

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package oauth
// TODO: write tests

View File

@ -70,7 +70,7 @@ func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.Tok
func (pts *tokenStore) sweep() error {
// select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
tokens := new([]*oauthToken)
tokens := new([]*Token)
if err := pts.db.GetAll(tokens); err != nil {
return err
}
@ -106,22 +106,22 @@ func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error
// RemoveByCode deletes a token from the DB based on the Code field
func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
return pts.db.DeleteWhere("code", code, &oauthToken{})
return pts.db.DeleteWhere("code", code, &Token{})
}
// RemoveByAccess deletes a token from the DB based on the Access field
func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
return pts.db.DeleteWhere("access", access, &oauthToken{})
return pts.db.DeleteWhere("access", access, &Token{})
}
// RemoveByRefresh deletes a token from the DB based on the Refresh field
func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
return pts.db.DeleteWhere("refresh", refresh, &oauthToken{})
return pts.db.DeleteWhere("refresh", refresh, &Token{})
}
// GetByCode selects a token from the DB based on the Code field
func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
pgt := &oauthToken{
pgt := &Token{
Code: code,
}
if err := pts.db.GetWhere("code", code, pgt); err != nil {
@ -132,7 +132,7 @@ func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.Token
// GetByAccess selects a token from the DB based on the Access field
func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
pgt := &oauthToken{
pgt := &Token{
Access: access,
}
if err := pts.db.GetWhere("access", access, pgt); err != nil {
@ -143,7 +143,7 @@ func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.T
// GetByRefresh selects a token from the DB based on the Refresh field
func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
pgt := &oauthToken{
pgt := &Token{
Refresh: refresh,
}
if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil {
@ -156,7 +156,7 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
The following models are basically helpers for the postgres token store implementation, they should only be used internally.
*/
// oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
// Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
//
// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
@ -164,9 +164,9 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
//
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22
// and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go.
// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
// As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
// and pgTokenToOauthToken can be used for that.
type oauthToken struct {
type Token struct {
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
ClientID string
UserID string
@ -186,7 +186,7 @@ type oauthToken struct {
}
// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
func oauthTokenToPGToken(tkn *models.Token) *Token {
now := time.Now()
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
@ -208,7 +208,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
rea = now.Add(tkn.RefreshExpiresIn)
}
return &oauthToken{
return &Token{
ClientID: tkn.ClientID,
UserID: tkn.UserID,
RedirectURI: tkn.RedirectURI,
@ -228,7 +228,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
}
// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
func pgTokenToOauthToken(pgt *oauthToken) *models.Token {
func pgTokenToOauthToken(pgt *Token) *models.Token {
now := time.Now()
return &models.Token{

View File

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package oauth
// TODO: write tests

View File

@ -36,7 +36,7 @@ import (
// Router provides the REST interface for gotosocial, using gin.
type Router interface {
// Attach a gin handler to the router with the given method and path
AttachHandler(method string, path string, handler gin.HandlerFunc)
AttachHandler(method string, path string, f gin.HandlerFunc)
// Attach a gin middleware to the router that will be used globally
AttachMiddleware(handler gin.HandlerFunc)
// Start the router
@ -59,6 +59,8 @@ func (r *router) Start() {
r.logger.Fatalf("listen: %s", err)
}
}()
// c := &gin.Context{}
// c.Get()
}
// Stop shuts down the router nicely

31
pkg/mastotypes/token.go Normal file
View File

@ -0,0 +1,31 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package mastotypes
// Token represents an OAuth token used for authenticating with the API and performing actions.. See https://docs.joinmastodon.org/entities/token/
type Token struct {
// An OAuth token to be used for authorization.
AccessToken string `json:"access_token"`
// The OAuth token type. Mastodon uses Bearer tokens.
TokenType string `json:"token_type"`
// The OAuth scopes granted by this token, space-separated.
Scope string `json:"scope"`
// When the token was generated. (UNIX timestamp seconds)
CreatedAt int64 `json:"created_at"`
}