From 348149cbacc2f28596b4caf98755ee8c23100dc0 Mon Sep 17 00:00:00 2001 From: luneo7 Date: Fri, 15 Nov 2024 12:48:00 -0600 Subject: [PATCH] Add jwksRetainOnErrorDuration --- .../AbstractKeyLocationResolver.java | 1 + .../auth/principal/JWTAuthContextInfo.java | 11 +++ .../config/JWTAuthContextInfoProvider.java | 11 +++ .../principal/KeyLocationResolverTest.java | 68 ++++++++++++++++++- 4 files changed, 90 insertions(+), 1 deletion(-) diff --git a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/AbstractKeyLocationResolver.java b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/AbstractKeyLocationResolver.java index f039e04e..59f802c3 100644 --- a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/AbstractKeyLocationResolver.java +++ b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/AbstractKeyLocationResolver.java @@ -115,6 +115,7 @@ protected HttpsJwks initializeHttpsJwks(String location) new InetSocketAddress(authContextInfo.getHttpProxyHost(), authContextInfo.getHttpProxyPort()))); } theHttpsJwks.setSimpleHttpGet(httpGet); + theHttpsJwks.setRetainCacheOnErrorDuration(authContextInfo.getJwksRetainOnErrorDuration()); return theHttpsJwks; } diff --git a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/JWTAuthContextInfo.java b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/JWTAuthContextInfo.java index 9eb36be8..e4eae651 100644 --- a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/JWTAuthContextInfo.java +++ b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/auth/principal/JWTAuthContextInfo.java @@ -51,6 +51,7 @@ public class JWTAuthContextInfo { private String decryptionKeyContent; private Integer jwksRefreshInterval = 60; private int forcedJwksRefreshInterval = 30; + private long jwksRetainOnErrorDuration = 0; private String tokenHeader = "Authorization"; private String tokenCookie; private boolean alwaysCheckAuthorization; @@ -121,6 +122,7 @@ public JWTAuthContextInfo(JWTAuthContextInfo orig) { this.decryptionKeyContent = orig.getDecryptionKeyContent(); this.jwksRefreshInterval = orig.getJwksRefreshInterval(); this.forcedJwksRefreshInterval = orig.getForcedJwksRefreshInterval(); + this.jwksRetainOnErrorDuration = orig.getJwksRetainOnErrorDuration(); this.tokenHeader = orig.getTokenHeader(); this.tokenCookie = orig.getTokenCookie(); this.alwaysCheckAuthorization = orig.isAlwaysCheckAuthorization(); @@ -283,6 +285,14 @@ public void setForcedJwksRefreshInterval(int forcedJwksRefreshInterval) { this.forcedJwksRefreshInterval = forcedJwksRefreshInterval; } + public long getJwksRetainOnErrorDuration() { + return jwksRetainOnErrorDuration; + } + + public void setJwksRetainOnErrorDuration(long jwksRetainOnErrorDuration) { + this.jwksRetainOnErrorDuration = jwksRetainOnErrorDuration; + } + public String getTokenHeader() { return tokenHeader; } @@ -436,6 +446,7 @@ public String toString() { ", decryptionKeyLocation='" + decryptionKeyLocation + '\'' + ", decryptionKeyContent='" + decryptionKeyContent + '\'' + ", jwksRefreshInterval=" + jwksRefreshInterval + + ", jwksRetainOnErrorDuration=" + jwksRetainOnErrorDuration + ", tokenHeader='" + tokenHeader + '\'' + ", tokenCookie='" + tokenCookie + '\'' + ", alwaysCheckAuthorization=" + alwaysCheckAuthorization + diff --git a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/config/JWTAuthContextInfoProvider.java b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/config/JWTAuthContextInfoProvider.java index dd2482c4..06490d8b 100644 --- a/implementation/jwt-auth/src/main/java/io/smallrye/jwt/config/JWTAuthContextInfoProvider.java +++ b/implementation/jwt-auth/src/main/java/io/smallrye/jwt/config/JWTAuthContextInfoProvider.java @@ -186,6 +186,7 @@ private static JWTAuthContextInfoProvider create(String key, provider.mpJwtVerifyTokenAge = Optional.empty(); provider.jwksRefreshInterval = 60; provider.forcedJwksRefreshInterval = 30; + provider.jwksRetainOnErrorDuration = 0; provider.signatureAlgorithm = Optional.of(SignatureAlgorithm.RS256); provider.keyEncryptionAlgorithm = Optional.empty(); provider.mpJwtDecryptKeyAlgorithm = new HashSet<>(Arrays.asList(KeyEncryptionAlgorithm.RSA_OAEP, @@ -465,6 +466,15 @@ private static JWTAuthContextInfoProvider create(String key, @ConfigProperty(name = "smallrye.jwt.jwks.forced-refresh-interval", defaultValue = "30") private int forcedJwksRefreshInterval; + /** + * JWK cache retain on error duration in seconds which sets the length of time, before trying again, to keep using the cache + * when an error occurs making the request to the JWKS URI or parsing the response. + * It will be ignored unless the 'mp.jwt.verify.publickey.location' property points to the HTTP or HTTPS URL based JWK set. + */ + @Inject + @ConfigProperty(name = "smallrye.jwt.jwks.retain-on-error-duration", defaultValue = "0") + private long jwksRetainOnErrorDuration; + /** * Supported JSON Web Algorithm asymmetric or symmetric signature algorithm. * @@ -836,6 +846,7 @@ Optional getOptionalContextInfo() { contextInfo.setTokenAge(mpJwtVerifyTokenAge.orElse(null)); contextInfo.setJwksRefreshInterval(jwksRefreshInterval); contextInfo.setForcedJwksRefreshInterval(forcedJwksRefreshInterval); + contextInfo.setJwksRetainOnErrorDuration(jwksRetainOnErrorDuration); Set resolvedAlgorithm = mpJwtPublicKeyAlgorithm; if (signatureAlgorithm.isPresent()) { if (signatureAlgorithm.get().getAlgorithm().startsWith("HS")) { diff --git a/implementation/jwt-auth/src/test/java/io/smallrye/jwt/auth/principal/KeyLocationResolverTest.java b/implementation/jwt-auth/src/test/java/io/smallrye/jwt/auth/principal/KeyLocationResolverTest.java index 6ec74134..2d7dffd1 100644 --- a/implementation/jwt-auth/src/test/java/io/smallrye/jwt/auth/principal/KeyLocationResolverTest.java +++ b/implementation/jwt-auth/src/test/java/io/smallrye/jwt/auth/principal/KeyLocationResolverTest.java @@ -21,7 +21,11 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.net.Proxy; @@ -30,12 +34,18 @@ import java.security.interfaces.RSAPublicKey; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; +import java.util.concurrent.TimeUnit; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; +import org.jose4j.base64url.Base64Url; import org.jose4j.http.Get; +import org.jose4j.http.SimpleResponse; +import org.jose4j.json.internal.json_simple.JSONObject; import org.jose4j.jwk.HttpsJwks; import org.jose4j.jwk.JsonWebKey; import org.jose4j.jwk.OctetSequenceJsonWebKey; @@ -68,6 +78,8 @@ class KeyLocationResolverTest { Get mockedGet; @Mock UrlStreamResolver urlResolver; + @Mock + SimpleResponse simpleResponse; RSAPublicKey rsaKey; SecretKey secretKey; @@ -180,6 +192,46 @@ protected Get getHttpGet() { assertNull(keyLocationResolver.key); } + @Test + void keepsRsaKeyFromHttpsJwksWhenErrorDuringRefresh() throws Exception { + long cacheDuration = 1L; + long jwksRetainOnErrorDuration = 10; + JWTAuthContextInfo contextInfo = new JWTAuthContextInfo("https://github.com/my_key.jwks", "issuer"); + contextInfo.setJwksRetainOnErrorDuration(jwksRetainOnErrorDuration); + + HttpsJwks spiedHttpsJwks = Mockito.spy(new HttpsJwks(contextInfo.getPublicKeyLocation())); + spiedHttpsJwks.setDefaultCacheDuration(cacheDuration); + when(simpleResponse.getBody()).thenReturn(generateJWK(rsaKey)); + when(mockedGet.get(contextInfo.getPublicKeyLocation())).thenReturn(simpleResponse); + + KeyLocationResolver keyLocationResolver = new KeyLocationResolver(contextInfo) { + protected HttpsJwks getHttpsJwks(String loc) { + return spiedHttpsJwks; + } + + protected Get getHttpGet() { + return mockedGet; + } + }; + + Mockito.verify(spiedHttpsJwks).setRetainCacheOnErrorDuration(jwksRetainOnErrorDuration); + Mockito.verify(spiedHttpsJwks).setSimpleHttpGet(mockedGet); + + when(signature.getHeaders()).thenReturn(headers); + when(headers.getStringHeaderValue(JsonWebKey.KEY_ID_PARAMETER)).thenReturn("1"); + when(headers.getStringHeaderValue(JsonWebKey.ALGORITHM_PARAMETER)).thenReturn("RS256"); + + assertEquals(rsaKey, keyLocationResolver.resolveKey(signature, emptyList())); + + doThrow(RuntimeException.class).when(mockedGet).get(contextInfo.getPublicKeyLocation()); + + TimeUnit.SECONDS.sleep(cacheDuration); + + assertEquals(rsaKey, keyLocationResolver.resolveKey(signature, emptyList())); + + verify(mockedGet, atLeastOnce()).get(contextInfo.getPublicKeyLocation()); + } + @Test void loadRsaKeyFromHttpJwks() throws Exception { JWTAuthContextInfo contextInfo = new JWTAuthContextInfo("http://github.com/my_key.jwks", "issuer"); @@ -330,7 +382,7 @@ void loadHttpsPemCrt() throws Exception { contextInfo.setJwksRefreshInterval(10); Mockito.doThrow(new JoseException("")).when(mockedHttpsJwks).refresh(); - Mockito.doReturn(ResourceUtils.getAsClasspathResource("publicCrt.pem")) + doReturn(ResourceUtils.getAsClasspathResource("publicCrt.pem")) .when(urlResolver).resolve(Mockito.any()); KeyLocationResolver keyLocationResolver = new KeyLocationResolver(contextInfo) { protected HttpsJwks initializeHttpsJwks(String loc) { @@ -380,4 +432,18 @@ void loadJWKOnClassPath() throws Exception { assertEquals(keyLocationResolver.key, keyLocationResolver.getJsonWebKey("key1", null).getKey()); } + + private String generateJWK(RSAPublicKey publicKey) { + Map key = new HashMap<>(); + + key.put("alg", "RS256"); + key.put("use", "sig"); + key.put("kty", publicKey.getAlgorithm()); + key.put("kid", "1"); + key.put("n", Base64Url.encode(publicKey.getModulus().toByteArray())); + key.put("e", Base64Url.encode(publicKey.getPublicExponent().toByteArray())); + + return JSONObject.toJSONString(Collections.singletonMap("keys", + Collections.singletonList(key))); + } }