Skip to content

Commit

Permalink
Merge pull request #135165 from cockroachdb/blathers/backport-release…
Browse files Browse the repository at this point in the history
…-24.3-135086

release-24.3: pgwire,authccl: use pgx for TestAuthenticationAndHBARules
  • Loading branch information
souravcrl authored Nov 14, 2024
2 parents b69ada3 + 30cae71 commit 13b1f5a
Show file tree
Hide file tree
Showing 12 changed files with 397 additions and 383 deletions.
3 changes: 2 additions & 1 deletion pkg/ccl/testccl/authccl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ go_test(
"@com_github_cockroachdb_datadriven//:datadriven",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_errors//stdstrings",
"@com_github_lib_pq//:pq",
"@com_github_jackc_pgconn//:pgconn",
"@com_github_jackc_pgx_v4//:pgx",
"@com_github_stretchr_testify//require",
],
)
84 changes: 57 additions & 27 deletions pkg/ccl/testccl/authccl/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ import (
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/errors/stdstrings"
"github.com/lib/pq"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -149,6 +150,7 @@ func makeSocketFile(t *testing.T) (socketDir, socketFile string, cleanupFn func(
}

func authCCLRunTest(t *testing.T, insecure bool) {
ctx := context.Background()
datadriven.Walk(t, datapathutils.TestDataPath(t), func(t *testing.T, path string) {
defer leaktest.AfterTest(t)()

Expand Down Expand Up @@ -188,15 +190,24 @@ func authCCLRunTest(t *testing.T, insecure bool) {
defer testutils.TestingHook(&ldapccl.NewLDAPUtil, newMockLDAPUtil)()
}

srv, conn, _ := serverutils.StartServer(t,
srv := serverutils.StartServerOnly(t,
base.TestServerArgs{
DefaultTestTenant: base.TestDoesNotWorkWithSharedProcessModeButWeDontKnowWhyYet(
base.TestTenantProbabilistic, 112949,
),
Insecure: insecure,
SocketFile: maybeSocketFile,
})
defer srv.Stopper().Stop(context.Background())
defer srv.Stopper().Stop(ctx)

pgURL, cleanup := srv.PGUrl(t, serverutils.User(username.RootUser), serverutils.ClientCerts(!insecure))
defer cleanup()
rootConn, err := pgx.Connect(ctx, pgURL.String())
if err != nil {
t.Fatal(err)
}
defer func() { _ = rootConn.Close(ctx) }()

s := srv.ApplicationLayer()
pgServer := s.PGServer().(*pgwire.Server)
pgServer.TestingEnableConnLogging()
Expand All @@ -209,7 +220,7 @@ func authCCLRunTest(t *testing.T, insecure bool) {

httpHBAUrl := s.AdminURL().WithPath("/debug/hba_conf").String()
sv := &s.ClusterSettings().SV
if _, err := conn.ExecContext(context.Background(), fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
if _, err := rootConn.Exec(ctx, fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
t.Fatal(err)
}

Expand All @@ -227,27 +238,27 @@ func authCCLRunTest(t *testing.T, insecure bool) {
if err != nil {
t.Fatalf("unknown value for jwt_cluster_setting enabled: %s", a.Vals[0])
}
jwtauthccl.JWTAuthEnabled.Override(context.Background(), sv, v)
jwtauthccl.JWTAuthEnabled.Override(ctx, sv, v)
case "audience":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting audience: %d", len(a.Vals))
}
jwtauthccl.JWTAuthAudience.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthAudience.Override(ctx, sv, a.Vals[0])
case "issuers":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting issuers: %d", len(a.Vals))
}
jwtauthccl.JWTAuthIssuersConfig.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthIssuersConfig.Override(ctx, sv, a.Vals[0])
case "jwks":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting jwks: %d", len(a.Vals))
}
jwtauthccl.JWTAuthJWKS.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthJWKS.Override(ctx, sv, a.Vals[0])
case "claim":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting claim: %d", len(a.Vals))
}
jwtauthccl.JWTAuthClaim.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthClaim.Override(ctx, sv, a.Vals[0])
case "ident_map":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting ident_map: %d", len(a.Vals))
Expand All @@ -256,7 +267,7 @@ func authCCLRunTest(t *testing.T, insecure bool) {
if len(args) != 3 {
t.Fatalf("wrong number of comma separated argumenets to jwt_cluster_setting ident_map: %d", len(a.Vals))
}
pgwire.ConnIdentityMapConf.Override(context.Background(), sv, strings.Join(args, " "))
pgwire.ConnIdentityMapConf.Override(ctx, sv, strings.Join(args, " "))
case "jwks_auto_fetch.enabled":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting jwks_auto.fetch_enabled: %d", len(a.Vals))
Expand All @@ -265,7 +276,7 @@ func authCCLRunTest(t *testing.T, insecure bool) {
if err != nil {
t.Fatalf("unknown value for jwt_cluster_setting jwks_auto_fetch.enabled: %s", a.Vals[0])
}
jwtauthccl.JWKSAutoFetchEnabled.Override(context.Background(), sv, v)
jwtauthccl.JWKSAutoFetchEnabled.Override(ctx, sv, v)
default:
t.Fatalf("unknown jwt_cluster_setting: %s", a.Key)
}
Expand All @@ -287,13 +298,13 @@ func authCCLRunTest(t *testing.T, insecure bool) {
}

case "sql":
_, err := conn.ExecContext(context.Background(), td.Input)
_, err := rootConn.Exec(ctx, td.Input)
return "ok", err

case "query_row":
var query_output string
err := conn.QueryRow(td.Input).Scan(&query_output)
return query_output, err
var query_output interface{}
err := rootConn.QueryRow(ctx, td.Input).Scan(&query_output)
return fmt.Sprintf("%v", query_output), err

case "connect", "connect_unix":
if td.Cmd == "connect_unix" && runtime.GOOS == "windows" {
Expand Down Expand Up @@ -375,24 +386,29 @@ func authCCLRunTest(t *testing.T, insecure bool) {
if len(a.Vals) > 0 {
val = a.Vals[0]
}
if len(val) == 0 {
// pgx.Connect requires empty values to be passed as a
// single-quoted empty string.
val = "''"
}
fmt.Fprintf(&dsnBuf, "%s%s=%s", sp, a.Key, val)
sp = " "
}
dsn := dsnBuf.String()

// Finally, connect and test the connection.
dbSQL, err := gosql.Open("postgres", dsn)
dbSQL, err := pgx.Connect(ctx, dsn)
if dbSQL != nil {
// Note: gosql.Open may return a valid db (with an open
// TCP connection) even if there is an error. We want to
// ensure this gets closed so that we catch the conn close
// message in the log.
defer dbSQL.Close()
defer func() { _ = dbSQL.Close(ctx) }()
}
if err != nil {
return "", err
}
row := dbSQL.QueryRow("SELECT current_catalog")
row := dbSQL.QueryRow(ctx, "SELECT current_catalog")
var dbName string
if err := row.Scan(&dbName); err != nil {
return "", err
Expand Down Expand Up @@ -442,7 +458,7 @@ func authCCLRunTest(t *testing.T, insecure bool) {
return strconv.Itoa(resp.StatusCode), nil

case "set_hba":
_, err := conn.ExecContext(context.Background(),
_, err := rootConn.Exec(ctx,
`SET CLUSTER SETTING server.host_based_authentication.configuration = $1`, td.Input)
if err != nil {
return "", err
Expand Down Expand Up @@ -521,20 +537,34 @@ var authLogFileRe = regexp.MustCompile(`"EventType":"client_`)
func fmtErr(err error) string {
if err != nil {
errStr := ""
if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) {
errStr = pqErr.Message
if pgcode.MakeCode(string(pqErr.Code)) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pqErr.Code)
if pgxErr := (*pgconn.PgError)(nil); errors.As(err, &pgxErr) {
errStr = pgxErr.Message
if pgcode.MakeCode(pgxErr.Code) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pgxErr.Code)
}
if pqErr.Hint != "" {
hint := strings.Replace(pqErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if pgxErr.Hint != "" {
hint := strings.Replace(pgxErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if strings.Contains(hint, "Supported methods:") {
// Depending on whether the test is running on linux or not
// (or, more specifically, whether gss build tag is set),
// "gss" method might not be included, so we remove it here
// and not include into the expected output.
hint = strings.Replace(hint, "gss, ", "", 1)
}
errStr += "\nHINT: " + hint
}
if pqErr.Detail != "" {
errStr += "\nDETAIL: " + pqErr.Detail
if pgxErr.Detail != "" {
errStr += "\nDETAIL: " + pgxErr.Detail
}
} else {
errStr = err.Error()
// pgx uses an internal type (pgconn.connectError) for "TLS not enabled"
// errors here. We need to munge the error here to avoid including
// non-stable information like IP addresses in the output.
const tlsErr = "tls error (server refused TLS connection)"
if strings.HasSuffix(errStr, tlsErr) {
errStr = tlsErr
}
}
return "ERROR: " + errStr
}
Expand Down
1 change: 0 additions & 1 deletion pkg/sql/pgwire/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ go_library(
"//pkg/sql/sqltelemetry",
"//pkg/sql/types",
"//pkg/util",
"//pkg/util/buildutil",
"//pkg/util/ctxlog",
"//pkg/util/duration",
"//pkg/util/envutil",
Expand Down
Loading

0 comments on commit 13b1f5a

Please sign in to comment.