Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow string arrays as filter for string and pango lineage fields #516

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.genspectrum.lapis.request.AminoAcidMutation
import org.genspectrum.lapis.request.CommonSequenceFilters
import org.genspectrum.lapis.request.DEFAULT_MIN_PROPORTION
import org.genspectrum.lapis.request.Field
import org.genspectrum.lapis.request.GetRequestSequenceFilters
import org.genspectrum.lapis.request.MutationProportionsRequest
import org.genspectrum.lapis.request.NucleotideInsertion
import org.genspectrum.lapis.request.NucleotideMutation
Expand Down Expand Up @@ -84,7 +85,7 @@ class LapisController(
fun aggregated(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@FieldsToAggregateBy
@RequestParam
fields: List<Field>?,
Expand Down Expand Up @@ -139,7 +140,7 @@ class LapisController(
fun getAggregatedAsCsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@FieldsToAggregateBy
@RequestParam
fields: List<Field>?,
Expand Down Expand Up @@ -192,7 +193,7 @@ class LapisController(
fun getAggregatedAsTsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@FieldsToAggregateBy
@RequestParam
fields: List<Field>?,
Expand Down Expand Up @@ -284,7 +285,7 @@ class LapisController(
fun getNucleotideMutations(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@RequestParam(required = false)
@NucleotideMutations
nucleotideMutations: List<NucleotideMutation>?,
Expand Down Expand Up @@ -339,7 +340,7 @@ class LapisController(
fun getNucleotideMutationsAsCsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@RequestParam(required = false)
@NucleotideMutations
nucleotideMutations: List<NucleotideMutation>?,
Expand Down Expand Up @@ -388,7 +389,7 @@ class LapisController(
fun getNucleotideMutationsAsTsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@RequestParam(required = false)
@NucleotideMutations
nucleotideMutations: List<NucleotideMutation>?,
Expand Down Expand Up @@ -477,7 +478,7 @@ class LapisController(
fun getAminoAcidMutations(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@RequestParam(required = false)
@NucleotideMutations
nucleotideMutations: List<NucleotideMutation>?,
Expand Down Expand Up @@ -528,7 +529,7 @@ class LapisController(
fun getAminoAcidMutationsAsCsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@RequestParam(required = false)
@NucleotideMutations
nucleotideMutations: List<NucleotideMutation>?,
Expand Down Expand Up @@ -577,7 +578,7 @@ class LapisController(
fun getAminoAcidMutationsAsTsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@RequestParam(required = false)
@NucleotideMutations
nucleotideMutations: List<NucleotideMutation>?,
Expand Down Expand Up @@ -682,7 +683,7 @@ class LapisController(
fun getDetailsAsJson(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@DetailsFields
@RequestParam
fields: List<Field>?,
Expand Down Expand Up @@ -735,7 +736,7 @@ class LapisController(
fun getDetailsAsCsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@DetailsFields
@RequestParam
fields: List<Field>?,
Expand Down Expand Up @@ -785,7 +786,7 @@ class LapisController(
fun getDetailsAsTsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@DetailsFields
@RequestParam
fields: List<Field>?,
Expand Down Expand Up @@ -874,7 +875,7 @@ class LapisController(
fun getNucleotideInsertions(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@InsertionsOrderByFields
@RequestParam
orderBy: List<OrderByField>?,
Expand Down Expand Up @@ -926,7 +927,7 @@ class LapisController(
fun getNucleotideInsertionsAsCsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@InsertionsOrderByFields
@RequestParam
orderBy: List<OrderByField>?,
Expand Down Expand Up @@ -977,7 +978,7 @@ class LapisController(
fun getNucleotideInsertionsAsTsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@InsertionsOrderByFields
@RequestParam
orderBy: List<OrderByField>?,
Expand Down Expand Up @@ -1072,7 +1073,7 @@ class LapisController(
fun getAminoAcidInsertions(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@InsertionsOrderByFields
@RequestParam
orderBy: List<OrderByField>?,
Expand Down Expand Up @@ -1124,7 +1125,7 @@ class LapisController(
fun getAminoAcidInsertionsAsCsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@InsertionsOrderByFields
@RequestParam
orderBy: List<OrderByField>?,
Expand Down Expand Up @@ -1175,7 +1176,7 @@ class LapisController(
fun getAminoAcidInsertionsAsTsv(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@InsertionsOrderByFields
@RequestParam
orderBy: List<OrderByField>?,
Expand Down Expand Up @@ -1271,7 +1272,7 @@ class LapisController(
@PathVariable(name = "gene", required = true) gene: String,
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@AminoAcidSequencesOrderByFields
@RequestParam
orderBy: List<OrderByField>?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import org.genspectrum.lapis.openApi.Offset
import org.genspectrum.lapis.openApi.PrimitiveFieldFilters
import org.genspectrum.lapis.request.AminoAcidInsertion
import org.genspectrum.lapis.request.AminoAcidMutation
import org.genspectrum.lapis.request.GetRequestSequenceFilters
import org.genspectrum.lapis.request.NucleotideInsertion
import org.genspectrum.lapis.request.NucleotideMutation
import org.genspectrum.lapis.request.OrderByField
Expand Down Expand Up @@ -47,7 +48,7 @@ class MultiSegmentedSequenceController(
@PathVariable(name = "segment", required = true) segment: String,
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@NucleotideSequencesOrderByFields
@RequestParam
orderBy: List<OrderByField>?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import org.genspectrum.lapis.openApi.Offset
import org.genspectrum.lapis.openApi.PrimitiveFieldFilters
import org.genspectrum.lapis.request.AminoAcidInsertion
import org.genspectrum.lapis.request.AminoAcidMutation
import org.genspectrum.lapis.request.GetRequestSequenceFilters
import org.genspectrum.lapis.request.NucleotideInsertion
import org.genspectrum.lapis.request.NucleotideMutation
import org.genspectrum.lapis.request.OrderByField
Expand Down Expand Up @@ -47,7 +48,7 @@ class SingleSegmentedSequenceController(
fun getAlignedNucleotideSequences(
@PrimitiveFieldFilters
@RequestParam
sequenceFilters: Map<String, String>?,
sequenceFilters: GetRequestSequenceFilters?,
@NucleotideSequencesOrderByFields
@RequestParam
orderBy: List<OrderByField>?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.genspectrum.lapis.silo.IntBetween
import org.genspectrum.lapis.silo.IntEquals
import org.genspectrum.lapis.silo.NucleotideInsertionContains
import org.genspectrum.lapis.silo.NucleotideSymbolEquals
import org.genspectrum.lapis.silo.Or
import org.genspectrum.lapis.silo.PangoLineageEquals
import org.genspectrum.lapis.silo.SiloFilterExpression
import org.genspectrum.lapis.silo.StringEquals
Expand All @@ -30,7 +31,7 @@ import java.time.LocalDate
import java.time.format.DateTimeParseException
import java.util.Locale

data class SequenceFilterValue(val type: SequenceFilterFieldType, val value: String, val originalKey: String)
data class SequenceFilterValue(val type: SequenceFilterFieldType, val values: List<String>, val originalKey: String)

typealias SequenceFilterFieldName = String

Expand All @@ -46,10 +47,10 @@ class SiloFilterExpressionMapper(

val allowedSequenceFiltersWithType = sequenceFilters
.sequenceFilters
.map { (key, value) ->
.map { (key, values) ->
val nullableField = allowedSequenceFilterFields.fields[key.lowercase(Locale.US)]
val (filterExpressionId, type) = mapToFilterExpressionIdentifier(nullableField, key)
filterExpressionId to SequenceFilterValue(type, value, key)
filterExpressionId to SequenceFilterValue(type, values, key)
}
.groupBy({ it.first }, { it.second })

Expand All @@ -62,10 +63,10 @@ class SiloFilterExpressionMapper(
val filterExpressions = allowedSequenceFiltersWithType.map { (key, values) ->
val (siloColumnName, filter) = key
when (filter) {
Filter.StringEquals -> StringEquals(siloColumnName, values[0].value)
Filter.PangoLineage -> mapToPangoLineageFilter(siloColumnName, values[0].value)
Filter.StringEquals -> mapToStringEqualsFilters(siloColumnName, values)
Filter.PangoLineage -> mapToPangoLineageFilter(siloColumnName, values)
Filter.DateBetween -> mapToDateBetweenFilter(siloColumnName, values)
Filter.VariantQuery -> mapToVariantQueryFilter(values[0].value)
Filter.VariantQuery -> mapToVariantQueryFilter(values[0].values[0])
Filter.IntEquals -> mapToIntEqualsFilter(siloColumnName, values)
Filter.IntBetween -> mapToIntBetweenFilter(siloColumnName, values)
Filter.FloatEquals -> mapToFloatEqualsFilter(siloColumnName, values)
Expand Down Expand Up @@ -161,6 +162,11 @@ class SiloFilterExpressionMapper(
}
}

private fun mapToStringEqualsFilters(
siloColumnName: SequenceFilterFieldName,
values: List<SequenceFilterValue>,
) = Or(values[0].values.map { StringEquals(siloColumnName, it) })

private fun mapToVariantQueryFilter(variantQuery: String): SiloFilterExpression {
if (variantQuery.isBlank()) {
throw BadRequestException("variantQuery must not be empty")
Expand Down Expand Up @@ -208,7 +214,8 @@ class SiloFilterExpressionMapper(
}

private fun getAsDate(sequenceFilterValue: SequenceFilterValue?): LocalDate? {
val (_, value, originalKey) = sequenceFilterValue ?: return null
val (_, values, originalKey) = sequenceFilterValue ?: return null
val value = extractSingleFilterValue(values, originalKey)

try {
return LocalDate.parse(value)
Expand All @@ -219,22 +226,26 @@ class SiloFilterExpressionMapper(

private fun mapToPangoLineageFilter(
column: String,
value: String,
) = when {
value.endsWith(".*") -> PangoLineageEquals(column, value.substringBeforeLast(".*"), includeSublineages = true)
value.endsWith('*') -> PangoLineageEquals(column, value.substringBeforeLast('*'), includeSublineages = true)
value.endsWith('.') -> throw BadRequestException(
"Invalid pango lineage: $value must not end with a dot. Did you mean '$value*'?",
)
values: List<SequenceFilterValue>,
) = Or(
values[0].values.map {
when {
it.endsWith(".*") -> PangoLineageEquals(column, it.substringBeforeLast(".*"), includeSublineages = true)
it.endsWith('*') -> PangoLineageEquals(column, it.substringBeforeLast('*'), includeSublineages = true)
it.endsWith('.') -> throw BadRequestException(
"Invalid pango lineage: $it must not end with a dot. Did you mean '$it*'?",
)

else -> PangoLineageEquals(column, value, includeSublineages = false)
}
else -> PangoLineageEquals(column, it, includeSublineages = false)
}
},
)

private fun mapToIntEqualsFilter(
siloColumnName: SequenceFilterFieldName,
values: List<SequenceFilterValue>,
): SiloFilterExpression {
val value = values[0].value
val value = extractSingleFilterValue(values[0])
try {
return IntEquals(siloColumnName, value.toInt())
} catch (exception: NumberFormatException) {
Expand All @@ -249,7 +260,7 @@ class SiloFilterExpressionMapper(
siloColumnName: SequenceFilterFieldName,
values: List<SequenceFilterValue>,
): SiloFilterExpression {
val value = values[0].value
val value = extractSingleFilterValue(values[0])
try {
return FloatEquals(siloColumnName, value.toDouble())
} catch (exception: NumberFormatException) {
Expand All @@ -274,7 +285,8 @@ class SiloFilterExpressionMapper(
private inline fun <reified T : SequenceFilterFieldType> findIntOfFilterType(
dateRangeFilters: List<SequenceFilterValue>,
): Int? {
val (_, value, originalKey) = dateRangeFilters.find { (type, _, _) -> type is T } ?: return null
val (_, values, originalKey) = dateRangeFilters.find { (type, _, _) -> type is T } ?: return null
val value = extractSingleFilterValue(values, originalKey)

try {
return value.toInt()
Expand All @@ -300,7 +312,8 @@ class SiloFilterExpressionMapper(
private inline fun <reified T : SequenceFilterFieldType> findFloatOfFilterType(
dateRangeFilters: List<SequenceFilterValue>,
): Double? {
val (_, value, originalKey) = dateRangeFilters.find { (type, _, _) -> type is T } ?: return null
val (_, values, originalKey) = dateRangeFilters.find { (type, _, _) -> type is T } ?: return null
val value = extractSingleFilterValue(values, originalKey)

try {
return value.toDouble()
Expand Down Expand Up @@ -356,4 +369,14 @@ class SiloFilterExpressionMapper(
}

private val variantQueryTypes = listOf(Filter.PangoLineage)

private fun extractSingleFilterValue(value: SequenceFilterValue) =
extractSingleFilterValue(value.values, value.originalKey)

private fun extractSingleFilterValue(
values: List<String>,
originalKey: String,
) = values.singleOrNull() ?: throw BadRequestException(
"Expected exactly one value for '$originalKey' but got ${values.size} values.",
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.genspectrum.lapis.config.MetadataType
import org.genspectrum.lapis.config.OpennessLevel
import org.genspectrum.lapis.config.ReferenceGenome
import org.genspectrum.lapis.config.SequenceFilterFieldName
import org.genspectrum.lapis.config.SequenceFilterFieldType
import org.genspectrum.lapis.config.SequenceFilterFields
import org.genspectrum.lapis.controller.AGGREGATED_GROUP_BY_FIELDS_DESCRIPTION
import org.genspectrum.lapis.controller.AMINO_ACID_INSERTIONS_PROPERTY
Expand Down Expand Up @@ -282,7 +283,20 @@ private fun mapToOpenApiType(type: MetadataType): String =
private fun primitiveSequenceFilterFieldSchemas(sequenceFilterFields: SequenceFilterFields) =
sequenceFilterFields.fields
.values
.associate { (fieldName, field) -> fieldName to Schema<String>().type(field.openApiType) }
.associate { (fieldName, field) -> fieldName to filterFieldSchema(field) }

private fun filterFieldSchema(fieldType: SequenceFilterFieldType) =
when (fieldType) {
SequenceFilterFieldType.String, SequenceFilterFieldType.PangoLineage ->
Schema<String>().anyOf(
listOf(
Schema<String>().type(fieldType.openApiType),
arraySchema(Schema<String>().type(fieldType.openApiType)),
),
)

else -> Schema<String>().type(fieldType.openApiType)
}

private fun requestSchemaForCommonSequenceFilters(
requestProperties: Map<SequenceFilterFieldName, Schema<out Any>>,
Expand Down
Loading