Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent panic caused by nil session recorder #10792

Merged
merged 5 commits into from
Mar 4, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 39 additions & 53 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -927,34 +927,11 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error {
// create a new "party" (connected client)
p := newParty(s, types.SessionPeerMode, ch, ctx)

// Nodes discard events in cases when proxies are already recording them.
if s.registry.srv.Component() == teleport.ComponentNode &&
services.IsRecordAtProxy(ctx.SessionRecordingConfig.GetMode()) {
s.recorder = &events.DiscardStream{}
} else {
streamer, err := s.newStreamer(ctx)
if err != nil {
return trace.Wrap(err)
}
s.recorder, err = events.NewAuditWriter(events.AuditWriterConfig{
// Audit stream is using server context, not session context,
// to make sure that session is uploaded even after it is closed
Context: ctx.srv.Context(),
Streamer: streamer,
Clock: ctx.srv.GetClock(),
SessionID: s.id,
Namespace: ctx.srv.GetNamespace(),
ServerID: ctx.srv.HostUUID(),
RecordOutput: ctx.SessionRecordingConfig.GetMode() != types.RecordOff,
Component: teleport.Component(teleport.ComponentSession, ctx.srv.Component()),
ClusterName: ctx.ClusterName,
})
if err != nil {
return trace.Wrap(err)
}
rec, err := newRecorder(s, ctx)
if err != nil {
return trace.Wrap(err)
}

var err error
s.recorder = rec

// allocate a terminal or take the one previously allocated via a
// seaprate "allocate TTY" SSH request
Expand Down Expand Up @@ -1055,36 +1032,45 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error {
return nil
}

func (s *session) startExec(channel ssh.Channel, ctx *ServerContext) error {
var err error

// newRecorder creates a new events.StreamWriter to be used as the recorder
// of the passed in session.
func newRecorder(s *session, ctx *ServerContext) (events.StreamWriter, error) {
// Nodes discard events in cases when proxies are already recording them.
if s.registry.srv.Component() == teleport.ComponentNode &&
services.IsRecordAtProxy(ctx.SessionRecordingConfig.GetMode()) {
s.recorder = &events.DiscardStream{}
} else {
streamer, err := s.newStreamer(ctx)
if err != nil {
return trace.Wrap(err)
}
rec, err := events.NewAuditWriter(events.AuditWriterConfig{
// Audit stream is using server context, not session context,
// to make sure that session is uploaded even after it is closed
Context: ctx.srv.Context(),
Streamer: streamer,
SessionID: s.id,
Clock: ctx.srv.GetClock(),
Namespace: ctx.srv.GetNamespace(),
ServerID: ctx.srv.HostUUID(),
RecordOutput: ctx.SessionRecordingConfig.GetMode() != types.RecordOff,
Component: teleport.Component(teleport.ComponentSession, ctx.srv.Component()),
ClusterName: ctx.ClusterName,
})
if err != nil {
return trace.Wrap(err)
}
s.recorder = rec
return &events.DiscardStream{}, nil
}

streamer, err := s.newStreamer(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
rec, err := events.NewAuditWriter(events.AuditWriterConfig{
// Audit stream is using server context, not session context,
// to make sure that session is uploaded even after it is closed
Context: ctx.srv.Context(),
Streamer: streamer,
SessionID: s.id,
Clock: ctx.srv.GetClock(),
Namespace: ctx.srv.GetNamespace(),
ServerID: ctx.srv.HostUUID(),
RecordOutput: ctx.SessionRecordingConfig.GetMode() != types.RecordOff,
Component: teleport.Component(teleport.ComponentSession, ctx.srv.Component()),
ClusterName: ctx.ClusterName,
})
if err != nil {
return nil, trace.Wrap(err)
}

return rec, nil
}

func (s *session) startExec(channel ssh.Channel, ctx *ServerContext) error {
rec, err := newRecorder(s, ctx)
if err != nil {
return trace.Wrap(err)
}
s.recorder = rec

// Emit a session.start event for the exec session.
sessionStartEvent := &apievents.SessionStart{
Expand Down
248 changes: 248 additions & 0 deletions lib/srv/sess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,20 @@ limitations under the License.
package srv

import (
"context"
"testing"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/bpf"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/pam"
restricted "github.com/gravitational/teleport/lib/restrictedsession"
"github.com/gravitational/teleport/lib/services"
rsession "github.com/gravitational/teleport/lib/session"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -63,3 +75,239 @@ func TestParseAccessRequestIDs(t *testing.T) {
}

}

var _ Server = (*mockServer)(nil)
rosstimothy marked this conversation as resolved.
Show resolved Hide resolved

type mockServer struct {
events.StreamEmitter
}

// ID is the unique ID of the server.
func (m *mockServer) ID() string {
return "test"
}

// HostUUID is the UUID of the underlying host. For the forwarding
// server this is the proxy the forwarding server is running in.
func (m *mockServer) HostUUID() string {
return "test"
}

// GetNamespace returns the namespace the server was created in.
func (m *mockServer) GetNamespace() string {
return "test"
}

// AdvertiseAddr is the publicly addressable address of this server.
func (m *mockServer) AdvertiseAddr() string {
return "test"
}

// Component is the type of server, forwarding or regular.
func (m *mockServer) Component() string {
return teleport.ComponentNode
}

// PermitUserEnvironment returns if reading environment variables upon
// startup is allowed.
func (m *mockServer) PermitUserEnvironment() bool {
return false
}

// GetAccessPoint returns an AccessPoint for this cluster.
func (m *mockServer) GetAccessPoint() AccessPoint {
return nil
}

// GetSessionServer returns a session server.
func (m *mockServer) GetSessionServer() rsession.Service {
return nil
}

// GetDataDir returns data directory of the server
func (m *mockServer) GetDataDir() string {
return "test"
}

// GetPAM returns PAM configuration for this server.
func (m *mockServer) GetPAM() (*pam.Config, error) {
return nil, nil
}

// GetClock returns a clock setup for the server
func (m *mockServer) GetClock() clockwork.Clock {
return clockwork.NewRealClock()
}

// GetInfo returns a services.Server that represents this server.
func (m *mockServer) GetInfo() types.Server {
return nil
}

// UseTunnel used to determine if this node has connected to this cluster
// using reverse tunnel.
func (m *mockServer) UseTunnel() bool {
return false
}

// GetBPF returns the BPF service used for enhanced session recording.
func (m *mockServer) GetBPF() bpf.BPF {
return nil
}

// GetRestrictedSessionManager returns the manager for restricting user activity
func (m *mockServer) GetRestrictedSessionManager() restricted.Manager {
return nil
}

// Context returns server shutdown context
func (m *mockServer) Context() context.Context {
return context.Background()
}

// GetUtmpPath returns the path of the user accounting database and log. Returns empty for system defaults.
func (m *mockServer) GetUtmpPath() (utmp, wtmp string) {
return "test", "test"
}

// GetLockWatcher gets the server's lock watcher.
func (m *mockServer) GetLockWatcher() *services.LockWatcher {
return nil
}

func TestSession_newRecorder(t *testing.T) {
proxyRecording, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{
Mode: types.RecordAtProxy,
})
require.NoError(t, err)

proxyRecordingSync, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{
Mode: types.RecordAtProxySync,
})
require.NoError(t, err)

nodeRecording, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{
Mode: types.RecordAtNode,
})
require.NoError(t, err)

nodeRecordingSync, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{
Mode: types.RecordAtNodeSync,
})
require.NoError(t, err)

logger := logrus.WithFields(logrus.Fields{
trace.Component: teleport.ComponentAuth,
})

cases := []struct {
desc string
sess *session
sctx *ServerContext
errAssertion require.ErrorAssertionFunc
recAssertion require.ValueAssertionFunc
}{
{
desc: "discard-stream-when-proxy-recording",
sess: &session{
id: "test",
log: logger,
registry: &SessionRegistry{
srv: &mockServer{},
},
},
sctx: &ServerContext{
SessionRecordingConfig: proxyRecording,
},
errAssertion: require.NoError,
recAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.NotNil(t, i)
_, ok := i.(*events.DiscardStream)
require.True(t, ok)
},
},
{
desc: "discard-stream--when-proxy-sync-recording",
sess: &session{
id: "test",
log: logger,
registry: &SessionRegistry{
srv: &mockServer{},
},
},
sctx: &ServerContext{
SessionRecordingConfig: proxyRecordingSync,
},
errAssertion: require.NoError,
recAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.NotNil(t, i)
_, ok := i.(*events.DiscardStream)
require.True(t, ok)
},
},
{
desc: "err-new-streamer-fails",
sess: &session{
id: "test",
log: logger,
registry: &SessionRegistry{
srv: &mockServer{},
},
},
sctx: &ServerContext{
SessionRecordingConfig: nodeRecording,
srv: &mockServer{},
},
errAssertion: require.Error,
recAssertion: require.Nil,
},
{
desc: "err-new-audit-writer-fails",
sess: &session{
id: "test",
log: logger,
registry: &SessionRegistry{
srv: &mockServer{},
},
},
sctx: &ServerContext{
SessionRecordingConfig: nodeRecordingSync,
srv: &mockServer{},
},
errAssertion: require.Error,
recAssertion: require.Nil,
},
{
desc: "audit-writer",
sess: &session{
id: "test",
log: logger,
registry: &SessionRegistry{
srv: &mockServer{},
},
},
sctx: &ServerContext{
ClusterName: "test",
SessionRecordingConfig: nodeRecordingSync,
srv: &mockServer{
StreamEmitter: &events.DiscardEmitter{},
},
},
errAssertion: require.NoError,
recAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.NotNil(t, i)
aw, ok := i.(*events.AuditWriter)
require.True(t, ok)
require.NoError(t, aw.Close(context.Background()))
},
},
}

for _, tt := range cases {
t.Run(tt.desc, func(t *testing.T) {
rec, err := newRecorder(tt.sess, tt.sctx)
tt.errAssertion(t, err)
tt.recAssertion(t, rec)
})
}
}