Skip to content

Commit

Permalink
Fixes and cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Apr 30, 2022
1 parent 43e06ee commit d01ad4a
Show file tree
Hide file tree
Showing 14 changed files with 200 additions and 248 deletions.
15 changes: 4 additions & 11 deletions lib/events/complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,6 @@ func (u *UploadCompleter) start(ctx context.Context) {

// checkUploads fetches uploads and completes any abandoned uploads
func (u *UploadCompleter) checkUploads(ctx context.Context) error {
trackers, err := u.cfg.SessionTracker.GetActiveSessionTrackers(ctx)
if err != nil {
return trace.Wrap(err)
}

var activeSessionIDs []string
for _, st := range trackers {
activeSessionIDs = append(activeSessionIDs, st.GetSessionID())
}

uploads, err := u.cfg.Uploader.ListUploads(ctx)
if err != nil {
return trace.Wrap(err)
Expand All @@ -153,8 +143,11 @@ func (u *UploadCompleter) checkUploads(ctx context.Context) error {

// Complete upload for any uploads without an active session tracker
for _, upload := range uploads {
if apiutils.SliceContainsStr(activeSessionIDs, upload.SessionID.String()) {
_, err := u.cfg.SessionTracker.GetSessionTracker(ctx, upload.SessionID.String())
if err == nil {
continue
} else if !trace.IsNotFound(err) {
return trace.Wrap(err)
}

parts, err := u.cfg.Uploader.ListParts(ctx, upload)
Expand Down
30 changes: 14 additions & 16 deletions lib/srv/app/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,17 @@ type session struct {

// newSession creates a new session.
func (s *Server) newSession(ctx context.Context, identity *tlsca.Identity, app types.Application) (*session, error) {
sess := &session{id: identity.RouteToApp.SessionID}

// Create a session tracker so that other services, such as
// the session upload completer, can track the session's lifetime.
err := s.trackSession(sess, identity)
if err != nil {
return nil, trace.Wrap(err)
}

// Create the stream writer that will write this chunk to the audit log.
streamWriter, err := s.newStreamWriter(identity, app)
sess.streamWriter, err = s.newStreamWriter(identity, app)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -73,7 +82,7 @@ func (s *Server) newSession(ctx context.Context, identity *tlsca.Identity, app t
// Create a rewriting transport that will be used to forward requests.
transport, err := newTransport(s.closeContext,
&transportConfig{
w: streamWriter,
w: sess.streamWriter,
app: app,
publicPort: s.proxyPort,
cipherSuites: s.c.CipherSuites,
Expand All @@ -85,7 +94,8 @@ func (s *Server) newSession(ctx context.Context, identity *tlsca.Identity, app t
if err != nil {
return nil, trace.Wrap(err)
}
fwd, err := forward.New(

sess.fwd, err = forward.New(
forward.FlushInterval(100*time.Millisecond),
forward.RoundTripper(transport),
forward.Logger(logrus.StandardLogger()),
Expand All @@ -96,21 +106,9 @@ func (s *Server) newSession(ctx context.Context, identity *tlsca.Identity, app t
return nil, trace.Wrap(err)
}

sess := &session{
id: identity.RouteToApp.SessionID,
fwd: fwd,
streamWriter: streamWriter,
}

// Create a session tracker so that other services, such as
// the session upload completer, can track the session's lifetime.
if err := s.trackSession(sess, identity); err != nil {
return nil, trace.Wrap(err)
}

// Put the session in the cache so the next request can use it for 5 minutes
// or the time until the certificate expires, whichever comes first.
ttl := utils.MinTTL(identity.Expires.Sub(s.c.Clock.Now()), 5*time.Minute)
ttl := utils.MinTTL(identity.Expires.Sub(s.c.Clock.Now()), 5*time.Second)
err = s.cache.set(sess, ttl)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/common/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ type Auth interface {
// AuthConfig is the database access authenticator configuration.
type AuthConfig struct {
// AuthClient is the cluster auth client.
AuthClient *libauth.Client
AuthClient libauth.ClientI
// Clients provides interface for obtaining cloud provider clients.
Clients CloudClients
// Clock is the clock implementation.
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/common/engines.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ type EngineConfig struct {
// Audit emits database access audit events.
Audit Audit
// AuthClient is the cluster auth server client.
AuthClient *auth.Client
AuthClient auth.ClientI
// CloudClients provides access to cloud API clients.
CloudClients CloudClients
// Context is the database server close context.
Expand Down
58 changes: 0 additions & 58 deletions lib/srv/db/common/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@ limitations under the License.
package common

import (
"context"
"fmt"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/trace"

"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -60,58 +57,3 @@ func (c *Session) String() string {
return fmt.Sprintf("db[%v] identity[%v] dbUser[%v] dbName[%v]",
c.Database.GetName(), c.Identity.Username, c.DatabaseUser, c.DatabaseName)
}

// TrackSession creates a new session tracker for the database session.
// While ctx is open, the session tracker's expiration will be extended
// on an interval. Once the ctx is closed, the sessiont tracker's state
// will be updated to terminated.
func (c *Session) TrackSession(ctx context.Context, engineCfg EngineConfig) error {
engineCfg.Log.Debug("Creating session tracker")
initiator := &types.Participant{
ID: c.DatabaseUser,
User: c.Identity.Username,
}

tracker, err := types.NewSessionTracker(types.SessionTrackerSpecV1{
SessionID: c.ID,
Kind: string(types.DatabaseSessionKind),
State: types.SessionState_SessionStateRunning,
Hostname: c.HostID,
DatabaseName: c.DatabaseName,
ClusterName: c.ClusterName,
Login: "root",
Participants: []types.Participant{*initiator},
HostUser: initiator.User,
})
if err != nil {
return trace.Wrap(err)
}

err = engineCfg.AuthClient.UpsertSessionTracker(ctx, tracker)
if err != nil {
return trace.Wrap(err)
}

// Start go routine to push back session expiration until ctx is canceled (session ends).
go func() {
ticker := engineCfg.Clock.NewTicker(defaults.SessionTrackerExpirationUpdateInterval)
defer ticker.Stop()
for {
select {
case time := <-ticker.Chan():
err := services.UpdateSessionTrackerExpiry(ctx, engineCfg.AuthClient, c.ID, time.Add(defaults.SessionTrackerTTL))
if err != nil {
engineCfg.Log.WithError(err).Warningf("Failed to update session tracker expiration for session %v.", c.ID)
return
}
case <-ctx.Done():
if err := services.UpdateSessionTrackerState(engineCfg.Context, engineCfg.AuthClient, c.ID, types.SessionState_SessionStateTerminated); err != nil {
engineCfg.Log.WithError(err).Warningf("Failed to update session tracker state for session %v.", c.ID)
}
return
}
}
}()

return nil
}
8 changes: 0 additions & 8 deletions lib/srv/db/mongodb/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,6 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
e.Audit.OnSessionStart(e.Context, sessionCtx, nil)
defer e.Audit.OnSessionEnd(e.Context, sessionCtx)

// Create a session tracker so that other services, such as
// the session upload completer, can track the session's lifetime.
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
if err := sessionCtx.TrackSession(cancelCtx, e.EngineConfig); err != nil {
return trace.Wrap(err)
}

// Start reading client messages and sending them to server.
for {
clientMessage, err := protocol.ReadMessage(e.clientConn)
Expand Down
8 changes: 0 additions & 8 deletions lib/srv/db/mysql/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,6 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
e.Audit.OnSessionStart(e.Context, sessionCtx, nil)
defer e.Audit.OnSessionEnd(e.Context, sessionCtx)

// Create a session tracker so that other services, such as
// the session upload completer, can track the session's lifetime.
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
if err := sessionCtx.TrackSession(cancelCtx, e.EngineConfig); err != nil {
return trace.Wrap(err)
}

// Copy between the connections.
clientErrCh := make(chan error, 1)
serverErrCh := make(chan error, 1)
Expand Down
8 changes: 0 additions & 8 deletions lib/srv/db/postgres/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,6 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
e.Audit.OnSessionStart(e.Context, sessionCtx, nil)
defer e.Audit.OnSessionEnd(e.Context, sessionCtx)

// Create a session tracker so that other services, such as
// the session upload completer, can track the session's lifetime.
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
if err := sessionCtx.TrackSession(cancelCtx, e.EngineConfig); err != nil {
return trace.Wrap(err)
}

// Reconstruct pgconn.PgConn from hijacked connection for easier access
// to its utility methods (such as Close).
serverConn, err := pgconn.Construct(hijackedConn)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/proxyserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ type monitorConnConfig struct {
identity tlsca.Identity
clock clockwork.Clock
serverID string
authClient *auth.Client
authClient auth.ClientI
teleportUser string
emitter events.Emitter
log logrus.FieldLogger
Expand Down
8 changes: 0 additions & 8 deletions lib/srv/db/redis/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,6 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
e.Audit.OnSessionStart(e.Context, sessionCtx, nil)
defer e.Audit.OnSessionEnd(e.Context, sessionCtx)

// Create a session tracker so that other services, such as
// the session upload completer, can track the session's lifetime.
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
if err := sessionCtx.TrackSession(cancelCtx, e.EngineConfig); err != nil {
return trace.Wrap(err)
}

if err := e.process(ctx); err != nil {
return trace.Wrap(err)
}
Expand Down
65 changes: 64 additions & 1 deletion lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type Config struct {
// DataDir is the path to the data directory for the server.
DataDir string
// AuthClient is a client directly connected to the Auth server.
AuthClient *auth.Client
AuthClient auth.ClientI
// AccessPoint is a caching client connected to the Auth Server.
AccessPoint auth.DatabaseAccessPoint
// StreamEmitter is a non-blocking audit events emitter.
Expand Down Expand Up @@ -674,6 +674,14 @@ func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) erro
return trace.Wrap(err)
}

// Create a session tracker so that other services, such as
// the session upload completer, can track the session's lifetime.
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
if err := s.trackSession(cancelCtx, sessionCtx); err != nil {
return trace.Wrap(err)
}

streamWriter, err := s.newStreamWriter(sessionCtx)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -858,3 +866,58 @@ func fetchMySQLVersion(ctx context.Context, database types.Database) error {

return nil
}

// trackSession creates a new session tracker for the database session.
// While ctx is open, the session tracker's expiration will be extended
// on an interval. Once the ctx is closed, the sessiont tracker's state
// will be updated to terminated.
func (s *Server) trackSession(ctx context.Context, sessionCtx *common.Session) error {
s.log.Debug("Creating session tracker")
initiator := &types.Participant{
ID: sessionCtx.DatabaseUser,
User: sessionCtx.Identity.Username,
}

tracker, err := types.NewSessionTracker(types.SessionTrackerSpecV1{
SessionID: sessionCtx.ID,
Kind: string(types.DatabaseSessionKind),
State: types.SessionState_SessionStateRunning,
Hostname: sessionCtx.HostID,
DatabaseName: sessionCtx.DatabaseName,
ClusterName: sessionCtx.ClusterName,
Login: "root",
Participants: []types.Participant{*initiator},
HostUser: initiator.User,
})
if err != nil {
return trace.Wrap(err)
}

err = s.cfg.AuthClient.UpsertSessionTracker(ctx, tracker)
if err != nil {
return trace.Wrap(err)
}

// Start go routine to push back session expiration until ctx is canceled (session ends).
go func() {
ticker := s.cfg.Clock.NewTicker(defaults.SessionTrackerExpirationUpdateInterval)
defer ticker.Stop()
for {
select {
case time := <-ticker.Chan():
err := services.UpdateSessionTrackerExpiry(ctx, s.cfg.AuthClient, sessionCtx.ID, time.Add(defaults.SessionTrackerTTL))
if err != nil {
s.log.WithError(err).Warningf("Failed to update session tracker expiration for session %v.", sessionCtx.ID)
return
}
case <-ctx.Done():
if err := services.UpdateSessionTrackerState(s.closeContext, s.cfg.AuthClient, sessionCtx.ID, types.SessionState_SessionStateTerminated); err != nil {
s.log.WithError(err).Warningf("Failed to update session tracker state for session %v.", sessionCtx.ID)
}
return
}
}
}()

return nil
}
8 changes: 0 additions & 8 deletions lib/srv/db/sqlserver/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,6 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
}
defer serverConn.Close()

// Create a session tracker so that other services, such as
// the session upload completer, can track the session's lifetime.
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
if err := sessionCtx.TrackSession(cancelCtx, e.EngineConfig); err != nil {
return trace.Wrap(err)
}

// Pass all flags returned by server during login back to the client.
err = protocol.WriteStreamResponse(e.clientConn, serverFlags)
if err != nil {
Expand Down
Loading

0 comments on commit d01ad4a

Please sign in to comment.