diff --git a/pkg/ccl/oidcccl/BUILD.bazel b/pkg/ccl/oidcccl/BUILD.bazel index adffa787919a..8ec0bd6f37da 100644 --- a/pkg/ccl/oidcccl/BUILD.bazel +++ b/pkg/ccl/oidcccl/BUILD.bazel @@ -5,6 +5,7 @@ go_library( name = "oidcccl", srcs = [ "authentication_oidc.go", + "claim_match.go", "settings.go", "state.go", ], diff --git a/pkg/ccl/oidcccl/authentication_oidc.go b/pkg/ccl/oidcccl/authentication_oidc.go index 659e925fff1c..f61889002f08 100644 --- a/pkg/ccl/oidcccl/authentication_oidc.go +++ b/pkg/ccl/oidcccl/authentication_oidc.go @@ -199,7 +199,10 @@ func reloadConfigLocked( provider, err := oidc.NewProvider(ctx, server.conf.providerURL) if err != nil { - log.Warningf(ctx, "unable to initialize OIDC provider, disabling OIDC: %v", err) + log.Warningf(ctx, "unable to initialize OIDC server, disabling OIDC: %v", err) + if log.V(1) { + log.Infof(ctx, "check provider URL OIDC cluster setting: "+OIDCProviderURLSettingName) + } return } @@ -208,7 +211,10 @@ func reloadConfigLocked( redirectURL, err := getRegionSpecificRedirectURL(locality, server.conf.redirectURLConf) if err != nil { - log.Warningf(ctx, "unable to initialize OIDC provider, disabling OIDC: %v", err) + log.Warningf(ctx, "unable to initialize OIDC server, disabling OIDC: %v", err) + if log.V(1) { + log.Infof(ctx, "check redirect URL OIDC cluster setting: "+OIDCRedirectURLSettingName) + } return } @@ -312,16 +318,16 @@ var ConfigureOIDC = func( return } - oauth2Token, err := oidcAuthentication.oauth2Config.Exchange(ctx, r.URL.Query().Get(codeKey)) + credentials, err := oidcAuthentication.oauth2Config.Exchange(ctx, r.URL.Query().Get(codeKey)) if err != nil { log.Errorf(ctx, "OIDC: failed to exchange code for token: %v", err) http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError) return } - rawIDToken, ok := oauth2Token.Extra(idTokenKey).(string) + rawIDToken, ok := credentials.Extra(idTokenKey).(string) if !ok { - log.Error(ctx, "OIDC: failed to extract ID token from OAuth2 token") + log.Error(ctx, "OIDC: failed to extract ID token from the token credentials") http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError) return } @@ -340,23 +346,23 @@ var ConfigureOIDC = func( return } - var principal string - claim := claims[oidcAuthentication.conf.claimJSONKey] - if err := json.Unmarshal(claim, &principal); err != nil { - log.Errorf(ctx, "OIDC: failed to complete authentication: failed to extract claim key %s: %v", oidcAuthentication.conf.claimJSONKey, err) - http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError) - return + if log.V(1) { + log.Infof( + ctx, + "attempting to extract SQL username from the payload using the claim key %s and regex %s", + oidcAuthentication.conf.claimJSONKey, + oidcAuthentication.conf.principalRegex, + ) } - match := oidcAuthentication.conf.principalRegex.FindStringSubmatch(principal) - numGroups := len(match) - if numGroups != 2 { - log.Errorf(ctx, "OIDC: failed to complete authentication: expected one group in regexp, got %d", numGroups) + username, err := extractUsernameFromClaims( + ctx, claims, oidcAuthentication.conf.claimJSONKey, oidcAuthentication.conf.principalRegex, + ) + if err != nil { http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError) return } - username := match[1] cookie, err := userLoginFromSSO(ctx, username) if err != nil { log.Errorf(ctx, "OIDC: failed to complete authentication: unable to create session for %s: %v", username, err) diff --git a/pkg/ccl/oidcccl/authentication_oidc_test.go b/pkg/ccl/oidcccl/authentication_oidc_test.go index 450f83df87be..1c040db90c4f 100644 --- a/pkg/ccl/oidcccl/authentication_oidc_test.go +++ b/pkg/ccl/oidcccl/authentication_oidc_test.go @@ -13,10 +13,12 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/base64" + "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" + "regexp" "strings" "testing" @@ -289,6 +291,67 @@ func TestOIDCStateEncodeDecode(t *testing.T) { } } +func TestOIDCClaimMatch(t *testing.T) { + ctx := context.Background() + + for _, tc := range []struct { + testName string + claimKey string + principalRegex string + claims map[string]json.RawMessage + wantError bool + }{ + { + testName: "string valued claim", + claimKey: "email", + principalRegex: "^([^@]+)@[^@]+$", + claims: map[string]json.RawMessage{ + "email": json.RawMessage(`"myfakeemail@example.com"`), + }, + }, + { + testName: "string valued claim with no match", + claimKey: "email", + principalRegex: "^([^@]+)@[^@]+$", + claims: map[string]json.RawMessage{ + "email": json.RawMessage(`"bademail"`), + }, + wantError: true, + }, + { + testName: "list valued claim", + claimKey: "groups", + principalRegex: "^([^@]+)@[^@]+$", + claims: map[string]json.RawMessage{ + "groups": json.RawMessage( + `["badgroupname", "myfakeemail@example.com", "anotherbadgroupname"]`, + ), + }, + }, + { + testName: "list valued claim with no matches", + claimKey: "groups", + principalRegex: "^([^@]+)@[^@]+$", + claims: map[string]json.RawMessage{ + "groups": json.RawMessage(`["badgroupname", "anotherbadgroupname"]`), + }, + wantError: true, + }, + } { + t.Run(tc.testName, func(t *testing.T) { + sqlUsername, err := extractUsernameFromClaims( + ctx, tc.claims, tc.claimKey, regexp.MustCompile(tc.principalRegex), + ) + if !tc.wantError { + require.NoError(t, err) + require.Equal(t, "myfakeemail", sqlUsername) + } else { + require.ErrorContains(t, err, "expected one group in regexp") + } + }) + } +} + func Test_getRegionSpecificRedirectURL(t *testing.T) { type args struct { locality roachpb.Locality diff --git a/pkg/ccl/oidcccl/claim_match.go b/pkg/ccl/oidcccl/claim_match.go new file mode 100644 index 000000000000..8bbc39987431 --- /dev/null +++ b/pkg/ccl/oidcccl/claim_match.go @@ -0,0 +1,97 @@ +// Copyright 2023 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package oidcccl + +import ( + "context" + "encoding/json" + "regexp" + "strings" + + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" +) + +// extractUsernameFromClaims uses a regex to strip out elements of the value +// corresponding to the token claim claimKey. +func extractUsernameFromClaims( + ctx context.Context, + claims map[string]json.RawMessage, + claimKey string, + principalRE *regexp.Regexp, +) (string, error) { + var ( + principal string + principals []string + ) + + claimKeys := make([]string, len(claims)) + i := 0 + for k := range claims { + claimKeys[i] = k + i++ + } + + targetClaim, ok := claims[claimKey] + if !ok { + log.Errorf( + ctx, "OIDC: failed to complete authentication: invalid JSON claim key: %s", claimKey, + ) + log.Infof(ctx, "token payload includes the following claims: %s", strings.Join(claimKeys, ", ")) + } + + if err := json.Unmarshal(targetClaim, &principal); err != nil { + // Try parsing assuming the claim value is a list and not a string. + if log.V(1) { + log.Infof(ctx, + "failed parsing claim as string; attempting to parse as a list", + ) + } + if err = json.Unmarshal(targetClaim, &principals); err != nil { + log.Errorf(ctx, + "OIDC: failed to complete authentication: failed to parse value for the claim %s: %v", + claimKey, err, + ) + return "", err + } + if log.V(1) { + log.Infof(ctx, + "multiple principals in the claim found; selecting first matching principal", + ) + } + } + + if len(principals) == 0 { + principals = []string{principal} + } + + var match []string + for _, principal := range principals { + match = principalRE.FindStringSubmatch(principal) + if len(match) == 2 { + log.Infof(ctx, + "extracted SQL username %s from the target claim %s", match[1], claimKey, + ) + return match[1], nil + } + } + + // Error when there is not a match. + err := errors.Newf("expected one group in regexp") + log.Errorf(ctx, "OIDC: failed to complete authentication: %v", err) + if log.V(1) { + log.Infof(ctx, + "token payload includes the following claims: %s\n"+ + "check OIDC cluster settings: %s, %s", + strings.Join(claimKeys, ", "), + OIDCClaimJSONKeySettingName, OIDCPrincipalRegexSettingName, + ) + } + return "", err +}