From a912c7513bf905287ed9a817ccf4f25e471ec4da Mon Sep 17 00:00:00 2001 From: leedonggyu Date: Wed, 22 May 2024 04:28:50 +0900 Subject: [PATCH] Support multiple audience for jwt authentication Signed-off-by: leedonggyu --- .../jwt/AbstractHTTPJwtAuthenticator.java | 7 ++--- .../auth/http/jwt/HTTPJwtAuthenticator.java | 24 ++++++++++++----- .../auth/http/jwt/keybyoidc/JwtVerifier.java | 15 ++++++----- .../http/jwt/HTTPJwtAuthenticatorTest.java | 27 +++++++++++++++++++ ...wtKeyByOpenIdConnectAuthenticatorTest.java | 22 +++++++++++++++ 5 files changed, 80 insertions(+), 15 deletions(-) diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java index ea0a6378d7..2100f68a97 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java @@ -16,6 +16,7 @@ import java.security.PrivilegedAction; import java.text.ParseException; import java.util.Collection; +import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -61,7 +62,7 @@ public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator private final String jwtUrlParameter; private final String subjectKey; private final String rolesKey; - private final String requiredAudience; + private final List requiredAudience; private final String requiredIssuer; public static final int DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS = 30; @@ -74,7 +75,7 @@ public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) { rolesKey = settings.get("roles_key"); subjectKey = settings.get("subject_key"); clockSkewToleranceSeconds = settings.getAsInt("jwt_clock_skew_tolerance_seconds", DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS); - requiredAudience = settings.get("required_audience"); + requiredAudience = settings.getAsList("required_audience"); requiredIssuer = settings.get("required_issuer"); if (!jwtHeaderName.equals(AUTHORIZATION)) { @@ -255,7 +256,7 @@ public Optional reRequestAuthentication(final SecurityRequest ); } - public String getRequiredAudience() { + public List getRequiredAudience() { return requiredAudience; } diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java index a6ff27eb6b..4a863c7cb1 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java @@ -16,6 +16,7 @@ import java.security.PrivilegedAction; import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -37,6 +38,7 @@ import org.opensearch.security.user.AuthCredentials; import org.opensearch.security.util.KeyUtils; +import com.nimbusds.jwt.proc.BadJWTException; import io.jsonwebtoken.Claims; import io.jsonwebtoken.JwtParser; import io.jsonwebtoken.JwtParserBuilder; @@ -58,7 +60,7 @@ public class HTTPJwtAuthenticator implements HTTPAuthenticator { private final String jwtUrlParameter; private final String rolesKey; private final String subjectKey; - private final String requireAudience; + private final List requiredAudience; private final String requireIssuer; public HTTPJwtAuthenticator(final Settings settings, final Path configPath) { @@ -70,7 +72,7 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) { isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); rolesKey = settings.get("roles_key"); subjectKey = settings.get("subject_key"); - requireAudience = settings.get("required_audience"); + requiredAudience = settings.getAsList("required_audience"); requireIssuer = settings.get("required_issuer"); if (!jwtHeaderName.equals(AUTHORIZATION)) { @@ -84,10 +86,6 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) { if (jwtParserBuilder == null) { jwtParser = null; } else { - if (requireAudience != null) { - jwtParserBuilder.requireAudience(requireAudience); - } - if (requireIssuer != null) { jwtParserBuilder.requireIssuer(requireIssuer); } @@ -161,6 +159,10 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) { try { final Claims claims = jwtParser.parseClaimsJws(jwtToken).getBody(); + if (!requiredAudience.isEmpty()) { + assertValidAudienceClaim(claims); + } + final String subject = extractSubject(claims, request); if (subject == null) { @@ -189,6 +191,16 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) { } } + private void assertValidAudienceClaim(Claims claims) throws BadJWTException { + if (requiredAudience.isEmpty()) { + return; + } + + if (Collections.disjoint(claims.getAudience(), requiredAudience)) { + throw new BadJWTException("Claim of 'aud' doesn't contain any required audience."); + } + } + @Override public Optional reRequestAuthentication(final SecurityRequest channel, AuthCredentials creds) { return Optional.of( diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java index da1b8393fb..809d68e531 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java @@ -13,6 +13,8 @@ import java.text.ParseException; import java.util.Collections; +import java.util.HashSet; +import java.util.List; import com.google.common.base.Strings; import org.apache.commons.lang3.StringEscapeUtils; @@ -38,9 +40,9 @@ public class JwtVerifier { private final KeyProvider keyProvider; private final int clockSkewToleranceSeconds; private final String requiredIssuer; - private final String requiredAudience; + private final List requiredAudience; - public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, String requiredAudience) { + public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, List requiredAudience) { this.keyProvider = keyProvider; this.clockSkewToleranceSeconds = clockSkewToleranceSeconds; this.requiredIssuer = requiredIssuer; @@ -116,9 +118,10 @@ private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTExceptio if (claims != null) { DefaultJWTClaimsVerifier claimsVerifier = new DefaultJWTClaimsVerifier<>( - requiredAudience, + requiredAudience.isEmpty() ? null : new HashSet<>(requiredAudience), null, - Collections.emptySet() + Collections.emptySet(), + null ); claimsVerifier.setMaxClockSkew(clockSkewToleranceSeconds); claimsVerifier.verify(claims, null); @@ -127,10 +130,10 @@ private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTExceptio } private void validateRequiredAudienceAndIssuer(JWTClaimsSet claims) throws BadJWTException { - String audience = claims.getAudience().stream().findFirst().orElse(""); + List audience = claims.getAudience(); String issuer = claims.getIssuer(); - if (!Strings.isNullOrEmpty(requiredAudience) && !requiredAudience.equals(audience)) { + if (!requiredAudience.isEmpty() && Collections.disjoint(requiredAudience, audience)) { throw new BadJWTException("Invalid audience"); } diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java index 3c9f2c158a..6e70034f9b 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java @@ -482,6 +482,33 @@ public void testRequiredAudienceWithIncorrectAudience() { Assert.assertNull(credentials); } + @Test + public void testRequiredAudienceWithCorrectAtLeastOneAudience() { + + final AuthCredentials credentials = extractCredentialsFromJwtHeader( + Settings.builder() + .put("signing_key", BaseEncoding.base64().encode(secretKeyBytes)) + .put("required_audience", "test_audience,test_audience_2"), + Jwts.builder().setSubject("Leonard McCoy").setAudience("test_audience_2") + ); + + Assert.assertNotNull(credentials); + Assert.assertEquals("Leonard McCoy", credentials.getUsername()); + } + + @Test + public void testRequiredAudienceWithInCorrectAtLeastOneAudience() { + + final AuthCredentials credentials = extractCredentialsFromJwtHeader( + Settings.builder() + .put("signing_key", BaseEncoding.base64().encode(secretKeyBytes)) + .put("required_audience", "test_audience,test_audience_2"), + Jwts.builder().setSubject("Leonard McCoy").setAudience("wrong_audience") + ); + + Assert.assertNull(credentials); + } + @Test public void testRequiredIssuerWithCorrectAudience() { diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java index a31e30db39..656922bb7a 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java @@ -118,6 +118,28 @@ public void jwksNotMatchingRequiredIssuerInClaimTest() { Assert.assertNull(creds); } + @Test + public void jwksMatchAtLeastOneRequiredAudienceInClaimTest() { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE + ",another_audience") + .build(); + + HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); + + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(0, creds.getBackendRoles().size()); + Assert.assertEquals(4, creds.getAttributes().size()); + } + @Test public void jwksMissingRequiredAudienceInClaimTest() { Settings settings = Settings.builder()