diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt index 2d975200..db2873cf 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt @@ -74,7 +74,7 @@ open class DefaultOAuth2TokenCallback data class RequestMappingTokenCallback( val issuerId: String, - val requestMappings: Set, + val requestMappings: List, val tokenExpiry: Long = Duration.ofHours(1).toSeconds(), ) : OAuth2TokenCallback { override fun issuerId(): String = issuerId @@ -89,31 +89,59 @@ data class RequestMappingTokenCallback( override fun tokenExpiry(): Long = tokenExpiry - private fun Set.getClaims(tokenRequest: TokenRequest): Map { + private fun List.getClaims(tokenRequest: TokenRequest): Map { val claims = firstOrNull { it.isMatch(tokenRequest) }?.claims ?: emptyMap() - return if (tokenRequest.grantType() == GrantType.CLIENT_CREDENTIALS && claims["sub"] == "\${clientId}") { - claims + ("sub" to tokenRequest.clientIdAsString()) - } else { - claims + val customParameters = tokenRequest.customParameters.mapValues { (_, value) -> value.first() } + val variables = + if (tokenRequest.grantType() == GrantType.CLIENT_CREDENTIALS) { + customParameters + ("clientId" to tokenRequest.clientIdAsString()) + } else { + customParameters + } + return claims.mapValues { (_, value) -> + when (value) { + is String -> replaceVariables(value, variables) + is List<*> -> + value.map { v -> + if (v is String) { + replaceVariables(v, variables) + } else { + v + } + } + else -> value + } } } - private inline fun Set.getClaimOrNull( + private inline fun List.getClaimOrNull( tokenRequest: TokenRequest, key: String, ): T? = getClaims(tokenRequest)[key] as? T - private fun Set.getTypeHeader(tokenRequest: TokenRequest) = firstOrNull { it.isMatch(tokenRequest) }?.typeHeader ?: JOSEObjectType.JWT.type + private fun List.getTypeHeader(tokenRequest: TokenRequest) = firstOrNull { it.isMatch(tokenRequest) }?.typeHeader ?: JOSEObjectType.JWT.type + + private fun replaceVariables( + input: String, + replacements: Map, + ): String { + val pattern = Regex("""\$\{(\w+)}""") + return pattern.replace(input) { result -> + val variableName = result.groupValues[1] + val replacement = replacements[variableName] + replacement ?: result.value + } + } } data class RequestMapping( private val requestParam: String, - private val match: String = "*", + private val match: String, val claims: Map = emptyMap(), val typeHeader: String = JOSEObjectType.JWT.type, ) { fun isMatch(tokenRequest: TokenRequest): Boolean = tokenRequest.toHTTPRequest().queryParameters[requestParam]?.any { - if (match != "*") it == match else true + match == "*" || match == it || match.toRegex().matchEntire(it) != null } ?: false } diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt index 7e116bdb..a675dfcf 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt @@ -17,7 +17,7 @@ internal class OAuth2TokenCallbackTest { RequestMappingTokenCallback( issuerId = "issuer1", requestMappings = - setOf( + listOf( RequestMapping( requestParam = "scope", match = "scope1", @@ -39,6 +39,15 @@ internal class OAuth2TokenCallbackTest { "custom" to "custom2", ), ), + RequestMapping( + requestParam = "audience", + match = "https://myapp.com/jwt/aud/.*", + claims = + mapOf( + "sub" to "\${clientId}", + "aud" to listOf("\${audience}"), + ), + ), RequestMapping( requestParam = "grant_type", match = "authorization_code", @@ -104,6 +113,17 @@ internal class OAuth2TokenCallbackTest { issuer1.typeHeader(grantTypeShouldMatch) shouldBe "JWT" } } + + @Test + fun `token request with request params matching requestmapping should return specific claims from callback with audience`() { + val grantTypeShouldMatch = clientCredentialsRequest("audience" to "https://myapp.com/jwt/aud/xxx") + assertSoftly { + issuer1.subject(grantTypeShouldMatch) shouldBe clientId + issuer1.audience(grantTypeShouldMatch) shouldBe listOf("https://myapp.com/jwt/aud/xxx") + issuer1.tokenExpiry() shouldBe 120 + issuer1.typeHeader(grantTypeShouldMatch) shouldBe "JWT" + } + } } @Nested