diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java index ea0a6378d7..2100f68a97 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/AbstractHTTPJwtAuthenticator.java @@ -16,6 +16,7 @@ import java.security.PrivilegedAction; import java.text.ParseException; import java.util.Collection; +import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -61,7 +62,7 @@ public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator private final String jwtUrlParameter; private final String subjectKey; private final String rolesKey; - private final String requiredAudience; + private final List requiredAudience; private final String requiredIssuer; public static final int DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS = 30; @@ -74,7 +75,7 @@ public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) { rolesKey = settings.get("roles_key"); subjectKey = settings.get("subject_key"); clockSkewToleranceSeconds = settings.getAsInt("jwt_clock_skew_tolerance_seconds", DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS); - requiredAudience = settings.get("required_audience"); + requiredAudience = settings.getAsList("required_audience"); requiredIssuer = settings.get("required_issuer"); if (!jwtHeaderName.equals(AUTHORIZATION)) { @@ -255,7 +256,7 @@ public Optional reRequestAuthentication(final SecurityRequest ); } - public String getRequiredAudience() { + public List getRequiredAudience() { return requiredAudience; } diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java index a6ff27eb6b..907cb4cc5d 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java @@ -16,12 +16,15 @@ import java.security.PrivilegedAction; import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; import java.util.Set; import java.util.regex.Pattern; +import com.nimbusds.jwt.proc.BadJWTException; +import io.jsonwebtoken.IncorrectClaimException; import org.apache.http.HttpStatus; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -58,7 +61,7 @@ public class HTTPJwtAuthenticator implements HTTPAuthenticator { private final String jwtUrlParameter; private final String rolesKey; private final String subjectKey; - private final String requireAudience; + private final List requiredAudience; private final String requireIssuer; public HTTPJwtAuthenticator(final Settings settings, final Path configPath) { @@ -70,7 +73,7 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) { isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName); rolesKey = settings.get("roles_key"); subjectKey = settings.get("subject_key"); - requireAudience = settings.get("required_audience"); + requiredAudience = settings.getAsList("required_audience"); requireIssuer = settings.get("required_issuer"); if (!jwtHeaderName.equals(AUTHORIZATION)) { @@ -84,10 +87,6 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) { if (jwtParserBuilder == null) { jwtParser = null; } else { - if (requireAudience != null) { - jwtParserBuilder.requireAudience(requireAudience); - } - if (requireIssuer != null) { jwtParserBuilder.requireIssuer(requireIssuer); } @@ -161,6 +160,10 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) { try { final Claims claims = jwtParser.parseClaimsJws(jwtToken).getBody(); + if (!requiredAudience.isEmpty()) { + assertValidAudienceClaim(claims); + } + final String subject = extractSubject(claims, request); if (subject == null) { @@ -189,6 +192,16 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) { } } + private void assertValidAudienceClaim(Claims claims) throws BadJWTException { + if (requiredAudience.isEmpty()) { + return; + } + + if (Collections.disjoint(claims.getAudience(), requiredAudience)) { + throw new BadJWTException("Claim of 'aud' doesn't contain any required audience."); + } + } + @Override public Optional reRequestAuthentication(final SecurityRequest channel, AuthCredentials creds) { return Optional.of( diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java index da1b8393fb..ca0f4284b1 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java @@ -13,6 +13,10 @@ import java.text.ParseException; import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; import com.google.common.base.Strings; import org.apache.commons.lang3.StringEscapeUtils; @@ -38,9 +42,9 @@ public class JwtVerifier { private final KeyProvider keyProvider; private final int clockSkewToleranceSeconds; private final String requiredIssuer; - private final String requiredAudience; + private final List requiredAudience; - public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, String requiredAudience) { + public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, List requiredAudience) { this.keyProvider = keyProvider; this.clockSkewToleranceSeconds = clockSkewToleranceSeconds; this.requiredIssuer = requiredIssuer; @@ -116,9 +120,10 @@ private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTExceptio if (claims != null) { DefaultJWTClaimsVerifier claimsVerifier = new DefaultJWTClaimsVerifier<>( - requiredAudience, - null, - Collections.emptySet() + requiredAudience.isEmpty() ? null : new HashSet<>(requiredAudience), + null, + Collections.emptySet(), + null ); claimsVerifier.setMaxClockSkew(clockSkewToleranceSeconds); claimsVerifier.verify(claims, null); @@ -127,10 +132,10 @@ private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTExceptio } private void validateRequiredAudienceAndIssuer(JWTClaimsSet claims) throws BadJWTException { - String audience = claims.getAudience().stream().findFirst().orElse(""); + List audience = claims.getAudience(); String issuer = claims.getIssuer(); - if (!Strings.isNullOrEmpty(requiredAudience) && !requiredAudience.equals(audience)) { + if (!requiredAudience.isEmpty() && Collections.disjoint(requiredAudience, audience)) { throw new BadJWTException("Invalid audience"); } diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java index 3c9f2c158a..2c04e04005 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticatorTest.java @@ -20,7 +20,10 @@ import java.util.Collections; import java.util.Date; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; import javax.crypto.SecretKey; import com.google.common.io.BaseEncoding; @@ -482,6 +485,29 @@ public void testRequiredAudienceWithIncorrectAudience() { Assert.assertNull(credentials); } + @Test + public void testRequiredAudienceWithCorrectAtLeastOneAudience() { + + final AuthCredentials credentials = extractCredentialsFromJwtHeader( + Settings.builder().put("signing_key", BaseEncoding.base64().encode(secretKeyBytes)).put("required_audience", "test_audience,test_audience_2"), + Jwts.builder().setSubject("Leonard McCoy").setAudience("test_audience_2") + ); + + Assert.assertNotNull(credentials); + Assert.assertEquals("Leonard McCoy", credentials.getUsername()); + } + + @Test + public void testRequiredAudienceWithInCorrectAtLeastOneAudience() { + + final AuthCredentials credentials = extractCredentialsFromJwtHeader( + Settings.builder().put("signing_key", BaseEncoding.base64().encode(secretKeyBytes)).put("required_audience", "test_audience,test_audience_2"), + Jwts.builder().setSubject("Leonard McCoy").setAudience("wrong_audience") + ); + + Assert.assertNull(credentials); + } + @Test public void testRequiredIssuerWithCorrectAudience() { diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java index a31e30db39..5ec81d64a3 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPJwtKeyByOpenIdConnectAuthenticatorTest.java @@ -118,6 +118,28 @@ public void jwksNotMatchingRequiredIssuerInClaimTest() { Assert.assertNull(creds); } + @Test + public void jwksMatchAtLeastOneRequiredAudienceInClaimTest() { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE + ",another_audience") + .build(); + + HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); + + AuthCredentials creds = jwtAuth.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername()); + Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud")); + Assert.assertEquals(0, creds.getBackendRoles().size()); + Assert.assertEquals(4, creds.getAttributes().size()); + } + @Test public void jwksMissingRequiredAudienceInClaimTest() { Settings settings = Settings.builder()