Skip to content

Commit

Permalink
sqlccl: test prepared statement in session migration test
Browse files Browse the repository at this point in the history
Release note: None
  • Loading branch information
rafiss committed Aug 8, 2023
1 parent cdf6d15 commit c0867be
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
1 change: 1 addition & 0 deletions pkg/ccl/testccl/sqlccl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ go_test(
"@com_github_cockroachdb_redact//:redact",
"@com_github_gogo_protobuf//types",
"@com_github_jackc_pgx_v4//:pgx",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_stretchr_testify//require",
],
)
51 changes: 37 additions & 14 deletions pkg/ccl/testccl/sqlccl/show_transfer_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/require"
)

Expand All @@ -40,6 +41,8 @@ func TestShowTransferState(t *testing.T) {
require.NoError(t, err)
_, err = mainDB.Exec("ALTER TENANT ALL SET CLUSTER SETTING server.user_login.session_revival_token.enabled = true")
require.NoError(t, err)
_, err = tenantDB.Exec("CREATE TYPE typ AS ENUM ('foo', 'bar')")
require.NoError(t, err)

testUserConn := tenant.SQLConnForUser(t, username.TestUser, "")

Expand Down Expand Up @@ -82,24 +85,30 @@ func TestShowTransferState(t *testing.T) {
q := pgURL.Query()
q.Add("application_name", "carl")
pgURL.RawQuery = q.Encode()
conn, err := gosql.Open("postgres", pgURL.String())
conn, err := pgx.Connect(ctx, pgURL.String())
require.NoError(t, err)
defer conn.Close()
defer func() { _ = conn.Close(ctx) }()

// Add a prepared statement to make sure SHOW TRANSFER STATE handles it.
// Since lib/pq doesn't tell us the name of the prepared statement, we won't
// be able to test that we can use it after deserializing the session, but
// there are other tests for that.
stmt, err := conn.Prepare("SELECT 1 WHERE 1 = 1")
_, err = conn.Prepare(ctx, "prepared_stmt", "SELECT $1::INT4, 'foo'::typ WHERE 1 = 1")
require.NoError(t, err)

var intResult int
var enumResult string
err = conn.QueryRow(ctx, "prepared_stmt", 1).Scan(&intResult, &enumResult)
require.NoError(t, err)
defer stmt.Close()
require.Equal(t, 1, intResult)
require.Equal(t, "foo", enumResult)

rows, err := conn.Query(`SHOW TRANSFER STATE WITH 'foobar'`)
rows, err := conn.Query(ctx, `SHOW TRANSFER STATE WITH 'foobar'`, pgx.QueryExecModeSimpleProtocol)
require.NoError(t, err, "show transfer state failed")
defer rows.Close()

resultColumns, err := rows.Columns()
require.NoError(t, err)
fieldDescriptions := rows.FieldDescriptions()
var resultColumns []string
for _, f := range fieldDescriptions {
resultColumns = append(resultColumns, string(f.Name))
}

require.Equal(t, []string{
"error",
Expand Down Expand Up @@ -136,26 +145,40 @@ func TestShowTransferState(t *testing.T) {
q.Add("application_name", "someotherapp")
q.Add("crdb:session_revival_token_base64", token)
pgURL.RawQuery = q.Encode()
conn, err := gosql.Open("postgres", pgURL.String())
conn, err := pgx.Connect(ctx, pgURL.String())
require.NoError(t, err)
defer conn.Close()
defer func() { _ = conn.Close(ctx) }()

var appName string
err = conn.QueryRow("SHOW application_name").Scan(&appName)
err = conn.QueryRow(ctx, "SHOW application_name").Scan(&appName)
require.NoError(t, err)
require.Equal(t, "someotherapp", appName)

var b bool
err = conn.QueryRow(
ctx,
"SELECT crdb_internal.deserialize_session(decode($1, 'base64'))",
state,
).Scan(&b)
require.NoError(t, err)
require.True(t, b)

err = conn.QueryRow("SHOW application_name").Scan(&appName)
err = conn.QueryRow(ctx, "SHOW application_name").Scan(&appName)
require.NoError(t, err)
require.Equal(t, "carl", appName)

// Confirm that the prepared statement can be used after deserializing the
// session.
result := conn.PgConn().ExecPrepared(
ctx,
"prepared_stmt",
[][]byte{{0, 0, 0, 2}}, // binary representation of 2
[]int16{1}, // paramFormats - 1 means binary
[]int16{1}, // resultFormats - 1 means binary
).Read()
require.Equal(t, [][][]byte{{
{0, 0, 0, 2}, {0x66, 0x6f, 0x6f}, // binary representation of 2, 'foo'
}}, result.Rows)
})

// Errors should be displayed as a SQL value.
Expand Down

0 comments on commit c0867be

Please sign in to comment.