Skip to content

Commit

Permalink
Merge pull request #124130 from souravcrl/backport23.1-123697
Browse files Browse the repository at this point in the history
release-23.1: ccl,sql,util: Fix jwt auth and add sensitive error logs
  • Loading branch information
souravcrl authored May 15, 2024
2 parents 59b75e6 + 35b48b6 commit 59c1f69
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 71 deletions.
63 changes: 41 additions & 22 deletions pkg/ccl/jwtauthccl/authentication_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,23 @@ func (authenticator *jwtAuthenticator) mapUsername(
// * the audience field matches the audience cluster setting.
// * the issuer field is one of the values in the issuer cluster setting.
// * the cluster has an enterprise license.
// It returns authError (which is the error sql clients will see in case of
// failures) and detailedError (which is the internal error from http clients
// that might contain sensitive information we do not want to send to sql
// clients but still want to log it). We do not want to send any information
// back to client which was not provided by the client.
func (authenticator *jwtAuthenticator) ValidateJWTLogin(
ctx context.Context,
st *cluster.Settings,
user username.SQLUsername,
tokenBytes []byte,
identMap *identmap.Conf,
) error {
) (detailedErrorMsg string, authError error) {
authenticator.mu.Lock()
defer authenticator.mu.Unlock()

if !authenticator.mu.enabled {
return errors.Newf("JWT authentication: not enabled")
return "", errors.Newf("JWT authentication: not enabled")
}

telemetry.Inc(beginAuthUseCounter)
Expand All @@ -146,7 +151,9 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
// The token will be parsed again later to actually verify the signature.
unverifiedToken, err := jwt.Parse(tokenBytes)
if err != nil {
return errors.Newf("JWT authentication: invalid token")
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid token"),
"token parsing failed: %v", err)
}

// Check for issuer match against configured issuers.
Expand All @@ -160,17 +167,18 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
}
}
if !issuerMatch {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid issuer"),
"token issued by %s", unverifiedToken.Issuer())
}

var jwkSet jwk.Set
// If auto-fetch is enabled, fetch the JWKS remotely from the issuer's well known jwks url.
if authenticator.mu.conf.jwksAutoFetchEnabled {
jwkSet, err = remoteFetchJWKS(ctx, issuerUrl)
jwkSet, err = authenticator.remoteFetchJWKS(ctx, issuerUrl)
if err != nil {
return errors.Newf("JWT authentication: unable to validate token")
return fmt.Sprintf("unable to fetch jwks: %v", err),
errors.Newf("JWT authentication: unable to validate token")
}
} else {
jwkSet = authenticator.mu.conf.jwks
Expand All @@ -179,7 +187,9 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
// Now that both the issuer and key-id are matched, parse the token again to validate the signature.
parsedToken, err := jwt.Parse(tokenBytes, jwt.WithKeySet(jwkSet), jwt.WithValidate(true), jwt.InferAlgorithmFromKey(true))
if err != nil {
return errors.Newf("JWT authentication: invalid token")
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid token"),
"unable to parse token: %v", err)
}

// Extract all requested principals from the token. By default, we take it from the subject unless they specify
Expand All @@ -190,7 +200,7 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
} else {
claimValue, ok := parsedToken.Get(authenticator.mu.conf.claim)
if !ok {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: missing claim"),
"token does not contain a claim for %s", authenticator.mu.conf.claim)
}
Expand All @@ -217,14 +227,14 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
for _, tokenPrincipal := range tokenPrincipals {
mappedUsernames, err := authenticator.mapUsername(tokenPrincipal, parsedToken.Issuer(), identMap)
if err != nil {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid claim value"),
"the value %s for the issuer %s is invalid", tokenPrincipal, parsedToken.Issuer())
}
acceptedUsernames = append(acceptedUsernames, mappedUsernames...)
}
if len(acceptedUsernames) == 0 {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid principal"),
"the value %s for the issuer %s is invalid", tokenPrincipals, parsedToken.Issuer())
}
Expand All @@ -236,12 +246,12 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
}
}
if !principalMatch {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid principal"),
"token issued for %s and login was for %s", tokenPrincipals, user.Normalized())
}
if user.IsRootUser() || user.IsReserved() {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid identity"),
"cannot use JWT auth to login to a reserved user %s", user.Normalized())
}
Expand All @@ -255,26 +265,29 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
}
}
if !audienceMatch {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid audience"),
"token issued with an audience of %s", parsedToken.Audience())
}

if err = utilccl.CheckEnterpriseEnabled(st, authenticator.clusterUUID, "JWT authentication"); err != nil {
return err
return "", err
}

telemetry.Inc(loginSuccessUseCounter)
return nil
return "", nil
}

// remoteFetchJWKS fetches the JWKS from the provided URI.
func remoteFetchJWKS(ctx context.Context, issuerUrl string) (jwk.Set, error) {
jwksUrl, err := getJWKSUrl(ctx, issuerUrl)
func (authenticator *jwtAuthenticator) remoteFetchJWKS(
ctx context.Context, issuerUrl string,
) (jwk.Set, error) {
jwksUrl, err := authenticator.getJWKSUrl(ctx, issuerUrl)
if err != nil {
return nil, err
}
body, err := getHttpResponse(ctx, jwksUrl)

body, err := getHttpResponse(ctx, jwksUrl, authenticator)
if err != nil {
return nil, err
}
Expand All @@ -286,12 +299,14 @@ func remoteFetchJWKS(ctx context.Context, issuerUrl string) (jwk.Set, error) {
}

// getJWKSUrl returns the JWKS URI from the OpenID configuration endpoint.
func getJWKSUrl(ctx context.Context, issuerUrl string) (string, error) {
func (authenticator *jwtAuthenticator) getJWKSUrl(
ctx context.Context, issuerUrl string,
) (string, error) {
type OIDCConfigResponse struct {
JWKSUri string `json:"jwks_uri"`
}
openIdConfigEndpoint := getOpenIdConfigEndpoint(issuerUrl)
body, err := getHttpResponse(ctx, openIdConfigEndpoint)
body, err := getHttpResponse(ctx, openIdConfigEndpoint, authenticator)
if err != nil {
return "", err
}
Expand All @@ -311,8 +326,12 @@ func getOpenIdConfigEndpoint(issuerUrl string) string {
return openIdConfigEndpoint
}

var getHttpResponse = func(ctx context.Context, url string) ([]byte, error) {
resp, err := httputil.Get(ctx, url)
var getHttpResponse = func(ctx context.Context, url string, authenticator *jwtAuthenticator) ([]byte, error) {
// TODO(souravcrl): cache the http client in a callback attached to customCA
// and other http client cluster settings as re parsing the custom CA every
// time is expensive
httpClient := httputil.NewClientWithTimeout(httputil.StandardHTTPTimeout)
resp, err := httpClient.Get(context.Background(), url)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 59c1f69

Please sign in to comment.