Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ccl,sql,util: Fix jwt auth and add sensitive error logs #123697

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions pkg/ccl/jwtauthccl/authentication_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,23 @@ func (authenticator *jwtAuthenticator) mapUsername(
// * the audience field matches the audience cluster setting.
// * the issuer field is one of the values in the issuer cluster setting.
// * the cluster has an enterprise license.
// It returns authError (which is the error sql clients will see in case of
// failures) and detailedError (which is the internal error from http clients
// that might contain sensitive information we do not want to send to sql
// clients but still want to log it). We do not want to send any information
// back to client which was not provided by the client.
func (authenticator *jwtAuthenticator) ValidateJWTLogin(
ctx context.Context,
st *cluster.Settings,
user username.SQLUsername,
tokenBytes []byte,
identMap *identmap.Conf,
) error {
) (detailedErrorMsg string, authError error) {
authenticator.mu.Lock()
defer authenticator.mu.Unlock()

if !authenticator.mu.enabled {
return errors.Newf("JWT authentication: not enabled")
return "", errors.Newf("JWT authentication: not enabled")
}

telemetry.Inc(beginAuthUseCounter)
Expand All @@ -146,7 +151,9 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
// The token will be parsed again later to actually verify the signature.
unverifiedToken, err := jwt.Parse(tokenBytes)
if err != nil {
return errors.Newf("JWT authentication: invalid token")
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid token"),
"token parsing failed: %v", err)
}

// Check for issuer match against configured issuers.
Expand All @@ -160,17 +167,18 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
}
}
if !issuerMatch {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid issuer"),
"token issued by %s", unverifiedToken.Issuer())
}

var jwkSet jwk.Set
// If auto-fetch is enabled, fetch the JWKS remotely from the issuer's well known jwks url.
if authenticator.mu.conf.jwksAutoFetchEnabled {
jwkSet, err = remoteFetchJWKS(ctx, issuerUrl)
jwkSet, err = authenticator.remoteFetchJWKS(ctx, issuerUrl)
if err != nil {
return errors.Newf("JWT authentication: unable to validate token")
return fmt.Sprintf("unable to fetch jwks: %v", err),
errors.Newf("JWT authentication: unable to validate token")
}
} else {
jwkSet = authenticator.mu.conf.jwks
Expand All @@ -179,7 +187,9 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
// Now that both the issuer and key-id are matched, parse the token again to validate the signature.
parsedToken, err := jwt.Parse(tokenBytes, jwt.WithKeySet(jwkSet), jwt.WithValidate(true), jwt.InferAlgorithmFromKey(true))
if err != nil {
return errors.Newf("JWT authentication: invalid token")
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid token"),
"unable to parse token: %v", err)
}

// Extract all requested principals from the token. By default, we take it from the subject unless they specify
Expand All @@ -190,7 +200,7 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
} else {
claimValue, ok := parsedToken.Get(authenticator.mu.conf.claim)
if !ok {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: missing claim"),
"token does not contain a claim for %s", authenticator.mu.conf.claim)
}
Expand All @@ -217,14 +227,14 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
for _, tokenPrincipal := range tokenPrincipals {
mappedUsernames, err := authenticator.mapUsername(tokenPrincipal, parsedToken.Issuer(), identMap)
if err != nil {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid claim value"),
"the value %s for the issuer %s is invalid", tokenPrincipal, parsedToken.Issuer())
}
acceptedUsernames = append(acceptedUsernames, mappedUsernames...)
}
if len(acceptedUsernames) == 0 {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid principal"),
"the value %s for the issuer %s is invalid", tokenPrincipals, parsedToken.Issuer())
}
Expand All @@ -236,12 +246,12 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
}
}
if !principalMatch {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid principal"),
"token issued for %s and login was for %s", tokenPrincipals, user.Normalized())
}
if user.IsRootUser() || user.IsReserved() {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid identity"),
"cannot use JWT auth to login to a reserved user %s", user.Normalized())
}
Expand All @@ -255,26 +265,29 @@ func (authenticator *jwtAuthenticator) ValidateJWTLogin(
}
}
if !audienceMatch {
return errors.WithDetailf(
return "", errors.WithDetailf(
errors.Newf("JWT authentication: invalid audience"),
"token issued with an audience of %s", parsedToken.Audience())
}

if err = utilccl.CheckEnterpriseEnabled(st, "JWT authentication"); err != nil {
return err
return "", err
}

telemetry.Inc(loginSuccessUseCounter)
return nil
return "", nil
}

// remoteFetchJWKS fetches the JWKS from the provided URI.
func remoteFetchJWKS(ctx context.Context, issuerUrl string) (jwk.Set, error) {
jwksUrl, err := getJWKSUrl(ctx, issuerUrl)
func (authenticator *jwtAuthenticator) remoteFetchJWKS(
ctx context.Context, issuerUrl string,
) (jwk.Set, error) {
jwksUrl, err := authenticator.getJWKSUrl(ctx, issuerUrl)
if err != nil {
return nil, err
}
body, err := getHttpResponse(ctx, jwksUrl)

body, err := getHttpResponse(ctx, jwksUrl, authenticator)
if err != nil {
return nil, err
}
Expand All @@ -286,12 +299,14 @@ func remoteFetchJWKS(ctx context.Context, issuerUrl string) (jwk.Set, error) {
}

// getJWKSUrl returns the JWKS URI from the OpenID configuration endpoint.
func getJWKSUrl(ctx context.Context, issuerUrl string) (string, error) {
func (authenticator *jwtAuthenticator) getJWKSUrl(
ctx context.Context, issuerUrl string,
) (string, error) {
type OIDCConfigResponse struct {
JWKSUri string `json:"jwks_uri"`
}
openIdConfigEndpoint := getOpenIdConfigEndpoint(issuerUrl)
body, err := getHttpResponse(ctx, openIdConfigEndpoint)
body, err := getHttpResponse(ctx, openIdConfigEndpoint, authenticator)
if err != nil {
return "", err
}
Expand All @@ -311,8 +326,12 @@ func getOpenIdConfigEndpoint(issuerUrl string) string {
return openIdConfigEndpoint
}

var getHttpResponse = func(ctx context.Context, url string) ([]byte, error) {
resp, err := httputil.Get(ctx, url)
var getHttpResponse = func(ctx context.Context, url string, authenticator *jwtAuthenticator) ([]byte, error) {
// TODO(souravcrl): cache the http client in a callback attached to customCA
// and other http client cluster settings as re parsing the custom CA every
// time is expensive
httpClient := httputil.NewClientWithTimeout(httputil.StandardHTTPTimeout)
resp, err := httpClient.Get(context.Background(), url)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading