diff --git a/go.mod b/go.mod index aec00d3..eab14c9 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/golang/mock v1.5.0 // indirect github.com/google/uuid v1.2.0 github.com/gorilla/sessions v1.2.1 // indirect + github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.1 github.com/json-iterator/go v1.1.11 // indirect github.com/leodido/go-urn v1.2.1 // indirect diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go new file mode 100644 index 0000000..a5b8a2d --- /dev/null +++ b/internal/api/client/streaming/stream.go @@ -0,0 +1,89 @@ +package streaming + +import ( + "fmt" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" +) + +// StreamGETHandler handles the creation of a new websocket streaming request. +func (m *Module) StreamGETHandler(c *gin.Context) { + l := m.log.WithField("func", "StreamGETHandler") + + streamType := c.Query(StreamQueryKey) + if streamType == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("no stream type provided under query key %s", StreamQueryKey)}) + return + } + + accessToken := c.Query(AccessTokenQueryKey) + if accessToken == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("no access token provided under query key %s", AccessTokenQueryKey)}) + return + } + + // make sure a valid token has been provided and obtain the associated account + account, err := m.processor.AuthorizeStreamingRequest(accessToken) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "could not authorize with given token"}) + return + } + + // prepare to upgrade the connection to a websocket connection + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + // we fully expect cors requests (via something like pinafore.social) so we should be lenient here + return true + }, + } + + // do the actual upgrade here + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + l.Infof("error upgrading websocket connection: %s", err) + return + } + defer conn.Close() // whatever happens, when we leave this function we want to close the websocket connection + + // inform the processor that we have a new connection and want a stream for it + stream, errWithCode := m.processor.OpenStreamForAccount(account, streamType) + if errWithCode != nil { + c.JSON(errWithCode.Code(), errWithCode.Safe()) + return + } + defer close(stream.Hangup) // closing stream.Hangup indicates that we've finished with the connection (the client has gone), so we want to do this on exiting this handler + + // spawn a new ticker for pinging the connection periodically + t := time.NewTicker(30 * time.Second) + + // we want to stay in the sendloop as long as possible while the client is connected -- the only thing that should break the loop is if the client leaves or something else goes wrong +sendLoop: + for { + select { + case m := <-stream.Messages: + // we've got a streaming message!! + l.Debug("received message from stream") + if err := conn.WriteJSON(m); err != nil { + l.Infof("error writing json to websocket connection: %s", err) + // if something is wrong we want to bail and drop the connection -- the client will create a new one + break sendLoop + } + l.Debug("wrote message into websocket connection") + case <-t.C: + l.Debug("received TICK from ticker") + if err := conn.WriteMessage(websocket.PingMessage, []byte(": ping")); err != nil { + l.Infof("error writing ping to websocket connection: %s", err) + // if something is wrong we want to bail and drop the connection -- the client will create a new one + break sendLoop + } + l.Debug("wrote ping message into websocket connection") + } + } + + l.Debug("leaving StreamGETHandler") +} diff --git a/internal/api/client/streaming/streaming.go b/internal/api/client/streaming/streaming.go new file mode 100644 index 0000000..92dfccd --- /dev/null +++ b/internal/api/client/streaming/streaming.go @@ -0,0 +1,62 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package streaming + +import ( + "net/http" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/api" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/router" +) + +const ( + // BasePath is the path for the streaming api + BasePath = "/api/v1/streaming" + + // StreamQueryKey is the query key for the type of stream being requested + StreamQueryKey = "stream" + + // AccessTokenQueryKey is the query key for an oauth access token that should be passed in streaming requests. + AccessTokenQueryKey = "access_token" +) + +// Module implements the api.ClientModule interface for everything related to streaming +type Module struct { + config *config.Config + processor processing.Processor + log *logrus.Logger +} + +// New returns a new streaming module +func New(config *config.Config, processor processing.Processor, log *logrus.Logger) api.ClientModule { + return &Module{ + config: config, + processor: processor, + log: log, + } +} + +// Route attaches all routes from this module to the given router +func (m *Module) Route(r router.Router) error { + r.AttachHandler(http.MethodGet, BasePath, m.StreamGETHandler) + return nil +} diff --git a/internal/cliactions/server/server.go b/internal/cliactions/server/server.go index f055890..74b1c78 100644 --- a/internal/cliactions/server/server.go +++ b/internal/cliactions/server/server.go @@ -24,6 +24,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/api/client/notification" "github.com/superseriousbusiness/gotosocial/internal/api/client/search" "github.com/superseriousbusiness/gotosocial/internal/api/client/status" + "github.com/superseriousbusiness/gotosocial/internal/api/client/streaming" "github.com/superseriousbusiness/gotosocial/internal/api/client/timeline" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger" @@ -134,6 +135,7 @@ var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log adminModule := admin.New(c, processor, log) statusModule := status.New(c, processor, log) securityModule := security.New(c, log) + streamingModule := streaming.New(c, processor, log) apis := []api.ClientModule{ // modules with middleware go first @@ -157,6 +159,7 @@ var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log filtersModule, emojiModule, listsModule, + streamingModule, } for _, m := range apis { diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go index 8515013..9daf94e 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/pg/pg.go @@ -810,7 +810,7 @@ func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccoun if sourceAccount == nil || targetAccount == nil { return false, nil } - + return ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() } @@ -818,7 +818,7 @@ func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targ if sourceAccount == nil || targetAccount == nil { return false, nil } - + return ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() } @@ -826,7 +826,7 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode if account1 == nil || account2 == nil { return false, nil } - + // make sure account 1 follows account 2 f1, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists() if err != nil { @@ -975,6 +975,7 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s q := ps.conn.Model(&statuses). Where("visibility = ?", gtsmodel.VisibilityPublic). Where("? IS NULL", pg.Ident("in_reply_to_id")). + Where("? IS NULL", pg.Ident("in_reply_to_uri")). Where("? IS NULL", pg.Ident("boost_of_id")). Order("status.id DESC") diff --git a/internal/gtsmodel/stream.go b/internal/gtsmodel/stream.go new file mode 100644 index 0000000..4a1571d --- /dev/null +++ b/internal/gtsmodel/stream.go @@ -0,0 +1,38 @@ +package gtsmodel + +import "sync" + +// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time. +// TODO: put a limit on this +type StreamsForAccount struct { + // The currently held streams for this account + Streams []*Stream + // Mutex to lock/unlock when modifying the slice of streams. + sync.Mutex +} + +// Stream represents one open stream for a client. +type Stream struct { + // ID of this stream, generated during creation. + ID string + // Type of this stream: user/public/etc + Type string + // Channel of messages for the client to read from + Messages chan *Message + // Channel to close when the client drops away + Hangup chan interface{} + // Only put messages in the stream when Connected + Connected bool + // Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc. + sync.Mutex +} + +// Message represents one streamed message. +type Message struct { + // All the stream types this message should be delivered to. + Stream []string `json:"stream"` + // The event type of the message (update/delete/notification etc) + Event string `json:"event"` + // The actual payload of the message. In case of an update or notification, this will be a JSON string. + Payload string `json:"payload"` +} diff --git a/internal/oauth/server.go b/internal/oauth/server.go index fb84743..1289b18 100644 --- a/internal/oauth/server.go +++ b/internal/oauth/server.go @@ -56,6 +56,7 @@ type Server interface { HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) + LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error) } // s fulfils the Server interface using the underlying oauth2 server @@ -171,3 +172,7 @@ func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, us s.log.Tracef("obtained user-level access token: %+v", accessToken) return accessToken, nil } + +func (s *s) LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error) { + return s.server.Manager.LoadAccessToken(ctx, access) +} diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 65ccef4..e10f754 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -96,6 +96,16 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { if err := p.db.Put(notif); err != nil { return fmt.Errorf("notifyStatus: error putting notification in database: %s", err) } + + // now stream the notification to the user + mastoNotif, err := p.tc.NotificationToMasto(notif) + if err != nil { + return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err) + } + + if err := p.streamingProcessor.StreamNotificationToAccount(mastoNotif, m.GTSAccount); err != nil { + return fmt.Errorf("notifyStatus: error streaming notification to account: %s", err) + } } return nil @@ -123,6 +133,16 @@ func (p *processor) notifyFollowRequest(followRequest *gtsmodel.FollowRequest, r return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err) } + // now stream the notification to the user + mastoNotif, err := p.tc.NotificationToMasto(notif) + if err != nil { + return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err) + } + + if err := p.streamingProcessor.StreamNotificationToAccount(mastoNotif, receivingAccount); err != nil { + return fmt.Errorf("notifyStatus: error streaming notification to account: %s", err) + } + return nil } @@ -157,6 +177,16 @@ func (p *processor) notifyFollow(follow *gtsmodel.Follow, receivingAccount *gtsm return fmt.Errorf("notifyFollow: error putting notification in database: %s", err) } + // now stream the notification to the user + mastoNotif, err := p.tc.NotificationToMasto(notif) + if err != nil { + return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err) + } + + if err := p.streamingProcessor.StreamNotificationToAccount(mastoNotif, receivingAccount); err != nil { + return fmt.Errorf("notifyStatus: error streaming notification to account: %s", err) + } + return nil } @@ -183,6 +213,16 @@ func (p *processor) notifyFave(fave *gtsmodel.StatusFave, receivingAccount *gtsm return fmt.Errorf("notifyFave: error putting notification in database: %s", err) } + // now stream the notification to the user + mastoNotif, err := p.tc.NotificationToMasto(notif) + if err != nil { + return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err) + } + + if err := p.streamingProcessor.StreamNotificationToAccount(mastoNotif, receivingAccount); err != nil { + return fmt.Errorf("notifyStatus: error streaming notification to account: %s", err) + } + return nil } @@ -242,6 +282,16 @@ func (p *processor) notifyAnnounce(status *gtsmodel.Status) error { return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err) } + // now stream the notification to the user + mastoNotif, err := p.tc.NotificationToMasto(notif) + if err != nil { + return fmt.Errorf("notifyStatus: error converting notification to masto representation: %s", err) + } + + if err := p.streamingProcessor.StreamNotificationToAccount(mastoNotif, boostedAcct); err != nil { + return fmt.Errorf("notifyStatus: error streaming notification to account: %s", err) + } + return nil } @@ -321,8 +371,32 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID return } - if err := p.timelineManager.IngestAndPrepare(status, timelineAccount.ID); err != nil { + // stick the status in the timeline for the account and then immediately prepare it so they can see it right away + inserted, err := p.timelineManager.IngestAndPrepare(status, timelineAccount.ID) + if err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %s", status.ID, err) + return + } + + // the status was inserted to stream it to the user + if inserted { + mastoStatus, err := p.tc.StatusToMasto(status, timelineAccount) + if err != nil { + errors <- fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %s", status.ID, err) + } else { + if err := p.streamingProcessor.StreamStatusToAccount(mastoStatus, timelineAccount); err != nil { + errors <- fmt.Errorf("timelineStatusForAccount: error streaming status %s: %s", status.ID, err) + } + } + } + + mastoStatus, err := p.tc.StatusToMasto(status, timelineAccount) + if err != nil { + errors <- fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %s", status.ID, err) + } else { + if err := p.streamingProcessor.StreamStatusToAccount(mastoStatus, timelineAccount); err != nil { + errors <- fmt.Errorf("timelineStatusForAccount: error streaming status %s: %s", status.ID, err) + } } } diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 301cb57..d1b4431 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -33,6 +33,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing/synchronous/status" + "github.com/superseriousbusiness/gotosocial/internal/processing/synchronous/streaming" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" @@ -132,6 +133,11 @@ type Processor interface { // PublicTimelineGet returns statuses from the public/local timeline, with the given filters/parameters. PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, gtserror.WithCode) + // AuthorizeStreamingRequest returns a gotosocial account in exchange for an access token, or an error if the given token is not valid. + AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) + // OpenStreamForAccount opens a new stream for the given account, with the given stream type. + OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) + /* FEDERATION API-FACING PROCESSING FUNCTIONS These functions are intended to be called when the federating client needs an immediate (ie., synchronous) reply @@ -192,7 +198,8 @@ type processor struct { SUB-PROCESSORS */ - statusProcessor status.Processor + statusProcessor status.Processor + streamingProcessor streaming.Processor } // NewProcessor returns a new Processor that uses the given federator and logger @@ -202,6 +209,7 @@ func NewProcessor(config *config.Config, tc typeutils.TypeConverter, federator f fromFederator := make(chan gtsmodel.FromFederator, 1000) statusProcessor := status.New(db, tc, config, fromClientAPI, log) + streamingProcessor := streaming.New(db, tc, oauthServer, config, log) return &processor{ fromClientAPI: fromClientAPI, @@ -218,7 +226,8 @@ func NewProcessor(config *config.Config, tc typeutils.TypeConverter, federator f db: db, filter: visibility.NewFilter(db, log), - statusProcessor: statusProcessor, + statusProcessor: statusProcessor, + streamingProcessor: streamingProcessor, } } diff --git a/internal/processing/streaming.go b/internal/processing/streaming.go new file mode 100644 index 0000000..1e566da --- /dev/null +++ b/internal/processing/streaming.go @@ -0,0 +1,14 @@ +package processing + +import ( + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) { + return p.streamingProcessor.AuthorizeStreamingRequest(accessToken) +} + +func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) { + return p.streamingProcessor.OpenStreamForAccount(account, streamType) +} diff --git a/internal/processing/synchronous/streaming/authorize.go b/internal/processing/synchronous/streaming/authorize.go new file mode 100644 index 0000000..8bbf185 --- /dev/null +++ b/internal/processing/synchronous/streaming/authorize.go @@ -0,0 +1,33 @@ +package streaming + +import ( + "context" + "fmt" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) { + ti, err := p.oauthServer.LoadAccessToken(context.Background(), accessToken) + if err != nil { + return nil, fmt.Errorf("AuthorizeStreamingRequest: error loading access token: %s", err) + } + + uid := ti.GetUserID() + if uid == "" { + return nil, fmt.Errorf("AuthorizeStreamingRequest: no userid in token") + } + + // fetch user's and account for this user id + user := >smodel.User{} + if err := p.db.GetByID(uid, user); err != nil || user == nil { + return nil, fmt.Errorf("AuthorizeStreamingRequest: no user found for validated uid %s", uid) + } + + acct := >smodel.Account{} + if err := p.db.GetByID(user.AccountID, acct); err != nil || acct == nil { + return nil, fmt.Errorf("AuthorizeStreamingRequest: no account retrieved for user with id %s", uid) + } + + return acct, nil +} diff --git a/internal/processing/synchronous/streaming/openstream.go b/internal/processing/synchronous/streaming/openstream.go new file mode 100644 index 0000000..68446ba --- /dev/null +++ b/internal/processing/synchronous/streaming/openstream.go @@ -0,0 +1,100 @@ +package streaming + +import ( + "errors" + "fmt" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" +) + +func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) { + l := p.log.WithFields(logrus.Fields{ + "func": "OpenStreamForAccount", + "account": account.ID, + "streamType": streamType, + }) + l.Debug("received open stream request") + + // each stream needs a unique ID so we know to close it + streamID, err := id.NewRandomULID() + if err != nil { + return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err)) + } + + thisStream := >smodel.Stream{ + ID: streamID, + Type: streamType, + Messages: make(chan *gtsmodel.Message, 100), + Hangup: make(chan interface{}, 1), + Connected: true, + } + go p.waitToCloseStream(account, thisStream) + + v, ok := p.streamMap.Load(account.ID) + if !ok || v == nil { + // there is no entry in the streamMap for this account yet, so make one and store it + streamsForAccount := >smodel.StreamsForAccount{ + Streams: []*gtsmodel.Stream{ + thisStream, + }, + } + p.streamMap.Store(account.ID, streamsForAccount) + } else { + // there is an entry in the streamMap for this account + // parse the interface as a streamsForAccount + streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount) + if !ok { + return nil, gtserror.NewErrorInternalError(errors.New("stream map error")) + } + + // append this stream to it + streamsForAccount.Lock() + streamsForAccount.Streams = append(streamsForAccount.Streams, thisStream) + streamsForAccount.Unlock() + } + + return thisStream, nil +} + +// waitToCloseStream waits until the hangup channel is closed for the given stream. +// It then iterates through the map of streams stored by the processor, removes the stream from it, +// and then closes the messages channel of the stream to indicate that the channel should no longer be read from. +func (p *processor) waitToCloseStream(account *gtsmodel.Account, thisStream *gtsmodel.Stream) { + <-thisStream.Hangup // wait for a hangup message + + // lock the stream to prevent more messages being put in it while we work + thisStream.Lock() + defer thisStream.Unlock() + + // indicate the stream is no longer connected + thisStream.Connected = false + + // load and parse the entry for this account from the stream map + v, ok := p.streamMap.Load(account.ID) + if !ok || v == nil { + return + } + streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount) + if !ok { + return + } + + // lock the streams for account while we remove this stream from its slice + streamsForAccount.Lock() + defer streamsForAccount.Unlock() + + // put everything into modified streams *except* the stream we're removing + modifiedStreams := []*gtsmodel.Stream{} + for _, s := range streamsForAccount.Streams { + if s.ID != thisStream.ID { + modifiedStreams = append(modifiedStreams, s) + } + } + streamsForAccount.Streams = modifiedStreams + + // finally close the messages channel so no more messages can be read from it + close(thisStream.Messages) +} diff --git a/internal/processing/synchronous/streaming/streaming.go b/internal/processing/synchronous/streaming/streaming.go new file mode 100644 index 0000000..9fd6285 --- /dev/null +++ b/internal/processing/synchronous/streaming/streaming.go @@ -0,0 +1,47 @@ +package streaming + +import ( + "sync" + + "github.com/sirupsen/logrus" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" +) + +// Processor wraps a bunch of functions for processing streaming. +type Processor interface { + // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API + AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) + OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) + StreamStatusToAccount(s *apimodel.Status, account *gtsmodel.Account) error + StreamNotificationToAccount(n *apimodel.Notification, account *gtsmodel.Account) error +} + +type processor struct { + tc typeutils.TypeConverter + config *config.Config + db db.DB + filter visibility.Filter + log *logrus.Logger + oauthServer oauth.Server + streamMap *sync.Map +} + +// New returns a new status processor. +func New(db db.DB, tc typeutils.TypeConverter, oauthServer oauth.Server, config *config.Config, log *logrus.Logger) Processor { + return &processor{ + tc: tc, + config: config, + db: db, + filter: visibility.NewFilter(db, log), + log: log, + oauthServer: oauthServer, + streamMap: &sync.Map{}, + } +} diff --git a/internal/processing/synchronous/streaming/streamnotification.go b/internal/processing/synchronous/streaming/streamnotification.go new file mode 100644 index 0000000..24c8342 --- /dev/null +++ b/internal/processing/synchronous/streaming/streamnotification.go @@ -0,0 +1,50 @@ +package streaming + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/sirupsen/logrus" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (p *processor) StreamNotificationToAccount(n *apimodel.Notification, account *gtsmodel.Account) error { + l := p.log.WithFields(logrus.Fields{ + "func": "StreamNotificationToAccount", + "account": account.ID, + }) + v, ok := p.streamMap.Load(account.ID) + if !ok { + // no open connections so nothing to stream + return nil + } + + streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount) + if !ok { + return errors.New("stream map error") + } + + notificationBytes, err := json.Marshal(n) + if err != nil { + return fmt.Errorf("error marshalling notification to json: %s", err) + } + + streamsForAccount.Lock() + defer streamsForAccount.Unlock() + for _, stream := range streamsForAccount.Streams { + stream.Lock() + defer stream.Unlock() + if stream.Connected { + l.Debugf("streaming notification to stream id %s", stream.ID) + stream.Messages <- >smodel.Message{ + Stream: []string{stream.Type}, + Event: "notification", + Payload: string(notificationBytes), + } + } + } + + return nil +} diff --git a/internal/processing/synchronous/streaming/streamstatus.go b/internal/processing/synchronous/streaming/streamstatus.go new file mode 100644 index 0000000..8d02625 --- /dev/null +++ b/internal/processing/synchronous/streaming/streamstatus.go @@ -0,0 +1,50 @@ +package streaming + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/sirupsen/logrus" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (p *processor) StreamStatusToAccount(s *apimodel.Status, account *gtsmodel.Account) error { + l := p.log.WithFields(logrus.Fields{ + "func": "StreamStatusForAccount", + "account": account.ID, + }) + v, ok := p.streamMap.Load(account.ID) + if !ok { + // no open connections so nothing to stream + return nil + } + + streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount) + if !ok { + return errors.New("stream map error") + } + + statusBytes, err := json.Marshal(s) + if err != nil { + return fmt.Errorf("error marshalling status to json: %s", err) + } + + streamsForAccount.Lock() + defer streamsForAccount.Unlock() + for _, stream := range streamsForAccount.Streams { + stream.Lock() + defer stream.Unlock() + if stream.Connected { + l.Debugf("streaming status to stream id %s", stream.ID) + stream.Messages <- >smodel.Message{ + Stream: []string{stream.Type}, + Event: "update", + Payload: string(statusBytes), + } + } + } + + return nil +} diff --git a/internal/processing/timeline.go b/internal/processing/timeline.go index a8f42d6..8f6b1d2 100644 --- a/internal/processing/timeline.go +++ b/internal/processing/timeline.go @@ -201,7 +201,7 @@ func (p *processor) indexAndIngest(statuses []*gtsmodel.Status, timelineAccount continue } if timelineable { - if err := p.timelineManager.Ingest(s, timelineAccount.ID); err != nil { + if _, err := p.timelineManager.Ingest(s, timelineAccount.ID); err != nil { l.Error(fmt.Errorf("initTimelineFor: error ingesting status %s: %s", s.ID, err)) continue } diff --git a/internal/router/router.go b/internal/router/router.go index adfd2ce..e575b11 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -128,9 +128,9 @@ func New(config *config.Config, logger *logrus.Logger) (Router, error) { AllowAllOrigins: true, AllowBrowserExtensions: true, AllowMethods: []string{"POST", "PUT", "DELETE", "GET", "PATCH", "OPTIONS"}, - AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization"}, + AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization", "Upgrade", "Sec-WebSocket-Extensions", "Sec-WebSocket-Key", "Sec-WebSocket-Protocol", "Sec-WebSocket-Version", "Connection"}, AllowWebSockets: true, - ExposeHeaders: []string{"Link", "X-RateLimit-Reset", "X-RateLimit-Limit", " X-RateLimit-Remaining", "X-Request-Id"}, + ExposeHeaders: []string{"Link", "X-RateLimit-Reset", "X-RateLimit-Limit", " X-RateLimit-Remaining", "X-Request-Id", "Connection", "Sec-WebSocket-Accept", "Upgrade"}, MaxAge: 2 * time.Minute, })) engine.MaxMultipartMemory = 8 << 20 // 8 MiB diff --git a/internal/timeline/index.go b/internal/timeline/index.go index bc1bf99..8d0506a 100644 --- a/internal/timeline/index.go +++ b/internal/timeline/index.go @@ -40,7 +40,7 @@ grabloop: } for _, s := range filtered { - if err := t.IndexOne(s.CreatedAt, s.ID, s.BoostOfID); err != nil { + if _, err := t.IndexOne(s.CreatedAt, s.ID, s.BoostOfID); err != nil { return fmt.Errorf("IndexBehindAndIncluding: error indexing status with id %s: %s", s.ID, err) } } @@ -52,7 +52,7 @@ func (t *timeline) IndexOneByID(statusID string) error { return nil } -func (t *timeline) IndexOne(statusCreatedAt time.Time, statusID string, boostOfID string) error { +func (t *timeline) IndexOne(statusCreatedAt time.Time, statusID string, boostOfID string) (bool, error) { t.Lock() defer t.Unlock() @@ -64,7 +64,7 @@ func (t *timeline) IndexOne(statusCreatedAt time.Time, statusID string, boostOfI return t.postIndex.insertIndexed(postIndexEntry) } -func (t *timeline) IndexAndPrepareOne(statusCreatedAt time.Time, statusID string) error { +func (t *timeline) IndexAndPrepareOne(statusCreatedAt time.Time, statusID string) (bool, error) { t.Lock() defer t.Unlock() @@ -72,15 +72,18 @@ func (t *timeline) IndexAndPrepareOne(statusCreatedAt time.Time, statusID string statusID: statusID, } - if err := t.postIndex.insertIndexed(postIndexEntry); err != nil { - return fmt.Errorf("IndexAndPrepareOne: error inserting indexed: %s", err) + inserted, err := t.postIndex.insertIndexed(postIndexEntry) + if err != nil { + return inserted, fmt.Errorf("IndexAndPrepareOne: error inserting indexed: %s", err) } - if err := t.prepare(statusID); err != nil { - return fmt.Errorf("IndexAndPrepareOne: error preparing: %s", err) + if inserted { + if err := t.prepare(statusID); err != nil { + return inserted, fmt.Errorf("IndexAndPrepareOne: error preparing: %s", err) + } } - return nil + return inserted, nil } func (t *timeline) OldestIndexedPostID() (string, error) { diff --git a/internal/timeline/manager.go b/internal/timeline/manager.go index c389a6b..2770f9e 100644 --- a/internal/timeline/manager.go +++ b/internal/timeline/manager.go @@ -51,12 +51,18 @@ type Manager interface { // Ingest takes one status and indexes it into the timeline for the given account ID. // // It should already be established before calling this function that the status/post actually belongs in the timeline! - Ingest(status *gtsmodel.Status, timelineAccountID string) error + // + // The returned bool indicates whether the status was actually put in the timeline. This could be false in cases where + // the status is a boost, but a boost of the original post or the post itself already exists recently in the timeline. + Ingest(status *gtsmodel.Status, timelineAccountID string) (bool, error) // IngestAndPrepare takes one status and indexes it into the timeline for the given account ID, and then immediately prepares it for serving. // This is useful in cases where we know the status will need to be shown at the top of a user's timeline immediately (eg., a new status is created). // // It should already be established before calling this function that the status/post actually belongs in the timeline! - IngestAndPrepare(status *gtsmodel.Status, timelineAccountID string) error + // + // The returned bool indicates whether the status was actually put in the timeline. This could be false in cases where + // the status is a boost, but a boost of the original post or the post itself already exists recently in the timeline. + IngestAndPrepare(status *gtsmodel.Status, timelineAccountID string) (bool, error) // HomeTimeline returns limit n amount of entries from the home timeline of the given account ID, in descending chronological order. // If maxID is provided, it will return entries from that maxID onwards, inclusive. HomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, error) @@ -95,7 +101,7 @@ type manager struct { log *logrus.Logger } -func (m *manager) Ingest(status *gtsmodel.Status, timelineAccountID string) error { +func (m *manager) Ingest(status *gtsmodel.Status, timelineAccountID string) (bool, error) { l := m.log.WithFields(logrus.Fields{ "func": "Ingest", "timelineAccountID": timelineAccountID, @@ -108,7 +114,7 @@ func (m *manager) Ingest(status *gtsmodel.Status, timelineAccountID string) erro return t.IndexOne(status.CreatedAt, status.ID, status.BoostOfID) } -func (m *manager) IngestAndPrepare(status *gtsmodel.Status, timelineAccountID string) error { +func (m *manager) IngestAndPrepare(status *gtsmodel.Status, timelineAccountID string) (bool, error) { l := m.log.WithFields(logrus.Fields{ "func": "IngestAndPrepare", "timelineAccountID": timelineAccountID, diff --git a/internal/timeline/postindex.go b/internal/timeline/postindex.go index 7142035..44765bf 100644 --- a/internal/timeline/postindex.go +++ b/internal/timeline/postindex.go @@ -14,7 +14,7 @@ type postIndexEntry struct { boostOfID string } -func (p *postIndex) insertIndexed(i *postIndexEntry) error { +func (p *postIndex) insertIndexed(i *postIndexEntry) (bool, error) { if p.data == nil { p.data = &list.List{} } @@ -22,7 +22,7 @@ func (p *postIndex) insertIndexed(i *postIndexEntry) error { // if we have no entries yet, this is both the newest and oldest entry, so just put it in the front if p.data.Len() == 0 { p.data.PushFront(i) - return nil + return true, nil } var insertMark *list.Element @@ -34,14 +34,14 @@ func (p *postIndex) insertIndexed(i *postIndexEntry) error { entry, ok := e.Value.(*postIndexEntry) if !ok { - return errors.New("index: could not parse e as a postIndexEntry") + return false, errors.New("index: could not parse e as a postIndexEntry") } // don't insert this if it's a boost of a status we've seen recently if i.boostOfID != "" { if i.boostOfID == entry.boostOfID || i.boostOfID == entry.statusID { if position < boostReinsertionDepth { - return nil + return false, nil } } } @@ -55,16 +55,16 @@ func (p *postIndex) insertIndexed(i *postIndexEntry) error { // make sure we don't insert a duplicate if entry.statusID == i.statusID { - return nil + return false, nil } } if insertMark != nil { p.data.InsertBefore(i, insertMark) - return nil + return true, nil } // if we reach this point it's the oldest post we've seen so put it at the back p.data.PushBack(i) - return nil + return true, nil } diff --git a/internal/timeline/timeline.go b/internal/timeline/timeline.go index 363c099..5e274b5 100644 --- a/internal/timeline/timeline.go +++ b/internal/timeline/timeline.go @@ -62,7 +62,10 @@ type Timeline interface { */ // IndexOne puts a status into the timeline at the appropriate place according to its 'createdAt' property. - IndexOne(statusCreatedAt time.Time, statusID string, boostOfID string) error + // + // The returned bool indicates whether or not the status was actually inserted into the timeline. This will be false + // if the status is a boost and the original post or another boost of it already exists < boostReinsertionDepth back in the timeline. + IndexOne(statusCreatedAt time.Time, statusID string, boostOfID string) (bool, error) // OldestIndexedPostID returns the id of the rearmost (ie., the oldest) indexed post, or an error if something goes wrong. // If nothing goes wrong but there's no oldest post, an empty string will be returned so make sure to check for this. @@ -79,7 +82,10 @@ type Timeline interface { PrepareBehind(statusID string, amount int) error // IndexOne puts a status into the timeline at the appropriate place according to its 'createdAt' property, // and then immediately prepares it. - IndexAndPrepareOne(statusCreatedAt time.Time, statusID string) error + // + // The returned bool indicates whether or not the status was actually inserted into the timeline. This will be false + // if the status is a boost and the original post or another boost of it already exists < boostReinsertionDepth back in the timeline. + IndexAndPrepareOne(statusCreatedAt time.Time, statusID string) (bool, error) // OldestPreparedPostID returns the id of the rearmost (ie., the oldest) prepared post, or an error if something goes wrong. // If nothing goes wrong but there's no oldest post, an empty string will be returned so make sure to check for this. OldestPreparedPostID() (string, error)