Skip to content

Commit

Permalink
Support multiple audience for jwt authentication
Browse files Browse the repository at this point in the history
Signed-off-by: leedonggyu <[email protected]>
  • Loading branch information
donggyu04 committed May 22, 2024
1 parent 382bc5f commit f08ceb3
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.security.PrivilegedAction;
import java.text.ParseException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand Down Expand Up @@ -61,7 +62,7 @@ public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator
private final String jwtUrlParameter;
private final String subjectKey;
private final String rolesKey;
private final String requiredAudience;
private final List<String> requiredAudience;
private final String requiredIssuer;

public static final int DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS = 30;
Expand All @@ -74,7 +75,7 @@ public AbstractHTTPJwtAuthenticator(Settings settings, Path configPath) {
rolesKey = settings.get("roles_key");
subjectKey = settings.get("subject_key");
clockSkewToleranceSeconds = settings.getAsInt("jwt_clock_skew_tolerance_seconds", DEFAULT_CLOCK_SKEW_TOLERANCE_SECONDS);
requiredAudience = settings.get("required_audience");
requiredAudience = settings.getAsList("required_audience");
requiredIssuer = settings.get("required_issuer");

if (!jwtHeaderName.equals(AUTHORIZATION)) {
Expand Down Expand Up @@ -255,7 +256,7 @@ public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest
);
}

public String getRequiredAudience() {
public List<String> getRequiredAudience() {
return requiredAudience;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,14 @@

package com.amazon.dlic.auth.http.jwt;

import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Pattern;

import com.nimbusds.jwt.proc.BadJWTException;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.JwtParserBuilder;
import io.jsonwebtoken.security.WeakKeyException;
import org.apache.http.HttpStatus;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.OpenSearchSecurityException;
import org.opensearch.SpecialPermission;
import org.opensearch.common.logging.DeprecationLogger;
Expand All @@ -37,10 +30,17 @@
import org.opensearch.security.user.AuthCredentials;
import org.opensearch.security.util.KeyUtils;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.JwtParserBuilder;
import io.jsonwebtoken.security.WeakKeyException;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Pattern;

import static org.apache.http.HttpHeaders.AUTHORIZATION;

Expand All @@ -58,7 +58,7 @@ public class HTTPJwtAuthenticator implements HTTPAuthenticator {
private final String jwtUrlParameter;
private final String rolesKey;
private final String subjectKey;
private final String requireAudience;
private final List<String> requiredAudience;
private final String requireIssuer;

public HTTPJwtAuthenticator(final Settings settings, final Path configPath) {
Expand All @@ -70,7 +70,7 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) {
isDefaultAuthHeader = AUTHORIZATION.equalsIgnoreCase(jwtHeaderName);
rolesKey = settings.get("roles_key");
subjectKey = settings.get("subject_key");
requireAudience = settings.get("required_audience");
requiredAudience = settings.getAsList("required_audience");
requireIssuer = settings.get("required_issuer");

if (!jwtHeaderName.equals(AUTHORIZATION)) {
Expand All @@ -84,10 +84,6 @@ public HTTPJwtAuthenticator(final Settings settings, final Path configPath) {
if (jwtParserBuilder == null) {
jwtParser = null;
} else {
if (requireAudience != null) {
jwtParserBuilder.requireAudience(requireAudience);
}

if (requireIssuer != null) {
jwtParserBuilder.requireIssuer(requireIssuer);
}
Expand Down Expand Up @@ -161,6 +157,10 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) {
try {
final Claims claims = jwtParser.parseClaimsJws(jwtToken).getBody();

if (!requiredAudience.isEmpty()) {
assertValidAudienceClaim(claims);
}

final String subject = extractSubject(claims, request);

if (subject == null) {
Expand Down Expand Up @@ -189,6 +189,16 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) {
}
}

private void assertValidAudienceClaim(Claims claims) throws BadJWTException {
if (requiredAudience.isEmpty()) {
return;

Check warning on line 194 in src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/amazon/dlic/auth/http/jwt/HTTPJwtAuthenticator.java#L194

Added line #L194 was not covered by tests
}

if (Collections.disjoint(claims.getAudience(), requiredAudience)) {
throw new BadJWTException("Claim of 'aud' doesn't contain any required audience.");
}
}

@Override
public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest channel, AuthCredentials creds) {
return Optional.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,7 @@

package com.amazon.dlic.auth.http.jwt.keybyoidc;

import java.text.ParseException;
import java.util.Collections;

import com.google.common.base.Strings;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSVerifier;
Expand All @@ -30,6 +23,14 @@
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.text.ParseException;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;

public class JwtVerifier {

Expand All @@ -38,9 +39,9 @@ public class JwtVerifier {
private final KeyProvider keyProvider;
private final int clockSkewToleranceSeconds;
private final String requiredIssuer;
private final String requiredAudience;
private final List<String> requiredAudience;

public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, String requiredAudience) {
public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, String requiredIssuer, List<String> requiredAudience) {
this.keyProvider = keyProvider;
this.clockSkewToleranceSeconds = clockSkewToleranceSeconds;
this.requiredIssuer = requiredIssuer;
Expand Down Expand Up @@ -116,9 +117,10 @@ private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTExceptio

if (claims != null) {
DefaultJWTClaimsVerifier<SimpleSecurityContext> claimsVerifier = new DefaultJWTClaimsVerifier<>(
requiredAudience,
null,
Collections.emptySet()
requiredAudience.isEmpty() ? null : new HashSet<>(requiredAudience),
null,
Collections.emptySet(),
null
);
claimsVerifier.setMaxClockSkew(clockSkewToleranceSeconds);
claimsVerifier.verify(claims, null);
Expand All @@ -127,10 +129,10 @@ private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTExceptio
}

private void validateRequiredAudienceAndIssuer(JWTClaimsSet claims) throws BadJWTException {
String audience = claims.getAudience().stream().findFirst().orElse("");
List<String> audience = claims.getAudience();
String issuer = claims.getIssuer();

if (!Strings.isNullOrEmpty(requiredAudience) && !requiredAudience.equals(audience)) {
if (!requiredAudience.isEmpty() && Collections.disjoint(requiredAudience, audience)) {
throw new BadJWTException("Invalid audience");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@

package com.amazon.dlic.auth.http.jwt;

import com.google.common.io.BaseEncoding;
import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.security.Keys;
import org.apache.hc.core5.http.HttpHeaders;
import org.junit.Assert;
import org.junit.Test;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.common.settings.Settings;
import org.opensearch.security.user.AuthCredentials;
import org.opensearch.security.util.FakeRestRequest;

import javax.crypto.SecretKey;
import java.nio.charset.StandardCharsets;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
Expand All @@ -21,26 +35,12 @@
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import javax.crypto.SecretKey;

import com.google.common.io.BaseEncoding;
import org.apache.hc.core5.http.HttpHeaders;
import org.junit.Assert;
import org.junit.Test;

import org.opensearch.OpenSearchSecurityException;
import org.opensearch.common.settings.Settings;
import org.opensearch.security.user.AuthCredentials;
import org.opensearch.security.util.FakeRestRequest;

import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.security.Keys;

import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

;

public class HTTPJwtAuthenticatorTest {

final static byte[] secretKeyBytes = new byte[1024];
Expand Down Expand Up @@ -482,6 +482,29 @@ public void testRequiredAudienceWithIncorrectAudience() {
Assert.assertNull(credentials);
}

@Test
public void testRequiredAudienceWithCorrectAtLeastOneAudience() {

final AuthCredentials credentials = extractCredentialsFromJwtHeader(
Settings.builder().put("signing_key", BaseEncoding.base64().encode(secretKeyBytes)).put("required_audience", "test_audience,test_audience_2"),
Jwts.builder().setSubject("Leonard McCoy").setAudience("test_audience_2")
);

Assert.assertNotNull(credentials);
Assert.assertEquals("Leonard McCoy", credentials.getUsername());
}

@Test
public void testRequiredAudienceWithInCorrectAtLeastOneAudience() {

final AuthCredentials credentials = extractCredentialsFromJwtHeader(
Settings.builder().put("signing_key", BaseEncoding.base64().encode(secretKeyBytes)).put("required_audience", "test_audience,test_audience_2"),
Jwts.builder().setSubject("Leonard McCoy").setAudience("wrong_audience")
);

Assert.assertNull(credentials);
}

@Test
public void testRequiredIssuerWithCorrectAudience() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,28 @@ public void jwksNotMatchingRequiredIssuerInClaimTest() {
Assert.assertNull(creds);
}

@Test
public void jwksMatchAtLeastOneRequiredAudienceInClaimTest() {
Settings settings = Settings.builder()
.put("openid_connect_url", mockIdpServer.getDiscoverUri())
.put("required_issuer", TestJwts.TEST_ISSUER)
.put("required_audience", TestJwts.TEST_AUDIENCE + ",another_audience")
.build();

HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null);

AuthCredentials creds = jwtAuth.extractCredentials(
new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()).asSecurityRequest(),
null
);

Assert.assertNotNull(creds);
Assert.assertEquals(TestJwts.MCCOY_SUBJECT, creds.getUsername());
Assert.assertEquals(List.of(TestJwts.TEST_AUDIENCE).toString(), creds.getAttributes().get("attr.jwt.aud"));
Assert.assertEquals(0, creds.getBackendRoles().size());
Assert.assertEquals(4, creds.getAttributes().size());
}

@Test
public void jwksMissingRequiredAudienceInClaimTest() {
Settings settings = Settings.builder()
Expand Down

0 comments on commit f08ceb3

Please sign in to comment.