From d01ad4aa01b2170c83d94553b80cfd3b23dfd0ef Mon Sep 17 00:00:00 2001 From: joerger Date: Fri, 29 Apr 2022 15:57:34 -0700 Subject: [PATCH] Fixes and cleanup. --- lib/events/complete.go | 15 +-- lib/srv/app/session.go | 30 +++-- lib/srv/db/common/auth.go | 2 +- lib/srv/db/common/engines.go | 2 +- lib/srv/db/common/session.go | 58 --------- lib/srv/db/mongodb/engine.go | 8 -- lib/srv/db/mysql/engine.go | 8 -- lib/srv/db/postgres/engine.go | 8 -- lib/srv/db/proxyserver.go | 2 +- lib/srv/db/redis/engine.go | 8 -- lib/srv/db/server.go | 65 +++++++++- lib/srv/db/sqlserver/engine.go | 8 -- lib/srv/db/tracker_test.go | 198 +++++++++++++++--------------- lib/srv/desktop/windows_server.go | 36 +++--- 14 files changed, 200 insertions(+), 248 deletions(-) diff --git a/lib/events/complete.go b/lib/events/complete.go index 3ab981c01c883..9759e7cdbc120 100644 --- a/lib/events/complete.go +++ b/lib/events/complete.go @@ -129,16 +129,6 @@ func (u *UploadCompleter) start(ctx context.Context) { // checkUploads fetches uploads and completes any abandoned uploads func (u *UploadCompleter) checkUploads(ctx context.Context) error { - trackers, err := u.cfg.SessionTracker.GetActiveSessionTrackers(ctx) - if err != nil { - return trace.Wrap(err) - } - - var activeSessionIDs []string - for _, st := range trackers { - activeSessionIDs = append(activeSessionIDs, st.GetSessionID()) - } - uploads, err := u.cfg.Uploader.ListUploads(ctx) if err != nil { return trace.Wrap(err) @@ -153,8 +143,11 @@ func (u *UploadCompleter) checkUploads(ctx context.Context) error { // Complete upload for any uploads without an active session tracker for _, upload := range uploads { - if apiutils.SliceContainsStr(activeSessionIDs, upload.SessionID.String()) { + _, err := u.cfg.SessionTracker.GetSessionTracker(ctx, upload.SessionID.String()) + if err == nil { continue + } else if !trace.IsNotFound(err) { + return trace.Wrap(err) } parts, err := u.cfg.Uploader.ListParts(ctx, upload) diff --git a/lib/srv/app/session.go b/lib/srv/app/session.go index d0d17726f8f89..5106f3b50b89a 100644 --- a/lib/srv/app/session.go +++ b/lib/srv/app/session.go @@ -53,8 +53,17 @@ type session struct { // newSession creates a new session. func (s *Server) newSession(ctx context.Context, identity *tlsca.Identity, app types.Application) (*session, error) { + sess := &session{id: identity.RouteToApp.SessionID} + + // Create a session tracker so that other services, such as + // the session upload completer, can track the session's lifetime. + err := s.trackSession(sess, identity) + if err != nil { + return nil, trace.Wrap(err) + } + // Create the stream writer that will write this chunk to the audit log. - streamWriter, err := s.newStreamWriter(identity, app) + sess.streamWriter, err = s.newStreamWriter(identity, app) if err != nil { return nil, trace.Wrap(err) } @@ -73,7 +82,7 @@ func (s *Server) newSession(ctx context.Context, identity *tlsca.Identity, app t // Create a rewriting transport that will be used to forward requests. transport, err := newTransport(s.closeContext, &transportConfig{ - w: streamWriter, + w: sess.streamWriter, app: app, publicPort: s.proxyPort, cipherSuites: s.c.CipherSuites, @@ -85,7 +94,8 @@ func (s *Server) newSession(ctx context.Context, identity *tlsca.Identity, app t if err != nil { return nil, trace.Wrap(err) } - fwd, err := forward.New( + + sess.fwd, err = forward.New( forward.FlushInterval(100*time.Millisecond), forward.RoundTripper(transport), forward.Logger(logrus.StandardLogger()), @@ -96,21 +106,9 @@ func (s *Server) newSession(ctx context.Context, identity *tlsca.Identity, app t return nil, trace.Wrap(err) } - sess := &session{ - id: identity.RouteToApp.SessionID, - fwd: fwd, - streamWriter: streamWriter, - } - - // Create a session tracker so that other services, such as - // the session upload completer, can track the session's lifetime. - if err := s.trackSession(sess, identity); err != nil { - return nil, trace.Wrap(err) - } - // Put the session in the cache so the next request can use it for 5 minutes // or the time until the certificate expires, whichever comes first. - ttl := utils.MinTTL(identity.Expires.Sub(s.c.Clock.Now()), 5*time.Minute) + ttl := utils.MinTTL(identity.Expires.Sub(s.c.Clock.Now()), 5*time.Second) err = s.cache.set(sess, ttl) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go index 837670237f6bc..544839ebbfad0 100644 --- a/lib/srv/db/common/auth.go +++ b/lib/srv/db/common/auth.go @@ -70,7 +70,7 @@ type Auth interface { // AuthConfig is the database access authenticator configuration. type AuthConfig struct { // AuthClient is the cluster auth client. - AuthClient *libauth.Client + AuthClient libauth.ClientI // Clients provides interface for obtaining cloud provider clients. Clients CloudClients // Clock is the clock implementation. diff --git a/lib/srv/db/common/engines.go b/lib/srv/db/common/engines.go index 25b773677ecce..107565d698a18 100644 --- a/lib/srv/db/common/engines.go +++ b/lib/srv/db/common/engines.go @@ -70,7 +70,7 @@ type EngineConfig struct { // Audit emits database access audit events. Audit Audit // AuthClient is the cluster auth server client. - AuthClient *auth.Client + AuthClient auth.ClientI // CloudClients provides access to cloud API clients. CloudClients CloudClients // Context is the database server close context. diff --git a/lib/srv/db/common/session.go b/lib/srv/db/common/session.go index 4ac2a0a22aa4d..8f214471865b7 100644 --- a/lib/srv/db/common/session.go +++ b/lib/srv/db/common/session.go @@ -17,14 +17,11 @@ limitations under the License. package common import ( - "context" "fmt" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/tlsca" - "github.com/gravitational/trace" "github.com/sirupsen/logrus" ) @@ -60,58 +57,3 @@ func (c *Session) String() string { return fmt.Sprintf("db[%v] identity[%v] dbUser[%v] dbName[%v]", c.Database.GetName(), c.Identity.Username, c.DatabaseUser, c.DatabaseName) } - -// TrackSession creates a new session tracker for the database session. -// While ctx is open, the session tracker's expiration will be extended -// on an interval. Once the ctx is closed, the sessiont tracker's state -// will be updated to terminated. -func (c *Session) TrackSession(ctx context.Context, engineCfg EngineConfig) error { - engineCfg.Log.Debug("Creating session tracker") - initiator := &types.Participant{ - ID: c.DatabaseUser, - User: c.Identity.Username, - } - - tracker, err := types.NewSessionTracker(types.SessionTrackerSpecV1{ - SessionID: c.ID, - Kind: string(types.DatabaseSessionKind), - State: types.SessionState_SessionStateRunning, - Hostname: c.HostID, - DatabaseName: c.DatabaseName, - ClusterName: c.ClusterName, - Login: "root", - Participants: []types.Participant{*initiator}, - HostUser: initiator.User, - }) - if err != nil { - return trace.Wrap(err) - } - - err = engineCfg.AuthClient.UpsertSessionTracker(ctx, tracker) - if err != nil { - return trace.Wrap(err) - } - - // Start go routine to push back session expiration until ctx is canceled (session ends). - go func() { - ticker := engineCfg.Clock.NewTicker(defaults.SessionTrackerExpirationUpdateInterval) - defer ticker.Stop() - for { - select { - case time := <-ticker.Chan(): - err := services.UpdateSessionTrackerExpiry(ctx, engineCfg.AuthClient, c.ID, time.Add(defaults.SessionTrackerTTL)) - if err != nil { - engineCfg.Log.WithError(err).Warningf("Failed to update session tracker expiration for session %v.", c.ID) - return - } - case <-ctx.Done(): - if err := services.UpdateSessionTrackerState(engineCfg.Context, engineCfg.AuthClient, c.ID, types.SessionState_SessionStateTerminated); err != nil { - engineCfg.Log.WithError(err).Warningf("Failed to update session tracker state for session %v.", c.ID) - } - return - } - } - }() - - return nil -} diff --git a/lib/srv/db/mongodb/engine.go b/lib/srv/db/mongodb/engine.go index 94d5d717384db..2d4fa1708162d 100644 --- a/lib/srv/db/mongodb/engine.go +++ b/lib/srv/db/mongodb/engine.go @@ -88,14 +88,6 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio e.Audit.OnSessionStart(e.Context, sessionCtx, nil) defer e.Audit.OnSessionEnd(e.Context, sessionCtx) - // Create a session tracker so that other services, such as - // the session upload completer, can track the session's lifetime. - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - if err := sessionCtx.TrackSession(cancelCtx, e.EngineConfig); err != nil { - return trace.Wrap(err) - } - // Start reading client messages and sending them to server. for { clientMessage, err := protocol.ReadMessage(e.clientConn) diff --git a/lib/srv/db/mysql/engine.go b/lib/srv/db/mysql/engine.go index 0b2a7046fd68b..c27273604ca96 100644 --- a/lib/srv/db/mysql/engine.go +++ b/lib/srv/db/mysql/engine.go @@ -118,14 +118,6 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio e.Audit.OnSessionStart(e.Context, sessionCtx, nil) defer e.Audit.OnSessionEnd(e.Context, sessionCtx) - // Create a session tracker so that other services, such as - // the session upload completer, can track the session's lifetime. - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - if err := sessionCtx.TrackSession(cancelCtx, e.EngineConfig); err != nil { - return trace.Wrap(err) - } - // Copy between the connections. clientErrCh := make(chan error, 1) serverErrCh := make(chan error, 1) diff --git a/lib/srv/db/postgres/engine.go b/lib/srv/db/postgres/engine.go index 6cc9896680514..2d72463cdc9e0 100644 --- a/lib/srv/db/postgres/engine.go +++ b/lib/srv/db/postgres/engine.go @@ -130,14 +130,6 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio e.Audit.OnSessionStart(e.Context, sessionCtx, nil) defer e.Audit.OnSessionEnd(e.Context, sessionCtx) - // Create a session tracker so that other services, such as - // the session upload completer, can track the session's lifetime. - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - if err := sessionCtx.TrackSession(cancelCtx, e.EngineConfig); err != nil { - return trace.Wrap(err) - } - // Reconstruct pgconn.PgConn from hijacked connection for easier access // to its utility methods (such as Close). serverConn, err := pgconn.Construct(hijackedConn) diff --git a/lib/srv/db/proxyserver.go b/lib/srv/db/proxyserver.go index 5870e584b907a..1fe356fb7cd25 100644 --- a/lib/srv/db/proxyserver.go +++ b/lib/srv/db/proxyserver.go @@ -518,7 +518,7 @@ type monitorConnConfig struct { identity tlsca.Identity clock clockwork.Clock serverID string - authClient *auth.Client + authClient auth.ClientI teleportUser string emitter events.Emitter log logrus.FieldLogger diff --git a/lib/srv/db/redis/engine.go b/lib/srv/db/redis/engine.go index 10fb02774c2a9..37a8fc5383d6a 100644 --- a/lib/srv/db/redis/engine.go +++ b/lib/srv/db/redis/engine.go @@ -167,14 +167,6 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio e.Audit.OnSessionStart(e.Context, sessionCtx, nil) defer e.Audit.OnSessionEnd(e.Context, sessionCtx) - // Create a session tracker so that other services, such as - // the session upload completer, can track the session's lifetime. - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - if err := sessionCtx.TrackSession(cancelCtx, e.EngineConfig); err != nil { - return trace.Wrap(err) - } - if err := e.process(ctx); err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 48a6ca10e0d16..3d74eba9f09c9 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -58,7 +58,7 @@ type Config struct { // DataDir is the path to the data directory for the server. DataDir string // AuthClient is a client directly connected to the Auth server. - AuthClient *auth.Client + AuthClient auth.ClientI // AccessPoint is a caching client connected to the Auth Server. AccessPoint auth.DatabaseAccessPoint // StreamEmitter is a non-blocking audit events emitter. @@ -674,6 +674,14 @@ func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) erro return trace.Wrap(err) } + // Create a session tracker so that other services, such as + // the session upload completer, can track the session's lifetime. + cancelCtx, cancel := context.WithCancel(ctx) + defer cancel() + if err := s.trackSession(cancelCtx, sessionCtx); err != nil { + return trace.Wrap(err) + } + streamWriter, err := s.newStreamWriter(sessionCtx) if err != nil { return trace.Wrap(err) @@ -858,3 +866,58 @@ func fetchMySQLVersion(ctx context.Context, database types.Database) error { return nil } + +// trackSession creates a new session tracker for the database session. +// While ctx is open, the session tracker's expiration will be extended +// on an interval. Once the ctx is closed, the sessiont tracker's state +// will be updated to terminated. +func (s *Server) trackSession(ctx context.Context, sessionCtx *common.Session) error { + s.log.Debug("Creating session tracker") + initiator := &types.Participant{ + ID: sessionCtx.DatabaseUser, + User: sessionCtx.Identity.Username, + } + + tracker, err := types.NewSessionTracker(types.SessionTrackerSpecV1{ + SessionID: sessionCtx.ID, + Kind: string(types.DatabaseSessionKind), + State: types.SessionState_SessionStateRunning, + Hostname: sessionCtx.HostID, + DatabaseName: sessionCtx.DatabaseName, + ClusterName: sessionCtx.ClusterName, + Login: "root", + Participants: []types.Participant{*initiator}, + HostUser: initiator.User, + }) + if err != nil { + return trace.Wrap(err) + } + + err = s.cfg.AuthClient.UpsertSessionTracker(ctx, tracker) + if err != nil { + return trace.Wrap(err) + } + + // Start go routine to push back session expiration until ctx is canceled (session ends). + go func() { + ticker := s.cfg.Clock.NewTicker(defaults.SessionTrackerExpirationUpdateInterval) + defer ticker.Stop() + for { + select { + case time := <-ticker.Chan(): + err := services.UpdateSessionTrackerExpiry(ctx, s.cfg.AuthClient, sessionCtx.ID, time.Add(defaults.SessionTrackerTTL)) + if err != nil { + s.log.WithError(err).Warningf("Failed to update session tracker expiration for session %v.", sessionCtx.ID) + return + } + case <-ctx.Done(): + if err := services.UpdateSessionTrackerState(s.closeContext, s.cfg.AuthClient, sessionCtx.ID, types.SessionState_SessionStateTerminated); err != nil { + s.log.WithError(err).Warningf("Failed to update session tracker state for session %v.", sessionCtx.ID) + } + return + } + } + }() + + return nil +} diff --git a/lib/srv/db/sqlserver/engine.go b/lib/srv/db/sqlserver/engine.go index d7dcd358d2895..554cc7228e8ff 100644 --- a/lib/srv/db/sqlserver/engine.go +++ b/lib/srv/db/sqlserver/engine.go @@ -93,14 +93,6 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio } defer serverConn.Close() - // Create a session tracker so that other services, such as - // the session upload completer, can track the session's lifetime. - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - if err := sessionCtx.TrackSession(cancelCtx, e.EngineConfig); err != nil { - return trace.Wrap(err) - } - // Pass all flags returned by server during login back to the client. err = protocol.WriteStreamResponse(e.clientConn, serverFlags) if err != nil { diff --git a/lib/srv/db/tracker_test.go b/lib/srv/db/tracker_test.go index 7d524505be6fb..fa80e20b5dc38 100644 --- a/lib/srv/db/tracker_test.go +++ b/lib/srv/db/tracker_test.go @@ -18,119 +18,113 @@ package db import ( "context" + "io" "testing" "time" + "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" ) -// TestSessiontTracker tests session tracker lifecycle for database sessions. func TestSessionTracker(t *testing.T) { ctx := context.Background() + log := logrus.New() + log.SetOutput(io.Discard) + clock := clockwork.NewFakeClockAt(time.Now()) - for _, tc := range []struct { - desc string - withDatabase withDatabaseOption - openFunc func(t *testing.T, testCtx *testContext) (closeFunc func(t *testing.T)) - }{ - { - desc: "postgres", - withDatabase: withSelfHostedPostgres("postgres"), - openFunc: func(t *testing.T, testCtx *testContext) (closeFunc func(t *testing.T)) { - psql, err := testCtx.postgresClient(ctx, "alice", "postgres", "admin", "admin") - require.NoError(t, err) - return func(t *testing.T) { - require.NoError(t, psql.Close(ctx)) - } - }, - }, { - desc: "mysql", - withDatabase: withSelfHostedMySQL("mysql"), - openFunc: func(t *testing.T, testCtx *testContext) (closeFunc func(t *testing.T)) { - mysql, err := testCtx.mysqlClient("alice", "mysql", "admin") - require.NoError(t, err) - return func(t *testing.T) { - require.NoError(t, mysql.Close()) - } - }, - }, { - desc: "mongo", - withDatabase: withSelfHostedMongo("mongo"), - openFunc: func(t *testing.T, testCtx *testContext) (closeFunc func(t *testing.T)) { - mongoClient, err := testCtx.mongoClient(ctx, "alice", "mongo", "admin") - require.NoError(t, err) - return func(t *testing.T) { - require.NoError(t, mongoClient.Disconnect(ctx)) - } - }, - }, { - desc: "redis", - withDatabase: withSelfHostedRedis("redis"), - openFunc: func(t *testing.T, testCtx *testContext) (closeFunc func(t *testing.T)) { - redisClient, err := testCtx.redisClient(ctx, "alice", "redis", "admin") - require.NoError(t, err) - return func(t *testing.T) { - require.NoError(t, redisClient.Close()) - } - }, - }, { - desc: "sqlserver", - withDatabase: withSQLServer("sqlserver"), - openFunc: func(t *testing.T, testCtx *testContext) (closeFunc func(t *testing.T)) { - conn, proxy, err := testCtx.sqlServerClient(ctx, "alice", "sqlserver", "admin", "master") - require.NoError(t, err) - return func(t *testing.T) { - require.NoError(t, conn.Close()) - require.NoError(t, proxy.Close()) - } - }, + mockAuthClient := &mockSessiontrackerService{ + clock: clock, + trackers: make(map[string]types.SessionTracker), + } + + s := &Server{ + closeContext: ctx, + log: logrus.NewEntry(log), + cfg: Config{ + Clock: clock, + AuthClient: mockAuthClient, }, - } { - t.Run(tc.desc, func(t *testing.T) { - t.Parallel() - - testCtx := setupTestContext(ctx, t, tc.withDatabase) - go testCtx.startHandlingConnections() - testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{"admin"}, []string{"admin"}) - - // Session tracker should be created for new connection - closeFunc := tc.openFunc(t, testCtx) - - var tracker types.SessionTracker - trackerCreated := func() bool { - trackers, err := testCtx.authClient.GetActiveSessionTrackers(ctx) - require.NoError(t, err) - // Note: mongo test creates 3 sessions (unrelated bug?), we just test the first one. - if len(trackers) > 0 { - tracker = trackers[0] - require.Equal(t, types.SessionState_SessionStateTerminated, tracker.GetState()) - return true - } - return false - } - require.Eventually(t, trackerCreated, time.Second*15, time.Second) - require.Equal(t, types.SessionState_SessionStateRunning, tracker.GetState()) - - // The session tracker expiration should be extended while the session is active - testCtx.clock.Advance(defaults.SessionTrackerExpirationUpdateInterval) - trackerUpdated := func() bool { - updatedTracker, err := testCtx.authClient.GetSessionTracker(ctx, tracker.GetSessionID()) - require.NoError(t, err) - return updatedTracker.Expiry().Equal(tracker.Expiry().Add(defaults.SessionTrackerExpirationUpdateInterval)) - } - require.Eventually(t, trackerUpdated, time.Second*15, time.Second) - - // Closing connection should trigger session tracker state to be terminated. - closeFunc(t) - - trackerTerminated := func() bool { - tracker, err := testCtx.authClient.GetSessionTracker(ctx, tracker.GetSessionID()) - require.NoError(t, err) - return tracker.GetState() == types.SessionState_SessionStateTerminated - } - require.Eventually(t, trackerTerminated, time.Second*15, time.Second) - }) } + + sessionCtx := &common.Session{ + ID: "sessionID", + DatabaseUser: "user", + Identity: tlsca.Identity{ + Username: "teleportUser", + }, + HostID: "hostname", + DatabaseName: "dbName", + ClusterName: "clusterName", + } + + cancelCtx, cancel := context.WithCancel(ctx) + err := s.trackSession(cancelCtx, sessionCtx) + require.NoError(t, err) + + // Tracker should be created + tracker, ok := mockAuthClient.trackers["sessionID"] + require.True(t, ok) + require.Equal(t, types.SessionState_SessionStateRunning, tracker.GetState()) + + // The session tracker expiration should be extended while the session is active + clock.BlockUntil(1) + expectedExpiry := tracker.Expiry().Add(defaults.SessionTrackerExpirationUpdateInterval) + clock.Advance(defaults.SessionTrackerExpirationUpdateInterval) + + trackerExpiryUpdated := func() bool { + return tracker.Expiry() == expectedExpiry + } + require.Eventually(t, trackerExpiryUpdated, time.Second*5, time.Second) + + // Closing ctx should trigger session tracker state to be terminated. + cancel() + trackerTerminated := func() bool { + return tracker.GetState() == types.SessionState_SessionStateTerminated + } + require.Eventually(t, trackerTerminated, time.Second*5, time.Second) +} + +type mockSessiontrackerService struct { + auth.ClientI + clock clockwork.Clock + trackers map[string]types.SessionTracker +} + +func (m *mockSessiontrackerService) GetActiveSessionTrackers(ctx context.Context) ([]types.SessionTracker, error) { + return nil, nil +} + +func (m *mockSessiontrackerService) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) { + return nil, nil +} + +func (m *mockSessiontrackerService) UpdateSessionTracker(ctx context.Context, req *proto.UpdateSessionTrackerRequest) error { + switch update := req.Update.(type) { + case *proto.UpdateSessionTrackerRequest_UpdateExpiry: + m.trackers[req.SessionID].SetExpiry(*update.UpdateExpiry.Expires) + case *proto.UpdateSessionTrackerRequest_UpdateState: + m.trackers[req.SessionID].SetState(update.UpdateState.State) + } + return nil +} + +func (m *mockSessiontrackerService) RemoveSessionTracker(ctx context.Context, sessionID string) error { + return nil +} + +func (m *mockSessiontrackerService) UpdatePresence(ctx context.Context, sessionID, user string) error { + return nil +} + +func (m *mockSessiontrackerService) UpsertSessionTracker(ctx context.Context, tracker types.SessionTracker) error { + tracker.SetExpiry(m.clock.Now().Add(defaults.SessionTrackerTTL)) + m.trackers[tracker.GetSessionID()] = tracker + return nil } diff --git a/lib/srv/desktop/windows_server.go b/lib/srv/desktop/windows_server.go index 417e3b1ee3067..22aa54500d32a 100644 --- a/lib/srv/desktop/windows_server.go +++ b/lib/srv/desktop/windows_server.go @@ -777,23 +777,6 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger, log.Infof("desktop session %v will not be recorded, user %v's roles disable recording", string(sessionID), authCtx.User.GetName()) } - sw, err := s.newStreamWriter(recordSession, string(sessionID)) - if err != nil { - return trace.Wrap(err) - } - - // Closing the stream writer is needed to flush all recorded data - // and trigger the upload. Do it in a goroutine since depending on - // the session size it can take a while, and we don't want to block - // the client. - defer func() { - go func() { - if err := sw.Close(context.Background()); err != nil { - log.WithError(err).Errorf("closing stream writer for desktop session %v", sessionID.String()) - } - }() - }() - var windowsUser string authorize := func(login string) error { windowsUser = login // capture attempted login user @@ -814,10 +797,29 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger, ctx, cancel := context.WithCancel(ctx) defer cancel() + // Create a session tracker so that other services, such as + // the session upload completer, can track the session's lifetime. if err := s.trackSession(ctx, &identity, windowsUser, string(sessionID), desktop); err != nil { return trace.Wrap(err) } + sw, err := s.newStreamWriter(recordSession, string(sessionID)) + if err != nil { + return trace.Wrap(err) + } + + // Closing the stream writer is needed to flush all recorded data + // and trigger the upload. Do it in a goroutine since depending on + // the session size it can take a while, and we don't want to block + // the client. + defer func() { + go func() { + if err := sw.Close(context.Background()); err != nil { + log.WithError(err).Errorf("closing stream writer for desktop session %v", sessionID.String()) + } + }() + }() + delay := timer() tdpConn.OnSend = s.makeTDPSendHandler(ctx, sw, delay, &identity, string(sessionID), desktop.GetAddr()) tdpConn.OnRecv = s.makeTDPReceiveHandler(ctx, sw, delay, &identity, string(sessionID), desktop.GetAddr())