diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/CorsConfiguration.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/CorsConfiguration.kt index 03415bbc7..9a8b7a01e 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/CorsConfiguration.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/CorsConfiguration.kt @@ -1,7 +1,7 @@ package org.genspectrum.lapis -import org.genspectrum.lapis.openApi.REQUEST_ID_HEADER -import org.genspectrum.lapis.request.LAPIS_DATA_VERSION_HEADER +import org.genspectrum.lapis.controller.LapisHeaders.LAPIS_DATA_VERSION +import org.genspectrum.lapis.controller.LapisHeaders.REQUEST_ID import org.springframework.context.annotation.Configuration import org.springframework.http.HttpHeaders.RETRY_AFTER import org.springframework.web.servlet.config.annotation.CorsRegistry @@ -14,7 +14,7 @@ class CorsConfiguration : WebMvcConfigurer { .allowedOrigins("*") .allowedMethods("GET", "POST", "OPTIONS") .allowedHeaders("*") - .exposedHeaders(LAPIS_DATA_VERSION_HEADER, REQUEST_ID_HEADER, RETRY_AFTER) + .exposedHeaders(LAPIS_DATA_VERSION, REQUEST_ID, RETRY_AFTER) .maxAge(3600) } } diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/CompressionFilter.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/CompressionFilter.kt index 725e2a3df..670c2ec32 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/CompressionFilter.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/CompressionFilter.kt @@ -9,17 +9,19 @@ import jakarta.servlet.http.HttpServletRequest import jakarta.servlet.http.HttpServletResponse import mu.KotlinLogging import org.genspectrum.lapis.util.CachedBodyHttpServletRequest -import org.genspectrum.lapis.util.HeaderModifyingRequestWrapper import org.springframework.boot.context.properties.bind.Binder import org.springframework.boot.web.servlet.server.Encoding import org.springframework.core.annotation.Order import org.springframework.core.env.Environment +import org.springframework.http.HttpHeaders import org.springframework.http.HttpHeaders.ACCEPT_ENCODING import org.springframework.http.HttpHeaders.CONTENT_ENCODING +import org.springframework.http.HttpHeaders.CONTENT_TYPE import org.springframework.http.HttpStatus import org.springframework.http.MediaType import org.springframework.http.ProblemDetail import org.springframework.http.converter.StringHttpMessageConverter +import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter import org.springframework.stereotype.Component import org.springframework.web.context.annotation.RequestScope import org.springframework.web.filter.OncePerRequestFilter @@ -30,9 +32,23 @@ import java.util.zip.GZIPOutputStream private val log = KotlinLogging.logger {} -enum class Compression(val value: String, val compressionOutputStreamFactory: (OutputStream) -> OutputStream) { - GZIP("gzip", ::LazyGzipOutputStream), - ZSTD("zstd", { ZstdOutputStream(it).apply { commitUnderlyingResponseToPreventContentLengthFromBeingSet() } }), +enum class Compression( + val value: String, + val contentType: MediaType, + val compressionOutputStreamFactory: (OutputStream) -> OutputStream, +) { + GZIP( + value = "gzip", + contentType = MediaType.parseMediaType("application/gzip"), + compressionOutputStreamFactory = ::LazyGzipOutputStream, + ), + ZSTD( + value = "zstd", + contentType = MediaType.parseMediaType("application/zstd"), + compressionOutputStreamFactory = { + ZstdOutputStream(it).apply { commitUnderlyingResponseToPreventContentLengthFromBeingSet() } + }, + ), ; companion object { @@ -78,7 +94,19 @@ fun ZstdOutputStream.commitUnderlyingResponseToPreventContentLengthFromBeingSet( @Component @RequestScope -class RequestCompression(var compression: Compression? = null) +class RequestCompression(var compressionSource: CompressionSource = CompressionSource.None) + +sealed interface CompressionSource { + data class RequestProperty(override var compression: Compression) : CompressionSource + + data class AcceptEncodingHeader(override var compression: Compression) : CompressionSource + + data object None : CompressionSource { + override val compression = null + } + + val compression: Compression? +} @Component @Order(COMPRESSION_FILTER_ORDER) @@ -91,8 +119,8 @@ class CompressionFilter(val objectMapper: ObjectMapper, val requestCompression: ) { val reReadableRequest = CachedBodyHttpServletRequest(request, objectMapper) - try { - validateCompressionProperty(reReadableRequest) + val compressionPropertyInRequest = try { + getValidatedCompressionProperty(reReadableRequest) } catch (e: UnknownCompressionFormatException) { response.status = HttpStatus.BAD_REQUEST.value() response.contentType = MediaType.APPLICATION_JSON_VALUE @@ -110,19 +138,14 @@ class CompressionFilter(val objectMapper: ObjectMapper, val requestCompression: return } - val requestWithContentEncoding = HeaderModifyingRequestWrapper( - reReadableRequest = reReadableRequest, - headerName = ACCEPT_ENCODING, - computeHeaderValueFromRequest = ::computeAcceptEncodingValueFromRequest, - ) - val maybeCompressingResponse = createMaybeCompressingResponse( response, - requestWithContentEncoding.getHeaders(ACCEPT_ENCODING), + reReadableRequest.getHeaders(ACCEPT_ENCODING), + compressionPropertyInRequest, ) filterChain.doFilter( - requestWithContentEncoding, + reReadableRequest, maybeCompressingResponse, ) @@ -130,47 +153,73 @@ class CompressionFilter(val objectMapper: ObjectMapper, val requestCompression: maybeCompressingResponse.outputStream.close() } - private fun validateCompressionProperty(reReadableRequest: CachedBodyHttpServletRequest) { - val compressionFormat = reReadableRequest.getStringField(COMPRESSION_PROPERTY) ?: return + private fun getValidatedCompressionProperty(reReadableRequest: CachedBodyHttpServletRequest): Compression? { + val compressionFormat = reReadableRequest.getStringField(COMPRESSION_PROPERTY) ?: return null - if (Compression.entries.toSet().none { it.value == compressionFormat }) { - throw UnknownCompressionFormatException(unknownFormatValue = compressionFormat) - } + return Compression.entries.toSet().find { it.value == compressionFormat } + ?: throw UnknownCompressionFormatException(unknownFormatValue = compressionFormat) } private fun createMaybeCompressingResponse( response: HttpServletResponse, acceptEncodingHeaders: Enumeration?, - ) = when (val compression = Compression.fromHeaders(acceptEncodingHeaders)) { - null -> response - else -> { - requestCompression.compression = compression - CompressingResponse(response, compression) + compressionPropertyInRequest: Compression?, + ): HttpServletResponse { + if (compressionPropertyInRequest != null) { + log.info { "Compressing using $compressionPropertyInRequest from request property" } + + requestCompression.compressionSource = CompressionSource.RequestProperty(compressionPropertyInRequest) + return CompressingResponse( + response, + compressionPropertyInRequest, + compressionPropertyInRequest.contentType.toString(), + ) } - } -} -private fun computeAcceptEncodingValueFromRequest(reReadableRequest: CachedBodyHttpServletRequest) = - when (reReadableRequest.getStringField(COMPRESSION_PROPERTY)) { - Compression.GZIP.value -> Compression.GZIP.value - Compression.ZSTD.value -> Compression.ZSTD.value - else -> null + val compression = Compression.fromHeaders(acceptEncodingHeaders) ?: return response + + log.info { "Compressing using $compression from $ACCEPT_ENCODING header" } + + requestCompression.compressionSource = CompressionSource.AcceptEncodingHeader(compression) + return CompressingResponse(response, compression, contentType = null) + .apply { + setHeader(CONTENT_ENCODING, compression.value) + } } +} private class UnknownCompressionFormatException(val unknownFormatValue: String) : Exception() class CompressingResponse( - response: HttpServletResponse, + private val response: HttpServletResponse, compression: Compression, + private val contentType: String?, ) : HttpServletResponse by response { init { - log.info { "Compressing using $compression" } - response.setHeader(CONTENT_ENCODING, compression.value) + if (contentType != null) { + response.setHeader(CONTENT_TYPE, contentType) + } } private val servletOutputStream = CompressingServletOutputStream(response.outputStream, compression) override fun getOutputStream() = servletOutputStream + + override fun getHeaders(name: String?): MutableCollection { + if (name == CONTENT_TYPE && contentType != null) { + return mutableListOf(contentType) + } + + return response.getHeaders(name) + } + + override fun getHeader(name: String): String? { + if (name == CONTENT_TYPE && contentType != null) { + return contentType + } + + return response.getHeader(name) + } } class CompressingServletOutputStream( @@ -198,16 +247,46 @@ class CompressingServletOutputStream( } } +@Component +class CompressionAwareMappingJackson2HttpMessageConverter( + objectMapper: ObjectMapper, + private val requestCompression: RequestCompression, +) : MappingJackson2HttpMessageConverter(objectMapper) { + override fun canWrite(mediaType: MediaType?): Boolean { + if (requestCompression.compressionSource.compression?.contentType?.isCompatibleWith(mediaType) == true) { + return true + } + + return super.canWrite(mediaType) + } + + override fun addDefaultHeaders( + headers: HttpHeaders, + value: Any, + contentType: MediaType?, + ) { + val compressionSource = requestCompression.compressionSource + if ( + compressionSource is CompressionSource.RequestProperty && + compressionSource.compression.contentType != contentType + ) { + headers.set(CONTENT_ENCODING, compressionSource.compression.value) + } + + super.addDefaultHeaders(headers, value, contentType) + } +} + @Component class StringHttpMessageConverterWithUnknownContentLengthInCaseOfCompression( environment: Environment, - val requestCompression: RequestCompression, + private val requestCompression: RequestCompression, ) : StringHttpMessageConverter(getCharsetFromEnvironment(environment)) { override fun getContentLength( str: String, contentType: MediaType?, ): Long? { - return when (requestCompression.compression) { + return when (requestCompression.compressionSource.compression) { null -> super.getContentLength(str, contentType) else -> null } diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/ControllerDescriptions.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/ControllerDescriptions.kt index a08140b0c..c8831f895 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/ControllerDescriptions.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/ControllerDescriptions.kt @@ -1,5 +1,7 @@ package org.genspectrum.lapis.controller +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_CSV_WITHOUT_HEADERS + const val DETAILS_ENDPOINT_DESCRIPTION = """Returns the specified metadata fields of sequences matching the filter.""" const val AGGREGATED_ENDPOINT_DESCRIPTION = """Returns the number of sequences matching the specified sequence filters.""" @@ -51,7 +53,7 @@ const val OFFSET_DESCRIPTION = This is useful for pagination in combination with \"limit\".""" const val FORMAT_DESCRIPTION = """The data format of the response. Alternatively, the data format can be specified by setting the \"Accept\"-header. -You can include the parameter to return the CSV/TSV without headers: "$TEXT_CSV_WITHOUT_HEADERS_HEADER". +You can include the parameter to return the CSV/TSV without headers: "$TEXT_CSV_WITHOUT_HEADERS". When both are specified, the request parameter takes precedence over the header.""" private const val MAYBE_DESCRIPTION = """ diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/DataFormatParameterFilter.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/DataFormatParameterFilter.kt index 97fbe565f..0a42f1585 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/DataFormatParameterFilter.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/DataFormatParameterFilter.kt @@ -14,10 +14,6 @@ import org.springframework.web.filter.OncePerRequestFilter const val HEADERS_ACCEPT_HEADER_PARAMETER = "headers" -const val TEXT_CSV_HEADER = "text/csv" -const val TEXT_CSV_WITHOUT_HEADERS_HEADER = "text/csv;$HEADERS_ACCEPT_HEADER_PARAMETER=false" -const val TEXT_TSV_HEADER = "text/tab-separated-values" - object DataFormat { const val JSON = "JSON" const val CSV = "CSV" @@ -47,9 +43,9 @@ class DataFormatParameterFilter(val objectMapper: ObjectMapper) : OncePerRequest private fun findAcceptHeaderOverwriteValue(reReadableRequest: CachedBodyHttpServletRequest) = when (reReadableRequest.getStringField(FORMAT_PROPERTY)?.uppercase()) { - DataFormat.CSV -> TEXT_CSV_HEADER - DataFormat.CSV_WITHOUT_HEADERS -> TEXT_CSV_WITHOUT_HEADERS_HEADER - DataFormat.TSV -> TEXT_TSV_HEADER + DataFormat.CSV -> LapisMediaType.TEXT_CSV + DataFormat.CSV_WITHOUT_HEADERS -> LapisMediaType.TEXT_CSV_WITHOUT_HEADERS + DataFormat.TSV -> LapisMediaType.TEXT_TSV DataFormat.JSON -> MediaType.APPLICATION_JSON_VALUE else -> null } diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/DownloadAsFileFilter.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/DownloadAsFileFilter.kt index 9e887ba9c..00e36ed12 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/DownloadAsFileFilter.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/DownloadAsFileFilter.kt @@ -4,6 +4,8 @@ import com.fasterxml.jackson.databind.ObjectMapper import jakarta.servlet.FilterChain import jakarta.servlet.http.HttpServletRequest import jakarta.servlet.http.HttpServletResponse +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_CSV +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_TSV import org.genspectrum.lapis.util.CachedBodyHttpServletRequest import org.springframework.core.annotation.Order import org.springframework.http.HttpHeaders.ACCEPT @@ -42,8 +44,8 @@ class DownloadAsFileFilter(private val objectMapper: ObjectMapper) : OncePerRequ } val fileEnding = when (request.getHeader(ACCEPT)) { - TEXT_CSV_HEADER -> "csv" - TEXT_TSV_HEADER -> "tsv" + TEXT_CSV -> "csv" + TEXT_TSV -> "tsv" else -> when (matchingRoute?.servesFasta) { true -> "fasta" else -> "json" diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/Headers.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/Headers.kt new file mode 100644 index 000000000..a21894230 --- /dev/null +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/Headers.kt @@ -0,0 +1,13 @@ +package org.genspectrum.lapis.controller + +object LapisHeaders { + const val REQUEST_ID = "X-Request-ID" + const val LAPIS_DATA_VERSION = "Lapis-Data-Version" +} + +object LapisMediaType { + const val TEXT_X_FASTA = "text/x-fasta" + const val TEXT_CSV = "text/csv" + const val TEXT_CSV_WITHOUT_HEADERS = "text/csv;$HEADERS_ACCEPT_HEADER_PARAMETER=false" + const val TEXT_TSV = "text/tab-separated-values" +} diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/LapisController.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/LapisController.kt index 7b31e663c..0292b8cd7 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/LapisController.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/LapisController.kt @@ -7,6 +7,9 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse import jakarta.servlet.http.HttpServletRequest import org.genspectrum.lapis.controller.Delimiter.COMMA import org.genspectrum.lapis.controller.Delimiter.TAB +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_CSV +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_TSV +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_X_FASTA import org.genspectrum.lapis.logging.RequestContext import org.genspectrum.lapis.model.SiloQueryModel import org.genspectrum.lapis.openApi.AGGREGATED_REQUEST_SCHEMA @@ -67,8 +70,6 @@ import org.springframework.web.bind.annotation.RequestMapping import org.springframework.web.bind.annotation.RequestParam import org.springframework.web.bind.annotation.RestController -const val TEXT_X_FASTA_HEADER = "text/x-fasta" - @RestController @RequestMapping("/sample") class LapisController( @@ -128,7 +129,7 @@ class LapisController( return LapisResponse(siloQueryModel.getAggregated(request)) } - @GetMapping(AGGREGATED_ROUTE, produces = [TEXT_CSV_HEADER]) + @GetMapping(AGGREGATED_ROUTE, produces = [TEXT_CSV]) @Operation( description = AGGREGATED_ENDPOINT_DESCRIPTION, operationId = "getAggregatedAsCsv", @@ -182,7 +183,7 @@ class LapisController( return getResponseAsCsv(request, httpHeaders.accept, COMMA, siloQueryModel::getAggregated) } - @GetMapping(AGGREGATED_ROUTE, produces = [TEXT_TSV_HEADER]) + @GetMapping(AGGREGATED_ROUTE, produces = [TEXT_TSV]) @Operation( description = AGGREGATED_ENDPOINT_DESCRIPTION, operationId = "getAggregatedAsTsv", @@ -251,7 +252,7 @@ class LapisController( return LapisResponse(siloQueryModel.getAggregated(request)) } - @PostMapping(AGGREGATED_ROUTE, produces = [TEXT_CSV_HEADER]) + @PostMapping(AGGREGATED_ROUTE, produces = [TEXT_CSV]) @Operation( description = AGGREGATED_ENDPOINT_DESCRIPTION, operationId = "postAggregatedAsCsv", @@ -266,7 +267,7 @@ class LapisController( return getResponseAsCsv(request, httpHeaders.accept, COMMA, siloQueryModel::getAggregated) } - @PostMapping(AGGREGATED_ROUTE, produces = [TEXT_TSV_HEADER]) + @PostMapping(AGGREGATED_ROUTE, produces = [TEXT_TSV]) @Operation( description = AGGREGATED_ENDPOINT_DESCRIPTION, operationId = "postAggregatedAsTsv", @@ -332,7 +333,7 @@ class LapisController( return LapisResponse(result) } - @GetMapping(NUCLEOTIDE_MUTATIONS_ROUTE, produces = [TEXT_CSV_HEADER]) + @GetMapping(NUCLEOTIDE_MUTATIONS_ROUTE, produces = [TEXT_CSV]) @Operation( description = NUCLEOTIDE_MUTATION_ENDPOINT_DESCRIPTION, operationId = "getNucleotideMutationsAsCsv", @@ -387,7 +388,7 @@ class LapisController( ) } - @GetMapping(NUCLEOTIDE_MUTATIONS_ROUTE, produces = [TEXT_TSV_HEADER]) + @GetMapping(NUCLEOTIDE_MUTATIONS_ROUTE, produces = [TEXT_TSV]) @Operation( description = NUCLEOTIDE_MUTATION_ENDPOINT_DESCRIPTION, operationId = "getNucleotideMutationsAsTsv", @@ -453,7 +454,7 @@ class LapisController( return LapisResponse(result) } - @PostMapping(NUCLEOTIDE_MUTATIONS_ROUTE, produces = [TEXT_CSV_HEADER]) + @PostMapping(NUCLEOTIDE_MUTATIONS_ROUTE, produces = [TEXT_CSV]) @Operation( description = NUCLEOTIDE_MUTATION_ENDPOINT_DESCRIPTION, operationId = "postNucleotideMutationsAsCsv", @@ -473,7 +474,7 @@ class LapisController( ) } - @PostMapping(NUCLEOTIDE_MUTATIONS_ROUTE, produces = [TEXT_TSV_HEADER]) + @PostMapping(NUCLEOTIDE_MUTATIONS_ROUTE, produces = [TEXT_TSV]) @Operation( description = NUCLEOTIDE_MUTATION_ENDPOINT_DESCRIPTION, operationId = "postNucleotideMutationsAsTsv", @@ -540,7 +541,7 @@ class LapisController( return LapisResponse(result) } - @GetMapping(AMINO_ACID_MUTATIONS_ROUTE, produces = [TEXT_CSV_HEADER]) + @GetMapping(AMINO_ACID_MUTATIONS_ROUTE, produces = [TEXT_CSV]) @Operation( description = AMINO_ACID_MUTATIONS_ENDPOINT_DESCRIPTION, operationId = "getAminoAcidMutationsAsCsv", @@ -595,7 +596,7 @@ class LapisController( ) } - @GetMapping(AMINO_ACID_MUTATIONS_ROUTE, produces = [TEXT_TSV_HEADER]) + @GetMapping(AMINO_ACID_MUTATIONS_ROUTE, produces = [TEXT_TSV]) @Operation( description = AMINO_ACID_MUTATIONS_ENDPOINT_DESCRIPTION, operationId = "getAminoAcidMutationsAsTsv", @@ -666,7 +667,7 @@ class LapisController( return LapisResponse(result) } - @PostMapping(AMINO_ACID_MUTATIONS_ROUTE, produces = [TEXT_CSV_HEADER]) + @PostMapping(AMINO_ACID_MUTATIONS_ROUTE, produces = [TEXT_CSV]) @Operation( description = AMINO_ACID_MUTATIONS_ENDPOINT_DESCRIPTION, operationId = "postAminoAcidMutationsAsCsv", @@ -688,7 +689,7 @@ class LapisController( ) } - @PostMapping(AMINO_ACID_MUTATIONS_ROUTE, produces = [TEXT_TSV_HEADER]) + @PostMapping(AMINO_ACID_MUTATIONS_ROUTE, produces = [TEXT_TSV]) @Operation( description = AMINO_ACID_MUTATIONS_ENDPOINT_DESCRIPTION, operationId = "postAminoAcidMutationsAsCsv", @@ -760,7 +761,7 @@ class LapisController( return LapisResponse(siloQueryModel.getDetails(request)) } - @GetMapping(DETAILS_ROUTE, produces = [TEXT_CSV_HEADER]) + @GetMapping(DETAILS_ROUTE, produces = [TEXT_CSV]) @Operation( operationId = "getDetailsAsCsv", responses = [ApiResponse(responseCode = "200")], @@ -810,7 +811,7 @@ class LapisController( return getResponseAsCsv(request, httpHeaders.accept, COMMA, siloQueryModel::getDetails) } - @GetMapping(DETAILS_ROUTE, produces = [TEXT_TSV_HEADER]) + @GetMapping(DETAILS_ROUTE, produces = [TEXT_TSV]) @Operation( description = DETAILS_ENDPOINT_DESCRIPTION, operationId = "getDetailsAsTsv", @@ -876,7 +877,7 @@ class LapisController( return LapisResponse(siloQueryModel.getDetails(request)) } - @PostMapping(DETAILS_ROUTE, produces = [TEXT_CSV_HEADER]) + @PostMapping(DETAILS_ROUTE, produces = [TEXT_CSV]) @Operation( description = DETAILS_ENDPOINT_DESCRIPTION, operationId = "postDetailsAsCsv", @@ -891,7 +892,7 @@ class LapisController( return getResponseAsCsv(request, httpHeaders.accept, COMMA, siloQueryModel::getDetails) } - @PostMapping(DETAILS_ROUTE, produces = [TEXT_TSV_HEADER]) + @PostMapping(DETAILS_ROUTE, produces = [TEXT_TSV]) @Operation( description = DETAILS_ENDPOINT_DESCRIPTION, operationId = "postDetailsAsTsv", @@ -954,7 +955,7 @@ class LapisController( return LapisResponse(result) } - @GetMapping(NUCLEOTIDE_INSERTIONS_ROUTE, produces = [TEXT_CSV_HEADER]) + @GetMapping(NUCLEOTIDE_INSERTIONS_ROUTE, produces = [TEXT_CSV]) @Operation( description = NUCLEOTIDE_INSERTIONS_ENDPOINT_DESCRIPTION, operationId = "getNucleotideInsertionsAsCsv", @@ -1006,7 +1007,7 @@ class LapisController( return getResponseAsCsv(request, httpHeaders.accept, COMMA, siloQueryModel::getNucleotideInsertions) } - @GetMapping(NUCLEOTIDE_INSERTIONS_ROUTE, produces = [TEXT_TSV_HEADER]) + @GetMapping(NUCLEOTIDE_INSERTIONS_ROUTE, produces = [TEXT_TSV]) @Operation( description = NUCLEOTIDE_INSERTIONS_ENDPOINT_DESCRIPTION, operationId = "getNucleotideInsertionsAsTsv", @@ -1074,7 +1075,7 @@ class LapisController( return LapisResponse(result) } - @PostMapping(NUCLEOTIDE_INSERTIONS_ROUTE, produces = [TEXT_CSV_HEADER]) + @PostMapping(NUCLEOTIDE_INSERTIONS_ROUTE, produces = [TEXT_CSV]) @Operation( description = NUCLEOTIDE_INSERTIONS_ENDPOINT_DESCRIPTION, operationId = "postNucleotideInsertionsAsCsv", @@ -1091,7 +1092,7 @@ class LapisController( return getResponseAsCsv(request, httpHeaders.accept, COMMA, siloQueryModel::getNucleotideInsertions) } - @PostMapping(NUCLEOTIDE_INSERTIONS_ROUTE, produces = [TEXT_TSV_HEADER]) + @PostMapping(NUCLEOTIDE_INSERTIONS_ROUTE, produces = [TEXT_TSV]) @Operation( description = NUCLEOTIDE_INSERTIONS_ENDPOINT_DESCRIPTION, operationId = "postNucleotideInsertionsAsTsv", @@ -1156,7 +1157,7 @@ class LapisController( return LapisResponse(result) } - @GetMapping(AMINO_ACID_INSERTIONS_ROUTE, produces = [TEXT_CSV_HEADER]) + @GetMapping(AMINO_ACID_INSERTIONS_ROUTE, produces = [TEXT_CSV]) @Operation( description = AMINO_ACID_INSERTIONS_ENDPOINT_DESCRIPTION, operationId = "getAminoAcidInsertionsAsCsv", @@ -1208,7 +1209,7 @@ class LapisController( return getResponseAsCsv(request, httpHeaders.accept, COMMA, siloQueryModel::getAminoAcidInsertions) } - @GetMapping(AMINO_ACID_INSERTIONS_ROUTE, produces = [TEXT_TSV_HEADER]) + @GetMapping(AMINO_ACID_INSERTIONS_ROUTE, produces = [TEXT_TSV]) @Operation( description = AMINO_ACID_INSERTIONS_ENDPOINT_DESCRIPTION, operationId = "getAminoAcidInsertionsAsTsv", @@ -1276,7 +1277,7 @@ class LapisController( return LapisResponse(result) } - @PostMapping(AMINO_ACID_INSERTIONS_ROUTE, produces = [TEXT_CSV_HEADER]) + @PostMapping(AMINO_ACID_INSERTIONS_ROUTE, produces = [TEXT_CSV]) @Operation( description = AMINO_ACID_INSERTIONS_ENDPOINT_DESCRIPTION, operationId = "postAminoAcidInsertionsAsCsv", @@ -1293,7 +1294,7 @@ class LapisController( return getResponseAsCsv(request, httpHeaders.accept, COMMA, siloQueryModel::getAminoAcidInsertions) } - @PostMapping(AMINO_ACID_INSERTIONS_ROUTE, produces = [TEXT_TSV_HEADER]) + @PostMapping(AMINO_ACID_INSERTIONS_ROUTE, produces = [TEXT_TSV]) @Operation( description = AMINO_ACID_INSERTIONS_ENDPOINT_DESCRIPTION, operationId = "postAminoAcidInsertionsAsTsv", @@ -1310,7 +1311,7 @@ class LapisController( return getResponseAsCsv(request, httpHeaders.accept, TAB, siloQueryModel::getAminoAcidInsertions) } - @GetMapping("$ALIGNED_AMINO_ACID_SEQUENCES_ROUTE/{gene}", produces = [TEXT_X_FASTA_HEADER]) + @GetMapping("$ALIGNED_AMINO_ACID_SEQUENCES_ROUTE/{gene}", produces = [TEXT_X_FASTA]) @LapisAlignedAminoAcidSequenceResponse fun getAlignedAminoAcidSequence( @PathVariable(name = "gene", required = true) @@ -1357,7 +1358,7 @@ class LapisController( return siloQueryModel.getGenomicSequence(request, SequenceType.ALIGNED, gene) } - @PostMapping("$ALIGNED_AMINO_ACID_SEQUENCES_ROUTE/{gene}", produces = [TEXT_X_FASTA_HEADER]) + @PostMapping("$ALIGNED_AMINO_ACID_SEQUENCES_ROUTE/{gene}", produces = [TEXT_X_FASTA]) @LapisAlignedAminoAcidSequenceResponse fun postAlignedAminoAcidSequence( @PathVariable(name = "gene", required = true) @@ -1401,8 +1402,8 @@ class LapisController( ): String? { val targetMediaType = MediaType.valueOf( when (delimiter) { - COMMA -> TEXT_CSV_HEADER - TAB -> TEXT_TSV_HEADER + COMMA -> TEXT_CSV + TAB -> TEXT_TSV }, ) return acceptHeader.find { it.includes(targetMediaType) } diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/MultiSegmentedSequenceController.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/MultiSegmentedSequenceController.kt index 923604305..e5e18a1f0 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/MultiSegmentedSequenceController.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/MultiSegmentedSequenceController.kt @@ -3,6 +3,7 @@ package org.genspectrum.lapis.controller import io.swagger.v3.oas.annotations.Parameter import io.swagger.v3.oas.annotations.media.Schema import org.genspectrum.lapis.config.REFERENCE_GENOME_SEGMENTS_APPLICATION_ARG_PREFIX +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_X_FASTA import org.genspectrum.lapis.logging.RequestContext import org.genspectrum.lapis.model.SiloQueryModel import org.genspectrum.lapis.openApi.AminoAcidInsertions @@ -44,7 +45,7 @@ class MultiSegmentedSequenceController( private val siloQueryModel: SiloQueryModel, private val requestContext: RequestContext, ) { - @GetMapping("$ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA_HEADER]) + @GetMapping("$ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA]) @LapisAlignedMultiSegmentedNucleotideSequenceResponse fun getAlignedNucleotideSequence( @PathVariable(name = "segment", required = true) @@ -95,7 +96,7 @@ class MultiSegmentedSequenceController( ) } - @PostMapping("$ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA_HEADER]) + @PostMapping("$ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA]) @LapisAlignedMultiSegmentedNucleotideSequenceResponse fun postAlignedNucleotideSequence( @PathVariable(name = "segment", required = true) @@ -114,7 +115,7 @@ class MultiSegmentedSequenceController( ) } - @GetMapping("$UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA_HEADER]) + @GetMapping("$UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA]) @LapisUnalignedMultiSegmentedNucleotideSequenceResponse fun getUnalignedNucleotideSequence( @PathVariable(name = "segment", required = true) @@ -165,7 +166,7 @@ class MultiSegmentedSequenceController( ) } - @PostMapping("$UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA_HEADER]) + @PostMapping("$UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA]) @LapisUnalignedMultiSegmentedNucleotideSequenceResponse fun postUnalignedNucleotideSequence( @PathVariable(name = "segment", required = true) diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/SingleSegmentedSequenceController.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/SingleSegmentedSequenceController.kt index c564d6806..31215ca1e 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/SingleSegmentedSequenceController.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/controller/SingleSegmentedSequenceController.kt @@ -4,6 +4,7 @@ import io.swagger.v3.oas.annotations.Parameter import io.swagger.v3.oas.annotations.media.Schema import org.genspectrum.lapis.config.REFERENCE_GENOME_SEGMENTS_APPLICATION_ARG_PREFIX import org.genspectrum.lapis.config.ReferenceGenomeSchema +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_X_FASTA import org.genspectrum.lapis.logging.RequestContext import org.genspectrum.lapis.model.SiloQueryModel import org.genspectrum.lapis.openApi.AminoAcidInsertions @@ -44,7 +45,7 @@ class SingleSegmentedSequenceController( private val requestContext: RequestContext, private val referenceGenomeSchema: ReferenceGenomeSchema, ) { - @GetMapping(ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA_HEADER]) + @GetMapping(ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA]) @LapisAlignedSingleSegmentedNucleotideSequenceResponse fun getAlignedNucleotideSequences( @PrimitiveFieldFilters @@ -92,7 +93,7 @@ class SingleSegmentedSequenceController( ) } - @PostMapping(ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA_HEADER]) + @PostMapping(ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA]) @LapisAlignedSingleSegmentedNucleotideSequenceResponse fun postAlignedNucleotideSequence( @Parameter(schema = Schema(ref = "#/components/schemas/$NUCLEOTIDE_SEQUENCE_REQUEST_SCHEMA")) @@ -108,7 +109,7 @@ class SingleSegmentedSequenceController( ) } - @GetMapping(UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA_HEADER]) + @GetMapping(UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA]) @LapisUnalignedSingleSegmentedNucleotideSequenceResponse fun getUnalignedNucleotideSequences( @PrimitiveFieldFilters @@ -156,7 +157,7 @@ class SingleSegmentedSequenceController( ) } - @PostMapping(UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA_HEADER]) + @PostMapping(UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA]) @LapisUnalignedSingleSegmentedNucleotideSequenceResponse fun postUnalignedNucleotideSequence( @Parameter(schema = Schema(ref = "#/components/schemas/$NUCLEOTIDE_SEQUENCE_REQUEST_SCHEMA")) diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/logging/RequestId.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/logging/RequestId.kt index c2e619dc0..0a4f87aa1 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/logging/RequestId.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/logging/RequestId.kt @@ -3,7 +3,7 @@ package org.genspectrum.lapis.logging import jakarta.servlet.FilterChain import jakarta.servlet.http.HttpServletRequest import jakarta.servlet.http.HttpServletResponse -import org.genspectrum.lapis.openApi.REQUEST_ID_HEADER +import org.genspectrum.lapis.controller.LapisHeaders.REQUEST_ID import org.slf4j.MDC import org.springframework.core.annotation.Order import org.springframework.stereotype.Component @@ -28,11 +28,11 @@ class RequestIdFilter(private val requestIdContext: RequestIdContext) : OncePerR response: HttpServletResponse, filterChain: FilterChain, ) { - val requestId = request.getHeader(REQUEST_ID_HEADER) ?: UUID.randomUUID().toString() + val requestId = request.getHeader(REQUEST_ID) ?: UUID.randomUUID().toString() MDC.put(REQUEST_ID_MDC_KEY, requestId) requestIdContext.requestId = requestId - response.addHeader(REQUEST_ID_HEADER, requestId) + response.addHeader(REQUEST_ID, requestId) try { filterChain.doFilter(request, response) diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/openApi/Schemas.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/openApi/Schemas.kt index 71a454f71..435097ec5 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/openApi/Schemas.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/openApi/Schemas.kt @@ -21,12 +21,13 @@ import org.genspectrum.lapis.controller.DETAILS_FIELDS_DESCRIPTION import org.genspectrum.lapis.controller.DETAILS_ORDER_BY_FIELDS_DESCRIPTION import org.genspectrum.lapis.controller.FORMAT_DESCRIPTION import org.genspectrum.lapis.controller.LIMIT_DESCRIPTION +import org.genspectrum.lapis.controller.LapisHeaders.LAPIS_DATA_VERSION +import org.genspectrum.lapis.controller.LapisHeaders.REQUEST_ID import org.genspectrum.lapis.controller.NUCLEOTIDE_INSERTIONS_ENDPOINT_DESCRIPTION import org.genspectrum.lapis.controller.NUCLEOTIDE_MUTATION_ENDPOINT_DESCRIPTION import org.genspectrum.lapis.controller.OFFSET_DESCRIPTION import org.genspectrum.lapis.controller.UNALIGNED_MULTI_SEGMENTED_NUCLEOTIDE_SEQUENCE_ENDPOINT_DESCRIPTION import org.genspectrum.lapis.controller.UNALIGNED_SINGLE_SEGMENTED_NUCLEOTIDE_SEQUENCE_ENDPOINT_DESCRIPTION -import org.genspectrum.lapis.request.LAPIS_DATA_VERSION_HEADER import org.springframework.core.annotation.AliasFor import org.springframework.http.HttpHeaders.ACCEPT_ENCODING import org.springframework.http.HttpHeaders.CONTENT_DISPOSITION @@ -71,9 +72,8 @@ const val LAPIS_DATA_VERSION_DESCRIPTION = "The data version of data in SILO." const val LAPIS_DATA_VERSION_HEADER_DESCRIPTION = "$LAPIS_DATA_VERSION_DESCRIPTION " + "Same as the value returned in the info object in the response body." const val LAPIS_DATA_VERSION_RESPONSE_DESCRIPTION = "$LAPIS_DATA_VERSION_DESCRIPTION " + - "Same as the value returned in the info object in the header '$LAPIS_DATA_VERSION_HEADER'." + "Same as the value returned in the info object in the header '$LAPIS_DATA_VERSION'." -const val REQUEST_ID_HEADER = "X-Request-ID" const val REQUEST_ID_HEADER_DESCRIPTION = """ A UUID that uniquely identifies the request for tracing purposes. If none if provided in the request, LAPIS will generate one. @@ -103,12 +103,12 @@ the other also grants access to detailed data. description = "OK", headers = [ Header( - name = LAPIS_DATA_VERSION_HEADER, + name = LAPIS_DATA_VERSION, description = LAPIS_DATA_VERSION_HEADER_DESCRIPTION, schema = Schema(type = "string"), ), Header( - name = REQUEST_ID_HEADER, + name = REQUEST_ID, description = REQUEST_ID_HEADER_DESCRIPTION, schema = Schema(type = "string"), ), diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/request/LapisInfo.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/request/LapisInfo.kt index d28aca4c1..39e73eb44 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/request/LapisInfo.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/request/LapisInfo.kt @@ -2,6 +2,7 @@ package org.genspectrum.lapis.request import io.swagger.v3.oas.annotations.media.Schema import org.genspectrum.lapis.controller.LapisErrorResponse +import org.genspectrum.lapis.controller.LapisHeaders.LAPIS_DATA_VERSION import org.genspectrum.lapis.controller.LapisResponse import org.genspectrum.lapis.logging.RequestIdContext import org.genspectrum.lapis.openApi.LAPIS_DATA_VERSION_EXAMPLE @@ -28,8 +29,6 @@ data class LapisInfo( var requestId: String? = null, ) -const val LAPIS_DATA_VERSION_HEADER = "Lapis-Data-Version" - @ControllerAdvice class ResponseBodyAdviceDataVersion( private val dataVersion: DataVersion, @@ -43,7 +42,7 @@ class ResponseBodyAdviceDataVersion( request: ServerHttpRequest, response: ServerHttpResponse, ): Any? { - response.headers.add(LAPIS_DATA_VERSION_HEADER, dataVersion.dataVersion) + response.headers.add(LAPIS_DATA_VERSION, dataVersion.dataVersion) val isDownload = response.headers.getFirst(HttpHeaders.CONTENT_DISPOSITION)?.startsWith("attachment") ?: false diff --git a/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/LapisControllerCompressionTest.kt b/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/LapisControllerCompressionTest.kt index 8ad369f85..f4be08301 100644 --- a/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/LapisControllerCompressionTest.kt +++ b/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/LapisControllerCompressionTest.kt @@ -4,6 +4,9 @@ import com.github.luben.zstd.ZstdInputStream import com.jayway.jsonpath.JsonPath import com.ninjasquad.springmockk.MockkBean import io.mockk.every +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_CSV +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_TSV +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_X_FASTA import org.genspectrum.lapis.controller.SampleRoute.AGGREGATED import org.genspectrum.lapis.controller.SampleRoute.ALIGNED_AMINO_ACID_SEQUENCES import org.genspectrum.lapis.controller.SampleRoute.ALIGNED_NUCLEOTIDE_SEQUENCES @@ -74,7 +77,12 @@ class LapisControllerCompressionTest( .andExpect(status().isOk) .andExpect(content().contentType(requestsScenario.expectedContentType)) .andExpect(header().doesNotExist(CONTENT_LENGTH)) - .andExpect(header().string(CONTENT_ENCODING, requestsScenario.compressionFormat)) + .andExpect( + when (requestsScenario.expectedContentEncoding) { + null -> header().doesNotExist(CONTENT_ENCODING) + else -> header().string(CONTENT_ENCODING, requestsScenario.expectedContentEncoding) + }, + ) .andReturn() val compressionFormat = requestsScenario.compressionFormat @@ -94,6 +102,27 @@ class LapisControllerCompressionTest( val response = mockMvc.perform(getSample("${AGGREGATED.pathSegment}?compression=$compressionFormat")) .andExpect(status().isBadRequest) + .andExpect(content().contentType(APPLICATION_JSON)) + .andExpect(header().string(CONTENT_ENCODING, compressionFormat)) + .andReturn() + + val decompressedContent = decompressContent(response, compressionFormat) + + val errorDetail = JsonPath.read(decompressedContent, "$.error.detail") + assertThat(errorDetail, `is`(errorMessage)) + } + + @ParameterizedTest + @MethodSource("getCompressionFormats") + fun `GIVEN model throws bad request WHEN accepting compressed data THEN it should return compressed error`( + compressionFormat: String, + ) { + val errorMessage = "test message" + every { siloQueryModelMock.getAggregated(any()) } throws BadRequestException(errorMessage) + + val response = mockMvc.perform(getSample(AGGREGATED.pathSegment).header(ACCEPT_ENCODING, compressionFormat)) + .andExpect(status().isBadRequest) + .andExpect(content().contentType(APPLICATION_JSON)) .andExpect(header().string(CONTENT_ENCODING, compressionFormat)) .andReturn() @@ -138,26 +167,22 @@ class LapisControllerCompressionTest( endpoint = it, dataFormat = MockDataCollection.DataFormat.CSV, compressionFormat = "gzip", - expectedContentType = "$TEXT_CSV_HEADER;charset=UTF-8", ) + getRequests( endpoint = it, dataFormat = MockDataCollection.DataFormat.CSV, compressionFormat = "zstd", - expectedContentType = "$TEXT_CSV_HEADER;charset=UTF-8", ) } + getRequests( AGGREGATED, dataFormat = MockDataCollection.DataFormat.NESTED_JSON, compressionFormat = "gzip", - expectedContentType = APPLICATION_JSON_VALUE, ) + getRequests( AGGREGATED, dataFormat = MockDataCollection.DataFormat.TSV, compressionFormat = "zstd", - expectedContentType = "$TEXT_TSV_HEADER;charset=UTF-8", ) + listOf( "${UNALIGNED_NUCLEOTIDE_SEQUENCES.pathSegment}/main", @@ -177,6 +202,7 @@ data class RequestScenario( val request: MockHttpServletRequestBuilder, val compressionFormat: String, val expectedContentType: String, + val expectedContentEncoding: String?, ) { override fun toString() = "$callDescription returns $compressionFormat compressed data" } @@ -185,14 +211,12 @@ fun getRequests( endpoint: SampleRoute, dataFormat: MockDataCollection.DataFormat, compressionFormat: String, - expectedContentType: String, -) = getRequests(endpoint.pathSegment, dataFormat, compressionFormat, expectedContentType) +) = getRequests(endpoint.pathSegment, dataFormat, compressionFormat) fun getRequests( endpoint: String, dataFormat: MockDataCollection.DataFormat, compressionFormat: String, - expectedContentType: String, ) = listOf( RequestScenario( callDescription = "GET $endpoint as $dataFormat with request parameter", @@ -201,7 +225,8 @@ fun getRequests( "$endpoint?country=Switzerland&dataFormat=$dataFormat&compression=$compressionFormat", ), compressionFormat = compressionFormat, - expectedContentType = expectedContentType, + expectedContentType = getContentTypeForCompressionFormat(compressionFormat), + expectedContentEncoding = null, ), RequestScenario( callDescription = "GET $endpoint as $dataFormat with accept header", @@ -209,7 +234,8 @@ fun getRequests( request = getSample("$endpoint?country=Switzerland&dataFormat=$dataFormat") .header(ACCEPT_ENCODING, compressionFormat), compressionFormat = compressionFormat, - expectedContentType = expectedContentType, + expectedContentType = getContentTypeForDataFormat(dataFormat), + expectedContentEncoding = compressionFormat, ), RequestScenario( callDescription = "POST $endpoint as $dataFormat with request parameter", @@ -220,7 +246,8 @@ fun getRequests( ) .contentType(APPLICATION_JSON), compressionFormat = compressionFormat, - expectedContentType = expectedContentType, + expectedContentType = getContentTypeForCompressionFormat(compressionFormat), + expectedContentEncoding = null, ), RequestScenario( callDescription = "POST $endpoint as $dataFormat with accept header", @@ -230,7 +257,8 @@ fun getRequests( .contentType(APPLICATION_JSON) .header(ACCEPT_ENCODING, compressionFormat), compressionFormat = compressionFormat, - expectedContentType = expectedContentType, + expectedContentType = getContentTypeForDataFormat(dataFormat), + expectedContentEncoding = compressionFormat, ), ) @@ -243,7 +271,8 @@ private fun getFastaRequests( mockData = MockDataForEndpoints.fastaMockData, request = getSample("$endpoint?country=Switzerland&compression=$compressionFormat"), compressionFormat = compressionFormat, - expectedContentType = "$TEXT_X_FASTA_HEADER;charset=UTF-8", + expectedContentType = getContentTypeForCompressionFormat(compressionFormat), + expectedContentEncoding = null, ), RequestScenario( callDescription = "GET $endpoint with accept header", @@ -251,7 +280,8 @@ private fun getFastaRequests( request = getSample("$endpoint?country=Switzerland") .header(ACCEPT_ENCODING, compressionFormat), compressionFormat = compressionFormat, - expectedContentType = "$TEXT_X_FASTA_HEADER;charset=UTF-8", + expectedContentType = "$TEXT_X_FASTA;charset=UTF-8", + expectedContentEncoding = compressionFormat, ), RequestScenario( callDescription = "POST $endpoint with request parameter", @@ -260,7 +290,8 @@ private fun getFastaRequests( .content("""{"country": "Switzerland", "compression": "$compressionFormat"}""") .contentType(APPLICATION_JSON), compressionFormat = compressionFormat, - expectedContentType = "$TEXT_X_FASTA_HEADER;charset=UTF-8", + expectedContentType = getContentTypeForCompressionFormat(compressionFormat), + expectedContentEncoding = null, ), RequestScenario( callDescription = "POST $endpoint with accept header", @@ -270,6 +301,22 @@ private fun getFastaRequests( .contentType(APPLICATION_JSON) .header(ACCEPT_ENCODING, compressionFormat), compressionFormat = compressionFormat, - expectedContentType = "$TEXT_X_FASTA_HEADER;charset=UTF-8", + expectedContentType = "$TEXT_X_FASTA;charset=UTF-8", + expectedContentEncoding = compressionFormat, ), ) + +private fun getContentTypeForCompressionFormat(compressionFormat: String) = + when (compressionFormat) { + "gzip" -> "application/gzip" + "zstd" -> "application/zstd" + else -> throw Exception("Test issue: unknown compression format $compressionFormat") + } + +private fun getContentTypeForDataFormat(dataFormat: MockDataCollection.DataFormat) = + when (dataFormat) { + MockDataCollection.DataFormat.PLAIN_JSON -> APPLICATION_JSON_VALUE + MockDataCollection.DataFormat.NESTED_JSON -> APPLICATION_JSON_VALUE + MockDataCollection.DataFormat.CSV -> "$TEXT_CSV;charset=UTF-8" + MockDataCollection.DataFormat.TSV -> "$TEXT_TSV;charset=UTF-8" + } diff --git a/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/LapisControllerCsvTest.kt b/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/LapisControllerCsvTest.kt index 5af3d7ffa..d721c8ed2 100644 --- a/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/LapisControllerCsvTest.kt +++ b/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/LapisControllerCsvTest.kt @@ -2,6 +2,9 @@ package org.genspectrum.lapis.controller import com.ninjasquad.springmockk.MockkBean import io.mockk.every +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_CSV +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_CSV_WITHOUT_HEADERS +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_TSV import org.genspectrum.lapis.model.SiloQueryModel import org.genspectrum.lapis.request.LapisInfo import org.junit.jupiter.api.BeforeEach @@ -169,9 +172,9 @@ class LapisControllerCsvTest( private fun getAcceptHeaderFor(dataFormat: String) = when (dataFormat) { - "csv" -> TEXT_CSV_HEADER - "csv-without-headers" -> TEXT_CSV_WITHOUT_HEADERS_HEADER - "tsv" -> TEXT_TSV_HEADER + "csv" -> TEXT_CSV + "csv-without-headers" -> TEXT_CSV_WITHOUT_HEADERS + "tsv" -> TEXT_TSV "json" -> MediaType.APPLICATION_JSON_VALUE else -> throw IllegalArgumentException("Unknown data format: $dataFormat") } diff --git a/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/MockData.kt b/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/MockData.kt index 49fafd47a..7f6b917df 100644 --- a/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/MockData.kt +++ b/lapis2/src/test/kotlin/org/genspectrum/lapis/controller/MockData.kt @@ -6,6 +6,8 @@ import com.fasterxml.jackson.databind.node.NullNode import com.fasterxml.jackson.databind.node.TextNode import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import io.mockk.every +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_CSV +import org.genspectrum.lapis.controller.LapisMediaType.TEXT_TSV import org.genspectrum.lapis.model.SiloQueryModel import org.genspectrum.lapis.response.AggregationData import org.genspectrum.lapis.response.AminoAcidInsertionResponse @@ -27,8 +29,8 @@ data class MockDataCollection( enum class DataFormat(val fileFormat: String, val acceptHeader: String) { PLAIN_JSON("json", APPLICATION_JSON_VALUE), NESTED_JSON("json", APPLICATION_JSON_VALUE), - CSV("csv", TEXT_CSV_HEADER), - TSV("tsv", TEXT_TSV_HEADER), + CSV("csv", TEXT_CSV), + TSV("tsv", TEXT_TSV), } companion object { diff --git a/siloLapisTests/test/common.spec.ts b/siloLapisTests/test/common.spec.ts index cce048eac..34e03c6da 100644 --- a/siloLapisTests/test/common.spec.ts +++ b/siloLapisTests/test/common.spec.ts @@ -1,5 +1,5 @@ import { expect } from 'chai'; -import { basePath, expectIsZstdEncoded } from './common'; +import { basePath, expectIsGzipEncoded, expectIsZstdEncoded } from './common'; const routes = [ { pathSegment: '/aggregated', servesFasta: false, expectedDownloadFilename: 'aggregated.json' }, @@ -45,12 +45,12 @@ describe('All endpoints', () => { for (const route of routes) { const url = `${basePath}/sample${route.pathSegment}`; - function get(params?: URLSearchParams) { + function get(params?: URLSearchParams, requestInit?: RequestInit) { if (params === undefined) { - return fetch(url); + return fetch(url, requestInit); } - return fetch(url + '?' + params.toString()); + return fetch(url + '?' + params.toString(), requestInit); } describe(`(${route.pathSegment})`, () => { @@ -72,22 +72,54 @@ describe('All endpoints', () => { expect(response.headers.get('lapis-data-version')).to.match(/\d{10}/); }); - it('should return zstd compressed data', async () => { + it('should return zstd compressed data when asking for compression', async () => { const urlParams = new URLSearchParams({ compression: 'zstd' }); const response = await get(urlParams); expect(response.status).equals(200); + expect(response.headers.get('content-type')).equals('application/zstd'); + expect(response.headers.get('content-encoding')).does.not.exist; + expectIsZstdEncoded(await response.arrayBuffer()); + }); + + it('should return zstd compressed data when accepting compression in header', async () => { + const urlParams = new URLSearchParams(); + + const response = await get(urlParams, { headers: { 'Accept-Encoding': 'zstd' } }); + + expect(response.status).equals(200); + if (route.servesFasta) { + expect(response.headers.get('content-type')).equals('text/x-fasta;charset=UTF-8'); + } else { + expect(response.headers.get('content-type')).equals('application/json'); + } expect(response.headers.get('content-encoding')).equals('zstd'); expectIsZstdEncoded(await response.arrayBuffer()); }); - it('should return gzip compressed data', async () => { + it('should return gzip compressed data when asking for compression', async () => { const urlParams = new URLSearchParams({ compression: 'gzip' }); const response = await get(urlParams); expect(response.status).equals(200); + expect(response.headers.get('content-type')).equals('application/gzip'); + expect(response.headers.get('content-encoding')).does.not.exist; + expectIsGzipEncoded(await response.arrayBuffer()); + }); + + it('should return gzip compressed data when accepting compression in header', async () => { + const urlParams = new URLSearchParams(); + + const response = await get(urlParams, { headers: { 'Accept-Encoding': 'gzip' } }); + + expect(response.status).equals(200); + if (route.servesFasta) { + expect(response.headers.get('content-type')).equals('text/x-fasta;charset=UTF-8'); + } else { + expect(response.headers.get('content-type')).equals('application/json'); + } expect(response.headers.get('content-encoding')).equals('gzip'); if (route.servesFasta) { diff --git a/siloLapisTests/test/common.ts b/siloLapisTests/test/common.ts index 5d106ccda..34b3268e4 100644 --- a/siloLapisTests/test/common.ts +++ b/siloLapisTests/test/common.ts @@ -61,3 +61,9 @@ export function expectIsZstdEncoded(arrayBuffer: ArrayBuffer) { expect([...first4Bytes]).deep.equals([Number('0x28'), Number('0xb5'), Number('0x2f'), Number('0xfd')]); } + +export function expectIsGzipEncoded(arrayBuffer: ArrayBuffer) { + const first2Bytes = new Uint8Array(arrayBuffer).slice(0, 2); + + expect([...first2Bytes]).deep.equals([Number('0x1f'), Number('0x8b')]); +}