Skip to content

Commit

Permalink
feat: set Content-Type header to application/gzip|zstd when the compr…
Browse files Browse the repository at this point in the history
…ession property in the request was set #665

That way we lose information about the actual content of the request, but Chrome downloads zstd encoded data instead of failing to display it.
  • Loading branch information
fengelniederhammer committed Feb 29, 2024
1 parent 901e08a commit 5592857
Show file tree
Hide file tree
Showing 17 changed files with 310 additions and 126 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.genspectrum.lapis

import org.genspectrum.lapis.openApi.REQUEST_ID_HEADER
import org.genspectrum.lapis.request.LAPIS_DATA_VERSION_HEADER
import org.genspectrum.lapis.controller.LapisHeaders.LAPIS_DATA_VERSION
import org.genspectrum.lapis.controller.LapisHeaders.REQUEST_ID
import org.springframework.context.annotation.Configuration
import org.springframework.http.HttpHeaders.RETRY_AFTER
import org.springframework.web.servlet.config.annotation.CorsRegistry
Expand All @@ -14,7 +14,7 @@ class CorsConfiguration : WebMvcConfigurer {
.allowedOrigins("*")
.allowedMethods("GET", "POST", "OPTIONS")
.allowedHeaders("*")
.exposedHeaders(LAPIS_DATA_VERSION_HEADER, REQUEST_ID_HEADER, RETRY_AFTER)
.exposedHeaders(LAPIS_DATA_VERSION, REQUEST_ID, RETRY_AFTER)
.maxAge(3600)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import mu.KotlinLogging
import org.genspectrum.lapis.util.CachedBodyHttpServletRequest
import org.genspectrum.lapis.util.HeaderModifyingRequestWrapper
import org.springframework.boot.context.properties.bind.Binder
import org.springframework.boot.web.servlet.server.Encoding
import org.springframework.core.annotation.Order
import org.springframework.core.env.Environment
import org.springframework.http.HttpHeaders
import org.springframework.http.HttpHeaders.ACCEPT_ENCODING
import org.springframework.http.HttpHeaders.CONTENT_ENCODING
import org.springframework.http.HttpHeaders.CONTENT_TYPE
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType
import org.springframework.http.ProblemDetail
import org.springframework.http.converter.StringHttpMessageConverter
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter
import org.springframework.stereotype.Component
import org.springframework.web.context.annotation.RequestScope
import org.springframework.web.filter.OncePerRequestFilter
Expand All @@ -30,9 +32,23 @@ import java.util.zip.GZIPOutputStream

private val log = KotlinLogging.logger {}

enum class Compression(val value: String, val compressionOutputStreamFactory: (OutputStream) -> OutputStream) {
GZIP("gzip", ::LazyGzipOutputStream),
ZSTD("zstd", { ZstdOutputStream(it).apply { commitUnderlyingResponseToPreventContentLengthFromBeingSet() } }),
enum class Compression(
val value: String,
val contentType: MediaType,
val compressionOutputStreamFactory: (OutputStream) -> OutputStream,
) {
GZIP(
value = "gzip",
contentType = MediaType.parseMediaType("application/gzip"),
compressionOutputStreamFactory = ::LazyGzipOutputStream,
),
ZSTD(
value = "zstd",
contentType = MediaType.parseMediaType("application/zstd"),
compressionOutputStreamFactory = {
ZstdOutputStream(it).apply { commitUnderlyingResponseToPreventContentLengthFromBeingSet() }
},
),
;

companion object {
Expand Down Expand Up @@ -78,7 +94,19 @@ fun ZstdOutputStream.commitUnderlyingResponseToPreventContentLengthFromBeingSet(

@Component
@RequestScope
class RequestCompression(var compression: Compression? = null)
class RequestCompression(var compressionSource: CompressionSource = CompressionSource.None)

sealed interface CompressionSource {
data class RequestProperty(override var compression: Compression) : CompressionSource

data class AcceptEncodingHeader(override var compression: Compression) : CompressionSource

data object None : CompressionSource {
override val compression = null
}

val compression: Compression?
}

@Component
@Order(COMPRESSION_FILTER_ORDER)
Expand All @@ -91,8 +119,8 @@ class CompressionFilter(val objectMapper: ObjectMapper, val requestCompression:
) {
val reReadableRequest = CachedBodyHttpServletRequest(request, objectMapper)

try {
validateCompressionProperty(reReadableRequest)
val compressionPropertyInRequest = try {
getValidatedCompressionProperty(reReadableRequest)
} catch (e: UnknownCompressionFormatException) {
response.status = HttpStatus.BAD_REQUEST.value()
response.contentType = MediaType.APPLICATION_JSON_VALUE
Expand All @@ -110,67 +138,88 @@ class CompressionFilter(val objectMapper: ObjectMapper, val requestCompression:
return
}

val requestWithContentEncoding = HeaderModifyingRequestWrapper(
reReadableRequest = reReadableRequest,
headerName = ACCEPT_ENCODING,
computeHeaderValueFromRequest = ::computeAcceptEncodingValueFromRequest,
)

val maybeCompressingResponse = createMaybeCompressingResponse(
response,
requestWithContentEncoding.getHeaders(ACCEPT_ENCODING),
reReadableRequest.getHeaders(ACCEPT_ENCODING),
compressionPropertyInRequest,
)

filterChain.doFilter(
requestWithContentEncoding,
reReadableRequest,
maybeCompressingResponse,
)

maybeCompressingResponse.outputStream.flush()
maybeCompressingResponse.outputStream.close()
}

private fun validateCompressionProperty(reReadableRequest: CachedBodyHttpServletRequest) {
val compressionFormat = reReadableRequest.getStringField(COMPRESSION_PROPERTY) ?: return
private fun getValidatedCompressionProperty(reReadableRequest: CachedBodyHttpServletRequest): Compression? {
val compressionFormat = reReadableRequest.getStringField(COMPRESSION_PROPERTY) ?: return null

if (Compression.entries.toSet().none { it.value == compressionFormat }) {
throw UnknownCompressionFormatException(unknownFormatValue = compressionFormat)
}
return Compression.entries.toSet().find { it.value == compressionFormat }
?: throw UnknownCompressionFormatException(unknownFormatValue = compressionFormat)
}

private fun createMaybeCompressingResponse(
response: HttpServletResponse,
acceptEncodingHeaders: Enumeration<String>?,
) = when (val compression = Compression.fromHeaders(acceptEncodingHeaders)) {
null -> response
else -> {
requestCompression.compression = compression
CompressingResponse(response, compression)
compressionPropertyInRequest: Compression?,
): HttpServletResponse {
if (compressionPropertyInRequest != null) {
log.info { "Compressing using $compressionPropertyInRequest from request property" }

requestCompression.compressionSource = CompressionSource.RequestProperty(compressionPropertyInRequest)
return CompressingResponse(
response,
compressionPropertyInRequest,
compressionPropertyInRequest.contentType.toString(),
)
}
}
}

private fun computeAcceptEncodingValueFromRequest(reReadableRequest: CachedBodyHttpServletRequest) =
when (reReadableRequest.getStringField(COMPRESSION_PROPERTY)) {
Compression.GZIP.value -> Compression.GZIP.value
Compression.ZSTD.value -> Compression.ZSTD.value
else -> null
val compression = Compression.fromHeaders(acceptEncodingHeaders) ?: return response

log.info { "Compressing using $compression from $ACCEPT_ENCODING header" }

requestCompression.compressionSource = CompressionSource.AcceptEncodingHeader(compression)
return CompressingResponse(response, compression, contentType = null)
.apply {
setHeader(CONTENT_ENCODING, compression.value)
}
}
}

private class UnknownCompressionFormatException(val unknownFormatValue: String) : Exception()

class CompressingResponse(
response: HttpServletResponse,
private val response: HttpServletResponse,
compression: Compression,
private val contentType: String?,
) : HttpServletResponse by response {
init {
log.info { "Compressing using $compression" }
response.setHeader(CONTENT_ENCODING, compression.value)
if (contentType != null) {
response.setHeader(CONTENT_TYPE, contentType)
}
}

private val servletOutputStream = CompressingServletOutputStream(response.outputStream, compression)

override fun getOutputStream() = servletOutputStream

override fun getHeaders(name: String?): MutableCollection<String> {
if (name == CONTENT_TYPE && contentType != null) {
return mutableListOf(contentType)
}

return response.getHeaders(name)
}

override fun getHeader(name: String): String? {
if (name == CONTENT_TYPE && contentType != null) {
return contentType
}

return response.getHeader(name)
}
}

class CompressingServletOutputStream(
Expand Down Expand Up @@ -198,16 +247,46 @@ class CompressingServletOutputStream(
}
}

@Component
class CompressionAwareMappingJackson2HttpMessageConverter(
objectMapper: ObjectMapper,
private val requestCompression: RequestCompression,
) : MappingJackson2HttpMessageConverter(objectMapper) {
override fun canWrite(mediaType: MediaType?): Boolean {
if (requestCompression.compressionSource.compression?.contentType?.isCompatibleWith(mediaType) == true) {
return true
}

return super.canWrite(mediaType)
}

override fun addDefaultHeaders(
headers: HttpHeaders,
value: Any,
contentType: MediaType?,
) {
val compressionSource = requestCompression.compressionSource
if (
compressionSource is CompressionSource.RequestProperty &&
compressionSource.compression.contentType != contentType
) {
headers.set(CONTENT_ENCODING, compressionSource.compression.value)
}

super.addDefaultHeaders(headers, value, contentType)
}
}

@Component
class StringHttpMessageConverterWithUnknownContentLengthInCaseOfCompression(
environment: Environment,
val requestCompression: RequestCompression,
private val requestCompression: RequestCompression,
) : StringHttpMessageConverter(getCharsetFromEnvironment(environment)) {
override fun getContentLength(
str: String,
contentType: MediaType?,
): Long? {
return when (requestCompression.compression) {
return when (requestCompression.compressionSource.compression) {
null -> super.getContentLength(str, contentType)
else -> null
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.genspectrum.lapis.controller

import org.genspectrum.lapis.controller.LapisMediaType.TEXT_CSV_WITHOUT_HEADERS

const val DETAILS_ENDPOINT_DESCRIPTION = """Returns the specified metadata fields of sequences matching the filter."""
const val AGGREGATED_ENDPOINT_DESCRIPTION =
"""Returns the number of sequences matching the specified sequence filters."""
Expand Down Expand Up @@ -51,7 +53,7 @@ const val OFFSET_DESCRIPTION =
This is useful for pagination in combination with \"limit\"."""
const val FORMAT_DESCRIPTION = """The data format of the response.
Alternatively, the data format can be specified by setting the \"Accept\"-header.
You can include the parameter to return the CSV/TSV without headers: "$TEXT_CSV_WITHOUT_HEADERS_HEADER".
You can include the parameter to return the CSV/TSV without headers: "$TEXT_CSV_WITHOUT_HEADERS".
When both are specified, the request parameter takes precedence over the header."""

private const val MAYBE_DESCRIPTION = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ import org.springframework.web.filter.OncePerRequestFilter

const val HEADERS_ACCEPT_HEADER_PARAMETER = "headers"

const val TEXT_CSV_HEADER = "text/csv"
const val TEXT_CSV_WITHOUT_HEADERS_HEADER = "text/csv;$HEADERS_ACCEPT_HEADER_PARAMETER=false"
const val TEXT_TSV_HEADER = "text/tab-separated-values"

object DataFormat {
const val JSON = "JSON"
const val CSV = "CSV"
Expand Down Expand Up @@ -47,9 +43,9 @@ class DataFormatParameterFilter(val objectMapper: ObjectMapper) : OncePerRequest

private fun findAcceptHeaderOverwriteValue(reReadableRequest: CachedBodyHttpServletRequest) =
when (reReadableRequest.getStringField(FORMAT_PROPERTY)?.uppercase()) {
DataFormat.CSV -> TEXT_CSV_HEADER
DataFormat.CSV_WITHOUT_HEADERS -> TEXT_CSV_WITHOUT_HEADERS_HEADER
DataFormat.TSV -> TEXT_TSV_HEADER
DataFormat.CSV -> LapisMediaType.TEXT_CSV
DataFormat.CSV_WITHOUT_HEADERS -> LapisMediaType.TEXT_CSV_WITHOUT_HEADERS
DataFormat.TSV -> LapisMediaType.TEXT_TSV
DataFormat.JSON -> MediaType.APPLICATION_JSON_VALUE
else -> null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import com.fasterxml.jackson.databind.ObjectMapper
import jakarta.servlet.FilterChain
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.genspectrum.lapis.controller.LapisMediaType.TEXT_CSV
import org.genspectrum.lapis.controller.LapisMediaType.TEXT_TSV
import org.genspectrum.lapis.util.CachedBodyHttpServletRequest
import org.springframework.core.annotation.Order
import org.springframework.http.HttpHeaders.ACCEPT
Expand Down Expand Up @@ -42,8 +44,8 @@ class DownloadAsFileFilter(private val objectMapper: ObjectMapper) : OncePerRequ
}

val fileEnding = when (request.getHeader(ACCEPT)) {
TEXT_CSV_HEADER -> "csv"
TEXT_TSV_HEADER -> "tsv"
TEXT_CSV -> "csv"
TEXT_TSV -> "tsv"
else -> when (matchingRoute?.servesFasta) {
true -> "fasta"
else -> "json"
Expand Down
13 changes: 13 additions & 0 deletions lapis2/src/main/kotlin/org/genspectrum/lapis/controller/Headers.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.genspectrum.lapis.controller

object LapisHeaders {
const val REQUEST_ID = "X-Request-ID"
const val LAPIS_DATA_VERSION = "Lapis-Data-Version"
}

object LapisMediaType {
const val TEXT_X_FASTA = "text/x-fasta"
const val TEXT_CSV = "text/csv"
const val TEXT_CSV_WITHOUT_HEADERS = "text/csv;$HEADERS_ACCEPT_HEADER_PARAMETER=false"
const val TEXT_TSV = "text/tab-separated-values"
}
Loading

0 comments on commit 5592857

Please sign in to comment.