diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt index 56b8a1d3..3a4b9f39 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt @@ -60,13 +60,11 @@ open class DefaultOAuth2TokenCallback } override fun addClaims(tokenRequest: TokenRequest): Map = - claims.toMutableMap().apply { - putAll( - mapOf( - "azp" to tokenRequest.clientIdAsString(), - "tid" to issuerId, - ), - ) + mutableMapOf( + "tid" to issuerId, + ).apply { + putAll(claims) + put("azp", tokenRequest.clientIdAsString()) } override fun tokenExpiry(): Long = expiry diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt index 964d2a1f..00581e69 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt @@ -149,6 +149,19 @@ internal class OAuth2TokenCallbackTest { } } } + + @Test + fun `Allow overriding tid`() { + val tokenRequest = clientCredentialsRequest() + DefaultOAuth2TokenCallback().asClue { + it.addClaims(tokenRequest) shouldContainAll mapOf("tid" to "default") + } + + DefaultOAuth2TokenCallback(claims = mapOf("tid" to "test-tid")).asClue { + it.addClaims(tokenRequest) shouldContainAll mapOf("tid" to "test-tid") + } + } + } private fun authCodeRequest(vararg formParams: Pair) =