Skip to content

Commit

Permalink
Fix broken security definition reference from security requirement fo…
Browse files Browse the repository at this point in the history
…r OAuth2 (ePages-de#220)

Fixes ePages-de#219
  • Loading branch information
Kieun authored Jan 4, 2023
1 parent dbd1788 commit e1bddea
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ class RestdocsOpenApiTaskTest : RestdocsOpenApiTaskTestBase() {

override fun thenSecurityDefinitionsFoundInOutputFile() {
with(JsonPath.parse(outputFolder.resolve("$outputFileNamePrefix.$format").readText())) {
then(read<String>("securityDefinitions.oauth2_accessCode.scopes.prod:r")).isEqualTo("Some text")
then(read<String>("securityDefinitions.oauth2_accessCode.type")).isEqualTo("oauth2")
then(read<String>("securityDefinitions.oauth2_accessCode.tokenUrl")).isNotEmpty()
then(read<String>("securityDefinitions.oauth2_accessCode.authorizationUrl")).isNotEmpty()
then(read<String>("securityDefinitions.oauth2_accessCode.flow")).isNotEmpty()
then(read<String>("securityDefinitions.oauth2.scopes.prod:r")).isEqualTo("Some text")
then(read<String>("securityDefinitions.oauth2.type")).isEqualTo("oauth2")
then(read<String>("securityDefinitions.oauth2.tokenUrl")).isNotEmpty()
then(read<String>("securityDefinitions.oauth2.authorizationUrl")).isNotEmpty()
then(read<String>("securityDefinitions.oauth2.flow")).isNotEmpty()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ open class Oauth2Configuration(
var flows: Array<String> = arrayOf(),
var scopes: Map<String, String> = mapOf()
) {
fun securitySchemeName(flow: String) = "oauth2_$flow"
fun securitySchemeName() = "oauth2"
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ object OpenApi20Generator {

private const val API_KEY_SECURITY_NAME = "api_key"
private const val BASIC_SECURITY_NAME = "basic"
private const val OAUTH2_SECURITY_NAME = "oauth2"
private val PATH_PARAMETER_PATTERN = """\{([^/}]+)}""".toRegex()
internal fun generate(
resources: List<ResourceModel>,
Expand Down Expand Up @@ -323,14 +324,7 @@ object OpenApi20Generator {
val securityRequirements = firstModelForPathAndMethod.request.securityRequirements
if (securityRequirements != null) {
when (securityRequirements.type) {
SecurityType.OAUTH2 -> oauth2SecuritySchemeDefinition?.flows?.map {
addSecurity(
oauth2SecuritySchemeDefinition.securitySchemeName(it),
securityRequirements2ScopesList(
securityRequirements
)
)
}
SecurityType.OAUTH2 -> addSecurity(OAUTH2_SECURITY_NAME, securityRequirements2ScopesList(securityRequirements))
SecurityType.BASIC -> addSecurity(BASIC_SECURITY_NAME, null)
SecurityType.API_KEY -> addSecurity(API_KEY_SECURITY_NAME, null)
}
Expand Down Expand Up @@ -372,7 +366,7 @@ object OpenApi20Generator {
addScope(it, scopeAndDescriptions.getOrDefault(it, "No description"))
}
}
openApi.addSecurityDefinition(oauth2SecuritySchemeDefinition.securitySchemeName(flow), oauth2Definition)
openApi.addSecurityDefinition(oauth2SecuritySchemeDefinition.securitySchemeName(), oauth2Definition)
}
if (hasAnyOperationWithSecurityName(
openApi,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ class OpenApi20GeneratorTest {
val openapi = whenOpenApiObjectGenerated(api)

with(openapi.securityDefinitions) {
then(this.containsKey("oauth2_accessCode"))
then(this["oauth2_accessCode"])
then(this.containsKey("oauth2"))
then(this["oauth2"])
.isEqualToComparingFieldByField(
OAuth2Definition().accessCode("http://example.com/authorize", "http://example.com/token")
.apply { addScope("prod:r", "No description") }
Expand Down Expand Up @@ -356,12 +356,12 @@ class OpenApi20GeneratorTest {
then(productPath.get.operationId).isNotEmpty()
then(productPath.get.consumes).contains(successfulGetProductModel.request.contentType)

then(productPath.get.security).hasSize(2)
then(productPath.get.security).hasSize(1)

then(productPath.get.tags).containsOnly("tag1", "tag2")

val combined = productPath.get.security.reduce { map1, map2 -> map1 + map2 }
then(combined).containsOnlyKeys("oauth2_application", "oauth2_accessCode")
then(combined).containsOnlyKeys("oauth2")
then(combined.values).containsOnly(listOf("prod:r"))

then(successfulGetResponse).isNotNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ object OpenApi3Generator {
)
}
)
}.apply { addSecurityItemFromSecurityRequirements(firstModelForPathAndMethod.request.securityRequirements, oauth2SecuritySchemeDefinition) }
}.apply { addSecurityItemFromSecurityRequirements(firstModelForPathAndMethod.request.securityRequirements) }
}

private fun operationId(operationIds: List<String>): String {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ internal object SecuritySchemeGenerator {
private const val API_KEY_SECURITY_NAME = "api_key"
private const val BASIC_SECURITY_NAME = "basic"
private const val JWT_BEARER_SECURITY_NAME = "bearerAuthJWT"
private const val OAUTH2_SECURITY_NAME = "oauth2"

fun OpenAPI.addSecurityDefinitions(oauth2SecuritySchemeDefinition: Oauth2Configuration?) {
if (oauth2SecuritySchemeDefinition?.flows?.isNotEmpty() == true) {
val flows = OAuthFlows()
components.addSecuritySchemes(
"oauth2",
OAUTH2_SECURITY_NAME,
SecurityScheme().apply {
type = SecurityScheme.Type.OAUTH2
this.flows = flows
Expand Down Expand Up @@ -90,17 +91,10 @@ internal object SecuritySchemeGenerator {
}
}

fun Operation.addSecurityItemFromSecurityRequirements(securityRequirements: SecurityRequirements?, oauth2SecuritySchemeDefinition: Oauth2Configuration?) {
fun Operation.addSecurityItemFromSecurityRequirements(securityRequirements: SecurityRequirements?) {
if (securityRequirements != null) {
when (securityRequirements.type) {
SecurityType.OAUTH2 -> oauth2SecuritySchemeDefinition?.flows?.map {
addSecurityItem(
SecurityRequirement().addList(
oauth2SecuritySchemeDefinition.securitySchemeName(it),
securityRequirements2ScopesList(securityRequirements)
)
)
}
SecurityType.OAUTH2 -> addSecurityItem(SecurityRequirement().addList(OAUTH2_SECURITY_NAME, securityRequirements2ScopesList(securityRequirements)))
SecurityType.BASIC -> addSecurityItem(SecurityRequirement().addList(BASIC_SECURITY_NAME))
SecurityType.API_KEY -> addSecurityItem(SecurityRequirement().addList(API_KEY_SECURITY_NAME))
SecurityType.JWT_BEARER -> addSecurityItem(SecurityRequirement().addList(JWT_BEARER_SECURITY_NAME))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,7 @@ class OpenApi3GeneratorTest {
then(openApiJsonPathContext.read<Any>("$productGetByIdPath.responses.200.content.application/json.schema.\$ref")).isNotNull()
then(openApiJsonPathContext.read<Any>("$productGetByIdPath.responses.200.content.application/json.examples.test.value")).isNotNull()

then(openApiJsonPathContext.read<List<List<String>>>("$productGetByIdPath.security[*].oauth2_clientCredentials").flatMap { it }).containsOnly("prod:r")
then(openApiJsonPathContext.read<List<List<String>>>("$productGetByIdPath.security[*].oauth2_authorizationCode").flatMap { it }).containsOnly("prod:r")
then(openApiJsonPathContext.read<List<List<String>>>("$productGetByIdPath.security[*].oauth2").flatMap { it }).containsOnly("prod:r")
}

private fun thenMultiplePathParametersExist() {
Expand Down

0 comments on commit e1bddea

Please sign in to comment.