diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index bdf9e782ad45..7bfb8cb726a3 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -484,6 +484,7 @@ ALL_TESTS = [ "//pkg/sql/parser:parser_disallowed_imports_test", "//pkg/sql/parser:parser_test", "//pkg/sql/pgrepl/pgreplparser:pgreplparser_test", + "//pkg/sql/pgrepl:pgrepl_test", "//pkg/sql/pgwire/hba:hba_test", "//pkg/sql/pgwire/identmap:identmap_test", "//pkg/sql/pgwire/pgerror:pgerror_test", @@ -1864,6 +1865,7 @@ GO_TARGETS = [ "//pkg/sql/pgrepl/pgreplparser:pgreplparser", "//pkg/sql/pgrepl/pgreplparser:pgreplparser_test", "//pkg/sql/pgrepl/pgrepltree:pgrepltree", + "//pkg/sql/pgrepl:pgrepl_test", "//pkg/sql/pgwire/hba:hba", "//pkg/sql/pgwire/hba:hba_test", "//pkg/sql/pgwire/identmap:identmap", diff --git a/pkg/sql/exec_util.go b/pkg/sql/exec_util.go index b826db84329e..c66342ce1569 100644 --- a/pkg/sql/exec_util.go +++ b/pkg/sql/exec_util.go @@ -3598,6 +3598,10 @@ func (m *sessionDataMutator) SetUnboundedParallelScans(val bool) { m.data.UnboundedParallelScans = val } +func (m *sessionDataMutator) SetReplicationMode(val sessiondatapb.ReplicationMode) { + m.data.ReplicationMode = val +} + // Utility functions related to scrubbing sensitive information on SQL Stats. // quantizeCounts ensures that the Count field in the diff --git a/pkg/sql/pgrepl/BUILD.bazel b/pkg/sql/pgrepl/BUILD.bazel new file mode 100644 index 000000000000..702b02cc348f --- /dev/null +++ b/pkg/sql/pgrepl/BUILD.bazel @@ -0,0 +1,21 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + +go_test( + name = "pgrepl_test", + srcs = ["pgrepl_test.go"], + args = ["-test.timeout=295s"], + deps = [ + "//pkg/security/securityassets", + "//pkg/security/securitytest", + "//pkg/security/username", + "//pkg/server", + "//pkg/sql/pgwire/pgcode", + "//pkg/sql/tests", + "//pkg/testutils/serverutils", + "//pkg/testutils/sqlutils", + "@com_github_cockroachdb_errors//:errors", + "@com_github_jackc_pgx_v5//:pgx", + "@com_github_jackc_pgx_v5//pgconn", + "@com_github_stretchr_testify//require", + ], +) diff --git a/pkg/sql/pgrepl/pgrepl_test.go b/pkg/sql/pgrepl/pgrepl_test.go new file mode 100644 index 000000000000..75dc4bbb7b23 --- /dev/null +++ b/pkg/sql/pgrepl/pgrepl_test.go @@ -0,0 +1,86 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package pgrepl_test + +import ( + "context" + "net/url" + "os" + "testing" + + "github.com/cockroachdb/cockroach/pkg/security/securityassets" + "github.com/cockroachdb/cockroach/pkg/security/securitytest" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" + "github.com/cockroachdb/cockroach/pkg/sql/tests" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/errors" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + securityassets.SetLoader(securitytest.EmbeddedAssets) + serverutils.InitTestServerFactory(server.TestServerFactory) + os.Exit(m.Run()) +} + +func TestConnect(t *testing.T) { + params, _ := tests.CreateTestServerParams() + + s, db, _ := serverutils.StartServer(t, params) + defer s.Stopper().Stop(context.Background()) + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, `CREATE USER testuser`) + + for _, tc := range []struct { + replicationMode string + expectedSessionVar string + expectError bool + }{ + {replicationMode: "", expectedSessionVar: "off"}, + {replicationMode: "0", expectedSessionVar: "off"}, + {replicationMode: "1", expectedSessionVar: "on"}, + {replicationMode: "false", expectedSessionVar: "off"}, + {replicationMode: "true", expectedSessionVar: "on"}, + {replicationMode: "database", expectedSessionVar: "database"}, + {replicationMode: "asdf", expectError: true}, + } { + t.Run(tc.replicationMode, func(t *testing.T) { + pgURL, cleanup := sqlutils.PGUrl(t, s.ServingSQLAddr(), "pgrepl_conn_test", url.User(username.TestUser)) + defer cleanup() + + cfg, err := pgx.ParseConfig(pgURL.String()) + require.NoError(t, err) + if tc.replicationMode != "" { + cfg.RuntimeParams["replication"] = tc.replicationMode + } + + ctx := context.Background() + conn, err := pgx.ConnectConfig(ctx, cfg) + if tc.expectError { + require.Error(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, pgcode.InvalidParameterValue.String(), pgErr.Code) + return + } + require.NoError(t, err) + var val string + require.NoError(t, conn.QueryRow(ctx, "SELECT current_setting('replication')").Scan(&val)) + require.Equal(t, tc.expectedSessionVar, val) + }) + } +} diff --git a/pkg/sql/pgwire/pre_serve_options.go b/pkg/sql/pgwire/pre_serve_options.go index 5a5b242b604a..12d8c11bf441 100644 --- a/pkg/sql/pgwire/pre_serve_options.go +++ b/pkg/sql/pgwire/pre_serve_options.go @@ -93,6 +93,12 @@ func parseClientProvidedSessionParameters( // initialization information. args.IsSuperuser = args.User.IsRootUser() + case "replication": + // See session variable comment for the reason behind the remapping. + if err := loadParameter(ctx, key, value, &args.SessionArgs); err != nil { + return args, pgerror.Wrapf(err, pgerror.GetPGCode(err), "replication parameter") + } + case "crdb:session_revival_token_base64": token, err := base64.StdEncoding.DecodeString(value) if err != nil { diff --git a/pkg/sql/sessiondatapb/local_only_session_data.proto b/pkg/sql/sessiondatapb/local_only_session_data.proto index 65559bacc439..6299b9adba06 100644 --- a/pkg/sql/sessiondatapb/local_only_session_data.proto +++ b/pkg/sql/sessiondatapb/local_only_session_data.proto @@ -402,6 +402,9 @@ message LocalOnlySessionData { // NOTE: we'd prefer to use tree.IsolationLevel here, but doing so would // introduce a package dependency cycle. int64 default_txn_isolation_level = 105; + // ReplicationMode represents the replication parameter passed in during + // connection time. + ReplicationMode replication_mode = 106; /////////////////////////////////////////////////////////////////////////// // WARNING: consider whether a session parameter you're adding needs to // @@ -410,6 +413,14 @@ message LocalOnlySessionData { /////////////////////////////////////////////////////////////////////////// } +// ReplicationMode represents the replication={0,1,on,off,database} connection +// parameter in PostgreSQL. +enum ReplicationMode { + REPLICATION_MODE_DISABLED = 0; + REPLICATION_MODE_ENABLED = 1; + REPLICATION_MODE_DATABASE = 2; +} + // SequenceCacheEntry is an entry in a SequenceCache. message SequenceCacheEntry { // CachedVersion stores the descpb.DescriptorVersion that cached values are associated with. diff --git a/pkg/sql/vars.go b/pkg/sql/vars.go index c06d600c2d7e..31fb47525613 100644 --- a/pkg/sql/vars.go +++ b/pkg/sql/vars.go @@ -2762,6 +2762,47 @@ var varGen = map[string]sessionVar{ }, GlobalDefault: globalFalse, }, + + // CockroachDB extension. + // PostgreSQL does not use the "replication" session variable (it is only a + // parameter on the connection string). Instead, it is represented by a + // `am_walsender` / `am_db_walsender` bool on the connection. + `replication`: { + // We are hiding this for now as it is only meant for internal observability. + // It should only be set at connection time. + Hidden: true, + Set: func(_ context.Context, m sessionDataMutator, s string) error { + if strings.ToLower(s) == "database" { + m.SetReplicationMode(sessiondatapb.ReplicationMode_REPLICATION_MODE_DATABASE) + return nil + } + b, err := paramparse.ParseBoolVar("replication", s) + if err != nil { + return pgerror.Newf( + pgcode.InvalidParameterValue, + `parameter "replication" requires a boolean value or "database"`, + ) + } + if b { + m.SetReplicationMode(sessiondatapb.ReplicationMode_REPLICATION_MODE_ENABLED) + } else { + m.SetReplicationMode(sessiondatapb.ReplicationMode_REPLICATION_MODE_DISABLED) + } + return nil + }, + Get: func(evalCtx *extendedEvalContext, _ *kv.Txn) (string, error) { + switch evalCtx.SessionData().ReplicationMode { + case sessiondatapb.ReplicationMode_REPLICATION_MODE_DISABLED: + return formatBoolAsPostgresSetting(false), nil + case sessiondatapb.ReplicationMode_REPLICATION_MODE_ENABLED: + return formatBoolAsPostgresSetting(true), nil + case sessiondatapb.ReplicationMode_REPLICATION_MODE_DATABASE: + return "database", nil + } + return "", errors.AssertionFailedf("unknown replication mode: %v", evalCtx.SessionData().ReplicationMode) + }, + GlobalDefault: globalFalse, + }, } // We want test coverage for this on and off so make it metamorphic.