From fb191cc84ac4e63618ee648141fcfef01d02fbd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tommy=20Tr=C3=B8en?= Date: Sat, 21 Mar 2020 18:08:48 +0100 Subject: [PATCH] Feature - add interactive login page (#1) * serve login html when login prompt is specified * formatting, refactoring, and fix failing test for interactive login * set interactivelogin=true as default for standalone server * show debug info on login page * enqueue MockResponse and use this if added instead of OAuth2HttpRequestHandler --- build.gradle.kts | 3 +- .../no/nav/security/mock/MockOAuth2Server.kt | 76 --------- .../mock/extensions/MockResponseExtensions.kt | 44 ----- .../extensions/RecordedRequestExtensions.kt | 29 ---- .../security/mock/oauth2/MockOAuth2Server.kt | 106 ++++++++++++ .../nav/security/mock/oauth2/OAuth2Config.kt | 10 ++ .../security/mock/oauth2/OAuth2Dispatcher.kt | 154 ------------------ .../security/mock/oauth2/OAuth2Response.kt | 36 ---- .../StandaloneMockOAuth2Server.kt | 8 +- .../extensions/HttpUrlExtensions.kt | 2 +- .../extensions/NimbusExtensions.kt | 9 +- .../extensions/RecordedRequestExtensions.kt | 7 + .../grant/AuthorizationCodeGrantHandler.kt | 80 ++++++--- .../grant/ClientCredentialsGrantHandler.kt | 12 +- .../mock/oauth2/grant/GrantHandler.kt | 6 +- .../oauth2/grant/JwtBearerGrantHandler.kt | 12 +- .../mock/oauth2/http/OAuth2HttpRequest.kt | 43 +++++ .../oauth2/http/OAuth2HttpRequestHandler.kt | 146 +++++++++++++++++ .../mock/oauth2/http/OAuth2HttpResponse.kt | 101 ++++++++++++ .../mock/oauth2/login/LoginRequestHandler.kt | 23 +++ .../mock/oauth2/templates/TemplateMapper.kt | 42 +++++ .../token/OAuth2TokenCallback.kt} | 16 +- .../oauth2/{ => token}/OAuth2TokenProvider.kt | 42 ++--- src/main/resources/templates/layout.ftl | 137 ++++++++++++++++ src/main/resources/templates/login.ftl | 20 +++ .../mock/{ => oauth2}/MockOAuth2ServerTest.kt | 154 ++++++++++++++++-- .../grant/AuthorizationCodeHandlerTest.kt | 115 +++++++++++++ 27 files changed, 1013 insertions(+), 420 deletions(-) delete mode 100644 src/main/kotlin/no/nav/security/mock/MockOAuth2Server.kt delete mode 100644 src/main/kotlin/no/nav/security/mock/extensions/MockResponseExtensions.kt delete mode 100644 src/main/kotlin/no/nav/security/mock/extensions/RecordedRequestExtensions.kt create mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/MockOAuth2Server.kt create mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Config.kt delete mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Dispatcher.kt delete mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Response.kt rename src/main/kotlin/no/nav/security/mock/{ => oauth2}/StandaloneMockOAuth2Server.kt (80%) rename src/main/kotlin/no/nav/security/mock/{ => oauth2}/extensions/HttpUrlExtensions.kt (96%) rename src/main/kotlin/no/nav/security/mock/{ => oauth2}/extensions/NimbusExtensions.kt (76%) create mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/extensions/RecordedRequestExtensions.kt create mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequest.kt create mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt create mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpResponse.kt create mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/login/LoginRequestHandler.kt create mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/templates/TemplateMapper.kt rename src/main/kotlin/no/nav/security/mock/{callback/TokenCallback.kt => oauth2/token/OAuth2TokenCallback.kt} (85%) rename src/main/kotlin/no/nav/security/mock/oauth2/{ => token}/OAuth2TokenProvider.kt (79%) create mode 100644 src/main/resources/templates/layout.ftl create mode 100644 src/main/resources/templates/login.ftl rename src/test/kotlin/no/nav/security/mock/{ => oauth2}/MockOAuth2ServerTest.kt (58%) create mode 100644 src/test/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeHandlerTest.kt diff --git a/build.gradle.kts b/build.gradle.kts index f9df2410..303f7ebc 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -9,7 +9,7 @@ val jacksonVersion = "2.10.1" val junitJupiterVersion = "5.5.2" val konfigVersion = "1.6.10.0" val kotlinVersion = "1.3.61" - +val freemarkerVersion = "2.3.29" val mavenRepoBaseUrl = "https://oss.sonatype.org" val mainClassKt = "no.nav.security.mock.StandaloneMockOAuth2ServerKt" @@ -54,6 +54,7 @@ dependencies { api("com.nimbusds:oauth2-oidc-sdk:$nimbusSdkVersion") implementation("io.github.microutils:kotlin-logging:$kotlinLoggingVersion") implementation("com.fasterxml.jackson.module:jackson-module-kotlin:$jacksonVersion") + implementation("org.freemarker:freemarker:$freemarkerVersion") testImplementation("org.assertj:assertj-core:$assertjVersion") testImplementation("org.junit.jupiter:junit-jupiter-api:$junitJupiterVersion") testImplementation("org.jetbrains.kotlin:kotlin-test-junit5:$kotlinVersion") diff --git a/src/main/kotlin/no/nav/security/mock/MockOAuth2Server.kt b/src/main/kotlin/no/nav/security/mock/MockOAuth2Server.kt deleted file mode 100644 index b041dc7d..00000000 --- a/src/main/kotlin/no/nav/security/mock/MockOAuth2Server.kt +++ /dev/null @@ -1,76 +0,0 @@ -package no.nav.security.mock - -import com.nimbusds.jwt.SignedJWT -import com.nimbusds.oauth2.sdk.AuthorizationCode -import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant -import com.nimbusds.oauth2.sdk.TokenRequest -import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic -import com.nimbusds.oauth2.sdk.auth.Secret -import com.nimbusds.oauth2.sdk.id.ClientID -import mu.KotlinLogging -import no.nav.security.mock.callback.DefaultTokenCallback -import no.nav.security.mock.callback.TokenCallback -import no.nav.security.mock.extensions.toAuthorizationEndpointUrl -import no.nav.security.mock.extensions.toJwksUrl -import no.nav.security.mock.extensions.toTokenEndpointUrl -import no.nav.security.mock.extensions.toWellKnownUrl -import no.nav.security.mock.oauth2.OAuth2Dispatcher -import no.nav.security.mock.oauth2.OAuth2TokenProvider -import okhttp3.HttpUrl -import okhttp3.mockwebserver.Dispatcher -import okhttp3.mockwebserver.MockWebServer -import okhttp3.mockwebserver.RecordedRequest -import java.io.IOException -import java.net.InetSocketAddress -import java.net.URI - -private val log = KotlinLogging.logger {} - -class MockOAuth2Server( - tokenCallbacks: Set = setOf(DefaultTokenCallback()) -) { - private val mockWebServer: MockWebServer = MockWebServer() - private val tokenProvider: OAuth2TokenProvider = OAuth2TokenProvider() - - var dispatcher: Dispatcher = OAuth2Dispatcher(tokenProvider, tokenCallbacks) - - fun start() { - mockWebServer.start() - mockWebServer.dispatcher = dispatcher - } - - fun start(port: Int = 0) { - val address = InetSocketAddress(0).address - log.info("attempting to start server on port $port and InetAddress=$address") - mockWebServer.start(address, port) - mockWebServer.dispatcher = dispatcher - } - - @Throws(IOException::class) - fun shutdown() { - mockWebServer.shutdown() - } - - fun enqueueCallback(tokenCallback: TokenCallback) = - (dispatcher as OAuth2Dispatcher).enqueueJwtCallback(tokenCallback) - - fun takeRequest(): RecordedRequest = mockWebServer.takeRequest() - - fun wellKnownUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toWellKnownUrl() - fun tokenEndpointUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toTokenEndpointUrl() - fun jwksUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toJwksUrl() - fun issuerUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId) - fun authorizationEndpointUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toAuthorizationEndpointUrl() - fun baseUrl(): HttpUrl = mockWebServer.url("") - - fun issueToken(issuerId: String, clientId: String, tokenCallback: TokenCallback): SignedJWT { - val uri = tokenEndpointUrl(issuerId) - val issuerUrl = issuerUrl(issuerId) - val tokenRequest = TokenRequest( - uri.toUri(), - ClientSecretBasic(ClientID(clientId), Secret("secret")), - AuthorizationCodeGrant(AuthorizationCode("123"), URI.create("http://localhost")) - ) - return tokenProvider.accessToken(tokenRequest, issuerUrl, null, tokenCallback) - } -} diff --git a/src/main/kotlin/no/nav/security/mock/extensions/MockResponseExtensions.kt b/src/main/kotlin/no/nav/security/mock/extensions/MockResponseExtensions.kt deleted file mode 100644 index 11bb922d..00000000 --- a/src/main/kotlin/no/nav/security/mock/extensions/MockResponseExtensions.kt +++ /dev/null @@ -1,44 +0,0 @@ -package no.nav.security.mock.extensions - -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.databind.SerializationFeature -import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper -import com.nimbusds.oauth2.sdk.ErrorObject -import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse -import okhttp3.mockwebserver.MockResponse - -private val objectMapper: ObjectMapper = jacksonObjectMapper() - -fun MockResponse.json(anyObject: Any): MockResponse = - jsonWithCode(200, anyObject) - -fun MockResponse.jsonWithCode(statusCode: Int, anyObject: Any): MockResponse = - this.setResponseCode(statusCode) - .setHeader("Content-Type", "application/json;charset=UTF-8") - .setBody( - when (anyObject) { - is String -> anyObject - else -> objectMapper - .enable(SerializationFeature.INDENT_OUTPUT) - .writeValueAsString(anyObject) - } - ) - -fun MockResponse.oauth2Error(error: ErrorObject): MockResponse { - val responseCode = error.httpStatusCode.takeUnless { it == 302 } ?: 400 - return this.setResponseCode(responseCode) - .setHeader("Content-Type", "application/json;charset=UTF-8") - .setBody( - objectMapper - .enable(SerializationFeature.INDENT_OUTPUT) - .writeValueAsString(error.toJSONObject()) - .toLowerCase() - ) -} - -fun MockResponse.authenticationSuccess( - authenticationSuccessResponse: AuthenticationSuccessResponse -): MockResponse { - val httpResponse = authenticationSuccessResponse.toHTTPResponse() - return this.setResponseCode(httpResponse.statusCode).setHeader("Location", httpResponse.location) -} diff --git a/src/main/kotlin/no/nav/security/mock/extensions/RecordedRequestExtensions.kt b/src/main/kotlin/no/nav/security/mock/extensions/RecordedRequestExtensions.kt deleted file mode 100644 index 6abb9aca..00000000 --- a/src/main/kotlin/no/nav/security/mock/extensions/RecordedRequestExtensions.kt +++ /dev/null @@ -1,29 +0,0 @@ -package no.nav.security.mock.extensions - -import no.nav.security.mock.oauth2.OAuth2Exception -import com.nimbusds.oauth2.sdk.OAuth2Error -import com.nimbusds.oauth2.sdk.TokenRequest -import com.nimbusds.oauth2.sdk.http.HTTPRequest -import com.nimbusds.openid.connect.sdk.AuthenticationRequest -import okhttp3.mockwebserver.RecordedRequest - -fun RecordedRequest.issuerId(): String = - this.requestUrl?.pathSegments - ?.first() - ?: throw OAuth2Exception(OAuth2Error.INVALID_REQUEST, "issuerid must be first segment in url path") - -fun RecordedRequest.asTokenRequest(): TokenRequest = - TokenRequest.parse(fromFormParameters(this)) - -fun RecordedRequest.asAuthenticationRequest(): AuthenticationRequest = - AuthenticationRequest.parse(this.requestUrl!!.toUri()) - -private fun fromFormParameters(request: RecordedRequest): HTTPRequest { - val httpRequest = HTTPRequest( - HTTPRequest.Method.valueOf(request.method!!), - request.requestUrl!!.toUrl() - ) - request.headers.forEach { httpRequest.setHeader(it.first, it.second) } - httpRequest.query = request.body.readUtf8() - return httpRequest -} diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/MockOAuth2Server.kt b/src/main/kotlin/no/nav/security/mock/oauth2/MockOAuth2Server.kt new file mode 100644 index 00000000..66f0d8a8 --- /dev/null +++ b/src/main/kotlin/no/nav/security/mock/oauth2/MockOAuth2Server.kt @@ -0,0 +1,106 @@ +package no.nav.security.mock.oauth2 + +import com.nimbusds.jwt.SignedJWT +import com.nimbusds.oauth2.sdk.AuthorizationCode +import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant +import com.nimbusds.oauth2.sdk.TokenRequest +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic +import com.nimbusds.oauth2.sdk.auth.Secret +import com.nimbusds.oauth2.sdk.id.ClientID +import mu.KotlinLogging +import no.nav.security.mock.oauth2.extensions.asOAuth2HttpRequest +import no.nav.security.mock.oauth2.extensions.toAuthorizationEndpointUrl +import no.nav.security.mock.oauth2.extensions.toJwksUrl +import no.nav.security.mock.oauth2.extensions.toTokenEndpointUrl +import no.nav.security.mock.oauth2.extensions.toWellKnownUrl +import no.nav.security.mock.oauth2.http.OAuth2HttpRequestHandler +import no.nav.security.mock.oauth2.http.OAuth2HttpResponse +import no.nav.security.mock.oauth2.token.OAuth2TokenCallback +import no.nav.security.mock.oauth2.token.OAuth2TokenProvider +import okhttp3.HttpUrl +import okhttp3.mockwebserver.Dispatcher +import okhttp3.mockwebserver.MockResponse +import okhttp3.mockwebserver.MockWebServer +import okhttp3.mockwebserver.RecordedRequest +import java.io.IOException +import java.net.InetSocketAddress +import java.net.URI +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue + +private val log = KotlinLogging.logger {} + +class MockOAuth2Server( + config: OAuth2Config = OAuth2Config() +) { + private val mockWebServer: MockWebServer = MockWebServer() + private val tokenProvider: OAuth2TokenProvider = + OAuth2TokenProvider() + + var dispatcher: Dispatcher = MockOAuth2Dispatcher(config) + + fun start() { + mockWebServer.start() + mockWebServer.dispatcher = dispatcher + } + + fun start(port: Int = 0) { + val address = InetSocketAddress(0).address + log.info("attempting to start server on port $port and InetAddress=$address") + mockWebServer.start(address, port) + mockWebServer.dispatcher = dispatcher + } + + @Throws(IOException::class) + fun shutdown() { + mockWebServer.shutdown() + } + + fun url(path: String): HttpUrl = mockWebServer.url(path) + fun enqueueResponse(response: MockResponse) = (dispatcher as MockOAuth2Dispatcher).enqueueResponse(response) + fun enqueueCallback(oAuth2TokenCallback: OAuth2TokenCallback) = (dispatcher as MockOAuth2Dispatcher).enqueueTokenCallback(oAuth2TokenCallback) + fun takeRequest(): RecordedRequest = mockWebServer.takeRequest() + + fun wellKnownUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toWellKnownUrl() + fun tokenEndpointUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toTokenEndpointUrl() + fun jwksUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toJwksUrl() + fun issuerUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId) + fun authorizationEndpointUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toAuthorizationEndpointUrl() + fun baseUrl(): HttpUrl = mockWebServer.url("") + + fun issueToken(issuerId: String, clientId: String, OAuth2TokenCallback: OAuth2TokenCallback): SignedJWT { + val uri = tokenEndpointUrl(issuerId) + val issuerUrl = issuerUrl(issuerId) + val tokenRequest = TokenRequest( + uri.toUri(), + ClientSecretBasic(ClientID(clientId), Secret("secret")), + AuthorizationCodeGrant(AuthorizationCode("123"), URI.create("http://localhost")) + ) + return tokenProvider.accessToken(tokenRequest, issuerUrl, null, OAuth2TokenCallback) + } +} + +class MockOAuth2Dispatcher( + config: OAuth2Config +) : Dispatcher() { + private val httpRequestHandler: OAuth2HttpRequestHandler = OAuth2HttpRequestHandler(config) + private val responseQueue: BlockingQueue = LinkedBlockingQueue() + + fun enqueueResponse(mockResponse: MockResponse) = responseQueue.add(mockResponse) + fun enqueueTokenCallback(oAuth2TokenCallback: OAuth2TokenCallback) = httpRequestHandler.enqueueTokenCallback(oAuth2TokenCallback) + + override fun dispatch(request: RecordedRequest): MockResponse = + when { + responseQueue.peek() != null -> responseQueue.take() + else -> mockResponse(httpRequestHandler.handleRequest(request.asOAuth2HttpRequest())) + } + + + private fun mockResponse(response: OAuth2HttpResponse): MockResponse = + MockResponse() + .setHeaders(response.headers) + .setResponseCode(response.status) + .apply { + response.body?.let { this.setBody(it) } + } +} diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Config.kt b/src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Config.kt new file mode 100644 index 00000000..d401e7c2 --- /dev/null +++ b/src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Config.kt @@ -0,0 +1,10 @@ +package no.nav.security.mock.oauth2 + +import no.nav.security.mock.oauth2.token.OAuth2TokenCallback +import no.nav.security.mock.oauth2.token.OAuth2TokenProvider + +data class OAuth2Config( + val interactiveLogin: Boolean = false, + val tokenProvider: OAuth2TokenProvider = OAuth2TokenProvider(), + val oAuth2TokenCallbacks: Set = emptySet() +) diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Dispatcher.kt b/src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Dispatcher.kt deleted file mode 100644 index d93a76fa..00000000 --- a/src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Dispatcher.kt +++ /dev/null @@ -1,154 +0,0 @@ -package no.nav.security.mock.oauth2 - -import com.nimbusds.oauth2.sdk.ErrorObject -import com.nimbusds.oauth2.sdk.GeneralException -import com.nimbusds.oauth2.sdk.GrantType -import com.nimbusds.oauth2.sdk.OAuth2Error -import com.nimbusds.oauth2.sdk.ParseException -import com.nimbusds.oauth2.sdk.TokenRequest -import com.nimbusds.openid.connect.sdk.AuthenticationRequest -import mu.KotlinLogging -import no.nav.security.mock.callback.DefaultTokenCallback -import no.nav.security.mock.callback.TokenCallback -import no.nav.security.mock.extensions.asAuthenticationRequest -import no.nav.security.mock.extensions.asTokenRequest -import no.nav.security.mock.extensions.authenticationSuccess -import no.nav.security.mock.extensions.grantType -import no.nav.security.mock.extensions.isAuthorizationEndpointUrl -import no.nav.security.mock.extensions.isJwksUrl -import no.nav.security.mock.extensions.isTokenEndpointUrl -import no.nav.security.mock.extensions.isWellKnownUrl -import no.nav.security.mock.extensions.issuerId -import no.nav.security.mock.extensions.json -import no.nav.security.mock.extensions.oauth2Error -import no.nav.security.mock.extensions.toAuthorizationEndpointUrl -import no.nav.security.mock.extensions.toIssuerUrl -import no.nav.security.mock.extensions.toJwksUrl -import no.nav.security.mock.extensions.toTokenEndpointUrl -import no.nav.security.mock.oauth2.grant.AuthorizationCodeHandler -import no.nav.security.mock.oauth2.grant.ClientCredentialsGrantHandler -import no.nav.security.mock.oauth2.grant.GrantHandler -import no.nav.security.mock.oauth2.grant.JwtBearerGrantHandler -import okhttp3.HttpUrl -import okhttp3.mockwebserver.Dispatcher -import okhttp3.mockwebserver.MockResponse -import okhttp3.mockwebserver.RecordedRequest -import java.util.concurrent.BlockingQueue -import java.util.concurrent.LinkedBlockingQueue - -private val log = KotlinLogging.logger {} - -// TODO: support more flows and oidc session management / logout -class OAuth2Dispatcher( - private val tokenProvider: OAuth2TokenProvider = OAuth2TokenProvider(), - private val tokenCallbacks: Set = setOf(DefaultTokenCallback(audience = "default")) -) : Dispatcher() { - - private val tokenCallbackQueue: BlockingQueue = LinkedBlockingQueue() - - private val grantHandlers: Map = mapOf( - GrantType.AUTHORIZATION_CODE to AuthorizationCodeHandler(tokenProvider), - GrantType.CLIENT_CREDENTIALS to ClientCredentialsGrantHandler(tokenProvider), - GrantType.JWT_BEARER to JwtBearerGrantHandler(tokenProvider) - ) - - private fun takeJwtCallbackOrCreateDefault(issuerId: String): TokenCallback { - if (tokenCallbackQueue.peek()?.issuerId() == issuerId) { - return tokenCallbackQueue.take() - } - return tokenCallbacks.firstOrNull { it.issuerId() == issuerId } - ?: DefaultTokenCallback(issuerId = issuerId) - } - - fun enqueueJwtCallback(tokenCallback: TokenCallback) = tokenCallbackQueue.add(tokenCallback) - - override fun dispatch(request: RecordedRequest): MockResponse { - return runCatching { - handleRequest(request) - }.fold( - onSuccess = { result -> result }, - onFailure = { error -> handleException(error) } - ) - } - - private fun handleRequest(request: RecordedRequest): MockResponse { - log.debug("received request on url=${request.requestUrl} with headers=${request.headers}") - val issuerId: String = request.issuerId() - val url = checkNotNull(request.requestUrl) - - return when { - url.isWellKnownUrl() -> { - log.debug("returning well-known json data for url=$url") - MockResponse().json(wellKnown(request)) - } - url.isAuthorizationEndpointUrl() -> { - log.debug("redirecting to callback with auth code") - val authRequest: AuthenticationRequest = request.asAuthenticationRequest() - - when { - authRequest.responseType.impliesCodeFlow() -> { - MockResponse().authenticationSuccess( - (grantHandlers[GrantType.AUTHORIZATION_CODE] as AuthorizationCodeHandler) - .authorizationCodeResponse(request.asAuthenticationRequest()) - ) - } - else -> throw OAuth2Exception( - OAuth2Error.INVALID_GRANT, "hybrid og implicit flow not supported (yet)." - ) - } - } - url.isTokenEndpointUrl() -> { - log.debug("handle token request $request") - val tokenCallback: TokenCallback = takeJwtCallbackOrCreateDefault(issuerId) - val tokenRequest: TokenRequest = request.asTokenRequest().also { - log.debug("query in tokenreq: ${it.toHTTPRequest().query}") - } - val issuerUrl: HttpUrl = issuerUrl(request) - MockResponse().json( - grantHandler(tokenRequest.grantType()).tokenResponse(tokenRequest, issuerUrl, tokenCallback) - ) - } - url.isJwksUrl() -> { - log.debug("handle jwks request") - MockResponse().json(tokenProvider.publicJwkSet().toJSONObject()) - } - else -> { - val msg = "path '${request.requestUrl}' not found" - log.error(msg) - MockResponse().setResponseCode(404).setBody(msg) - } - } - } - - private fun handleException(error: Throwable): MockResponse { - log.error("received exception when handling request.", error) - val errorObject: ErrorObject = when (error) { - is OAuth2Exception -> error.errorObject - is ParseException -> error.errorObject - ?: OAuth2Error.INVALID_REQUEST - .appendDescription(". received exception message: ${error.message}") - is GeneralException -> error.errorObject - else -> null - } ?: OAuth2Error.SERVER_ERROR - .appendDescription(". received exception message: ${error.message}") - - return MockResponse().oauth2Error(errorObject) - } - - private fun wellKnown(request: RecordedRequest): WellKnown = - WellKnown( - issuer = request.requestUrl?.toIssuerUrl().toString(), - authorizationEndpoint = request.requestUrl?.toAuthorizationEndpointUrl().toString(), - tokenEndpoint = request.requestUrl?.toTokenEndpointUrl().toString(), - jwksUri = request.requestUrl?.toJwksUrl().toString() - ) - - private fun issuerUrl(request: RecordedRequest): HttpUrl = - request.requestUrl?.toIssuerUrl() - ?: throw OAuth2Exception(OAuth2Error.INVALID_REQUEST, "issuerid must be first segment in url path") - - private fun grantHandler(grantType: GrantType): GrantHandler = - grantHandlers[grantType] ?: throw OAuth2Exception( - OAuth2Error.INVALID_GRANT, "grant_type $grantType not supported." - ) -} diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Response.kt b/src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Response.kt deleted file mode 100644 index b43a23ba..00000000 --- a/src/main/kotlin/no/nav/security/mock/oauth2/OAuth2Response.kt +++ /dev/null @@ -1,36 +0,0 @@ -package no.nav.security.mock.oauth2 - -import com.fasterxml.jackson.annotation.JsonInclude -import com.fasterxml.jackson.annotation.JsonProperty - -data class WellKnown( - val issuer: String, - @JsonProperty("authorization_endpoint") - val authorizationEndpoint: String?, - @JsonProperty("token_endpoint") - val tokenEndpoint: String?, - @JsonProperty("jwks_uri") - val jwksUri: String?, - @JsonProperty("response_types_supported") - val responseTypesSupported: List = listOf("query", "fragment", "form_post"), - @JsonProperty("subject_types_supported") - val subjectTypesSupported: List = listOf("public"), - @JsonProperty("id_token_signing_alg_values_supported") - val idTokenSigningAlgValuesSupported: List = listOf("RS256") -) - -@JsonInclude(JsonInclude.Include.NON_NULL) -data class OAuth2TokenResponse( - @JsonProperty("token_type") - val tokenType: String, - @JsonProperty("id_token") - val idToken: String? = null, - @JsonProperty("access_token") - val accessToken: String?, - @JsonProperty("refresh_token") - val refreshToken: String? = null, - @JsonProperty("expires_in") - val expiresIn: Int = 0, - @JsonProperty("scope") - val scope: String? = null -) diff --git a/src/main/kotlin/no/nav/security/mock/StandaloneMockOAuth2Server.kt b/src/main/kotlin/no/nav/security/mock/oauth2/StandaloneMockOAuth2Server.kt similarity index 80% rename from src/main/kotlin/no/nav/security/mock/StandaloneMockOAuth2Server.kt rename to src/main/kotlin/no/nav/security/mock/oauth2/StandaloneMockOAuth2Server.kt index 66f687c9..e6d0abb7 100644 --- a/src/main/kotlin/no/nav/security/mock/StandaloneMockOAuth2Server.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/StandaloneMockOAuth2Server.kt @@ -1,4 +1,4 @@ -package no.nav.security.mock +package no.nav.security.mock.oauth2 import com.natpryce.konfig.ConfigurationProperties import com.natpryce.konfig.EnvironmentVariables @@ -21,5 +21,9 @@ data class Configuration( fun main() { val config = Configuration() - MockOAuth2Server().start(config.server.port) + MockOAuth2Server( + OAuth2Config( + interactiveLogin = true + ) + ).start(config.server.port) } diff --git a/src/main/kotlin/no/nav/security/mock/extensions/HttpUrlExtensions.kt b/src/main/kotlin/no/nav/security/mock/oauth2/extensions/HttpUrlExtensions.kt similarity index 96% rename from src/main/kotlin/no/nav/security/mock/extensions/HttpUrlExtensions.kt rename to src/main/kotlin/no/nav/security/mock/oauth2/extensions/HttpUrlExtensions.kt index 10eeb7d2..1e83916f 100644 --- a/src/main/kotlin/no/nav/security/mock/extensions/HttpUrlExtensions.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/extensions/HttpUrlExtensions.kt @@ -1,4 +1,4 @@ -package no.nav.security.mock.extensions +package no.nav.security.mock.oauth2.extensions import com.nimbusds.oauth2.sdk.OAuth2Error import no.nav.security.mock.oauth2.OAuth2Exception diff --git a/src/main/kotlin/no/nav/security/mock/extensions/NimbusExtensions.kt b/src/main/kotlin/no/nav/security/mock/oauth2/extensions/NimbusExtensions.kt similarity index 76% rename from src/main/kotlin/no/nav/security/mock/extensions/NimbusExtensions.kt rename to src/main/kotlin/no/nav/security/mock/oauth2/extensions/NimbusExtensions.kt index d1ddee75..b3b74f4b 100644 --- a/src/main/kotlin/no/nav/security/mock/extensions/NimbusExtensions.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/extensions/NimbusExtensions.kt @@ -1,4 +1,4 @@ -package no.nav.security.mock.extensions +package no.nav.security.mock.oauth2.extensions import com.nimbusds.jwt.SignedJWT import com.nimbusds.oauth2.sdk.AuthorizationCode @@ -6,10 +6,17 @@ import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant import com.nimbusds.oauth2.sdk.GrantType import com.nimbusds.oauth2.sdk.OAuth2Error import com.nimbusds.oauth2.sdk.TokenRequest +import com.nimbusds.openid.connect.sdk.AuthenticationRequest +import com.nimbusds.openid.connect.sdk.Prompt import no.nav.security.mock.oauth2.OAuth2Exception import java.time.Duration import java.time.Instant +fun AuthenticationRequest.isPrompt(): Boolean = + this.prompt?.any { + it == Prompt.Type.LOGIN || it == Prompt.Type.CONSENT || it == Prompt.Type.SELECT_ACCOUNT + } ?: false + fun TokenRequest.grantType(): GrantType = this.authorizationGrant?.type ?: throw OAuth2Exception(OAuth2Error.INVALID_REQUEST, "missing required parameter grant_type") diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/extensions/RecordedRequestExtensions.kt b/src/main/kotlin/no/nav/security/mock/oauth2/extensions/RecordedRequestExtensions.kt new file mode 100644 index 00000000..f332199b --- /dev/null +++ b/src/main/kotlin/no/nav/security/mock/oauth2/extensions/RecordedRequestExtensions.kt @@ -0,0 +1,7 @@ +package no.nav.security.mock.oauth2.extensions + +import no.nav.security.mock.oauth2.http.OAuth2HttpRequest +import okhttp3.mockwebserver.RecordedRequest + +fun RecordedRequest.asOAuth2HttpRequest(): OAuth2HttpRequest = + OAuth2HttpRequest(this.headers, this.method!!, this.requestUrl!!, this.body.readUtf8()) diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeGrantHandler.kt b/src/main/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeGrantHandler.kt index fe2f72c3..616a164d 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeGrantHandler.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeGrantHandler.kt @@ -2,51 +2,68 @@ package no.nav.security.mock.oauth2.grant import com.nimbusds.jwt.SignedJWT import com.nimbusds.oauth2.sdk.AuthorizationCode +import com.nimbusds.oauth2.sdk.OAuth2Error import com.nimbusds.oauth2.sdk.TokenRequest import com.nimbusds.openid.connect.sdk.AuthenticationRequest import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse import mu.KotlinLogging -import no.nav.security.mock.callback.TokenCallback -import no.nav.security.mock.extensions.authorizationCode -import no.nav.security.mock.extensions.expiresIn -import no.nav.security.mock.oauth2.OAuth2TokenProvider -import no.nav.security.mock.oauth2.OAuth2TokenResponse +import no.nav.security.mock.oauth2.OAuth2Exception +import no.nav.security.mock.oauth2.extensions.authorizationCode +import no.nav.security.mock.oauth2.extensions.expiresIn +import no.nav.security.mock.oauth2.http.OAuth2TokenResponse +import no.nav.security.mock.oauth2.login.Login +import no.nav.security.mock.oauth2.token.OAuth2TokenCallback +import no.nav.security.mock.oauth2.token.OAuth2TokenProvider import okhttp3.HttpUrl import java.util.UUID private val log = KotlinLogging.logger {} class AuthorizationCodeHandler( - private val tokenProvider: OAuth2TokenProvider + private val tokenProvider: OAuth2TokenProvider = OAuth2TokenProvider() ) : GrantHandler { private val codeToAuthRequestCache: MutableMap = HashMap() + private val codeToLoginCache: MutableMap = HashMap() - fun authorizationCodeResponse(authenticationRequest: AuthenticationRequest): AuthenticationSuccessResponse { - val code = AuthorizationCode() - log.debug("issuing authorization code $code") - codeToAuthRequestCache[code] = authenticationRequest - return AuthenticationSuccessResponse( - authenticationRequest.redirectionURI, - code, - null, - null, - authenticationRequest.state, - null, - authenticationRequest.responseMode - ) + fun authorizationCodeResponse(authenticationRequest: AuthenticationRequest, login: Login? = null): AuthenticationSuccessResponse { + when { + authenticationRequest.responseType.impliesCodeFlow() -> { + val code = AuthorizationCode() + log.debug("issuing authorization code $code") + codeToAuthRequestCache[code] = authenticationRequest + if (login?.username != null) { + log.debug("adding user with username ${login.username} to cache") + codeToLoginCache[code] = login + } + return AuthenticationSuccessResponse( + authenticationRequest.redirectionURI, + code, + null, + null, + authenticationRequest.state, + null, + authenticationRequest.responseMode + ) + } + else -> throw OAuth2Exception( + OAuth2Error.INVALID_GRANT, "hybrid og implicit flow not supported (yet)." + ) + } } override fun tokenResponse( tokenRequest: TokenRequest, issuerUrl: HttpUrl, - tokenCallback: TokenCallback + oAuth2TokenCallback: OAuth2TokenCallback ): OAuth2TokenResponse { - val authenticationRequest = getAuthenticationRequest(tokenRequest.authorizationCode()) + val code = tokenRequest.authorizationCode() + log.debug("issuing token for code=$code") + val authenticationRequest = takeAuthenticationRequestFromCache(code) val scope: String? = tokenRequest.scope?.toString() val nonce: String? = authenticationRequest?.nonce?.value - val idToken: SignedJWT = tokenProvider.idToken(tokenRequest, issuerUrl, nonce, tokenCallback) - val accessToken: SignedJWT = tokenProvider.accessToken(tokenRequest, issuerUrl, nonce, tokenCallback) + val idToken: SignedJWT = tokenProvider.idToken(tokenRequest, issuerUrl, nonce, getLoginTokenCallbackOrDefault(code, oAuth2TokenCallback)) + val accessToken: SignedJWT = tokenProvider.accessToken(tokenRequest, issuerUrl, nonce, getLoginTokenCallbackOrDefault(code, oAuth2TokenCallback)) return OAuth2TokenResponse( tokenType = "Bearer", @@ -58,5 +75,20 @@ class AuthorizationCodeHandler( ) } - private fun getAuthenticationRequest(code: AuthorizationCode): AuthenticationRequest? = codeToAuthRequestCache[code] + private fun getLoginTokenCallbackOrDefault(code: AuthorizationCode, OAuth2TokenCallback: OAuth2TokenCallback): OAuth2TokenCallback { + return takeLoginFromCache(code)?.username?.let { + LoginOAuth2TokenCallback(it, OAuth2TokenCallback) + } ?: OAuth2TokenCallback + } + + private fun takeLoginFromCache(code: AuthorizationCode): Login? = codeToLoginCache.remove(code) + private fun takeAuthenticationRequestFromCache(code: AuthorizationCode): AuthenticationRequest? = codeToAuthRequestCache.remove(code) + + private class LoginOAuth2TokenCallback(val subject: String, val OAuth2TokenCallback: OAuth2TokenCallback) : OAuth2TokenCallback { + override fun issuerId(): String = OAuth2TokenCallback.issuerId() + override fun subject(tokenRequest: TokenRequest): String = subject + override fun audience(tokenRequest: TokenRequest): String = OAuth2TokenCallback.audience(tokenRequest) + override fun addClaims(tokenRequest: TokenRequest): Map = OAuth2TokenCallback.addClaims(tokenRequest) + override fun tokenExpiry(): Long = OAuth2TokenCallback.tokenExpiry() + } } diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/grant/ClientCredentialsGrantHandler.kt b/src/main/kotlin/no/nav/security/mock/oauth2/grant/ClientCredentialsGrantHandler.kt index 686c3b27..65d6b790 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/grant/ClientCredentialsGrantHandler.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/grant/ClientCredentialsGrantHandler.kt @@ -1,10 +1,10 @@ package no.nav.security.mock.oauth2.grant import com.nimbusds.oauth2.sdk.TokenRequest -import no.nav.security.mock.callback.TokenCallback -import no.nav.security.mock.extensions.expiresIn -import no.nav.security.mock.oauth2.OAuth2TokenProvider -import no.nav.security.mock.oauth2.OAuth2TokenResponse +import no.nav.security.mock.oauth2.extensions.expiresIn +import no.nav.security.mock.oauth2.http.OAuth2TokenResponse +import no.nav.security.mock.oauth2.token.OAuth2TokenCallback +import no.nav.security.mock.oauth2.token.OAuth2TokenProvider import okhttp3.HttpUrl import java.util.UUID @@ -15,13 +15,13 @@ class ClientCredentialsGrantHandler( override fun tokenResponse( tokenRequest: TokenRequest, issuerUrl: HttpUrl, - tokenCallback: TokenCallback + oAuth2TokenCallback: OAuth2TokenCallback ): OAuth2TokenResponse { val accessToken = tokenProvider.accessToken( tokenRequest, issuerUrl, null, - tokenCallback + oAuth2TokenCallback ) return OAuth2TokenResponse( tokenType = "Bearer", diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/grant/GrantHandler.kt b/src/main/kotlin/no/nav/security/mock/oauth2/grant/GrantHandler.kt index 91793d53..966adff9 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/grant/GrantHandler.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/grant/GrantHandler.kt @@ -1,14 +1,14 @@ package no.nav.security.mock.oauth2.grant import com.nimbusds.oauth2.sdk.TokenRequest -import no.nav.security.mock.callback.TokenCallback -import no.nav.security.mock.oauth2.OAuth2TokenResponse +import no.nav.security.mock.oauth2.http.OAuth2TokenResponse +import no.nav.security.mock.oauth2.token.OAuth2TokenCallback import okhttp3.HttpUrl interface GrantHandler { fun tokenResponse( tokenRequest: TokenRequest, issuerUrl: HttpUrl, - tokenCallback: TokenCallback + oAuth2TokenCallback: OAuth2TokenCallback ): OAuth2TokenResponse } diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/grant/JwtBearerGrantHandler.kt b/src/main/kotlin/no/nav/security/mock/oauth2/grant/JwtBearerGrantHandler.kt index 2605312a..9cdc303c 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/grant/JwtBearerGrantHandler.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/grant/JwtBearerGrantHandler.kt @@ -14,11 +14,11 @@ import com.nimbusds.jwt.proc.DefaultJWTProcessor import com.nimbusds.oauth2.sdk.JWTBearerGrant import com.nimbusds.oauth2.sdk.OAuth2Error import com.nimbusds.oauth2.sdk.TokenRequest -import no.nav.security.mock.callback.TokenCallback -import no.nav.security.mock.extensions.expiresIn import no.nav.security.mock.oauth2.OAuth2Exception -import no.nav.security.mock.oauth2.OAuth2TokenProvider -import no.nav.security.mock.oauth2.OAuth2TokenResponse +import no.nav.security.mock.oauth2.extensions.expiresIn +import no.nav.security.mock.oauth2.http.OAuth2TokenResponse +import no.nav.security.mock.oauth2.token.OAuth2TokenCallback +import no.nav.security.mock.oauth2.token.OAuth2TokenProvider import okhttp3.HttpUrl import java.util.HashSet import java.util.UUID @@ -28,14 +28,14 @@ class JwtBearerGrantHandler(private val tokenProvider: OAuth2TokenProvider) : Gr override fun tokenResponse( tokenRequest: TokenRequest, issuerUrl: HttpUrl, - tokenCallback: TokenCallback + oAuth2TokenCallback: OAuth2TokenCallback ): OAuth2TokenResponse { val receivedClaimsSet = assertion(tokenRequest) val accessToken = tokenProvider.onBehalfOfAccessToken( receivedClaimsSet, tokenRequest, - tokenCallback + oAuth2TokenCallback ) return OAuth2TokenResponse( tokenType = "Bearer", diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequest.kt b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequest.kt new file mode 100644 index 00000000..7cf9d54c --- /dev/null +++ b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequest.kt @@ -0,0 +1,43 @@ +package no.nav.security.mock.oauth2.http + +import com.nimbusds.oauth2.sdk.TokenRequest +import com.nimbusds.oauth2.sdk.http.HTTPRequest +import com.nimbusds.openid.connect.sdk.AuthenticationRequest +import okhttp3.Headers +import okhttp3.HttpUrl +import java.net.URLDecoder +import java.nio.charset.StandardCharsets + +data class OAuth2HttpRequest( + val headers: Headers, + val method: String, + val url: HttpUrl, + val body: String? +) { + val formParameters: Parameters = Parameters(body) + + fun asTokenRequest(): TokenRequest = + TokenRequest.parse( + HTTPRequest(HTTPRequest.Method.valueOf(method), url.toUrl()) + .apply { + headers.forEach { header -> this.setHeader(header.first, header.second) } + query = body + } + ) + + fun asAuthenticationRequest(): AuthenticationRequest = AuthenticationRequest.parse(this.url.toUri()) + + data class Parameters(val parameterString: String?) { + + val map: Map = + parameterString?.split("&") + ?.filter { it.contains("=") } + ?.associate { + val (left, right) = it.split("=") + decode(left) to decode(right) + } ?: emptyMap() + + fun get(name: String): String? = map[name] + private fun decode(string: String): String = URLDecoder.decode(string, StandardCharsets.UTF_8) + } +} diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt new file mode 100644 index 00000000..a1f4f42d --- /dev/null +++ b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt @@ -0,0 +1,146 @@ +package no.nav.security.mock.oauth2.http + +import com.nimbusds.oauth2.sdk.ErrorObject +import com.nimbusds.oauth2.sdk.GeneralException +import com.nimbusds.oauth2.sdk.GrantType +import com.nimbusds.oauth2.sdk.OAuth2Error +import com.nimbusds.oauth2.sdk.ParseException +import com.nimbusds.oauth2.sdk.TokenRequest +import com.nimbusds.openid.connect.sdk.AuthenticationRequest +import mu.KotlinLogging +import no.nav.security.mock.oauth2.OAuth2Config +import no.nav.security.mock.oauth2.OAuth2Exception +import no.nav.security.mock.oauth2.extensions.grantType +import no.nav.security.mock.oauth2.extensions.isAuthorizationEndpointUrl +import no.nav.security.mock.oauth2.extensions.isJwksUrl +import no.nav.security.mock.oauth2.extensions.isPrompt +import no.nav.security.mock.oauth2.extensions.isTokenEndpointUrl +import no.nav.security.mock.oauth2.extensions.isWellKnownUrl +import no.nav.security.mock.oauth2.extensions.issuerId +import no.nav.security.mock.oauth2.extensions.toAuthorizationEndpointUrl +import no.nav.security.mock.oauth2.extensions.toIssuerUrl +import no.nav.security.mock.oauth2.extensions.toJwksUrl +import no.nav.security.mock.oauth2.extensions.toTokenEndpointUrl +import no.nav.security.mock.oauth2.grant.AuthorizationCodeHandler +import no.nav.security.mock.oauth2.grant.ClientCredentialsGrantHandler +import no.nav.security.mock.oauth2.grant.GrantHandler +import no.nav.security.mock.oauth2.grant.JwtBearerGrantHandler +import no.nav.security.mock.oauth2.login.Login +import no.nav.security.mock.oauth2.login.LoginRequestHandler +import no.nav.security.mock.oauth2.token.DefaultOAuth2TokenCallback +import no.nav.security.mock.oauth2.token.OAuth2TokenCallback +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue + +private val log = KotlinLogging.logger {} + +// TODO: support more flows and oidc session management / logout +class OAuth2HttpRequestHandler( + private val config: OAuth2Config +) { + private val loginRequestHandler = LoginRequestHandler() + private val oAuth2TokenCallbackQueue: BlockingQueue = LinkedBlockingQueue() + + private val grantHandlers: Map = mapOf( + GrantType.AUTHORIZATION_CODE to AuthorizationCodeHandler(config.tokenProvider), + GrantType.CLIENT_CREDENTIALS to ClientCredentialsGrantHandler(config.tokenProvider), + GrantType.JWT_BEARER to JwtBearerGrantHandler(config.tokenProvider) + ) + + fun handleRequest(request: OAuth2HttpRequest): OAuth2HttpResponse { + return runCatching { + log.debug("received request on url=${request.url} with headers=${request.headers}") + val url = request.url + return when { + url.isWellKnownUrl() -> { + log.debug("returning well-known json data for url=$url") + return json(wellKnown(request)) + } + url.isAuthorizationEndpointUrl() -> { + log.debug("received call to authorization endpoint") + val authRequest: AuthenticationRequest = request.asAuthenticationRequest() + val authorizationCodeHandler = (grantHandler(authRequest) as AuthorizationCodeHandler) + return when (request.method) { + "GET" -> { + if (config.interactiveLogin || authRequest.isPrompt()) + html(loginRequestHandler.loginHtml(request)) + else { + authenticationSuccess(authorizationCodeHandler.authorizationCodeResponse(authRequest)) + } + } + "POST" -> { + val login: Login = LoginRequestHandler().loginSubmit(request) + authenticationSuccess(authorizationCodeHandler.authorizationCodeResponse(authRequest, login)) + } + else -> throw OAuth2Exception( + OAuth2Error.INVALID_REQUEST, + "Unsupported request method ${request.method}" + ) + } + } + url.isTokenEndpointUrl() -> { + log.debug("handle token request $request") + val oAuth2TokenCallback: OAuth2TokenCallback = takeTokenCallbackOrCreateDefault(request.url.issuerId()) + val tokenRequest: TokenRequest = request.asTokenRequest() + json(grantHandler(tokenRequest).tokenResponse(tokenRequest, request.url.toIssuerUrl(), oAuth2TokenCallback)) + } + url.isJwksUrl() -> { + log.debug("handle jwks request") + return json(config.tokenProvider.publicJwkSet().toJSONObject()) + } + else -> { + val msg = "path '${request.url}' not found" + log.error(msg) + return notFound() + } + } + }.fold( + onSuccess = { result -> result }, + onFailure = { error -> handleException(error) } + ) + } + + fun enqueueTokenCallback(oAuth2TokenCallback: OAuth2TokenCallback) = oAuth2TokenCallbackQueue.add(oAuth2TokenCallback) + + private fun takeTokenCallbackOrCreateDefault(issuerId: String): OAuth2TokenCallback { + if (oAuth2TokenCallbackQueue.peek()?.issuerId() == issuerId) { + return oAuth2TokenCallbackQueue.take() + } + return config.oAuth2TokenCallbacks.firstOrNull { it.issuerId() == issuerId } + ?: DefaultOAuth2TokenCallback(issuerId = issuerId) + } + + private fun handleException(error: Throwable): OAuth2HttpResponse { + log.error("received exception when handling request.", error) + val errorObject: ErrorObject = when (error) { + is OAuth2Exception -> error.errorObject + is ParseException -> error.errorObject + ?: OAuth2Error.INVALID_REQUEST + .appendDescription(". received exception message: ${error.message}") + is GeneralException -> error.errorObject + else -> null + } ?: OAuth2Error.SERVER_ERROR + .appendDescription(". received exception message: ${error.message}") + return oauth2Error(errorObject) + } + + private fun grantHandler(authenticationRequest: AuthenticationRequest): GrantHandler = + if (authenticationRequest.responseType.impliesCodeFlow()) { + (grantHandlers[GrantType.AUTHORIZATION_CODE] as AuthorizationCodeHandler) + } else throw OAuth2Exception( + OAuth2Error.INVALID_GRANT, "hybrid og implicit flow not supported (yet)." + ) + + private fun grantHandler(tokenRequest: TokenRequest): GrantHandler = + grantHandlers[tokenRequest.grantType()] ?: throw OAuth2Exception( + OAuth2Error.INVALID_GRANT, "grant_type ${tokenRequest.grantType()} not supported." + ) + + private fun wellKnown(request: OAuth2HttpRequest): WellKnown = + WellKnown( + issuer = request.url.toIssuerUrl().toString(), + authorizationEndpoint = request.url.toAuthorizationEndpointUrl().toString(), + tokenEndpoint = request.url.toTokenEndpointUrl().toString(), + jwksUri = request.url.toJwksUrl().toString() + ) +} diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpResponse.kt b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpResponse.kt new file mode 100644 index 00000000..b9a55897 --- /dev/null +++ b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpResponse.kt @@ -0,0 +1,101 @@ +package no.nav.security.mock.oauth2.http + +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.annotation.JsonProperty +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.SerializationFeature +import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper +import com.nimbusds.oauth2.sdk.ErrorObject +import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse +import okhttp3.Headers + +private val objectMapper: ObjectMapper = jacksonObjectMapper() + +data class OAuth2HttpResponse( + val headers: Headers = Headers.headersOf(), + val status: Int, + val body: String? = null +) { + object ContentType { + const val HEADER = "Content-Type" + const val JSON = "application/json;charset=UTF-8" + const val HTML = "text/html;charset=UTF-8" + } +} + +data class WellKnown( + val issuer: String, + @JsonProperty("authorization_endpoint") + val authorizationEndpoint: String, + @JsonProperty("token_endpoint") + val tokenEndpoint: String, + @JsonProperty("jwks_uri") + val jwksUri: String, + @JsonProperty("response_types_supported") + val responseTypesSupported: List = listOf("query", "fragment", "form_post"), + @JsonProperty("subject_types_supported") + val subjectTypesSupported: List = listOf("public"), + @JsonProperty("id_token_signing_alg_values_supported") + val idTokenSigningAlgValuesSupported: List = listOf("RS256") +) + +@JsonInclude(JsonInclude.Include.NON_NULL) +data class OAuth2TokenResponse( + @JsonProperty("token_type") + val tokenType: String, + @JsonProperty("id_token") + val idToken: String? = null, + @JsonProperty("access_token") + val accessToken: String?, + @JsonProperty("refresh_token") + val refreshToken: String? = null, + @JsonProperty("expires_in") + val expiresIn: Int = 0, + @JsonProperty("scope") + val scope: String? = null +) + +fun json(anyObject: Any): OAuth2HttpResponse = OAuth2HttpResponse( + headers = Headers.headersOf( + OAuth2HttpResponse.ContentType.HEADER, OAuth2HttpResponse.ContentType.JSON + ), + status = 200, + body = when (anyObject) { + is String -> anyObject + else -> objectMapper + .enable(SerializationFeature.INDENT_OUTPUT) + .writeValueAsString(anyObject) + } +) + +fun html(content: String): OAuth2HttpResponse = OAuth2HttpResponse( + headers = Headers.headersOf( + OAuth2HttpResponse.ContentType.HEADER, OAuth2HttpResponse.ContentType.HTML + ), + status = 200, + body = content +) + +fun notFound(): OAuth2HttpResponse = OAuth2HttpResponse(status = 404) + +fun authenticationSuccess(authenticationSuccessResponse: AuthenticationSuccessResponse): OAuth2HttpResponse { + val httpResponse = authenticationSuccessResponse.toHTTPResponse() + return OAuth2HttpResponse( + headers = Headers.headersOf("Location", httpResponse.location!!.toString()), + status = 302 + ) +} + +fun oauth2Error(error: ErrorObject): OAuth2HttpResponse { + val responseCode = error.httpStatusCode.takeUnless { it == 302 } ?: 400 + return OAuth2HttpResponse( + headers = Headers.headersOf( + OAuth2HttpResponse.ContentType.HEADER, OAuth2HttpResponse.ContentType.JSON + ), + status = responseCode, + body = objectMapper + .enable(SerializationFeature.INDENT_OUTPUT) + .writeValueAsString(error.toJSONObject()) + .toLowerCase() + ) +} diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/login/LoginRequestHandler.kt b/src/main/kotlin/no/nav/security/mock/oauth2/login/LoginRequestHandler.kt new file mode 100644 index 00000000..7f8a5385 --- /dev/null +++ b/src/main/kotlin/no/nav/security/mock/oauth2/login/LoginRequestHandler.kt @@ -0,0 +1,23 @@ +package no.nav.security.mock.oauth2.login + +import no.nav.security.mock.oauth2.http.OAuth2HttpRequest +import no.nav.security.mock.oauth2.templates.TemplateMapper +import no.nav.security.mock.oauth2.templates.TemplateMapper.Companion.create + +val templateMapper: TemplateMapper = create {} + +class LoginRequestHandler { + + fun loginHtml(httpRequest: OAuth2HttpRequest): String = templateMapper.loginHtml(httpRequest) + + fun loginSubmit(httpRequest: OAuth2HttpRequest): Login { + val formParameters = httpRequest.formParameters + val username = checkNotNull(formParameters.get("username")) + return Login(username, formParameters.get("acr")) + } +} + +data class Login( + val username: String, + val acr: String? = null +) diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/templates/TemplateMapper.kt b/src/main/kotlin/no/nav/security/mock/oauth2/templates/TemplateMapper.kt new file mode 100644 index 00000000..a2f86e8a --- /dev/null +++ b/src/main/kotlin/no/nav/security/mock/oauth2/templates/TemplateMapper.kt @@ -0,0 +1,42 @@ +package no.nav.security.mock.oauth2.templates + +import freemarker.cache.ClassTemplateLoader +import freemarker.template.Configuration +import no.nav.security.mock.oauth2.http.OAuth2HttpRequest +import java.io.StringWriter + +data class HtmlContent( + val template: String, + val model: Any? +) + +class TemplateMapper( + private val config: Configuration +) { + + fun loginHtml(oAuth2HttpRequest: OAuth2HttpRequest): String = + asString( + HtmlContent( + "login.ftl", mapOf( + "request" to oAuth2HttpRequest, + "query" to OAuth2HttpRequest.Parameters(oAuth2HttpRequest.url.query).map + ) + ) + ) + + private fun asString(htmlContent: HtmlContent): String = + StringWriter().apply { + config.getTemplate(htmlContent.template).process(htmlContent.model, this) + }.toString() + + companion object { + fun create(configure: Configuration.() -> Unit): TemplateMapper { + val config = Configuration(Configuration.DEFAULT_INCOMPATIBLE_IMPROVEMENTS) + .apply { + templateLoader = ClassTemplateLoader(this::class.java.classLoader, "templates") + } + .apply(configure) + return TemplateMapper(config) + } + } +} diff --git a/src/main/kotlin/no/nav/security/mock/callback/TokenCallback.kt b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt similarity index 85% rename from src/main/kotlin/no/nav/security/mock/callback/TokenCallback.kt rename to src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt index b2bcdc19..b18f2f38 100644 --- a/src/main/kotlin/no/nav/security/mock/callback/TokenCallback.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt @@ -1,13 +1,13 @@ -package no.nav.security.mock.callback +package no.nav.security.mock.oauth2.token import com.nimbusds.oauth2.sdk.GrantType import com.nimbusds.oauth2.sdk.TokenRequest import com.nimbusds.openid.connect.sdk.OIDCScopeValue -import no.nav.security.mock.extensions.clientIdAsString -import no.nav.security.mock.extensions.grantType +import no.nav.security.mock.oauth2.extensions.clientIdAsString +import no.nav.security.mock.oauth2.extensions.grantType import java.util.UUID -interface TokenCallback { +interface OAuth2TokenCallback { fun issuerId(): String fun subject(tokenRequest: TokenRequest): String fun audience(tokenRequest: TokenRequest): String @@ -15,14 +15,16 @@ interface TokenCallback { fun tokenExpiry(): Long } -class DefaultTokenCallback( +open class DefaultOAuth2TokenCallback( private val issuerId: String = "default", private val subject: String = UUID.randomUUID().toString(), private val audience: String? = null, private val claims: Map = emptyMap(), private val expiry: Long = 3600 -) : TokenCallback { +) : OAuth2TokenCallback { + override fun issuerId(): String = issuerId + override fun subject(tokenRequest: TokenRequest): String { return when (GrantType.CLIENT_CREDENTIALS) { tokenRequest.grantType() -> tokenRequest.clientID.value @@ -35,7 +37,7 @@ class DefaultTokenCallback( return audience ?: let { tokenRequest.scope?.toStringList() - ?.filterNot { oidcScopeList.contains(it) }?.first() + ?.filterNot { oidcScopeList.contains(it) }?.firstOrNull() } ?: "default" } diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/OAuth2TokenProvider.kt b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt similarity index 79% rename from src/main/kotlin/no/nav/security/mock/oauth2/OAuth2TokenProvider.kt rename to src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt index 8e21cf3c..b96265ea 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/OAuth2TokenProvider.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt @@ -1,4 +1,4 @@ -package no.nav.security.mock.oauth2 +package no.nav.security.mock.oauth2.token import com.nimbusds.jose.JOSEObjectType import com.nimbusds.jose.JWSAlgorithm @@ -9,9 +9,7 @@ import com.nimbusds.jose.jwk.RSAKey import com.nimbusds.jwt.JWTClaimsSet import com.nimbusds.jwt.SignedJWT import com.nimbusds.oauth2.sdk.TokenRequest -import mu.KotlinLogging -import no.nav.security.mock.callback.TokenCallback -import no.nav.security.mock.extensions.clientIdAsString +import no.nav.security.mock.oauth2.extensions.clientIdAsString import okhttp3.HttpUrl import java.security.KeyPair import java.security.KeyPairGenerator @@ -22,14 +20,13 @@ import java.time.Instant import java.util.Date import java.util.UUID -private val log = KotlinLogging.logger {} - open class OAuth2TokenProvider { private val jwkSet: JWKSet private val rsaKey: RSAKey init { - jwkSet = generateJWKSet(DEFAULT_KEYID) + jwkSet = + generateJWKSet(DEFAULT_KEYID) rsaKey = jwkSet.getKeyByKeyId(DEFAULT_KEYID) as RSAKey } @@ -41,16 +38,16 @@ open class OAuth2TokenProvider { tokenRequest: TokenRequest, issuerUrl: HttpUrl, nonce: String?, - tokenCallback: TokenCallback + oAuth2TokenCallback: OAuth2TokenCallback ): SignedJWT { return createSignedJWT( defaultClaims( issuerUrl, - tokenCallback.subject(tokenRequest), + oAuth2TokenCallback.subject(tokenRequest), tokenRequest.clientIdAsString(), nonce, - tokenCallback.addClaims(tokenRequest), - tokenCallback.tokenExpiry() + oAuth2TokenCallback.addClaims(tokenRequest), + oAuth2TokenCallback.tokenExpiry() ).build() ) } @@ -59,16 +56,16 @@ open class OAuth2TokenProvider { tokenRequest: TokenRequest, issuerUrl: HttpUrl, nonce: String?, - tokenCallback: TokenCallback + oAuth2TokenCallback: OAuth2TokenCallback ): SignedJWT { return createSignedJWT( defaultClaims( issuerUrl, - tokenCallback.subject(tokenRequest), - tokenCallback.audience(tokenRequest), + oAuth2TokenCallback.subject(tokenRequest), + oAuth2TokenCallback.audience(tokenRequest), nonce, - tokenCallback.addClaims(tokenRequest), - tokenCallback.tokenExpiry() + oAuth2TokenCallback.addClaims(tokenRequest), + oAuth2TokenCallback.tokenExpiry() ).build() ) } @@ -76,16 +73,16 @@ open class OAuth2TokenProvider { fun onBehalfOfAccessToken( claimsSet: JWTClaimsSet, tokenRequest: TokenRequest, - tokenCallback: TokenCallback + oAuth2TokenCallback: OAuth2TokenCallback ): SignedJWT { val now = Instant.now() return createSignedJWT( JWTClaimsSet.Builder(claimsSet) - .expirationTime(Date.from(now.plusSeconds(tokenCallback.tokenExpiry()))) + .expirationTime(Date.from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry()))) .notBeforeTime(Date.from(now)) .issueTime(Date.from(now)) .jwtID(UUID.randomUUID().toString()) - .audience(tokenCallback.audience(tokenRequest)) + .audience(oAuth2TokenCallback.audience(tokenRequest)) .build() ) } @@ -130,7 +127,12 @@ open class OAuth2TokenProvider { companion object { private const val DEFAULT_KEYID = "mock-oauth2-server-key" private fun generateJWKSet(keyId: String): JWKSet { - return JWKSet(createJWK(keyId, generateKeyPair())) + return JWKSet( + createJWK( + keyId, + generateKeyPair() + ) + ) } private fun generateKeyPair(): KeyPair { diff --git a/src/main/resources/templates/layout.ftl b/src/main/resources/templates/layout.ftl new file mode 100644 index 00000000..96c0b7cf --- /dev/null +++ b/src/main/resources/templates/layout.ftl @@ -0,0 +1,137 @@ +<#macro mainLayout title="" description=""> + +<#--noinspection HtmlRequiredLangAttribute--> + + + + + ${title} | ${description} + + + + <#nested /> + + + diff --git a/src/main/resources/templates/login.ftl b/src/main/resources/templates/login.ftl new file mode 100644 index 00000000..6a7dcac9 --- /dev/null +++ b/src/main/resources/templates/login.ftl @@ -0,0 +1,20 @@ +<#import "layout.ftl" as layout /> + +<@layout.mainLayout title="mock-oauth2-server" description="Just a mock login"> + + diff --git a/src/test/kotlin/no/nav/security/mock/MockOAuth2ServerTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/MockOAuth2ServerTest.kt similarity index 58% rename from src/test/kotlin/no/nav/security/mock/MockOAuth2ServerTest.kt rename to src/test/kotlin/no/nav/security/mock/oauth2/MockOAuth2ServerTest.kt index cd17507f..571f888a 100644 --- a/src/test/kotlin/no/nav/security/mock/MockOAuth2ServerTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/MockOAuth2ServerTest.kt @@ -1,11 +1,12 @@ -package no.nav.security.mock +package no.nav.security.mock.oauth2 import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import com.fasterxml.jackson.module.kotlin.readValue import com.nimbusds.jwt.SignedJWT import com.nimbusds.oauth2.sdk.GrantType -import no.nav.security.mock.callback.DefaultTokenCallback -import no.nav.security.mock.oauth2.OAuth2TokenResponse +import no.nav.security.mock.oauth2.http.OAuth2TokenResponse +import no.nav.security.mock.oauth2.token.DefaultOAuth2TokenCallback +import no.nav.security.mock.oauth2.token.OAuth2TokenProvider import okhttp3.Credentials import okhttp3.FormBody import okhttp3.HttpUrl @@ -14,6 +15,7 @@ import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.RequestBody import okhttp3.Response +import okhttp3.mockwebserver.MockResponse import okio.IOException import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.AfterEach @@ -29,16 +31,25 @@ class MockOAuth2ServerTest { .build() private lateinit var server: MockOAuth2Server + private lateinit var interactiveLoginServer: MockOAuth2Server @BeforeEach fun before() { server = MockOAuth2Server() server.start() + interactiveLoginServer = MockOAuth2Server( + OAuth2Config( + interactiveLogin = true, + oAuth2TokenCallbacks = emptySet(), + tokenProvider = OAuth2TokenProvider() + ) + ) } @AfterEach fun shutdown() { server.shutdown() + interactiveLoginServer.shutdown() } @Test @@ -48,6 +59,23 @@ class MockOAuth2ServerTest { assertWellKnownResponseForIssuer("bar") } + @Test + fun enqueuedResponse(){ + assertWellKnownResponseForIssuer("default") + server.enqueueResponse(MockResponse() + .setResponseCode(200) + .setBody("some body") + ) + val request: Request = Request.Builder() + .url(server.url("/someurl")) + .get() + .build() + + val response = client.newCall(request).execute() + assertThat(response.code).isEqualTo(200) + assertThat(response.body?.string()).isEqualTo("some body") + } + @Test fun noIssuerIdInUrlShouldReturn404() { val request: Request = Request.Builder() @@ -83,12 +111,96 @@ class MockOAuth2ServerTest { ) } + @Test + fun fullAuthorizationCodeFlow() { + val authorizationCodeFlowUrl = authorizationCodeFlowUrl( + "default", + "client1", + "http://myapp/callback", + "openid scope1" + ) + val request: Request = Request.Builder() + .url(authorizationCodeFlowUrl) + .get() + .build() + + val response: Response = client.newCall(request).execute() + val url: HttpUrl = checkNotNull(response.headers["location"]?.toHttpUrlOrNull()) + val code = checkNotNull(url.queryParameter("code")) + val tokenResponse: Response = client.newCall( + authCodeTokenRequest( + server.tokenEndpointUrl("default"), + "client1", + "https://myapp/callback", + "openid scope1", + code + ) + ).execute() + assertThat(tokenResponse.code).isEqualTo(200) + val oAuth2TokenResponse: OAuth2TokenResponse = jacksonObjectMapper().readValue(checkNotNull(tokenResponse.body?.string())) + assertThat(oAuth2TokenResponse.accessToken).isNotNull() + assertThat(oAuth2TokenResponse.idToken).isNotNull() + assertThat(oAuth2TokenResponse.expiresIn).isGreaterThan(0) + assertThat(oAuth2TokenResponse.scope).contains("openid scope1") + assertThat(oAuth2TokenResponse.tokenType).isEqualTo("Bearer") + val idToken: SignedJWT = SignedJWT.parse(oAuth2TokenResponse.idToken) + val accessToken: SignedJWT = SignedJWT.parse(oAuth2TokenResponse.accessToken) + assertThat(idToken.jwtClaimsSet.audience.first()).isEqualTo("client1") + assertThat(accessToken.jwtClaimsSet.audience).containsExactly("scope1") + } + + @Test + fun fullAuthorizationCodeFlowWithInteractiveLogin() { + interactiveLoginServer.start() + val authorizationCodeFlowUrl = authorizationCodeFlowUrl( + interactiveLoginServer.authorizationEndpointUrl("default"), + "client1", + "http://myapp/callback", + "openid scope1" + ) + + val authEndpointResponse: Response = client.newCall( + Request.Builder() + .url(authorizationCodeFlowUrl) + .get() + .build() + ).execute() + assertThat(authEndpointResponse.headers["Content-Type"]).isEqualTo("text/html;charset=UTF-8") + val expectedSubject = "foo" + val loginResponse: Response = client.newCall(loginSubmitRequest(authorizationCodeFlowUrl, expectedSubject)).execute() + assertThat(loginResponse.code).isEqualTo(302) + val url: HttpUrl = checkNotNull(loginResponse.headers["location"]?.toHttpUrlOrNull()) + val code = checkNotNull(url.queryParameter("code")) + val tokenResponse: Response = client.newCall( + authCodeTokenRequest( + interactiveLoginServer.tokenEndpointUrl("default"), + "client1", + "https://myapp/callback", + "openid scope1", + code + ) + ).execute() + assertThat(tokenResponse.code).isEqualTo(200) + val oAuth2TokenResponse: OAuth2TokenResponse = jacksonObjectMapper().readValue(checkNotNull(tokenResponse.body?.string())) + assertThat(oAuth2TokenResponse.accessToken).isNotNull() + assertThat(oAuth2TokenResponse.idToken).isNotNull() + assertThat(oAuth2TokenResponse.expiresIn).isGreaterThan(0) + assertThat(oAuth2TokenResponse.scope).contains("openid scope1") + assertThat(oAuth2TokenResponse.tokenType).isEqualTo("Bearer") + val idToken: SignedJWT = SignedJWT.parse(oAuth2TokenResponse.idToken) + val accessToken: SignedJWT = SignedJWT.parse(oAuth2TokenResponse.accessToken) + assertThat(idToken.jwtClaimsSet.subject).isEqualTo("foo") + assertThat(idToken.jwtClaimsSet.audience.first()).isEqualTo("client1") + assertThat(accessToken.jwtClaimsSet.audience).containsExactly("scope1") + interactiveLoginServer.shutdown() + } + @Test @Throws(IOException::class) fun tokenRequestWithCodeShouldReturnTokensWithDefaultClaims() { val response: Response = client.newCall( authCodeTokenRequest( - "default", + server.tokenEndpointUrl("default"), "client1", "https://myapp/callback", "openid scope1", @@ -113,7 +225,7 @@ class MockOAuth2ServerTest { @Throws(IOException::class) fun tokenWithCodeShouldReturnTokensWithClaimsFromEnqueuedCallback() { server.enqueueCallback( - DefaultTokenCallback( + DefaultOAuth2TokenCallback( issuerId = "custom", subject = "yolo", audience = "myaud" @@ -122,7 +234,7 @@ class MockOAuth2ServerTest { val response: Response = client.newCall( authCodeTokenRequest( - "custom", + server.tokenEndpointUrl("custom"), "client1", "https://myapp/callback", "openid scope1", @@ -147,7 +259,7 @@ class MockOAuth2ServerTest { @Test fun tokenRequestForjwtBearerGrant() { - val signedJWT = server.issueToken("default", "client1", DefaultTokenCallback()) + val signedJWT = server.issueToken("default", "client1", DefaultOAuth2TokenCallback()) val response: Response = client.newCall( jwtBearerGrantTokenRequest( "default", @@ -183,6 +295,16 @@ class MockOAuth2ServerTest { return responseBody } + private fun loginSubmitRequest(url: HttpUrl, username: String): Request { + val formBody: RequestBody = FormBody.Builder() + .add("username", username) + .build() + return Request.Builder() + .url(url) + .post(formBody) + .build() + } + private fun jwtBearerGrantTokenRequest( issuerId: String, clientId: String, @@ -203,7 +325,7 @@ class MockOAuth2ServerTest { } private fun authCodeTokenRequest( - issuerId: String, + tokenEndpointUrl: HttpUrl, clientId: String, redirectUri: String, scope: String, @@ -216,7 +338,7 @@ class MockOAuth2ServerTest { .add("grant_type", "authorization_code") .build() return Request.Builder() - .url(server.tokenEndpointUrl(issuerId)) + .url(tokenEndpointUrl) .addHeader("Authorization", Credentials.basic(clientId, "test")) .post(formBody) .build() @@ -227,8 +349,20 @@ class MockOAuth2ServerTest { clientId: String, redirectUri: String, scope: String + ): HttpUrl = authorizationCodeFlowUrl( + server.authorizationEndpointUrl(issuerId), + clientId, + redirectUri, + scope + ) + + private fun authorizationCodeFlowUrl( + authEndpointUrl: HttpUrl, + clientId: String, + redirectUri: String, + scope: String ): HttpUrl { - return server.authorizationEndpointUrl(issuerId).newBuilder() + return authEndpointUrl.newBuilder() .addEncodedQueryParameter("client_id", clientId) .addEncodedQueryParameter("response_type", "code") .addEncodedQueryParameter("redirect_uri", redirectUri) diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeHandlerTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeHandlerTest.kt new file mode 100644 index 00000000..52d5134a --- /dev/null +++ b/src/test/kotlin/no/nav/security/mock/oauth2/grant/AuthorizationCodeHandlerTest.kt @@ -0,0 +1,115 @@ +package no.nav.security.mock.oauth2.grant + +import com.nimbusds.jwt.SignedJWT +import com.nimbusds.oauth2.sdk.AuthorizationCode +import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant +import com.nimbusds.oauth2.sdk.ResponseMode +import com.nimbusds.oauth2.sdk.Scope +import com.nimbusds.oauth2.sdk.TokenRequest +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic +import com.nimbusds.oauth2.sdk.auth.Secret +import com.nimbusds.oauth2.sdk.id.ClientID +import com.nimbusds.oauth2.sdk.id.State +import com.nimbusds.openid.connect.sdk.AuthenticationRequest +import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse +import no.nav.security.mock.oauth2.token.DefaultOAuth2TokenCallback +import no.nav.security.mock.oauth2.login.Login +import okhttp3.HttpUrl +import okhttp3.HttpUrl.Companion.toHttpUrl +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import java.net.URI + +internal class AuthorizationCodeHandlerTest { + private val handler = AuthorizationCodeHandler() + + @Test + fun authorizationCodeResponse() { + val response: AuthenticationSuccessResponse = + handler.authorizationCodeResponse( + authRequest( + "client1", + "openid", + "code", + "http://redirect", + "someState", + "someNonce", + "query" + ) + ) + assertThatAuthResponseContainsRequiredParams(response) + } + + @Test + fun tokenResponse() { + } + + @Test + fun tokenResponseWithLogin() { + val response: AuthenticationSuccessResponse = + handler.authorizationCodeResponse( + authRequest( + "client1", + "openid", + "code", + "http://redirect", + "someState", + "someNonce", + "query" + ), + Login("foo") + ) + assertThatAuthResponseContainsRequiredParams(response) + + val tokenResponse = handler.tokenResponse( + tokenRequest(response.authorizationCode, "http://redirect", "openid"), + "http://myissuer".toHttpUrl(), + DefaultOAuth2TokenCallback() + ) + val idToken: SignedJWT = SignedJWT.parse(tokenResponse.idToken) + assertThat(idToken.jwtClaimsSet.audience.first()).isEqualTo("client1") + assertThat(idToken.jwtClaimsSet.subject).isEqualTo("foo") + } + + private fun assertThatAuthResponseContainsRequiredParams(response: AuthenticationSuccessResponse) { + assertThat(response.impliedResponseType().impliesCodeFlow()).isTrue() + assertThat(response.impliedResponseMode()).isEqualTo(ResponseMode.QUERY) + assertThat(response.state).isEqualTo(State("someState")) + assertThat(response.redirectionURI).isEqualTo(URI.create("http://redirect")) + } + + private fun authRequest( + clientId: String, + scope: String, + responseType: String, + redirectUri: String, + state: String, + nonce: String, + responseMode: String + ): AuthenticationRequest { + val url: HttpUrl = "http://localhost".toHttpUrl().newBuilder() + .addQueryParameter("client_id", clientId) + .addQueryParameter("scope", scope) + .addQueryParameter("response_type", responseType) + .addQueryParameter("redirect_uri", redirectUri) + .addQueryParameter("state", state) + .addQueryParameter("nonce", nonce) + .addQueryParameter("response_mode", responseMode) + .build() + return AuthenticationRequest.parse(url.toUri()) + } + + private fun tokenRequest( + code: AuthorizationCode, + redirectUri: String, + scope: String + ): TokenRequest { + return TokenRequest( + URI.create("http://localhost/token"), + ClientSecretBasic(ClientID("client1"), Secret("clientSecret")), + AuthorizationCodeGrant(code, + URI.create(redirectUri)), + Scope(scope) + ) + } +}