streaming working

This commit is contained in:
tsmethurst 2021-06-18 13:06:02 +02:00 committed by tsmethurst
parent e87dec807e
commit a3d3cd8d43
13 changed files with 170 additions and 93 deletions

2
go.mod
View File

@ -26,7 +26,7 @@ require (
github.com/golang/mock v1.5.0 // indirect
github.com/google/uuid v1.2.0
github.com/gorilla/sessions v1.2.1 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/gorilla/websocket v1.4.2
github.com/h2non/filetype v1.1.1
github.com/json-iterator/go v1.1.11 // indirect
github.com/leodido/go-urn v1.2.1 // indirect

View File

@ -9,6 +9,7 @@ import (
"github.com/gorilla/websocket"
)
// StreamGETHandler handles the creation of a new websocket streaming request.
func (m *Module) StreamGETHandler(c *gin.Context) {
l := m.log.WithField("func", "StreamGETHandler")
@ -24,29 +25,65 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
return
}
// make sure a valid token has been provided and obtain the associated account
account, err := m.processor.AuthorizeStreamingRequest(accessToken)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "could not authorize with given token"})
return
}
// prepare to upgrade the connection to a websocket connection
upgrader := websocket.Upgrader{
HandshakeTimeout: 5 * time.Second,
ReadBufferSize: 1024,
WriteBufferSize: 1024,
Subprotocols: []string{"wss"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// we fully expect cors requests (via something like pinafore.social) so we should be lenient here
return true
},
}
// do the actual upgrade here
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
l.Infof("error upgrading websocket connection: %s", err)
return
}
defer conn.Close() // whatever happens, when we leave this function we want to close the websocket connection
if errWithCode := m.processor.OpenStreamForAccount(conn, account, streamType); errWithCode != nil {
// inform the processor that we have a new connection and want a stream for it
stream, errWithCode := m.processor.OpenStreamForAccount(account, streamType)
if errWithCode != nil {
c.JSON(errWithCode.Code(), errWithCode.Safe())
return
}
defer close(stream.Hangup) // closing stream.Hangup indicates that we've finished with the connection (the client has gone), so we want to do this on exiting this handler
// spawn a new ticker for pinging the connection periodically
t := time.NewTicker(30 * time.Second)
// we want to stay in the sendloop as long as possible while the client is connected -- the only thing that should break the loop is if the client leaves or something else goes wrong
sendLoop:
for {
select {
case m := <-stream.Messages:
// we've got a streaming message!!
l.Debug("received message from stream")
if err := conn.WriteJSON(m); err != nil {
l.Infof("error writing json to websocket connection: %s", err)
// if something is wrong we want to bail and drop the connection -- the client will create a new one
break sendLoop
}
l.Debug("wrote message into websocket connection")
case <-t.C:
l.Debug("received TICK from ticker")
if err := conn.WriteMessage(websocket.PingMessage, []byte(": ping")); err != nil {
l.Infof("error writing ping to websocket connection: %s", err)
// if something is wrong we want to bail and drop the connection -- the client will create a new one
break sendLoop
}
l.Debug("wrote ping message into websocket connection")
}
}
l.Debug("leaving StreamGETHandler")
}

View File

@ -35,7 +35,7 @@ const (
// StreamQueryKey is the query key for the type of stream being requested
StreamQueryKey = "stream"
// AccessTokenQueryKey
// AccessTokenQueryKey is the query key for an oauth access token that should be passed in streaming requests.
AccessTokenQueryKey = "access_token"
)

View File

@ -24,6 +24,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/api/client/notification"
"github.com/superseriousbusiness/gotosocial/internal/api/client/search"
"github.com/superseriousbusiness/gotosocial/internal/api/client/status"
"github.com/superseriousbusiness/gotosocial/internal/api/client/streaming"
"github.com/superseriousbusiness/gotosocial/internal/api/client/timeline"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/user"
"github.com/superseriousbusiness/gotosocial/internal/api/s2s/webfinger"
@ -39,7 +40,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/api/client/streaming"
"github.com/superseriousbusiness/gotosocial/internal/router"
timelineprocessing "github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/transport"

View File

@ -810,7 +810,7 @@ func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccoun
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
return ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
@ -818,7 +818,7 @@ func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targ
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
return ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
@ -826,7 +826,7 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
f1, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
if err != nil {

View File

@ -0,0 +1,38 @@
package gtsmodel
import "sync"
// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time.
// TODO: put a limit on this
type StreamsForAccount struct {
// The currently held streams for this account
Streams []*Stream
// Mutex to lock/unlock when modifying the slice of streams.
sync.Mutex
}
// Stream represents one open stream for a client.
type Stream struct {
// ID of this stream, generated during creation.
ID string
// Type of this stream: user/public/etc
Type string
// Channel of messages for the client to read from
Messages chan *Message
// Channel to close when the client drops away
Hangup chan interface{}
// Only put messages in the stream when Connected
Connected bool
// Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc.
sync.Mutex
}
// Message represents one streamed message.
type Message struct {
// All the stream types this message should be delivered to.
Stream []string `json:"stream"`
// The event type of the message (update/delete/notification etc)
Event string `json:"event"`
// The actual payload of the message. In case of an update or notification, this will be a JSON string.
Payload string `json:"payload"`
}

View File

@ -329,7 +329,7 @@ func (p *processor) timelineStatusForAccount(status *gtsmodel.Status, accountID
if err != nil {
errors <- fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %s", status.ID, err)
} else {
if err := p.streamingProcessor.StreamStatusForAccount(mastoStatus, timelineAccount); err != nil {
if err := p.streamingProcessor.StreamStatusToAccount(mastoStatus, timelineAccount); err != nil {
errors <- fmt.Errorf("timelineStatusForAccount: error streaming status %s: %s", status.ID, err)
}
}

View File

@ -22,7 +22,6 @@ import (
"context"
"net/http"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/blob"
@ -134,10 +133,10 @@ type Processor interface {
// PublicTimelineGet returns statuses from the public/local timeline, with the given filters/parameters.
PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) ([]*apimodel.Status, gtserror.WithCode)
// AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API
// AuthorizeStreamingRequest returns a gotosocial account in exchange for an access token, or an error if the given token is not valid.
AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error)
// OpenStreamForAccount streams to websocket connection c for an account, with the given streamType.
OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode
// OpenStreamForAccount opens a new stream for the given account, with the given stream type.
OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode)
/*
FEDERATION API-FACING PROCESSING FUNCTIONS

View File

@ -1,7 +1,6 @@
package processing
import (
"github.com/gorilla/websocket"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@ -10,6 +9,6 @@ func (p *processor) AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Acc
return p.streamingProcessor.AuthorizeStreamingRequest(accessToken)
}
func (p *processor) OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode {
return p.streamingProcessor.OpenStreamForAccount(c, account, streamType)
func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) {
return p.streamingProcessor.OpenStreamForAccount(account, streamType)
}

View File

@ -3,17 +3,14 @@ package streaming
import (
"errors"
"fmt"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
)
func (p *processor) OpenStreamForAccount(conn *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode {
func (p *processor) OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode) {
l := p.log.WithFields(logrus.Fields{
"func": "OpenStreamForAccount",
"account": account.ID,
@ -21,88 +18,83 @@ func (p *processor) OpenStreamForAccount(conn *websocket.Conn, account *gtsmodel
})
l.Debug("received open stream request")
// each stream needs a unique ID so we know to close it
streamID, err := id.NewRandomULID()
if err != nil {
return gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err))
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err))
}
thisStream := &stream{
streamID: streamID,
streamType: streamType,
conn: conn,
thisStream := &gtsmodel.Stream{
ID: streamID,
Type: streamType,
Messages: make(chan *gtsmodel.Message, 100),
Hangup: make(chan interface{}, 1),
Connected: true,
}
go p.waitToCloseStream(account, thisStream)
v, ok := p.streamMap.Load(account.ID)
if !ok || v == nil {
// there is no entry in the streamMap for this account yet, so make one and store it
streams := &streamsForAccount{
s: []*stream{
streamsForAccount := &gtsmodel.StreamsForAccount{
Streams: []*gtsmodel.Stream{
thisStream,
},
}
p.streamMap.Store(account.ID, streams)
p.streamMap.Store(account.ID, streamsForAccount)
} else {
// there is an entry in the streamMap for this account
// parse the interface as a streamsForAccount
streams, ok := v.(*streamsForAccount)
streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount)
if !ok {
return gtserror.NewErrorInternalError(errors.New("stream map error"))
return nil, gtserror.NewErrorInternalError(errors.New("stream map error"))
}
// append this stream to it
streams.Lock()
streams.s = append(streams.s, thisStream)
streams.Unlock()
streamsForAccount.Lock()
streamsForAccount.Streams = append(streamsForAccount.Streams, thisStream)
streamsForAccount.Unlock()
}
// set the close handler to remove the given stream from the stream map so that messages stop getting put into it
conn.SetCloseHandler(func(code int, text string) error {
l.Debug("closing stream")
v, ok := p.streamMap.Load(account.ID)
if !ok || v == nil {
// the map doesn't contain an entry for the account anyway, so we can just return
// this probably should never happen but let's check anyway
return nil
}
return thisStream, nil
}
// parse the interface as a streamsForAccount
streams, ok := v.(*streamsForAccount)
if !ok {
return gtserror.NewErrorInternalError(errors.New("stream map error"))
}
// waitToCloseStream waits until the hangup channel is closed for the given stream.
// It then iterates through the map of streams stored by the processor, removes the stream from it,
// and then closes the messages channel of the stream to indicate that the channel should no longer be read from.
func (p *processor) waitToCloseStream(account *gtsmodel.Account, thisStream *gtsmodel.Stream) {
<-thisStream.Hangup // wait for a hangup message
// remove thisStream from the slice of streams stored in streamsForAccount
streams.Lock()
newStreamSlice := []*stream{}
for _, s := range streams.s {
if s.streamID != thisStream.streamID {
newStreamSlice = append(newStreamSlice, s)
}
}
streams.s = newStreamSlice
streams.Unlock()
l.Debug("stream closed")
return nil
})
// lock the stream to prevent more messages being put in it while we work
thisStream.Lock()
defer thisStream.Unlock()
defer conn.Close()
t := time.NewTicker(60 * time.Second)
for range t.C {
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return gtserror.NewErrorInternalError(err)
}
// indicate the stream is no longer connected
thisStream.Connected = false
// load and parse the entry for this account from the stream map
v, ok := p.streamMap.Load(account.ID)
if !ok || v == nil {
return
}
streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount)
if !ok {
return
}
return nil
}
// lock the streams for account while we remove this stream from its slice
streamsForAccount.Lock()
defer streamsForAccount.Unlock()
type streamsForAccount struct {
s []*stream
sync.Mutex
}
// put everything into modified streams *except* the stream we're removing
modifiedStreams := []*gtsmodel.Stream{}
for _, s := range streamsForAccount.Streams {
if s.ID != thisStream.ID {
modifiedStreams = append(modifiedStreams, s)
}
}
streamsForAccount.Streams = modifiedStreams
type stream struct {
streamID string
streamType string
conn *websocket.Conn
// finally close the messages channel so no more messages can be read from it
close(thisStream.Messages)
}

View File

@ -3,8 +3,8 @@ package streaming
import (
"sync"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
@ -12,15 +12,14 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
)
// Processor wraps a bunch of functions for processing streaming.
type Processor interface {
// AuthorizeStreamingRequest returns an oauth2 token info in response to an access token query from the streaming API
AuthorizeStreamingRequest(accessToken string) (*gtsmodel.Account, error)
OpenStreamForAccount(c *websocket.Conn, account *gtsmodel.Account, streamType string) gtserror.WithCode
StreamStatusForAccount(s *apimodel.Status, account *gtsmodel.Account) error
OpenStreamForAccount(account *gtsmodel.Account, streamType string) (*gtsmodel.Stream, gtserror.WithCode)
StreamStatusToAccount(s *apimodel.Status, account *gtsmodel.Account) error
}
type processor struct {

View File

@ -1,14 +1,16 @@
package streaming
import (
"encoding/json"
"errors"
"fmt"
"github.com/sirupsen/logrus"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (p *processor) StreamStatusForAccount(s *apimodel.Status, account *gtsmodel.Account) error {
func (p *processor) StreamStatusToAccount(s *apimodel.Status, account *gtsmodel.Account) error {
l := p.log.WithFields(logrus.Fields{
"func": "StreamStatusForAccount",
"account": account.ID,
@ -19,17 +21,28 @@ func (p *processor) StreamStatusForAccount(s *apimodel.Status, account *gtsmodel
return nil
}
streams, ok := v.(*streamsForAccount)
streamsForAccount, ok := v.(*gtsmodel.StreamsForAccount)
if !ok {
return errors.New("stream map error")
}
streams.Lock()
defer streams.Unlock()
for _, stream := range streams.s {
l.Debugf("streaming status to stream id %s", stream.streamID)
if err := stream.conn.WriteJSON(s); err != nil {
return err
statusBytes, err := json.Marshal(s)
if err != nil {
return fmt.Errorf("error marshalling status to json: %s", err)
}
streamsForAccount.Lock()
defer streamsForAccount.Unlock()
for _, stream := range streamsForAccount.Streams {
stream.Lock()
defer stream.Unlock()
if stream.Connected {
l.Debugf("streaming status to stream id %s", stream.ID)
stream.Messages <- &gtsmodel.Message{
Stream: []string{stream.Type},
Event: "update",
Payload: string(statusBytes),
}
}
}

View File

@ -128,9 +128,9 @@ func New(config *config.Config, logger *logrus.Logger) (Router, error) {
AllowAllOrigins: true,
AllowBrowserExtensions: true,
AllowMethods: []string{"POST", "PUT", "DELETE", "GET", "PATCH", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization"},
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization", "Upgrade", "Sec-WebSocket-Extensions", "Sec-WebSocket-Key", "Sec-WebSocket-Protocol", "Sec-WebSocket-Version", "Connection"},
AllowWebSockets: true,
ExposeHeaders: []string{"Link", "X-RateLimit-Reset", "X-RateLimit-Limit", " X-RateLimit-Remaining", "X-Request-Id"},
ExposeHeaders: []string{"Link", "X-RateLimit-Reset", "X-RateLimit-Limit", " X-RateLimit-Remaining", "X-Request-Id", "Connection", "Sec-WebSocket-Accept", "Upgrade"},
MaxAge: 2 * time.Minute,
}))
engine.MaxMultipartMemory = 8 << 20 // 8 MiB