diff --git a/pkg/detectors/postgres/postgres.go b/pkg/detectors/postgres/postgres.go index 8f25ed3def71..055a85125eb7 100644 --- a/pkg/detectors/postgres/postgres.go +++ b/pkg/detectors/postgres/postgres.go @@ -21,18 +21,19 @@ import ( const ( defaultPort = "5432" - pg_connect_timeout = "connect_timeout" - pg_dbname = "dbname" - pg_host = "host" - pg_password = "password" - pg_port = "port" - pg_requiressl = "requiressl" - pg_sslmode = "sslmode" - pg_sslmode_allow = "allow" - pg_sslmode_disable = "disable" - pg_sslmode_prefer = "prefer" - pg_sslmode_require = "require" - pg_user = "user" + pgConnectTimeout = "connect_timeout" + pgDbname = "dbname" + pgHost = "host" + pgPassword = "password" + pgPort = "port" + pgRequiressl = "requiressl" + pgSslmode = "sslmode" + pgSslmodeAllow = "allow" + pgSslmodeDisable = "disable" + pgSslmodePrefer = "prefer" + pgSslmodeRequire = "require" + pgUser = "user" + pgDbType = "db_type" ) // This detector currently only finds Postgres connection string URIs @@ -47,7 +48,7 @@ const ( // happen to run into a case where this matters we can address it then. var ( _ detectors.Detector = (*Scanner)(nil) - uriPattern = regexp.MustCompile(`\b(?i)postgres(?:ql)?://\S+\b`) + uriPattern = regexp.MustCompile(`\b(?i)(postgres(?:ql)?)://\S+\b`) connStrPartPattern = regexp.MustCompile(`([[:alpha:]]+)='(.+?)' ?`) ) @@ -71,17 +72,17 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete if common.IsDone(ctx) { break } - user, ok := params[pg_user] + user, ok := params[pgUser] if !ok { continue } - password, ok := params[pg_password] + password, ok := params[pgPassword] if !ok { continue } - host, ok := params[pg_host] + host, ok := params[pgHost] if !ok { continue } @@ -94,13 +95,18 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete } } - port, ok := params[pg_port] + port, ok := params[pgPort] if !ok { port = defaultPort - params[pg_port] = port + params[pgPort] = port } - raw := []byte(fmt.Sprintf("postgresql://%s:%s@%s:%s", user, password, host, port)) + const defaultDBType = "postgresql" + dbType, ok := params[pgDbType] + if !ok { + dbType = defaultDBType + } + raw := []byte(fmt.Sprintf("%s://%s:%s@%s:%s", dbType, user, password, host, port)) result := detectors.Result{ DetectorType: detectorspb.DetectorType_Postgres, @@ -112,17 +118,17 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete // do it for us - but we will do it anyway here so that when we later capture sslmode into ExtraData we will // capture it post-normalization. (The detector's behavior is undefined for candidate secrets that have both // requiressl and sslmode set.) - if requiressl := params[pg_requiressl]; requiressl == "0" { - params[pg_sslmode] = pg_sslmode_prefer + if requiressl := params[pgRequiressl]; requiressl == "0" { + params[pgSslmode] = pgSslmodePrefer } else if requiressl == "1" { - params[pg_sslmode] = pg_sslmode_require + params[pgSslmode] = pgSslmodeRequire } if verify { // pq appears to ignore the context deadline, so we copy any timeout that's been set into the connection // parameters themselves. if timeout, ok := getDeadlineInSeconds(ctx); ok && timeout > 0 { - params[pg_connect_timeout] = strconv.Itoa(timeout) + params[pgConnectTimeout] = strconv.Itoa(timeout) } else if timeout <= 0 { // Deadline in the context has already exceeded. break @@ -137,12 +143,12 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete } // We gather SSL information into ExtraData in case it's useful for later reporting. - sslmode := params[pg_sslmode] + sslmode := params[pgSslmode] if sslmode == "" { sslmode = "" } result.ExtraData = map[string]string{ - pg_sslmode: sslmode, + pgSslmode: sslmode, } results = append(results, result) @@ -158,6 +164,13 @@ func (s Scanner) IsFalsePositive(_ detectors.Result) (bool, string) { func findUriMatches(data []byte) []map[string]string { var matches []map[string]string for _, uri := range uriPattern.FindAll(data, -1) { + // Capture the database type (e.g., "postgres" or "postgresql") + dbTypeMatch := uriPattern.FindSubmatch(uri) + if len(dbTypeMatch) < 2 { + continue + } + dbType := string(dbTypeMatch[1]) + connStr, err := pq.ParseURL(string(uri)) if err != nil { continue @@ -169,6 +182,7 @@ func findUriMatches(data []byte) []map[string]string { params[part[1]] = part[2] } + params[pgDbType] = dbType matches = append(matches, params) } return matches @@ -198,14 +212,14 @@ func isErrorDatabaseNotFound(err error, dbName string) bool { } func verifyPostgres(params map[string]string) (bool, error) { - if sslmode := params[pg_sslmode]; sslmode == pg_sslmode_allow || sslmode == pg_sslmode_prefer { + if sslmode := params[pgSslmode]; sslmode == pgSslmodeAllow || sslmode == pgSslmodePrefer { // pq doesn't support 'allow' or 'prefer'. If we find either of them, we'll just ignore it. This will trigger // the same logic that is run if no sslmode is set at all (which mimics 'prefer', which is the default). - delete(params, pg_sslmode) + delete(params, pgSslmode) // We still want to save the original sslmode in ExtraData, so we'll re-add it before returning. defer func() { - params[pg_sslmode] = sslmode + params[pgSslmode] = sslmode }() } @@ -226,14 +240,14 @@ func verifyPostgres(params map[string]string) (bool, error) { return true, nil case strings.Contains(err.Error(), "password authentication failed"): return false, nil - case errors.Is(err, pq.ErrSSLNotSupported) && params[pg_sslmode] == "": + case errors.Is(err, pq.ErrSSLNotSupported) && params[pgSslmode] == "": // If the sslmode is unset, then either it was unset in the candidate secret, or we've intentionally unset it // because it was specified as 'allow' or 'prefer', neither of which pq supports. In all of these cases, non-SSL // connections are acceptable, so now we try a connection without SSL. - params[pg_sslmode] = pg_sslmode_disable - defer delete(params, pg_sslmode) // We want to return with the original params map intact (for ExtraData) + params[pgSslmode] = pgSslmodeDisable + defer delete(params, pgSslmode) // We want to return with the original params map intact (for ExtraData) return verifyPostgres(params) - case isErrorDatabaseNotFound(err, params[pg_dbname]): + case isErrorDatabaseNotFound(err, params[pgDbname]): return true, nil // If we know this, we were able to authenticate default: return false, err diff --git a/pkg/engine/engine_test.go b/pkg/engine/engine_test.go index 5901e67d2d78..df12b2109409 100644 --- a/pkg/engine/engine_test.go +++ b/pkg/engine/engine_test.go @@ -1189,3 +1189,82 @@ func TestEngineInitializesCloudProviderDetectors(t *testing.T) { t.Fatal("no detectors found implementing Endpoints(), did EndpointSetter change?") } } + +func TestEngineignoreLine(t *testing.T) { + tests := []struct { + name string + content string + expectedFindings int + }{ + { + name: "ignore at end of line", + content: ` +# tests/example_false_positive.py + +def test_something(): + connection_string = "who-cares" + + # Ignoring this does not work + assert connection_string == "postgres://master_user:master_password@hostname:1234/main" # trufflehog:ignore`, + expectedFindings: 0, + }, + { + name: "ignore not on secret line", + content: ` +# tests/example_false_positive.py + +def test_something(): + connection_string = "who-cares" + + # Ignoring this does not work + assert some_other_stuff == "blah" # trufflehog:ignore + assert connection_string == "postgres://master_user:master_password@hostname:1234/main"`, + expectedFindings: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tmpFile, err := os.CreateTemp("", "test_creds") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + err = os.WriteFile(tmpFile.Name(), []byte(tt.content), os.ModeAppend) + assert.NoError(t, err) + + const defaultOutputBufferSize = 64 + opts := []func(*sources.SourceManager){ + sources.WithSourceUnits(), + sources.WithBufferedOutput(defaultOutputBufferSize), + } + + sourceManager := sources.NewManager(opts...) + + conf := Config{ + Concurrency: 1, + Decoders: decoders.DefaultDecoders(), + Detectors: DefaultDetectors(), + Verify: false, + SourceManager: sourceManager, + Dispatcher: NewPrinterDispatcher(new(discardPrinter)), + } + + eng, err := NewEngine(ctx, &conf) + assert.NoError(t, err) + + eng.Start(ctx) + + cfg := sources.FilesystemConfig{Paths: []string{tmpFile.Name()}} + err = eng.ScanFileSystem(ctx, cfg) + assert.NoError(t, err) + + assert.NoError(t, eng.Finish(ctx)) + assert.Equal(t, tt.expectedFindings, int(eng.GetMetrics().UnverifiedSecretsFound)) + }) + } +}