diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/LapisSpringConfig.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/LapisSpringConfig.kt index 34469327f..4da0888a1 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/LapisSpringConfig.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/LapisSpringConfig.kt @@ -1,17 +1,15 @@ package org.genspectrum.lapis -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.dataformat.yaml.YAMLFactory import com.fasterxml.jackson.module.kotlin.readValue -import com.fasterxml.jackson.module.kotlin.registerKotlinModule import mu.KotlinLogging -import org.genspectrum.lapis.auth.DataOpennessAuthorizationFilter +import org.genspectrum.lapis.auth.DataOpennessAuthorizationFilterFactory import org.genspectrum.lapis.config.DatabaseConfig import org.genspectrum.lapis.config.SequenceFilterFields import org.genspectrum.lapis.logging.RequestContext import org.genspectrum.lapis.logging.RequestContextLogger import org.genspectrum.lapis.logging.StatisticsLogObjectMapper import org.genspectrum.lapis.util.TimeFactory +import org.genspectrum.lapis.util.YamlObjectMapper import org.springframework.beans.factory.annotation.Value import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Configuration @@ -24,8 +22,11 @@ class LapisSpringConfig { fun openAPI(sequenceFilterFields: SequenceFilterFields) = buildOpenApiSchema(sequenceFilterFields) @Bean - fun databaseConfig(@Value("\${lapis.databaseConfig.path}") configPath: String): DatabaseConfig { - return ObjectMapper(YAMLFactory()).registerKotlinModule().readValue(File(configPath)) + fun databaseConfig( + @Value("\${lapis.databaseConfig.path}") configPath: String, + yamlObjectMapper: YamlObjectMapper, + ): DatabaseConfig { + return yamlObjectMapper.objectMapper.readValue(File(configPath)) } @Bean @@ -55,6 +56,7 @@ class LapisSpringConfig { ) @Bean - fun dataOpennessAuthorizationFilter(databaseConfig: DatabaseConfig, objectMapper: ObjectMapper) = - DataOpennessAuthorizationFilter.createFromConfig(databaseConfig, objectMapper) + fun dataOpennessAuthorizationFilter( + dataOpennessAuthorizationFilterFactory: DataOpennessAuthorizationFilterFactory, + ) = dataOpennessAuthorizationFilterFactory.create() } diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/auth/DataOpennessAuthorizationFilter.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/auth/DataOpennessAuthorizationFilter.kt index 8612bffc1..682e87876 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/auth/DataOpennessAuthorizationFilter.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/auth/DataOpennessAuthorizationFilter.kt @@ -1,24 +1,43 @@ package org.genspectrum.lapis.auth import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.kotlin.readValue import jakarta.servlet.FilterChain import jakarta.servlet.http.HttpServletRequest import jakarta.servlet.http.HttpServletResponse +import org.genspectrum.lapis.config.AccessKeys +import org.genspectrum.lapis.config.AccessKeysReader import org.genspectrum.lapis.config.DatabaseConfig import org.genspectrum.lapis.config.OpennessLevel import org.genspectrum.lapis.controller.LapisHttpErrorResponse import org.springframework.http.HttpStatus import org.springframework.http.MediaType +import org.springframework.stereotype.Component import org.springframework.web.filter.OncePerRequestFilter +import org.springframework.web.util.ContentCachingRequestWrapper -abstract class DataOpennessAuthorizationFilter(val objectMapper: ObjectMapper) : OncePerRequestFilter() { +@Component +class DataOpennessAuthorizationFilterFactory( + private val databaseConfig: DatabaseConfig, + private val objectMapper: ObjectMapper, + private val accessKeysReader: AccessKeysReader, +) { + fun create() = when (databaseConfig.schema.opennessLevel) { + OpennessLevel.OPEN -> AlwaysAuthorizedAuthorizationFilter(objectMapper) + OpennessLevel.GISAID -> ProtectedGisaidDataAuthorizationFilter(objectMapper, accessKeysReader.read()) + } +} + +abstract class DataOpennessAuthorizationFilter(protected val objectMapper: ObjectMapper) : OncePerRequestFilter() { override fun doFilterInternal( request: HttpServletRequest, response: HttpServletResponse, filterChain: FilterChain, ) { - when (val result = isAuthorizedForEndpoint(request)) { - AuthorizationResult.Success -> filterChain.doFilter(request, response) + val reReadableRequest = ContentCachingRequestWrapper(request) + + when (val result = isAuthorizedForEndpoint(reReadableRequest)) { + AuthorizationResult.Success -> filterChain.doFilter(reReadableRequest, response) is AuthorizationResult.Failure -> { response.status = HttpStatus.FORBIDDEN.value() response.contentType = MediaType.APPLICATION_JSON_VALUE @@ -34,15 +53,7 @@ abstract class DataOpennessAuthorizationFilter(val objectMapper: ObjectMapper) : } } - abstract fun isAuthorizedForEndpoint(request: HttpServletRequest): AuthorizationResult - - companion object { - fun createFromConfig(databaseConfig: DatabaseConfig, objectMapper: ObjectMapper) = - when (databaseConfig.schema.opennessLevel) { - OpennessLevel.OPEN -> NoOpAuthorizationFilter(objectMapper) - OpennessLevel.GISAID -> ProtectedGisaidDataAuthorizationFilter(objectMapper) - } - } + abstract fun isAuthorizedForEndpoint(request: ContentCachingRequestWrapper): AuthorizationResult } sealed interface AuthorizationResult { @@ -52,24 +63,40 @@ sealed interface AuthorizationResult { fun failure(message: String): AuthorizationResult = Failure(message) } - fun isSuccessful(): Boolean - - object Success : AuthorizationResult { - override fun isSuccessful() = true - } + object Success : AuthorizationResult - class Failure(val message: String) : AuthorizationResult { - override fun isSuccessful() = false - } + class Failure(val message: String) : AuthorizationResult } -private class NoOpAuthorizationFilter(objectMapper: ObjectMapper) : DataOpennessAuthorizationFilter(objectMapper) { - override fun isAuthorizedForEndpoint(request: HttpServletRequest) = AuthorizationResult.success() +private class AlwaysAuthorizedAuthorizationFilter(objectMapper: ObjectMapper) : + DataOpennessAuthorizationFilter(objectMapper) { + + override fun isAuthorizedForEndpoint(request: ContentCachingRequestWrapper) = AuthorizationResult.success() } -private class ProtectedGisaidDataAuthorizationFilter(objectMapper: ObjectMapper) : +private class ProtectedGisaidDataAuthorizationFilter(objectMapper: ObjectMapper, private val accessKeys: AccessKeys) : DataOpennessAuthorizationFilter(objectMapper) { - override fun isAuthorizedForEndpoint(request: HttpServletRequest) = - AuthorizationResult.failure("An access key is required to access this endpoint.") + override fun isAuthorizedForEndpoint(request: ContentCachingRequestWrapper): AuthorizationResult { + val accessKey = request.getParameter("accessKey") + ?: getAccessKeyFromBody(request) + ?: return AuthorizationResult.failure("An access key is required to access this endpoint.") + + // TODO validate access keys + + return AuthorizationResult.failure("You are not authorized to access this endpoint.") + } + + private fun getAccessKeyFromBody(request: ContentCachingRequestWrapper): String? { + return if (request.contentLength > 0) { + try { + objectMapper.readValue>(request.reader)["accessKey"] + } catch (exception: Exception) { + // TODO logging + return null + } + } else { + null + } + } } diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/config/AccessKeys.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/config/AccessKeys.kt new file mode 100644 index 000000000..25eed585b --- /dev/null +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/config/AccessKeys.kt @@ -0,0 +1,23 @@ +package org.genspectrum.lapis.config + +import com.fasterxml.jackson.module.kotlin.readValue +import org.genspectrum.lapis.util.YamlObjectMapper +import org.springframework.beans.factory.annotation.Value +import org.springframework.stereotype.Component +import java.io.File + +@Component +class AccessKeysReader( + @Value("\${lapis.accessKeys.path:#{null}}") private val accessKeysFile: String?, + private val yamlObjectMapper: YamlObjectMapper, +) { + fun read(): AccessKeys { + if (accessKeysFile == null) { + throw IllegalArgumentException("Cannot read LAPIS access keys, lapis.accessKeys.path was not set.") + } + + return yamlObjectMapper.objectMapper.readValue(File(accessKeysFile)) + } +} + +data class AccessKeys(val fullAccessKey: String, val aggregatedDataAccessKey: String) diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/util/YamlObjectMapper.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/util/YamlObjectMapper.kt new file mode 100644 index 000000000..a7705f2b6 --- /dev/null +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/util/YamlObjectMapper.kt @@ -0,0 +1,11 @@ +package org.genspectrum.lapis.util + +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory +import com.fasterxml.jackson.module.kotlin.registerKotlinModule +import org.springframework.stereotype.Component + +@Component +object YamlObjectMapper { + val objectMapper: ObjectMapper = ObjectMapper(YAMLFactory()).registerKotlinModule() +} diff --git a/lapis2/src/test/kotlin/org/genspectrum/lapis/auth/GisaidAuthorizationTest.kt b/lapis2/src/test/kotlin/org/genspectrum/lapis/auth/GisaidAuthorizationTest.kt index 68e0aefe0..fa9282a17 100644 --- a/lapis2/src/test/kotlin/org/genspectrum/lapis/auth/GisaidAuthorizationTest.kt +++ b/lapis2/src/test/kotlin/org/genspectrum/lapis/auth/GisaidAuthorizationTest.kt @@ -34,7 +34,7 @@ class GisaidAuthorizationTest(@Autowired val mockMvc: MockMvc) { } @Test - fun `given no access key in request to GISAID instance, then access is denied`() { + fun `given no access key in GET request to GISAID instance, then access is denied`() { mockMvc.perform(MockMvcRequestBuilders.get(validRoute)) .andExpect(MockMvcResultMatchers.status().isForbidden) .andExpect(MockMvcResultMatchers.content().contentType(MediaType.APPLICATION_JSON)) @@ -49,4 +49,55 @@ class GisaidAuthorizationTest(@Autowired val mockMvc: MockMvc) { ), ) } + + @Test + fun `given no access key in POST request to GISAID instance, then access is denied`() { + mockMvc.perform(MockMvcRequestBuilders.post(validRoute)) + .andExpect(MockMvcResultMatchers.status().isForbidden) + .andExpect(MockMvcResultMatchers.content().contentType(MediaType.APPLICATION_JSON)) + .andExpect( + MockMvcResultMatchers.content().json( + """ + { + "title": "Forbidden", + "message": "An access key is required to access this endpoint." + } + """, + ), + ) + } + + @Test + fun `given wrong access key in GET request to GISAID instance, then access is denied`() { + mockMvc.perform(MockMvcRequestBuilders.get("$validRoute?accessKey=invalidKey")) + .andExpect(MockMvcResultMatchers.status().isForbidden) + .andExpect(MockMvcResultMatchers.content().contentType(MediaType.APPLICATION_JSON)) + .andExpect( + MockMvcResultMatchers.content().json( + """ + { + "title": "Forbidden", + "message": "You are not authorized to access this endpoint." + } + """, + ), + ) + } + + @Test + fun `given wrong access key in POST request to GISAID instance, then access is denied`() { + mockMvc.perform(MockMvcRequestBuilders.post(validRoute).content("""{"accessKey": "invalidKey"}""")) + .andExpect(MockMvcResultMatchers.status().isForbidden) + .andExpect(MockMvcResultMatchers.content().contentType(MediaType.APPLICATION_JSON)) + .andExpect( + MockMvcResultMatchers.content().json( + """ + { + "title": "Forbidden", + "message": "You are not authorized to access this endpoint." + } + """, + ), + ) + } } diff --git a/lapis2/src/test/kotlin/org/genspectrum/lapis/config/AccessKeysReaderTest.kt b/lapis2/src/test/kotlin/org/genspectrum/lapis/config/AccessKeysReaderTest.kt new file mode 100644 index 000000000..a19edac6b --- /dev/null +++ b/lapis2/src/test/kotlin/org/genspectrum/lapis/config/AccessKeysReaderTest.kt @@ -0,0 +1,36 @@ +package org.genspectrum.lapis.config + +import org.hamcrest.MatcherAssert.assertThat +import org.hamcrest.Matchers.equalTo +import org.hamcrest.Matchers.`is` +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.boot.test.context.SpringBootTest +import org.springframework.test.context.ActiveProfiles + +@SpringBootTest +class AccessKeysReaderTest { + @Autowired + lateinit var underTest: AccessKeysReader + + @Test + fun `given access keys file path as property then should successfully read access keys`() { + val result = underTest.read() + + assertThat(result.fullAccessKey, `is`(equalTo("testFullAccessKey"))) + assertThat(result.aggregatedDataAccessKey, `is`(equalTo("testAggregatedDataAccessKey"))) + } +} + +@SpringBootTest +@ActiveProfiles("testWithoutAccessKeys") +class AccessKeysReaderWithPathNotSetTest { + @Autowired + lateinit var underTest: AccessKeysReader + + @Test + fun `given access keys file path property is not set then should throw exception when reading access keys`() { + assertThrows { underTest.read() } + } +} diff --git a/lapis2/src/test/resources/application-test.properties b/lapis2/src/test/resources/application-test.properties index 7bb53e9ea..9665721ec 100644 --- a/lapis2/src/test/resources/application-test.properties +++ b/lapis2/src/test/resources/application-test.properties @@ -1,2 +1,3 @@ silo.url=http://url.to.silo lapis.databaseConfig.path=src/test/resources/config/testDatabaseConfig.yaml +lapis.accessKeys.path=src/test/resources/config/testAccessKeys.yaml diff --git a/lapis2/src/test/resources/application-testWithoutAccessKeys.properties b/lapis2/src/test/resources/application-testWithoutAccessKeys.properties new file mode 100644 index 000000000..fd626d23b --- /dev/null +++ b/lapis2/src/test/resources/application-testWithoutAccessKeys.properties @@ -0,0 +1,4 @@ +spring.config.import=file:src/main/resources/application.properties + +silo.url=http://url.to.silo +lapis.databaseConfig.path=src/test/resources/config/testDatabaseConfig.yaml diff --git a/lapis2/src/test/resources/config/testAccessKeys.yaml b/lapis2/src/test/resources/config/testAccessKeys.yaml new file mode 100644 index 000000000..505988abb --- /dev/null +++ b/lapis2/src/test/resources/config/testAccessKeys.yaml @@ -0,0 +1,2 @@ +fullAccessKey: testFullAccessKey +aggregatedDataAccessKey: testAggregatedDataAccessKey