Skip to content

Commit

Permalink
feat: implement compressing the response as zstd and gzip #600
Browse files Browse the repository at this point in the history
  • Loading branch information
fengelniederhammer committed Feb 14, 2024
1 parent 11dc214 commit d6e8d9d
Show file tree
Hide file tree
Showing 20 changed files with 1,002 additions and 394 deletions.
1 change: 1 addition & 0 deletions lapis2/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies {
antlr 'org.antlr:antlr4:4.13.1'
implementation 'org.antlr:antlr4-runtime:4.13.1'
implementation 'org.apache.commons:commons-csv:1.10.0'
implementation 'com.github.luben:zstd-jni:1.5.0-4'

testImplementation('org.springframework.boot:spring-boot-starter-test') {
exclude group: "org.mockito"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package org.genspectrum.lapis.controller

import com.fasterxml.jackson.databind.ObjectMapper
import com.github.luben.zstd.ZstdOutputStream
import jakarta.servlet.FilterChain
import jakarta.servlet.ServletOutputStream
import jakarta.servlet.WriteListener
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import mu.KotlinLogging
import org.genspectrum.lapis.util.CachedBodyHttpServletRequest
import org.genspectrum.lapis.util.HeaderModifyingRequestWrapper
import org.springframework.core.annotation.Order
import org.springframework.http.HttpHeaders.ACCEPT_ENCODING
import org.springframework.http.HttpHeaders.CONTENT_ENCODING
import org.springframework.http.HttpHeaders.TRANSFER_ENCODING
import org.springframework.stereotype.Component
import org.springframework.web.filter.OncePerRequestFilter
import java.io.OutputStream
import java.util.Enumeration
import java.util.zip.GZIPOutputStream

private val log = KotlinLogging.logger {}

enum class Compression(val value: String, val compressionOutputStreamFactory: (OutputStream) -> OutputStream) {
GZIP("gzip", ::GZIPOutputStream),
ZSTD("zstd", { ZstdOutputStream(it).apply { commitUnderlyingResponseToPreventContentLengthFromBeingSet() } }),
;

companion object {
fun fromHeaders(acceptEncodingHeaders: Enumeration<String>?): Compression? {
if (acceptEncodingHeaders == null) {
return null
}

val headersList = acceptEncodingHeaders.toList()

return when {
headersList.contains(GZIP.value) -> GZIP
headersList.contains(ZSTD.value) -> ZSTD
else -> null
}
}
}
}

// https://github.com/apache/tomcat/blob/10e3731f344cd0d018d4be2ee767c105d2832283/java/org/apache/catalina/connector/OutputBuffer.java#L223-L229
fun ZstdOutputStream.commitUnderlyingResponseToPreventContentLengthFromBeingSet() {
val nothing = ByteArray(0)
write(nothing)
}

@Component
@Order(DOWNLOAD_AS_FILE_FILTER_ORDER - 1)
class CompressionFilter(val objectMapper: ObjectMapper) : OncePerRequestFilter() {
override fun doFilterInternal(
request: HttpServletRequest,
response: HttpServletResponse,
filterChain: FilterChain,
) {
val reReadableRequest = CachedBodyHttpServletRequest(request, objectMapper)

val requestWithContentEncoding = HeaderModifyingRequestWrapper(
reReadableRequest = reReadableRequest,
headerName = ACCEPT_ENCODING,
computeHeaderValueFromRequest = ::computeAcceptEncodingValueFromRequest,
)

val maybeCompressingResponse = createMaybeCompressingResponse(
response,
requestWithContentEncoding.getHeaders(ACCEPT_ENCODING),
)

filterChain.doFilter(
requestWithContentEncoding,
maybeCompressingResponse,
)

maybeCompressingResponse.outputStream.flush()
maybeCompressingResponse.outputStream.close()
}

private fun createMaybeCompressingResponse(
response: HttpServletResponse,
acceptEncodingHeaders: Enumeration<String>?,
) = when (val compression = Compression.fromHeaders(acceptEncodingHeaders)) {
null -> response
else -> CompressingResponse(response, compression)
}
}

private fun computeAcceptEncodingValueFromRequest(reReadableRequest: CachedBodyHttpServletRequest) =
when (reReadableRequest.getStringField(COMPRESSION_PROPERTY)) {
Compression.GZIP.value -> Compression.GZIP.value
Compression.ZSTD.value -> Compression.ZSTD.value
else -> null
}

class CompressingResponse(
response: HttpServletResponse,
compression: Compression,
) : HttpServletResponse by response {
init {
log.info { "Compressing using $compression" }
response.setHeader(CONTENT_ENCODING, compression.value)
preventSpringFromSettingTheContentLengthWhichIsUnknownWhenCompressing(response)
}

private fun preventSpringFromSettingTheContentLengthWhichIsUnknownWhenCompressing(response: HttpServletResponse) {
response.addHeader(TRANSFER_ENCODING, "chunked")
}

private val servletOutputStream = CompressingServletOutputStream(response.outputStream, compression)

override fun getOutputStream() = servletOutputStream
}

class CompressingServletOutputStream(
private val outputStream: ServletOutputStream,
compression: Compression,
) : ServletOutputStream() {
private val compressingStream = compression.compressionOutputStreamFactory(outputStream)

override fun write(byte: Int) {
compressingStream.write(byte)
}

override fun isReady() = outputStream.isReady

override fun setWriteListener(listener: WriteListener?) = outputStream.setWriteListener(listener)

override fun close() {
super.close()
compressingStream.close()
}

override fun flush() {
super.flush()
compressingStream.flush()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@ package org.genspectrum.lapis.controller
import com.fasterxml.jackson.databind.ObjectMapper
import jakarta.servlet.FilterChain
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletRequestWrapper
import jakarta.servlet.http.HttpServletResponse
import mu.KotlinLogging
import org.genspectrum.lapis.util.CachedBodyHttpServletRequest
import org.genspectrum.lapis.util.HeaderModifyingRequestWrapper
import org.springframework.core.annotation.Order
import org.springframework.http.HttpHeaders.ACCEPT
import org.springframework.http.MediaType
import org.springframework.stereotype.Component
import org.springframework.web.filter.OncePerRequestFilter
import java.util.Collections
import java.util.Enumeration

private val log = KotlinLogging.logger {}

Expand Down Expand Up @@ -41,52 +40,22 @@ class DataFormatParameterFilter(val objectMapper: ObjectMapper) : OncePerRequest
) {
val reReadableRequest = CachedBodyHttpServletRequest(request, objectMapper)

filterChain.doFilter(AcceptHeaderModifyingRequestWrapper(reReadableRequest), response)
filterChain.doFilter(
HeaderModifyingRequestWrapper(
reReadableRequest,
ACCEPT,
::findAcceptHeaderOverwriteValue,
),
response,
)
}
}

class AcceptHeaderModifyingRequestWrapper(
private val reReadableRequest: CachedBodyHttpServletRequest,
) : HttpServletRequestWrapper(reReadableRequest) {
override fun getHeader(name: String): String? {
if (name.equals("Accept", ignoreCase = true)) {
when (val overwrittenValue = findAcceptHeaderOverwriteValue()) {
null -> {}
else -> return overwriteWith(overwrittenValue)
}
}

return super.getHeader(name)
}

override fun getHeaders(name: String): Enumeration<String>? {
if (name.equals("Accept", ignoreCase = true)) {
when (val overwrittenValue = findAcceptHeaderOverwriteValue()) {
null -> {}
else -> return Collections.enumeration(listOf(overwriteWith(overwrittenValue)))
}
}

return super.getHeaders(name)
}

override fun getHeaderNames(): Enumeration<String> =
when (findAcceptHeaderOverwriteValue()) {
null -> super.getHeaderNames()
else -> Collections.enumeration(super.getHeaderNames().toList() + "Accept")
}

private fun findAcceptHeaderOverwriteValue() =
private fun findAcceptHeaderOverwriteValue(reReadableRequest: CachedBodyHttpServletRequest) =
when (reReadableRequest.getStringField(FORMAT_PROPERTY)?.uppercase()) {
DataFormat.CSV -> TEXT_CSV_HEADER
DataFormat.CSV_WITHOUT_HEADERS -> TEXT_CSV_WITHOUT_HEADERS_HEADER
DataFormat.TSV -> TEXT_TSV_HEADER
DataFormat.JSON -> MediaType.APPLICATION_JSON_VALUE
else -> null
}

private fun overwriteWith(value: String): String {
log.debug { "Overwriting Accept header to $value due to format property" }
return value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.genspectrum.lapis.util.CachedBodyHttpServletRequest
import org.springframework.core.annotation.Order
import org.springframework.http.HttpHeaders
import org.springframework.http.HttpHeaders.ACCEPT
import org.springframework.http.HttpHeaders.ACCEPT_ENCODING
import org.springframework.http.HttpHeaders.CONTENT_DISPOSITION
import org.springframework.stereotype.Component
import org.springframework.web.filter.OncePerRequestFilter

const val DOWNLOAD_AS_FILE_FILTER_ORDER = DATA_FORMAT_FILTER_ORDER + 1

@Component
@Order(DATA_FORMAT_FILTER_ORDER + 1)
@Order(DOWNLOAD_AS_FILE_FILTER_ORDER)
class DownloadAsFileFilter(private val objectMapper: ObjectMapper) : OncePerRequestFilter() {
override fun doFilterInternal(
request: HttpServletRequest,
Expand All @@ -23,7 +27,7 @@ class DownloadAsFileFilter(private val objectMapper: ObjectMapper) : OncePerRequ
val downloadAsFile = reReadableRequest.getBooleanField(DOWNLOAD_AS_FILE_PROPERTY) ?: false
if (downloadAsFile) {
val filename = getFilename(reReadableRequest)
response.setHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=$filename")
response.setHeader(CONTENT_DISPOSITION, "attachment; filename=$filename")
}
filterChain.doFilter(reReadableRequest, response)
}
Expand All @@ -33,14 +37,20 @@ class DownloadAsFileFilter(private val objectMapper: ObjectMapper) : OncePerRequ
SampleRoute.entries.find { request.getProxyAwarePath().startsWith("/sample${it.pathSegment}") }
val dataName = matchingRoute?.pathSegment?.trim('/') ?: "data"

val fileEnding = when (request.getHeader("Accept")) {
val compressionEnding = when (Compression.fromHeaders(request.getHeaders(ACCEPT_ENCODING))) {
Compression.GZIP -> ".gzip"
Compression.ZSTD -> ".zstd"
null -> ""
}

val fileEnding = when (request.getHeader(ACCEPT)) {
TEXT_CSV_HEADER -> "csv"
TEXT_TSV_HEADER -> "tsv"
else -> when (matchingRoute?.servesFasta) {
true -> "fasta"
else -> "json"
}
}
return "$dataName.$fileEnding"
return "$dataName.$fileEnding$compressionEnding"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RequestParam
import org.springframework.web.bind.annotation.RestController

const val TEXT_X_FASTA_HEADER = "text/x-fasta"

@RestController
@RequestMapping("/sample")
class LapisController(
Expand Down Expand Up @@ -1307,7 +1309,7 @@ class LapisController(
return getResponseAsCsv(request, httpHeaders.accept, TAB, siloQueryModel::getAminoAcidInsertions)
}

@GetMapping("$ALIGNED_AMINO_ACID_SEQUENCES_ROUTE/{gene}", produces = ["text/x-fasta"])
@GetMapping("$ALIGNED_AMINO_ACID_SEQUENCES_ROUTE/{gene}", produces = [TEXT_X_FASTA_HEADER])
@LapisAlignedAminoAcidSequenceResponse
fun getAlignedAminoAcidSequence(
@PathVariable(name = "gene", required = true)
Expand Down Expand Up @@ -1354,7 +1356,7 @@ class LapisController(
return siloQueryModel.getGenomicSequence(request, SequenceType.ALIGNED, gene)
}

@PostMapping("$ALIGNED_AMINO_ACID_SEQUENCES_ROUTE/{gene}", produces = ["text/x-fasta"])
@PostMapping("$ALIGNED_AMINO_ACID_SEQUENCES_ROUTE/{gene}", produces = [TEXT_X_FASTA_HEADER])
@LapisAlignedAminoAcidSequenceResponse
fun postAlignedAminoAcidSequence(
@PathVariable(name = "gene", required = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class MultiSegmentedSequenceController(
private val siloQueryModel: SiloQueryModel,
private val requestContext: RequestContext,
) {
@GetMapping("$ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = ["text/x-fasta"])
@GetMapping("$ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA_HEADER])
@LapisAlignedMultiSegmentedNucleotideSequenceResponse
fun getAlignedNucleotideSequence(
@PathVariable(name = "segment", required = true)
Expand Down Expand Up @@ -95,7 +95,7 @@ class MultiSegmentedSequenceController(
)
}

@PostMapping("$ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = ["text/x-fasta"])
@PostMapping("$ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA_HEADER])
@LapisAlignedMultiSegmentedNucleotideSequenceResponse
fun postAlignedNucleotideSequence(
@PathVariable(name = "segment", required = true)
Expand All @@ -114,7 +114,7 @@ class MultiSegmentedSequenceController(
)
}

@GetMapping("$UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = ["text/x-fasta"])
@GetMapping("$UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA_HEADER])
@LapisUnalignedMultiSegmentedNucleotideSequenceResponse
fun getUnalignedNucleotideSequence(
@PathVariable(name = "segment", required = true)
Expand Down Expand Up @@ -165,7 +165,7 @@ class MultiSegmentedSequenceController(
)
}

@PostMapping("$UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = ["text/x-fasta"])
@PostMapping("$UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE/{segment}", produces = [TEXT_X_FASTA_HEADER])
@LapisUnalignedMultiSegmentedNucleotideSequenceResponse
fun postUnalignedNucleotideSequence(
@PathVariable(name = "segment", required = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class SingleSegmentedSequenceController(
private val requestContext: RequestContext,
private val referenceGenomeSchema: ReferenceGenomeSchema,
) {
@GetMapping(ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = ["text/x-fasta"])
@GetMapping(ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA_HEADER])
@LapisAlignedSingleSegmentedNucleotideSequenceResponse
fun getAlignedNucleotideSequences(
@PrimitiveFieldFilters
Expand Down Expand Up @@ -92,7 +92,7 @@ class SingleSegmentedSequenceController(
)
}

@PostMapping(ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = ["text/x-fasta"])
@PostMapping(ALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA_HEADER])
@LapisAlignedSingleSegmentedNucleotideSequenceResponse
fun postAlignedNucleotideSequence(
@Parameter(schema = Schema(ref = "#/components/schemas/$NUCLEOTIDE_SEQUENCE_REQUEST_SCHEMA"))
Expand All @@ -108,7 +108,7 @@ class SingleSegmentedSequenceController(
)
}

@GetMapping(UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = ["text/x-fasta"])
@GetMapping(UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA_HEADER])
@LapisUnalignedSingleSegmentedNucleotideSequenceResponse
fun getUnalignedNucleotideSequences(
@PrimitiveFieldFilters
Expand Down Expand Up @@ -156,7 +156,7 @@ class SingleSegmentedSequenceController(
)
}

@PostMapping(UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = ["text/x-fasta"])
@PostMapping(UNALIGNED_NUCLEOTIDE_SEQUENCES_ROUTE, produces = [TEXT_X_FASTA_HEADER])
@LapisUnalignedSingleSegmentedNucleotideSequenceResponse
fun postUnalignedNucleotideSequence(
@Parameter(schema = Schema(ref = "#/components/schemas/$NUCLEOTIDE_SEQUENCE_REQUEST_SCHEMA"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const val ORDER_BY_PROPERTY = "orderBy"
const val LIMIT_PROPERTY = "limit"
const val OFFSET_PROPERTY = "offset"
const val DOWNLOAD_AS_FILE_PROPERTY = "downloadAsFile"
const val COMPRESSION_PROPERTY = "compression"

val SPECIAL_REQUEST_PROPERTIES = listOf(
MIN_PROPORTION_PROPERTY,
Expand All @@ -26,4 +27,5 @@ val SPECIAL_REQUEST_PROPERTIES = listOf(
OFFSET_PROPERTY,
FORMAT_PROPERTY,
DOWNLOAD_AS_FILE_PROPERTY,
COMPRESSION_PROPERTY,
)
Loading

0 comments on commit d6e8d9d

Please sign in to comment.