diff --git a/pkg/sql/pgwire/conn_test.go b/pkg/sql/pgwire/conn_test.go index 7f7ecb6720cc..acb14f50ea9c 100644 --- a/pkg/sql/pgwire/conn_test.go +++ b/pkg/sql/pgwire/conn_test.go @@ -520,32 +520,41 @@ func client(ctx context.Context, serverAddr net.Addr, wg *sync.WaitGroup) error // waitForClientConn blocks until a client connects and performs the pgwire // handshake. This emulates what pgwire.Server does. func waitForClientConn(ln net.Listener) (*conn, error) { - conn, err := ln.Accept() + conn, _, err := getSessionArgs(ln, false /* trustRemoteAddr */) if err != nil { return nil, err } + metrics := makeServerMetrics(sql.MemoryMetrics{} /* sqlMemMetrics */, metric.TestSampleInterval) + pgwireConn := newConn(conn, sql.SessionArgs{ConnResultsBufferSize: 16 << 10}, &metrics, nil) + return pgwireConn, nil +} + +// getSessionArgs blocks until a client connects and returns the connection +// together with session arguments or an error. +func getSessionArgs(ln net.Listener, trustRemoteAddr bool) (net.Conn, sql.SessionArgs, error) { + conn, err := ln.Accept() + if err != nil { + return nil, sql.SessionArgs{}, err + } + buf := pgwirebase.MakeReadBuffer() _, err = buf.ReadUntypedMsg(conn) if err != nil { - return nil, err + return nil, sql.SessionArgs{}, err } version, err := buf.GetUint32() if err != nil { - return nil, err + return nil, sql.SessionArgs{}, err } if version != version30 { - return nil, errors.Errorf("unexpected protocol version: %d", version) - } - - // Consume the connection options. - if _, err := parseClientProvidedSessionParameters(context.Background(), nil, &buf, conn.RemoteAddr(), false /* trustRemoteAddr */); err != nil { - return nil, err + return nil, sql.SessionArgs{}, errors.Errorf("unexpected protocol version: %d", version) } - metrics := makeServerMetrics(sql.MemoryMetrics{} /* sqlMemMetrics */, metric.TestSampleInterval) - pgwireConn := newConn(conn, sql.SessionArgs{ConnResultsBufferSize: 16 << 10}, &metrics, nil) - return pgwireConn, nil + args, err := parseClientProvidedSessionParameters( + context.Background(), nil, &buf, conn.RemoteAddr(), trustRemoteAddr, + ) + return conn, args, err } func makeTestingConvCfg() (sessiondatapb.DataConversionConfig, *time.Location) { @@ -1252,3 +1261,263 @@ func TestConnCloseCancelsAuth(t *testing.T) { // Check that the auth process indeed noticed the cancelation. <-authBlocked } + +func TestParseClientProvidedSessionParameters(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // The test server is used only incidentally by this test: this is not the + // server that the client will connect to; we just use it on the side to + // execute some metadata queries that pgx sends whenever it opens a + // connection. + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true, UseDatabase: "system"}) + defer s.Stopper().Stop(context.Background()) + + // Start a pgwire "server". + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + serverAddr := ln.Addr() + log.Infof(context.Background(), "started listener on %s", serverAddr) + testCases := []struct { + desc string + query string + assert func(t *testing.T, args sql.SessionArgs, err error) + }{ + { + desc: "user is set from query", + query: "user=root", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "root", args.User.Normalized()) + }, + }, + { + desc: "user is ignored in options", + query: "user=root&options=-c%20user=test_user_from_options", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "root", args.User.Normalized()) + _, ok := args.SessionDefaults["user"] + require.False(t, ok) + }, + }, + { + desc: "results_buffer_size is not configurable from options", + query: "user=root&options=-c%20results_buffer_size=42", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "options: parameter \"results_buffer_size\" cannot be changed", err) + }, + }, + { + desc: "crdb:remote_addr is ignored in options", + query: "user=root&options=-c%20crdb%3Aremote_addr=2.3.4.5%3A5432", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.NotEqual(t, "2.3.4.5:5432", args.RemoteAddr.String()) + }, + }, + { + desc: "more keys than values in options error", + query: "user=root&options=-c%20search_path==public,test,default", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "option \"search_path==public,test,default\" is invalid, check '='", err) + }, + }, + { + desc: "more values than keys in options error", + query: "user=root&options=-c%20search_path", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "option \"search_path\" is invalid, check '='", err) + }, + }, + { + desc: "success parsing encoded options", + query: "user=root&options=-c%20search_path%3ddefault%2Ctest", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "default,test", args.SessionDefaults["search_path"]) + }, + }, + { + desc: "success parsing options with no space after '-c'", + query: "user=root&options=-csearch_path=default,test -coptimizer_use_multicol_stats=true", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "default,test", args.SessionDefaults["search_path"]) + require.Equal(t, "true", args.SessionDefaults["optimizer_use_multicol_stats"]) + }, + }, + { + desc: "error when no leading '-c'", + query: "user=root&options=search_path=default", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "option \"search_path=default\" is invalid, must have prefix '-c' or '--'", err) + }, + }, + { + desc: "'-c' with no leading space belongs to prev value", + query: "user=root&options=-c search_path=default-c", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "default-c", args.SessionDefaults["search_path"]) + }, + }, + { + desc: "fail to parse '-c' with no leading space", + query: "user=root&options=-c search_path=default-c optimizer_use_multicol_stats=true", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "option \"optimizer_use_multicol_stats=true\" is invalid, must have prefix '-c' or '--'", err) + }, + }, + { + desc: "parse multiple options successfully", + query: "user=root&options=-c%20search_path=default,test%20-c%20optimizer_use_multicol_stats=true", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "default,test", args.SessionDefaults["search_path"]) + require.Equal(t, "true", args.SessionDefaults["optimizer_use_multicol_stats"]) + }, + }, + { + desc: "success parsing option with space in value", + query: "user=root&options=-c default_transaction_isolation=READ\\ UNCOMMITTED", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "READ UNCOMMITTED", args.SessionDefaults["default_transaction_isolation"]) + }, + }, + { + desc: "remote_addr missing port", + query: "user=root&crdb:remote_addr=5.4.3.2", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "invalid address format: address 5.4.3.2: missing port in address", err) + }, + }, + { + desc: "remote_addr port must be numeric", + query: "user=root&crdb:remote_addr=5.4.3.2:port", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "remote port is not numeric", err) + }, + }, + { + desc: "remote_addr host must be numeric", + query: "user=root&crdb:remote_addr=ip:5432", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.Error(t, err) + require.Regexp(t, "remote address is not numeric", err) + }, + }, + { + desc: "success setting remote address from query", + query: "user=root&crdb:remote_addr=2.3.4.5:5432", + assert: func(t *testing.T, args sql.SessionArgs, err error) { + require.NoError(t, err) + require.Equal(t, "2.3.4.5:5432", args.RemoteAddr.String()) + }, + }, + } + + baseURL := fmt.Sprintf("postgres://%s/system?sslmode=disable", serverAddr) + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + + go func() { + url := fmt.Sprintf("%s&%s", baseURL, tc.query) + c, err := gosql.Open("postgres", url) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + // ignore the error because there is no answer from the server, we are + // interested in parsing session arguments only + _ = c.PingContext(ctx) + // closing connection immediately, since getSessionArgs is blocking + _ = c.Close() + }() + + // Wait for the client to connect and perform the handshake. + _, args, err := getSessionArgs(ln, true /* trustRemoteAddr */) + tc.assert(t, args, err) + }) + } +} + +func TestSetSessionArguments(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + ctx := context.Background() + defer s.Stopper().Stop(ctx) + + pgURL, cleanupFunc := sqlutils.PGUrl( + t, s.ServingSQLAddr(), "testConnClose" /* prefix */, url.User(security.RootUser), + ) + defer cleanupFunc() + + q := pgURL.Query() + q.Add("options", " --user=test -c search_path=public,testsp %20 --default-transaction-isolation=read\\ uncommitted -capplication_name=test --datestyle=iso\\ ,\\ mdy\\ ") + pgURL.RawQuery = q.Encode() + noBufferDB, err := gosql.Open("postgres", pgURL.String()) + + if err != nil { + t.Fatal(err) + } + defer noBufferDB.Close() + + pgxConfig, err := pgx.ParseConnectionString(pgURL.String()) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(pgxConfig) + if err != nil { + t.Fatal(err) + } + + rows, err := conn.Query("show all") + if err != nil { + t.Fatal(err) + } + + expectedOptions := map[string]string{ + "search_path": "public,testsp", + // setting an isolation level is a noop: + // all transactions execute with serializable isolation. + "default_transaction_isolation": "serializable", + "application_name": "test", + "datestyle": "ISO, MDY", + } + expectedFoundOptions := len(expectedOptions) + + var foundOptions int + var variable, value string + for rows.Next() { + err = rows.Scan(&variable, &value) + if err != nil { + t.Fatal(err) + } + if v, ok := expectedOptions[variable]; ok { + foundOptions++ + if v != value { + t.Fatalf("option %q expected value %q, actual %q", variable, v, value) + } + } + } + require.Equal(t, expectedFoundOptions, foundOptions) + + if err := conn.Close(); err != nil { + t.Fatal(err) + } +} diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index e95d437c0aa8..8b702e50d4a0 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -15,6 +15,7 @@ import ( "crypto/tls" "io" "net" + "net/url" "strconv" "strings" "sync/atomic" @@ -755,23 +756,21 @@ func parseClientProvidedSessionParameters( } args.RemoteAddr = &net.TCPAddr{IP: ip, Port: port} - default: - exists, configurable := sql.IsSessionVariableConfigurable(key) - - switch { - case exists && configurable: - args.SessionDefaults[key] = value - - case !exists: - if _, ok := sql.UnsupportedVars[key]; ok { - counter := sqltelemetry.UnimplementedClientStatusParameterCounter(key) - telemetry.Inc(counter) + case "options": + opts, err := parseOptions(value) + if err != nil { + return sql.SessionArgs{}, err + } + for _, opt := range opts { + err = loadParameter(ctx, opt.key, opt.value, &args) + if err != nil { + return sql.SessionArgs{}, pgerror.Wrapf(err, pgerror.GetPGCode(err), "options") } - log.Warningf(ctx, "unknown configuration parameter: %q", key) - - case !configurable: - return sql.SessionArgs{}, pgerror.Newf(pgcode.CantChangeRuntimeParam, - "parameter %q cannot be changed", key) + } + default: + err = loadParameter(ctx, key, value, &args) + if err != nil { + return sql.SessionArgs{}, err } } } @@ -790,6 +789,133 @@ func parseClientProvidedSessionParameters( return args, nil } +func loadParameter(ctx context.Context, key, value string, args *sql.SessionArgs) error { + exists, configurable := sql.IsSessionVariableConfigurable(key) + + switch { + case exists && configurable: + args.SessionDefaults[key] = value + + case !exists: + if _, ok := sql.UnsupportedVars[key]; ok { + counter := sqltelemetry.UnimplementedClientStatusParameterCounter(key) + telemetry.Inc(counter) + } + log.Warningf(ctx, "unknown configuration parameter: %q", key) + + case !configurable: + return pgerror.Newf(pgcode.CantChangeRuntimeParam, + "parameter %q cannot be changed", key) + } + return nil +} + +// option represents an option argument passed in the connection URL. +type option struct { + key string + value string +} + +// parseOptions parses the given string into the options. The options must be +// separated by space and have one of the following patterns: +// '-c key=value', '-ckey=value', '--key=value' +func parseOptions(optionsString string) ([]option, error) { + var res []option + optionsRaw, err := url.QueryUnescape(optionsString) + if err != nil { + return nil, pgerror.Newf(pgcode.ProtocolViolation, "failed to unescape options %q", optionsString) + } + + lastWasDashC := false + opts := splitOptions(optionsRaw) + + for i := 0; i < len(opts); i++ { + prefix := "" + if len(opts[i]) > 1 { + prefix = opts[i][:2] + } + + switch { + case opts[i] == "-c": + lastWasDashC = true + continue + case lastWasDashC: + lastWasDashC = false + // if the last option was '-c' parse current option with no regard to + // the prefix + prefix = "" + case prefix == "--" || prefix == "-c": + lastWasDashC = false + default: + return nil, pgerror.Newf(pgcode.ProtocolViolation, + "option %q is invalid, must have prefix '-c' or '--'", opts[i]) + } + + opt, err := splitOption(opts[i], prefix) + if err != nil { + return nil, err + } + res = append(res, opt) + } + return res, nil +} + +// splitOptions slices the given string into substrings separated by space +// unless the space is escaped using backslashes '\\'. It also skips multiple +// subsequent spaces. +func splitOptions(options string) []string { + var res []string + var sb strings.Builder + i := 0 + for i < len(options) { + sb.Reset() + // skip leading space + for i < len(options) && options[i] == ' ' { + i++ + } + if i == len(options) { + break + } + + lastWasEscape := false + + for i < len(options) { + if options[i] == ' ' && !lastWasEscape { + break + } + if !lastWasEscape && options[i] == '\\' { + lastWasEscape = true + } else { + lastWasEscape = false + sb.WriteByte(options[i]) + } + i++ + } + + res = append(res, sb.String()) + } + + return res +} + +// splitOption splits the given opt argument into substrings separated by '='. +// It returns an error if the given option does not comply with the pattern +// "key=value" and the number of elements in the result is not two. +// splitOption removes the prefix from the key and replaces '-' with '_' so +// "--option-name=value" becomes [option_name, value]. +func splitOption(opt, prefix string) (option, error) { + kv := strings.Split(opt, "=") + + if len(kv) != 2 { + return option{}, pgerror.Newf(pgcode.ProtocolViolation, + "option %q is invalid, check '='", opt) + } + + kv[0] = strings.TrimPrefix(kv[0], prefix) + + return option{key: strings.ReplaceAll(kv[0], "-", "_"), value: kv[1]}, nil +} + // Note: Usage of an env var here makes it possible to unconditionally // enable this feature when cluster settings do not work reliably, // e.g. in multi-tenant setups in v20.2. This override mechanism can