From 496d069980998afc94b2aac98cb9f7fcb9a1e612 Mon Sep 17 00:00:00 2001 From: Cameron Nunez Date: Mon, 13 Mar 2023 11:52:10 -0400 Subject: [PATCH] ccl/oidcccl: support principal matching on list claims Previously, matching on ID token claims was not possible if the claim key specified had a corresponding value that was a list, not a string. With this change, matching can now occur on claims that are list valued in order to add login capabilities to DB Console. It is important to note that this change does NOT offer the user the ability to choose between possible matches; it simply selects the first match to log the user in. This change also adds more verbose logging about ID token details. Epic: none Fixes: #97301, #97468 Release note (enterprise change): The cluster setting `server.oidc_authentication.claim_json_key` for DB Console SSO now accepts list-valued token claims. Release note (general change): Increasing the logging verbosity is more helpful with troubleshooting DB Console SSO issues. --- pkg/ccl/oidcccl/BUILD.bazel | 1 + pkg/ccl/oidcccl/authentication_oidc.go | 38 ++++---- pkg/ccl/oidcccl/authentication_oidc_test.go | 63 +++++++++++++ pkg/ccl/oidcccl/claim_match.go | 97 +++++++++++++++++++++ 4 files changed, 183 insertions(+), 16 deletions(-) create mode 100644 pkg/ccl/oidcccl/claim_match.go diff --git a/pkg/ccl/oidcccl/BUILD.bazel b/pkg/ccl/oidcccl/BUILD.bazel index adffa787919a..8ec0bd6f37da 100644 --- a/pkg/ccl/oidcccl/BUILD.bazel +++ b/pkg/ccl/oidcccl/BUILD.bazel @@ -5,6 +5,7 @@ go_library( name = "oidcccl", srcs = [ "authentication_oidc.go", + "claim_match.go", "settings.go", "state.go", ], diff --git a/pkg/ccl/oidcccl/authentication_oidc.go b/pkg/ccl/oidcccl/authentication_oidc.go index 659e925fff1c..f61889002f08 100644 --- a/pkg/ccl/oidcccl/authentication_oidc.go +++ b/pkg/ccl/oidcccl/authentication_oidc.go @@ -199,7 +199,10 @@ func reloadConfigLocked( provider, err := oidc.NewProvider(ctx, server.conf.providerURL) if err != nil { - log.Warningf(ctx, "unable to initialize OIDC provider, disabling OIDC: %v", err) + log.Warningf(ctx, "unable to initialize OIDC server, disabling OIDC: %v", err) + if log.V(1) { + log.Infof(ctx, "check provider URL OIDC cluster setting: "+OIDCProviderURLSettingName) + } return } @@ -208,7 +211,10 @@ func reloadConfigLocked( redirectURL, err := getRegionSpecificRedirectURL(locality, server.conf.redirectURLConf) if err != nil { - log.Warningf(ctx, "unable to initialize OIDC provider, disabling OIDC: %v", err) + log.Warningf(ctx, "unable to initialize OIDC server, disabling OIDC: %v", err) + if log.V(1) { + log.Infof(ctx, "check redirect URL OIDC cluster setting: "+OIDCRedirectURLSettingName) + } return } @@ -312,16 +318,16 @@ var ConfigureOIDC = func( return } - oauth2Token, err := oidcAuthentication.oauth2Config.Exchange(ctx, r.URL.Query().Get(codeKey)) + credentials, err := oidcAuthentication.oauth2Config.Exchange(ctx, r.URL.Query().Get(codeKey)) if err != nil { log.Errorf(ctx, "OIDC: failed to exchange code for token: %v", err) http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError) return } - rawIDToken, ok := oauth2Token.Extra(idTokenKey).(string) + rawIDToken, ok := credentials.Extra(idTokenKey).(string) if !ok { - log.Error(ctx, "OIDC: failed to extract ID token from OAuth2 token") + log.Error(ctx, "OIDC: failed to extract ID token from the token credentials") http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError) return } @@ -340,23 +346,23 @@ var ConfigureOIDC = func( return } - var principal string - claim := claims[oidcAuthentication.conf.claimJSONKey] - if err := json.Unmarshal(claim, &principal); err != nil { - log.Errorf(ctx, "OIDC: failed to complete authentication: failed to extract claim key %s: %v", oidcAuthentication.conf.claimJSONKey, err) - http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError) - return + if log.V(1) { + log.Infof( + ctx, + "attempting to extract SQL username from the payload using the claim key %s and regex %s", + oidcAuthentication.conf.claimJSONKey, + oidcAuthentication.conf.principalRegex, + ) } - match := oidcAuthentication.conf.principalRegex.FindStringSubmatch(principal) - numGroups := len(match) - if numGroups != 2 { - log.Errorf(ctx, "OIDC: failed to complete authentication: expected one group in regexp, got %d", numGroups) + username, err := extractUsernameFromClaims( + ctx, claims, oidcAuthentication.conf.claimJSONKey, oidcAuthentication.conf.principalRegex, + ) + if err != nil { http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError) return } - username := match[1] cookie, err := userLoginFromSSO(ctx, username) if err != nil { log.Errorf(ctx, "OIDC: failed to complete authentication: unable to create session for %s: %v", username, err) diff --git a/pkg/ccl/oidcccl/authentication_oidc_test.go b/pkg/ccl/oidcccl/authentication_oidc_test.go index 450f83df87be..1c040db90c4f 100644 --- a/pkg/ccl/oidcccl/authentication_oidc_test.go +++ b/pkg/ccl/oidcccl/authentication_oidc_test.go @@ -13,10 +13,12 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/base64" + "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" + "regexp" "strings" "testing" @@ -289,6 +291,67 @@ func TestOIDCStateEncodeDecode(t *testing.T) { } } +func TestOIDCClaimMatch(t *testing.T) { + ctx := context.Background() + + for _, tc := range []struct { + testName string + claimKey string + principalRegex string + claims map[string]json.RawMessage + wantError bool + }{ + { + testName: "string valued claim", + claimKey: "email", + principalRegex: "^([^@]+)@[^@]+$", + claims: map[string]json.RawMessage{ + "email": json.RawMessage(`"myfakeemail@example.com"`), + }, + }, + { + testName: "string valued claim with no match", + claimKey: "email", + principalRegex: "^([^@]+)@[^@]+$", + claims: map[string]json.RawMessage{ + "email": json.RawMessage(`"bademail"`), + }, + wantError: true, + }, + { + testName: "list valued claim", + claimKey: "groups", + principalRegex: "^([^@]+)@[^@]+$", + claims: map[string]json.RawMessage{ + "groups": json.RawMessage( + `["badgroupname", "myfakeemail@example.com", "anotherbadgroupname"]`, + ), + }, + }, + { + testName: "list valued claim with no matches", + claimKey: "groups", + principalRegex: "^([^@]+)@[^@]+$", + claims: map[string]json.RawMessage{ + "groups": json.RawMessage(`["badgroupname", "anotherbadgroupname"]`), + }, + wantError: true, + }, + } { + t.Run(tc.testName, func(t *testing.T) { + sqlUsername, err := extractUsernameFromClaims( + ctx, tc.claims, tc.claimKey, regexp.MustCompile(tc.principalRegex), + ) + if !tc.wantError { + require.NoError(t, err) + require.Equal(t, "myfakeemail", sqlUsername) + } else { + require.ErrorContains(t, err, "expected one group in regexp") + } + }) + } +} + func Test_getRegionSpecificRedirectURL(t *testing.T) { type args struct { locality roachpb.Locality diff --git a/pkg/ccl/oidcccl/claim_match.go b/pkg/ccl/oidcccl/claim_match.go new file mode 100644 index 000000000000..8bbc39987431 --- /dev/null +++ b/pkg/ccl/oidcccl/claim_match.go @@ -0,0 +1,97 @@ +// Copyright 2023 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package oidcccl + +import ( + "context" + "encoding/json" + "regexp" + "strings" + + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" +) + +// extractUsernameFromClaims uses a regex to strip out elements of the value +// corresponding to the token claim claimKey. +func extractUsernameFromClaims( + ctx context.Context, + claims map[string]json.RawMessage, + claimKey string, + principalRE *regexp.Regexp, +) (string, error) { + var ( + principal string + principals []string + ) + + claimKeys := make([]string, len(claims)) + i := 0 + for k := range claims { + claimKeys[i] = k + i++ + } + + targetClaim, ok := claims[claimKey] + if !ok { + log.Errorf( + ctx, "OIDC: failed to complete authentication: invalid JSON claim key: %s", claimKey, + ) + log.Infof(ctx, "token payload includes the following claims: %s", strings.Join(claimKeys, ", ")) + } + + if err := json.Unmarshal(targetClaim, &principal); err != nil { + // Try parsing assuming the claim value is a list and not a string. + if log.V(1) { + log.Infof(ctx, + "failed parsing claim as string; attempting to parse as a list", + ) + } + if err = json.Unmarshal(targetClaim, &principals); err != nil { + log.Errorf(ctx, + "OIDC: failed to complete authentication: failed to parse value for the claim %s: %v", + claimKey, err, + ) + return "", err + } + if log.V(1) { + log.Infof(ctx, + "multiple principals in the claim found; selecting first matching principal", + ) + } + } + + if len(principals) == 0 { + principals = []string{principal} + } + + var match []string + for _, principal := range principals { + match = principalRE.FindStringSubmatch(principal) + if len(match) == 2 { + log.Infof(ctx, + "extracted SQL username %s from the target claim %s", match[1], claimKey, + ) + return match[1], nil + } + } + + // Error when there is not a match. + err := errors.Newf("expected one group in regexp") + log.Errorf(ctx, "OIDC: failed to complete authentication: %v", err) + if log.V(1) { + log.Infof(ctx, + "token payload includes the following claims: %s\n"+ + "check OIDC cluster settings: %s, %s", + strings.Join(claimKeys, ", "), + OIDCClaimJSONKeySettingName, OIDCPrincipalRegexSettingName, + ) + } + return "", err +}