From 3c9e6bd1f7353593b501bab18ce81fbbe444de5b Mon Sep 17 00:00:00 2001 From: Zac Bergquist Date: Mon, 30 Oct 2023 21:29:37 -0700 Subject: [PATCH] Introduce a new streaming player API (#31754) This new API can be used to play back sessions of any type. The player accepts a session ID and a streamer, and provides the caller with an API for playback controls (speed, play/pause, seek, etc) as well as a channel that receives events with the proper timing delay applied. The design for this change is discussed in RFD 91. Updates #10578 Updates #10579 Updates gravitational/teleport-private#665 Updates gravitational/teleport-private#1024 --- lib/player/player.go | 321 ++++++++++++++++++++++++++++++++++++++ lib/player/player_test.go | 277 ++++++++++++++++++++++++++++++++ 2 files changed, 598 insertions(+) create mode 100644 lib/player/player.go create mode 100644 lib/player/player_test.go diff --git a/lib/player/player.go b/lib/player/player.go new file mode 100644 index 0000000000000..4be2813f12702 --- /dev/null +++ b/lib/player/player.go @@ -0,0 +1,321 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package player includes an API to play back recorded sessions. +package player + +import ( + "context" + "errors" + "math" + "os" + "sync/atomic" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/session" +) + +// Player is used to stream recorded sessions over a channel. +type Player struct { + // read only config fields + clock clockwork.Clock + log logrus.FieldLogger + sessionID session.ID + streamer Streamer + + speed atomic.Value // playback speed (1.0 for normal speed) + lastPlayed atomic.Int64 // timestamp of most recently played event + + // advanceTo is used to implement fast-forward and rewind. + // During normal operation, it is set to [normalPlayback]. + // + // When set to a positive value the player is seeking forward + // in time (and plays events as quickly as possible). + // + // When set to a negative value, the player needs to "rewind" + // by starting the stream over from the beginning and then + // seeking forward to the rewind point. + advanceTo atomic.Int64 + + emit chan events.AuditEvent + done chan struct{} + + // playPause holds a channel to be closed when + // the player transitions from paused to playing, + // or nil if the player is already playing. + // + // This approach mimics a "select-able" condition variable + // and is inspired by "Rethinking Classical Concurrency Patterns" + // by Bryan C. Mills (GopherCon 2018): https://www.youtube.com/watch?v=5zXAHh5tJqQ + playPause chan chan struct{} +} + +const normalPlayback = math.MinInt64 + +type Streamer interface { + StreamSessionEvents( + ctx context.Context, + sessionID session.ID, + startIndex int64, + ) (chan events.AuditEvent, chan error) +} + +// Config configures a session player. +type Config struct { + Clock clockwork.Clock + Log logrus.FieldLogger + SessionID session.ID + Streamer Streamer +} + +func New(cfg *Config) (*Player, error) { + if cfg.Streamer == nil { + return nil, trace.BadParameter("missing Streamer") + } + + if cfg.SessionID == "" { + return nil, trace.BadParameter("missing SessionID") + } + + clk := cfg.Clock + if clk == nil { + clk = clockwork.NewRealClock() + } + + var log logrus.FieldLogger = cfg.Log + if log == nil { + l := logrus.New().WithField(trace.Component, "player") + l.Logger.SetOutput(os.Stdout) // TODO(zmb3) remove + l.Logger.SetLevel(logrus.DebugLevel) + log = l + } + + p := &Player{ + clock: clk, + log: log, + sessionID: cfg.SessionID, + streamer: cfg.Streamer, + emit: make(chan events.AuditEvent, 64), + playPause: make(chan chan struct{}, 1), + done: make(chan struct{}), + } + + p.speed.Store(float64(defaultPlaybackSpeed)) + p.advanceTo.Store(normalPlayback) + + // start in a paused state + p.playPause <- make(chan struct{}) + + go p.stream() + + return p, nil +} + +// errClosed is an internal error that is used to signal +// that the player has been closed +var errClosed = errors.New("player closed") + +const ( + minPlaybackSpeed = 0.25 + defaultPlaybackSpeed = 1.0 + maxPlaybackSpeed = 16 +) + +func (p *Player) SetSpeed(s float64) error { + if s < minPlaybackSpeed || s > maxPlaybackSpeed { + return trace.BadParameter("speed %v is out of range", s) + } + p.speed.Store(s) + return nil +} + +func (p *Player) stream() { + // TODO(zmb3): consider using context instead of close chan + eventsC, errC := p.streamer.StreamSessionEvents(context.TODO(), p.sessionID, 0) + lastDelay := int64(0) + for { + select { + case <-p.done: + close(p.emit) + return + case err := <-errC: + // TODO(zmb3): figure out how to surface the error + // (probably close the chan and expose a method) + p.log.Warn(err) + return + case evt := <-eventsC: + if evt == nil { + p.log.Debug("reached end of playback") + close(p.emit) + return + } + + if err := p.waitWhilePaused(); err != nil { + p.log.Warn(err) + close(p.emit) + return + } + + currentDelay := getDelay(evt) + if currentDelay > 0 && currentDelay > lastDelay { + switch adv := p.advanceTo.Load(); { + case adv >= currentDelay: + // no timing delay necessary, we are fast forwarding + break + case adv < 0 && adv != normalPlayback: + // any negative value other than normalPlayback means + // we rewind (by restarting the stream and seeking forward + // to the rewind point) + p.advanceTo.Store(adv * -1) + go p.stream() + return + default: + if adv != normalPlayback { + p.advanceTo.Store(normalPlayback) + + // we're catching back up to real time, so the delay + // is calculated not from the last event but from the + // time we were advanced to + lastDelay = adv + } + if err := p.applyDelay(time.Duration(currentDelay-lastDelay) * time.Millisecond); err != nil { + close(p.emit) + return + } + } + + lastDelay = currentDelay + } + + p.log.Debugf("playing %v (%v)", evt.GetType(), evt.GetID()) + select { + case p.emit <- evt: + p.lastPlayed.Store(currentDelay) + default: + p.log.Warnf("dropped event %v, reader too slow", evt.GetID()) + } + } + } +} + +// Close shuts down the player and cancels any streams that are +// in progress. +func (p *Player) Close() error { + close(p.done) + return nil +} + +// C returns a read only channel of recorded session events. +// The player manages the timing of events and writes them to the channel +// when they should be rendered. The channel is closed when the player +// has reached the end of playback. +func (p *Player) C() <-chan events.AuditEvent { + return p.emit +} + +// TODO(zmb3): add an Err() method to be checked after C is closed + +// Pause temporarily stops the player from emitting events. +// It is a no-op if playback is currently paused. +func (p *Player) Pause() error { + p.setPlaying(false) + return nil +} + +// Play starts emitting events. It is used to start playback +// for the first time and to resume playing after the player +// is paused. +func (p *Player) Play() error { + p.setPlaying(true) + return nil +} + +// SetPos sets playback to a specific time, expressed as a duration +// from the beginning of the session. A duration of 0 restarts playback +// from the beginning. A duration greater than the length of the session +// will cause playback to rapidly advance to the end of the recording. +func (p *Player) SetPos(d time.Duration) error { + if d.Milliseconds() < p.lastPlayed.Load() { + // if we're rewinding we store a negative value + d = -1 * d + } + p.advanceTo.Store(d.Milliseconds()) + return nil +} + +// applyDelay "sleeps" for d in a manner that +// can be canceled +func (p *Player) applyDelay(d time.Duration) error { + p.log.Debugf("waiting %v until next event", d) + scaled := float64(d) / p.speed.Load().(float64) + select { + case <-p.done: + return errClosed + case <-p.clock.After(time.Duration(scaled)): + return nil + } +} + +func (p *Player) setPlaying(play bool) { + ch := <-p.playPause + alreadyPlaying := ch == nil + + if alreadyPlaying && !play { + ch = make(chan struct{}) + } else if !alreadyPlaying && play { + // signal waiters who are paused that it's time to resume playing + close(ch) + ch = nil + } + + p.playPause <- ch +} + +// waitWhilePaused blocks while the player is in a paused state. +// It returns immediately if the player is currently playing. +func (p *Player) waitWhilePaused() error { + ch := <-p.playPause + p.playPause <- ch + + if alreadyPlaying := ch == nil; !alreadyPlaying { + select { + case <-p.done: + return errClosed + case <-ch: + } + } + return nil +} + +// LastPlayed returns the time of the last played event, +// expressed as milliseconds since the start of the session. +func (p *Player) LastPlayed() int64 { + return p.lastPlayed.Load() +} + +func getDelay(e events.AuditEvent) int64 { + switch x := e.(type) { + case *events.DesktopRecording: + return x.DelayMilliseconds + case *events.SessionPrint: + return x.DelayMilliseconds + default: + return int64(0) + } +} diff --git a/lib/player/player_test.go b/lib/player/player_test.go new file mode 100644 index 0000000000000..0a0aa01588430 --- /dev/null +++ b/lib/player/player_test.go @@ -0,0 +1,277 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package player_test + +import ( + "context" + "fmt" + "strconv" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/player" + "github.com/gravitational/teleport/lib/session" +) + +func TestBasicStream(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 3}, + }) + require.NoError(t, err) + + require.NoError(t, p.Play()) + + count := 0 + for range p.C() { + count++ + } + + require.Equal(t, 3, count) +} + +func TestPlayPause(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 3}, + }) + require.NoError(t, err) + + // pausing an already paused player should be a no-op + require.NoError(t, p.Pause()) + require.NoError(t, p.Pause()) + + // toggling back and forth between play and pause + // should not impact our ability to receive all + // 3 events + require.NoError(t, p.Play()) + require.NoError(t, p.Pause()) + require.NoError(t, p.Play()) + + count := 0 + for range p.C() { + count++ + } + + require.Equal(t, 3, count) +} + +func TestAppliesTiming(t *testing.T) { + for _, test := range []struct { + desc string + speed float64 + advance time.Duration + }{ + { + desc: "half speed", + speed: 0.5, + advance: 2000 * time.Millisecond, + }, + { + desc: "normal speed", + speed: 1.0, + advance: 1000 * time.Millisecond, + }, + { + desc: "double speed", + speed: 2.0, + advance: 500 * time.Millisecond, + }, + } { + t.Run(test.desc, func(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 3, delay: 1000}, + }) + require.NoError(t, err) + + require.NoError(t, p.SetSpeed(test.speed)) + require.NoError(t, p.Play()) + + clk.BlockUntil(1) // player is now waiting to emit event 0 + + // advance to next event (player will have emitted event 0 + // and will be waiting to emit event 1) + clk.Advance(test.advance) + clk.BlockUntil(1) + evt := <-p.C() + require.Equal(t, int64(0), evt.GetIndex()) + + // repeat the process (emit event 1, wait for event 2) + clk.Advance(test.advance) + clk.BlockUntil(1) + evt = <-p.C() + require.Equal(t, int64(1), evt.GetIndex()) + + // advance the player to allow event 2 to be emitted + clk.Advance(test.advance) + evt = <-p.C() + require.Equal(t, int64(2), evt.GetIndex()) + + // channel should be closed + _, ok := <-p.C() + require.False(t, ok, "player should be closed") + }) + } +} + +func TestClose(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 2, delay: 1000}, + }) + require.NoError(t, err) + + require.NoError(t, p.Play()) + + clk.BlockUntil(1) // player is now waiting to emit event 0 + + // advance to next event (player will have emitted event 0 + // and will be waiting to emit event 1) + clk.Advance(1001 * time.Millisecond) + clk.BlockUntil(1) + evt := <-p.C() + require.Equal(t, int64(0), evt.GetIndex()) + + require.NoError(t, p.Close()) + + // channel should have been closed + _, ok := <-p.C() + require.False(t, ok, "player channel should have been closed") +} + +func TestSeekForward(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 10, delay: 1000}, + }) + require.NoError(t, err) + require.NoError(t, p.Play()) + + clk.BlockUntil(1) // player is now waiting to emit event 0 + + // advance playback until right before the last event + p.SetPos(9001 * time.Millisecond) + + // advance the clock to unblock the player + // (it should now spit out all but the last event in rapid succession) + clk.Advance(1001 * time.Millisecond) + + ch := make(chan struct{}) + go func() { + defer close(ch) + for evt := range p.C() { + t.Logf("got event %v (delay=%v)", evt.GetID(), evt.GetCode()) + } + }() + + clk.BlockUntil(1) + require.Equal(t, int64(9000), p.LastPlayed()) + + clk.Advance(999 * time.Millisecond) + select { + case <-ch: + case <-time.After(3 * time.Second): + require.FailNow(t, "player hasn't closed in time") + } +} + +func TestRewind(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 10, delay: 1000}, + }) + require.NoError(t, err) + require.NoError(t, p.Play()) + + // play through 7 events at regular speed + for i := 0; i < 7; i++ { + clk.BlockUntil(1) // player is now waiting to emit event + clk.Advance(1000 * time.Millisecond) // unblock event + <-p.C() // read event + } + + // now "rewind" to the point just prior to event index 3 (4000 ms into session) + clk.BlockUntil(1) + p.SetPos(3900 * time.Millisecond) + + // when we advance the clock, we expect the following behavior: + // - event index 7 (which we were blocked on) comes out right away + // - playback restarts, events 0 through 2 are emitted immediately + // - event index 3 is emitted after another 100ms + clk.Advance(1000 * time.Millisecond) + require.Equal(t, int64(7), (<-p.C()).GetIndex()) + require.Equal(t, int64(0), (<-p.C()).GetIndex(), "expected playback to retart for rewind") + require.Equal(t, int64(1), (<-p.C()).GetIndex(), "expected rapid streaming up to rewind point") + require.Equal(t, int64(2), (<-p.C()).GetIndex()) + clk.BlockUntil(1) + clk.Advance(100 * time.Millisecond) + require.Equal(t, int64(3), (<-p.C()).GetIndex()) + + p.Close() +} + +// simpleStreamer streams a fake session that contains +// count events, emitted at a particular interval +type simpleStreamer struct { + count int64 + delay int64 // milliseconds +} + +func (s *simpleStreamer) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) { + errors := make(chan error, 1) + evts := make(chan apievents.AuditEvent) + + go func() { + defer close(evts) + + for i := int64(0); i < s.count; i++ { + select { + case <-ctx.Done(): + return + case evts <- &apievents.SessionPrint{ + Metadata: apievents.Metadata{ + Type: events.SessionPrintEvent, + Index: i, + ID: strconv.Itoa(int(i)), + Code: strconv.FormatInt((i+1)*s.delay, 10), + }, + Data: []byte(fmt.Sprintf("event %d\n", i)), + ChunkIndex: i, // TODO(zmb3) deprecate this + DelayMilliseconds: (i + 1) * s.delay, + }: + } + } + }() + + return evts, errors +}