Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release-23.1: pgwire,authccl: use pgx for TestAuthenticationAndHBARules #135178

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pkg/ccl/testccl/authccl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ go_test(
"@com_github_cockroachdb_datadriven//:datadriven",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_errors//stdstrings",
"@com_github_lib_pq//:pq",
"@com_github_jackc_pgconn//:pgconn",
"@com_github_jackc_pgx_v4//:pgx",
"@com_github_stretchr_testify//require",
],
)
81 changes: 58 additions & 23 deletions pkg/ccl/testccl/authccl/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ import (
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/errors/stdstrings"
"github.com/lib/pq"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -144,7 +145,7 @@ func makeSocketFile(t *testing.T) (socketDir, socketFile string, cleanupFn func(
}

func jwtRunTest(t *testing.T, insecure bool) {

ctx := context.Background()
datadriven.Walk(t, datapathutils.TestDataPath(t), func(t *testing.T, path string) {
defer leaktest.AfterTest(t)()

Expand Down Expand Up @@ -178,7 +179,7 @@ func jwtRunTest(t *testing.T, insecure bool) {
}
defer cleanup()

s, conn, _ := serverutils.StartServer(t,
s, _, _ := serverutils.StartServer(t,
base.TestServerArgs{
Insecure: insecure,
SocketFile: maybeSocketFile,
Expand All @@ -190,7 +191,22 @@ func jwtRunTest(t *testing.T, insecure bool) {
sv = &s.TestTenants()[0].ClusterSettings().SV
}

if _, err := conn.ExecContext(context.Background(), fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
pgURL, cleanup, err := sqlutils.PGUrlE(
s.ServingSQLAddr(), "TestAuthenticationAndHBARules", url.User(username.RootUser))
if err != nil {
t.Fatal(err)
}
if insecure {
pgURL.RawQuery = "sslmode=disable"
}
defer cleanup()
rootConn, err := pgx.Connect(ctx, pgURL.String())
if err != nil {
t.Fatal(err)
}
defer func() { _ = rootConn.Close(ctx) }()

if _, err := rootConn.Exec(ctx, fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
t.Fatal(err)
}

Expand All @@ -208,27 +224,27 @@ func jwtRunTest(t *testing.T, insecure bool) {
if err != nil {
t.Fatalf("unknown value for jwt_cluster_setting enabled: %s", a.Vals[0])
}
jwtauthccl.JWTAuthEnabled.Override(context.Background(), sv, v)
jwtauthccl.JWTAuthEnabled.Override(ctx, sv, v)
case "audience":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting audience: %d", len(a.Vals))
}
jwtauthccl.JWTAuthAudience.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthAudience.Override(ctx, sv, a.Vals[0])
case "issuers":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting issuers: %d", len(a.Vals))
}
jwtauthccl.JWTAuthIssuersConfig.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthIssuersConfig.Override(ctx, sv, a.Vals[0])
case "jwks":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting jwks: %d", len(a.Vals))
}
jwtauthccl.JWTAuthJWKS.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthJWKS.Override(ctx, sv, a.Vals[0])
case "claim":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting claim: %d", len(a.Vals))
}
jwtauthccl.JWTAuthClaim.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthClaim.Override(ctx, sv, a.Vals[0])
case "ident_map":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting ident_map: %d", len(a.Vals))
Expand All @@ -237,7 +253,7 @@ func jwtRunTest(t *testing.T, insecure bool) {
if len(args) != 3 {
t.Fatalf("wrong number of comma separated argumenets to jwt_cluster_setting ident_map: %d", len(a.Vals))
}
pgwire.ConnIdentityMapConf.Override(context.Background(), sv, strings.Join(args, " "))
pgwire.ConnIdentityMapConf.Override(ctx, sv, strings.Join(args, " "))
case "jwks_auto_fetch.enabled":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting jwks_auto.fetch_enabled: %d", len(a.Vals))
Expand All @@ -246,7 +262,7 @@ func jwtRunTest(t *testing.T, insecure bool) {
if err != nil {
t.Fatalf("unknown value for jwt_cluster_setting jwks_auto_fetch.enabled: %s", a.Vals[0])
}
jwtauthccl.JWKSAutoFetchEnabled.Override(context.Background(), sv, v)
jwtauthccl.JWKSAutoFetchEnabled.Override(ctx, sv, v)
default:
t.Fatalf("unknown jwt_cluster_setting: %s", a.Key)
}
Expand All @@ -268,7 +284,7 @@ func jwtRunTest(t *testing.T, insecure bool) {
}

case "sql":
_, err := conn.ExecContext(context.Background(), td.Input)
_, err := rootConn.Exec(ctx, td.Input)
return "ok", err

case "connect", "connect_unix":
Expand Down Expand Up @@ -350,24 +366,29 @@ func jwtRunTest(t *testing.T, insecure bool) {
if len(a.Vals) > 0 {
val = a.Vals[0]
}
if len(val) == 0 {
// pgx.Connect requires empty values to be passed as a
// single-quoted empty string.
val = "''"
}
fmt.Fprintf(&dsnBuf, "%s%s=%s", sp, a.Key, val)
sp = " "
}
dsn := dsnBuf.String()

// Finally, connect and test the connection.
dbSQL, err := gosql.Open("postgres", dsn)
dbSQL, err := pgx.Connect(ctx, dsn)
if dbSQL != nil {
// Note: gosql.Open may return a valid db (with an open
// TCP connection) even if there is an error. We want to
// ensure this gets closed so that we catch the conn close
// message in the log.
defer dbSQL.Close()
defer func() { _ = dbSQL.Close(ctx) }()
}
if err != nil {
return "", err
}
row := dbSQL.QueryRow("SELECT current_catalog")
row := dbSQL.QueryRow(ctx, "SELECT current_catalog")
var dbName string
if err := row.Scan(&dbName); err != nil {
return "", err
Expand All @@ -393,20 +414,34 @@ var authLogFileRe = regexp.MustCompile(`"EventType":"client_`)
func fmtErr(err error) string {
if err != nil {
errStr := ""
if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) {
errStr = pqErr.Message
if pgcode.MakeCode(string(pqErr.Code)) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pqErr.Code)
if pgxErr := (*pgconn.PgError)(nil); errors.As(err, &pgxErr) {
errStr = pgxErr.Message
if pgcode.MakeCode(pgxErr.Code) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pgxErr.Code)
}
if pqErr.Hint != "" {
hint := strings.Replace(pqErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if pgxErr.Hint != "" {
hint := strings.Replace(pgxErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if strings.Contains(hint, "Supported methods:") {
// Depending on whether the test is running on linux or not
// (or, more specifically, whether gss build tag is set),
// "gss" method might not be included, so we remove it here
// and not include into the expected output.
hint = strings.Replace(hint, "gss, ", "", 1)
}
errStr += "\nHINT: " + hint
}
if pqErr.Detail != "" {
errStr += "\nDETAIL: " + pqErr.Detail
if pgxErr.Detail != "" {
errStr += "\nDETAIL: " + pgxErr.Detail
}
} else {
errStr = err.Error()
// pgx uses an internal type (pgconn.connectError) for "TLS not enabled"
// errors here. We need to munge the error here to avoid including
// non-stable information like IP addresses in the output.
const tlsErr = "tls error (server refused TLS connection)"
if strings.HasSuffix(errStr, tlsErr) {
errStr = tlsErr
}
}
return "ERROR: " + errStr
}
Expand Down
1 change: 0 additions & 1 deletion pkg/sql/pgwire/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ go_library(
"//pkg/sql/sqltelemetry",
"//pkg/sql/types",
"//pkg/util",
"//pkg/util/buildutil",
"//pkg/util/contextutil",
"//pkg/util/duration",
"//pkg/util/envutil",
Expand Down
68 changes: 48 additions & 20 deletions pkg/sql/pgwire/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ import (
"github.com/cockroachdb/errors"
"github.com/cockroachdb/errors/stdstrings"
"github.com/cockroachdb/redact"
"github.com/jackc/pgconn"
pgx "github.com/jackc/pgx/v4"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -172,6 +172,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
if !insecure {
httpScheme = "https://"
}
ctx := context.Background()

datadriven.Walk(t, datapathutils.TestDataPath(t, "auth"), func(t *testing.T, path string) {
defer leaktest.AfterTest(t)()
Expand Down Expand Up @@ -213,9 +214,24 @@ func hbaRunTest(t *testing.T, insecure bool) {
}
defer cleanup()

s, conn, _ := serverutils.StartServer(t,
s, _, _ := serverutils.StartServer(t,
base.TestServerArgs{Insecure: insecure, SocketFile: maybeSocketFile})
defer s.Stopper().Stop(context.Background())
defer s.Stopper().Stop(ctx)

pgURL, cleanup, err := sqlutils.PGUrlE(
s.ServingSQLAddr(), "TestAuthenticationAndHBARules", url.User(username.RootUser))
if err != nil {
t.Fatal(err)
}
if insecure {
pgURL.RawQuery = "sslmode=disable"
}
defer cleanup()
rootConn, err := pgx.Connect(ctx, pgURL.String())
if err != nil {
t.Fatal(err)
}
defer func() { _ = rootConn.Close(ctx) }()

// Enable conn/auth logging.
// We can't use the cluster settings to do this, because
Expand All @@ -232,7 +248,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
}
httpHBAUrl := httpScheme + s.HTTPAddr() + "/debug/hba_conf"

if _, err := conn.ExecContext(context.Background(), fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
if _, err := rootConn.Exec(ctx, fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -269,7 +285,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
testServer.Cfg.AcceptSQLWithoutTLS = false

case "set_hba":
_, err := conn.ExecContext(context.Background(),
_, err := rootConn.Exec(ctx,
`SET CLUSTER SETTING server.host_based_authentication.configuration = $1`, td.Input)
if err != nil {
return "", err
Expand Down Expand Up @@ -310,7 +326,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
return string(body), nil

case "set_identity_map":
_, err := conn.ExecContext(context.Background(),
_, err := rootConn.Exec(ctx,
`SET CLUSTER SETTING server.identity_map.configuration = $1`, td.Input)
if err != nil {
return "", err
Expand Down Expand Up @@ -351,7 +367,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
return string(body), nil

case "sql":
_, err := conn.ExecContext(context.Background(), td.Input)
_, err := rootConn.Exec(ctx, td.Input)
return "ok", err

case "authlog":
Expand Down Expand Up @@ -547,30 +563,35 @@ func hbaRunTest(t *testing.T, insecure bool) {
if len(a.Vals) > 0 {
val = a.Vals[0]
}
if len(val) == 0 {
// pgx.Connect requires empty values to be passed as a
// single-quoted empty string.
val = "''"
}
fmt.Fprintf(&dsnBuf, "%s%s=%s", sp, a.Key, val)
sp = " "
}
dsn := dsnBuf.String()

// Finally, connect and test the connection.
dbSQL, err := gosql.Open("postgres", dsn)
dbSQL, err := pgx.Connect(ctx, dsn)
if dbSQL != nil {
// Note: gosql.Open may return a valid db (with an open
// Note: pgx.Connect may return a valid db (with an open
// TCP connection) even if there is an error. We want to
// ensure this gets closed so that we catch the conn close
// message in the log.
defer dbSQL.Close()
defer func() { _ = dbSQL.Close(ctx) }()
}
if err != nil {
return "", err
}
row := dbSQL.QueryRow("SELECT current_catalog")
row := dbSQL.QueryRow(ctx, "SELECT current_catalog")
var result string
if err := row.Scan(&result); err != nil {
return "", err
}
if showSystemIdentity {
row := dbSQL.QueryRow(`SHOW system_identity`)
row := dbSQL.QueryRow(ctx, `SHOW system_identity`)
var name string
if err := row.Scan(&name); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -599,20 +620,27 @@ var authLogFileRe = regexp.MustCompile(`"EventType":"client_`)
func fmtErr(err error) string {
if err != nil {
errStr := ""
if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) {
errStr = pqErr.Message
if pgcode.MakeCode(string(pqErr.Code)) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pqErr.Code)
if pgxErr := (*pgconn.PgError)(nil); errors.As(err, &pgxErr) {
errStr = pgxErr.Message
if pgcode.MakeCode(pgxErr.Code) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pgxErr.Code)
}
if pqErr.Hint != "" {
hint := strings.Replace(pqErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if pgxErr.Hint != "" {
hint := strings.Replace(pgxErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
errStr += "\nHINT: " + hint
}
if pqErr.Detail != "" {
errStr += "\nDETAIL: " + pqErr.Detail
if pgxErr.Detail != "" {
errStr += "\nDETAIL: " + pgxErr.Detail
}
} else {
errStr = err.Error()
// pgx uses an internal type (pgconn.connectError) for "TLS not enabled"
// errors here. We need to munge the error here to avoid including
// non-stable information like IP addresses in the output.
const tlsErr = "tls error (server refused TLS connection)"
if strings.HasSuffix(errStr, tlsErr) {
errStr = tlsErr
}
}
return "ERROR: " + errStr
}
Expand Down
Loading