Skip to content

Commit

Permalink
refactor(oauth2httprequest): simplify logic for proxy aware url
Browse files Browse the repository at this point in the history
  • Loading branch information
tommytroen committed Dec 15, 2022
1 parent 9974521 commit d1c70a3
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 67 deletions.
104 changes: 37 additions & 67 deletions src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,9 @@ import no.nav.security.mock.oauth2.extensions.toJwksUrl
import no.nav.security.mock.oauth2.extensions.toTokenEndpointUrl
import no.nav.security.mock.oauth2.extensions.toUserInfoUrl
import no.nav.security.mock.oauth2.grant.TokenExchangeGrant
import no.nav.security.mock.oauth2.http.RequestType.AUTHORIZATION
import no.nav.security.mock.oauth2.http.RequestType.DEBUGGER
import no.nav.security.mock.oauth2.http.RequestType.DEBUGGER_CALLBACK
import no.nav.security.mock.oauth2.http.RequestType.END_SESSION
import no.nav.security.mock.oauth2.http.RequestType.FAVICON
import no.nav.security.mock.oauth2.http.RequestType.INTROSPECT
import no.nav.security.mock.oauth2.http.RequestType.JWKS
import no.nav.security.mock.oauth2.http.RequestType.PREFLIGHT
import no.nav.security.mock.oauth2.http.RequestType.TOKEN
import no.nav.security.mock.oauth2.http.RequestType.UNKNOWN
import no.nav.security.mock.oauth2.http.RequestType.USER_INFO
import no.nav.security.mock.oauth2.http.RequestType.WELL_KNOWN
import no.nav.security.mock.oauth2.missingParameter
import okhttp3.Headers
import okhttp3.HttpUrl
import java.net.URI

data class OAuth2HttpRequest(
val headers: Headers,
Expand Down Expand Up @@ -83,21 +70,6 @@ data class OAuth2HttpRequest(

fun asAuthenticationRequest(): AuthenticationRequest = AuthenticationRequest.parse(this.url.toUri())

fun type() = when {
url.isWellKnownUrl() -> WELL_KNOWN
url.isAuthorizationEndpointUrl() -> AUTHORIZATION
url.isTokenEndpointUrl() -> TOKEN
url.isEndSessionEndpointUrl() -> END_SESSION
url.isUserInfoUrl() -> USER_INFO
url.isIntrospectUrl() -> INTROSPECT
url.isJwksUrl() -> JWKS
url.isDebuggerUrl() -> DEBUGGER
url.isDebuggerCallbackUrl() -> DEBUGGER_CALLBACK
url.encodedPath == "/favicon.ico" -> FAVICON
method == "OPTIONS" -> PREFLIGHT
else -> UNKNOWN
}

fun grantType(): GrantType =
this.formParameters.map["grant_type"]
?.ifBlank { null }
Expand All @@ -115,51 +87,49 @@ data class OAuth2HttpRequest(
userInfoEndpoint = this.proxyAwareUrl().toUserInfoUrl().toString()
)

internal fun proxyAwareUrl(): HttpUrl {
val hostheader = this.headers["host"]
val proto = this.headers["x-forwarded-proto"]
val port = this.headers["x-forwarded-port"]
return if (hostheader != null && proto != null) {
val hostUri = URI(null, hostheader, null, null, null).parseServerAuthority()
val hostFromHostHeader = hostUri.host
val portFromHostHeader = hostUri.port

HttpUrl.Builder()
.scheme(proto)
.host(hostFromHostHeader)
.apply {
port?.toInt()?.let {
port(it)
} ?: run {
if (portFromHostHeader != -1) {
port(portFromHostHeader)
}
}
internal fun proxyAwareUrl(): HttpUrl = HttpUrl.Builder()
.scheme(resolveScheme())
.host(resolveHost())
.port(resolvePort())
.encodedPath(originalUrl.encodedPath)
.query(originalUrl.query)
.build()

private fun resolveScheme(): String = headers["x-forwarded-proto"] ?: originalUrl.scheme

private fun resolveHost() = parseHostHeader()?.first ?: originalUrl.host

private fun resolvePort(): Int {
val xForwardedProto = this.headers["x-forwarded-proto"]
val xForwardedPort = this.headers["x-forwarded-port"]?.toInt() ?: -1
val hostHeaderPort = parseHostHeader()?.second ?: -1
return when {
xForwardedPort != -1 -> xForwardedPort
hostHeaderPort != -1 -> hostHeaderPort
xForwardedProto != null -> {
if (xForwardedProto == "https") {
443
} else {
80
}
.encodedPath(originalUrl.encodedPath)
.query(originalUrl.query).build()
} else {
hostheader?.let {
val hostUri = URI(originalUrl.scheme, hostheader, null, null, null).parseServerAuthority()
HttpUrl.Builder()
.scheme(hostUri.scheme)
.host(hostUri.host)
.port(hostUri.port)
.encodedPath(originalUrl.encodedPath)
.query(originalUrl.query)
.build()
} ?: originalUrl
}

else -> originalUrl.port
}
}

private fun parseHostHeader(): Pair<String, Int>? {
val hostHeader = this.headers["host"]
if (hostHeader != null) {
val hostPort = hostHeader.split(":")
val port = if (hostPort.size == 2) hostPort[1].toInt() else -1
return hostPort[0] to port
}
return null
}

data class Parameters(val parameterString: String?) {
val map: Map<String, String> = parameterString?.keyValuesToMap("&") ?: emptyMap()
fun get(name: String): String? = map[name]
}
}

enum class RequestType {
WELL_KNOWN, AUTHORIZATION, TOKEN, END_SESSION,
JWKS, DEBUGGER, DEBUGGER_CALLBACK, FAVICON,
PREFLIGHT, UNKNOWN, USER_INFO, INTROSPECT
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,61 @@ import org.junit.jupiter.api.Test

internal class OAuth2HttpRequestTest {

@Test
fun `proxy aware urls all usecases`() {
// no hostheader
"http://localhost:8080/mypath?query=1".GET().url shouldBe "http://localhost:8080/mypath?query=1".toHttpUrl()

// no host header, x-forwarded-proto set
"http://localhost:8080/mypath?query=1".GET(
"x-forwarded-proto",
"https"
).url shouldBe "https://localhost/mypath?query=1".toHttpUrl()

// host header overrides host and port in url
"http://localhost:8080/mypath?query=1".GET(
"host",
"localhost:8080"
).url shouldBe "http://localhost:8080/mypath?query=1".toHttpUrl()

// host header overrides host in url, port from original url should be used
"http://localhost:8080/mypath?query=1".GET(
"host",
"hostonly"
).url shouldBe "http://hostonly:8080/mypath?query=1".toHttpUrl()

// host header overrides host in url, port from original url should be used
"http://localhost:8080/mypath?query=1".GET(
"host",
"hostonly:-1"
).url shouldBe "http://hostonly:8080/mypath?query=1".toHttpUrl()

// host header present, x-forwarded-port overrides port in url
"http://localhost:8080/mypath?query=1".GET(
"host",
"host:8080",
"x-forwarded-port",
"9090"
).url shouldBe "http://host:9090/mypath?query=1".toHttpUrl()

// host header and x-forwarded- headers present
"http://localhost:8080/mypath?query=1".GET(
"host",
"hostheader:8080",
"x-forwarded-port",
"9090",
"x-forwarded-proto",
"https"
).url shouldBe "https://hostheader:9090/mypath?query=1".toHttpUrl()
}

private fun String.GET(vararg headers: String) =
OAuth2HttpRequest(
originalUrl = this.toHttpUrl(),
headers = Headers.headersOf(*headers),
method = "GET"
)

@Test
fun `proxyAwareUrl should use host header and x-forwarded-for- `() {
val req1 = OAuth2HttpRequest(
Expand Down Expand Up @@ -69,6 +124,36 @@ internal class OAuth2HttpRequestTest {
originalUrl = "http://localhost:8080/mypath?query=1".toHttpUrl()
)
req5.proxyAwareUrl().toString() shouldBe "https://fakedings.nais.io/mypath?query=1"

val req6 = OAuth2HttpRequest(
headers = Headers.headersOf(
"host",
"oauth2"
),
method = "GET",
originalUrl = "http://localhost:8080/mypath?query=1".toHttpUrl()
)
req6.proxyAwareUrl().toString() shouldBe "http://oauth2:8080/mypath?query=1"

val req7 = OAuth2HttpRequest(
headers = Headers.headersOf(
"host",
"oauth2:8080"
),
method = "GET",
originalUrl = "http://localhost:8080/mypath?query=1".toHttpUrl()
)
req7.proxyAwareUrl().toString() shouldBe "http://oauth2:8080/mypath?query=1"

val req8 = OAuth2HttpRequest(
headers = Headers.headersOf(
"host",
"oauth2"
),
method = "GET",
originalUrl = "https://somehost/mypath?query=1".toHttpUrl()
)
req8.proxyAwareUrl().toString() shouldBe "https://oauth2/mypath?query=1"
}

@Test
Expand Down

0 comments on commit d1c70a3

Please sign in to comment.