Skip to content

Commit

Permalink
[bug] - correctly capture db type for postgres detector (#3610)
Browse files Browse the repository at this point in the history
* correctly capture db type for postgres detector.

* use const and rename other consts
  • Loading branch information
ahrav authored Nov 15, 2024
1 parent 8f2ebc9 commit cca7e6b
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 32 deletions.
78 changes: 46 additions & 32 deletions pkg/detectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@ import (
const (
defaultPort = "5432"

pg_connect_timeout = "connect_timeout"
pg_dbname = "dbname"
pg_host = "host"
pg_password = "password"
pg_port = "port"
pg_requiressl = "requiressl"
pg_sslmode = "sslmode"
pg_sslmode_allow = "allow"
pg_sslmode_disable = "disable"
pg_sslmode_prefer = "prefer"
pg_sslmode_require = "require"
pg_user = "user"
pgConnectTimeout = "connect_timeout"
pgDbname = "dbname"
pgHost = "host"
pgPassword = "password"
pgPort = "port"
pgRequiressl = "requiressl"
pgSslmode = "sslmode"
pgSslmodeAllow = "allow"
pgSslmodeDisable = "disable"
pgSslmodePrefer = "prefer"
pgSslmodeRequire = "require"
pgUser = "user"
pgDbType = "db_type"
)

// This detector currently only finds Postgres connection string URIs
Expand All @@ -47,7 +48,7 @@ const (
// happen to run into a case where this matters we can address it then.
var (
_ detectors.Detector = (*Scanner)(nil)
uriPattern = regexp.MustCompile(`\b(?i)postgres(?:ql)?://\S+\b`)
uriPattern = regexp.MustCompile(`\b(?i)(postgres(?:ql)?)://\S+\b`)
connStrPartPattern = regexp.MustCompile(`([[:alpha:]]+)='(.+?)' ?`)
)

Expand All @@ -71,17 +72,17 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete
if common.IsDone(ctx) {
break
}
user, ok := params[pg_user]
user, ok := params[pgUser]
if !ok {
continue
}

password, ok := params[pg_password]
password, ok := params[pgPassword]
if !ok {
continue
}

host, ok := params[pg_host]
host, ok := params[pgHost]
if !ok {
continue
}
Expand All @@ -94,13 +95,18 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete
}
}

port, ok := params[pg_port]
port, ok := params[pgPort]
if !ok {
port = defaultPort
params[pg_port] = port
params[pgPort] = port
}

raw := []byte(fmt.Sprintf("postgresql://%s:%s@%s:%s", user, password, host, port))
const defaultDBType = "postgresql"
dbType, ok := params[pgDbType]
if !ok {
dbType = defaultDBType
}
raw := []byte(fmt.Sprintf("%s://%s:%s@%s:%s", dbType, user, password, host, port))

result := detectors.Result{
DetectorType: detectorspb.DetectorType_Postgres,
Expand All @@ -112,17 +118,17 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete
// do it for us - but we will do it anyway here so that when we later capture sslmode into ExtraData we will
// capture it post-normalization. (The detector's behavior is undefined for candidate secrets that have both
// requiressl and sslmode set.)
if requiressl := params[pg_requiressl]; requiressl == "0" {
params[pg_sslmode] = pg_sslmode_prefer
if requiressl := params[pgRequiressl]; requiressl == "0" {
params[pgSslmode] = pgSslmodePrefer
} else if requiressl == "1" {
params[pg_sslmode] = pg_sslmode_require
params[pgSslmode] = pgSslmodeRequire
}

if verify {
// pq appears to ignore the context deadline, so we copy any timeout that's been set into the connection
// parameters themselves.
if timeout, ok := getDeadlineInSeconds(ctx); ok && timeout > 0 {
params[pg_connect_timeout] = strconv.Itoa(timeout)
params[pgConnectTimeout] = strconv.Itoa(timeout)
} else if timeout <= 0 {
// Deadline in the context has already exceeded.
break
Expand All @@ -137,12 +143,12 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete
}

// We gather SSL information into ExtraData in case it's useful for later reporting.
sslmode := params[pg_sslmode]
sslmode := params[pgSslmode]
if sslmode == "" {
sslmode = "<unset>"
}
result.ExtraData = map[string]string{
pg_sslmode: sslmode,
pgSslmode: sslmode,
}

results = append(results, result)
Expand All @@ -158,6 +164,13 @@ func (s Scanner) IsFalsePositive(_ detectors.Result) (bool, string) {
func findUriMatches(data []byte) []map[string]string {
var matches []map[string]string
for _, uri := range uriPattern.FindAll(data, -1) {
// Capture the database type (e.g., "postgres" or "postgresql")
dbTypeMatch := uriPattern.FindSubmatch(uri)
if len(dbTypeMatch) < 2 {
continue
}
dbType := string(dbTypeMatch[1])

connStr, err := pq.ParseURL(string(uri))
if err != nil {
continue
Expand All @@ -169,6 +182,7 @@ func findUriMatches(data []byte) []map[string]string {
params[part[1]] = part[2]
}

params[pgDbType] = dbType
matches = append(matches, params)
}
return matches
Expand Down Expand Up @@ -198,14 +212,14 @@ func isErrorDatabaseNotFound(err error, dbName string) bool {
}

func verifyPostgres(params map[string]string) (bool, error) {
if sslmode := params[pg_sslmode]; sslmode == pg_sslmode_allow || sslmode == pg_sslmode_prefer {
if sslmode := params[pgSslmode]; sslmode == pgSslmodeAllow || sslmode == pgSslmodePrefer {
// pq doesn't support 'allow' or 'prefer'. If we find either of them, we'll just ignore it. This will trigger
// the same logic that is run if no sslmode is set at all (which mimics 'prefer', which is the default).
delete(params, pg_sslmode)
delete(params, pgSslmode)

// We still want to save the original sslmode in ExtraData, so we'll re-add it before returning.
defer func() {
params[pg_sslmode] = sslmode
params[pgSslmode] = sslmode
}()
}

Expand All @@ -226,14 +240,14 @@ func verifyPostgres(params map[string]string) (bool, error) {
return true, nil
case strings.Contains(err.Error(), "password authentication failed"):
return false, nil
case errors.Is(err, pq.ErrSSLNotSupported) && params[pg_sslmode] == "":
case errors.Is(err, pq.ErrSSLNotSupported) && params[pgSslmode] == "":
// If the sslmode is unset, then either it was unset in the candidate secret, or we've intentionally unset it
// because it was specified as 'allow' or 'prefer', neither of which pq supports. In all of these cases, non-SSL
// connections are acceptable, so now we try a connection without SSL.
params[pg_sslmode] = pg_sslmode_disable
defer delete(params, pg_sslmode) // We want to return with the original params map intact (for ExtraData)
params[pgSslmode] = pgSslmodeDisable
defer delete(params, pgSslmode) // We want to return with the original params map intact (for ExtraData)
return verifyPostgres(params)
case isErrorDatabaseNotFound(err, params[pg_dbname]):
case isErrorDatabaseNotFound(err, params[pgDbname]):
return true, nil // If we know this, we were able to authenticate
default:
return false, err
Expand Down
79 changes: 79 additions & 0 deletions pkg/engine/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1189,3 +1189,82 @@ func TestEngineInitializesCloudProviderDetectors(t *testing.T) {
t.Fatal("no detectors found implementing Endpoints(), did EndpointSetter change?")
}
}

func TestEngineignoreLine(t *testing.T) {
tests := []struct {
name string
content string
expectedFindings int
}{
{
name: "ignore at end of line",
content: `
# tests/example_false_positive.py
def test_something():
connection_string = "who-cares"
# Ignoring this does not work
assert connection_string == "postgres://master_user:master_password@hostname:1234/main" # trufflehog:ignore`,
expectedFindings: 0,
},
{
name: "ignore not on secret line",
content: `
# tests/example_false_positive.py
def test_something():
connection_string = "who-cares"
# Ignoring this does not work
assert some_other_stuff == "blah" # trufflehog:ignore
assert connection_string == "postgres://master_user:master_password@hostname:1234/main"`,
expectedFindings: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

tmpFile, err := os.CreateTemp("", "test_creds")
assert.NoError(t, err)
defer os.Remove(tmpFile.Name())

err = os.WriteFile(tmpFile.Name(), []byte(tt.content), os.ModeAppend)
assert.NoError(t, err)

const defaultOutputBufferSize = 64
opts := []func(*sources.SourceManager){
sources.WithSourceUnits(),
sources.WithBufferedOutput(defaultOutputBufferSize),
}

sourceManager := sources.NewManager(opts...)

conf := Config{
Concurrency: 1,
Decoders: decoders.DefaultDecoders(),
Detectors: DefaultDetectors(),

Check failure on line 1251 in pkg/engine/engine_test.go

View workflow job for this annotation

GitHub Actions / test

undefined: DefaultDetectors

Check failure on line 1251 in pkg/engine/engine_test.go

View workflow job for this annotation

GitHub Actions / test

undefined: DefaultDetectors
Verify: false,
SourceManager: sourceManager,
Dispatcher: NewPrinterDispatcher(new(discardPrinter)),
}

eng, err := NewEngine(ctx, &conf)
assert.NoError(t, err)

eng.Start(ctx)

cfg := sources.FilesystemConfig{Paths: []string{tmpFile.Name()}}
err = eng.ScanFileSystem(ctx, cfg)
assert.NoError(t, err)

assert.NoError(t, eng.Finish(ctx))
assert.Equal(t, tt.expectedFindings, int(eng.GetMetrics().UnverifiedSecretsFound))
})
}
}

0 comments on commit cca7e6b

Please sign in to comment.