Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add usage example tests #2

Merged
merged 4 commits into from
Mar 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import no.nav.security.mock.oauth2.extensions.toWellKnownUrl
import no.nav.security.mock.oauth2.http.OAuth2HttpRequestHandler
import no.nav.security.mock.oauth2.http.OAuth2HttpResponse
import no.nav.security.mock.oauth2.token.OAuth2TokenCallback
import no.nav.security.mock.oauth2.token.OAuth2TokenProvider
import okhttp3.HttpUrl
import okhttp3.mockwebserver.Dispatcher
import okhttp3.mockwebserver.MockResponse
Expand All @@ -31,11 +30,9 @@ import java.util.concurrent.LinkedBlockingQueue
private val log = KotlinLogging.logger {}

class MockOAuth2Server(
config: OAuth2Config = OAuth2Config()
val config: OAuth2Config = OAuth2Config()
) {
private val mockWebServer: MockWebServer = MockWebServer()
private val tokenProvider: OAuth2TokenProvider =
OAuth2TokenProvider()

var dispatcher: Dispatcher = MockOAuth2Dispatcher(config)

Expand Down Expand Up @@ -76,7 +73,7 @@ class MockOAuth2Server(
ClientSecretBasic(ClientID(clientId), Secret("secret")),
AuthorizationCodeGrant(AuthorizationCode("123"), URI.create("http://localhost"))
)
return tokenProvider.accessToken(tokenRequest, issuerUrl, null, OAuth2TokenCallback)
return config.tokenProvider.accessToken(tokenRequest, issuerUrl, null, OAuth2TokenCallback)
}
}

Expand All @@ -95,7 +92,6 @@ class MockOAuth2Dispatcher(
else -> mockResponse(httpRequestHandler.handleRequest(request.asOAuth2HttpRequest()))
}


private fun mockResponse(response: OAuth2HttpResponse): MockResponse =
MockResponse()
.setHeaders(response.headers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ open class DefaultOAuth2TokenCallback(

override fun subject(tokenRequest: TokenRequest): String {
return when (GrantType.CLIENT_CREDENTIALS) {
tokenRequest.grantType() -> tokenRequest.clientID.value
tokenRequest.grantType() -> tokenRequest.clientIdAsString()
else -> subject
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class MockOAuth2ServerTest {
}

@Test
fun enqueuedResponse(){
fun enqueuedResponse() {
assertWellKnownResponseForIssuer("default")
server.enqueueResponse(MockResponse()
.setResponseCode(200)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package no.nav.security.mock.oauth2.examples

import com.fasterxml.jackson.databind.ObjectMapper
import com.nimbusds.jose.JOSEObjectType
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.JWKSet
import com.nimbusds.jose.jwk.source.ImmutableJWKSet
import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier
import com.nimbusds.jose.proc.JWSKeySelector
import com.nimbusds.jose.proc.JWSVerificationKeySelector
import com.nimbusds.jose.proc.SecurityContext
import com.nimbusds.jose.util.DefaultResourceRetriever
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier
import com.nimbusds.jwt.proc.DefaultJWTProcessor
import com.nimbusds.oauth2.sdk.id.Issuer
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata
import mu.KotlinLogging
import okhttp3.HttpUrl
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.mockwebserver.Dispatcher
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.MockWebServer
import okhttp3.mockwebserver.RecordedRequest
import java.net.URL
import java.util.HashSet

private val log = KotlinLogging.logger {}

abstract class AbstractExampleApp(oauth2DiscoveryUrl: String) {

val oauth2Client: OkHttpClient = OkHttpClient()
.newBuilder()
.followRedirects(false)
.build()

val metadata = OIDCProviderMetadata.parse(DefaultResourceRetriever().retrieveResource(URL(oauth2DiscoveryUrl)).content)

lateinit var exampleApp: MockWebServer

fun start() {
exampleApp = MockWebServer()
exampleApp.start()
exampleApp.dispatcher = object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse {
return runCatching {
handleRequest(request)
}.fold(
onSuccess = { result -> result },
onFailure = { error ->
log.error("received unhandled exception.", error)
MockResponse()
.setResponseCode(500)
.setBody("unhandled exception with message ${error.message}")
}
)
}
}
}

fun shutdown() {
exampleApp.shutdown()
}

fun url(path: String): HttpUrl = exampleApp.url(path)

fun retrieveJwks(): JWKSet {
return oauth2Client.newCall(
Request.Builder()
.url(metadata.jwkSetURI.toURL())
.get()
.build()
).execute().body?.string()?.let {
JWKSet.parse(it)
} ?: throw RuntimeException("could not retrieve jwks")
}

fun verifyJwt(jwt: String, issuer: Issuer, jwkSet: JWKSet): JWTClaimsSet {
val jwtProcessor: ConfigurableJWTProcessor<SecurityContext?> = DefaultJWTProcessor()
jwtProcessor.jwsTypeVerifier = DefaultJOSEObjectTypeVerifier(JOSEObjectType("JWT"))
val keySelector: JWSKeySelector<SecurityContext?> = JWSVerificationKeySelector(
JWSAlgorithm.RS256,
ImmutableJWKSet(jwkSet)
)
jwtProcessor.jwsKeySelector = keySelector
jwtProcessor.jwtClaimsSetVerifier = DefaultJWTClaimsVerifier(
JWTClaimsSet.Builder().issuer(issuer.toString()).build(),
HashSet(listOf("sub", "iat", "exp", "aud"))
)
return try {
jwtProcessor.process(jwt, null)
} catch (e: Exception) {
throw RuntimeException("invalid jwt.", e)
}
}

fun bearerToken(request: RecordedRequest): String? =
request.headers["Authorization"]
?.split("Bearer ")
?.let { it[0] }

fun notAuthorized(): MockResponse = MockResponse().setResponseCode(401)

fun json(value: Any): MockResponse = MockResponse()
.setResponseCode(200)
.setHeader("Content-Type","application/json")
.setBody(ObjectMapper().writeValueAsString(value))

abstract fun handleRequest(request: RecordedRequest): MockResponse
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package no.nav.security.mock.oauth2.examples.clientcredentials

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import no.nav.security.mock.oauth2.examples.AbstractExampleApp
import okhttp3.Credentials
import okhttp3.FormBody
import okhttp3.Request
import okhttp3.Response
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.RecordedRequest

class ExampleAppWithClientCredentialsClient(oauth2DiscoveryUrl: String) : AbstractExampleApp(oauth2DiscoveryUrl) {

override fun handleRequest(request: RecordedRequest): MockResponse {
return getClientCredentialsAccessToken()
?.let {
MockResponse()
.setResponseCode(200)
.setBody("token=$it")
}
?: MockResponse().setResponseCode(500).setBody("could not get access_token")
}

private fun getClientCredentialsAccessToken(): String? {
val tokenResponse: Response = oauth2Client.newCall(
Request.Builder()
.url(metadata.tokenEndpointURI.toURL())
.addHeader("Authorization", Credentials.basic("ExampleAppWithClientCredentialsClient", "test"))
.post(
FormBody.Builder()
.add("client_id", "ExampleAppWithClientCredentialsClient")
.add("scope", "scope1")
.add("grant_type", "client_credentials")
.build()
)
.build()
).execute()
return tokenResponse.body?.string()?.let {
ObjectMapper().readValue<JsonNode>(it).get("access_token")?.textValue()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package no.nav.security.mock.oauth2.examples.clientcredentials

import com.nimbusds.jwt.SignedJWT
import no.nav.security.mock.oauth2.MockOAuth2Server
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test

internal class ExampleAppWithClientCredentialsClientTest {
private lateinit var client: OkHttpClient
private lateinit var oAuth2Server: MockOAuth2Server
private lateinit var exampleApp: ExampleAppWithClientCredentialsClient

private val ISSUER_ID = "test"

@BeforeEach
fun before() {
oAuth2Server = MockOAuth2Server()
oAuth2Server.start()
exampleApp = ExampleAppWithClientCredentialsClient(oAuth2Server.wellKnownUrl(ISSUER_ID).toString())
exampleApp.start()
client = OkHttpClient().newBuilder().build()
}

@AfterEach
fun shutdown() {
oAuth2Server.shutdown()
exampleApp.shutdown()
}

@Test
fun appShouldReturnClientCredentialsAccessTokenWhenInvoked() {
val response: Response = client.newCall(
Request.Builder()
.url(exampleApp.url("/clientcredentials"))
.get()
.build()
).execute()
assertThat(response.code).isEqualTo(200)

val token: SignedJWT? = response.body?.string()
?.split("token=")
?.let { it[1] }
?.let { SignedJWT.parse(it) }

assertThat(token).isNotNull
assertThat(token?.jwtClaimsSet?.subject).isEqualTo("ExampleAppWithClientCredentialsClient")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package no.nav.security.mock.oauth2.examples.openidconnect

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.openid.connect.sdk.AuthenticationRequest
import mu.KotlinLogging
import no.nav.security.mock.oauth2.examples.AbstractExampleApp
import okhttp3.FormBody
import okhttp3.Request
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.RecordedRequest

private val log = KotlinLogging.logger {}

class ExampleAppWithOpenIdConnect(oidcDiscoveryUrl: String) : AbstractExampleApp(oidcDiscoveryUrl) {

override fun handleRequest(request: RecordedRequest): MockResponse {
return when (request.requestUrl?.encodedPath) {
"/login" -> {
MockResponse()
.setResponseCode(302)
.setHeader("Location", authenticationRequest().toURI())
}
"/callback" -> {
log.debug("got callback: $request")
val code = request.requestUrl?.queryParameter("code")!!
val tokenResponse = oauth2Client.newCall(
Request.Builder()
.url(metadata.tokenEndpointURI.toURL())
.post(
FormBody.Builder()
.add("client_id", "client1")
.add("scope", authenticationRequest().scope.toString())
.add("code", code)
.add("redirect_uri", exampleApp.url("/callback").toString())
.add("grant_type", "authorization_code")
.build()
)
.build()
).execute()
val idToken: String = ObjectMapper().readValue<JsonNode>(tokenResponse.body!!.string()).get("id_token").textValue()
val idTokenClaims: JWTClaimsSet = verifyJwt(idToken, metadata.issuer, retrieveJwks())
MockResponse()
.setResponseCode(200)
.setHeader("Set-Cookie", "id_token=$idToken")
.setBody("logged in as ${idTokenClaims.subject}")
}
"/secured" -> {
getCookies(request)["id_token"]
?.let {
verifyJwt(it, metadata.issuer, retrieveJwks())
}?.let {
MockResponse()
.setResponseCode(200)
.setBody("welcome ${it.subject}")
} ?: MockResponse().setResponseCode(302).setHeader("Location", exampleApp.url("/login"))
}
else -> MockResponse().setResponseCode(404)
}
}

private fun getCookies(request: RecordedRequest): Map<String, String> {
return request.getHeader("Cookie")
?.split(";")
?.filter { it.contains("=") }
?.associate {
val (key, value) = it.split("=")
key.trim() to value.trim()
} ?: emptyMap()
}

private fun authenticationRequest(): AuthenticationRequest =
AuthenticationRequest.parse(
metadata.authorizationEndpointURI,
mutableMapOf(
"client_id" to listOf("client"),
"response_type" to listOf("code"),
"redirect_uri" to listOf(exampleApp.url("/callback").toString()),
"response_mode" to listOf("query"),
"scope" to listOf("openid", "scope1"),
"state" to listOf("1234"),
"nonce" to listOf("5678")
)
)
}
Loading