Skip to content

Commit

Permalink
pgrepl: allow logging in with replication parameter
Browse files Browse the repository at this point in the history
This commit adds the `replication=database` connection parameter, which
enables certain replication commands to be run.
See: https://www.postgresql.org/docs/current/protocol-replication.html

We do this by making it a session variable instead of a parameter on the
connection like PG does. This adds a hidden `replication` connection
parameter to accomplish this.

Release note: None
  • Loading branch information
otan committed Jun 23, 2023
1 parent 320e861 commit c6728de
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pkg/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ ALL_TESTS = [
"//pkg/sql/parser:parser_disallowed_imports_test",
"//pkg/sql/parser:parser_test",
"//pkg/sql/pgrepl/pgreplparser:pgreplparser_test",
"//pkg/sql/pgrepl:pgrepl_test",
"//pkg/sql/pgwire/hba:hba_test",
"//pkg/sql/pgwire/identmap:identmap_test",
"//pkg/sql/pgwire/pgerror:pgerror_test",
Expand Down Expand Up @@ -1864,6 +1865,7 @@ GO_TARGETS = [
"//pkg/sql/pgrepl/pgreplparser:pgreplparser",
"//pkg/sql/pgrepl/pgreplparser:pgreplparser_test",
"//pkg/sql/pgrepl/pgrepltree:pgrepltree",
"//pkg/sql/pgrepl:pgrepl_test",
"//pkg/sql/pgwire/hba:hba",
"//pkg/sql/pgwire/hba:hba_test",
"//pkg/sql/pgwire/identmap:identmap",
Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/exec_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -3598,6 +3598,10 @@ func (m *sessionDataMutator) SetUnboundedParallelScans(val bool) {
m.data.UnboundedParallelScans = val
}

func (m *sessionDataMutator) SetReplicationMode(val sessiondatapb.ReplicationMode) {
m.data.ReplicationMode = val
}

// Utility functions related to scrubbing sensitive information on SQL Stats.

// quantizeCounts ensures that the Count field in the
Expand Down
21 changes: 21 additions & 0 deletions pkg/sql/pgrepl/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")

go_test(
name = "pgrepl_test",
srcs = ["pgrepl_test.go"],
args = ["-test.timeout=295s"],
deps = [
"//pkg/security/securityassets",
"//pkg/security/securitytest",
"//pkg/security/username",
"//pkg/server",
"//pkg/sql/pgwire/pgcode",
"//pkg/sql/tests",
"//pkg/testutils/serverutils",
"//pkg/testutils/sqlutils",
"@com_github_cockroachdb_errors//:errors",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgconn",
"@com_github_stretchr_testify//require",
],
)
86 changes: 86 additions & 0 deletions pkg/sql/pgrepl/pgrepl_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright 2023 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package pgrepl_test

import (
"context"
"net/url"
"os"
"testing"

"github.com/cockroachdb/cockroach/pkg/security/securityassets"
"github.com/cockroachdb/cockroach/pkg/security/securitytest"
"github.com/cockroachdb/cockroach/pkg/security/username"
"github.com/cockroachdb/cockroach/pkg/server"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/tests"
"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
"github.com/cockroachdb/errors"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/require"
)

func TestMain(m *testing.M) {
securityassets.SetLoader(securitytest.EmbeddedAssets)
serverutils.InitTestServerFactory(server.TestServerFactory)
os.Exit(m.Run())
}

func TestConnect(t *testing.T) {
params, _ := tests.CreateTestServerParams()

s, db, _ := serverutils.StartServer(t, params)
defer s.Stopper().Stop(context.Background())

sqlDB := sqlutils.MakeSQLRunner(db)
sqlDB.Exec(t, `CREATE USER testuser`)

for _, tc := range []struct {
replicationMode string
expectedSessionVar string
expectError bool
}{
{replicationMode: "", expectedSessionVar: "off"},
{replicationMode: "0", expectedSessionVar: "off"},
{replicationMode: "1", expectedSessionVar: "on"},
{replicationMode: "false", expectedSessionVar: "off"},
{replicationMode: "true", expectedSessionVar: "on"},
{replicationMode: "database", expectedSessionVar: "database"},
{replicationMode: "asdf", expectError: true},
} {
t.Run(tc.replicationMode, func(t *testing.T) {
pgURL, cleanup := sqlutils.PGUrl(t, s.ServingSQLAddr(), "pgrepl_conn_test", url.User(username.TestUser))
defer cleanup()

cfg, err := pgx.ParseConfig(pgURL.String())
require.NoError(t, err)
if tc.replicationMode != "" {
cfg.RuntimeParams["replication"] = tc.replicationMode
}

ctx := context.Background()
conn, err := pgx.ConnectConfig(ctx, cfg)
if tc.expectError {
require.Error(t, err)
var pgErr *pgconn.PgError
require.True(t, errors.As(err, &pgErr))
require.Equal(t, pgcode.InvalidParameterValue.String(), pgErr.Code)
return
}
require.NoError(t, err)
var val string
require.NoError(t, conn.QueryRow(ctx, "SELECT current_setting('replication')").Scan(&val))
require.Equal(t, tc.expectedSessionVar, val)
})
}
}
6 changes: 6 additions & 0 deletions pkg/sql/pgwire/pre_serve_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ func parseClientProvidedSessionParameters(
// initialization information.
args.IsSuperuser = args.User.IsRootUser()

case "replication":
// See session variable comment for the reason behind the remapping.
if err := loadParameter(ctx, key, value, &args.SessionArgs); err != nil {
return args, pgerror.Wrapf(err, pgerror.GetPGCode(err), "replication parameter")
}

case "crdb:session_revival_token_base64":
token, err := base64.StdEncoding.DecodeString(value)
if err != nil {
Expand Down
11 changes: 11 additions & 0 deletions pkg/sql/sessiondatapb/local_only_session_data.proto
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,9 @@ message LocalOnlySessionData {
// NOTE: we'd prefer to use tree.IsolationLevel here, but doing so would
// introduce a package dependency cycle.
int64 default_txn_isolation_level = 105;
// ReplicationMode represents the replication parameter passed in during
// connection time.
ReplicationMode replication_mode = 106;

///////////////////////////////////////////////////////////////////////////
// WARNING: consider whether a session parameter you're adding needs to //
Expand All @@ -410,6 +413,14 @@ message LocalOnlySessionData {
///////////////////////////////////////////////////////////////////////////
}

// ReplicationMode represents the replication={0,1,on,off,database} connection
// parameter in PostgreSQL.
enum ReplicationMode {
REPLICATION_MODE_DISABLED = 0;
REPLICATION_MODE_ENABLED = 1;
REPLICATION_MODE_DATABASE = 2;
}

// SequenceCacheEntry is an entry in a SequenceCache.
message SequenceCacheEntry {
// CachedVersion stores the descpb.DescriptorVersion that cached values are associated with.
Expand Down
41 changes: 41 additions & 0 deletions pkg/sql/vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -2762,6 +2762,47 @@ var varGen = map[string]sessionVar{
},
GlobalDefault: globalFalse,
},

// CockroachDB extension.
// PostgreSQL does not use the "replication" session variable (it is only a
// parameter on the connection string). Instead, it is represented by a
// `am_walsender` / `am_db_walsender` bool on the connection.
`replication`: {
// We are hiding this for now as it is only meant for internal observability.
// It should only be set at connection time.
Hidden: true,
Set: func(_ context.Context, m sessionDataMutator, s string) error {
if strings.ToLower(s) == "database" {
m.SetReplicationMode(sessiondatapb.ReplicationMode_REPLICATION_MODE_DATABASE)
return nil
}
b, err := paramparse.ParseBoolVar("replication", s)
if err != nil {
return pgerror.Newf(
pgcode.InvalidParameterValue,
`parameter "replication" requires a boolean value or "database"`,
)
}
if b {
m.SetReplicationMode(sessiondatapb.ReplicationMode_REPLICATION_MODE_ENABLED)
} else {
m.SetReplicationMode(sessiondatapb.ReplicationMode_REPLICATION_MODE_DISABLED)
}
return nil
},
Get: func(evalCtx *extendedEvalContext, _ *kv.Txn) (string, error) {
switch evalCtx.SessionData().ReplicationMode {
case sessiondatapb.ReplicationMode_REPLICATION_MODE_DISABLED:
return formatBoolAsPostgresSetting(false), nil
case sessiondatapb.ReplicationMode_REPLICATION_MODE_ENABLED:
return formatBoolAsPostgresSetting(true), nil
case sessiondatapb.ReplicationMode_REPLICATION_MODE_DATABASE:
return "database", nil
}
return "", errors.AssertionFailedf("unknown replication mode: %v", evalCtx.SessionData().ReplicationMode)
},
GlobalDefault: globalFalse,
},
}

// We want test coverage for this on and off so make it metamorphic.
Expand Down

0 comments on commit c6728de

Please sign in to comment.