diff --git a/internal/pgproxy/pgproxy.go b/internal/pgproxy/pgproxy.go index e5b5a24b5e..37ce587cc2 100644 --- a/internal/pgproxy/pgproxy.go +++ b/internal/pgproxy/pgproxy.go @@ -93,16 +93,16 @@ func HandleConnection(ctx context.Context, conn net.Conn, connectionFn DSNConstr logger.Debugf("startup message: %+v", startup) logger.Debugf("backend connected: %s", conn.RemoteAddr()) - dsn, err := connectionFn(ctx, startup.Parameters) + frontend, err := connectFrontend(ctx, connectionFn, startup) if err != nil { - handleBackendError(ctx, backend, err) - return - } + // try again, in case there was a credential rotation + logger.Warnf("failed to connect frontend: %s, trying again", err) - frontend, err := connectFrontend(ctx, dsn) - if err != nil { - handleBackendError(ctx, backend, err) - return + frontend, err = connectFrontend(ctx, connectionFn, startup) + if err != nil { + handleBackendError(ctx, backend, err) + return + } } logger.Debugf("frontend connected") @@ -171,7 +171,12 @@ func connectBackend(ctx context.Context, conn net.Conn) (*pgproto3.Backend, *pgp } } -func connectFrontend(ctx context.Context, dsn string) (*pgproto3.Frontend, error) { +func connectFrontend(ctx context.Context, connectionFn DSNConstructor, startup *pgproto3.StartupMessage) (*pgproto3.Frontend, error) { + dsn, err := connectionFn(ctx, startup.Parameters) + if err != nil { + return nil, fmt.Errorf("failed to construct dsn: %w", err) + } + conn, err := pgconn.Connect(ctx, dsn) if err != nil { return nil, fmt.Errorf("failed to connect to backend: %w", err)