From 4e3819d853f3e5d5c8a76dea3afd75209c6c8332 Mon Sep 17 00:00:00 2001 From: Thomas Oddsund <8527574+oddsund@users.noreply.github.com> Date: Wed, 17 Apr 2024 09:20:26 +0200 Subject: [PATCH] feat(DefaultOAuth2TokenCallback): Allow overriding tid claim (#663) As it's not a standard claim, it should be possible to override it by the consumer of the library. --- .../mock/oauth2/token/OAuth2TokenCallback.kt | 12 +++++------- .../mock/oauth2/token/OAuth2TokenCallbackTest.kt | 13 +++++++++++++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt index 56b8a1d3..3a4b9f39 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt @@ -60,13 +60,11 @@ open class DefaultOAuth2TokenCallback } override fun addClaims(tokenRequest: TokenRequest): Map = - claims.toMutableMap().apply { - putAll( - mapOf( - "azp" to tokenRequest.clientIdAsString(), - "tid" to issuerId, - ), - ) + mutableMapOf( + "tid" to issuerId, + ).apply { + putAll(claims) + put("azp", tokenRequest.clientIdAsString()) } override fun tokenExpiry(): Long = expiry diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt index 964d2a1f..00581e69 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt @@ -149,6 +149,19 @@ internal class OAuth2TokenCallbackTest { } } } + + @Test + fun `Allow overriding tid`() { + val tokenRequest = clientCredentialsRequest() + DefaultOAuth2TokenCallback().asClue { + it.addClaims(tokenRequest) shouldContainAll mapOf("tid" to "default") + } + + DefaultOAuth2TokenCallback(claims = mapOf("tid" to "test-tid")).asClue { + it.addClaims(tokenRequest) shouldContainAll mapOf("tid" to "test-tid") + } + } + } private fun authCodeRequest(vararg formParams: Pair) =