Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple audience for jwt authentication #4359

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.security.PrivilegedAction;
import java.text.ParseException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand Down Expand Up @@ -61,7 +62,7 @@ public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator
private final String jwtUrlParameter;
private final String subjectKey;
private final String rolesKey;
private final String requiredAudience;
private final List<String> requiredAudience;
private final String requiredIssuer;

public static final int DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS = 30;
Expand All @@ -74,7 +75,7 @@ public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) {
rolesKey = settings.get("roles_key");
subjectKey = settings.get("subject_key");
clockSkewToleranceSeconds = settings.getAsInt("jwt_clock_skew_tolerance_seconds", DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS);
requiredAudience = settings.get("required_audience");
requiredAudience = settings.getAsList("required_audience");
requiredIssuer = settings.get("required_issuer");

if (!jwtHeaderName.equals(AUTHORIZATION)) {
Expand Down Expand Up @@ -255,7 +256,7 @@ public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest
);
}

public String getRequiredAudience() {
public List<String> getRequiredAudience() {
return requiredAudience;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.security.PrivilegedAction;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand All @@ -37,6 +38,7 @@
import org.opensearch.security.user.AuthCredentials;
import org.opensearch.security.util.KeyUtils;

import com.nimbusds.jwt.proc.BadJWTException;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.JwtParserBuilder;
Expand All @@ -58,7 +60,7 @@
private final String jwtUrlParameter;
private final String rolesKey;
private final String subjectKey;
private final String requireAudience;
private final List<String> requiredAudience;
private final String requireIssuer;

public HTTPJwtAuthenticator(final Settings settings, final Path configPath) {
Expand All @@ -70,7 +72,7 @@
isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName);
rolesKey = settings.get("roles_key");
subjectKey = settings.get("subject_key");
requireAudience = settings.get("required_audience");
requiredAudience = settings.getAsList("required_audience");
requireIssuer = settings.get("required_issuer");

if (!jwtHeaderName.equals(AUTHORIZATION)) {
Expand All @@ -84,10 +86,6 @@
if (jwtParserBuilder == null) {
jwtParser = null;
} else {
if (requireAudience != null) {
jwtParserBuilder.requireAudience(requireAudience);
cwperks marked this conversation as resolved.
Show resolved Hide resolved
}

if (requireIssuer != null) {
jwtParserBuilder.requireIssuer(requireIssuer);
}
Expand Down Expand Up @@ -161,6 +159,10 @@
try {
final Claims claims = jwtParser.parseClaimsJws(jwtToken).getBody();

if (!requiredAudience.isEmpty()) {
assertValidAudienceClaim(claims);
}

final String subject = extractSubject(claims, request);

if (subject == null) {
Expand Down Expand Up @@ -189,6 +191,16 @@
}
}

private void assertValidAudienceClaim(Claims claims) throws BadJWTException {
if (requiredAudience.isEmpty()) {
return;

Check warning on line 196 in src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java#L196

Added line #L196 was not covered by tests
}

if (Collections.disjoint(claims.getAudience(), requiredAudience)) {
throw new BadJWTException("Claim of 'aud' doesn't contain any required audience.");
}
}

@Override
public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest channel, AuthCredentials creds) {
return Optional.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import java.text.ParseException;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;

import com.google.common.base.Strings;
import org.apache.commons.lang3.StringEscapeUtils;
Expand All @@ -38,9 +40,9 @@ public class JwtVerifier {
private final KeyProvider keyProvider;
private final int clockSkewToleranceSeconds;
private final String requiredIssuer;
private final String requiredAudience;
private final List<String> requiredAudience;

public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, String requiredAudience) {
public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, List<String> requiredAudience) {
this.keyProvider = keyProvider;
this.clockSkewToleranceSeconds = clockSkewToleranceSeconds;
this.requiredIssuer = requiredIssuer;
Expand Down Expand Up @@ -116,9 +118,10 @@ private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTExceptio

if (claims != null) {
DefaultJWTClaimsVerifier<SimpleSecurityContext> claimsVerifier = new DefaultJWTClaimsVerifier<>(
requiredAudience,
requiredAudience.isEmpty() ? null : new HashSet<>(requiredAudience),
null,
Collections.emptySet()
Collections.emptySet(),
null
);
claimsVerifier.setMaxClockSkew(clockSkewToleranceSeconds);
claimsVerifier.verify(claims, null);
Expand All @@ -127,10 +130,10 @@ private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTExceptio
}

private void validateRequiredAudienceAndIssuer(JWTClaimsSet claims) throws BadJWTException {
String audience = claims.getAudience().stream().findFirst().orElse("");
List<String> audience = claims.getAudience();
String issuer = claims.getIssuer();

if (!Strings.isNullOrEmpty(requiredAudience) && !requiredAudience.equals(audience)) {
if (!requiredAudience.isEmpty() && Collections.disjoint(requiredAudience, audience)) {
throw new BadJWTException("Invalid audience");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,33 @@ public void testRequiredAudienceWithIncorrectAudience() {
Assert.assertNull(credentials);
}

@Test
public void testRequiredAudienceWithCorrectAtLeastOneAudience() {

final AuthCredentials credentials = extractCredentialsFromJwtHeader(
Settings.builder()
.put("signing_key", BaseEncoding.base64().encode(secretKeyBytes))
.put("required_audience", "test_audience,test_audience_2"),
Jwts.builder().setSubject("Leonard McCoy").setAudience("test_audience_2")
);

Assert.assertNotNull(credentials);
Assert.assertEquals("Leonard McCoy", credentials.getUsername());
}

@Test
public void testRequiredAudienceWithInCorrectAtLeastOneAudience() {

final AuthCredentials credentials = extractCredentialsFromJwtHeader(
Settings.builder()
.put("signing_key", BaseEncoding.base64().encode(secretKeyBytes))
.put("required_audience", "test_audience,test_audience_2"),
Jwts.builder().setSubject("Leonard McCoy").setAudience("wrong_audience")
);

Assert.assertNull(credentials);
}

@Test
public void testRequiredIssuerWithCorrectAudience() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,28 @@ public void jwksNotMatchingRequiredIssuerInClaimTest() {
Assert.assertNull(creds);
}

@Test
public void jwksMatchAtLeastOneRequiredAudienceInClaimTest() {
Settings settings = Settings.builder()
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
.put("required_issuer", TestJwts.TEST_ISSUER)
.put("required_audience", TestJwts.TEST_AUDIENCE + ",another_audience")
.build();

HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);

AuthCredentials creds = jwtAuth.extractCredentials(
new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()).asSecurityRequest(),
null
);

Assert.assertNotNull(creds);
Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername());
Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud"));
Assert.assertEquals(0, creds.getBackendRoles().size());
Assert.assertEquals(4, creds.getAttributes().size());
}

@Test
public void jwksMissingRequiredAudienceInClaimTest() {
Settings settings = Settings.builder()
Expand Down
Loading