From 42b8333d1b446866080854f812d5bd06c84d702a Mon Sep 17 00:00:00 2001 From: tsmethurst Date: Thu, 17 Jun 2021 22:51:50 +0200 Subject: [PATCH] additional faffing around with streaming --- internal/api/client/streaming/stream.go | 8 +- internal/cliactions/server/server.go | 3 + internal/processing/fromcommon.go | 9 ++ internal/processing/processor.go | 4 +- internal/processing/streaming.go | 4 +- .../synchronous/streaming/openstream.go | 108 ++++++++++++++++++ .../synchronous/streaming/stream.go | 22 ---- .../synchronous/streaming/streaming.go | 4 +- .../synchronous/streaming/streamstatus.go | 37 ++++++ 9 files changed, 171 insertions(+), 28 deletions(-) create mode 100644 internal/processing/synchronous/streaming/openstream.go delete mode 100644 internal/processing/synchronous/streaming/stream.go create mode 100644 internal/processing/synchronous/streaming/streamstatus.go diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index c0097d0..3aa0166 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -10,6 +10,8 @@ import ( ) func (m *Module) StreamGETHandler(c *gin.Context) { + l := m.log.WithField("func", "StreamGETHandler") + streamType := c.Query(StreamQueryKey) if streamType == "" { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("no stream type provided under query key %s", StreamQueryKey)}) @@ -33,14 +35,18 @@ func (m *Module) StreamGETHandler(c *gin.Context) { ReadBufferSize: 1024, WriteBufferSize: 1024, Subprotocols: []string{"wss"}, + CheckOrigin: func(r *http.Request) bool { + return true + }, } conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { + l.Infof("error upgrading websocket connection: %s", err) return } - if errWithCode := m.processor.StreamForAccount(conn, account, streamType); errWithCode != nil { + if errWithCode := m.processor.OpenStreamForAccount(conn, account, streamType); errWithCode != nil { c.JSON(errWithCode.Code(), errWithCode.Safe()) } } diff --git a/internal/cliactions/server/server.go b/internal/cliactions/server/server.go index f055890..282e0b8 100644 --- a/internal/cliactions/server/server.go +++ b/internal/cliactions/server/server.go @@ -39,6 +39,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/api/client/streaming" "github.com/superseriousbusiness/gotosocial/internal/router" timelineprocessing "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/transport" @@ -134,6 +135,7 @@ var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log adminModule := admin.New(c, processor, log) statusModule := status.New(c, processor, log) securityModule := security.New(c, log) + streamingModule := streaming.New(c, processor, log) apis := []api.ClientModule{ // modules with middleware go first @@ -157,6 +159,7 @@ var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log filtersModule, emojiModule, listsModule, + streamingModule, } for _, m := range apis { diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 65ccef4..3d60486 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -324,6 +324,15 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID if err := p.timelineManager.IngestAndPrepare(status, timelineAccount.ID); err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %s", status.ID, err) } + + mastoStatus, err := p.tc.StatusToMasto(status, timelineAccount) + if err != nil { + errors <- fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %s", status.ID, err) + } else { + if err := p.streamingProcessor.StreamStatusForAccount(mastoStatus, timelineAccount); err != nil { + errors <- fmt.Errorf("timelineStatusForAccount: error streaming status %s: %s", status.ID, err) + } + } } func (p *processor) deleteStatusFromTimelines(status *gtsmodel.Status) error { diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 396f0a2..5d7df62 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -136,8 +136,8 @@ type Processor interface { // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) - // StreamForAccount streams to websocket connection c for an account, with the given streamType. - StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode + // OpenStreamForAccount streams to websocket connection c for an account, with the given streamType. + OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode /* FEDERATION API-FACING PROCESSING FUNCTIONS diff --git a/internal/processing/streaming.go b/internal/processing/streaming.go index 80ca1bd..353ea00 100644 --- a/internal/processing/streaming.go +++ b/internal/processing/streaming.go @@ -10,6 +10,6 @@ func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Acc return p.streamingProcessor.AuthorizeStreamingRequest(accessToken) } -func (p *processor) StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { - return p.streamingProcessor.StreamForAccount(c, account, streamType) +func (p *processor) OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { + return p.streamingProcessor.OpenStreamForAccount(c, account, streamType) } diff --git a/internal/processing/synchronous/streaming/openstream.go b/internal/processing/synchronous/streaming/openstream.go new file mode 100644 index 0000000..5ea6ba2 --- /dev/null +++ b/internal/processing/synchronous/streaming/openstream.go @@ -0,0 +1,108 @@ +package streaming + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" +) + +func (p *processor) OpenStreamForAccount(conn *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { + l := p.log.WithFields(logrus.Fields{ + "func": "OpenStreamForAccount", + "account": account.ID, + "streamType": streamType, + }) + l.Debug("received open stream request") + + streamID, err := id.NewRandomULID() + if err != nil { + return gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err)) + } + + thisStream := &stream{ + streamID: streamID, + streamType: streamType, + conn: conn, + } + + v, ok := p.streamMap.Load(account.ID) + if !ok || v == nil { + // there is no entry in the streamMap for this account yet, so make one and store it + streams := &streamsForAccount{ + s: []*stream{ + thisStream, + }, + } + p.streamMap.Store(account.ID, streams) + } else { + // there is an entry in the streamMap for this account + // parse the interface as a streamsForAccount + streams, ok := v.(*streamsForAccount) + if !ok { + return gtserror.NewErrorInternalError(errors.New("stream map error")) + } + + // append this stream to it + streams.Lock() + streams.s = append(streams.s, thisStream) + streams.Unlock() + } + + // set the close handler to remove the given stream from the stream map so that messages stop getting put into it + conn.SetCloseHandler(func(code int, text string) error { + l.Debug("closing stream") + v, ok := p.streamMap.Load(account.ID) + if !ok || v == nil { + // the map doesn't contain an entry for the account anyway, so we can just return + // this probably should never happen but let's check anyway + return nil + } + + // parse the interface as a streamsForAccount + streams, ok := v.(*streamsForAccount) + if !ok { + return gtserror.NewErrorInternalError(errors.New("stream map error")) + } + + // remove thisStream from the slice of streams stored in streamsForAccount + streams.Lock() + newStreamSlice := []*stream{} + for _, s := range streams.s { + if s.streamID != thisStream.streamID { + newStreamSlice = append(newStreamSlice, s) + } + } + streams.s = newStreamSlice + streams.Unlock() + l.Debug("stream closed") + return nil + }) + + defer conn.Close() + t := time.NewTicker(60 * time.Second) + for range t.C { + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return gtserror.NewErrorInternalError(err) + } + } + + return nil +} + +type streamsForAccount struct { + s []*stream + sync.Mutex +} + +type stream struct { + streamID string + streamType string + conn *websocket.Conn +} diff --git a/internal/processing/synchronous/streaming/stream.go b/internal/processing/synchronous/streaming/stream.go deleted file mode 100644 index e2bfaad..0000000 --- a/internal/processing/synchronous/streaming/stream.go +++ /dev/null @@ -1,22 +0,0 @@ -package streaming - -import ( - "github.com/gorilla/websocket" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -) - -func (p *processor) StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { - - v, loaded := p.streamMap.LoadOrStore(account.ID, sync.Slice) - if loaded { - - } - - return nil -} - -type streams struct { - accountID string - -} diff --git a/internal/processing/synchronous/streaming/streaming.go b/internal/processing/synchronous/streaming/streaming.go index 207c1ec..0eecd25 100644 --- a/internal/processing/synchronous/streaming/streaming.go +++ b/internal/processing/synchronous/streaming/streaming.go @@ -12,13 +12,15 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" ) // Processor wraps a bunch of functions for processing streaming. type Processor interface { // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) - StreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode + OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode + StreamStatusForAccount(s *apimodel.Status, account *gtsmodel.Account) error } type processor struct { diff --git a/internal/processing/synchronous/streaming/streamstatus.go b/internal/processing/synchronous/streaming/streamstatus.go new file mode 100644 index 0000000..440a214 --- /dev/null +++ b/internal/processing/synchronous/streaming/streamstatus.go @@ -0,0 +1,37 @@ +package streaming + +import ( + "errors" + + "github.com/sirupsen/logrus" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (p *processor) StreamStatusForAccount(s *apimodel.Status, account *gtsmodel.Account) error { + l := p.log.WithFields(logrus.Fields{ + "func": "StreamStatusForAccount", + "account": account.ID, + }) + v, ok := p.streamMap.Load(account.ID) + if !ok { + // no open connections so nothing to stream + return nil + } + + streams, ok := v.(*streamsForAccount) + if !ok { + return errors.New("stream map error") + } + + streams.Lock() + defer streams.Unlock() + for _, stream := range streams.s { + l.Debugf("streaming status to stream id %s", stream.streamID) + if err := stream.conn.WriteJSON(s); err != nil { + return err + } + } + + return nil +}