Skip to content

Commit

Permalink
Merge branch 'master' into joerger/oidc-redirect-url
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger authored May 26, 2022
2 parents 21651d0 + 42d3efd commit de9b752
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 81 deletions.
34 changes: 14 additions & 20 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,23 +509,17 @@ func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error {
return nil
}
// make sure whatever session is requested is a valid session
_, err := rsession.ParseID(ssid)
id, err := rsession.ParseID(ssid)
if err != nil {
return trace.BadParameter("invalid session id")
}

findSession := func() (*session, bool) {
reg.sessionsMux.Lock()
defer reg.sessionsMux.Unlock()
return reg.findSessionLocked(rsession.ID(ssid))
}

// update ctx with a session ID
c.session, _ = findSession()
if c.session == nil {
log.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr())
// update ctx with the session if it exists
if sess, found := reg.findSession(*id); found {
c.session = sess
c.Logger.Debugf("Will join session %v for SSH connection %v.", c.session.id, c.ServerConn.RemoteAddr())
} else {
log.Debugf("Will join session %v for SSH connection %v.", c.session.id, c.ServerConn.RemoteAddr())
c.Logger.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr())
}

return nil
Expand Down Expand Up @@ -676,16 +670,16 @@ func (c *ServerContext) OpenXServerListener(x11Req x11.ForwardRequestPayload, di
// The client's session hasn't been fully set up yet so this
// could potentially be a break-in attempt.
if ok, err := c.x11Ready(); err != nil {
log.WithError(err).Debug("Failed to get X11 ready status")
c.Logger.WithError(err).Debug("Failed to get X11 ready status")
return
} else if !ok {
log.WithError(err).Debug("Rejecting X11 request, XServer Proxy is not ready")
c.Logger.WithError(err).Debug("Rejecting X11 request, XServer Proxy is not ready")
return
}

xchan, sin, err := c.ServerConn.OpenChannel(sshutils.X11ChannelRequest, x11ChannelReqPayload)
if err != nil {
log.WithError(err).Debug("Failed to open a new X11 channel")
c.Logger.WithError(err).Debug("Failed to open a new X11 channel")
return
}
defer xchan.Close()
Expand All @@ -697,12 +691,12 @@ func (c *ServerContext) OpenXServerListener(x11Req x11.ForwardRequestPayload, di
go func() {
err := sshutils.ForwardRequests(ctx, sin, c.RemoteSession)
if err != nil {
log.WithError(err).Debug("Failed to forward ssh request from server during X11 forwarding")
c.Logger.WithError(err).Debug("Failed to forward ssh request from server during X11 forwarding")
}
}()

if err := x11.Forward(ctx, xconn, xchan); err != nil {
log.WithError(err).Debug("Encountered error during X11 forwarding")
c.Logger.WithError(err).Debug("Encountered error during X11 forwarding")
}
}()

Expand Down Expand Up @@ -940,7 +934,7 @@ func getPAMConfig(c *ServerContext) (*PAMConfig, error) {
// If the trait isn't passed by the IdP due to misconfiguration
// we fallback to setting a value which will indicate this.
if trace.IsNotFound(err) {
log.Warnf("Attempted to interpolate custom PAM environment with external trait %[1]q but received SAML response does not contain claim %[1]q", expr.Name())
c.Logger.Warnf("Attempted to interpolate custom PAM environment with external trait %[1]q but received SAML response does not contain claim %[1]q", expr.Name())
continue
}

Expand Down Expand Up @@ -1033,11 +1027,11 @@ func buildEnvironment(ctx *ServerContext) []string {
// SSH_CONNECTION environment variables.
remoteHost, remotePort, err := net.SplitHostPort(ctx.ServerConn.RemoteAddr().String())
if err != nil {
log.Debugf("Failed to split remote address: %v.", err)
ctx.Logger.Debugf("Failed to split remote address: %v.", err)
} else {
localHost, localPort, err := net.SplitHostPort(ctx.ServerConn.LocalAddr().String())
if err != nil {
log.Debugf("Failed to split local address: %v.", err)
ctx.Logger.Debugf("Failed to split local address: %v.", err)
} else {
env = append(env,
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
Expand Down
123 changes: 66 additions & 57 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ func (s *SessionRegistry) findSessionLocked(id rsession.ID) (*session, bool) {
sess, found := s.sessions[id]
return sess, found
}
func (s *SessionRegistry) findSession(id rsession.ID) (*session, bool) {
s.sessionsMux.Lock()
defer s.sessionsMux.Unlock()
return s.findSessionLocked(id)
}

func (s *SessionRegistry) Close() {
s.sessionsMux.Lock()
Expand Down Expand Up @@ -427,10 +432,6 @@ type session struct {

doneCh chan struct{}

bpfContext *bpf.SessionContext

cgroupID uint64

displayParticipantRequirements bool

// endingContext is the server context which closed this session.
Expand Down Expand Up @@ -549,6 +550,11 @@ func newSession(id rsession.ID, r *SessionRegistry, ctx *ServerContext) (*sessio
return nil, trace.Wrap(err)
}

sess.recorder, err = newRecorder(sess, ctx)
if err != nil {
return nil, trace.Wrap(err)
}

return sess, nil
}

Expand Down Expand Up @@ -593,6 +599,10 @@ func (s *session) Stop() {
// close io copy loops
s.io.Close()

// remove session from server context to prevent new requests
// from attempting to join the session during cleanup
s.scx.setSession(nil)

// Close and kill terminal
if s.term != nil {
if err := s.term.Close(); err != nil {
Expand Down Expand Up @@ -626,7 +636,6 @@ func (s *session) Close() error {
p.Close()
}

// Remove session from registry
s.registry.removeSession(s)

// Remove the session from the backend.
Expand Down Expand Up @@ -852,6 +861,12 @@ func (s *session) setEndingContext(ctx *ServerContext) {
s.endingContext = ctx
}

func (s *session) setHasEnhancedRecording(val bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.hasEnhancedRecording = val
}

// launch launches the session.
// Must be called under session Lock.
func (s *session) launch(ctx *ServerContext) error {
Expand Down Expand Up @@ -930,15 +945,6 @@ func (s *session) launch(ctx *ServerContext) error {
case <-s.doneCh:
}

ctx.srv.GetRestrictedSessionManager().CloseSession(s.bpfContext, s.cgroupID)

// Close the BPF recording session. If BPF was not configured, not available,
// or running in a recording proxy, this is simply a NOP.
err = ctx.srv.GetBPF().CloseSession(s.bpfContext)
if err != nil {
ctx.Errorf("Failed to close enhanced recording (interactive) session: %v: %v.", s.id, err)
}

if ctx.ExecRequest.GetCommand() != "" {
emitExecAuditEvent(ctx, ctx.ExecRequest.GetCommand(), err)
}
Expand All @@ -959,49 +965,29 @@ func (s *session) launch(ctx *ServerContext) error {
// startInteractive starts a new interactive process (or a shell) in the
// current session.
func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error {
// create a new "party" (connected client)
p := newParty(s, types.SessionPeerMode, ch, ctx)

rec, err := newRecorder(s, ctx)
if err != nil {
return trace.Wrap(err)
}
s.recorder = rec

// allocate a terminal or take the one previously allocated via a
// seaprate "allocate TTY" SSH request
if ctx.GetTerm() != nil {
s.term = ctx.GetTerm()
ctx.SetTerm(nil)
} else {
if s.term, err = NewTerminal(ctx); err != nil {
ctx.Infof("Unable to allocate new terminal: %v", err)
return trace.Wrap(err)
}
}

// Emit a session.start event for the interactive session.
s.emitSessionStartEvent(ctx)

inReader, inWriter := io.Pipe()
s.inWriter = inWriter
s.io.AddReader("reader", inReader)
s.io.AddWriter(sessionRecorderID, utils.WriteCloserWithContext(ctx.srv.Context(), s.recorder))
s.BroadcastMessage("Creating session with ID: %v...", s.id)
s.BroadcastMessage(SessionControlsInfoBroadcast)

if err := s.term.Run(); err != nil {
ctx.Errorf("Unable to run shell command: %v.", err)
return trace.ConvertSystemError(err)
if err := s.startTerminal(ctx); err != nil {
return trace.Wrap(err)
}

// Emit a session.start event for the interactive session.
s.emitSessionStartEvent(ctx)

// create a new "party" (connected client) and launch/join the session.
p := newParty(s, types.SessionPeerMode, ch, ctx)
if err := s.addParty(p, types.SessionPeerMode); err != nil {
return trace.Wrap(err)
}

// Open a BPF recording session. If BPF was not configured, not available,
// or running in a recording proxy, OpenSession is a NOP.
s.bpfContext = &bpf.SessionContext{
sessionContext := &bpf.SessionContext{
Context: ctx.srv.Context(),
PID: s.term.PID(),
Emitter: s.recorder,
Expand All @@ -1012,16 +998,23 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error {
User: ctx.Identity.TeleportUser,
Events: ctx.Identity.RoleSet.EnhancedRecordingSet(),
}
s.cgroupID, err = ctx.srv.GetBPF().OpenSession(s.bpfContext)
if err != nil {

if cgroupID, err := ctx.srv.GetBPF().OpenSession(sessionContext); err != nil {
ctx.Errorf("Failed to open enhanced recording (interactive) session: %v: %v.", s.id, err)
return trace.Wrap(err)
}

// If a cgroup ID was assigned then enhanced session recording was enabled.
if s.cgroupID > 0 {
s.hasEnhancedRecording = true
ctx.srv.GetRestrictedSessionManager().OpenSession(s.bpfContext, s.cgroupID)
} else if cgroupID > 0 {
// If a cgroup ID was assigned then enhanced session recording was enabled.
s.setHasEnhancedRecording(true)
ctx.srv.GetRestrictedSessionManager().OpenSession(sessionContext, cgroupID)
go func() {
// Close the BPF recording session once the session is closed
<-s.stopC
ctx.srv.GetRestrictedSessionManager().CloseSession(sessionContext, cgroupID)
err = ctx.srv.GetBPF().CloseSession(sessionContext)
if err != nil {
ctx.Errorf("Failed to close enhanced recording (interactive) session: %v: %v.", s.id, err)
}
}()
}

ctx.Debug("Waiting for continue signal")
Expand All @@ -1037,6 +1030,28 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error {
return nil
}

func (s *session) startTerminal(ctx *ServerContext) error {
s.mu.Lock()
defer s.mu.Unlock()

// allocate a terminal or take the one previously allocated via a
// separate "allocate TTY" SSH request
var err error
if s.term = ctx.GetTerm(); s.term != nil {
ctx.SetTerm(nil)
} else if s.term, err = NewTerminal(ctx); err != nil {
ctx.Infof("Unable to allocate new terminal: %v", err)
return trace.Wrap(err)
}

if err := s.term.Run(); err != nil {
ctx.Errorf("Unable to run shell command: %v.", err)
return trace.ConvertSystemError(err)
}

return nil
}

// newRecorder creates a new events.StreamWriter to be used as the recorder
// of the passed in session.
func newRecorder(s *session, ctx *ServerContext) (events.StreamWriter, error) {
Expand Down Expand Up @@ -1071,12 +1086,6 @@ func newRecorder(s *session, ctx *ServerContext) (events.StreamWriter, error) {
}

func (s *session) startExec(channel ssh.Channel, ctx *ServerContext) error {
rec, err := newRecorder(s, ctx)
if err != nil {
return trace.Wrap(err)
}
s.recorder = rec

// Emit a session.start event for the exec session.
s.emitSessionStartEvent(ctx)

Expand Down Expand Up @@ -1113,7 +1122,7 @@ func (s *session) startExec(channel ssh.Channel, ctx *ServerContext) error {

// If a cgroup ID was assigned then enhanced session recording was enabled.
if cgroupID > 0 {
s.hasEnhancedRecording = true
s.setHasEnhancedRecording(true)
ctx.srv.GetRestrictedSessionManager().OpenSession(sessionContext, cgroupID)
}

Expand Down
6 changes: 2 additions & 4 deletions lib/srv/sess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,10 @@ func TestInteractiveSession(t *testing.T) {
sess.Stop()

sessionClosed := func() bool {
reg.sessionsMux.Lock()
defer reg.sessionsMux.Unlock()
_, found := reg.findSessionLocked(sess.id)
_, found := reg.findSession(sess.id)
return !found
}
require.Eventually(t, sessionClosed, time.Second*5, time.Millisecond*500)
require.Eventually(t, sessionClosed, time.Second*15, time.Millisecond*500)
})

t.Run("BrokenRecorder", func(t *testing.T) {
Expand Down

0 comments on commit de9b752

Please sign in to comment.