Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ccl/oidcccl: support principal matching on list claims #98522

Merged
merged 1 commit into from
Mar 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/ccl/oidcccl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go_library(
name = "oidcccl",
srcs = [
"authentication_oidc.go",
"claim_match.go",
"settings.go",
"state.go",
],
Expand Down
38 changes: 22 additions & 16 deletions pkg/ccl/oidcccl/authentication_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ func reloadConfigLocked(

provider, err := oidc.NewProvider(ctx, server.conf.providerURL)
if err != nil {
log.Warningf(ctx, "unable to initialize OIDC provider, disabling OIDC: %v", err)
log.Warningf(ctx, "unable to initialize OIDC server, disabling OIDC: %v", err)
if log.V(1) {
log.Infof(ctx, "check provider URL OIDC cluster setting: "+OIDCProviderURLSettingName)
}
return
}

Expand All @@ -208,7 +211,10 @@ func reloadConfigLocked(

redirectURL, err := getRegionSpecificRedirectURL(locality, server.conf.redirectURLConf)
if err != nil {
log.Warningf(ctx, "unable to initialize OIDC provider, disabling OIDC: %v", err)
log.Warningf(ctx, "unable to initialize OIDC server, disabling OIDC: %v", err)
if log.V(1) {
log.Infof(ctx, "check redirect URL OIDC cluster setting: "+OIDCRedirectURLSettingName)
}
return
}

Expand Down Expand Up @@ -312,16 +318,16 @@ var ConfigureOIDC = func(
return
}

oauth2Token, err := oidcAuthentication.oauth2Config.Exchange(ctx, r.URL.Query().Get(codeKey))
credentials, err := oidcAuthentication.oauth2Config.Exchange(ctx, r.URL.Query().Get(codeKey))
if err != nil {
log.Errorf(ctx, "OIDC: failed to exchange code for token: %v", err)
http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError)
return
}

rawIDToken, ok := oauth2Token.Extra(idTokenKey).(string)
rawIDToken, ok := credentials.Extra(idTokenKey).(string)
if !ok {
log.Error(ctx, "OIDC: failed to extract ID token from OAuth2 token")
log.Error(ctx, "OIDC: failed to extract ID token from the token credentials")
http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError)
return
}
Expand All @@ -340,23 +346,23 @@ var ConfigureOIDC = func(
return
}

var principal string
claim := claims[oidcAuthentication.conf.claimJSONKey]
if err := json.Unmarshal(claim, &principal); err != nil {
log.Errorf(ctx, "OIDC: failed to complete authentication: failed to extract claim key %s: %v", oidcAuthentication.conf.claimJSONKey, err)
http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError)
return
if log.V(1) {
log.Infof(
ctx,
"attempting to extract SQL username from the payload using the claim key %s and regex %s",
oidcAuthentication.conf.claimJSONKey,
oidcAuthentication.conf.principalRegex,
)
}

match := oidcAuthentication.conf.principalRegex.FindStringSubmatch(principal)
numGroups := len(match)
if numGroups != 2 {
log.Errorf(ctx, "OIDC: failed to complete authentication: expected one group in regexp, got %d", numGroups)
username, err := extractUsernameFromClaims(
ctx, claims, oidcAuthentication.conf.claimJSONKey, oidcAuthentication.conf.principalRegex,
)
if err != nil {
http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError)
return
}

username := match[1]
cookie, err := userLoginFromSSO(ctx, username)
if err != nil {
log.Errorf(ctx, "OIDC: failed to complete authentication: unable to create session for %s: %v", username, err)
Expand Down
63 changes: 63 additions & 0 deletions pkg/ccl/oidcccl/authentication_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"

Expand Down Expand Up @@ -289,6 +291,67 @@ func TestOIDCStateEncodeDecode(t *testing.T) {
}
}

func TestOIDCClaimMatch(t *testing.T) {
ctx := context.Background()

for _, tc := range []struct {
testName string
claimKey string
principalRegex string
claims map[string]json.RawMessage
wantError bool
}{
{
testName: "string valued claim",
claimKey: "email",
principalRegex: "^([^@]+)@[^@]+$",
claims: map[string]json.RawMessage{
"email": json.RawMessage(`"[email protected]"`),
},
},
{
testName: "string valued claim with no match",
claimKey: "email",
principalRegex: "^([^@]+)@[^@]+$",
claims: map[string]json.RawMessage{
"email": json.RawMessage(`"bademail"`),
},
wantError: true,
},
{
testName: "list valued claim",
claimKey: "groups",
principalRegex: "^([^@]+)@[^@]+$",
claims: map[string]json.RawMessage{
"groups": json.RawMessage(
`["badgroupname", "[email protected]", "anotherbadgroupname"]`,
),
},
},
{
testName: "list valued claim with no matches",
claimKey: "groups",
principalRegex: "^([^@]+)@[^@]+$",
claims: map[string]json.RawMessage{
"groups": json.RawMessage(`["badgroupname", "anotherbadgroupname"]`),
},
wantError: true,
},
} {
t.Run(tc.testName, func(t *testing.T) {
sqlUsername, err := extractUsernameFromClaims(
ctx, tc.claims, tc.claimKey, regexp.MustCompile(tc.principalRegex),
)
if !tc.wantError {
require.NoError(t, err)
require.Equal(t, "myfakeemail", sqlUsername)
} else {
require.ErrorContains(t, err, "expected one group in regexp")
}
})
}
}

func Test_getRegionSpecificRedirectURL(t *testing.T) {
type args struct {
locality roachpb.Locality
Expand Down
97 changes: 97 additions & 0 deletions pkg/ccl/oidcccl/claim_match.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2023 The Cockroach Authors.
//
// Licensed as a CockroachDB Enterprise file under the Cockroach Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt

package oidcccl

import (
"context"
"encoding/json"
"regexp"
"strings"

"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/errors"
)

// extractUsernameFromClaims uses a regex to strip out elements of the value
// corresponding to the token claim claimKey.
func extractUsernameFromClaims(
ctx context.Context,
claims map[string]json.RawMessage,
claimKey string,
principalRE *regexp.Regexp,
) (string, error) {
var (
principal string
principals []string
)

claimKeys := make([]string, len(claims))
i := 0
for k := range claims {
claimKeys[i] = k
i++
}

targetClaim, ok := claims[claimKey]
if !ok {
log.Errorf(
ctx, "OIDC: failed to complete authentication: invalid JSON claim key: %s", claimKey,
)
log.Infof(ctx, "token payload includes the following claims: %s", strings.Join(claimKeys, ", "))
}

if err := json.Unmarshal(targetClaim, &principal); err != nil {
// Try parsing assuming the claim value is a list and not a string.
if log.V(1) {
log.Infof(ctx,
"failed parsing claim as string; attempting to parse as a list",
)
}
if err = json.Unmarshal(targetClaim, &principals); err != nil {
log.Errorf(ctx,
"OIDC: failed to complete authentication: failed to parse value for the claim %s: %v",
claimKey, err,
)
return "", err
}
if log.V(1) {
log.Infof(ctx,
"multiple principals in the claim found; selecting first matching principal",
)
}
}

if len(principals) == 0 {
principals = []string{principal}
}

var match []string
for _, principal := range principals {
match = principalRE.FindStringSubmatch(principal)
if len(match) == 2 {
log.Infof(ctx,
"extracted SQL username %s from the target claim %s", match[1], claimKey,
)
return match[1], nil
}
}

// Error when there is not a match.
err := errors.Newf("expected one group in regexp")
log.Errorf(ctx, "OIDC: failed to complete authentication: %v", err)
if log.V(1) {
log.Infof(ctx,
"token payload includes the following claims: %s\n"+
"check OIDC cluster settings: %s, %s",
strings.Join(claimKeys, ", "),
OIDCClaimJSONKeySettingName, OIDCPrincipalRegexSettingName,
)
}
return "", err
}