diff --git a/api/types/accessgraph/authorized_key.go b/api/types/accessgraph/authorized_key.go index 5532f39bc2775..33de330f9c688 100644 --- a/api/types/accessgraph/authorized_key.go +++ b/api/types/accessgraph/authorized_key.go @@ -28,7 +28,8 @@ import ( ) const ( - authorizedKeyDefaultKeyTTL = 8 * time.Hour + // AuthorizedKeyDefaultKeyTTL is the default TTL for an authorized key. + AuthorizedKeyDefaultKeyTTL = 8 * time.Hour ) // NewAuthorizedKey creates a new SSH authorized key resource. @@ -40,7 +41,7 @@ func NewAuthorizedKey(spec *accessgraphv1pb.AuthorizedKeySpec) (*accessgraphv1pb Metadata: &headerv1.Metadata{ Name: name, Expires: timestamppb.New( - time.Now().Add(authorizedKeyDefaultKeyTTL), + time.Now().Add(AuthorizedKeyDefaultKeyTTL), ), }, Spec: spec, diff --git a/lib/secretsscanner/authorizedkeys/authorized_keys.go b/lib/secretsscanner/authorizedkeys/authorized_keys.go new file mode 100644 index 0000000000000..5e26720fae4c7 --- /dev/null +++ b/lib/secretsscanner/authorizedkeys/authorized_keys.go @@ -0,0 +1,388 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package authorizedkeys + +import ( + "bufio" + "context" + "errors" + "log/slog" + "os" + "os/user" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport/api/constants" + accessgraphsecretsv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" + clusterconfigpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/clusterconfig/v1" + "github.com/gravitational/teleport/api/types/accessgraph" + "github.com/gravitational/teleport/api/utils/retryutils" +) + +var ( + // ErrUnsupportedPlatform is returned when the operating system is not supported. + ErrUnsupportedPlatform = errors.New("unsupported platform") +) + +// Watcher watches for changes to authorized_keys files +// and reports them to the cluster. If the cluster does not have +// scanning enabled, the watcher will hold until the feature is enabled. +type Watcher struct { + // client is the client to use to communicate with the cluster. + client ClusterClient + logger *slog.Logger + clock clockwork.Clock + hostID string + usersAccountFile string +} + +// ClusterClient is the client to use to communicate with the cluster. +type ClusterClient interface { + GetClusterAccessGraphConfig(context.Context) (*clusterconfigpb.AccessGraphConfig, error) + AccessGraphSecretsScannerClient() accessgraphsecretsv1pb.SecretsScannerServiceClient +} + +// WatcherConfig is the configuration for the Watcher. +type WatcherConfig struct { + // Client is the client to use to communicate with the cluster. + Client ClusterClient + // Logger is the logger to use. + Logger *slog.Logger + // Clock is the clock to use. + Clock clockwork.Clock + // HostID is the ID of the host. + HostID string + // getRuntimeOS returns the runtime operating system. + // used for testing purposes. + getRuntimeOS func() string + // etcPasswdFile is the path to the file that contains the users account information on the system. + // This file is used to get the list of users on the system and their home directories. + // Value is set to "/etc/passwd" by default. + etcPasswdFile string +} + +// NewWatcher creates a new Watcher instance. +// Returns [ErrUnsupportedPlatform] if the operating system is not supported. +func NewWatcher(ctx context.Context, config WatcherConfig) (*Watcher, error) { + + if getOS(config) != constants.LinuxOS { + return nil, trace.Wrap(ErrUnsupportedPlatform) + } + + if config.HostID == "" { + return nil, trace.BadParameter("missing host ID") + } + if config.Client == nil { + return nil, trace.BadParameter("missing client") + } + if config.Logger == nil { + config.Logger = slog.Default() + } + if config.Clock == nil { + config.Clock = clockwork.NewRealClock() + } + if config.etcPasswdFile == "" { + // etcPasswordPath is the path to the password file. + // This file is used to get the list of users on the system and their home directories. + const etcPasswordPath = "/etc/passwd" + config.etcPasswdFile = etcPasswordPath + } + + w := &Watcher{ + client: config.Client, + logger: config.Logger, + clock: config.Clock, + hostID: config.HostID, + usersAccountFile: config.etcPasswdFile, + } + + return w, nil +} + +func (w *Watcher) Run(ctx context.Context) error { + return trace.Wrap(w.monitorClusterConfigAndStart(ctx)) +} + +func (w *Watcher) monitorClusterConfigAndStart(ctx context.Context) error { + const tickerInterval = 30 * time.Minute + return trace.Wrap(supervisorRunner(ctx, supervisorRunnerConfig{ + clock: w.clock, + tickerInterval: tickerInterval, + runner: w.start, + checkIfMonitorEnabled: w.isAuthorizedKeysReportEnabled, + logger: w.logger, + })) +} + +// start starts the watcher. +func (w *Watcher) start(ctx context.Context) error { + wg := sync.WaitGroup{} + defer wg.Wait() + + fileWatcher, err := fsnotify.NewWatcher() + if err != nil { + return trace.Wrap(err) + } + defer func() { + if err := fileWatcher.Close(); err != nil { + w.logger.WarnContext(ctx, "Failed to close watcher", "error", err) + } + }() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + reload := make(chan struct{}) + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case <-fileWatcher.Events: + innerLoop: + for { + select { + case <-ctx.Done(): + return + case <-fileWatcher.Events: + case reload <- struct{}{}: + break innerLoop + } + } + case err := <-fileWatcher.Errors: + w.logger.WarnContext(ctx, "Error watching authorized_keys file", "error", err) + } + } + }() + + if err := fileWatcher.Add(w.usersAccountFile); err != nil { + w.logger.WarnContext(ctx, "Failed to add watcher for file", "error", err) + } + + stream, err := w.client.AccessGraphSecretsScannerClient().ReportAuthorizedKeys(ctx) + if err != nil { + return trace.Wrap(err) + } + + // Wait for the initial delay before sending the first report to spread the load. + // The initial delay is a random value between 0 and maxInitialDelay. + const maxInitialDelay = 5 * time.Minute + select { + case <-ctx.Done(): + return nil + case <-w.clock.After(retryutils.NewFullJitter()(maxInitialDelay)): + } + + jitterFunc := retryutils.NewHalfJitter() + // maxReSendInterval is the maximum interval to re-send the authorized keys report + // to the cluster in case of no changes. + const maxReSendInterval = accessgraph.AuthorizedKeyDefaultKeyTTL - 20*time.Minute + timer := w.clock.NewTimer(jitterFunc(maxReSendInterval)) + defer timer.Stop() + for { + + if err := w.fetchAndReportAuthorizedKeys(ctx, stream, fileWatcher); err != nil { + w.logger.WarnContext(ctx, "Failed to report authorized keys", "error", err) + } + + if !timer.Stop() { + <-timer.Chan() + } + timer.Reset(jitterFunc(maxReSendInterval)) + + select { + case <-ctx.Done(): + return nil + case <-reload: + case <-timer.Chan(): + } + } +} + +// isAuthorizedKeysReportEnabled checks if the cluster has authorized keys report enabled. +func (w *Watcher) isAuthorizedKeysReportEnabled(ctx context.Context) (bool, error) { + accessGraphConfig, err := w.client.GetClusterAccessGraphConfig(ctx) + if err != nil { + return false, trace.Wrap(err) + } + return accessGraphConfig.GetEnabled() && accessGraphConfig.GetSecretsScanConfig().GetSshScanEnabled(), nil +} + +// fetchAndReportAuthorizedKeys fetches the authorized keys from the system and reports them to the cluster. +func (w *Watcher) fetchAndReportAuthorizedKeys( + ctx context.Context, + stream accessgraphsecretsv1pb.SecretsScannerService_ReportAuthorizedKeysClient, + fileWatcher *fsnotify.Watcher, +) error { + users, err := userList(ctx, w.logger, w.usersAccountFile) + if err != nil { + return trace.Wrap(err) + } + var keys []*accessgraphsecretsv1pb.AuthorizedKey + for _, u := range users { + if u.HomeDir == "" { + w.logger.DebugContext(ctx, "Skipping user with empty home directory", "user", u.Name) + continue + } + + for _, file := range []string{"authorized_keys", "authorized_keys2"} { + authorizedKeysPath := filepath.Join(u.HomeDir, ".ssh", file) + if fs, err := os.Stat(authorizedKeysPath); err != nil || fs.IsDir() { + continue + } + + hostKeys, err := w.parseAuthorizedKeysFile(ctx, u, authorizedKeysPath) + if errors.Is(err, os.ErrNotExist) { + continue + } else if err != nil { + w.logger.WarnContext(ctx, "Failed to parse authorized_keys file", "error", err) + continue + } + + // Add the file to the watcher. If file was already added, this is a no-op. + if err := fileWatcher.Add(authorizedKeysPath); err != nil { + w.logger.WarnContext(ctx, "Failed to add watcher for file", "error", err) + } + keys = append(keys, hostKeys...) + } + } + + const maxKeysPerReport = 500 + for i := 0; i < len(keys); i += maxKeysPerReport { + start := i + end := min(i+maxKeysPerReport, len(keys)) + if err := stream.Send( + &accessgraphsecretsv1pb.ReportAuthorizedKeysRequest{ + Keys: keys[start:end], + Operation: accessgraphsecretsv1pb.OperationType_OPERATION_TYPE_ADD, + }, + ); err != nil { + return trace.Wrap(err) + } + } + + if err := stream.Send( + &accessgraphsecretsv1pb.ReportAuthorizedKeysRequest{Operation: accessgraphsecretsv1pb.OperationType_OPERATION_TYPE_SYNC}, + ); err != nil { + return trace.Wrap(err) + } + return nil +} + +// userList retrieves all users on the system +func userList(ctx context.Context, log *slog.Logger, filePath string) ([]user.User, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer func() { + if err := file.Close(); err != nil { + log.DebugContext(ctx, "Failed to close file", "error", err, "file", filePath) + } + }() + + var users []user.User + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, "#") { + continue + } + // username:password:uid:gid:gecos:home:shell + parts := strings.Split(line, ":") + if len(parts) < 7 { + continue + } + users = append(users, user.User{ + Username: parts[0], + Uid: parts[2], + Gid: parts[3], + Name: parts[4], + HomeDir: parts[5], + }) + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return users, nil +} + +func (w *Watcher) parseAuthorizedKeysFile(ctx context.Context, u user.User, authorizedKeysPath string) ([]*accessgraphsecretsv1pb.AuthorizedKey, error) { + file, err := os.Open(authorizedKeysPath) + if err != nil { + return nil, trace.Wrap(err) + } + defer func() { + if err := file.Close(); err != nil { + w.logger.WarnContext(ctx, "Failed to close file", "error", err, "path", authorizedKeysPath) + } + }() + + var keys []*accessgraphsecretsv1pb.AuthorizedKey + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + for scanner.Scan() { + payload := scanner.Bytes() + if len(payload) == 0 || payload[0] == '#' { + continue + } + parsedKey, _, _, _, err := ssh.ParseAuthorizedKey(payload) + if err != nil { + w.logger.WarnContext(ctx, "Failed to parse authorized key", "error", err) + continue + } else if parsedKey == nil { + continue + } + + authorizedKey, err := accessgraph.NewAuthorizedKey( + &accessgraphsecretsv1pb.AuthorizedKeySpec{ + HostId: w.hostID, + HostUser: u.Username, + KeyFingerprint: ssh.FingerprintSHA256(parsedKey), + }, + ) + if err != nil { + w.logger.WarnContext(ctx, "Failed to create authorized key", "error", err) + continue + } + keys = append(keys, authorizedKey) + } + + return keys, nil +} + +func getOS(config WatcherConfig) string { + goos := runtime.GOOS + if config.getRuntimeOS != nil { + goos = config.getRuntimeOS() + } + return goos +} diff --git a/lib/secretsscanner/authorizedkeys/authorized_keys_test.go b/lib/secretsscanner/authorizedkeys/authorized_keys_test.go new file mode 100644 index 0000000000000..8e54603ad0ec4 --- /dev/null +++ b/lib/secretsscanner/authorizedkeys/authorized_keys_test.go @@ -0,0 +1,220 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package authorizedkeys + +import ( + "context" + "fmt" + "log/slog" + "os" + "path/filepath" + "slices" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/protobuf/testing/protocmp" + + "github.com/gravitational/teleport/api/constants" + accessgraphsecretsv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" + clusterconfigpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/clusterconfig/v1" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + "github.com/gravitational/teleport/api/types/accessgraph" +) + +func TestAuthorizedKeys(t *testing.T) { + hostID := "hostID" + + etcPasswdFile := createFSData(t) + clock := clockwork.NewFakeClockAt(time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC)) + client := &fakeClient{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + watcher, err := NewWatcher(ctx, WatcherConfig{ + Client: client, + etcPasswdFile: etcPasswdFile, + HostID: hostID, + Clock: clock, + Logger: slog.Default(), + getRuntimeOS: func() string { + return constants.LinuxOS + }, + }) + require.NoError(t, err) + + // Start the watcher + group, ctx := errgroup.WithContext(ctx) + group.Go(func() error { + return trace.Wrap(watcher.Run(ctx)) + }) + + // Wait for the watcher to start and to block on the initial spread. + clock.BlockUntil(2) // wait for clock to blocked at supervisor and initial delay routine + // Advance the clock to trigger the first scan + clock.Advance(5 * time.Minute) + + // Wait for the watcher to start + require.Eventually(t, func() bool { + return len(client.getReqReceived()) == 2 + }, 1*time.Second, 10*time.Millisecond, "expected watcher to start, but it did not") + + // Check the requests + got := client.getReqReceived() + require.Len(t, got, 2) + expected := []*accessgraphsecretsv1pb.ReportAuthorizedKeysRequest{ + { + Keys: createKeysForUsers(t, hostID), + Operation: accessgraphsecretsv1pb.OperationType_OPERATION_TYPE_ADD, + }, + { + Operation: accessgraphsecretsv1pb.OperationType_OPERATION_TYPE_SYNC, + }, + } + require.Empty(t, cmp.Diff(got, expected, + protocmp.Transform(), + protocmp.SortRepeated( + func(a, b *accessgraphsecretsv1pb.AuthorizedKey) bool { + return a.Metadata.Name < b.Metadata.Name + }, + ), + protocmp.IgnoreFields(&headerv1.Metadata{}, "expires"), + ), + ) + + // Clear the requests + client.clear() + + // Update the etcPasswdFile + createUsersAndAuthorizedKeys(t, filepath.Dir(etcPasswdFile)) + + cancel() + err = group.Wait() + require.NoError(t, err) + +} + +func createFSData(t *testing.T) string { + dir := t.TempDir() + etcPasswd := exampleEtcPasswdFile(dir) + createFile(t, dir, "passwd", etcPasswd) + + createUsersAndAuthorizedKeys(t, dir) + return filepath.Join(dir, "passwd") +} + +func createFile(t *testing.T, dir, name, content string) { + err := os.MkdirAll(dir, 0755) + require.NoError(t, err) + path := fmt.Sprintf("%s/%s", dir, name) + err = os.WriteFile(path, []byte(content), 0644) + require.NoError(t, err) +} + +func exampleEtcPasswdFile(dir string) string { + return fmt.Sprintf( + `root:x:0:0::%s/root:/usr/bin/bash +bin:x:1:1::/:/usr/bin/nologin +user:x:1000:1000::%s/user:/usr/bin/zsh`, + dir, + dir, + ) +} + +const authorizedFileExample = ` +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQClwXUKOp/S4XEtFjgr8mfaCy4OyI7N9ZMibdCGxvk2VHP9+Vn8Al1lUSVwuBxHI7EHiq42RCTBetIpTjzn6yiPNAeGNL5cfl9i6r+P5k7og1hz+2oheWveGodx6Dp+Z4o2dw65NGf5EPaotXF8AcHJc3+OiMS5yp/x2A3tu2I1SPQ6dtPa067p8q1L49BKbFwrFRBCVwkr6kpEQAIjnMESMPGD5Buu/AtyAdEZQSLTt8RZajJZDfXFKMEtQm2UF248NFl3hSMAcbbTxITBbZxX7THbwQz22Yuw7422G5CYBPf6WRXBY84Rs6jCS4I4GMxj+3rF4mGtjvuz0wOE32s3w4eMh9h3bPuEynufjE8henmPCIW49+kuZO4LZut7Zg5BfVDQnZYclwokEIMz+gR02YpyflxQOa98t/0mENu+t4f0LNAdkQEBpYtGKKDth5kLphi2Sdi9JpGO2sTivlxMsGyBqdd0wT9VwQpWf4wro6t09HdZJX1SAuEi/0tNI10= friel@test +# comment +ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGtqQKEkGIY5+Bc4EmEv7NeSn6aA7KMl5eiNEAOqwTBl friel@test +invalidLine +# comment +` + +func createUsersAndAuthorizedKeys(t *testing.T, dir string) { + for _, user := range []string{"root", "user"} { + dir := filepath.Join(dir, user, ".ssh") + createFile(t, dir, "authorized_keys", authorizedFileExample) + } +} + +type fakeClient struct { + accessgraphsecretsv1pb.SecretsScannerServiceClient + accessgraphsecretsv1pb.SecretsScannerService_ReportAuthorizedKeysClient + mu sync.Mutex + reqReceived []*accessgraphsecretsv1pb.ReportAuthorizedKeysRequest +} + +func (f *fakeClient) GetClusterAccessGraphConfig(_ context.Context) (*clusterconfigpb.AccessGraphConfig, error) { + return &clusterconfigpb.AccessGraphConfig{ + Enabled: true, + SecretsScanConfig: &clusterconfigpb.AccessGraphSecretsScanConfiguration{ + SshScanEnabled: true, + }, + }, nil +} +func (f *fakeClient) AccessGraphSecretsScannerClient() accessgraphsecretsv1pb.SecretsScannerServiceClient { + return f +} + +func (f *fakeClient) ReportAuthorizedKeys(_ context.Context, _ ...grpc.CallOption) (accessgraphsecretsv1pb.SecretsScannerService_ReportAuthorizedKeysClient, error) { + return f, nil +} + +func (f *fakeClient) Send(req *accessgraphsecretsv1pb.ReportAuthorizedKeysRequest) error { + f.mu.Lock() + defer f.mu.Unlock() + f.reqReceived = append(f.reqReceived, req) + return nil +} + +func (f *fakeClient) clear() { + f.mu.Lock() + defer f.mu.Unlock() + f.reqReceived = nil +} + +func (f *fakeClient) getReqReceived() []*accessgraphsecretsv1pb.ReportAuthorizedKeysRequest { + f.mu.Lock() + defer f.mu.Unlock() + return slices.Clone(f.reqReceived) +} + +func createKeysForUsers(t *testing.T, hostID string) []*accessgraphsecretsv1pb.AuthorizedKey { + var keys []*accessgraphsecretsv1pb.AuthorizedKey + for _, fingerprint := range []string{ + "SHA256:GbJlTLeQgZhvGoklWGXHo0AinGgGEcldllgYExoSy+s", /* ssh-rsa */ + "SHA256:ewwMB/nCAYurNrYFXYZuxLZv7T7vgpPd7QuIo0d5n+U", /* ssh-ed25519 */ + } { + for _, user := range []string{"root", "user"} { + at, err := accessgraph.NewAuthorizedKey(&accessgraphsecretsv1pb.AuthorizedKeySpec{ + HostId: hostID, + HostUser: user, + KeyFingerprint: fingerprint, + }) + require.NoError(t, err) + keys = append(keys, at) + } + } + return keys +} diff --git a/lib/secretsscanner/authorizedkeys/supervisor.go b/lib/secretsscanner/authorizedkeys/supervisor.go new file mode 100644 index 0000000000000..c0048abb07877 --- /dev/null +++ b/lib/secretsscanner/authorizedkeys/supervisor.go @@ -0,0 +1,112 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package authorizedkeys + +import ( + "context" + "errors" + "log/slog" + "sync" + "time" + + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport/api/utils/retryutils" +) + +var errShutdown = errors.New("watcher is shutting down") + +type supervisorRunnerConfig struct { + clock clockwork.Clock + tickerInterval time.Duration + runner func(context.Context) error + checkIfMonitorEnabled func(context.Context) (bool, error) + logger *slog.Logger +} + +// supervisorRunner runs the runner based on the checkIfMonitorEnabled result. +// If the monitor is enabled, the runner is started. If the monitor is disabled, +// the runner is stopped if it is running. +// The checkIfMonitorEnabled is evaluated every tickerInterval duration to determine +// if the monitor should be started or stopped. +// tickerInterval is jittered to prevent all watchers from running at the same time. +// If the watcher is stopped, it will be restarted after the next checkIfMonitorEnabled evaluation. +func supervisorRunner(parentCtx context.Context, cfg supervisorRunnerConfig) error { + var ( + isRunning = false + runCtx context.Context + runCtxCancel context.CancelCauseFunc + wg sync.WaitGroup + mu sync.Mutex + ) + + getIsRunning := func() bool { + mu.Lock() + defer mu.Unlock() + return isRunning + } + + setIsRunning := func(s bool) { + mu.Lock() + defer mu.Unlock() + isRunning = s + } + + runRoutine := func(ctx context.Context, cancel context.CancelCauseFunc) { + defer func() { + wg.Done() + cancel(errShutdown) + setIsRunning(false) + }() + if err := cfg.runner(ctx); err != nil && !errors.Is(err, errShutdown) { + cfg.logger.WarnContext(ctx, "Runner failed", "error", err) + } + } + + jitterFunc := retryutils.NewHalfJitter() + t := cfg.clock.NewTimer(jitterFunc(cfg.tickerInterval)) + for { + switch enabled, err := cfg.checkIfMonitorEnabled(parentCtx); { + case err != nil: + cfg.logger.WarnContext(parentCtx, "Failed to check if authorized keys report is enabled", "error", err) + case enabled && !getIsRunning(): + runCtx, runCtxCancel = context.WithCancelCause(parentCtx) + setIsRunning(true) + wg.Add(1) + go runRoutine(runCtx, runCtxCancel) + case !enabled && getIsRunning(): + runCtxCancel(errShutdown) + // Wait for the runner to stop before checking if the monitor is enabled again. + wg.Wait() + } + + select { + case <-t.Chan(): + if !t.Stop() { + select { + case <-t.Chan(): + default: + } + } + t.Reset(jitterFunc(cfg.tickerInterval)) + case <-parentCtx.Done(): + return nil + } + } +} diff --git a/lib/secretsscanner/authorizedkeys/supervisor_test.go b/lib/secretsscanner/authorizedkeys/supervisor_test.go new file mode 100644 index 0000000000000..93399ec1fdc5b --- /dev/null +++ b/lib/secretsscanner/authorizedkeys/supervisor_test.go @@ -0,0 +1,135 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package authorizedkeys + +import ( + "context" + "log/slog" + "sync" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func TestSupervisorRunner(t *testing.T) { + // Create a mock clock + clock := clockwork.NewFakeClock() + + t.Run("runner starts and stops based on monitor state", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var mu sync.Mutex + var running bool + + runner := func(ctx context.Context) error { + mu.Lock() + running = true + mu.Unlock() + <-ctx.Done() + mu.Lock() + running = false + mu.Unlock() + return nil + } + + checker, enable, disable := checkIfMonitorEnabled() + enable() + + cfg := supervisorRunnerConfig{ + clock: clock, + tickerInterval: 1 * time.Second, + runner: runner, + checkIfMonitorEnabled: checker, + logger: slog.Default(), + } + + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + return supervisorRunner(ctx, cfg) + }) + + require.Eventually(t, func() bool { + mu.Lock() + defer mu.Unlock() + return running + }, 100*time.Millisecond, 10*time.Millisecond, "expected runner to start, but it did not") + + disable() + + clock.BlockUntil(1) + clock.Advance(2 * time.Second) + + require.Eventually(t, func() bool { + mu.Lock() + defer mu.Unlock() + return !running + }, 100*time.Millisecond, 10*time.Millisecond, "expected runner to stop, but it did not") + + enable() + clock.BlockUntil(1) + clock.Advance(2 * time.Second) + + require.Eventually(t, func() bool { + mu.Lock() + defer mu.Unlock() + return running + }, 100*time.Millisecond, 10*time.Millisecond, "expected runner to re-start, but it did not") + + disable() + clock.BlockUntil(1) + clock.Advance(2 * time.Second) + + require.Eventually(t, func() bool { + mu.Lock() + defer mu.Unlock() + return !running + }, 100*time.Millisecond, 10*time.Millisecond, "expected runner to re-stop, but it did not") + + // Cancel the context to stop the supervisor + cancel() + if err := g.Wait(); err != nil { + t.Fatal(err) + } + }) + +} + +func checkIfMonitorEnabled() (checker func(context.Context) (bool, error), enable func(), disable func()) { + var ( + enabled bool + mu sync.Mutex + ) + return func(ctx context.Context) (bool, error) { + mu.Lock() + defer mu.Unlock() + return enabled, nil + }, func() { + mu.Lock() + defer mu.Unlock() + enabled = true + }, func() { + mu.Lock() + defer mu.Unlock() + enabled = false + } +} diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 301f1c1d6345f..9975612c12c3d 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -26,6 +26,7 @@ import ( "errors" "fmt" "io" + "log/slog" "maps" "net" "os" @@ -62,6 +63,7 @@ import ( "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/reversetunnelclient" + authorizedkeysreporter "github.com/gravitational/teleport/lib/secretsscanner/authorizedkeys" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" @@ -850,6 +852,12 @@ func New( } s.srv = server + if !s.proxyMode { + if err := s.startAuthorizedKeysManager(ctx, auth); err != nil { + log.WithError(err).Infof("Failed to start authorized keys manager.") + } + } + var heartbeatMode srv.HeartbeatMode if s.proxyMode { heartbeatMode = srv.HeartbeatModeProxy @@ -904,6 +912,31 @@ func (s *Server) tunnelWithAccessChecker(ctx *srv.ServerContext) (reversetunnelc return reversetunnelclient.NewTunnelWithRoles(s.proxyTun, clusterName.GetClusterName(), ctx.Identity.AccessChecker, s.proxyAccessPoint), nil } +// startAuthorizedKeysManager starts the authorized keys manager. +func (s *Server) startAuthorizedKeysManager(ctx context.Context, auth authclient.ClientI) error { + authorizedKeysWatcher, err := authorizedkeysreporter.NewWatcher( + ctx, + authorizedkeysreporter.WatcherConfig{ + Client: auth, + Logger: slog.Default(), + HostID: s.uuid, + Clock: s.clock, + }, + ) + if errors.Is(err, authorizedkeysreporter.ErrUnsupportedPlatform) { + return nil + } else if err != nil { + return trace.Wrap(err) + } + + go func() { + if err := authorizedKeysWatcher.Run(ctx); err != nil { + s.Warningf("Failed to start authorized keys watcher: %v", err) + } + }() + return nil +} + // Context returns server shutdown context func (s *Server) Context() context.Context { return s.ctx