Skip to content

Commit

Permalink
handle code review
Browse files Browse the repository at this point in the history
  • Loading branch information
tigrato committed Jul 19, 2024
1 parent 35e68a1 commit 7138ae2
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 70 deletions.
8 changes: 4 additions & 4 deletions lib/secretsscanner/reporter/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func setup(t *testing.T, ops ...option) env {
assert.NoError(t, err)
})

svc := newServiceFake(dtFakeSvc)
svc := newServiceFake(dtFakeSvc.Service)
svc.assertionError = o.assertionError
svc.preReconcileError = o.preReconcileError

Expand Down Expand Up @@ -127,7 +127,7 @@ func setup(t *testing.T, ops ...option) env {
}
}

func newServiceFake(deviceTrustSvc *dttestenv.E) *serviceFake {
func newServiceFake(deviceTrustSvc *dttestenv.FakeDeviceService) *serviceFake {
return &serviceFake{
deviceTrustSvc: deviceTrustSvc,
}
Expand All @@ -136,7 +136,7 @@ func newServiceFake(deviceTrustSvc *dttestenv.E) *serviceFake {
type serviceFake struct {
accessgraphsecretsv1pb.UnimplementedSecretsScannerServiceServer
privateKeysReported []*accessgraphsecretsv1pb.PrivateKey
deviceTrustSvc *dttestenv.E
deviceTrustSvc *dttestenv.FakeDeviceService
assertionError error
preReconcileError error
}
Expand All @@ -146,7 +146,7 @@ func (s *serviceFake) ReportSecrets(in accessgraphsecretsv1pb.SecretsScannerServ
return s.assertionError
}
// Step 1. Assert the device.
if _, err := s.deviceTrustSvc.Service.AssertDevice(in.Context(), streamAdapter{stream: in}); err != nil {
if _, err := s.deviceTrustSvc.AssertDevice(in.Context(), streamAdapter{stream: in}); err != nil {
return trace.Wrap(err)
}
// Step 2. Collect the private keys into a temporary slice.
Expand Down
34 changes: 9 additions & 25 deletions lib/secretsscanner/reporter/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,46 +32,41 @@ import (
secretsscannerclient "github.com/gravitational/teleport/lib/secretsscanner/client"
)

const (
defaultBatchSize = 100
)

// AssertCeremonyBuilderFunc is a function that builds the device authentication ceremony.
type AssertCeremonyBuilderFunc func() (*dtassert.Ceremony, error)

// Config specifies the configuration for the reporter.
type Config struct {
// ProxyAddress is the address of the proxy server to send the report to.
ProxyAddress string
// InsecureSkipVerify specifies whether to trust the certificates from the Proxy Server during registration without verification.
InsecureSkipVerify bool
// Client is a client for the SecretsScannerService.
Client secretsscannerclient.Client
// Log is the logger.
Log *slog.Logger
// BatchSize is the number of secrets to send in a single batch. Defaults to [defaultBatchSize] if not set.
BatchSize int
// AssertCeremonyBuilder is the device authentication ceremony builder.
// If not set, the default device authentication ceremony will be used.
// Used for testing, avoid in production code.
AssertCeremonyBuilder AssertCeremonyBuilderFunc
}

// Reporter reports secrets to the Teleport Proxy.
type Reporter struct {
proxyAddr string
insecureSkipVerify bool
client secretsscannerclient.Client
log *slog.Logger
batchSize int
assertCeremonyBuilder AssertCeremonyBuilderFunc
}

// New creates a new reporter instance.
func New(cfg Config) (*Reporter, error) {
if cfg.ProxyAddress == "" {
return nil, trace.BadParameter("missing proxy address")
if cfg.Client == nil {
return nil, trace.BadParameter("missing client")
}
if cfg.Log == nil {
cfg.Log = slog.Default()
}
if cfg.BatchSize == 0 {
const defaultBatchSize = 100
cfg.BatchSize = defaultBatchSize
}
if cfg.AssertCeremonyBuilder == nil {
Expand All @@ -80,8 +75,7 @@ func New(cfg Config) (*Reporter, error) {
}
}
return &Reporter{
proxyAddr: cfg.ProxyAddress,
insecureSkipVerify: cfg.InsecureSkipVerify,
client: cfg.Client,
log: cfg.Log,
batchSize: cfg.BatchSize,
assertCeremonyBuilder: cfg.AssertCeremonyBuilder,
Expand All @@ -95,18 +89,8 @@ func New(cfg Config) (*Reporter, error) {
// 3. Report the private keys to the Teleport cluster.
// 4. Wait for the server to acknowledge the report.
func (r *Reporter) ReportPrivateKeys(ctx context.Context, pks []*accessgraphsecretsv1pb.PrivateKey) error {
client, err := secretsscannerclient.NewSecretsScannerServiceClient(
ctx,
secretsscannerclient.ClientConfig{
ProxyServer: r.proxyAddr,
Insecure: r.insecureSkipVerify,
Log: slog.Default(),
})
if err != nil {
return trace.Wrap(err, "failed to create client")
}

stream, err := client.ReportSecrets(ctx)
stream, err := r.client.ReportSecrets(ctx)
if err != nil {
return trace.Wrap(err, "failed to create client")
}
Expand Down
26 changes: 19 additions & 7 deletions lib/secretsscanner/reporter/report_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
dtassert "github.com/gravitational/teleport/lib/devicetrust/assert"
dtauthn "github.com/gravitational/teleport/lib/devicetrust/authn"
dttestenv "github.com/gravitational/teleport/lib/devicetrust/testenv"
secretsscannerclient "github.com/gravitational/teleport/lib/secretsscanner/client"
"github.com/gravitational/teleport/lib/secretsscanner/reporter"
)

Expand Down Expand Up @@ -66,7 +67,7 @@ func TestReporter(t *testing.T) {
name: "assertion error",
assertionError: errors.New("assertion error"),
report: newPrivateKeys(t, deviceID),
assertErr: func(t require.TestingT, err error, i ...any) {
assertErr: func(t require.TestingT, err error, _ ...any) {
require.ErrorContains(t, err, "assertion error")

},
Expand All @@ -75,27 +76,38 @@ func TestReporter(t *testing.T) {
name: "pre-reconcile error",
preReconcileError: errors.New("pre-reconcile error"),
report: newPrivateKeys(t, deviceID),
assertErr: func(t require.TestingT, err error, i ...any) {
assertErr: func(t require.TestingT, err error, _ ...any) {
require.ErrorContains(t, err, "pre-reconcile error")
},
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
e := setup(
t,
withDevice(deviceID, device),
withAssertionError(tt.assertionError),
withPreReconcileError(tt.preReconcileError),
)

ctx := context.Background()

client, err := secretsscannerclient.NewSecretsScannerServiceClient(ctx,
secretsscannerclient.ClientConfig{
ProxyServer: e.secretsScannerAddr,
Insecure: true,
},
)
require.NoError(t, err)

r, err := reporter.New(
reporter.Config{
ProxyAddress: e.secretsScannerAddr,
Log: slog.Default(),
InsecureSkipVerify: true, /* insecureSkipVerify for tests */
BatchSize: 1, /* batch size for tests */
Log: slog.Default(),
Client: client,
BatchSize: 1, /* batch size for tests */
AssertCeremonyBuilder: func() (*dtassert.Ceremony, error) {
return dtassert.NewCeremony(
dtassert.WithNewAuthnCeremonyFunc(
Expand All @@ -117,7 +129,7 @@ func TestReporter(t *testing.T) {
)
require.NoError(t, err)

err = r.ReportPrivateKeys(context.Background(), tt.report)
err = r.ReportPrivateKeys(ctx, tt.report)
tt.assertErr(t, err)

got := e.service.privateKeysReported
Expand Down
44 changes: 25 additions & 19 deletions lib/secretsscanner/scaner/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ import (
type Config struct {
// Dirs is a list of directories to scan.
Dirs []string
// SkipDirs is a list of directories to skip.
// SkipPaths is a list of paths to skip.
// It supports glob patterns (e.g. "/etc/*/").
// Please refer to the [filepath.Match] documentation for more information.
SkipDirs []string
SkipPaths []string
// Log is the logger.
Log *slog.Logger
}
Expand All @@ -58,26 +58,26 @@ func New(cfg Config) (*Scanner, error) {
cfg.Log = slog.Default()
}

// expand the glob patterns in the skipDirs list.
// expand the glob patterns in the skipPaths list.
// we expand the glob patterns here to avoid expanding them for each file during the scan.
// only the directories matched by the glob patterns will be skipped.
skippedDirs, err := expandSkipDirs(cfg.SkipDirs)
skippedPaths, err := expandSkipPaths(cfg.SkipPaths)
if err != nil {
return nil, trace.Wrap(err)
}

return &Scanner{
dirs: cfg.Dirs,
log: cfg.Log,
skippedDirs: skippedDirs,
dirs: cfg.Dirs,
log: cfg.Log,
skippedPaths: skippedPaths,
}, nil
}

// Scanner is a scanner that scans directories for secrets.
type Scanner struct {
dirs []string
log *slog.Logger
skippedDirs map[string]struct{}
dirs []string
log *slog.Logger
skippedPaths map[string]struct{}
}

// ScanPrivateKeys scans directories for SSH private keys.
Expand Down Expand Up @@ -118,12 +118,18 @@ func (s *Scanner) findPrivateKeys(ctx context.Context, root, deviceID string, pr
return fs.SkipDir
}
if info.IsDir() {
if _, ok := s.skippedDirs[path]; ok {
if _, ok := s.skippedPaths[path]; ok {
logger.DebugContext(ctx, "skipping directory", "path", path)
return fs.SkipDir
}
return nil
}

if _, ok := s.skippedPaths[path]; ok {
logger.DebugContext(ctx, "skipping file", "path", path)
return nil
}

switch fileData, isKey, err := s.readFileIfSSHPrivateKey(ctx, path); {
case err != nil:
logger.DebugContext(ctx, "error reading file", "path", path, "error", err)
Expand Down Expand Up @@ -161,7 +167,7 @@ func (s *Scanner) readFileIfSSHPrivateKey(ctx context.Context, filePath string)
}
defer func() {
if err = file.Close(); err != nil {
s.log.WarnContext(ctx, "failed to close file", "path", filePath, "error", err)
s.log.DebugContext(ctx, "failed to close file", "path", filePath, "error", err)
}
}()

Expand Down Expand Up @@ -275,18 +281,18 @@ func privateKeyNameGen(path, deviceID, fingerprint string) string {
return hex.EncodeToString(sha.Sum(nil))
}

// expandSkipDirs expands the glob patterns in the skipDirs list and returns a set of the
// directories matched by the glob patterns to be skipped.
func expandSkipDirs(skipDirs []string) (map[string]struct{}, error) {
skippedDirs := make(map[string]struct{})
for _, glob := range skipDirs {
// expandSkipPaths expands the glob patterns in the skipPaths list and returns a set of the
// paths matched by the glob patterns to be skipped.
func expandSkipPaths(skipPaths []string) (map[string]struct{}, error) {
set := make(map[string]struct{})
for _, glob := range skipPaths {
matches, err := filepath.Glob(glob)
if err != nil {
return nil, trace.Wrap(err, "glob pattern %q is invalid", glob)
}
for _, match := range matches {
skippedDirs[match] = struct{}{}
set[match] = struct{}{}
}
}
return skippedDirs, nil
return set, nil
}
8 changes: 4 additions & 4 deletions lib/secretsscanner/scaner/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ func TestNewScanner(t *testing.T) {

expect := tt.keysGen(t, dir)

var skipDirs []string
var skipPaths []string
if tt.skipTestDir {
// skip the test directory.
skipDirs = []string{filepath.Join(dir, "*")}
skipPaths = []string{filepath.Join(dir, "*")}
// the expected keys should be nil since the test directory is skipped.
expect = nil
}

s, err := New(Config{
Dirs: []string{dir},
SkipDirs: skipDirs,
Dirs: []string{dir},
SkipPaths: skipPaths,
})
require.NoError(t, err)

Expand Down
33 changes: 22 additions & 11 deletions tool/tsh/common/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/gravitational/teleport/api/types/accessgraph"
"github.com/gravitational/teleport/lib/devicetrust/assert"
dtnative "github.com/gravitational/teleport/lib/devicetrust/native"
secretsscannerclient "github.com/gravitational/teleport/lib/secretsscanner/client"
secretsreporter "github.com/gravitational/teleport/lib/secretsscanner/reporter"
secretsscanner "github.com/gravitational/teleport/lib/secretsscanner/scaner"
)
Expand All @@ -51,16 +52,16 @@ func newScanCommand(app *kingpin.Application) scanCommand {

type scanKeysCommand struct {
*kingpin.CmdClause
dirs []string
skipDirs []string
ca string
out io.Writer
dirs []string
skipPaths []string
ca string
out io.Writer
}

func newScanKeysCommand(parent *kingpin.CmdClause) *scanKeysCommand {
c := &scanKeysCommand{CmdClause: parent.Command("keys", "Scan the local machine for SSH private keys and report findings to Teleport")}
c.Flag("dirs", "Directories to scan.").Default(defaultDirValues()).StringsVar(&c.dirs)
c.Flag("skip-dirs", "Directories to skip. Supports matching patterns.").StringsVar(&c.skipDirs)
c.Flag("skip-paths", "Paths to directories or files to skip. Supports for matching patterns.").StringsVar(&c.skipPaths)
return c
}

Expand Down Expand Up @@ -96,9 +97,9 @@ func (c *scanKeysCommand) run(cf *CLIConf) error {
fmt.Printf("Device trust credentials found.\nScanning %s.\n", strings.Join(c.dirs, ", "))

scanner, err := secretsscanner.New(secretsscanner.Config{
Dirs: c.dirs,
SkipDirs: c.skipDirs,
Log: slog.Default(),
Dirs: c.dirs,
SkipPaths: c.skipPaths,
Log: slog.Default(),
})
if err != nil {
return trace.Wrap(err, "failed to create scanner")
Expand All @@ -111,11 +112,21 @@ func (c *scanKeysCommand) run(cf *CLIConf) error {

printPrivateKeys(privateKeys)

client, err := secretsscannerclient.NewSecretsScannerServiceClient(
ctx,
secretsscannerclient.ClientConfig{
ProxyServer: cf.Proxy,
Insecure: cf.InsecureSkipVerify,
Log: slog.Default(),
})
if err != nil {
return trace.Wrap(err, "failed to create client")
}

reporter, err := secretsreporter.New(
secretsreporter.Config{
ProxyAddress: cf.Proxy,
Log: slog.Default(),
InsecureSkipVerify: cf.InsecureSkipVerify,
Client: client,
Log: slog.Default(),
AssertCeremonyBuilder: func() (*assert.Ceremony, error) {
return assert.NewCeremony()
},
Expand Down

0 comments on commit 7138ae2

Please sign in to comment.