diff --git a/lib/backend/pgbk/atomicwrite.go b/lib/backend/pgbk/atomicwrite.go index 4e263c70ab07c..78d08e0b648b8 100644 --- a/lib/backend/pgbk/atomicwrite.go +++ b/lib/backend/pgbk/atomicwrite.go @@ -101,9 +101,9 @@ func (b *Backend) AtomicWrite(ctx context.Context, condacts []backend.Conditiona return trace.Wrap(row.Scan(&success)) } - var tries int + var attempts int err = pgcommon.RetryTx(ctx, b.log, b.pool, pgx.TxOptions{}, false, func(tx pgx.Tx) error { - tries++ + attempts++ var condBatch, actBatch pgx.Batch for _, bi := range condBatchItems { @@ -130,14 +130,15 @@ func (b *Backend) AtomicWrite(ctx context.Context, condacts []backend.Conditiona return nil }) - if tries > 1 { - backend.AtomicWriteContention.WithLabelValues(b.GetName()).Add(float64(tries - 1)) + if attempts > 1 { + backend.AtomicWriteContention.WithLabelValues(b.GetName()).Add(float64(attempts - 1)) } - if tries > 2 { - // if we retried more than once, txn experienced non-trivial conflict and we should warn about it. Infrequent warnings of this kind - // are nothing to be concerned about, but high volumes may indicate that an automatic process is creating excessive conflicts. - b.log.Warnf("AtomicWrite retried %d times due to postgres transaction contention. Some conflict is expected, but persistent conflict warnings may indicate an unhealthy state.", tries) + if attempts > 2 { + b.log.WarnContext(ctx, + "AtomicWrite was retried several times due to transaction contention. Some conflict is expected, but persistent conflict warnings may indicate an unhealthy state.", + "attempts", attempts, + ) } if err != nil { diff --git a/lib/backend/pgbk/background.go b/lib/backend/pgbk/background.go index 15fcf06d17657..5a0daebe564d9 100644 --- a/lib/backend/pgbk/background.go +++ b/lib/backend/pgbk/background.go @@ -22,13 +22,13 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "time" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/lib/backend" pgcommon "github.com/gravitational/teleport/lib/backend/pgbk/common" @@ -36,7 +36,7 @@ import ( ) func (b *Backend) backgroundExpiry(ctx context.Context) { - defer b.log.Info("Exited expiry loop.") + defer b.log.InfoContext(ctx, "Exited expiry loop.") for ctx.Err() == nil { // "DELETE FROM kv WHERE expires <= now()" but more complicated: logical @@ -71,15 +71,15 @@ func (b *Backend) backgroundExpiry(ctx context.Context) { return tag.RowsAffected(), nil }) if err != nil { - b.log.WithError(err).Error("Failed to delete expired items.") + b.log.ErrorContext(ctx, "Failed to delete expired items.", "error", err) break } if deleted > 0 { - b.log.WithFields(logrus.Fields{ - "deleted": deleted, - "elapsed": time.Since(t0).String(), - }).Debug("Deleted expired items.") + b.log.DebugContext(ctx, "Deleted expired items.", + "deleted", deleted, + "elapsed", time.Since(t0), + ) } if deleted < int64(b.cfg.ExpiryBatchSize) { @@ -96,16 +96,16 @@ func (b *Backend) backgroundExpiry(ctx context.Context) { } func (b *Backend) backgroundChangeFeed(ctx context.Context) { - defer b.log.Info("Exited change feed loop.") + defer b.log.InfoContext(ctx, "Exited change feed loop.") defer b.buf.Close() for ctx.Err() == nil { - b.log.Info("Starting change feed stream.") + b.log.InfoContext(ctx, "Starting change feed stream.") err := b.runChangeFeed(ctx) if ctx.Err() != nil { break } - b.log.WithError(err).Error("Change feed stream lost.") + b.log.ErrorContext(ctx, "Change feed stream lost.", "error", err) select { case <-ctx.Done(): @@ -135,7 +135,7 @@ func (b *Backend) runChangeFeed(ctx context.Context) error { closeCtx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() if err := conn.Close(closeCtx); err != nil && closeCtx.Err() != nil { - b.log.WithError(err).Warn("Error closing change feed connection.") + b.log.WarnContext(ctx, "Error closing change feed connection.", "error", err) } }() if ac := b.feedConfig.AfterConnect; ac != nil { @@ -164,7 +164,7 @@ func (b *Backend) runChangeFeed(ctx context.Context) error { // permission issues, which would delete the temporary slot (it's deleted on // any error), so we have to do it before that if _, err := conn.Exec(ctx, "SET log_min_messages TO fatal", pgx.QueryExecModeExec); err != nil { - b.log.WithError(err).Debug("Failed to silence log messages for change feed session.") + b.log.DebugContext(ctx, "Failed to silence log messages for change feed session.", "error", err) } // this can be useful on Azure if we have azure_pg_admin permissions but not @@ -174,12 +174,12 @@ func (b *Backend) runChangeFeed(ctx context.Context) error { // // HACK(espadolini): ALTER ROLE CURRENT_USER crashes Postgres on Azure, so // we have to use an explicit username - if b.cfg.AuthMode == AzureADAuth && connConfig.User != "" { + if b.cfg.AuthMode == pgcommon.AzureADAuth && connConfig.User != "" { if _, err := conn.Exec(ctx, fmt.Sprintf("ALTER ROLE %v REPLICATION", pgx.Identifier{connConfig.User}.Sanitize()), pgx.QueryExecModeExec, ); err != nil { - b.log.WithError(err).Debug("Failed to enable replication for the current user.") + b.log.DebugContext(ctx, "Failed to enable replication for the current user.", "error", err) } } @@ -188,7 +188,7 @@ func (b *Backend) runChangeFeed(ctx context.Context) error { // https://github.com/postgres/postgres/blob/b0ec61c9c27fb932ae6524f92a18e0d1fadbc144/src/backend/replication/slot.c#L193-L194 slotName := fmt.Sprintf("teleport_%x", [16]byte(uuid.New())) - b.log.WithField("slot_name", slotName).Info("Setting up change feed.") + b.log.InfoContext(ctx, "Setting up change feed.", "slot_name", slotName) // be noisy about pg_create_logical_replication_slot taking too long, since // hanging here leaves the backend non-functional @@ -202,7 +202,7 @@ func (b *Backend) runChangeFeed(ctx context.Context) error { } cancel() - b.log.WithField("slot_name", slotName).Info("Change feed started.") + b.log.InfoContext(ctx, "Change feed started.", "slot_name", slotName) b.buf.SetInit() defer b.buf.Reset() @@ -260,10 +260,10 @@ func (b *Backend) pollChangeFeed(ctx context.Context, conn *pgx.Conn, addTables, messages := tag.RowsAffected() if messages > 0 { - b.log.WithFields(logrus.Fields{ - "messages": messages, - "elapsed": time.Since(t0).String(), - }).Debug("Fetched change feed events.") + b.log.LogAttrs(ctx, slog.LevelDebug, "Fetched change feed events.", + slog.Int64("messages", messages), + slog.Duration("elapsed", time.Since(t0)), + ) } return messages, nil diff --git a/lib/backend/pgbk/common/auth.go b/lib/backend/pgbk/common/auth.go new file mode 100644 index 0000000000000..54c5bd272f28d --- /dev/null +++ b/lib/backend/pgbk/common/auth.go @@ -0,0 +1,98 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package pgcommon + +import ( + "context" + "fmt" + "log/slog" + "slices" + "strings" + + "github.com/gravitational/trace" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// AuthMode determines if we should use some environment-specific authentication +// mechanism or credentials. +type AuthMode string + +const ( + // StaticAuth uses the static credentials as defined in the connection + // string. + StaticAuth AuthMode = "" + // AzureADAuth gets a connection token from Azure and uses it as the + // password when connecting. + AzureADAuth AuthMode = "azure" + // GCPSQLIAMAuth fetches an access token and uses it as password when + // connecting to GCP SQL PostgreSQL. + GCPSQLIAMAuth AuthMode = "gcp-sql" + // GCPAlloyDBIAMAuth fetches an access token and uses it as password when + // connecting to GCP AlloyDB (PostgreSQL-compatible). + GCPAlloyDBIAMAuth AuthMode = "gcp-alloydb" +) + +var supportedAuthModes = []AuthMode{ + StaticAuth, + AzureADAuth, + GCPSQLIAMAuth, + GCPAlloyDBIAMAuth, +} + +// Check returns an error if the AuthMode is invalid. +func (a AuthMode) Check() error { + if slices.Contains(supportedAuthModes, a) { + return nil + } + + quotedModes := make([]string, 0, len(supportedAuthModes)) + for _, mode := range supportedAuthModes { + quotedModes = append(quotedModes, fmt.Sprintf("%q", mode)) + } + + return trace.BadParameter("invalid authentication mode %q, should be one of %s", a, strings.Join(quotedModes, ", ")) +} + +// ConfigurePoolConfigs configures pgxpool.Config based on the authMode. +func (a AuthMode) ConfigurePoolConfigs(ctx context.Context, logger *slog.Logger, configs ...*pgxpool.Config) error { + if bc, err := a.getBeforeConnect(ctx, logger); err != nil { + return trace.Wrap(err) + } else if bc != nil { + for _, config := range configs { + config.BeforeConnect = bc + } + } + return nil +} + +func (a AuthMode) getBeforeConnect(ctx context.Context, logger *slog.Logger) (func(context.Context, *pgx.ConnConfig) error, error) { + switch a { + case AzureADAuth: + bc, err := AzureBeforeConnect(ctx, logger) + return bc, trace.Wrap(err) + case GCPSQLIAMAuth: + bc, err := GCPSQLBeforeConnect(ctx, logger) + return bc, trace.Wrap(err) + case GCPAlloyDBIAMAuth: + bc, err := GCPAlloyDBBeforeConnect(ctx, logger) + return bc, trace.Wrap(err) + } + return nil, nil +} diff --git a/lib/backend/pgbk/common/auth_test.go b/lib/backend/pgbk/common/auth_test.go new file mode 100644 index 0000000000000..e3170e550e8cf --- /dev/null +++ b/lib/backend/pgbk/common/auth_test.go @@ -0,0 +1,116 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package pgcommon + +import ( + "context" + "log/slog" + "os" + "testing" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/utils" +) + +func TestMain(m *testing.M) { + utils.InitLoggerForTests() + os.Exit(m.Run()) +} + +func TestAuthMode(t *testing.T) { + mustSetGoogleApplicationCredentialsEnv(t) + mustSetAzureEnvironmentCredential(t) + + verifyBeforeConnectIsSet := func(t *testing.T, config *pgxpool.Config) { + t.Helper() + require.NotNil(t, config.BeforeConnect) + } + verifyNothingIsSet := func(t *testing.T, config *pgxpool.Config) { + t.Helper() + require.NotNil(t, config) + require.Equal(t, pgxpool.Config{}, *config) + } + + tests := []struct { + authMode AuthMode + requireCheckError require.ErrorAssertionFunc + verifyPoolConfigAfterConfigure func(*testing.T, *pgxpool.Config) + }{ + { + authMode: AuthMode("unknown-mode"), + requireCheckError: require.Error, + }, + { + authMode: StaticAuth, + requireCheckError: require.NoError, + verifyPoolConfigAfterConfigure: verifyNothingIsSet, + }, + { + authMode: AzureADAuth, + requireCheckError: require.NoError, + verifyPoolConfigAfterConfigure: verifyBeforeConnectIsSet, + }, + { + authMode: GCPSQLIAMAuth, + requireCheckError: require.NoError, + verifyPoolConfigAfterConfigure: verifyBeforeConnectIsSet, + }, + { + authMode: GCPAlloyDBIAMAuth, + requireCheckError: require.NoError, + verifyPoolConfigAfterConfigure: verifyBeforeConnectIsSet, + }, + } + + ctx := context.Background() + logger := slog.Default() + for _, tc := range tests { + t.Run(string(tc.authMode), func(t *testing.T) { + err := tc.authMode.Check() + if err != nil { + // Just checking out how the error message looks like. + t.Log(err) + } + tc.requireCheckError(t, err) + + if tc.verifyPoolConfigAfterConfigure != nil { + configs := []*pgxpool.Config{ + &pgxpool.Config{}, + &pgxpool.Config{}, + } + + err := tc.authMode.ConfigurePoolConfigs(ctx, logger, configs...) + require.NoError(t, err) + + for _, config := range configs { + tc.verifyPoolConfigAfterConfigure(t, config) + } + } + }) + } +} + +func mustSetAzureEnvironmentCredential(t *testing.T) { + t.Helper() + t.Setenv("AZURE_TENANT_ID", "teleport-test-tenant-id") + t.Setenv("AZURE_CLIENT_ID", "teleport-test-client-id") + t.Setenv("AZURE_CLIENT_SECRET", "teleport-test-client-secret") +} diff --git a/lib/backend/pgbk/common/azure.go b/lib/backend/pgbk/common/azure.go index 9fef75a0c5fd9..6676820bae671 100644 --- a/lib/backend/pgbk/common/azure.go +++ b/lib/backend/pgbk/common/azure.go @@ -20,19 +20,19 @@ package pgcommon import ( "context" + "log/slog" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/gravitational/trace" "github.com/jackc/pgx/v5" - "github.com/sirupsen/logrus" ) // AzureBeforeConnect will return a pgx BeforeConnect function suitable for // Azure AD authentication. The returned function will set the password of the // connection to a token for the relevant scope. -func AzureBeforeConnect(log logrus.FieldLogger) (func(ctx context.Context, config *pgx.ConnConfig) error, error) { +func AzureBeforeConnect(ctx context.Context, logger *slog.Logger) (func(ctx context.Context, config *pgx.ConnConfig) error, error) { cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return nil, trace.Wrap(err, "creating Azure credentials") @@ -48,7 +48,7 @@ func AzureBeforeConnect(log logrus.FieldLogger) (func(ctx context.Context, confi return trace.Wrap(err, "obtaining Azure authentication token") } - log.WithField("ttl", time.Until(token.ExpiresOn).String()).Debug("Acquired Azure access token.") + logger.DebugContext(ctx, "Acquired Azure access token.", "ttl", time.Until(token.ExpiresOn)) config.Password = token.Token return nil diff --git a/lib/backend/pgbk/common/gcp.go b/lib/backend/pgbk/common/gcp.go new file mode 100644 index 0000000000000..8e217de986bb5 --- /dev/null +++ b/lib/backend/pgbk/common/gcp.go @@ -0,0 +1,139 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package pgcommon + +import ( + "context" + "fmt" + "log/slog" + "time" + + credentials "cloud.google.com/go/iam/credentials/apiv1" + "cloud.google.com/go/iam/credentials/apiv1/credentialspb" + "github.com/gravitational/trace" + "github.com/jackc/pgx/v5" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + + "github.com/gravitational/teleport/lib/utils/gcp" +) + +// GCPSQLBeforeConnect returns a pgx BeforeConnect function suitable for GCP +// SQL PostgreSQL with IAM authentication. +func GCPSQLBeforeConnect(ctx context.Context, logger *slog.Logger) (func(ctx context.Context, config *pgx.ConnConfig) error, error) { + return gcpOAuthAccessTokenBeforeConnect(ctx, gcpAccessTokenGetterImpl{}, gcpSQLOAuthScope, logger) +} + +// GCPAlloyDBBeforeConnect returns a pgx BeforeConnect function suitable for GCP +// AlloyDB (PostgreSQL-compatible) with IAM authentication. +func GCPAlloyDBBeforeConnect(ctx context.Context, logger *slog.Logger) (func(ctx context.Context, config *pgx.ConnConfig) error, error) { + return gcpOAuthAccessTokenBeforeConnect(ctx, gcpAccessTokenGetterImpl{}, gcpAlloyDBOAuthScope, logger) +} + +const ( + // gcpSQLOAuthScope is the scope used for GCP SQL IAM authentication. + // https://developers.google.com/identity/protocols/oauth2/scopes#sqladmin + gcpSQLOAuthScope = "https://www.googleapis.com/auth/sqlservice.admin" + // gcpAlloyDBOAuthScope is the scope used for GCP AlloyDB IAM authentication. + // https://cloud.google.com/alloydb/docs/connect-iam + gcpAlloyDBOAuthScope = "https://www.googleapis.com/auth/alloydb.login" + + gcpServiceAccountEmailSuffix = ".gserviceaccount.com" +) + +type gcpAccessTokenGetter interface { + getFromCredentials(ctx context.Context, credentials *google.Credentials) (*oauth2.Token, error) + generateForServiceAccount(ctx context.Context, serviceAccount, scope string) (string, time.Time, error) +} + +func gcpOAuthAccessTokenBeforeConnect(ctx context.Context, tokenGetter gcpAccessTokenGetter, scope string, logger *slog.Logger) (func(context.Context, *pgx.ConnConfig) error, error) { + defaultCred, err := google.FindDefaultCredentials(ctx, scope) + if err != nil { + // google.FindDefaultCredentials gives pretty error descriptions already. + return nil, trace.Wrap(err) + } + + // This function tries to capture service account emails from various + // credentials methods but may fail for some unknown scenarios. + defaultServiceAccount, err := gcp.GetServiceAccountFromCredentials(defaultCred) + if err != nil || defaultServiceAccount == "" { + logger.WarnContext(ctx, "Failed to get service account email from default google credentials. Teleport will assume the database user in the PostgreSQL connection string matches the service account of the default google credentials.", "err", err, "sa", defaultServiceAccount) + } + + return func(ctx context.Context, config *pgx.ConnConfig) error { + // IAM auth users have the PostgreSQL username of their emails minus the + // ".gserviceaccount.com" part. Now add the suffix back for the full + // service account email. + serviceAccountToAuth := config.User + gcpServiceAccountEmailSuffix + + // If the requested db user is for another service account, the + // "host"/default service account can impersonate the target service + // account as a Token Creator. This is useful when using a different + // database user for change feed. + if defaultServiceAccount != "" && defaultServiceAccount != serviceAccountToAuth { + token, expires, err := tokenGetter.generateForServiceAccount(ctx, serviceAccountToAuth, scope) + if err != nil { + return trace.Wrap(err, "generating GCP access token for %v", serviceAccountToAuth) + } + + logger.DebugContext(ctx, "Generated GCP access token.", "service_account", serviceAccountToAuth, "ttl", time.Until(expires).String()) + config.Password = token + return nil + } + + token, err := tokenGetter.getFromCredentials(ctx, defaultCred) + if err != nil { + return trace.Wrap(err, "obtaining GCP access token from default credentials") + } + + logger.DebugContext(ctx, "Obtained GCP access token from default credentials.", "ttl", time.Until(token.Expiry).String(), "token_type", token.TokenType) + config.Password = token.AccessToken + return nil + }, nil +} + +type gcpAccessTokenGetterImpl struct { +} + +func (g gcpAccessTokenGetterImpl) getFromCredentials(ctx context.Context, credentials *google.Credentials) (*oauth2.Token, error) { + token, err := credentials.TokenSource.Token() + return token, trace.Wrap(err) +} + +func (g gcpAccessTokenGetterImpl) generateForServiceAccount(ctx context.Context, serviceAccount, scope string) (string, time.Time, error) { + gcpIAM, err := credentials.NewIamCredentialsClient(ctx) + if err != nil { + return "", time.Time{}, trace.Wrap(err) + } + + defer func() { + if err := gcpIAM.Close(); err != nil { + slog.DebugContext(ctx, "Failed to close GCP IAM Credentials client.", "err", err) + } + }() + + resp, err := gcpIAM.GenerateAccessToken(ctx, &credentialspb.GenerateAccessTokenRequest{ + Name: fmt.Sprintf("projects/-/serviceAccounts/%v", serviceAccount), + Scope: []string{scope}, + }) + if err != nil { + return "", time.Time{}, trace.Wrap(err) + } + return resp.AccessToken, resp.ExpireTime.AsTime(), nil +} diff --git a/lib/backend/pgbk/common/gcp_test.go b/lib/backend/pgbk/common/gcp_test.go new file mode 100644 index 0000000000000..ba181a97d3ff1 --- /dev/null +++ b/lib/backend/pgbk/common/gcp_test.go @@ -0,0 +1,122 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package pgcommon + +import ( + "context" + "fmt" + "log/slog" + "os" + "path" + "testing" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +type fakeGCPAccessTokenGetter struct { +} + +func (f fakeGCPAccessTokenGetter) getFromCredentials(context.Context, *google.Credentials) (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "token-from-default-credentials", + Expiry: time.Now().Add(time.Hour), + }, nil +} + +func (f fakeGCPAccessTokenGetter) generateForServiceAccount(ctx context.Context, serviceAccount, scope string) (string, time.Time, error) { + return fmt.Sprintf("token-for-%s-with-scope-%s", serviceAccount, scope), time.Now().Add(time.Hour), nil +} + +func Test_gcpOAuthAccessTokenBeforeConnect(t *testing.T) { + mustSetGoogleApplicationCredentialsEnv(t) + + ctx := context.Background() + tokenGetter := fakeGCPAccessTokenGetter{} + tests := []struct { + name string + config *pgx.ConnConfig + wantUser string + wantPassword string + }{ + { + name: "default service account as user in connection string", + config: &pgx.ConnConfig{ + Config: pgconn.Config{ + User: "my-service-account@teleport-example-123456.iam", + }, + }, + wantUser: "my-service-account@teleport-example-123456.iam", + wantPassword: "token-from-default-credentials", + }, + { + name: "another service account as user in connection string", + config: &pgx.ConnConfig{ + Config: pgconn.Config{ + User: "another-service-account@teleport-example-123456.iam", + }, + }, + wantUser: "another-service-account@teleport-example-123456.iam", + wantPassword: "token-for-another-service-account@teleport-example-123456.iam.gserviceaccount.com-with-scope-test-scope", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + bc, err := gcpOAuthAccessTokenBeforeConnect(ctx, tokenGetter, "test-scope", slog.Default()) + require.NoError(t, err) + + err = bc(context.Background(), tc.config) + require.NoError(t, err) + require.Equal(t, tc.wantUser, tc.config.User) + require.Equal(t, tc.wantPassword, tc.config.Password) + }) + } +} + +func mustSetGoogleApplicationCredentialsEnv(t *testing.T) { + t.Helper() + + file := path.Join(t.TempDir(), uuid.New().String()) + err := os.WriteFile(file, []byte(fakeServiceAccountCredentialsJSON), 0644) + require.NoError(t, err) + + t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", file) +} + +const ( + fakeServiceAccountCredentialsJSON = `{ + "type": "service_account", + "project_id": "teleport-example-123456", + "private_key_id": "1234569890abcdef1234567890abcdef12345678", + "private_key": "fake-private-key", + "client_email": "my-service-account@teleport-example-123456.iam.gserviceaccount.com", + "client_id": "111111111111111111111", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/my-service-account%40teleport-example-123456.iam.gserviceaccount.com", + "universe_domain": "googleapis.com" +}` +) diff --git a/lib/backend/pgbk/common/utils.go b/lib/backend/pgbk/common/utils.go index 314b5a1b76a31..44d22fb1f02f5 100644 --- a/lib/backend/pgbk/common/utils.go +++ b/lib/backend/pgbk/common/utils.go @@ -22,6 +22,7 @@ import ( "context" "errors" "fmt" + "log/slog" "time" "github.com/gravitational/trace" @@ -29,7 +30,6 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/utils/retryutils" ) @@ -63,10 +63,10 @@ func ConnectPostgres(ctx context.Context, poolConfig *pgxpool.Config) (*pgx.Conn // TryEnsureDatabase will connect to the "postgres" database and attempt to // create the database named in the pool's configuration. -func TryEnsureDatabase(ctx context.Context, poolConfig *pgxpool.Config, log logrus.FieldLogger) { +func TryEnsureDatabase(ctx context.Context, poolConfig *pgxpool.Config, log *slog.Logger) { pgConn, err := ConnectPostgres(ctx, poolConfig) if err != nil { - log.WithError(err).Warn("Failed to connect to the \"postgres\" database.") + log.WarnContext(ctx, "Failed to connect to the \"postgres\" database.", "error", err) return } @@ -81,13 +81,13 @@ func TryEnsureDatabase(ctx context.Context, poolConfig *pgxpool.Config, log logr // will fail immediately if we can't connect, anyway, so we can log // permission errors at debug level here. if IsCode(err, pgerrcode.InsufficientPrivilege) { - log.WithError(err).Debug("Error creating database due to insufficient privileges.") + log.DebugContext(ctx, "Error creating database due to insufficient privileges.", "error", err) } else { - log.WithError(err).Warn("Error creating database.") + log.WarnContext(ctx, "Error creating database.", "error", err) } } if err := pgConn.Close(ctx); err != nil { - log.WithError(err).Warn("Error closing connection to the \"postgres\" database.") + log.WarnContext(ctx, "Error closing connection to the \"postgres\" database.", "error", err) } } @@ -97,7 +97,7 @@ func TryEnsureDatabase(ctx context.Context, poolConfig *pgxpool.Config, log logr // any data has been sent. It will retry unique constraint violation and // exclusion constraint violations, so the closure should not rely on those for // normal behavior. -func Retry[T any](ctx context.Context, log logrus.FieldLogger, f func() (T, error)) (T, error) { +func Retry[T any](ctx context.Context, log *slog.Logger, f func() (T, error)) (T, error) { const idempotent = false v, err := retry(ctx, log, idempotent, f) return v, trace.Wrap(err) @@ -108,13 +108,13 @@ func Retry[T any](ctx context.Context, log logrus.FieldLogger, f func() (T, erro // assumes that f is idempotent, so it will retry even in ambiguous situations. // It will retry unique constraint violation and exclusion constraint // violations, so the closure should not rely on those for normal behavior. -func RetryIdempotent[T any](ctx context.Context, log logrus.FieldLogger, f func() (T, error)) (T, error) { +func RetryIdempotent[T any](ctx context.Context, log *slog.Logger, f func() (T, error)) (T, error) { const idempotent = true v, err := retry(ctx, log, idempotent, f) return v, trace.Wrap(err) } -func retry[T any](ctx context.Context, log logrus.FieldLogger, isIdempotent bool, f func() (T, error)) (T, error) { +func retry[T any](ctx context.Context, log *slog.Logger, isIdempotent bool, f func() (T, error)) (T, error) { var v T var err error v, err = f() @@ -143,18 +143,22 @@ func retry[T any](ctx context.Context, log logrus.FieldLogger, isIdempotent bool _ = errors.As(err, &pgErr) if pgErr != nil && isSerializationErrorCode(pgErr.Code) { - log.WithError(err). - WithField("attempt", i). - Debug("Operation failed due to conflicts, retrying quickly.") + log.LogAttrs(ctx, slog.LevelDebug, + "Operation failed due to conflicts, retrying quickly.", + slog.Int("attempt", i), + slog.Any("error", err), + ) retry.Reset() // the very first attempt gets instant retry on serialization failure if i > 1 { retry.Inc() } } else if (isIdempotent && pgErr == nil) || pgconn.SafeToRetry(err) { - log.WithError(err). - WithField("attempt", i). - Debug("Operation failed, retrying.") + log.LogAttrs(ctx, slog.LevelDebug, + "Operation failed, retrying.", + slog.Int("attempt", i), + slog.Any("error", err), + ) retry.Inc() } else { // we either know we shouldn't retry (on a database error), or we @@ -207,7 +211,7 @@ func isSerializationErrorCode(code string) bool { // [pgx.BeginTxFunc]. func RetryTx( ctx context.Context, - log logrus.FieldLogger, + log *slog.Logger, db interface { BeginTx(context.Context, pgx.TxOptions) (pgx.Tx, error) }, @@ -233,7 +237,7 @@ func IsCode(err error, code string) bool { // the name of a table used to hold schema version numbers. func SetupAndMigrate( ctx context.Context, - log logrus.FieldLogger, + log *slog.Logger, db interface { BeginTx(context.Context, pgx.TxOptions) (pgx.Tx, error) Exec(context.Context, string, ...any) (pgconn.CommandTag, error) @@ -259,7 +263,10 @@ func SetupAndMigrate( }); err != nil { // the very first SELECT in the next transaction will fail, we don't // need anything higher than debug here - log.WithError(err).Debugf("Failed to confirm the existence of the %v table.", tableName) + log.DebugContext(ctx, "Failed to confirm the existence of the configured table.", + "table", tableName, + "error", err, + ) } const idempotent = true @@ -307,10 +314,10 @@ func SetupAndMigrate( } if int(version) != len(schemas) { - log.WithFields(logrus.Fields{ - "previous_version": version, - "current_version": len(schemas), - }).Info("Migrated database schema.") + log.InfoContext(ctx, "Migrated database schema.", + "previous_version", version, + "current_version", len(schemas), + ) } return nil diff --git a/lib/backend/pgbk/pgbk.go b/lib/backend/pgbk/pgbk.go index 098a1f0a6a1ce..e8cb20444e167 100644 --- a/lib/backend/pgbk/pgbk.go +++ b/lib/backend/pgbk/pgbk.go @@ -22,6 +22,7 @@ import ( "bytes" "context" "errors" + "log/slog" "sync" "time" @@ -30,7 +31,6 @@ import ( "github.com/jackc/pgx/v5/pgtype/zeronull" "github.com/jackc/pgx/v5/pgxpool" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" @@ -64,36 +64,13 @@ const ( defaultExpiryInterval = 30 * time.Second ) -// AuthMode determines if we should use some environment-specific authentication -// mechanism or credentials. -type AuthMode string - -const ( - // StaticAuth uses the static credentials as defined in the connection - // string. - StaticAuth AuthMode = "" - // AzureADAuth gets a connection token from Azure and uses it as the - // password when connecting. - AzureADAuth AuthMode = "azure" -) - -// Check returns an error if the AuthMode is invalid. -func (a AuthMode) Check() error { - switch a { - case StaticAuth, AzureADAuth: - return nil - default: - return trace.BadParameter("invalid authentication mode %q, should be %q or %q", a, StaticAuth, AzureADAuth) - } -} - // Config is the configuration struct for [Backend]; outside of tests or custom // code, it's usually generated by converting the [backend.Params] from the // Teleport configuration file. type Config struct { ConnString string `json:"conn_string"` - AuthMode AuthMode `json:"auth_mode"` + AuthMode pgcommon.AuthMode `json:"auth_mode"` ChangeFeedConnString string `json:"change_feed_conn_string"` ChangeFeedPollInterval types.Duration `json:"change_feed_poll_interval"` @@ -172,26 +149,22 @@ func NewWithConfig(ctx context.Context, cfg Config) (*Backend, error) { return nil, trace.Wrap(err) } - log := logrus.WithField(teleport.ComponentKey, componentName) + log := slog.With(teleport.ComponentKey, componentName) - if cfg.AuthMode == AzureADAuth { - bc, err := pgcommon.AzureBeforeConnect(log) - if err != nil { - return nil, trace.Wrap(err) - } - poolConfig.BeforeConnect = bc - feedConfig.BeforeConnect = bc + if err := cfg.AuthMode.ConfigurePoolConfigs(ctx, log, poolConfig, feedConfig); err != nil { + return nil, trace.Wrap(err) } const defaultTxIsoParamName = "default_transaction_isolation" if defaultTxIso := poolConfig.ConnConfig.RuntimeParams[defaultTxIsoParamName]; defaultTxIso != "" { - log.WithField(defaultTxIsoParamName, defaultTxIso). - Error("The " + defaultTxIsoParamName + " parameter was overridden in the connection string; proceeding with an unsupported configuration.") + const message = "The " + defaultTxIsoParamName + " parameter was overridden in the connection string; proceeding with an unsupported configuration." + log.ErrorContext(ctx, message, + defaultTxIsoParamName, defaultTxIso) } else { poolConfig.ConnConfig.RuntimeParams[defaultTxIsoParamName] = "serializable" } - log.Info("Setting up backend.") + log.InfoContext(ctx, "Setting up backend.") pgcommon.TryEnsureDatabase(ctx, poolConfig, log) @@ -238,7 +211,7 @@ type Backend struct { cfg Config feedConfig *pgxpool.Config - log logrus.FieldLogger + log *slog.Logger pool *pgxpool.Pool buf *backend.CircularBuffer diff --git a/lib/events/pgevents/pgevents.go b/lib/events/pgevents/pgevents.go index 841dbbba2a342..7adcce24e1e25 100644 --- a/lib/events/pgevents/pgevents.go +++ b/lib/events/pgevents/pgevents.go @@ -21,6 +21,7 @@ package pgevents import ( "context" "fmt" + "log/slog" "net/url" "strconv" "strings" @@ -31,7 +32,6 @@ import ( "github.com/gravitational/trace" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" @@ -64,35 +64,12 @@ const ( retentionPeriodParam = "retention_period" ) -// AuthMode determines if we should use some environment-specific authentication -// mechanism or credentials. -type AuthMode string - -const ( - // FixedAuth uses the static credentials as defined in the connection - // string. - FixedAuth AuthMode = "" - // AzureADAuth gets a connection token from Azure and uses it as the - // password when connecting. - AzureADAuth AuthMode = "azure" -) - -// Check returns an error if the AuthMode is invalid. -func (a AuthMode) Check() error { - switch a { - case FixedAuth, AzureADAuth: - return nil - default: - return trace.BadParameter("invalid authentication mode %q", a) - } -} - // Config is the configuration struct to pass to New. type Config struct { - Log logrus.FieldLogger + Log *slog.Logger PoolConfig *pgxpool.Config - AuthMode AuthMode + AuthMode pgcommon.AuthMode DisableCleanup bool RetentionPeriod time.Duration @@ -121,7 +98,7 @@ func (c *Config) SetFromURL(u *url.URL) error { } c.PoolConfig = poolConfig - c.AuthMode = AuthMode(params.Get(authModeParam)) + c.AuthMode = pgcommon.AuthMode(params.Get(authModeParam)) if s := params.Get(disableCleanupParam); s != "" { b, err := strconv.ParseBool(s) @@ -176,7 +153,7 @@ func (c *Config) CheckAndSetDefaults() error { } if c.Log == nil { - c.Log = logrus.WithField(teleport.ComponentKey, componentName) + c.Log = slog.With(teleport.ComponentKey, componentName) } return nil @@ -189,15 +166,11 @@ func New(ctx context.Context, cfg Config) (*Log, error) { return nil, trace.Wrap(err) } - if cfg.AuthMode == AzureADAuth { - bc, err := pgcommon.AzureBeforeConnect(cfg.Log) - if err != nil { - return nil, trace.Wrap(err) - } - cfg.PoolConfig.BeforeConnect = bc + if err := cfg.AuthMode.ConfigurePoolConfigs(ctx, cfg.Log, cfg.PoolConfig); err != nil { + return nil, trace.Wrap(err) } - cfg.Log.Info("Setting up events backend.") + cfg.Log.InfoContext(ctx, "Setting up events backend.") pgcommon.TryEnsureDatabase(ctx, cfg.PoolConfig, cfg.Log) @@ -223,14 +196,14 @@ func New(ctx context.Context, cfg Config) (*Log, error) { go l.periodicCleanup(periodicCtx, cfg.CleanupInterval, cfg.RetentionPeriod) } - l.log.Info("Started events backend.") + l.log.InfoContext(ctx, "Started events backend.") return l, nil } // Log is an external [events.AuditLogger] backed by a PostgreSQL database. type Log struct { - log logrus.FieldLogger + log *slog.Logger pool *pgxpool.Pool cancel context.CancelFunc @@ -275,7 +248,7 @@ func (l *Log) periodicCleanup(ctx context.Context, cleanupInterval, retentionPer case <-tk.C: } - l.log.Debug("Executing periodic cleanup.") + l.log.DebugContext(ctx, "Executing periodic cleanup.") deleted, err := pgcommon.RetryIdempotent(ctx, l.log, func() (int64, error) { tag, err := l.pool.Exec(ctx, "DELETE FROM events WHERE creation_time < (now() - $1::interval)", @@ -288,9 +261,9 @@ func (l *Log) periodicCleanup(ctx context.Context, cleanupInterval, retentionPer return tag.RowsAffected(), nil }) if err != nil { - l.log.WithError(err).Error("Failed to execute periodic cleanup.") + l.log.ErrorContext(ctx, "Failed to execute periodic cleanup.", "error", err) } else { - l.log.WithField("deleted_rows", deleted).Debug("Executed periodic cleanup.") + l.log.DebugContext(ctx, "Executed periodic cleanup.", "deleted", deleted) } } } diff --git a/lib/events/pgevents/pgevents_test.go b/lib/events/pgevents/pgevents_test.go index 07e193f9ed62f..126b3f8b549c0 100644 --- a/lib/events/pgevents/pgevents_test.go +++ b/lib/events/pgevents/pgevents_test.go @@ -28,6 +28,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + pgcommon "github.com/gravitational/teleport/lib/backend/pgbk/common" "github.com/gravitational/teleport/lib/events/test" "github.com/gravitational/teleport/lib/utils" ) @@ -86,7 +87,12 @@ func TestPostgresEvents(t *testing.T) { func TestConfig(t *testing.T) { configs := map[string]*Config{ "postgres://foo#auth_mode=azure": { - AuthMode: AzureADAuth, + AuthMode: pgcommon.AzureADAuth, + RetentionPeriod: defaultRetentionPeriod, + CleanupInterval: defaultCleanupInterval, + }, + "postgres://foo?sslmode=require#auth_mode=azure": { + AuthMode: pgcommon.AzureADAuth, RetentionPeriod: defaultRetentionPeriod, CleanupInterval: defaultCleanupInterval, }, diff --git a/lib/utils/gcp/gcp.go b/lib/utils/gcp/gcp.go index 00d71db3e89b5..141a74edbd680 100644 --- a/lib/utils/gcp/gcp.go +++ b/lib/utils/gcp/gcp.go @@ -19,9 +19,12 @@ package gcp import ( + "encoding/json" "strings" + "cloud.google.com/go/compute/metadata" "github.com/gravitational/trace" + "golang.org/x/oauth2/google" ) // SortedGCPServiceAccounts sorts service accounts by project and service account name. @@ -56,7 +59,7 @@ func (s SortedGCPServiceAccounts) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -const expectedParentDomain = "iam.gserviceaccount.com" +const serviceAccountParentDomain = "iam.gserviceaccount.com" func ProjectIDFromServiceAccountName(serviceAccount string) (string, error) { if serviceAccount == "" { @@ -80,8 +83,8 @@ func ProjectIDFromServiceAccountName(serviceAccount string) (string, error) { return "", trace.BadParameter("invalid service account format: missing project ID") } - if iamDomain != expectedParentDomain { - return "", trace.BadParameter("invalid service account format: expected suffix %q, got %q", expectedParentDomain, iamDomain) + if iamDomain != serviceAccountParentDomain { + return "", trace.BadParameter("invalid service account format: expected suffix %q, got %q", serviceAccountParentDomain, iamDomain) } return projectID, nil @@ -91,3 +94,62 @@ func ValidateGCPServiceAccountName(serviceAccount string) error { _, err := ProjectIDFromServiceAccountName(serviceAccount) return err } + +// GetServiceAccountFromCredentials attempts to retrieve service account email +// from provided credentials. +func GetServiceAccountFromCredentials(credentials *google.Credentials) (string, error) { + // When credentials JSON file is provided through either + // GOOGLE_APPLICATION_CREDENTIALS env var or a well known file. + if len(credentials.JSON) > 0 { + sa, err := GetServiceAccountFromCredentialsJSON(credentials.JSON) + return sa, trace.Wrap(err) + } + + // No credentials from JSON files but using metadata endpoints when on + // Google Compute Engine. + if metadata.OnGCE() { + email, err := metadata.Email("") + return email, trace.Wrap(err) + } + + return "", trace.NotImplemented("unknown environment for getting service account") +} + +// GetServiceAccountFromCredentialsJSON attempts to retrieve service account +// email from provided credentials JSON. +func GetServiceAccountFromCredentialsJSON(credentialsJSON []byte) (string, error) { + content := struct { + // ClientEmail defines the service account email for service_account + // credentials. + // + // Reference: https://google.aip.dev/auth/4112 + ClientEmail string `json:"client_email"` + + // ServiceAccountImpersonationURL is used for external + // account_credentials (e.g. Workload Identity Federation) when using + // service account personation. + // + // Reference: https://google.aip.dev/auth/4117 + ServiceAccountImpersonationURL string `json:"service_account_impersonation_url"` + }{} + + if err := json.Unmarshal(credentialsJSON, &content); err != nil { + return "", trace.Wrap(err) + } + + if content.ClientEmail != "" { + return content.ClientEmail, nil + } + + // Format: + // https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/$EMAIL:generateAccessToken + if _, after, ok := strings.Cut(content.ServiceAccountImpersonationURL, "/serviceAccounts/"); ok { + index := strings.LastIndex(after, serviceAccountParentDomain) + if index < 0 { + return "", trace.BadParameter("invalid service_account_impersonation_url %q", content.ServiceAccountImpersonationURL) + } + return after[:index+len(serviceAccountParentDomain)], nil + } + + return "", trace.NotImplemented("unknown environment for getting service account") +} diff --git a/lib/utils/gcp/gcp_test.go b/lib/utils/gcp/gcp_test.go index 23918f3f7ddda..bcf2b71e2002b 100644 --- a/lib/utils/gcp/gcp_test.go +++ b/lib/utils/gcp/gcp_test.go @@ -192,3 +192,73 @@ func TestProjectIDFromServiceAccountName(t *testing.T) { }) } } + +func TestGetServiceAccountFromCredentialsJSON(t *testing.T) { + tests := []struct { + name string + credentialsJSON []byte + checkError require.ErrorAssertionFunc + wantServiceAccount string + }{ + { + name: "service_account credentials", + credentialsJSON: []byte(fakeServiceAccountCredentialsJSON), + checkError: require.NoError, + wantServiceAccount: "my-service-account@teleport-example-123456.iam.gserviceaccount.com", + }, + { + name: "external_account credentials with sa impersonation", + credentialsJSON: []byte(fakeExternalAccountCredentialsJSON), + checkError: require.NoError, + wantServiceAccount: "my-service-account@teleport-example-987654.iam.gserviceaccount.com", + }, + { + name: "unknown credentials", + credentialsJSON: []byte(`{}`), + checkError: require.Error, + }, + { + name: "bad json", + credentialsJSON: []byte(`{}`), + checkError: require.Error, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + sa, err := GetServiceAccountFromCredentialsJSON(tc.credentialsJSON) + tc.checkError(t, err) + require.Equal(t, tc.wantServiceAccount, sa) + }) + } +} + +const ( + fakeServiceAccountCredentialsJSON = `{ + "type": "service_account", + "project_id": "teleport-example-123456", + "private_key_id": "1234569890abcdef1234567890abcdef12345678", + "private_key": "fake-private-key", + "client_email": "my-service-account@teleport-example-123456.iam.gserviceaccount.com", + "client_id": "111111111111111111111", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/my-service-account%40teleport-example-123456.iam.gserviceaccount.com", + "universe_domain": "googleapis.com" +}` + fakeExternalAccountCredentialsJSON = `{ + "type": "external_account", + "audience": "//iam.googleapis.com/projects/111111111111/locations/global/workloadIdentityPools/my-identity-pool/providers/my-provider", + "subject_token_type": "urn:ietf:params:aws:token-type:aws4_request", + "service_account_impersonation_url": "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/my-service-account@teleport-example-987654.iam.gserviceaccount.com:generateAccessToken", + "token_url": "https://sts.googleapis.com/v1/token", + "credential_source": { + "environment_id": "aws1", + "region_url": "http://169.254.169.254/latest/meta-data/placement/availability-zone", + "url": "http://169.254.169.254/latest/meta-data/iam/security-credentials", + "regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "imdsv2_session_token_url": "http://169.254.169.254/latest/api/token" + } +}` +)