diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 43065203844e..9d47664cf786 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -1902,6 +1902,7 @@ GO_TARGETS = [ "//pkg/sql/parser:parser", "//pkg/sql/parser:parser_test", "//pkg/sql/pgrepl/lsn:lsn", + "//pkg/sql/pgrepl/lsnutil:lsnutil", "//pkg/sql/pgrepl/pgreplparser:pgreplparser", "//pkg/sql/pgrepl/pgreplparser:pgreplparser_test", "//pkg/sql/pgrepl/pgrepltree:pgrepltree", diff --git a/pkg/acceptance/generated_cli_test.go b/pkg/acceptance/generated_cli_test.go index 08412e06b8b5..efe4ba3e3008 100644 --- a/pkg/acceptance/generated_cli_test.go +++ b/pkg/acceptance/generated_cli_test.go @@ -347,6 +347,13 @@ func TestDockerCLI_test_reconnect(t *testing.T) { runTestDockerCLI(t, "test_reconnect", "../cli/interactive_tests/test_reconnect.tcl") } +func TestDockerCLI_test_replication_protocol(t *testing.T) { + s := log.Scope(t) + defer s.Close(t) + + runTestDockerCLI(t, "test_replication_protocol", "../cli/interactive_tests/test_replication_protocol.tcl") +} + func TestDockerCLI_test_secure(t *testing.T) { s := log.Scope(t) defer s.Close(t) diff --git a/pkg/cli/interactive_tests/test_replication_protocol.tcl b/pkg/cli/interactive_tests/test_replication_protocol.tcl new file mode 100644 index 000000000000..33d1f648f1b9 --- /dev/null +++ b/pkg/cli/interactive_tests/test_replication_protocol.tcl @@ -0,0 +1,20 @@ +#! /usr/bin/env expect -f + +source [file join [file dirname $argv0] common.tcl] + +start_server $argv + +start_test "Ensure that replication mode works as expected in the sql shell" + +# Spawn a sql shell. +spawn /bin/bash + +send "$argv sql --url `cat server_url`'\&replication=database' -e 'IDENTIFY_SYSTEM'\r" +eexpect "(1 row)" + +send_eof +eexpect eof + +end_test + +stop_server $argv diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index 588a83494a33..d0eff0933ad3 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -127,6 +127,7 @@ go_library( "grant_revoke_system.go", "grant_role.go", "group.go", + "identify_system.go", "index_backfiller.go", "index_join.go", "index_split_scatter.go", @@ -442,6 +443,9 @@ go_library( "//pkg/sql/paramparse", "//pkg/sql/parser", "//pkg/sql/parser/statements", + "//pkg/sql/pgrepl/lsn", + "//pkg/sql/pgrepl/lsnutil", + "//pkg/sql/pgrepl/pgrepltree", "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgerror", "//pkg/sql/pgwire/pgnotice", diff --git a/pkg/sql/catalog/colinfo/result_columns.go b/pkg/sql/catalog/colinfo/result_columns.go index 4a44fa8617e1..9eb8252379d3 100644 --- a/pkg/sql/catalog/colinfo/result_columns.go +++ b/pkg/sql/catalog/colinfo/result_columns.go @@ -344,3 +344,11 @@ const RangesExtraRenders = ` coalesce((crdb_internal.range_stats(start_key)->>'range_key_bytes')::INT, 0) + coalesce((crdb_internal.range_stats(start_key)->>'range_val_bytes')::INT, 0) AS range_size ` + +// IdentifySystemColumns is the schema for IDENTIFY_SYSTEM. +var IdentifySystemColumns = ResultColumns{ + {Name: "systemid", Typ: types.String}, + {Name: "timeline", Typ: types.Int4}, + {Name: "xlogpos", Typ: types.String}, + {Name: "dbname", Typ: types.String}, +} diff --git a/pkg/sql/delegate/BUILD.bazel b/pkg/sql/delegate/BUILD.bazel index 8a6244c37fd4..d354fd28b5ab 100644 --- a/pkg/sql/delegate/BUILD.bazel +++ b/pkg/sql/delegate/BUILD.bazel @@ -53,6 +53,7 @@ go_library( "//pkg/sql/oidext", "//pkg/sql/opt/cat", "//pkg/sql/parser", + "//pkg/sql/pgrepl/pgreplparser", "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgerror", "//pkg/sql/privilege", @@ -60,6 +61,7 @@ go_library( "//pkg/sql/sem/catconstants", "//pkg/sql/sem/eval", "//pkg/sql/sem/tree", + "//pkg/sql/sessiondatapb", "//pkg/sql/sqlerrors", "//pkg/sql/sqltelemetry", "//pkg/sql/syntheticprivilege", diff --git a/pkg/sql/delegate/show_syntax.go b/pkg/sql/delegate/show_syntax.go index f74e324e8165..d57d629b321d 100644 --- a/pkg/sql/delegate/show_syntax.go +++ b/pkg/sql/delegate/show_syntax.go @@ -18,7 +18,9 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/colinfo" "github.com/cockroachdb/cockroach/pkg/sql/lexbase" "github.com/cockroachdb/cockroach/pkg/sql/parser" + "github.com/cockroachdb/cockroach/pkg/sql/pgrepl/pgreplparser" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" ) // delegateShowSyntax implements SHOW SYNTAX. This statement is usually handled @@ -41,6 +43,16 @@ func (d *delegator) delegateShowSyntax(n *tree.ShowSyntax) (tree.Statement, erro colinfo.ShowSyntaxColumns[0].Name, colinfo.ShowSyntaxColumns[1].Name, ) + // For replication based statements, return nothing for now. + if d.evalCtx.SessionData().ReplicationMode != sessiondatapb.ReplicationMode_REPLICATION_MODE_DISABLED && + pgreplparser.IsReplicationProtocolCommand(n.Statement) { + return d.parse(fmt.Sprintf( + `SELECT '' AS %s, '' AS %s FROM generate_series(0, -1) x`, + colinfo.ShowSyntaxColumns[0].Name, + colinfo.ShowSyntaxColumns[1].Name, + )) + } + comma := "" // TODO(knz): in the call below, reportErr is nil although we might // want to be able to capture (and report) these errors as well. diff --git a/pkg/sql/exec_util.go b/pkg/sql/exec_util.go index da6abfae3771..359776b3d6a0 100644 --- a/pkg/sql/exec_util.go +++ b/pkg/sql/exec_util.go @@ -2152,6 +2152,7 @@ type SessionArgs struct { User username.SQLUsername IsSuperuser bool IsSSL bool + ReplicationMode sessiondatapb.ReplicationMode SystemIdentity username.SQLUsername SessionDefaults SessionDefaults CustomOptionSessionDefaults SessionDefaults diff --git a/pkg/sql/identify_system.go b/pkg/sql/identify_system.go new file mode 100644 index 000000000000..fba3dbb99605 --- /dev/null +++ b/pkg/sql/identify_system.go @@ -0,0 +1,66 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package sql + +import ( + "context" + + "github.com/cockroachdb/cockroach/pkg/sql/pgrepl/lsn" + "github.com/cockroachdb/cockroach/pkg/sql/pgrepl/lsnutil" + "github.com/cockroachdb/cockroach/pkg/sql/pgrepl/pgrepltree" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" +) + +type identifySystemNode struct { + optColumnsSlot + clusterID string + database string + lsn lsn.LSN + shown bool +} + +func (s *identifySystemNode) startExec(params runParams) error { + return nil +} + +func (s *identifySystemNode) Next(params runParams) (bool, error) { + if s.shown { + return false, nil + } + s.shown = true + return true, nil +} + +func (s *identifySystemNode) Values() tree.Datums { + db := tree.DNull + if s.database != "" { + db = tree.NewDString(s.database) + } + return tree.Datums{ + tree.NewDString(s.clusterID), + tree.NewDInt(1), // timeline + tree.NewDString(s.lsn.String()), + db, + } +} + +func (s *identifySystemNode) Close(ctx context.Context) {} + +func (p *planner) IdentifySystem( + ctx context.Context, n *pgrepltree.IdentifySystem, +) (planNode, error) { + return &identifySystemNode{ + // TODO(#105130): correctly populate this field. + lsn: lsnutil.HLCToLSN(p.Txn().ReadTimestamp()), + clusterID: p.ExecCfg().NodeInfo.LogicalClusterID().String(), + database: p.SessionData().Database, + }, nil +} diff --git a/pkg/sql/opaque.go b/pkg/sql/opaque.go index 02d3fead6b3c..166a0a702878 100644 --- a/pkg/sql/opaque.go +++ b/pkg/sql/opaque.go @@ -17,6 +17,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/colinfo" "github.com/cockroachdb/cockroach/pkg/sql/opt" "github.com/cockroachdb/cockroach/pkg/sql/opt/optbuilder" + "github.com/cockroachdb/cockroach/pkg/sql/pgrepl/pgrepltree" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" @@ -281,6 +282,8 @@ func planOpaque(ctx context.Context, p *planner, stmt tree.Statement) (planNode, return p.Truncate(ctx, n) case *tree.Unlisten: return p.Unlisten(ctx, n) + case *pgrepltree.IdentifySystem: + return p.IdentifySystem(ctx, n) case tree.CCLOnlyStatement: plan, err := p.maybePlanHook(ctx, stmt) if plan == nil && err == nil { @@ -397,6 +400,8 @@ func init() { &tree.Truncate{}, &tree.Unlisten{}, + &pgrepltree.IdentifySystem{}, + // CCL statements (without Export which has an optimizer operator). &tree.AlterBackup{}, &tree.AlterBackupSchedule{}, diff --git a/pkg/sql/pgrepl/BUILD.bazel b/pkg/sql/pgrepl/BUILD.bazel index d7d7a2f6f017..35eb6750bf42 100644 --- a/pkg/sql/pgrepl/BUILD.bazel +++ b/pkg/sql/pgrepl/BUILD.bazel @@ -4,9 +4,11 @@ go_test( name = "pgrepl_test", srcs = [ "connect_test.go", + "extended_protocol_test.go", "main_test.go", ], args = ["-test.timeout=295s"], + data = glob(["testdata/**"]), deps = [ "//pkg/base", "//pkg/ccl", @@ -15,13 +17,17 @@ go_test( "//pkg/security/username", "//pkg/server", "//pkg/sql/pgwire/pgcode", + "//pkg/testutils/datapathutils", "//pkg/testutils/serverutils", "//pkg/testutils/sqlutils", "//pkg/util/leaktest", "//pkg/util/log", + "@com_github_cockroachdb_datadriven//:datadriven", "@com_github_cockroachdb_errors//:errors", + "@com_github_jackc_pgx_v4//:pgx", "@com_github_jackc_pgx_v5//:pgx", "@com_github_jackc_pgx_v5//pgconn", + "@com_github_jackc_pgx_v5//pgproto3", "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/sql/pgrepl/connect_test.go b/pkg/sql/pgrepl/connect_test.go index a2c93222a3ac..19b18ba6ac3a 100644 --- a/pkg/sql/pgrepl/connect_test.go +++ b/pkg/sql/pgrepl/connect_test.go @@ -115,7 +115,7 @@ func TestReplicationConnect(t *testing.T) { } require.NoError(t, err) var val string - require.NoError(t, conn.QueryRow(ctx, "SELECT current_setting('replication')").Scan(&val)) + require.NoError(t, conn.QueryRow(ctx, "SELECT current_setting('replication')", pgx.QueryExecModeSimpleProtocol).Scan(&val)) require.Equal(t, tc.expectedSessionVar, val) }) } diff --git a/pkg/sql/pgrepl/extended_protocol_test.go b/pkg/sql/pgrepl/extended_protocol_test.go new file mode 100644 index 000000000000..f865e0e78a68 --- /dev/null +++ b/pkg/sql/pgrepl/extended_protocol_test.go @@ -0,0 +1,103 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package pgrepl + +import ( + "context" + "net/url" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +// TestExtendedProtocolDisabled ensures the extended protocol is disabled +// during replication mode. +func TestExtendedProtocolDisabled(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + srv, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(context.Background()) + s := srv.ApplicationLayer() + + sqlDB := sqlutils.MakeSQLRunner(db) + sqlDB.Exec(t, `CREATE USER testuser LOGIN REPLICATION`) + + pgURL, cleanup := sqlutils.PGUrl(t, s.AdvSQLAddr(), "pgrepl_extended_protocol_test", url.User(username.TestUser)) + defer cleanup() + + cfg, err := pgconn.ParseConfig(pgURL.String()) + require.NoError(t, err) + cfg.RuntimeParams["replication"] = "database" + ctx := context.Background() + + conn, err := pgconn.ConnectConfig(ctx, cfg) + require.NoError(t, err) + fe := conn.Frontend() + + for _, tc := range []struct { + desc string + msg []pgproto3.FrontendMessage + }{ + {desc: "parse", msg: []pgproto3.FrontendMessage{&pgproto3.Parse{Name: "a", Query: "SELECT 1"}}}, + {desc: "bind", msg: []pgproto3.FrontendMessage{&pgproto3.Bind{}}}, + {desc: "parse and bind", msg: []pgproto3.FrontendMessage{ + &pgproto3.Parse{Name: "a", Query: "SELECT 1"}, + &pgproto3.Bind{}, + }}, + {desc: "describe", msg: []pgproto3.FrontendMessage{&pgproto3.Describe{Name: "a"}}}, + {desc: "exec", msg: []pgproto3.FrontendMessage{&pgproto3.Execute{Portal: "a"}}}, + {desc: "close", msg: []pgproto3.FrontendMessage{&pgproto3.Close{}}}, + } { + t.Run(tc.desc, func(t *testing.T) { + for _, msg := range tc.msg { + fe.Send(msg) + } + fe.Send(&pgproto3.Sync{}) + err := fe.Flush() + require.NoError(t, err) + var pgErr *pgconn.PgError + done := false + for !done { + recv, err := fe.Receive() + require.NoError(t, err) + switch recv := recv.(type) { + case *pgproto3.ReadyForQuery: + done = true + case *pgproto3.ErrorResponse: + // Ensure we do not have multiple errors. + require.Nil(t, pgErr) + pgErr = pgconn.ErrorResponseToPgError(recv) + default: + t.Errorf("received unexpected message %#v", recv) + } + } + require.NotNil(t, pgErr) + require.Equal(t, pgcode.ProtocolViolation.String(), pgErr.Code) + require.Contains(t, pgErr.Message, "extended query protocol not supported in a replication connection") + + // Ensure we can use the connection using the simple protocol. + rows := conn.Exec(ctx, "SELECT 1") + _, err = rows.ReadAll() + require.NoError(t, err) + require.NoError(t, rows.Close()) + }) + } +} diff --git a/pkg/sql/pgrepl/lsnutil/BUILD.bazel b/pkg/sql/pgrepl/lsnutil/BUILD.bazel new file mode 100644 index 000000000000..82ae74536d50 --- /dev/null +++ b/pkg/sql/pgrepl/lsnutil/BUILD.bazel @@ -0,0 +1,12 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "lsnutil", + srcs = ["lsnutil.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/sql/pgrepl/lsnutil", + visibility = ["//visibility:public"], + deps = [ + "//pkg/sql/pgrepl/lsn", + "//pkg/util/hlc", + ], +) diff --git a/pkg/sql/pgrepl/lsnutil/lsnutil.go b/pkg/sql/pgrepl/lsnutil/lsnutil.go new file mode 100644 index 000000000000..e6a7af4e3f67 --- /dev/null +++ b/pkg/sql/pgrepl/lsnutil/lsnutil.go @@ -0,0 +1,25 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package lsnutil + +import ( + "time" + + "github.com/cockroachdb/cockroach/pkg/sql/pgrepl/lsn" + "github.com/cockroachdb/cockroach/pkg/util/hlc" +) + +// HLCToLSN converts a HLC to a LSN. +// It is in a separate package to prevent the `lsn` package importing `log`. +func HLCToLSN(h hlc.Timestamp) lsn.LSN { + // TODO(#105130): correctly populate this field. + return lsn.LSN(h.WallTime/int64(time.Millisecond)) << 32 +} diff --git a/pkg/sql/pgrepl/main_test.go b/pkg/sql/pgrepl/main_test.go index 2a10a35968b4..518a08c0b217 100644 --- a/pkg/sql/pgrepl/main_test.go +++ b/pkg/sql/pgrepl/main_test.go @@ -8,17 +8,30 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -package pgrepl_test +package pgrepl import ( + "context" + "fmt" + "net/url" "os" + "strings" "testing" + "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/ccl" "github.com/cockroachdb/cockroach/pkg/security/securityassets" "github.com/cockroachdb/cockroach/pkg/security/securitytest" + "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/testutils/datapathutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/datadriven" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/require" ) //go:generate ../../util/leaktest/add-leaktest.sh *_test.go @@ -29,3 +42,80 @@ func TestMain(m *testing.M) { defer ccl.TestingEnableEnterprise()() os.Exit(m.Run()) } + +func TestDataDriven(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + s := serverutils.StartServerOnly(t, base.TestServerArgs{}) + defer s.Stopper().Stop(context.Background()) + + pgURL, cleanup := sqlutils.PGUrl(t, s.AdvSQLAddr(), "pgrepl_datadriven_test", url.User(username.RootUser)) + defer cleanup() + + cfg, err := pgx.ParseConfig(pgURL.String()) + require.NoError(t, err) + cfg.RuntimeParams["replication"] = "database" + ctx := context.Background() + + conn, err := pgx.ConnectConfig(ctx, cfg) + require.NoError(t, err) + + datadriven.Walk(t, datapathutils.TestDataPath(t), func(t *testing.T, path string) { + datadriven.RunTest(t, path, func(t *testing.T, d *datadriven.TestData) string { + var expectError bool + args := d.CmdArgs[:0] + for _, arg := range d.CmdArgs { + switch arg.Key { + case "error": + expectError = true + default: + args = append(args, arg) + } + } + d.CmdArgs = args + + switch d.Cmd { + case "simple_query": + rows, err := conn.Query(ctx, d.Input, pgx.QuerySimpleProtocol(true)) + if expectError { + require.Error(t, err) + return err.Error() + } + out, err := sqlutils.PGXRowsToDataDrivenOutput(rows) + rows.Close() + require.NoError(t, err) + return out + case "identify_system": + // IDENTIFY_SYSTEM needs some redaction to be deterministic. + rows, err := conn.Query(ctx, "IDENTIFY_SYSTEM", pgx.QuerySimpleProtocol(true)) + require.NoError(t, err) + var sb strings.Builder + for rows.Next() { + vals, err := rows.Values() + require.NoError(t, err) + for i, val := range vals { + if i > 0 { + sb.WriteRune('\n') + } + switch string(rows.FieldDescriptions()[i].Name) { + case "systemid": + val = "some_cluster_id" + case "xlogpos": + val = "some_lsn" + } + sb.Write(rows.FieldDescriptions()[i].Name) + sb.WriteString(": ") + sb.WriteString(fmt.Sprintf("%v", val)) + } + } + require.NoError(t, rows.Err()) + rows.Close() + return sb.String() + default: + t.Errorf("unhandled command %s", d.Cmd) + } + return "" + }) + }) +} diff --git a/pkg/sql/pgrepl/pgreplparser/BUILD.bazel b/pkg/sql/pgrepl/pgreplparser/BUILD.bazel index a8e62011a661..6ad287ed0225 100644 --- a/pkg/sql/pgrepl/pgreplparser/BUILD.bazel +++ b/pkg/sql/pgrepl/pgreplparser/BUILD.bazel @@ -38,6 +38,7 @@ go_library( deps = [ "//pkg/sql/lexbase", "//pkg/sql/parser", + "//pkg/sql/parser/statements", "//pkg/sql/pgrepl/lsn", "//pkg/sql/pgrepl/pgrepltree", "//pkg/sql/pgwire/pgcode", diff --git a/pkg/sql/pgrepl/pgreplparser/lexer.go b/pkg/sql/pgrepl/pgreplparser/lexer.go index 63494e1d3086..79fbb97a412b 100644 --- a/pkg/sql/pgrepl/pgreplparser/lexer.go +++ b/pkg/sql/pgrepl/pgreplparser/lexer.go @@ -33,6 +33,18 @@ const ( doubleQuote = '"' ) +func IsReplicationProtocolCommand(q string) bool { + l := newLexer(q) + var lval pgreplSymType + switch l.Lex(&lval) { + case K_CREATE_REPLICATION_SLOT, K_DROP_REPLICATION_SLOT, K_READ_REPLICATION_SLOT, K_START_REPLICATION, + K_IDENTIFY_SYSTEM, + K_TIMELINE_HISTORY, K_BASE_BACKUP: + return true + } + return false +} + // lexer implements the Lexer goyacc interface. // It differs from the sql/scanner package as replication has unique, case // sensitive behavior with unique keywords. diff --git a/pkg/sql/pgrepl/pgreplparser/parser.go b/pkg/sql/pgrepl/pgreplparser/parser.go index 0de2f0d35cab..e2213d68977b 100644 --- a/pkg/sql/pgrepl/pgreplparser/parser.go +++ b/pkg/sql/pgrepl/pgreplparser/parser.go @@ -11,21 +11,25 @@ package pgreplparser import ( - "github.com/cockroachdb/cockroach/pkg/sql/pgrepl/pgrepltree" + "github.com/cockroachdb/cockroach/pkg/sql/parser/statements" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/errors" ) -func Parse(sql string) (pgrepltree.ReplicationStatement, error) { +func Parse(sql string) (statements.Statement[tree.Statement], error) { lexer := newLexer(sql) p := pgreplNewParser() if p.Parse(lexer) != 0 { if lexer.lastError == nil { - return nil, errors.AssertionFailedf("expected lexer error but got none") + return statements.Statement[tree.Statement]{}, errors.AssertionFailedf("expected lexer error but got none") } - return nil, lexer.lastError + return statements.Statement[tree.Statement]{}, lexer.lastError } if lexer.stmt == nil { - return nil, errors.AssertionFailedf("expected statement but got none") + return statements.Statement[tree.Statement]{}, errors.AssertionFailedf("expected statement but got none") } - return lexer.stmt, nil + return statements.Statement[tree.Statement]{ + AST: lexer.stmt, + SQL: sql, + }, nil } diff --git a/pkg/sql/pgrepl/pgreplparser/parser_test.go b/pkg/sql/pgrepl/pgreplparser/parser_test.go index 154fda1f9525..246c5ef1ae0a 100644 --- a/pkg/sql/pgrepl/pgreplparser/parser_test.go +++ b/pkg/sql/pgrepl/pgreplparser/parser_test.go @@ -52,20 +52,20 @@ func TestParser(t *testing.T) { return msg } require.NoError(t, err) - ref := p.String() + ref := tree.AsString(p.AST) note := "" if ref != td.Input { note = " -- normalized!" } var buf bytes.Buffer fmt.Fprintf(&buf, "%s%s\n", ref, note) - constantsHidden := tree.AsStringWithFlags(p, tree.FmtHideConstants) + constantsHidden := tree.AsStringWithFlags(p.AST, tree.FmtHideConstants) fmt.Fprintln(&buf, constantsHidden, "-- literals removed") // Test roundtrip. - reparsed, err := Parse(p.String()) + reparsed, err := Parse(ref) require.NoError(t, err) - assert.Equal(t, p.String(), reparsed.String()) + assert.Equal(t, ref, tree.AsString(reparsed.AST)) return buf.String() default: diff --git a/pkg/sql/pgrepl/pgrepltree/identify_system.go b/pkg/sql/pgrepl/pgrepltree/identify_system.go index e0573e5dff8a..7a72707aaeae 100644 --- a/pkg/sql/pgrepl/pgrepltree/identify_system.go +++ b/pkg/sql/pgrepl/pgrepltree/identify_system.go @@ -26,11 +26,11 @@ func (i *IdentifySystem) Format(ctx *tree.FmtCtx) { } func (i *IdentifySystem) StatementReturnType() tree.StatementReturnType { - return tree.Replication + return tree.Rows } func (i *IdentifySystem) StatementType() tree.StatementType { - return tree.TypeDDL + return tree.TypeDML } func (i *IdentifySystem) StatementTag() string { diff --git a/pkg/sql/pgrepl/testdata/identify_system.ddt b/pkg/sql/pgrepl/testdata/identify_system.ddt new file mode 100644 index 000000000000..6f89fc0ac049 --- /dev/null +++ b/pkg/sql/pgrepl/testdata/identify_system.ddt @@ -0,0 +1,19 @@ +# valid identify_system usages +identify_system +---- +systemid: some_cluster_id +timeline: 1 +xlogpos: some_lsn +dbname: defaultdb + + +# invalid identify_system usages +simple_query error +IDENTIFY_SYSTEM; IDENTIFY_SYSTEM; +---- +ERROR: at or near "IDENTIFY_SYSTEM": syntax error (SQLSTATE 42601) + +simple_query error +identify_system; +---- +ERROR: at or near "identify_system": syntax error (SQLSTATE 42601) diff --git a/pkg/sql/pgwire/BUILD.bazel b/pkg/sql/pgwire/BUILD.bazel index b939d9bb9653..d0765f5ab814 100644 --- a/pkg/sql/pgwire/BUILD.bazel +++ b/pkg/sql/pgwire/BUILD.bazel @@ -43,6 +43,8 @@ go_library( "//pkg/sql/lex", "//pkg/sql/parser", "//pkg/sql/parser/statements", + "//pkg/sql/pgrepl/pgreplparser", + "//pkg/sql/pgrepl/pgrepltree", "//pkg/sql/pgwire/hba", "//pkg/sql/pgwire/identmap", "//pkg/sql/pgwire/pgcode", diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 552e63b38184..f0f588482100 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -30,6 +30,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/clusterunique" "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/parser/statements" + "github.com/cockroachdb/cockroach/pkg/sql/pgrepl/pgreplparser" + "github.com/cockroachdb/cockroach/pkg/sql/pgrepl/pgrepltree" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" @@ -39,6 +41,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" "github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/mon" "github.com/cockroachdb/cockroach/pkg/util/ring" @@ -344,6 +347,34 @@ func (c *conn) handleSimpleQuery( } startParse := timeutil.Now() + if c.sessionArgs.ReplicationMode != sessiondatapb.ReplicationMode_REPLICATION_MODE_DISABLED && + pgreplparser.IsReplicationProtocolCommand(query) { + stmt, err := pgreplparser.Parse(query) + if err != nil { + log.SqlExec.Infof(ctx, "could not parse simple query in replication protocol: %s", query) + return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) + } + switch stmt.AST.(type) { + case *pgrepltree.IdentifySystem: + default: + log.SqlExec.Infof(ctx, "unhandled replication protocol query: %s", query) + return c.stmtBuf.Push(ctx, sql.SendError{ + Err: unimplemented.NewWithIssueDetail(0, fmt.Sprintf("%T", stmt.AST), "replication protocol command not implemented"), + }) + } + endParse := timeutil.Now() + return c.stmtBuf.Push( + ctx, + sql.ExecStmt{ + Statement: stmt, + TimeReceived: timeReceived, + ParseStart: startParse, + ParseEnd: endParse, + LastInBatch: true, + LastInBatchBeforeShowCommitTimestamp: false, + }, + ) + } stmts, err := c.parser.ParseWithInt(query, unqualifiedIntSize) if err != nil { log.SqlExec.Infof(ctx, "could not parse simple query: %s", query) diff --git a/pkg/sql/pgwire/pgerror/pgcode.go b/pkg/sql/pgwire/pgerror/pgcode.go index d3721827d93e..7f6fa3ae47a6 100644 --- a/pkg/sql/pgwire/pgerror/pgcode.go +++ b/pkg/sql/pgwire/pgerror/pgcode.go @@ -25,7 +25,6 @@ func WithCandidateCode(err error, code pgcode.Code) error { if err == nil { return nil } - return &withCandidateCode{cause: err, code: code.String()} } diff --git a/pkg/sql/pgwire/pre_serve_options.go b/pkg/sql/pgwire/pre_serve_options.go index 36d28d3b1382..f43b1c0dec8d 100644 --- a/pkg/sql/pgwire/pre_serve_options.go +++ b/pkg/sql/pgwire/pre_serve_options.go @@ -99,6 +99,11 @@ func parseClientProvidedSessionParameters( if err := loadParameter(ctx, key, value, &args.SessionArgs); err != nil { return args, pgerror.Wrapf(err, pgerror.GetPGCode(err), "replication parameter") } + // Cache the value into session args. + args.ReplicationMode, err = sql.ReplicationModeFromString(args.SessionArgs.SessionDefaults["replication"]) + if err != nil { + return args, err + } case "crdb:session_revival_token_base64": token, err := base64.StdEncoding.DecodeString(value) diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index 33e97e333526..391d2c08a313 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -33,6 +33,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirecancel" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" "github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry" "github.com/cockroachdb/cockroach/pkg/util/ctxlog" "github.com/cockroachdb/cockroach/pkg/util/envutil" @@ -1133,6 +1134,9 @@ func (s *Server) serveImpl( }) case pgwirebase.ClientMsgExecute: + if err := c.prohibitUnderReplicationMode(ctx); err != nil { + return false, isSimpleQuery, err + } // To support the 1PC txn fast path, we peek at the next command to // see if it is a Sync. This is because in the extended protocol, an // implicit transaction cannot commit until the Sync is seen. If there's @@ -1147,15 +1151,27 @@ func (s *Server) serveImpl( return false, isSimpleQuery, c.handleExecute(ctx, timeReceived, followedBySync) case pgwirebase.ClientMsgParse: + if err := c.prohibitUnderReplicationMode(ctx); err != nil { + return false, isSimpleQuery, err + } return false, isSimpleQuery, c.handleParse(ctx, parser.NakedIntTypeFromDefaultIntSize(atomic.LoadInt32(atomicUnqualifiedIntSize))) case pgwirebase.ClientMsgDescribe: + if err := c.prohibitUnderReplicationMode(ctx); err != nil { + return false, isSimpleQuery, err + } return false, isSimpleQuery, c.handleDescribe(ctx) case pgwirebase.ClientMsgBind: + if err := c.prohibitUnderReplicationMode(ctx); err != nil { + return false, isSimpleQuery, err + } return false, isSimpleQuery, c.handleBind(ctx) case pgwirebase.ClientMsgClose: + if err := c.prohibitUnderReplicationMode(ctx); err != nil { + return false, isSimpleQuery, err + } return false, isSimpleQuery, c.handleClose(ctx) case pgwirebase.ClientMsgTerminate: @@ -1252,6 +1268,24 @@ func (s *Server) serveImpl( } } +// From https://github.com/postgres/postgres/blob/28b5726561841556dc3e00ffe26b01a8107ee654/src/backend/tcop/postgres.c#L4891-L4891 +func (c *conn) prohibitUnderReplicationMode(ctx context.Context) error { + if c.sessionArgs.ReplicationMode == sessiondatapb.ReplicationMode_REPLICATION_MODE_DISABLED { + return nil + } + pgErr := pgerror.New( + pgcode.ProtocolViolation, + "extended query protocol not supported in a replication connection", + ) + if err := c.stmtBuf.Push(ctx, sql.SendError{ + Err: pgErr, + }); err != nil { + return err + } + // return the same error so that ignoreUntilSync is hit. + return pgErr +} + // readCancelKey retrieves the "backend data" key that identifies // a cancellable query, then closes the connection. func readCancelKey( diff --git a/pkg/sql/plan_columns.go b/pkg/sql/plan_columns.go index f5bdb75dbcb7..4b5fd3658a13 100644 --- a/pkg/sql/plan_columns.go +++ b/pkg/sql/plan_columns.go @@ -173,6 +173,9 @@ func getPlanColumns(plan planNode, mut bool) colinfo.ResultColumns { return n.planCols case *cdcValuesNode: return n.resultColumns + + case *identifySystemNode: + return n.getColumns(mut, colinfo.IdentifySystemColumns) } // Every other node has no columns in their results. diff --git a/pkg/sql/walk.go b/pkg/sql/walk.go index e6a3fc8cdcec..a82145875243 100644 --- a/pkg/sql/walk.go +++ b/pkg/sql/walk.go @@ -484,4 +484,5 @@ var planNodeNames = map[reflect.Type]string{ reflect.TypeOf(&zeroNode{}): "norows", reflect.TypeOf(&zigzagJoinNode{}): "zigzag join", reflect.TypeOf(&schemaChangePlanNode{}): "schema change", + reflect.TypeOf(&identifySystemNode{}): "identify system", }