diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 3856fb0133533..95cd504cc9816 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -25,6 +25,7 @@ import ( "errors" "fmt" "io" + "io/fs" "os" "path/filepath" "sort" @@ -957,7 +958,9 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID if rmErr := os.Remove(tarballPath); rmErr != nil { l.log.WithError(rmErr).Warningf("Failed to remove file %v.", tarballPath) } - + if errors.Is(err, fs.ErrNotExist) { + err = trace.NotFound("a recording for session %v was not found", sessionID) + } e <- trace.Wrap(err) return c, e } @@ -980,7 +983,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID for { if ctx.Err() != nil { e <- trace.Wrap(ctx.Err()) - break + return } event, err := protoReader.Read(ctx) @@ -990,12 +993,16 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID } else { close(c) } - - break + return } if event.GetIndex() >= startIndex { - c <- event + select { + case c <- event: + case <-ctx.Done(): + e <- trace.Wrap(ctx.Err()) + return + } } } }() diff --git a/lib/player/player.go b/lib/player/player.go index 4be2813f12702..2e225a2136ac0 100644 --- a/lib/player/player.go +++ b/lib/player/player.go @@ -19,7 +19,6 @@ import ( "context" "errors" "math" - "os" "sync/atomic" "time" @@ -64,10 +63,15 @@ type Player struct { // and is inspired by "Rethinking Classical Concurrency Patterns" // by Bryan C. Mills (GopherCon 2018): https://www.youtube.com/watch?v=5zXAHh5tJqQ playPause chan chan struct{} + + // err holds the error (if any) encountered during playback + err error } const normalPlayback = math.MinInt64 +// Streamer is the underlying streamer that provides +// access to recorded session events. type Streamer interface { StreamSessionEvents( ctx context.Context, @@ -100,10 +104,7 @@ func New(cfg *Config) (*Player, error) { var log logrus.FieldLogger = cfg.Log if log == nil { - l := logrus.New().WithField(trace.Component, "player") - l.Logger.SetOutput(os.Stdout) // TODO(zmb3) remove - l.Logger.SetLevel(logrus.DebugLevel) - log = l + log = logrus.New().WithField(trace.Component, "player") } p := &Player{ @@ -137,6 +138,11 @@ const ( maxPlaybackSpeed = 16 ) +// SetSpeed adjusts the playback speed of the player. +// It can be called at any time (the player can be in a playing +// or paused state). A speed of 1.0 plays back at regular speed, +// while a speed of 2.0 plays back twice as fast as originally +// recorded. Valid speeds range from 0.25 to 16.0. func (p *Player) SetSpeed(s float64) error { if s < minPlaybackSpeed || s > maxPlaybackSpeed { return trace.BadParameter("speed %v is out of range", s) @@ -146,8 +152,10 @@ func (p *Player) SetSpeed(s float64) error { } func (p *Player) stream() { - // TODO(zmb3): consider using context instead of close chan - eventsC, errC := p.streamer.StreamSessionEvents(context.TODO(), p.sessionID, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + eventsC, errC := p.streamer.StreamSessionEvents(ctx, p.sessionID, 0) lastDelay := int64(0) for { select { @@ -155,13 +163,13 @@ func (p *Player) stream() { close(p.emit) return case err := <-errC: - // TODO(zmb3): figure out how to surface the error - // (probably close the chan and expose a method) p.log.Warn(err) + p.err = err + close(p.emit) return case evt := <-eventsC: if evt == nil { - p.log.Debug("reached end of playback") + p.log.Debugf("reached end of playback for session %v", p.sessionID) close(p.emit) return } @@ -203,7 +211,6 @@ func (p *Player) stream() { lastDelay = currentDelay } - p.log.Debugf("playing %v (%v)", evt.GetType(), evt.GetID()) select { case p.emit <- evt: p.lastPlayed.Store(currentDelay) @@ -229,7 +236,12 @@ func (p *Player) C() <-chan events.AuditEvent { return p.emit } -// TODO(zmb3): add an Err() method to be checked after C is closed +// Err returns the error (if any) that occurred during playback. +// It should only be called after the channel returned by [C] is +// closed. +func (p *Player) Err() error { + return p.err +} // Pause temporarily stops the player from emitting events. // It is a no-op if playback is currently paused. @@ -262,7 +274,6 @@ func (p *Player) SetPos(d time.Duration) error { // applyDelay "sleeps" for d in a manner that // can be canceled func (p *Player) applyDelay(d time.Duration) error { - p.log.Debugf("waiting %v until next event", d) scaled := float64(d) / p.speed.Load().(float64) select { case <-p.done: diff --git a/lib/player/player_test.go b/lib/player/player_test.go index 0a0aa01588430..8ac04cb632928 100644 --- a/lib/player/player_test.go +++ b/lib/player/player_test.go @@ -47,6 +47,7 @@ func TestBasicStream(t *testing.T) { } require.Equal(t, 3, count) + require.NoError(t, p.Err()) } func TestPlayPause(t *testing.T) { @@ -163,6 +164,7 @@ func TestClose(t *testing.T) { // channel should have been closed _, ok := <-p.C() require.False(t, ok, "player channel should have been closed") + require.NoError(t, p.Err()) } func TestSeekForward(t *testing.T) { diff --git a/lib/web/desktop/playback.go b/lib/web/desktop/playback.go index de53c02d6c00d..8614ff8083c6a 100644 --- a/lib/web/desktop/playback.go +++ b/lib/web/desktop/playback.go @@ -18,19 +18,14 @@ package desktop import ( "context" - "errors" + "encoding/json" "fmt" - "net" - "os" - "sync" - "time" - "github.com/gravitational/trace" "github.com/sirupsen/logrus" "golang.org/x/net/websocket" - apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/player" "github.com/gravitational/teleport/lib/utils" ) @@ -39,74 +34,6 @@ const ( maxPlaybackSpeed = 16 ) -// Player manages the playback of a recorded desktop session. -// It streams events from the audit log to the browser over -// a websocket connection. -type Player struct { - ws *websocket.Conn - streamer Streamer - - mu sync.Mutex - cond *sync.Cond - playState playbackState - playSpeed float32 - - log logrus.FieldLogger - sID string - - closeOnce sync.Once -} - -// Streamer is the interface that can provide with a stream of events related to -// a particular session. -type Streamer interface { - // StreamSessionEvents streams all events from a given session recording. An error is returned on the first - // channel if one is encountered. Otherwise the event channel is closed when the stream ends. - // The event channel is not closed on error to prevent race conditions in downstream select statements. - StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) -} - -// NewPlayer creates a player that streams a desktop session -// over the provided websocket connection. -func NewPlayer(sID string, ws *websocket.Conn, streamer Streamer, log logrus.FieldLogger) *Player { - p := &Player{ - ws: ws, - streamer: streamer, - playState: playStatePlaying, - log: log, - sID: sID, - playSpeed: 1.0, - } - p.cond = sync.NewCond(&p.mu) - return p -} - -// Play kicks off goroutines for receiving actions -// and playing back the session over the websocket, -// and then waits for the stream to complete. -func (pp *Player) Play(ctx context.Context) { - defer pp.log.Debug("playbackPlayer.Play returned") - - pp.ws.PayloadType = websocket.BinaryFrame - ppCtx, cancel := context.WithCancel(ctx) - defer pp.close(cancel) - - go pp.receiveActions(cancel) - go pp.streamSessionEvents(ppCtx, cancel) - - // Wait until the ctx is canceled, either by - // one of the goroutines above or by the http handler. - <-ppCtx.Done() -} - -type playbackState string - -const ( - playStatePlaying = playbackState("playing") - playStatePaused = playbackState("paused") - playStateFinished = playbackState("finished") -) - // playbackAction identifies a command sent from the // browser to control playback type playbackAction string @@ -125,163 +52,99 @@ const ( // control playback. type actionMessage struct { Action playbackAction `json:"action"` - PlaybackSpeed float32 `json:"speed,omitempty"` -} - -// waitWhilePaused waits idly while the player's state is paused, waiting until: -// - the play state is toggled back to playing -// - the play state is set to finished (the player is closed) -func (pp *Player) waitWhilePaused() { - pp.cond.L.Lock() - defer pp.cond.L.Unlock() - - for pp.playState == playStatePaused { - pp.cond.Wait() - } + PlaybackSpeed float64 `json:"speed,omitempty"` } -// togglePlaying toggles the state of the player between playing and paused, -// and wakes up any goroutines waiting in waitWhilePaused. -func (pp *Player) togglePlaying() { - pp.cond.L.Lock() - defer pp.cond.L.Unlock() - switch pp.playState { - case playStatePlaying: - pp.playState = playStatePaused - case playStatePaused: - pp.playState = playStatePlaying - } - pp.cond.Broadcast() -} - -// close closes the websocket connection, wakes up any goroutines waiting on the playState condition, -// and cancels the playbackPlayer's context. -// -// It should be deferred by all the goroutines that use playbackPlayer, -// in order to ensure that when one goroutine closes, all the others do too. -func (pp *Player) close(cancel context.CancelFunc) { - pp.closeOnce.Do(func() { - pp.mu.Lock() - defer pp.mu.Unlock() - - err := pp.ws.Close() - if err != nil { - pp.log.WithError(err).Errorf("websocket.Close() failed") - } - - pp.playState = playStateFinished - pp.cond.Broadcast() - cancel() - }) -} - -// receiveActions handles logic for receiving playbackAction jsons -// over the websocket and modifying playbackPlayer's state accordingly. -func (pp *Player) receiveActions(cancel context.CancelFunc) { - defer pp.log.Debug("playbackPlayer.ReceiveActions returned") - defer pp.close(cancel) +// ReceivePlaybackActions handles logic for receiving playbackAction messages +// over the websocket and updating the player state accordingly. +func ReceivePlaybackActions( + log logrus.FieldLogger, + ws *websocket.Conn, + player *player.Player) { + // playback always starts in a playing state + playing := true for { var action actionMessage - if err := websocket.JSON.Receive(pp.ws, &action); err != nil { - // We expect net.ErrClosed if the websocket is closed by another - // goroutine and io.EOF if the websocket is closed by the browser - // while websocket.JSON.Receive() is hanging. + if err := websocket.JSON.Receive(ws, &action); err != nil { + // Connection close errors are expected if the user closes the tab. + // Only log unexpected errors to avoid cluttering the logs. if !utils.IsOKNetworkError(err) { - pp.log.WithError(err).Error("error reading from websocket") + log.Warnf("websocket read error: %v", err) } return } - pp.log.Debugf("received playback action: %+v", action) + switch action.Action { case actionPlayPause: - pp.togglePlaying() - case actionSpeed: - if action.PlaybackSpeed < minPlaybackSpeed { - action.PlaybackSpeed = minPlaybackSpeed - } else if action.PlaybackSpeed > maxPlaybackSpeed { - action.PlaybackSpeed = maxPlaybackSpeed + if playing { + player.Pause() + } else { + player.Play() } - - pp.mu.Lock() - pp.playSpeed = action.PlaybackSpeed - pp.mu.Unlock() + playing = !playing + case actionSpeed: + action.PlaybackSpeed = max(action.PlaybackSpeed, minPlaybackSpeed) + action.PlaybackSpeed = min(action.PlaybackSpeed, maxPlaybackSpeed) + player.SetSpeed(action.PlaybackSpeed) default: - pp.log.Errorf("received unknown action: %v", action.Action) + log.Warnf("invalid desktop playback action: %v", action.Action) return } } } -// streamSessionEvents streams the session's events as playback events over the websocket. -func (pp *Player) streamSessionEvents(ctx context.Context, cancel context.CancelFunc) { - defer pp.log.Debug("playbackPlayer.StreamSessionEvents returned") - defer pp.close(cancel) - - var lastDelay int64 - scaleDelay := func(delay int64) int64 { - pp.mu.Lock() - defer pp.mu.Unlock() - return int64(float32(delay) / pp.playSpeed) - } - eventsC, errC := pp.streamer.StreamSessionEvents(ctx, session.ID(pp.sID), 0) +// PlayRecording feeds recorded events from a player +// over a websocket. +func PlayRecording( + ctx context.Context, + log logrus.FieldLogger, + ws *websocket.Conn, + player *player.Player) { + player.Play() for { - pp.waitWhilePaused() - select { - case err := <-errC: - if err != nil && !errors.Is(err, context.Canceled) { - pp.log.WithError(err).Errorf("streaming session %v", pp.sID) - var errorText string - if os.IsNotExist(err) || trace.IsNotFound(err) { - errorText = "session not found" - } else { - errorText = "server error" - } - if _, err := pp.ws.Write([]byte(fmt.Sprintf(`{"message": "error", "errorText": "%v"}`, errorText))); err != nil { - pp.log.WithError(err).Error("failed to write \"error\" message over websocket") - } - } + case <-ctx.Done(): return - case evt := <-eventsC: - if evt == nil { - pp.log.Debug("reached end of playback") - if _, err := pp.ws.Write([]byte(`{"message":"end"}`)); err != nil { - pp.log.WithError(err).Error("failed to write \"end\" message over websocket") - } - return - } - switch e := evt.(type) { - case *apievents.DesktopRecording: - if e.DelayMilliseconds > lastDelay { - // TODO(zmb3): replace with time.After so we can cancel - time.Sleep(time.Duration(scaleDelay(e.DelayMilliseconds-lastDelay)) * time.Millisecond) - lastDelay = e.DelayMilliseconds - } - msg, err := utils.FastMarshal(e) - if err != nil { - pp.log.WithError(err).Errorf("failed to marshal DesktopRecording event into JSON: %v", e) - if _, err := pp.ws.Write([]byte(`{"message":"error","errorText":"server error"}`)); err != nil { - pp.log.WithError(err).Error("failed to write \"error\" message over websocket") + case evt, ok := <-player.C(): + if !ok { + if playerErr := player.Err(); playerErr != nil { + // Attempt to JSONify the error (escaping any quotes) + msg, err := json.Marshal(playerErr.Error()) + if err != nil { + log.Warnf("failed to marshal player error message: %v", err) + msg = []byte(`"internal server error"`) } - return - } - if _, err := pp.ws.Write(msg); err != nil { - // We expect net.ErrClosed to arise when another goroutine returns before - // this one or the browser window is closed, both of which cause the websocket to close. - if !errors.Is(err, net.ErrClosed) { - pp.log.WithError(err).Error("failed to write DesktopRecording event over websocket") + //lint:ignore QF1012 this write needs to happen in a single operation + if _, err := ws.Write([]byte(fmt.Sprintf(`{"message":"error", "errorText":%s}`, string(msg)))); err != nil { + log.Errorf("failed to write error message: %v", err) } return } - case *apievents.WindowsDesktopSessionStart, *apievents.WindowsDesktopSessionEnd: - // these events are part of the stream but never needed for playback - case *apievents.DesktopClipboardReceive, *apievents.DesktopClipboardSend: - // these events are not currently needed for playback, - // but may be useful in the future + if _, err := ws.Write([]byte(`{"message":"end"}`)); err != nil { + log.Errorf("failed to write end message: %v", err) + } + return + } - default: - pp.log.Warnf("session %v contains unexpected event type %T", pp.sID, evt) + // some events are part of the stream but not currently + // needed during playback (session start/end, clipboard use, etc) + if _, ok := evt.(*events.DesktopRecording); !ok { + continue + } + msg, err := utils.FastMarshal(evt) + if err != nil { + log.Errorf("failed to marshal desktop event: %v", err) + ws.Write([]byte(`{"message":"error","errorText":"server error"}`)) + return + } + if _, err := ws.Write(msg); err != nil { + // Connection close errors are expected if the user closes the tab. + // Only log unexpected errors to avoid cluttering the logs. + if !utils.IsOKNetworkError(err) { + log.Warnf("websocket write error: %v", err) + } + return } } } diff --git a/lib/web/desktop/playback_test.go b/lib/web/desktop/playback_test.go index 4728b40a6a843..9f66ee7c9ae3d 100644 --- a/lib/web/desktop/playback_test.go +++ b/lib/web/desktop/playback_test.go @@ -23,11 +23,15 @@ import ( "testing" "time" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/websocket" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events/eventstest" + "github.com/gravitational/teleport/lib/player" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/desktop" ) @@ -72,9 +76,19 @@ func newServer(t *testing.T, streamInterval time.Duration, events []apievents.Au t.Helper() fs := eventstest.NewFakeStreamer(events, streamInterval) + log := utils.NewLoggerForTests() + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { websocket.Handler(func(ws *websocket.Conn) { - desktop.NewPlayer("session-id", ws, fs, utils.NewLoggerForTests()).Play(r.Context()) + player, err := player.New(&player.Config{ + Clock: clockwork.NewRealClock(), + Log: log, + SessionID: session.ID("session-id"), + Streamer: fs, + }) + assert.NoError(t, err) + player.Play() + desktop.PlayRecording(r.Context(), log, ws, player) }).ServeHTTP(w, r) })) t.Cleanup(s.Close) diff --git a/lib/web/desktop_playback.go b/lib/web/desktop_playback.go index be2035288e580..086597099b23d 100644 --- a/lib/web/desktop_playback.go +++ b/lib/web/desktop_playback.go @@ -17,13 +17,16 @@ limitations under the License. package web import ( + "context" "net/http" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" "golang.org/x/net/websocket" + "github.com/gravitational/teleport/lib/player" "github.com/gravitational/teleport/lib/reversetunnelclient" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/web/desktop" ) @@ -31,22 +34,52 @@ func (h *Handler) desktopPlaybackHandle( w http.ResponseWriter, r *http.Request, p httprouter.Params, - ctx *SessionContext, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ) (interface{}, error) { sID := p.ByName("sid") if sID == "" { - return nil, trace.BadParameter("missing sid in request URL") + return nil, trace.BadParameter("missing session ID in request URL") } - clt, err := ctx.GetUserClient(r.Context(), site) + clt, err := sctx.GetUserClient(r.Context(), site) if err != nil { return nil, trace.Wrap(err) } websocket.Handler(func(ws *websocket.Conn) { - defer h.log.Debug("desktopPlaybackHandle websocket handler returned") - desktop.NewPlayer(sID, ws, clt, h.log).Play(r.Context()) + ws.PayloadType = websocket.BinaryFrame + + player, err := player.New(&player.Config{ + Clock: h.clock, + Log: h.log, + SessionID: session.ID(sID), + Streamer: clt, + }) + if err != nil { + h.log.Errorf("couldn't create player for session %v: %v", sID, err) + ws.Write([]byte(`{"message": "error", "errorText": "Internal server error"}`)) + return + } + + defer player.Close() + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + go func() { + defer cancel() + desktop.ReceivePlaybackActions(h.log, ws, player) + }() + + go func() { + defer cancel() + defer ws.Close() + desktop.PlayRecording(ctx, h.log, ws, player) + }() + + <-ctx.Done() }).ServeHTTP(w, r) + return nil, nil }