Skip to content

Commit

Permalink
pgwire: set options based on "options" URL parameter
Browse files Browse the repository at this point in the history
Previously, CRDB ignored "options" URL parameter. User session parameters should
have been set via URL parameters directly:
`postgres://user@host:port/database?serial_normalization=virtual_sequence`

CRDB can now parse "options" URL parameter and set corresponding session
parameters (in compliance with Postgres jdbc connection parameters):
`postgres://user@host:port/database?options=-c%20serial_normalization=virtual_sequence`

Fixes #59404

Release note (sql change): CockroachDB now recognizes "options" URL parameter. The
"options" parameter specifies session variables to set at connection start. This
is treated the same as defined in the PostgreSQL docs: https://www.postgresql.org/docs/13/libpq-connect.html#LIBPQ-PARAMKEYWORDS
  • Loading branch information
mneverov committed Feb 9, 2021
1 parent 81a2c26 commit dd2512c
Show file tree
Hide file tree
Showing 2 changed files with 423 additions and 28 deletions.
293 changes: 281 additions & 12 deletions pkg/sql/pgwire/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,32 +520,41 @@ func client(ctx context.Context, serverAddr net.Addr, wg *sync.WaitGroup) error
// waitForClientConn blocks until a client connects and performs the pgwire
// handshake. This emulates what pgwire.Server does.
func waitForClientConn(ln net.Listener) (*conn, error) {
conn, err := ln.Accept()
conn, _, err := getSessionArgs(ln, false /* trustRemoteAddr */)
if err != nil {
return nil, err
}

metrics := makeServerMetrics(sql.MemoryMetrics{} /* sqlMemMetrics */, metric.TestSampleInterval)
pgwireConn := newConn(conn, sql.SessionArgs{ConnResultsBufferSize: 16 << 10}, &metrics, nil)
return pgwireConn, nil
}

// getSessionArgs blocks until a client connects and returns the connection
// together with session arguments or an error.
func getSessionArgs(ln net.Listener, trustRemoteAddr bool) (net.Conn, sql.SessionArgs, error) {
conn, err := ln.Accept()
if err != nil {
return nil, sql.SessionArgs{}, err
}

buf := pgwirebase.MakeReadBuffer()
_, err = buf.ReadUntypedMsg(conn)
if err != nil {
return nil, err
return nil, sql.SessionArgs{}, err
}
version, err := buf.GetUint32()
if err != nil {
return nil, err
return nil, sql.SessionArgs{}, err
}
if version != version30 {
return nil, errors.Errorf("unexpected protocol version: %d", version)
}

// Consume the connection options.
if _, err := parseClientProvidedSessionParameters(context.Background(), nil, &buf, conn.RemoteAddr(), false /* trustRemoteAddr */); err != nil {
return nil, err
return nil, sql.SessionArgs{}, errors.Errorf("unexpected protocol version: %d", version)
}

metrics := makeServerMetrics(sql.MemoryMetrics{} /* sqlMemMetrics */, metric.TestSampleInterval)
pgwireConn := newConn(conn, sql.SessionArgs{ConnResultsBufferSize: 16 << 10}, &metrics, nil)
return pgwireConn, nil
args, err := parseClientProvidedSessionParameters(
context.Background(), nil, &buf, conn.RemoteAddr(), trustRemoteAddr,
)
return conn, args, err
}

func makeTestingConvCfg() (sessiondatapb.DataConversionConfig, *time.Location) {
Expand Down Expand Up @@ -1252,3 +1261,263 @@ func TestConnCloseCancelsAuth(t *testing.T) {
// Check that the auth process indeed noticed the cancelation.
<-authBlocked
}

func TestParseClientProvidedSessionParameters(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

// The test server is used only incidentally by this test: this is not the
// server that the client will connect to; we just use it on the side to
// execute some metadata queries that pgx sends whenever it opens a
// connection.
s, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true, UseDatabase: "system"})
defer s.Stopper().Stop(context.Background())

// Start a pgwire "server".
addr := util.TestAddr
ln, err := net.Listen(addr.Network(), addr.String())
if err != nil {
t.Fatal(err)
}
serverAddr := ln.Addr()
log.Infof(context.Background(), "started listener on %s", serverAddr)
testCases := []struct {
desc string
query string
assert func(t *testing.T, args sql.SessionArgs, err error)
}{
{
desc: "user is set from query",
query: "user=root",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.NoError(t, err)
require.Equal(t, "root", args.User.Normalized())
},
},
{
desc: "user is ignored in options",
query: "user=root&options=-c%20user=test_user_from_options",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.NoError(t, err)
require.Equal(t, "root", args.User.Normalized())
_, ok := args.SessionDefaults["user"]
require.False(t, ok)
},
},
{
desc: "results_buffer_size is not configurable from options",
query: "user=root&options=-c%20results_buffer_size=42",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.Error(t, err)
require.Regexp(t, "options: parameter \"results_buffer_size\" cannot be changed", err)
},
},
{
desc: "crdb:remote_addr is ignored in options",
query: "user=root&options=-c%20crdb%3Aremote_addr=2.3.4.5%3A5432",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.NoError(t, err)
require.NotEqual(t, "2.3.4.5:5432", args.RemoteAddr.String())
},
},
{
desc: "more keys than values in options error",
query: "user=root&options=-c%20search_path==public,test,default",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.Error(t, err)
require.Regexp(t, "option \"search_path==public,test,default\" is invalid, check '='", err)
},
},
{
desc: "more values than keys in options error",
query: "user=root&options=-c%20search_path",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.Error(t, err)
require.Regexp(t, "option \"search_path\" is invalid, check '='", err)
},
},
{
desc: "success parsing encoded options",
query: "user=root&options=-c%20search_path%3ddefault%2Ctest",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.NoError(t, err)
require.Equal(t, "default,test", args.SessionDefaults["search_path"])
},
},
{
desc: "success parsing options with no space after '-c'",
query: "user=root&options=-csearch_path=default,test -coptimizer_use_multicol_stats=true",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.NoError(t, err)
require.Equal(t, "default,test", args.SessionDefaults["search_path"])
require.Equal(t, "true", args.SessionDefaults["optimizer_use_multicol_stats"])
},
},
{
desc: "error when no leading '-c'",
query: "user=root&options=search_path=default",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.Error(t, err)
require.Regexp(t, "option \"search_path=default\" is invalid, must have prefix '-c' or '--'", err)
},
},
{
desc: "'-c' with no leading space belongs to prev value",
query: "user=root&options=-c search_path=default-c",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.NoError(t, err)
require.Equal(t, "default-c", args.SessionDefaults["search_path"])
},
},
{
desc: "fail to parse '-c' with no leading space",
query: "user=root&options=-c search_path=default-c optimizer_use_multicol_stats=true",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.Error(t, err)
require.Regexp(t, "option \"optimizer_use_multicol_stats=true\" is invalid, must have prefix '-c' or '--'", err)
},
},
{
desc: "parse multiple options successfully",
query: "user=root&options=-c%20search_path=default,test%20-c%20optimizer_use_multicol_stats=true",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.NoError(t, err)
require.Equal(t, "default,test", args.SessionDefaults["search_path"])
require.Equal(t, "true", args.SessionDefaults["optimizer_use_multicol_stats"])
},
},
{
desc: "success parsing option with space in value",
query: "user=root&options=-c default_transaction_isolation=READ\\ UNCOMMITTED",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.NoError(t, err)
require.Equal(t, "READ UNCOMMITTED", args.SessionDefaults["default_transaction_isolation"])
},
},
{
desc: "remote_addr missing port",
query: "user=root&crdb:remote_addr=5.4.3.2",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.Error(t, err)
require.Regexp(t, "invalid address format: address 5.4.3.2: missing port in address", err)
},
},
{
desc: "remote_addr port must be numeric",
query: "user=root&crdb:remote_addr=5.4.3.2:port",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.Error(t, err)
require.Regexp(t, "remote port is not numeric", err)
},
},
{
desc: "remote_addr host must be numeric",
query: "user=root&crdb:remote_addr=ip:5432",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.Error(t, err)
require.Regexp(t, "remote address is not numeric", err)
},
},
{
desc: "success setting remote address from query",
query: "user=root&crdb:remote_addr=2.3.4.5:5432",
assert: func(t *testing.T, args sql.SessionArgs, err error) {
require.NoError(t, err)
require.Equal(t, "2.3.4.5:5432", args.RemoteAddr.String())
},
},
}

baseURL := fmt.Sprintf("postgres://%s/system?sslmode=disable", serverAddr)

for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {

go func() {
url := fmt.Sprintf("%s&%s", baseURL, tc.query)
c, err := gosql.Open("postgres", url)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
// ignore the error because there is no answer from the server, we are
// interested in parsing session arguments only
_ = c.PingContext(ctx)
// closing connection immediately, since getSessionArgs is blocking
_ = c.Close()
}()

// Wait for the client to connect and perform the handshake.
_, args, err := getSessionArgs(ln, true /* trustRemoteAddr */)
tc.assert(t, args, err)
})
}
}

func TestSetSessionArguments(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)
s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
ctx := context.Background()
defer s.Stopper().Stop(ctx)

pgURL, cleanupFunc := sqlutils.PGUrl(
t, s.ServingSQLAddr(), "testConnClose" /* prefix */, url.User(security.RootUser),
)
defer cleanupFunc()

q := pgURL.Query()
q.Add("options", " --user=test -c search_path=public,testsp %20 --default-transaction-isolation=read\\ uncommitted -capplication_name=test --datestyle=iso\\ ,\\ mdy\\ ")
pgURL.RawQuery = q.Encode()
noBufferDB, err := gosql.Open("postgres", pgURL.String())

if err != nil {
t.Fatal(err)
}
defer noBufferDB.Close()

pgxConfig, err := pgx.ParseConnectionString(pgURL.String())
if err != nil {
t.Fatal(err)
}

conn, err := pgx.Connect(pgxConfig)
if err != nil {
t.Fatal(err)
}

rows, err := conn.Query("show all")
if err != nil {
t.Fatal(err)
}

expectedOptions := map[string]string{
"search_path": "public,testsp",
// setting an isolation level is a noop:
// all transactions execute with serializable isolation.
"default_transaction_isolation": "serializable",
"application_name": "test",
"datestyle": "ISO, MDY",
}
expectedFoundOptions := len(expectedOptions)

var foundOptions int
var variable, value string
for rows.Next() {
err = rows.Scan(&variable, &value)
if err != nil {
t.Fatal(err)
}
if v, ok := expectedOptions[variable]; ok {
foundOptions++
if v != value {
t.Fatalf("option %q expected value %q, actual %q", variable, v, value)
}
}
}
require.Equal(t, expectedFoundOptions, foundOptions)

if err := conn.Close(); err != nil {
t.Fatal(err)
}
}
Loading

0 comments on commit dd2512c

Please sign in to comment.