diff --git a/go.mod b/go.mod index 81e2c4d..eab14c9 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/golang/mock v1.5.0 // indirect github.com/google/uuid v1.2.0 github.com/gorilla/sessions v1.2.1 // indirect - github.com/gorilla/websocket v1.4.2 // indirect + github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.1 github.com/json-iterator/go v1.1.11 // indirect github.com/leodido/go-urn v1.2.1 // indirect diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 3aa0166..a5b8a2d 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -9,6 +9,7 @@ import ( "github.com/gorilla/websocket" ) +// StreamGETHandler handles the creation of a new websocket streaming request. func (m *Module) StreamGETHandler(c *gin.Context) { l := m.log.WithField("func", "StreamGETHandler") @@ -24,29 +25,65 @@ func (m *Module) StreamGETHandler(c *gin.Context) { return } + // make sure a valid token has been provided and obtain the associated account account, err := m.processor.AuthorizeStreamingRequest(accessToken) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "could not authorize with given token"}) return } + // prepare to upgrade the connection to a websocket connection upgrader := websocket.Upgrader{ - HandshakeTimeout: 5 * time.Second, - ReadBufferSize: 1024, - WriteBufferSize: 1024, - Subprotocols: []string{"wss"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { + // we fully expect cors requests (via something like pinafore.social) so we should be lenient here return true }, } + // do the actual upgrade here conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { l.Infof("error upgrading websocket connection: %s", err) return } + defer conn.Close() // whatever happens, when we leave this function we want to close the websocket connection - if errWithCode := m.processor.OpenStreamForAccount(conn, account, streamType); errWithCode != nil { + // inform the processor that we have a new connection and want a stream for it + stream, errWithCode := m.processor.OpenStreamForAccount(account, streamType) + if errWithCode != nil { c.JSON(errWithCode.Code(), errWithCode.Safe()) + return } + defer close(stream.Hangup) // closing stream.Hangup indicates that we've finished with the connection (the client has gone), so we want to do this on exiting this handler + + // spawn a new ticker for pinging the connection periodically + t := time.NewTicker(30 * time.Second) + + // we want to stay in the sendloop as long as possible while the client is connected -- the only thing that should break the loop is if the client leaves or something else goes wrong +sendLoop: + for { + select { + case m := <-stream.Messages: + // we've got a streaming message!! + l.Debug("received message from stream") + if err := conn.WriteJSON(m); err != nil { + l.Infof("error writing json to websocket connection: %s", err) + // if something is wrong we want to bail and drop the connection -- the client will create a new one + break sendLoop + } + l.Debug("wrote message into websocket connection") + case <-t.C: + l.Debug("received TICK from ticker") + if err := conn.WriteMessage(websocket.PingMessage, []byte(": ping")); err != nil { + l.Infof("error writing ping to websocket connection: %s", err) + // if something is wrong we want to bail and drop the connection -- the client will create a new one + break sendLoop + } + l.Debug("wrote ping message into websocket connection") + } + } + + l.Debug("leaving StreamGETHandler") } diff --git a/internal/api/client/streaming/streaming.go b/internal/api/client/streaming/streaming.go index 5a7cee9..92dfccd 100644 --- a/internal/api/client/streaming/streaming.go +++ b/internal/api/client/streaming/streaming.go @@ -35,7 +35,7 @@ const ( // StreamQueryKey is the query key for the type of stream being requested StreamQueryKey = "stream" - // AccessTokenQueryKey + // AccessTokenQueryKey is the query key for an oauth access token that should be passed in streaming requests. AccessTokenQueryKey = "access_token" ) diff --git a/internal/cliactions/server/server.go b/internal/cliactions/server/server.go index 282e0b8..74b1c78 100644 --- a/internal/cliactions/server/server.go +++ b/internal/cliactions/server/server.go @@ -24,6 +24,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/api/client/notification" "github.com/superseriousbusiness/gotosocial/internal/api/client/search" "github.com/superseriousbusiness/gotosocial/internal/api/client/status" + "github.com/superseriousbusiness/gotosocial/internal/api/client/streaming" "github.com/superseriousbusiness/gotosocial/internal/api/client/timeline" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/user" "github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger" @@ -39,7 +40,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" - "github.com/superseriousbusiness/gotosocial/internal/api/client/streaming" "github.com/superseriousbusiness/gotosocial/internal/router" timelineprocessing "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/transport" diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go index 8515013..887df92 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/pg/pg.go @@ -810,7 +810,7 @@ func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccoun if sourceAccount == nil || targetAccount == nil { return false, nil } - + return ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() } @@ -818,7 +818,7 @@ func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targ if sourceAccount == nil || targetAccount == nil { return false, nil } - + return ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() } @@ -826,7 +826,7 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode if account1 == nil || account2 == nil { return false, nil } - + // make sure account 1 follows account 2 f1, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists() if err != nil { diff --git a/internal/gtsmodel/stream.go b/internal/gtsmodel/stream.go new file mode 100644 index 0000000..4a1571d --- /dev/null +++ b/internal/gtsmodel/stream.go @@ -0,0 +1,38 @@ +package gtsmodel + +import "sync" + +// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time. +// TODO: put a limit on this +type StreamsForAccount struct { + // The currently held streams for this account + Streams []*Stream + // Mutex to lock/unlock when modifying the slice of streams. + sync.Mutex +} + +// Stream represents one open stream for a client. +type Stream struct { + // ID of this stream, generated during creation. + ID string + // Type of this stream: user/public/etc + Type string + // Channel of messages for the client to read from + Messages chan *Message + // Channel to close when the client drops away + Hangup chan interface{} + // Only put messages in the stream when Connected + Connected bool + // Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc. + sync.Mutex +} + +// Message represents one streamed message. +type Message struct { + // All the stream types this message should be delivered to. + Stream []string `json:"stream"` + // The event type of the message (update/delete/notification etc) + Event string `json:"event"` + // The actual payload of the message. In case of an update or notification, this will be a JSON string. + Payload string `json:"payload"` +} diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 3d60486..719d72e 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -329,7 +329,7 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID if err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %s", status.ID, err) } else { - if err := p.streamingProcessor.StreamStatusForAccount(mastoStatus, timelineAccount); err != nil { + if err := p.streamingProcessor.StreamStatusToAccount(mastoStatus, timelineAccount); err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error streaming status %s: %s", status.ID, err) } } diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 5d7df62..d1b4431 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -22,7 +22,6 @@ import ( "context" "net/http" - "github.com/gorilla/websocket" "github.com/sirupsen/logrus" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/blob" @@ -134,10 +133,10 @@ type Processor interface { // PublicTimelineGet returns statuses from the public/local timeline, with the given filters/parameters. PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, gtserror.WithCode) - // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API + // AuthorizeStreamingRequest returns a gotosocial account in exchange for an access token, or an error if the given token is not valid. AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) - // OpenStreamForAccount streams to websocket connection c for an account, with the given streamType. - OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode + // OpenStreamForAccount opens a new stream for the given account, with the given stream type. + OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) /* FEDERATION API-FACING PROCESSING FUNCTIONS diff --git a/internal/processing/streaming.go b/internal/processing/streaming.go index 353ea00..1e566da 100644 --- a/internal/processing/streaming.go +++ b/internal/processing/streaming.go @@ -1,7 +1,6 @@ package processing import ( - "github.com/gorilla/websocket" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -10,6 +9,6 @@ func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Acc return p.streamingProcessor.AuthorizeStreamingRequest(accessToken) } -func (p *processor) OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { - return p.streamingProcessor.OpenStreamForAccount(c, account, streamType) +func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) { + return p.streamingProcessor.OpenStreamForAccount(account, streamType) } diff --git a/internal/processing/synchronous/streaming/openstream.go b/internal/processing/synchronous/streaming/openstream.go index 5ea6ba2..68446ba 100644 --- a/internal/processing/synchronous/streaming/openstream.go +++ b/internal/processing/synchronous/streaming/openstream.go @@ -3,17 +3,14 @@ package streaming import ( "errors" "fmt" - "sync" - "time" - "github.com/gorilla/websocket" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" ) -func (p *processor) OpenStreamForAccount(conn *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode { +func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) { l := p.log.WithFields(logrus.Fields{ "func": "OpenStreamForAccount", "account": account.ID, @@ -21,88 +18,83 @@ func (p *processor) OpenStreamForAccount(conn *websocket.Conn, account *gtsmodel }) l.Debug("received open stream request") + // each stream needs a unique ID so we know to close it streamID, err := id.NewRandomULID() if err != nil { - return gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err)) + return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err)) } - thisStream := &stream{ - streamID: streamID, - streamType: streamType, - conn: conn, + thisStream := >smodel.Stream{ + ID: streamID, + Type: streamType, + Messages: make(chan *gtsmodel.Message, 100), + Hangup: make(chan interface{}, 1), + Connected: true, } + go p.waitToCloseStream(account, thisStream) v, ok := p.streamMap.Load(account.ID) if !ok || v == nil { // there is no entry in the streamMap for this account yet, so make one and store it - streams := &streamsForAccount{ - s: []*stream{ + streamsForAccount := >smodel.StreamsForAccount{ + Streams: []*gtsmodel.Stream{ thisStream, }, } - p.streamMap.Store(account.ID, streams) + p.streamMap.Store(account.ID, streamsForAccount) } else { // there is an entry in the streamMap for this account // parse the interface as a streamsForAccount - streams, ok := v.(*streamsForAccount) + streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount) if !ok { - return gtserror.NewErrorInternalError(errors.New("stream map error")) + return nil, gtserror.NewErrorInternalError(errors.New("stream map error")) } // append this stream to it - streams.Lock() - streams.s = append(streams.s, thisStream) - streams.Unlock() + streamsForAccount.Lock() + streamsForAccount.Streams = append(streamsForAccount.Streams, thisStream) + streamsForAccount.Unlock() } - // set the close handler to remove the given stream from the stream map so that messages stop getting put into it - conn.SetCloseHandler(func(code int, text string) error { - l.Debug("closing stream") - v, ok := p.streamMap.Load(account.ID) - if !ok || v == nil { - // the map doesn't contain an entry for the account anyway, so we can just return - // this probably should never happen but let's check anyway - return nil - } + return thisStream, nil +} - // parse the interface as a streamsForAccount - streams, ok := v.(*streamsForAccount) - if !ok { - return gtserror.NewErrorInternalError(errors.New("stream map error")) - } +// waitToCloseStream waits until the hangup channel is closed for the given stream. +// It then iterates through the map of streams stored by the processor, removes the stream from it, +// and then closes the messages channel of the stream to indicate that the channel should no longer be read from. +func (p *processor) waitToCloseStream(account *gtsmodel.Account, thisStream *gtsmodel.Stream) { + <-thisStream.Hangup // wait for a hangup message - // remove thisStream from the slice of streams stored in streamsForAccount - streams.Lock() - newStreamSlice := []*stream{} - for _, s := range streams.s { - if s.streamID != thisStream.streamID { - newStreamSlice = append(newStreamSlice, s) - } - } - streams.s = newStreamSlice - streams.Unlock() - l.Debug("stream closed") - return nil - }) + // lock the stream to prevent more messages being put in it while we work + thisStream.Lock() + defer thisStream.Unlock() - defer conn.Close() - t := time.NewTicker(60 * time.Second) - for range t.C { - if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { - return gtserror.NewErrorInternalError(err) - } + // indicate the stream is no longer connected + thisStream.Connected = false + + // load and parse the entry for this account from the stream map + v, ok := p.streamMap.Load(account.ID) + if !ok || v == nil { + return + } + streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount) + if !ok { + return } - return nil -} + // lock the streams for account while we remove this stream from its slice + streamsForAccount.Lock() + defer streamsForAccount.Unlock() -type streamsForAccount struct { - s []*stream - sync.Mutex -} + // put everything into modified streams *except* the stream we're removing + modifiedStreams := []*gtsmodel.Stream{} + for _, s := range streamsForAccount.Streams { + if s.ID != thisStream.ID { + modifiedStreams = append(modifiedStreams, s) + } + } + streamsForAccount.Streams = modifiedStreams -type stream struct { - streamID string - streamType string - conn *websocket.Conn + // finally close the messages channel so no more messages can be read from it + close(thisStream.Messages) } diff --git a/internal/processing/synchronous/streaming/streaming.go b/internal/processing/synchronous/streaming/streaming.go index 0eecd25..848da32 100644 --- a/internal/processing/synchronous/streaming/streaming.go +++ b/internal/processing/synchronous/streaming/streaming.go @@ -3,8 +3,8 @@ package streaming import ( "sync" - "github.com/gorilla/websocket" "github.com/sirupsen/logrus" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" @@ -12,15 +12,14 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" - apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" ) // Processor wraps a bunch of functions for processing streaming. type Processor interface { // AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error) - OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode - StreamStatusForAccount(s *apimodel.Status, account *gtsmodel.Account) error + OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) + StreamStatusToAccount(s *apimodel.Status, account *gtsmodel.Account) error } type processor struct { diff --git a/internal/processing/synchronous/streaming/streamstatus.go b/internal/processing/synchronous/streaming/streamstatus.go index 440a214..8d02625 100644 --- a/internal/processing/synchronous/streaming/streamstatus.go +++ b/internal/processing/synchronous/streaming/streamstatus.go @@ -1,14 +1,16 @@ package streaming import ( + "encoding/json" "errors" + "fmt" "github.com/sirupsen/logrus" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (p *processor) StreamStatusForAccount(s *apimodel.Status, account *gtsmodel.Account) error { +func (p *processor) StreamStatusToAccount(s *apimodel.Status, account *gtsmodel.Account) error { l := p.log.WithFields(logrus.Fields{ "func": "StreamStatusForAccount", "account": account.ID, @@ -19,17 +21,28 @@ func (p *processor) StreamStatusForAccount(s *apimodel.Status, account *gtsmodel return nil } - streams, ok := v.(*streamsForAccount) + streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount) if !ok { return errors.New("stream map error") } - streams.Lock() - defer streams.Unlock() - for _, stream := range streams.s { - l.Debugf("streaming status to stream id %s", stream.streamID) - if err := stream.conn.WriteJSON(s); err != nil { - return err + statusBytes, err := json.Marshal(s) + if err != nil { + return fmt.Errorf("error marshalling status to json: %s", err) + } + + streamsForAccount.Lock() + defer streamsForAccount.Unlock() + for _, stream := range streamsForAccount.Streams { + stream.Lock() + defer stream.Unlock() + if stream.Connected { + l.Debugf("streaming status to stream id %s", stream.ID) + stream.Messages <- >smodel.Message{ + Stream: []string{stream.Type}, + Event: "update", + Payload: string(statusBytes), + } } } diff --git a/internal/router/router.go b/internal/router/router.go index adfd2ce..e575b11 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -128,9 +128,9 @@ func New(config *config.Config, logger *logrus.Logger) (Router, error) { AllowAllOrigins: true, AllowBrowserExtensions: true, AllowMethods: []string{"POST", "PUT", "DELETE", "GET", "PATCH", "OPTIONS"}, - AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization"}, + AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization", "Upgrade", "Sec-WebSocket-Extensions", "Sec-WebSocket-Key", "Sec-WebSocket-Protocol", "Sec-WebSocket-Version", "Connection"}, AllowWebSockets: true, - ExposeHeaders: []string{"Link", "X-RateLimit-Reset", "X-RateLimit-Limit", " X-RateLimit-Remaining", "X-Request-Id"}, + ExposeHeaders: []string{"Link", "X-RateLimit-Reset", "X-RateLimit-Limit", " X-RateLimit-Remaining", "X-Request-Id", "Connection", "Sec-WebSocket-Accept", "Upgrade"}, MaxAge: 2 * time.Minute, })) engine.MaxMultipartMemory = 8 << 20 // 8 MiB