From 0cee5aa5699e705911df96a1e52c87895c40ce1e Mon Sep 17 00:00:00 2001 From: tsmethurst Date: Thu, 17 Jun 2021 19:20:08 +0200 Subject: [PATCH 1/3] start messing about with streaming api --- go.mod | 1 + internal/api/client/streaming/stream.go | 46 ++++++++++++++ internal/api/client/streaming/streaming.go | 62 +++++++++++++++++++ internal/oauth/server.go | 5 ++ internal/processing/processor.go | 14 ++++- internal/processing/streaming.go | 15 +++++ .../synchronous/streaming/authorize.go | 33 ++++++++++ .../synchronous/streaming/stream.go | 22 +++++++ .../synchronous/streaming/streaming.go | 45 ++++++++++++++ 9 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 internal/api/client/streaming/stream.go create mode 100644 internal/api/client/streaming/streaming.go create mode 100644 internal/processing/streaming.go create mode 100644 internal/processing/synchronous/streaming/authorize.go create mode 100644 internal/processing/synchronous/streaming/stream.go create mode 100644 internal/processing/synchronous/streaming/streaming.go diff --git a/go.mod b/go.mod index aec00d3..81e2c4d 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/golang/mock v1.5.0 // indirect github.com/google/uuid v1.2.0 github.com/gorilla/sessions v1.2.1 // indirect + github.com/gorilla/websocket v1.4.2 // indirect github.com/h2non/filetype v1.1.1 github.com/json-iterator/go v1.1.11 // indirect github.com/leodido/go-urn v1.2.1 // indirect diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go new file mode 100644 index 0000000..c0097d0 --- /dev/null +++ b/internal/api/client/streaming/stream.go @@ -0,0 +1,46 @@ +package streaming + +import ( + "fmt" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" +) + +func (m *Module) StreamGETHandler(c *gin.Context) { + streamType := c.Query(StreamQueryKey) + if streamType == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("no stream type provided under query key %s", StreamQueryKey)}) + return + } + + accessToken := c.Query(AccessTokenQueryKey) + if accessToken == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("no access token provided under query key %s", AccessTokenQueryKey)}) + return + } + + account, err := m.processor.AuthorizeStreamingRequest(accessToken) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "could not authorize with given token"}) + return + } + + upgrader := websocket.Upgrader{ + HandshakeTimeout: 5 * time.Second, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + Subprotocols: []string{"wss"}, + } + + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + return + } + + if errWithCode := m.processor.StreamForAccount(conn, account, streamType); errWithCode != nil { + c.JSON(errWithCode.Code(), errWithCode.Safe()) + } +} diff --git a/internal/api/client/streaming/streaming.go b/internal/api/client/streaming/streaming.go new file mode 100644 index 0000000..5a7cee9 --- /dev/null +++ b/internal/api/client/streaming/streaming.go @@ -0,0 +1,62 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package streaming + +import ( + "net/http" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/api" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/router" +) + +const ( + // BasePath is the path for the streaming api + BasePath = "/api/v1/streaming" + + // StreamQueryKey is the query key for the type of stream being requested + StreamQueryKey = "stream" + + // AccessTokenQueryKey + AccessTokenQueryKey = "access_token" +) + +// Module implements the api.ClientModule interface for everything related to streaming +type Module struct { + config *config.Config + processor processing.Processor + log *logrus.Logger +} + +// New returns a new streaming module +func New(config *config.Config, processor processing.Processor, log *logrus.Logger) api.ClientModule { + return &Module{ + config: config, + processor: processor, + log: log, + } +} + +// Route attaches all routes from this module to the given router +func (m *Module) Route(r router.Router) error { + r.AttachHandler(http.MethodGet, BasePath, m.StreamGETHandler) + return nil +} diff --git a/internal/oauth/server.go b/internal/oauth/server.go index fb84743..1289b18 100644 --- a/internal/oauth/server.go +++ b/internal/oauth/server.go @@ -56,6 +56,7 @@ type Server interface { HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) + LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error) } // s fulfils the Server interface using the underlying oauth2 server @@ -171,3 +172,7 @@ func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, us s.log.Tracef("obtained user-level access token: %+v", accessToken) return accessToken, nil } + +func (s *s) LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error) { + return s.server.Manager.LoadAccessToken(ctx, access) +} diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 301cb57..396f0a2 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -22,6 +22,7 @@ import ( "context" "net/http" + "github.com/gorilla/websocket" "github.com/sirupsen/logrus" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/blob" @@ -33,6 +34,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing/synchronous/status" + "github.com/superseriousbusiness/gotosocial/internal/processing/synchronous/streaming" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" @@ -132,6 +134,11 @@ type Processor interface { // PublicTimelineGet returns statuses from the public/local timeline, with the given filters/parameters. PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, gtserror.WithCode) + // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API + AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) + // StreamForAccount streams to websocket connection c for an account, with the given streamType. + StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode + /* FEDERATION API-FACING PROCESSING FUNCTIONS These functions are intended to be called when the federating client needs an immediate (ie., synchronous) reply @@ -192,7 +199,8 @@ type processor struct { SUB-PROCESSORS */ - statusProcessor status.Processor + statusProcessor status.Processor + streamingProcessor streaming.Processor } // NewProcessor returns a new Processor that uses the given federator and logger @@ -202,6 +210,7 @@ func NewProcessor(config *config.Config, tc typeutils.TypeConverter, federator f fromFederator := make(chan gtsmodel.FromFederator, 1000) statusProcessor := status.New(db, tc, config, fromClientAPI, log) + streamingProcessor := streaming.New(db, tc, oauthServer, config, log) return &processor{ fromClientAPI: fromClientAPI, @@ -218,7 +227,8 @@ func NewProcessor(config *config.Config, tc typeutils.TypeConverter, federator f db: db, filter: visibility.NewFilter(db, log), - statusProcessor: statusProcessor, + statusProcessor: statusProcessor, + streamingProcessor: streamingProcessor, } } diff --git a/internal/processing/streaming.go b/internal/processing/streaming.go new file mode 100644 index 0000000..80ca1bd --- /dev/null +++ b/internal/processing/streaming.go @@ -0,0 +1,15 @@ +package processing + +import ( + "github.com/gorilla/websocket" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) { + return p.streamingProcessor.AuthorizeStreamingRequest(accessToken) +} + +func (p *processor) StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { + return p.streamingProcessor.StreamForAccount(c, account, streamType) +} diff --git a/internal/processing/synchronous/streaming/authorize.go b/internal/processing/synchronous/streaming/authorize.go new file mode 100644 index 0000000..8bbf185 --- /dev/null +++ b/internal/processing/synchronous/streaming/authorize.go @@ -0,0 +1,33 @@ +package streaming + +import ( + "context" + "fmt" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) { + ti, err := p.oauthServer.LoadAccessToken(context.Background(), accessToken) + if err != nil { + return nil, fmt.Errorf("AuthorizeStreamingRequest: error loading access token: %s", err) + } + + uid := ti.GetUserID() + if uid == "" { + return nil, fmt.Errorf("AuthorizeStreamingRequest: no userid in token") + } + + // fetch user's and account for this user id + user := >smodel.User{} + if err := p.db.GetByID(uid, user); err != nil || user == nil { + return nil, fmt.Errorf("AuthorizeStreamingRequest: no user found for validated uid %s", uid) + } + + acct := >smodel.Account{} + if err := p.db.GetByID(user.AccountID, acct); err != nil || acct == nil { + return nil, fmt.Errorf("AuthorizeStreamingRequest: no account retrieved for user with id %s", uid) + } + + return acct, nil +} diff --git a/internal/processing/synchronous/streaming/stream.go b/internal/processing/synchronous/streaming/stream.go new file mode 100644 index 0000000..e2bfaad --- /dev/null +++ b/internal/processing/synchronous/streaming/stream.go @@ -0,0 +1,22 @@ +package streaming + +import ( + "github.com/gorilla/websocket" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (p *processor) StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { + + v, loaded := p.streamMap.LoadOrStore(account.ID, sync.Slice) + if loaded { + + } + + return nil +} + +type streams struct { + accountID string + +} diff --git a/internal/processing/synchronous/streaming/streaming.go b/internal/processing/synchronous/streaming/streaming.go new file mode 100644 index 0000000..207c1ec --- /dev/null +++ b/internal/processing/synchronous/streaming/streaming.go @@ -0,0 +1,45 @@ +package streaming + +import ( + "sync" + + "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" +) + +// Processor wraps a bunch of functions for processing streaming. +type Processor interface { + // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API + AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) + StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode +} + +type processor struct { + tc typeutils.TypeConverter + config *config.Config + db db.DB + filter visibility.Filter + log *logrus.Logger + oauthServer oauth.Server + streamMap *sync.Map +} + +// New returns a new status processor. +func New(db db.DB, tc typeutils.TypeConverter, oauthServer oauth.Server, config *config.Config, log *logrus.Logger) Processor { + return &processor{ + tc: tc, + config: config, + db: db, + filter: visibility.NewFilter(db, log), + log: log, + oauthServer: oauthServer, + streamMap: &sync.Map{}, + } +} From 42b8333d1b446866080854f812d5bd06c84d702a Mon Sep 17 00:00:00 2001 From: tsmethurst Date: Thu, 17 Jun 2021 22:51:50 +0200 Subject: [PATCH 2/3] additional faffing around with streaming --- internal/api/client/streaming/stream.go | 8 +- internal/cliactions/server/server.go | 3 + internal/processing/fromcommon.go | 9 ++ internal/processing/processor.go | 4 +- internal/processing/streaming.go | 4 +- .../synchronous/streaming/openstream.go | 108 ++++++++++++++++++ .../synchronous/streaming/stream.go | 22 ---- .../synchronous/streaming/streaming.go | 4 +- .../synchronous/streaming/streamstatus.go | 37 ++++++ 9 files changed, 171 insertions(+), 28 deletions(-) create mode 100644 internal/processing/synchronous/streaming/openstream.go delete mode 100644 internal/processing/synchronous/streaming/stream.go create mode 100644 internal/processing/synchronous/streaming/streamstatus.go diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index c0097d0..3aa0166 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -10,6 +10,8 @@ import ( ) func (m *Module) StreamGETHandler(c *gin.Context) { + l := m.log.WithField("func", "StreamGETHandler") + streamType := c.Query(StreamQueryKey) if streamType == "" { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("no stream type provided under query key %s", StreamQueryKey)}) @@ -33,14 +35,18 @@ func (m *Module) StreamGETHandler(c *gin.Context) { ReadBufferSize: 1024, WriteBufferSize: 1024, Subprotocols: []string{"wss"}, + CheckOrigin: func(r *http.Request) bool { + return true + }, } conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { + l.Infof("error upgrading websocket connection: %s", err) return } - if errWithCode := m.processor.StreamForAccount(conn, account, streamType); errWithCode != nil { + if errWithCode := m.processor.OpenStreamForAccount(conn, account, streamType); errWithCode != nil { c.JSON(errWithCode.Code(), errWithCode.Safe()) } } diff --git a/internal/cliactions/server/server.go b/internal/cliactions/server/server.go index f055890..282e0b8 100644 --- a/internal/cliactions/server/server.go +++ b/internal/cliactions/server/server.go @@ -39,6 +39,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/api/client/streaming" "github.com/superseriousbusiness/gotosocial/internal/router" timelineprocessing "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/transport" @@ -134,6 +135,7 @@ var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log adminModule := admin.New(c, processor, log) statusModule := status.New(c, processor, log) securityModule := security.New(c, log) + streamingModule := streaming.New(c, processor, log) apis := []api.ClientModule{ // modules with middleware go first @@ -157,6 +159,7 @@ var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log filtersModule, emojiModule, listsModule, + streamingModule, } for _, m := range apis { diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 65ccef4..3d60486 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -324,6 +324,15 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID if err := p.timelineManager.IngestAndPrepare(status, timelineAccount.ID); err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %s", status.ID, err) } + + mastoStatus, err := p.tc.StatusToMasto(status, timelineAccount) + if err != nil { + errors <- fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %s", status.ID, err) + } else { + if err := p.streamingProcessor.StreamStatusForAccount(mastoStatus, timelineAccount); err != nil { + errors <- fmt.Errorf("timelineStatusForAccount: error streaming status %s: %s", status.ID, err) + } + } } func (p *processor) deleteStatusFromTimelines(status *gtsmodel.Status) error { diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 396f0a2..5d7df62 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -136,8 +136,8 @@ type Processor interface { // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) - // StreamForAccount streams to websocket connection c for an account, with the given streamType. - StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode + // OpenStreamForAccount streams to websocket connection c for an account, with the given streamType. + OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode /* FEDERATION API-FACING PROCESSING FUNCTIONS diff --git a/internal/processing/streaming.go b/internal/processing/streaming.go index 80ca1bd..353ea00 100644 --- a/internal/processing/streaming.go +++ b/internal/processing/streaming.go @@ -10,6 +10,6 @@ func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Acc return p.streamingProcessor.AuthorizeStreamingRequest(accessToken) } -func (p *processor) StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { - return p.streamingProcessor.StreamForAccount(c, account, streamType) +func (p *processor) OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { + return p.streamingProcessor.OpenStreamForAccount(c, account, streamType) } diff --git a/internal/processing/synchronous/streaming/openstream.go b/internal/processing/synchronous/streaming/openstream.go new file mode 100644 index 0000000..5ea6ba2 --- /dev/null +++ b/internal/processing/synchronous/streaming/openstream.go @@ -0,0 +1,108 @@ +package streaming + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" +) + +func (p *processor) OpenStreamForAccount(conn *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { + l := p.log.WithFields(logrus.Fields{ + "func": "OpenStreamForAccount", + "account": account.ID, + "streamType": streamType, + }) + l.Debug("received open stream request") + + streamID, err := id.NewRandomULID() + if err != nil { + return gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err)) + } + + thisStream := &stream{ + streamID: streamID, + streamType: streamType, + conn: conn, + } + + v, ok := p.streamMap.Load(account.ID) + if !ok || v == nil { + // there is no entry in the streamMap for this account yet, so make one and store it + streams := &streamsForAccount{ + s: []*stream{ + thisStream, + }, + } + p.streamMap.Store(account.ID, streams) + } else { + // there is an entry in the streamMap for this account + // parse the interface as a streamsForAccount + streams, ok := v.(*streamsForAccount) + if !ok { + return gtserror.NewErrorInternalError(errors.New("stream map error")) + } + + // append this stream to it + streams.Lock() + streams.s = append(streams.s, thisStream) + streams.Unlock() + } + + // set the close handler to remove the given stream from the stream map so that messages stop getting put into it + conn.SetCloseHandler(func(code int, text string) error { + l.Debug("closing stream") + v, ok := p.streamMap.Load(account.ID) + if !ok || v == nil { + // the map doesn't contain an entry for the account anyway, so we can just return + // this probably should never happen but let's check anyway + return nil + } + + // parse the interface as a streamsForAccount + streams, ok := v.(*streamsForAccount) + if !ok { + return gtserror.NewErrorInternalError(errors.New("stream map error")) + } + + // remove thisStream from the slice of streams stored in streamsForAccount + streams.Lock() + newStreamSlice := []*stream{} + for _, s := range streams.s { + if s.streamID != thisStream.streamID { + newStreamSlice = append(newStreamSlice, s) + } + } + streams.s = newStreamSlice + streams.Unlock() + l.Debug("stream closed") + return nil + }) + + defer conn.Close() + t := time.NewTicker(60 * time.Second) + for range t.C { + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return gtserror.NewErrorInternalError(err) + } + } + + return nil +} + +type streamsForAccount struct { + s []*stream + sync.Mutex +} + +type stream struct { + streamID string + streamType string + conn *websocket.Conn +} diff --git a/internal/processing/synchronous/streaming/stream.go b/internal/processing/synchronous/streaming/stream.go deleted file mode 100644 index e2bfaad..0000000 --- a/internal/processing/synchronous/streaming/stream.go +++ /dev/null @@ -1,22 +0,0 @@ -package streaming - -import ( - "github.com/gorilla/websocket" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -func (p *processor) StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { - - v, loaded := p.streamMap.LoadOrStore(account.ID, sync.Slice) - if loaded { - - } - - return nil -} - -type streams struct { - accountID string - -} diff --git a/internal/processing/synchronous/streaming/streaming.go b/internal/processing/synchronous/streaming/streaming.go index 207c1ec..0eecd25 100644 --- a/internal/processing/synchronous/streaming/streaming.go +++ b/internal/processing/synchronous/streaming/streaming.go @@ -12,13 +12,15 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" ) // Processor wraps a bunch of functions for processing streaming. type Processor interface { // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) - StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode + OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode + StreamStatusForAccount(s *apimodel.Status, account *gtsmodel.Account) error } type processor struct { diff --git a/internal/processing/synchronous/streaming/streamstatus.go b/internal/processing/synchronous/streaming/streamstatus.go new file mode 100644 index 0000000..440a214 --- /dev/null +++ b/internal/processing/synchronous/streaming/streamstatus.go @@ -0,0 +1,37 @@ +package streaming + +import ( + "errors" + + "github.com/sirupsen/logrus" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (p *processor) StreamStatusForAccount(s *apimodel.Status, account *gtsmodel.Account) error { + l := p.log.WithFields(logrus.Fields{ + "func": "StreamStatusForAccount", + "account": account.ID, + }) + v, ok := p.streamMap.Load(account.ID) + if !ok { + // no open connections so nothing to stream + return nil + } + + streams, ok := v.(*streamsForAccount) + if !ok { + return errors.New("stream map error") + } + + streams.Lock() + defer streams.Unlock() + for _, stream := range streams.s { + l.Debugf("streaming status to stream id %s", stream.streamID) + if err := stream.conn.WriteJSON(s); err != nil { + return err + } + } + + return nil +} From 79faab7239a6479b0ae96f3e006802412810ddf8 Mon Sep 17 00:00:00 2001 From: tsmethurst Date: Fri, 18 Jun 2021 13:06:02 +0200 Subject: [PATCH 3/3] streaming working --- go.mod | 2 +- internal/api/client/streaming/stream.go | 47 +++++++- internal/api/client/streaming/streaming.go | 2 +- internal/cliactions/server/server.go | 2 +- internal/db/pg/pg.go | 6 +- internal/gtsmodel/stream.go | 38 ++++++ internal/processing/fromcommon.go | 2 +- internal/processing/processor.go | 7 +- internal/processing/streaming.go | 5 +- .../synchronous/streaming/openstream.go | 112 ++++++++---------- .../synchronous/streaming/streaming.go | 7 +- .../synchronous/streaming/streamstatus.go | 29 +++-- internal/router/router.go | 4 +- 13 files changed, 170 insertions(+), 93 deletions(-) create mode 100644 internal/gtsmodel/stream.go diff --git a/go.mod b/go.mod index 81e2c4d..eab14c9 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/golang/mock v1.5.0 // indirect github.com/google/uuid v1.2.0 github.com/gorilla/sessions v1.2.1 // indirect - github.com/gorilla/websocket v1.4.2 // indirect + github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.1 github.com/json-iterator/go v1.1.11 // indirect github.com/leodido/go-urn v1.2.1 // indirect diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 3aa0166..a5b8a2d 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -9,6 +9,7 @@ import ( "github.com/gorilla/websocket" ) +// StreamGETHandler handles the creation of a new websocket streaming request. func (m *Module) StreamGETHandler(c *gin.Context) { l := m.log.WithField("func", "StreamGETHandler") @@ -24,29 +25,65 @@ func (m *Module) StreamGETHandler(c *gin.Context) { return } + // make sure a valid token has been provided and obtain the associated account account, err := m.processor.AuthorizeStreamingRequest(accessToken) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "could not authorize with given token"}) return } + // prepare to upgrade the connection to a websocket connection upgrader := websocket.Upgrader{ - HandshakeTimeout: 5 * time.Second, - ReadBufferSize: 1024, - WriteBufferSize: 1024, - Subprotocols: []string{"wss"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { + // we fully expect cors requests (via something like pinafore.social) so we should be lenient here return true }, } + // do the actual upgrade here conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { l.Infof("error upgrading websocket connection: %s", err) return } + defer conn.Close() // whatever happens, when we leave this function we want to close the websocket connection - if errWithCode := m.processor.OpenStreamForAccount(conn, account, streamType); errWithCode != nil { + // inform the processor that we have a new connection and want a stream for it + stream, errWithCode := m.processor.OpenStreamForAccount(account, streamType) + if errWithCode != nil { c.JSON(errWithCode.Code(), errWithCode.Safe()) + return } + defer close(stream.Hangup) // closing stream.Hangup indicates that we've finished with the connection (the client has gone), so we want to do this on exiting this handler + + // spawn a new ticker for pinging the connection periodically + t := time.NewTicker(30 * time.Second) + + // we want to stay in the sendloop as long as possible while the client is connected -- the only thing that should break the loop is if the client leaves or something else goes wrong +sendLoop: + for { + select { + case m := <-stream.Messages: + // we've got a streaming message!! + l.Debug("received message from stream") + if err := conn.WriteJSON(m); err != nil { + l.Infof("error writing json to websocket connection: %s", err) + // if something is wrong we want to bail and drop the connection -- the client will create a new one + break sendLoop + } + l.Debug("wrote message into websocket connection") + case <-t.C: + l.Debug("received TICK from ticker") + if err := conn.WriteMessage(websocket.PingMessage, []byte(": ping")); err != nil { + l.Infof("error writing ping to websocket connection: %s", err) + // if something is wrong we want to bail and drop the connection -- the client will create a new one + break sendLoop + } + l.Debug("wrote ping message into websocket connection") + } + } + + l.Debug("leaving StreamGETHandler") } diff --git a/internal/api/client/streaming/streaming.go b/internal/api/client/streaming/streaming.go index 5a7cee9..92dfccd 100644 --- a/internal/api/client/streaming/streaming.go +++ b/internal/api/client/streaming/streaming.go @@ -35,7 +35,7 @@ const ( // StreamQueryKey is the query key for the type of stream being requested StreamQueryKey = "stream" - // AccessTokenQueryKey + // AccessTokenQueryKey is the query key for an oauth access token that should be passed in streaming requests. AccessTokenQueryKey = "access_token" ) diff --git a/internal/cliactions/server/server.go b/internal/cliactions/server/server.go index 282e0b8..74b1c78 100644 --- a/internal/cliactions/server/server.go +++ b/internal/cliactions/server/server.go @@ -24,6 +24,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/api/client/notification" "github.com/superseriousbusiness/gotosocial/internal/api/client/search" "github.com/superseriousbusiness/gotosocial/internal/api/client/status" + "github.com/superseriousbusiness/gotosocial/internal/api/client/streaming" "github.com/superseriousbusiness/gotosocial/internal/api/client/timeline" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger" @@ -39,7 +40,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" - "github.com/superseriousbusiness/gotosocial/internal/api/client/streaming" "github.com/superseriousbusiness/gotosocial/internal/router" timelineprocessing "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/transport" diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go index 8515013..887df92 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/pg/pg.go @@ -810,7 +810,7 @@ func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccoun if sourceAccount == nil || targetAccount == nil { return false, nil } - + return ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() } @@ -818,7 +818,7 @@ func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targ if sourceAccount == nil || targetAccount == nil { return false, nil } - + return ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() } @@ -826,7 +826,7 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode if account1 == nil || account2 == nil { return false, nil } - + // make sure account 1 follows account 2 f1, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists() if err != nil { diff --git a/internal/gtsmodel/stream.go b/internal/gtsmodel/stream.go new file mode 100644 index 0000000..4a1571d --- /dev/null +++ b/internal/gtsmodel/stream.go @@ -0,0 +1,38 @@ +package gtsmodel + +import "sync" + +// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time. +// TODO: put a limit on this +type StreamsForAccount struct { + // The currently held streams for this account + Streams []*Stream + // Mutex to lock/unlock when modifying the slice of streams. + sync.Mutex +} + +// Stream represents one open stream for a client. +type Stream struct { + // ID of this stream, generated during creation. + ID string + // Type of this stream: user/public/etc + Type string + // Channel of messages for the client to read from + Messages chan *Message + // Channel to close when the client drops away + Hangup chan interface{} + // Only put messages in the stream when Connected + Connected bool + // Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc. + sync.Mutex +} + +// Message represents one streamed message. +type Message struct { + // All the stream types this message should be delivered to. + Stream []string `json:"stream"` + // The event type of the message (update/delete/notification etc) + Event string `json:"event"` + // The actual payload of the message. In case of an update or notification, this will be a JSON string. + Payload string `json:"payload"` +} diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 3d60486..719d72e 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -329,7 +329,7 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID if err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %s", status.ID, err) } else { - if err := p.streamingProcessor.StreamStatusForAccount(mastoStatus, timelineAccount); err != nil { + if err := p.streamingProcessor.StreamStatusToAccount(mastoStatus, timelineAccount); err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error streaming status %s: %s", status.ID, err) } } diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 5d7df62..d1b4431 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -22,7 +22,6 @@ import ( "context" "net/http" - "github.com/gorilla/websocket" "github.com/sirupsen/logrus" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/blob" @@ -134,10 +133,10 @@ type Processor interface { // PublicTimelineGet returns statuses from the public/local timeline, with the given filters/parameters. PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, gtserror.WithCode) - // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API + // AuthorizeStreamingRequest returns a gotosocial account in exchange for an access token, or an error if the given token is not valid. AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) - // OpenStreamForAccount streams to websocket connection c for an account, with the given streamType. - OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode + // OpenStreamForAccount opens a new stream for the given account, with the given stream type. + OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) /* FEDERATION API-FACING PROCESSING FUNCTIONS diff --git a/internal/processing/streaming.go b/internal/processing/streaming.go index 353ea00..1e566da 100644 --- a/internal/processing/streaming.go +++ b/internal/processing/streaming.go @@ -1,7 +1,6 @@ package processing import ( - "github.com/gorilla/websocket" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -10,6 +9,6 @@ func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Acc return p.streamingProcessor.AuthorizeStreamingRequest(accessToken) } -func (p *processor) OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { - return p.streamingProcessor.OpenStreamForAccount(c, account, streamType) +func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) { + return p.streamingProcessor.OpenStreamForAccount(account, streamType) } diff --git a/internal/processing/synchronous/streaming/openstream.go b/internal/processing/synchronous/streaming/openstream.go index 5ea6ba2..68446ba 100644 --- a/internal/processing/synchronous/streaming/openstream.go +++ b/internal/processing/synchronous/streaming/openstream.go @@ -3,17 +3,14 @@ package streaming import ( "errors" "fmt" - "sync" - "time" - "github.com/gorilla/websocket" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" ) -func (p *processor) OpenStreamForAccount(conn *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { +func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) { l := p.log.WithFields(logrus.Fields{ "func": "OpenStreamForAccount", "account": account.ID, @@ -21,88 +18,83 @@ func (p *processor) OpenStreamForAccount(conn *websocket.Conn, account *gtsmodel }) l.Debug("received open stream request") + // each stream needs a unique ID so we know to close it streamID, err := id.NewRandomULID() if err != nil { - return gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err)) + return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err)) } - thisStream := &stream{ - streamID: streamID, - streamType: streamType, - conn: conn, + thisStream := >smodel.Stream{ + ID: streamID, + Type: streamType, + Messages: make(chan *gtsmodel.Message, 100), + Hangup: make(chan interface{}, 1), + Connected: true, } + go p.waitToCloseStream(account, thisStream) v, ok := p.streamMap.Load(account.ID) if !ok || v == nil { // there is no entry in the streamMap for this account yet, so make one and store it - streams := &streamsForAccount{ - s: []*stream{ + streamsForAccount := >smodel.StreamsForAccount{ + Streams: []*gtsmodel.Stream{ thisStream, }, } - p.streamMap.Store(account.ID, streams) + p.streamMap.Store(account.ID, streamsForAccount) } else { // there is an entry in the streamMap for this account // parse the interface as a streamsForAccount - streams, ok := v.(*streamsForAccount) + streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount) if !ok { - return gtserror.NewErrorInternalError(errors.New("stream map error")) + return nil, gtserror.NewErrorInternalError(errors.New("stream map error")) } // append this stream to it - streams.Lock() - streams.s = append(streams.s, thisStream) - streams.Unlock() + streamsForAccount.Lock() + streamsForAccount.Streams = append(streamsForAccount.Streams, thisStream) + streamsForAccount.Unlock() } - // set the close handler to remove the given stream from the stream map so that messages stop getting put into it - conn.SetCloseHandler(func(code int, text string) error { - l.Debug("closing stream") - v, ok := p.streamMap.Load(account.ID) - if !ok || v == nil { - // the map doesn't contain an entry for the account anyway, so we can just return - // this probably should never happen but let's check anyway - return nil - } + return thisStream, nil +} - // parse the interface as a streamsForAccount - streams, ok := v.(*streamsForAccount) - if !ok { - return gtserror.NewErrorInternalError(errors.New("stream map error")) - } +// waitToCloseStream waits until the hangup channel is closed for the given stream. +// It then iterates through the map of streams stored by the processor, removes the stream from it, +// and then closes the messages channel of the stream to indicate that the channel should no longer be read from. +func (p *processor) waitToCloseStream(account *gtsmodel.Account, thisStream *gtsmodel.Stream) { + <-thisStream.Hangup // wait for a hangup message - // remove thisStream from the slice of streams stored in streamsForAccount - streams.Lock() - newStreamSlice := []*stream{} - for _, s := range streams.s { - if s.streamID != thisStream.streamID { - newStreamSlice = append(newStreamSlice, s) - } - } - streams.s = newStreamSlice - streams.Unlock() - l.Debug("stream closed") - return nil - }) + // lock the stream to prevent more messages being put in it while we work + thisStream.Lock() + defer thisStream.Unlock() - defer conn.Close() - t := time.NewTicker(60 * time.Second) - for range t.C { - if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { - return gtserror.NewErrorInternalError(err) - } + // indicate the stream is no longer connected + thisStream.Connected = false + + // load and parse the entry for this account from the stream map + v, ok := p.streamMap.Load(account.ID) + if !ok || v == nil { + return + } + streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount) + if !ok { + return } - return nil -} + // lock the streams for account while we remove this stream from its slice + streamsForAccount.Lock() + defer streamsForAccount.Unlock() -type streamsForAccount struct { - s []*stream - sync.Mutex -} + // put everything into modified streams *except* the stream we're removing + modifiedStreams := []*gtsmodel.Stream{} + for _, s := range streamsForAccount.Streams { + if s.ID != thisStream.ID { + modifiedStreams = append(modifiedStreams, s) + } + } + streamsForAccount.Streams = modifiedStreams -type stream struct { - streamID string - streamType string - conn *websocket.Conn + // finally close the messages channel so no more messages can be read from it + close(thisStream.Messages) } diff --git a/internal/processing/synchronous/streaming/streaming.go b/internal/processing/synchronous/streaming/streaming.go index 0eecd25..848da32 100644 --- a/internal/processing/synchronous/streaming/streaming.go +++ b/internal/processing/synchronous/streaming/streaming.go @@ -3,8 +3,8 @@ package streaming import ( "sync" - "github.com/gorilla/websocket" "github.com/sirupsen/logrus" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" @@ -12,15 +12,14 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" - apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" ) // Processor wraps a bunch of functions for processing streaming. type Processor interface { // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) - OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode - StreamStatusForAccount(s *apimodel.Status, account *gtsmodel.Account) error + OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) + StreamStatusToAccount(s *apimodel.Status, account *gtsmodel.Account) error } type processor struct { diff --git a/internal/processing/synchronous/streaming/streamstatus.go b/internal/processing/synchronous/streaming/streamstatus.go index 440a214..8d02625 100644 --- a/internal/processing/synchronous/streaming/streamstatus.go +++ b/internal/processing/synchronous/streaming/streamstatus.go @@ -1,14 +1,16 @@ package streaming import ( + "encoding/json" "errors" + "fmt" "github.com/sirupsen/logrus" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) StreamStatusForAccount(s *apimodel.Status, account *gtsmodel.Account) error { +func (p *processor) StreamStatusToAccount(s *apimodel.Status, account *gtsmodel.Account) error { l := p.log.WithFields(logrus.Fields{ "func": "StreamStatusForAccount", "account": account.ID, @@ -19,17 +21,28 @@ func (p *processor) StreamStatusForAccount(s *apimodel.Status, account *gtsmodel return nil } - streams, ok := v.(*streamsForAccount) + streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount) if !ok { return errors.New("stream map error") } - streams.Lock() - defer streams.Unlock() - for _, stream := range streams.s { - l.Debugf("streaming status to stream id %s", stream.streamID) - if err := stream.conn.WriteJSON(s); err != nil { - return err + statusBytes, err := json.Marshal(s) + if err != nil { + return fmt.Errorf("error marshalling status to json: %s", err) + } + + streamsForAccount.Lock() + defer streamsForAccount.Unlock() + for _, stream := range streamsForAccount.Streams { + stream.Lock() + defer stream.Unlock() + if stream.Connected { + l.Debugf("streaming status to stream id %s", stream.ID) + stream.Messages <- >smodel.Message{ + Stream: []string{stream.Type}, + Event: "update", + Payload: string(statusBytes), + } } } diff --git a/internal/router/router.go b/internal/router/router.go index adfd2ce..e575b11 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -128,9 +128,9 @@ func New(config *config.Config, logger *logrus.Logger) (Router, error) { AllowAllOrigins: true, AllowBrowserExtensions: true, AllowMethods: []string{"POST", "PUT", "DELETE", "GET", "PATCH", "OPTIONS"}, - AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization"}, + AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization", "Upgrade", "Sec-WebSocket-Extensions", "Sec-WebSocket-Key", "Sec-WebSocket-Protocol", "Sec-WebSocket-Version", "Connection"}, AllowWebSockets: true, - ExposeHeaders: []string{"Link", "X-RateLimit-Reset", "X-RateLimit-Limit", " X-RateLimit-Remaining", "X-Request-Id"}, + ExposeHeaders: []string{"Link", "X-RateLimit-Reset", "X-RateLimit-Limit", " X-RateLimit-Remaining", "X-Request-Id", "Connection", "Sec-WebSocket-Accept", "Upgrade"}, MaxAge: 2 * time.Minute, })) engine.MaxMultipartMemory = 8 << 20 // 8 MiB