Skip to content

Commit

Permalink
Add usage example tests (#2)
Browse files Browse the repository at this point in the history
* issueToken method is using its own TokenProvider, use the same as in config
* use safe call to clientIdAsString for clientid as subject
* add examples (apps and tests), some linting
  • Loading branch information
tommytroen authored Mar 24, 2020
1 parent fb191cc commit c2f7a28
Show file tree
Hide file tree
Showing 10 changed files with 465 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import no.nav.security.mock.oauth2.extensions.toWellKnownUrl
import no.nav.security.mock.oauth2.http.OAuth2HttpRequestHandler
import no.nav.security.mock.oauth2.http.OAuth2HttpResponse
import no.nav.security.mock.oauth2.token.OAuth2TokenCallback
import no.nav.security.mock.oauth2.token.OAuth2TokenProvider
import okhttp3.HttpUrl
import okhttp3.mockwebserver.Dispatcher
import okhttp3.mockwebserver.MockResponse
Expand All @@ -31,11 +30,9 @@ import java.util.concurrent.LinkedBlockingQueue
private val log = KotlinLogging.logger {}

class MockOAuth2Server(
config: OAuth2Config = OAuth2Config()
val config: OAuth2Config = OAuth2Config()
) {
private val mockWebServer: MockWebServer = MockWebServer()
private val tokenProvider: OAuth2TokenProvider =
OAuth2TokenProvider()

var dispatcher: Dispatcher = MockOAuth2Dispatcher(config)

Expand Down Expand Up @@ -76,7 +73,7 @@ class MockOAuth2Server(
ClientSecretBasic(ClientID(clientId), Secret("secret")),
AuthorizationCodeGrant(AuthorizationCode("123"), URI.create("http://localhost"))
)
return tokenProvider.accessToken(tokenRequest, issuerUrl, null, OAuth2TokenCallback)
return config.tokenProvider.accessToken(tokenRequest, issuerUrl, null, OAuth2TokenCallback)
}
}

Expand All @@ -95,7 +92,6 @@ class MockOAuth2Dispatcher(
else -> mockResponse(httpRequestHandler.handleRequest(request.asOAuth2HttpRequest()))
}


private fun mockResponse(response: OAuth2HttpResponse): MockResponse =
MockResponse()
.setHeaders(response.headers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ open class DefaultOAuth2TokenCallback(

override fun subject(tokenRequest: TokenRequest): String {
return when (GrantType.CLIENT_CREDENTIALS) {
tokenRequest.grantType() -> tokenRequest.clientID.value
tokenRequest.grantType() -> tokenRequest.clientIdAsString()
else -> subject
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class MockOAuth2ServerTest {
}

@Test
fun enqueuedResponse(){
fun enqueuedResponse() {
assertWellKnownResponseForIssuer("default")
server.enqueueResponse(MockResponse()
.setResponseCode(200)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package no.nav.security.mock.oauth2.examples

import com.fasterxml.jackson.databind.ObjectMapper
import com.nimbusds.jose.JOSEObjectType
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.JWKSet
import com.nimbusds.jose.jwk.source.ImmutableJWKSet
import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier
import com.nimbusds.jose.proc.JWSKeySelector
import com.nimbusds.jose.proc.JWSVerificationKeySelector
import com.nimbusds.jose.proc.SecurityContext
import com.nimbusds.jose.util.DefaultResourceRetriever
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier
import com.nimbusds.jwt.proc.DefaultJWTProcessor
import com.nimbusds.oauth2.sdk.id.Issuer
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata
import mu.KotlinLogging
import okhttp3.HttpUrl
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.mockwebserver.Dispatcher
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.MockWebServer
import okhttp3.mockwebserver.RecordedRequest
import java.net.URL
import java.util.HashSet

private val log = KotlinLogging.logger {}

abstract class AbstractExampleApp(oauth2DiscoveryUrl: String) {

val oauth2Client: OkHttpClient = OkHttpClient()
.newBuilder()
.followRedirects(false)
.build()

val metadata = OIDCProviderMetadata.parse(DefaultResourceRetriever().retrieveResource(URL(oauth2DiscoveryUrl)).content)

lateinit var exampleApp: MockWebServer

fun start() {
exampleApp = MockWebServer()
exampleApp.start()
exampleApp.dispatcher = object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse {
return runCatching {
handleRequest(request)
}.fold(
onSuccess = { result -> result },
onFailure = { error ->
log.error("received unhandled exception.", error)
MockResponse()
.setResponseCode(500)
.setBody("unhandled exception with message ${error.message}")
}
)
}
}
}

fun shutdown() {
exampleApp.shutdown()
}

fun url(path: String): HttpUrl = exampleApp.url(path)

fun retrieveJwks(): JWKSet {
return oauth2Client.newCall(
Request.Builder()
.url(metadata.jwkSetURI.toURL())
.get()
.build()
).execute().body?.string()?.let {
JWKSet.parse(it)
} ?: throw RuntimeException("could not retrieve jwks")
}

fun verifyJwt(jwt: String, issuer: Issuer, jwkSet: JWKSet): JWTClaimsSet {
val jwtProcessor: ConfigurableJWTProcessor<SecurityContext?> = DefaultJWTProcessor()
jwtProcessor.jwsTypeVerifier = DefaultJOSEObjectTypeVerifier(JOSEObjectType("JWT"))
val keySelector: JWSKeySelector<SecurityContext?> = JWSVerificationKeySelector(
JWSAlgorithm.RS256,
ImmutableJWKSet(jwkSet)
)
jwtProcessor.jwsKeySelector = keySelector
jwtProcessor.jwtClaimsSetVerifier = DefaultJWTClaimsVerifier(
JWTClaimsSet.Builder().issuer(issuer.toString()).build(),
HashSet(listOf("sub", "iat", "exp", "aud"))
)
return try {
jwtProcessor.process(jwt, null)
} catch (e: Exception) {
throw RuntimeException("invalid jwt.", e)
}
}

fun bearerToken(request: RecordedRequest): String? =
request.headers["Authorization"]
?.split("Bearer ")
?.let { it[0] }

fun notAuthorized(): MockResponse = MockResponse().setResponseCode(401)

fun json(value: Any): MockResponse = MockResponse()
.setResponseCode(200)
.setHeader("Content-Type","application/json")
.setBody(ObjectMapper().writeValueAsString(value))

abstract fun handleRequest(request: RecordedRequest): MockResponse
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package no.nav.security.mock.oauth2.examples.clientcredentials

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import no.nav.security.mock.oauth2.examples.AbstractExampleApp
import okhttp3.Credentials
import okhttp3.FormBody
import okhttp3.Request
import okhttp3.Response
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.RecordedRequest

class ExampleAppWithClientCredentialsClient(oauth2DiscoveryUrl: String) : AbstractExampleApp(oauth2DiscoveryUrl) {

override fun handleRequest(request: RecordedRequest): MockResponse {
return getClientCredentialsAccessToken()
?.let {
MockResponse()
.setResponseCode(200)
.setBody("token=$it")
}
?: MockResponse().setResponseCode(500).setBody("could not get access_token")
}

private fun getClientCredentialsAccessToken(): String? {
val tokenResponse: Response = oauth2Client.newCall(
Request.Builder()
.url(metadata.tokenEndpointURI.toURL())
.addHeader("Authorization", Credentials.basic("ExampleAppWithClientCredentialsClient", "test"))
.post(
FormBody.Builder()
.add("client_id", "ExampleAppWithClientCredentialsClient")
.add("scope", "scope1")
.add("grant_type", "client_credentials")
.build()
)
.build()
).execute()
return tokenResponse.body?.string()?.let {
ObjectMapper().readValue<JsonNode>(it).get("access_token")?.textValue()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package no.nav.security.mock.oauth2.examples.clientcredentials

import com.nimbusds.jwt.SignedJWT
import no.nav.security.mock.oauth2.MockOAuth2Server
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test

internal class ExampleAppWithClientCredentialsClientTest {
private lateinit var client: OkHttpClient
private lateinit var oAuth2Server: MockOAuth2Server
private lateinit var exampleApp: ExampleAppWithClientCredentialsClient

private val ISSUER_ID = "test"

@BeforeEach
fun before() {
oAuth2Server = MockOAuth2Server()
oAuth2Server.start()
exampleApp = ExampleAppWithClientCredentialsClient(oAuth2Server.wellKnownUrl(ISSUER_ID).toString())
exampleApp.start()
client = OkHttpClient().newBuilder().build()
}

@AfterEach
fun shutdown() {
oAuth2Server.shutdown()
exampleApp.shutdown()
}

@Test
fun appShouldReturnClientCredentialsAccessTokenWhenInvoked() {
val response: Response = client.newCall(
Request.Builder()
.url(exampleApp.url("/clientcredentials"))
.get()
.build()
).execute()
assertThat(response.code).isEqualTo(200)

val token: SignedJWT? = response.body?.string()
?.split("token=")
?.let { it[1] }
?.let { SignedJWT.parse(it) }

assertThat(token).isNotNull
assertThat(token?.jwtClaimsSet?.subject).isEqualTo("ExampleAppWithClientCredentialsClient")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package no.nav.security.mock.oauth2.examples.openidconnect

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.openid.connect.sdk.AuthenticationRequest
import mu.KotlinLogging
import no.nav.security.mock.oauth2.examples.AbstractExampleApp
import okhttp3.FormBody
import okhttp3.Request
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.RecordedRequest

private val log = KotlinLogging.logger {}

class ExampleAppWithOpenIdConnect(oidcDiscoveryUrl: String) : AbstractExampleApp(oidcDiscoveryUrl) {

override fun handleRequest(request: RecordedRequest): MockResponse {
return when (request.requestUrl?.encodedPath) {
"/login" -> {
MockResponse()
.setResponseCode(302)
.setHeader("Location", authenticationRequest().toURI())
}
"/callback" -> {
log.debug("got callback: $request")
val code = request.requestUrl?.queryParameter("code")!!
val tokenResponse = oauth2Client.newCall(
Request.Builder()
.url(metadata.tokenEndpointURI.toURL())
.post(
FormBody.Builder()
.add("client_id", "client1")
.add("scope", authenticationRequest().scope.toString())
.add("code", code)
.add("redirect_uri", exampleApp.url("/callback").toString())
.add("grant_type", "authorization_code")
.build()
)
.build()
).execute()
val idToken: String = ObjectMapper().readValue<JsonNode>(tokenResponse.body!!.string()).get("id_token").textValue()
val idTokenClaims: JWTClaimsSet = verifyJwt(idToken, metadata.issuer, retrieveJwks())
MockResponse()
.setResponseCode(200)
.setHeader("Set-Cookie", "id_token=$idToken")
.setBody("logged in as ${idTokenClaims.subject}")
}
"/secured" -> {
getCookies(request)["id_token"]
?.let {
verifyJwt(it, metadata.issuer, retrieveJwks())
}?.let {
MockResponse()
.setResponseCode(200)
.setBody("welcome ${it.subject}")
} ?: MockResponse().setResponseCode(302).setHeader("Location", exampleApp.url("/login"))
}
else -> MockResponse().setResponseCode(404)
}
}

private fun getCookies(request: RecordedRequest): Map<String, String> {
return request.getHeader("Cookie")
?.split(";")
?.filter { it.contains("=") }
?.associate {
val (key, value) = it.split("=")
key.trim() to value.trim()
} ?: emptyMap()
}

private fun authenticationRequest(): AuthenticationRequest =
AuthenticationRequest.parse(
metadata.authorizationEndpointURI,
mutableMapOf(
"client_id" to listOf("client"),
"response_type" to listOf("code"),
"redirect_uri" to listOf(exampleApp.url("/callback").toString()),
"response_mode" to listOf("query"),
"scope" to listOf("openid", "scope1"),
"state" to listOf("1234"),
"nonce" to listOf("5678")
)
)
}
Loading

0 comments on commit c2f7a28

Please sign in to comment.