From 7ebca5e753e5563b05868e95cf791b1c5898854f Mon Sep 17 00:00:00 2001 From: Rafi Shamim Date: Mon, 7 Aug 2023 13:28:42 -0400 Subject: [PATCH] sqlccl: test prepared statement in session migration test Release note: None --- .../sqlccl/show_transfer_state_test.go | 51 ++++++++++++++----- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go b/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go index db5eaa04ae8a..8dcc144f4ffd 100644 --- a/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go +++ b/pkg/ccl/testccl/sqlccl/show_transfer_state_test.go @@ -20,6 +20,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/jackc/pgx/v4" "github.com/stretchr/testify/require" ) @@ -40,6 +41,8 @@ func TestShowTransferState(t *testing.T) { require.NoError(t, err) _, err = mainDB.Exec("ALTER TENANT ALL SET CLUSTER SETTING server.user_login.session_revival_token.enabled = true") require.NoError(t, err) + _, err = tenantDB.Exec("CREATE TYPE typ AS ENUM ('foo', 'bar')") + require.NoError(t, err) testUserConn := tenant.SQLConnForUser(t, username.TestUser, "") @@ -82,24 +85,30 @@ func TestShowTransferState(t *testing.T) { q := pgURL.Query() q.Add("application_name", "carl") pgURL.RawQuery = q.Encode() - conn, err := gosql.Open("postgres", pgURL.String()) + conn, err := pgx.Connect(ctx, pgURL.String()) require.NoError(t, err) - defer conn.Close() + defer func() { _ = conn.Close(ctx) }() // Add a prepared statement to make sure SHOW TRANSFER STATE handles it. - // Since lib/pq doesn't tell us the name of the prepared statement, we won't - // be able to test that we can use it after deserializing the session, but - // there are other tests for that. - stmt, err := conn.Prepare("SELECT 1 WHERE 1 = 1") + _, err = conn.Prepare(ctx, "prepared_stmt", "SELECT $1::INT4, 'foo'::typ WHERE 1 = 1") + require.NoError(t, err) + + var intResult int + var enumResult string + err = conn.QueryRow(ctx, "prepared_stmt", 1).Scan(&intResult, &enumResult) require.NoError(t, err) - defer stmt.Close() + require.Equal(t, 1, intResult) + require.Equal(t, "foo", enumResult) - rows, err := conn.Query(`SHOW TRANSFER STATE WITH 'foobar'`) + rows, err := conn.Query(ctx, `SHOW TRANSFER STATE WITH 'foobar'`, pgx.QuerySimpleProtocol(true)) require.NoError(t, err, "show transfer state failed") defer rows.Close() - resultColumns, err := rows.Columns() - require.NoError(t, err) + fieldDescriptions := rows.FieldDescriptions() + var resultColumns []string + for _, f := range fieldDescriptions { + resultColumns = append(resultColumns, string(f.Name)) + } require.Equal(t, []string{ "error", @@ -136,26 +145,40 @@ func TestShowTransferState(t *testing.T) { q.Add("application_name", "someotherapp") q.Add("crdb:session_revival_token_base64", token) pgURL.RawQuery = q.Encode() - conn, err := gosql.Open("postgres", pgURL.String()) + conn, err := pgx.Connect(ctx, pgURL.String()) require.NoError(t, err) - defer conn.Close() + defer func() { _ = conn.Close(ctx) }() var appName string - err = conn.QueryRow("SHOW application_name").Scan(&appName) + err = conn.QueryRow(ctx, "SHOW application_name").Scan(&appName) require.NoError(t, err) require.Equal(t, "someotherapp", appName) var b bool err = conn.QueryRow( + ctx, "SELECT crdb_internal.deserialize_session(decode($1, 'base64'))", state, ).Scan(&b) require.NoError(t, err) require.True(t, b) - err = conn.QueryRow("SHOW application_name").Scan(&appName) + err = conn.QueryRow(ctx, "SHOW application_name").Scan(&appName) require.NoError(t, err) require.Equal(t, "carl", appName) + + // Confirm that the prepared statement can be used after deserializing the + // session. + result := conn.PgConn().ExecPrepared( + ctx, + "prepared_stmt", + [][]byte{{0, 0, 0, 2}}, // binary representation of 2 + []int16{1}, // paramFormats - 1 means binary + []int16{1}, // resultFormats - 1 means binary + ).Read() + require.Equal(t, [][][]byte{{ + {0, 0, 0, 2}, {0x66, 0x6f, 0x6f}, // binary representation of 2, 'foo' + }}, result.Rows) }) // Errors should be displayed as a SQL value.