Skip to content

Commit

Permalink
pgwire,authccl: use pgx for TestAuthenticationAndHBARules
Browse files Browse the repository at this point in the history
The lib/pq driver is not maintained. Since we started to see flakes
related to how that driver does error handling for secure connections,
we switch to pgx instead.

Release note: None
  • Loading branch information
rafiss committed Nov 14, 2024
1 parent efc3b8c commit fa9f7c3
Show file tree
Hide file tree
Showing 10 changed files with 406 additions and 312 deletions.
3 changes: 2 additions & 1 deletion pkg/ccl/testccl/authccl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ go_test(
"@com_github_cockroachdb_datadriven//:datadriven",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_errors//stdstrings",
"@com_github_lib_pq//:pq",
"@com_github_jackc_pgconn//:pgconn",
"@com_github_jackc_pgx_v4//:pgx",
"@com_github_stretchr_testify//require",
],
)
81 changes: 58 additions & 23 deletions pkg/ccl/testccl/authccl/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ import (
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/errors/stdstrings"
"github.com/lib/pq"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -144,7 +145,7 @@ func makeSocketFile(t *testing.T) (socketDir, socketFile string, cleanupFn func(
}

func jwtRunTest(t *testing.T, insecure bool) {

ctx := context.Background()
datadriven.Walk(t, datapathutils.TestDataPath(t), func(t *testing.T, path string) {
defer leaktest.AfterTest(t)()

Expand Down Expand Up @@ -178,7 +179,7 @@ func jwtRunTest(t *testing.T, insecure bool) {
}
defer cleanup()

s, conn, _ := serverutils.StartServer(t,
s, _, _ := serverutils.StartServer(t,
base.TestServerArgs{
Insecure: insecure,
SocketFile: maybeSocketFile,
Expand All @@ -190,7 +191,22 @@ func jwtRunTest(t *testing.T, insecure bool) {
sv = &s.TestTenants()[0].ClusterSettings().SV
}

if _, err := conn.ExecContext(context.Background(), fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
pgURL, cleanup, err := sqlutils.PGUrlE(
s.ServingSQLAddr(), "TestAuthenticationAndHBARules", url.User(username.RootUser))
if err != nil {
t.Fatal(err)
}
if insecure {
pgURL.RawQuery = "sslmode=disable"
}
defer cleanup()
rootConn, err := pgx.Connect(ctx, pgURL.String())
if err != nil {
t.Fatal(err)
}
defer func() { _ = rootConn.Close(ctx) }()

if _, err := rootConn.Exec(ctx, fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
t.Fatal(err)
}

Expand All @@ -208,27 +224,27 @@ func jwtRunTest(t *testing.T, insecure bool) {
if err != nil {
t.Fatalf("unknown value for jwt_cluster_setting enabled: %s", a.Vals[0])
}
jwtauthccl.JWTAuthEnabled.Override(context.Background(), sv, v)
jwtauthccl.JWTAuthEnabled.Override(ctx, sv, v)
case "audience":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting audience: %d", len(a.Vals))
}
jwtauthccl.JWTAuthAudience.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthAudience.Override(ctx, sv, a.Vals[0])
case "issuers":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting issuers: %d", len(a.Vals))
}
jwtauthccl.JWTAuthIssuersConfig.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthIssuersConfig.Override(ctx, sv, a.Vals[0])
case "jwks":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting jwks: %d", len(a.Vals))
}
jwtauthccl.JWTAuthJWKS.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthJWKS.Override(ctx, sv, a.Vals[0])
case "claim":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting claim: %d", len(a.Vals))
}
jwtauthccl.JWTAuthClaim.Override(context.Background(), sv, a.Vals[0])
jwtauthccl.JWTAuthClaim.Override(ctx, sv, a.Vals[0])
case "ident_map":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting ident_map: %d", len(a.Vals))
Expand All @@ -237,7 +253,7 @@ func jwtRunTest(t *testing.T, insecure bool) {
if len(args) != 3 {
t.Fatalf("wrong number of comma separated argumenets to jwt_cluster_setting ident_map: %d", len(a.Vals))
}
pgwire.ConnIdentityMapConf.Override(context.Background(), sv, strings.Join(args, " "))
pgwire.ConnIdentityMapConf.Override(ctx, sv, strings.Join(args, " "))
case "jwks_auto_fetch.enabled":
if len(a.Vals) != 1 {
t.Fatalf("wrong number of argumenets to jwt_cluster_setting jwks_auto.fetch_enabled: %d", len(a.Vals))
Expand All @@ -246,7 +262,7 @@ func jwtRunTest(t *testing.T, insecure bool) {
if err != nil {
t.Fatalf("unknown value for jwt_cluster_setting jwks_auto_fetch.enabled: %s", a.Vals[0])
}
jwtauthccl.JWKSAutoFetchEnabled.Override(context.Background(), sv, v)
jwtauthccl.JWKSAutoFetchEnabled.Override(ctx, sv, v)
default:
t.Fatalf("unknown jwt_cluster_setting: %s", a.Key)
}
Expand All @@ -268,7 +284,7 @@ func jwtRunTest(t *testing.T, insecure bool) {
}

case "sql":
_, err := conn.ExecContext(context.Background(), td.Input)
_, err := rootConn.Exec(ctx, td.Input)
return "ok", err

case "connect", "connect_unix":
Expand Down Expand Up @@ -350,24 +366,29 @@ func jwtRunTest(t *testing.T, insecure bool) {
if len(a.Vals) > 0 {
val = a.Vals[0]
}
if len(val) == 0 {
// pgx.Connect requires empty values to be passed as a
// single-quoted empty string.
val = "''"
}
fmt.Fprintf(&dsnBuf, "%s%s=%s", sp, a.Key, val)
sp = " "
}
dsn := dsnBuf.String()

// Finally, connect and test the connection.
dbSQL, err := gosql.Open("postgres", dsn)
dbSQL, err := pgx.Connect(ctx, dsn)
if dbSQL != nil {
// Note: gosql.Open may return a valid db (with an open
// TCP connection) even if there is an error. We want to
// ensure this gets closed so that we catch the conn close
// message in the log.
defer dbSQL.Close()
defer func() { _ = dbSQL.Close(ctx) }()
}
if err != nil {
return "", err
}
row := dbSQL.QueryRow("SELECT current_catalog")
row := dbSQL.QueryRow(ctx, "SELECT current_catalog")
var dbName string
if err := row.Scan(&dbName); err != nil {
return "", err
Expand All @@ -393,20 +414,34 @@ var authLogFileRe = regexp.MustCompile(`"EventType":"client_`)
func fmtErr(err error) string {
if err != nil {
errStr := ""
if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) {
errStr = pqErr.Message
if pgcode.MakeCode(string(pqErr.Code)) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pqErr.Code)
if pgxErr := (*pgconn.PgError)(nil); errors.As(err, &pgxErr) {
errStr = pgxErr.Message
if pgcode.MakeCode(pgxErr.Code) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pgxErr.Code)
}
if pqErr.Hint != "" {
hint := strings.Replace(pqErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if pgxErr.Hint != "" {
hint := strings.Replace(pgxErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if strings.Contains(hint, "Supported methods:") {
// Depending on whether the test is running on linux or not
// (or, more specifically, whether gss build tag is set),
// "gss" method might not be included, so we remove it here
// and not include into the expected output.
hint = strings.Replace(hint, "gss, ", "", 1)
}
errStr += "\nHINT: " + hint
}
if pqErr.Detail != "" {
errStr += "\nDETAIL: " + pqErr.Detail
if pgxErr.Detail != "" {
errStr += "\nDETAIL: " + pgxErr.Detail
}
} else {
errStr = err.Error()
// pgx uses an internal type (pgconn.connectError) for "TLS not enabled"
// errors here. We need to munge the error here to avoid including
// non-stable information like IP addresses in the output.
const tlsErr = "tls error (server refused TLS connection)"
if strings.HasSuffix(errStr, tlsErr) {
errStr = tlsErr
}
}
return "ERROR: " + errStr
}
Expand Down
68 changes: 48 additions & 20 deletions pkg/sql/pgwire/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ import (
"github.com/cockroachdb/errors"
"github.com/cockroachdb/errors/stdstrings"
"github.com/cockroachdb/redact"
"github.com/jackc/pgconn"
pgx "github.com/jackc/pgx/v4"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -172,6 +172,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
if !insecure {
httpScheme = "https://"
}
ctx := context.Background()

datadriven.Walk(t, datapathutils.TestDataPath(t, "auth"), func(t *testing.T, path string) {
defer leaktest.AfterTest(t)()
Expand Down Expand Up @@ -213,9 +214,24 @@ func hbaRunTest(t *testing.T, insecure bool) {
}
defer cleanup()

s, conn, _ := serverutils.StartServer(t,
s, _, _ := serverutils.StartServer(t,
base.TestServerArgs{Insecure: insecure, SocketFile: maybeSocketFile})
defer s.Stopper().Stop(context.Background())
defer s.Stopper().Stop(ctx)

pgURL, cleanup, err := sqlutils.PGUrlE(
s.ServingSQLAddr(), "TestAuthenticationAndHBARules", url.User(username.RootUser))
if err != nil {
t.Fatal(err)
}
if insecure {
pgURL.RawQuery = "sslmode=disable"
}
defer cleanup()
rootConn, err := pgx.Connect(ctx, pgURL.String())
if err != nil {
t.Fatal(err)
}
defer func() { _ = rootConn.Close(ctx) }()

// Enable conn/auth logging.
// We can't use the cluster settings to do this, because
Expand All @@ -232,7 +248,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
}
httpHBAUrl := httpScheme + s.HTTPAddr() + "/debug/hba_conf"

if _, err := conn.ExecContext(context.Background(), fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
if _, err := rootConn.Exec(ctx, fmt.Sprintf(`CREATE USER %s`, username.TestUser)); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -269,7 +285,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
testServer.Cfg.AcceptSQLWithoutTLS = false

case "set_hba":
_, err := conn.ExecContext(context.Background(),
_, err := rootConn.Exec(ctx,
`SET CLUSTER SETTING server.host_based_authentication.configuration = $1`, td.Input)
if err != nil {
return "", err
Expand Down Expand Up @@ -310,7 +326,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
return string(body), nil

case "set_identity_map":
_, err := conn.ExecContext(context.Background(),
_, err := rootConn.Exec(ctx,
`SET CLUSTER SETTING server.identity_map.configuration = $1`, td.Input)
if err != nil {
return "", err
Expand Down Expand Up @@ -351,7 +367,7 @@ func hbaRunTest(t *testing.T, insecure bool) {
return string(body), nil

case "sql":
_, err := conn.ExecContext(context.Background(), td.Input)
_, err := rootConn.Exec(ctx, td.Input)
return "ok", err

case "authlog":
Expand Down Expand Up @@ -547,30 +563,35 @@ func hbaRunTest(t *testing.T, insecure bool) {
if len(a.Vals) > 0 {
val = a.Vals[0]
}
if len(val) == 0 {
// pgx.Connect requires empty values to be passed as a
// single-quoted empty string.
val = "''"
}
fmt.Fprintf(&dsnBuf, "%s%s=%s", sp, a.Key, val)
sp = " "
}
dsn := dsnBuf.String()

// Finally, connect and test the connection.
dbSQL, err := gosql.Open("postgres", dsn)
dbSQL, err := pgx.Connect(ctx, dsn)
if dbSQL != nil {
// Note: gosql.Open may return a valid db (with an open
// Note: pgx.Connect may return a valid db (with an open
// TCP connection) even if there is an error. We want to
// ensure this gets closed so that we catch the conn close
// message in the log.
defer dbSQL.Close()
defer func() { _ = dbSQL.Close(ctx) }()
}
if err != nil {
return "", err
}
row := dbSQL.QueryRow("SELECT current_catalog")
row := dbSQL.QueryRow(ctx, "SELECT current_catalog")
var result string
if err := row.Scan(&result); err != nil {
return "", err
}
if showSystemIdentity {
row := dbSQL.QueryRow(`SHOW system_identity`)
row := dbSQL.QueryRow(ctx, `SHOW system_identity`)
var name string
if err := row.Scan(&name); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -599,20 +620,27 @@ var authLogFileRe = regexp.MustCompile(`"EventType":"client_`)
func fmtErr(err error) string {
if err != nil {
errStr := ""
if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) {
errStr = pqErr.Message
if pgcode.MakeCode(string(pqErr.Code)) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pqErr.Code)
if pgxErr := (*pgconn.PgError)(nil); errors.As(err, &pgxErr) {
errStr = pgxErr.Message
if pgcode.MakeCode(pgxErr.Code) != pgcode.Uncategorized {
errStr += fmt.Sprintf(" (SQLSTATE %s)", pgxErr.Code)
}
if pqErr.Hint != "" {
hint := strings.Replace(pqErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
if pgxErr.Hint != "" {
hint := strings.Replace(pgxErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
errStr += "\nHINT: " + hint
}
if pqErr.Detail != "" {
errStr += "\nDETAIL: " + pqErr.Detail
if pgxErr.Detail != "" {
errStr += "\nDETAIL: " + pgxErr.Detail
}
} else {
errStr = err.Error()
// pgx uses an internal type (pgconn.connectError) for "TLS not enabled"
// errors here. We need to munge the error here to avoid including
// non-stable information like IP addresses in the output.
const tlsErr = "tls error (server refused TLS connection)"
if strings.HasSuffix(errStr, tlsErr) {
errStr = tlsErr
}
}
return "ERROR: " + errStr
}
Expand Down
Loading

0 comments on commit fa9f7c3

Please sign in to comment.