forked from navikt/mock-oauth2-server
-
Notifications
You must be signed in to change notification settings - Fork 0
/
OAuth2TokenCallback.kt
119 lines (96 loc) · 4.56 KB
/
OAuth2TokenCallback.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package no.nav.security.mock.oauth2.token
import com.nimbusds.jose.JOSEObjectType
import com.nimbusds.oauth2.sdk.GrantType
import com.nimbusds.oauth2.sdk.TokenRequest
import no.nav.security.mock.oauth2.extensions.clientIdAsString
import no.nav.security.mock.oauth2.extensions.grantType
import no.nav.security.mock.oauth2.extensions.scopesWithoutOidcScopes
import no.nav.security.mock.oauth2.extensions.tokenExchangeGrantOrNull
import java.time.Duration
import java.util.UUID
interface OAuth2TokenCallback {
fun issuerId(): String
fun subject(tokenRequest: TokenRequest): String?
fun typeHeader(tokenRequest: TokenRequest): String
fun audience(tokenRequest: TokenRequest): List<String>
fun addClaims(tokenRequest: TokenRequest): Map<String, Any>
fun tokenExpiry(): Long
}
// TODO: for JwtBearerGrant and TokenExchange should be able to ovverride sub, make sub nullable and return some default
open class DefaultOAuth2TokenCallback
@JvmOverloads
constructor(
private val issuerId: String = "default",
private val subject: String = UUID.randomUUID().toString(),
private val typeHeader: String = JOSEObjectType.JWT.type,
// needs to be nullable in order to know if a list has explicitly been set, empty list should be a allowable value
private val audience: List<String>? = null,
private val claims: Map<String, Any> = emptyMap(),
private val expiry: Long = 3600,
) : OAuth2TokenCallback {
override fun issuerId(): String = issuerId
override fun subject(tokenRequest: TokenRequest): String {
return when (GrantType.CLIENT_CREDENTIALS) {
tokenRequest.grantType() -> tokenRequest.clientIdAsString()
else -> subject
}
}
override fun typeHeader(tokenRequest: TokenRequest): String {
return typeHeader
}
override fun audience(tokenRequest: TokenRequest): List<String> {
val audienceParam = tokenRequest.tokenExchangeGrantOrNull()?.audience
return when {
audience != null -> audience
audienceParam != null -> audienceParam
tokenRequest.scope != null -> tokenRequest.scopesWithoutOidcScopes()
else -> listOf("default")
}
}
override fun addClaims(tokenRequest: TokenRequest): Map<String, Any> =
claims.toMutableMap().apply {
putAll(
mapOf(
"azp" to tokenRequest.clientIdAsString(),
"tid" to issuerId,
),
)
}
override fun tokenExpiry(): Long = expiry
}
data class RequestMappingTokenCallback(
val issuerId: String,
val requestMappings: Set<RequestMapping>,
val tokenExpiry: Long = Duration.ofHours(1).toSeconds(),
) : OAuth2TokenCallback {
override fun issuerId(): String = issuerId
override fun subject(tokenRequest: TokenRequest): String? = requestMappings.getClaimOrNull(tokenRequest, "sub")
override fun typeHeader(tokenRequest: TokenRequest): String = requestMappings.getTypeHeader(tokenRequest)
override fun audience(tokenRequest: TokenRequest): List<String> = requestMappings.getClaimOrNull(tokenRequest, "aud") ?: emptyList()
override fun addClaims(tokenRequest: TokenRequest): Map<String, Any> = requestMappings.getClaims(tokenRequest)
override fun tokenExpiry(): Long = tokenExpiry
private fun Set<RequestMapping>.getClaims(tokenRequest: TokenRequest): Map<String, Any> {
val claims = firstOrNull { it.isMatch(tokenRequest) }?.claims ?: emptyMap()
return if (tokenRequest.grantType() == GrantType.CLIENT_CREDENTIALS && claims["sub"] == "\${clientId}") {
claims + ("sub" to tokenRequest.clientIdAsString())
} else {
claims
}
}
private inline fun <reified T> Set<RequestMapping>.getClaimOrNull(
tokenRequest: TokenRequest,
key: String,
): T? = getClaims(tokenRequest)[key] as? T
private fun Set<RequestMapping>.getTypeHeader(tokenRequest: TokenRequest) = firstOrNull { it.isMatch(tokenRequest) }?.typeHeader ?: JOSEObjectType.JWT.type
}
data class RequestMapping(
private val requestParam: String,
private val match: String = "*",
val claims: Map<String, Any> = emptyMap(),
val typeHeader: String = JOSEObjectType.JWT.type,
) {
fun isMatch(tokenRequest: TokenRequest): Boolean =
tokenRequest.toHTTPRequest().queryParameters[requestParam]?.any {
if (match != "*") it == match else true
} ?: false
}