Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release-24.2: pgwire,authccl: use pgx for TestAuthenticationAndHBARules #135172

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
pgwire,authccl: use pgx for TestAuthenticationAndHBARules
The lib/pq driver is not maintained. Since we started to see flakes
related to how that driver does error handling for secure connections,
we switch to pgx instead.

Release note: None
rafiss committed Nov 14, 2024
commit 46f29e4e3c174aca1c3a2954bb8fb30761dceee1
3 changes: 2 additions & 1 deletion pkg/ccl/testccl/authccl/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -34,7 +34,8 @@ go_test(
"@com_github_cockroachdb_datadriven//:datadriven",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_errors//stdstrings",
"@com_github_lib_pq//:pq",
"@com_github_jackc_pgconn//:pgconn",
"@com_github_jackc_pgx_v4//:pgx",
"@com_github_stretchr_testify//require",
],
)
78 changes: 54 additions & 24 deletions pkg/ccl/testccl/authccl/auth_test.go
Original file line number Diff line number Diff line change
@@ -44,7 +44,8 @@ import (
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/errors/stdstrings"
"github.com/lib/pq"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/require"
)

@@ -148,6 +149,7 @@ func makeSocketFile(t *testing.T) (socketDir, socketFile string, cleanupFn func(
}

func authCCLRunTest(t *testing.T, insecure bool) {
ctx := context.Background()
datadriven.Walk(t, datapathutils.TestDataPath(t), func(t *testing.T, path string) {
defer leaktest.AfterTest(t)()

@@ -181,15 +183,24 @@ func authCCLRunTest(t *testing.T, insecure bool) {
}
defer cleanup()

srv, conn, _ := serverutils.StartServer(t,
srv := serverutils.StartServerOnly(t,
base.TestServerArgs{
DefaultTestTenant: base.TestDoesNotWorkWithSharedProcessModeButWeDontKnowWhyYet(
base.TestTenantProbabilistic, 112949,
),
Insecure: insecure,
SocketFile: maybeSocketFile,
})
defer srv.Stopper().Stop(context.Background())
defer srv.Stopper().Stop(ctx)

pgURL, cleanup := srv.PGUrl(t, serverutils.User(username.RootUser), serverutils.ClientCerts(!insecure))
defer cleanup()
rootConn, err := pgx.Connect(ctx, pgURL.String())
if err != nil {
t.Fatal(err)
}
defer func() { _ = rootConn.Close(ctx) }()

s := srv.ApplicationLayer()
pgServer := s.PGServer().(*pgwire.Server)
pgServer.TestingEnableConnLogging()
@@ -202,7 +213,7 @@ func authCCLRunTest(t *testing.T, insecure bool) {

httpHBAUrl := s.AdminURL().WithPath("/debug/hba_conf").String()
sv := &s.ClusterSettings().SV
if _, err := conn.ExecContext(context.Background(), fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
if _, err := rootConn.Exec(ctx, fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
t.Fatal(err)
}

@@ -227,27 +238,27 @@ func authCCLRunTest(t *testing.T, insecure bool) {
if err != nil {
t.Fatalf("unknown value for jwt_cluster_setting enabled: %s", a.Vals[0])
}
jwtauthccl.JWTAuthEnabled.Override(context.Background(), sv, v)
jwtauthccl.JWTAuthEnabled.Override(ctx, sv, v)
case "audience":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting audience: %d", len(a.Vals))
}
jwtauthccl.JWTAuthAudience.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthAudience.Override(ctx, sv, a.Vals[0])
case "issuers":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting issuers: %d", len(a.Vals))
}
jwtauthccl.JWTAuthIssuersConfig.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthIssuersConfig.Override(ctx, sv, a.Vals[0])
case "jwks":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting jwks: %d", len(a.Vals))
}
jwtauthccl.JWTAuthJWKS.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthJWKS.Override(ctx, sv, a.Vals[0])
case "claim":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting claim: %d", len(a.Vals))
}
jwtauthccl.JWTAuthClaim.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthClaim.Override(ctx, sv, a.Vals[0])
case "ident_map":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting ident_map: %d", len(a.Vals))
@@ -256,7 +267,7 @@ func authCCLRunTest(t *testing.T, insecure bool) {
if len(args) != 3 {
t.Fatalf("wrong number of comma separated argumenets to jwt_cluster_setting ident_map: %d", len(a.Vals))
}
pgwire.ConnIdentityMapConf.Override(context.Background(), sv, strings.Join(args, " "))
pgwire.ConnIdentityMapConf.Override(ctx, sv, strings.Join(args, " "))
case "jwks_auto_fetch.enabled":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting jwks_auto.fetch_enabled: %d", len(a.Vals))
@@ -265,7 +276,7 @@ func authCCLRunTest(t *testing.T, insecure bool) {
if err != nil {
t.Fatalf("unknown value for jwt_cluster_setting jwks_auto_fetch.enabled: %s", a.Vals[0])
}
jwtauthccl.JWKSAutoFetchEnabled.Override(context.Background(), sv, v)
jwtauthccl.JWKSAutoFetchEnabled.Override(ctx, sv, v)
default:
t.Fatalf("unknown jwt_cluster_setting: %s", a.Key)
}
@@ -287,7 +298,7 @@ func authCCLRunTest(t *testing.T, insecure bool) {
}

case "sql":
_, err := conn.ExecContext(context.Background(), td.Input)
_, err := rootConn.Exec(ctx, td.Input)
return "ok", err

case "connect", "connect_unix":
@@ -370,24 +381,29 @@ func authCCLRunTest(t *testing.T, insecure bool) {
if len(a.Vals) > 0 {
val = a.Vals[0]
}
if len(val) == 0 {
// pgx.Connect requires empty values to be passed as a
// single-quoted empty string.
val = "''"
}
fmt.Fprintf(&dsnBuf, "%s%s=%s", sp, a.Key, val)
sp = " "
}
dsn := dsnBuf.String()

// Finally, connect and test the connection.
dbSQL, err := gosql.Open("postgres", dsn)
dbSQL, err := pgx.Connect(ctx, dsn)
if dbSQL != nil {
// Note: gosql.Open may return a valid db (with an open
// TCP connection) even if there is an error. We want to
// ensure this gets closed so that we catch the conn close
// message in the log.
defer dbSQL.Close()
defer func() { _ = dbSQL.Close(ctx) }()
}
if err != nil {
return "", err
}
row := dbSQL.QueryRow("SELECT current_catalog")
row := dbSQL.QueryRow(ctx, "SELECT current_catalog")
var dbName string
if err := row.Scan(&dbName); err != nil {
return "", err
@@ -437,7 +453,7 @@ func authCCLRunTest(t *testing.T, insecure bool) {
return strconv.Itoa(resp.StatusCode), nil

case "set_hba":
_, err := conn.ExecContext(context.Background(),
_, err := rootConn.Exec(ctx,
`SET CLUSTER SETTING server.host_based_authentication.configuration = $1`, td.Input)
if err != nil {
return "", err
@@ -496,20 +512,34 @@ var authLogFileRe = regexp.MustCompile(`"EventType":"client_`)
func fmtErr(err error) string {
if err != nil {
errStr := ""
if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) {
errStr = pqErr.Message
if pgcode.MakeCode(string(pqErr.Code)) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pqErr.Code)
if pgxErr := (*pgconn.PgError)(nil); errors.As(err, &pgxErr) {
errStr = pgxErr.Message
if pgcode.MakeCode(pgxErr.Code) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pgxErr.Code)
}
if pqErr.Hint != "" {
hint := strings.Replace(pqErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if pgxErr.Hint != "" {
hint := strings.Replace(pgxErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if strings.Contains(hint, "Supported methods:") {
// Depending on whether the test is running on linux or not
// (or, more specifically, whether gss build tag is set),
// "gss" method might not be included, so we remove it here
// and not include into the expected output.
hint = strings.Replace(hint, "gss, ", "", 1)
}
errStr += "\nHINT: " + hint
}
if pqErr.Detail != "" {
errStr += "\nDETAIL: " + pqErr.Detail
if pgxErr.Detail != "" {
errStr += "\nDETAIL: " + pgxErr.Detail
}
} else {
errStr = err.Error()
// pgx uses an internal type (pgconn.connectError) for "TLS not enabled"
// errors here. We need to munge the error here to avoid including
// non-stable information like IP addresses in the output.
const tlsErr = "tls error (server refused TLS connection)"
if strings.HasSuffix(errStr, tlsErr) {
errStr = tlsErr
}
}
return "ERROR: " + errStr
}
61 changes: 41 additions & 20 deletions pkg/sql/pgwire/auth_test.go
Original file line number Diff line number Diff line change
@@ -42,8 +42,8 @@ import (
"github.com/cockroachdb/errors"
"github.com/cockroachdb/errors/stdstrings"
"github.com/cockroachdb/redact"
"github.com/jackc/pgconn"
pgx "github.com/jackc/pgx/v4"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)

@@ -171,6 +171,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
if !insecure {
httpScheme = "https://"
}
ctx := context.Background()

datadriven.Walk(t, datapathutils.TestDataPath(t, "auth"), func(t *testing.T, path string) {
defer leaktest.AfterTest(t)()
@@ -212,13 +213,21 @@ func hbaRunTest(t *testing.T, insecure bool) {
}
defer cleanup()

s, conn, _ := serverutils.StartServer(t,
s := serverutils.StartServerOnly(t,
base.TestServerArgs{
DefaultTestTenant: base.TestIsForStuffThatShouldWorkWithSecondaryTenantsButDoesntYet(107310),
Insecure: insecure,
SocketFile: maybeSocketFile,
})
defer s.Stopper().Stop(context.Background())
defer s.Stopper().Stop(ctx)

pgURL, cleanup := s.PGUrl(t, serverutils.User(username.RootUser), serverutils.ClientCerts(!insecure))
defer cleanup()
rootConn, err := pgx.Connect(ctx, pgURL.String())
if err != nil {
t.Fatal(err)
}
defer func() { _ = rootConn.Close(ctx) }()

// Enable conn/auth logging.
// We can't use the cluster settings to do this, because
@@ -235,7 +244,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
}
httpHBAUrl := httpScheme + s.HTTPAddr() + "/debug/hba_conf"

if _, err := conn.ExecContext(context.Background(), fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
if _, err := rootConn.Exec(ctx, fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
t.Fatal(err)
}

@@ -272,7 +281,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
testServer.SetAcceptSQLWithoutTLS(false)

case "set_hba":
_, err := conn.ExecContext(context.Background(),
_, err := rootConn.Exec(ctx,
`SET CLUSTER SETTING server.host_based_authentication.configuration = $1`, td.Input)
if err != nil {
return "", err
@@ -313,7 +322,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
return string(body), nil

case "set_identity_map":
_, err := conn.ExecContext(context.Background(),
_, err := rootConn.Exec(ctx,
`SET CLUSTER SETTING server.identity_map.configuration = $1`, td.Input)
if err != nil {
return "", err
@@ -354,7 +363,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
return string(body), nil

case "sql":
_, err := conn.ExecContext(context.Background(), td.Input)
_, err := rootConn.Exec(ctx, td.Input)
return "ok", err

case "authlog":
@@ -550,30 +559,35 @@ func hbaRunTest(t *testing.T, insecure bool) {
if len(a.Vals) > 0 {
val = a.Vals[0]
}
if len(val) == 0 {
// pgx.Connect requires empty values to be passed as a
// single-quoted empty string.
val = "''"
}
fmt.Fprintf(&dsnBuf, "%s%s=%s", sp, a.Key, val)
sp = " "
}
dsn := dsnBuf.String()

// Finally, connect and test the connection.
dbSQL, err := gosql.Open("postgres", dsn)
dbSQL, err := pgx.Connect(ctx, dsn)
if dbSQL != nil {
// Note: gosql.Open may return a valid db (with an open
// Note: pgx.Connect may return a valid db (with an open
// TCP connection) even if there is an error. We want to
// ensure this gets closed so that we catch the conn close
// message in the log.
defer dbSQL.Close()
defer func() { _ = dbSQL.Close(ctx) }()
}
if err != nil {
return "", err
}
row := dbSQL.QueryRow("SELECT current_catalog")
row := dbSQL.QueryRow(ctx, "SELECT current_catalog")
var result string
if err := row.Scan(&result); err != nil {
return "", err
}
if showSystemIdentity {
row := dbSQL.QueryRow(`SHOW system_identity`)
row := dbSQL.QueryRow(ctx, `SHOW system_identity`)
var name string
if err := row.Scan(&name); err != nil {
t.Fatal(err)
@@ -602,13 +616,13 @@ var authLogFileRe = regexp.MustCompile(`"EventType":"client_`)
func fmtErr(err error) string {
if err != nil {
errStr := ""
if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) {
errStr = pqErr.Message
if pgcode.MakeCode(string(pqErr.Code)) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pqErr.Code)
if pgxErr := (*pgconn.PgError)(nil); errors.As(err, &pgxErr) {
errStr = pgxErr.Message
if pgcode.MakeCode(pgxErr.Code) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pgxErr.Code)
}
if pqErr.Hint != "" {
hint := strings.Replace(pqErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if pgxErr.Hint != "" {
hint := strings.Replace(pgxErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if strings.Contains(hint, "Supported methods:") {
// Depending on whether the test is running on linux or not
// (or, more specifically, whether gss build tag is set),
@@ -618,11 +632,18 @@ func fmtErr(err error) string {
}
errStr += "\nHINT: " + hint
}
if pqErr.Detail != "" {
errStr += "\nDETAIL: " + pqErr.Detail
if pgxErr.Detail != "" {
errStr += "\nDETAIL: " + pgxErr.Detail
}
} else {
errStr = err.Error()
// pgx uses an internal type (pgconn.connectError) for "TLS not enabled"
// errors here. We need to munge the error here to avoid including
// non-stable information like IP addresses in the output.
const tlsErr = "tls error (server refused TLS connection)"
if strings.HasSuffix(errStr, tlsErr) {
errStr = tlsErr
}
}
return "ERROR: " + errStr
}
Loading