diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/MockOAuth2Server.kt b/src/main/kotlin/no/nav/security/mock/oauth2/MockOAuth2Server.kt index c4e7c5ce..67384272 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/MockOAuth2Server.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/MockOAuth2Server.kt @@ -90,7 +90,7 @@ class MockOAuth2Server( DefaultOAuth2TokenCallback( issuerId, subject, - audience, + audience?.let { listOf(it) }, claims, expiry ) diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeGrantHandler.kt b/src/main/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeGrantHandler.kt index cfdcd6a0..c0038b8f 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeGrantHandler.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeGrantHandler.kt @@ -93,7 +93,7 @@ internal class AuthorizationCodeHandler( private class LoginOAuth2TokenCallback(val login: Login, val OAuth2TokenCallback: OAuth2TokenCallback) : OAuth2TokenCallback { override fun issuerId(): String = OAuth2TokenCallback.issuerId() override fun subject(tokenRequest: TokenRequest): String = login.username - override fun audience(tokenRequest: TokenRequest): String = OAuth2TokenCallback.audience(tokenRequest) + override fun audience(tokenRequest: TokenRequest): List = OAuth2TokenCallback.audience(tokenRequest) override fun addClaims(tokenRequest: TokenRequest): Map = OAuth2TokenCallback.addClaims(tokenRequest).toMutableMap().apply { login.acr?.let { put("acr", it) } diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/grant/JwtBearerGrantHandler.kt b/src/main/kotlin/no/nav/security/mock/oauth2/grant/JwtBearerGrantHandler.kt index 7321fe32..ec1fe98d 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/grant/JwtBearerGrantHandler.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/grant/JwtBearerGrantHandler.kt @@ -8,6 +8,7 @@ import no.nav.security.mock.oauth2.OAuth2Exception import no.nav.security.mock.oauth2.extensions.expiresIn import no.nav.security.mock.oauth2.http.OAuth2HttpRequest import no.nav.security.mock.oauth2.http.OAuth2TokenResponse +import no.nav.security.mock.oauth2.invalidRequest import no.nav.security.mock.oauth2.token.OAuth2TokenCallback import no.nav.security.mock.oauth2.token.OAuth2TokenProvider import okhttp3.HttpUrl @@ -20,7 +21,7 @@ internal class JwtBearerGrantHandler(private val tokenProvider: OAuth2TokenProvi oAuth2TokenCallback: OAuth2TokenCallback ): OAuth2TokenResponse { val tokenRequest = request.asNimbusTokenRequest() - val receivedClaimsSet = assertion(tokenRequest) + val receivedClaimsSet = tokenRequest.assertion() val accessToken = tokenProvider.exchangeAccessToken( tokenRequest, issuerUrl, @@ -31,11 +32,17 @@ internal class JwtBearerGrantHandler(private val tokenProvider: OAuth2TokenProvi tokenType = "Bearer", accessToken = accessToken.serialize(), expiresIn = accessToken.expiresIn(), - scope = tokenRequest.scope.toString() + scope = tokenRequest.responseScope() ) } - private fun assertion(tokenRequest: TokenRequest): JWTClaimsSet = - (tokenRequest.authorizationGrant as? JWTBearerGrant)?.jwtAssertion?.jwtClaimsSet + private fun TokenRequest.responseScope(): String { + return scope?.toString() + ?: assertion().getClaim("scope")?.toString() + ?: invalidRequest("scope must be specified in request or as a claim in assertion parameter") + } + + private fun TokenRequest.assertion(): JWTClaimsSet = + (this.authorizationGrant as? JWTBearerGrant)?.jwtAssertion?.jwtClaimsSet ?: throw OAuth2Exception(OAuth2Error.INVALID_REQUEST, "missing required parameter assertion") } diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/grant/TokenExchangeGrant.kt b/src/main/kotlin/no/nav/security/mock/oauth2/grant/TokenExchangeGrant.kt index 2bc9869e..cfabd772 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/grant/TokenExchangeGrant.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/grant/TokenExchangeGrant.kt @@ -10,7 +10,7 @@ val TOKEN_EXCHANGE = GrantType("urn:ietf:params:oauth:grant-type:token-exchange" class TokenExchangeGrant( val subjectTokenType: String, val subjectToken: String, - val audience: String + val audience: MutableList ) : AuthorizationGrant(TOKEN_EXCHANGE) { override fun toParameters(): MutableMap> = @@ -18,7 +18,7 @@ class TokenExchangeGrant( "grant_type" to mutableListOf(TOKEN_EXCHANGE.value), "subject_token_type" to mutableListOf(subjectTokenType), "subject_token" to mutableListOf(subjectToken), - "audience" to mutableListOf(audience) + "audience" to audience ) companion object { @@ -27,6 +27,8 @@ class TokenExchangeGrant( parameters.require("subject_token_type"), parameters.require("subject_token"), parameters.require("audience") + .split(" ") + .toMutableList() ) } } diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt index e4a72d8d..77b0fb2f 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt @@ -6,11 +6,12 @@ import com.nimbusds.openid.connect.sdk.OIDCScopeValue import java.util.UUID import no.nav.security.mock.oauth2.extensions.clientIdAsString import no.nav.security.mock.oauth2.extensions.grantType +import no.nav.security.mock.oauth2.grant.TokenExchangeGrant interface OAuth2TokenCallback { fun issuerId(): String fun subject(tokenRequest: TokenRequest): String - fun audience(tokenRequest: TokenRequest): String + fun audience(tokenRequest: TokenRequest): List fun addClaims(tokenRequest: TokenRequest): Map fun tokenExpiry(): Long } @@ -19,7 +20,8 @@ interface OAuth2TokenCallback { open class DefaultOAuth2TokenCallback( private val issuerId: String = "default", private val subject: String = UUID.randomUUID().toString(), - private val audience: String? = null, + // needs to be nullable in order to know if a list has explicitly been set, empty list should be a allowable value + private val audience: List? = null, private val claims: Map = emptyMap(), private val expiry: Long = 3600 ) : OAuth2TokenCallback { @@ -33,15 +35,14 @@ open class DefaultOAuth2TokenCallback( } } - override fun audience(tokenRequest: TokenRequest): String { + override fun audience(tokenRequest: TokenRequest): List { val oidcScopeList = OIDCScopeValue.values().map { it.toString() } return audience + ?: (tokenRequest.authorizationGrant as? TokenExchangeGrant)?.audience ?: let { tokenRequest.scope?.toStringList() - ?.filterNot { oidcScopeList.contains(it) }?.firstOrNull() - } - ?: tokenRequest.customParameters["audience"]?.first() - ?: "default" + ?.filterNot { oidcScopeList.contains(it) } + } ?: listOf("default") } override fun addClaims(tokenRequest: TokenRequest): Map = diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt index d89a8ff9..39e19e8f 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt @@ -36,7 +36,7 @@ class OAuth2TokenProvider { defaultClaims( issuerUrl, oAuth2TokenCallback.subject(tokenRequest), - tokenRequest.clientIdAsString(), + listOf(tokenRequest.clientIdAsString()), nonce, oAuth2TokenCallback.addClaims(tokenRequest), oAuth2TokenCallback.tokenExpiry() @@ -90,7 +90,7 @@ class OAuth2TokenProvider { private fun defaultClaims( issuerUrl: HttpUrl, subject: String, - audience: String, + audience: List, nonce: String?, additionalClaims: Map, expiry: Long diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/MockOAuth2ServerTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/MockOAuth2ServerTest.kt index 6d1d5d41..686cdaf8 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/MockOAuth2ServerTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/MockOAuth2ServerTest.kt @@ -261,7 +261,7 @@ class MockOAuth2ServerTest { DefaultOAuth2TokenCallback( issuerId = "custom", subject = "yolo", - audience = "myaud" + audience = listOf("myaud") ) ) @@ -322,7 +322,7 @@ class MockOAuth2ServerTest { DefaultOAuth2TokenCallback( issuerId = "default", subject = "mysub", - audience = "muyaud", + audience = listOf("muyaud"), claims = mapOf("someclaim" to "claimvalue") ) ) diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/e2e/JwtBearerGrantIntegrationTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/e2e/JwtBearerGrantIntegrationTest.kt index f2b72aac..0bc8216e 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/e2e/JwtBearerGrantIntegrationTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/e2e/JwtBearerGrantIntegrationTest.kt @@ -1,7 +1,9 @@ package no.nav.security.mock.oauth2.e2e import com.nimbusds.oauth2.sdk.GrantType +import io.kotest.matchers.collections.shouldBeEmpty import io.kotest.matchers.collections.shouldContainExactly +import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.should import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain @@ -62,4 +64,52 @@ class JwtBearerGrantIntegrationTest { response.accessToken.claims["claim2"] shouldBe "value2" } } + + @Test + fun `token request with JwtBearerGrant should exchange assertion with a new token with scope specified in assertion claim or request parmas`() { + withMockOAuth2Server { + val initialSubject = "mysub" + val initialToken = this.issueToken( + issuerId = "idprovider", + clientId = "client1", + tokenCallback = DefaultOAuth2TokenCallback( + issuerId = "idprovider", + subject = initialSubject, + audience = emptyList(), + claims = mapOf( + "claim1" to "value1", + "claim2" to "value2", + "scope" to "ascope", + "resource" to "aud1", + ) + ) + ) + + initialToken.audience.shouldBeEmpty() + + val issuerId = "aad" + + this.enqueueCallback(DefaultOAuth2TokenCallback(issuerId = issuerId, audience = emptyList())) + + val response: ParsedTokenResponse = client.tokenRequest( + url = this.tokenEndpointUrl(issuerId), + parameters = mapOf( + "grant_type" to GrantType.JWT_BEARER.value, + "assertion" to initialToken.serialize() + ) + ).toTokenResponse() + + println("YOLO:" + response.accessToken?.serialize()) + + response shouldBeValidFor GrantType.JWT_BEARER + response.scope shouldContain "ascope" + response.issuedTokenType shouldBe null + response.accessToken.shouldNotBeNull() + response.accessToken should verifyWith(issuerId, this, listOf("sub", "iss", "iat", "exp")) + response.accessToken.subject shouldBe initialSubject + response.accessToken.audience.shouldBeEmpty() + response.accessToken.claims["claim1"] shouldBe "value1" + response.accessToken.claims["claim2"] shouldBe "value2" + } + } } diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/testutils/Token.kt b/src/test/kotlin/no/nav/security/mock/oauth2/testutils/Token.kt index 4ee271dd..89028e37 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/testutils/Token.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/testutils/Token.kt @@ -79,10 +79,14 @@ infix fun ParsedTokenResponse.shouldBeValidFor(type: GrantType) { } } -fun verifyWith(issuerId: String, server: MockOAuth2Server) = object : Matcher { +fun verifyWith( + issuerId: String, + server: MockOAuth2Server, + requiredClaims: List = listOf("sub", "iss", "iat", "exp", "aud") +) = object : Matcher { override fun test(value: SignedJWT): MatcherResult { return try { - value.verifyWith(server.issuerUrl(issuerId), server.jwksUrl(issuerId)) + value.verifyWith(server.issuerUrl(issuerId), server.jwksUrl(issuerId), requiredClaims) MatcherResult( true, "should not happen, famous last words", @@ -105,7 +109,11 @@ val SignedJWT.issuer: String get() = jwtClaimsSet.issuer val SignedJWT.subject: String get() = jwtClaimsSet.subject val SignedJWT.claims: Map get() = jwtClaimsSet.claims -fun SignedJWT.verifyWith(issuer: HttpUrl, jwkSetUri: HttpUrl): JWTClaimsSet { +fun SignedJWT.verifyWith( + issuer: HttpUrl, + jwkSetUri: HttpUrl, + requiredClaims: List = listOf("sub", "iss", "iat", "exp", "aud") +): JWTClaimsSet { return DefaultJWTProcessor() .apply { jwsKeySelector = JWSVerificationKeySelector(JWSAlgorithm.RS256, RemoteJWKSet(jwkSetUri.toUrl())) @@ -113,9 +121,7 @@ fun SignedJWT.verifyWith(issuer: HttpUrl, jwkSetUri: HttpUrl): JWTClaimsSet { JWTClaimsSet.Builder() .issuer(issuer.toString()) .build(), - HashSet( - listOf("sub", "iss", "iat", "exp", "aud") - ) + HashSet(requiredClaims) ) }.process(this, null) }