Skip to content

Commit

Permalink
Add validation rules to PrevalenceAllLineagesByLocationHandler; Updat…
Browse files Browse the repository at this point in the history
…e PrevalenceByLocationAndTimeHandler

* Add validation rules to PrevalenceAllLineagesByLocationHandler parameters
* Add validation rules to PrevalenceByLocationAndTimeHandler
* Allow location_id = None for PrevalenceAllLineagesByLocationHandler
  • Loading branch information
remoteeng00 authored Feb 22, 2023
1 parent 09dcc5b commit 897b144
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 43 deletions.
8 changes: 0 additions & 8 deletions web/handlers/genomics/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,6 @@ def get_total_hits(d): # To account for difference in ES versions 7.12.0 vs 6.8.
return d["hits"]["total"]["value"] if isinstance(d["hits"]["total"], dict) else d["hits"]["total"]


def validate_iso_date(date_str):
try:
dt.strptime(date_str, '%Y-%m-%d')
return True
except ValueError:
return False


def create_date_range_filter(field_name, min_date=None, max_date=None):
date_range_filter = {"range": {field_name: {}}}
if not max_date and not min_date:
Expand Down
33 changes: 11 additions & 22 deletions web/handlers/v2/genomics/prevalence_all_lineages_by_location.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from datetime import datetime as dt, timedelta

import pandas as pd
from tornado.web import HTTPError

from web.handlers.genomics.base import BaseHandler
from web.handlers.genomics.util import (
Expand All @@ -12,50 +11,38 @@
get_major_lineage_prevalence,
parse_location_id_to_query,
parse_time_window_to_query,
validate_iso_date,
)


class PrevalenceAllLineagesByLocationHandler(BaseHandler):
# size = 100 # If size=1000 it will raise too_many_buckets_exception in case missing location_id in query.
name = "prevalence-by-location-all-lineages"
kwargs = dict(BaseHandler.kwargs)
kwargs["GET"] = {
"location_id": {"type": str, "default": None},
"window": {"type": str, "default": None},
"other_threshold": {"type": float, "default": 0.05},
"nday_threshold": {"type": float, "default": 10},
"ndays": {"type": float, "default": 180},
"window": {"type": int, "default": None, "min": 1},
"other_threshold": {"type": float, "default": 0.05, "min": 0, "max": 1},
"nday_threshold": {"type": int, "default": 10, "min": 1},
"ndays": {"type": int, "default": 180, "min": 1},
"other_exclude": {"type": str, "default": None},
"cumulative": {"type": str, "default": None},
"min_date": {"type": str, "default": None},
"max_date": {"type": str, "default": None},
"cumulative": {"type": bool, "default": False},
"min_date": {"type": str, "default": None, "date_format": "%Y-%m-%d"},
"max_date": {"type": str, "default": None, "date_format": "%Y-%m-%d"},
}

async def _get(self):
query_location = self.args.location_id
query_window = self.args.window
query_window = int(query_window) if query_window is not None else None
query_other_threshold = self.args.other_threshold
query_other_threshold = float(query_other_threshold)
query_nday_threshold = self.args.nday_threshold
query_nday_threshold = float(query_nday_threshold)
query_ndays = self.args.ndays
query_ndays = int(query_ndays)
query_other_exclude = self.args.other_exclude
query_other_exclude = (
query_other_exclude.split(",") if query_other_exclude is not None else []
)
query_cumulative = self.args.cumulative
query_cumulative = True if query_cumulative == "true" else False
if self.args.max_date:
if not validate_iso_date(self.args.max_date):
raise HTTPError(400, reason="Invalid max_date format")
if self.args.min_date:
if not validate_iso_date(self.args.min_date):
raise HTTPError(400, reason="Invalid min_date format")
query = {
"size": 0,
"query": {},
"aggs": {
"count": {
"terms": {"field": "date_collected", "size": self.size},
Expand All @@ -69,7 +56,9 @@ async def _get(self):
date_range_filter = create_date_range_filter(
"date_collected", self.args.min_date, self.args.max_date
)
query["query"] = parse_time_window_to_query(date_range_filter, query_obj=query_obj)
query_obj = parse_time_window_to_query(date_range_filter, query_obj=query_obj)
if query_obj:
query["query"] = query_obj
# import json
# print(json.dumps(query))
resp = await self.asynchronous_fetch(query)
Expand Down
20 changes: 7 additions & 13 deletions web/handlers/v2/genomics/prevalence_by_location_and_time.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
from tornado.web import HTTPError

from web.handlers.genomics.base import BaseHandler
from web.handlers.genomics.util import (
create_iterator,
create_nested_mutation_query,
parse_location_id_to_query,
transform_prevalence,
validate_iso_date,
)


class PrevalenceByLocationAndTimeHandler(BaseHandler):
name = "prevalence-by-location"
kwargs = dict(BaseHandler.kwargs)
kwargs["GET"] = {
"pangolin_lineage": {"type": str, "default": None},
"pangolin_lineage": {"type": str, "required": True},
"mutations": {"type": str, "default": None},
"location_id": {"type": str, "default": None},
"cumulative": {"type": str, "default": None},
"min_date": {"type": str, "default": None},
"max_date": {"type": str, "default": None},
"cumulative": {"type": bool, "default": False},
"min_date": {"type": str, "default": None, "date_format": "%Y-%m-%d"},
"max_date": {"type": str, "default": None, "date_format": "%Y-%m-%d"},
}

async def _get(self):
Expand All @@ -31,15 +28,10 @@ async def _get(self):
query_mutations = self.args.mutations
query_mutations = query_mutations.split(" AND ") if query_mutations is not None else []
cumulative = self.args.cumulative
cumulative = True if cumulative == "true" else False
date_range_filter = {"query": {"range": {"date_collected": {}}}}
if self.args.max_date:
if not validate_iso_date(self.args.max_date):
raise HTTPError(400, reason="Invalid max_date format")
date_range_filter["query"]["range"]["date_collected"]["lte"] = self.args.max_date
if self.args.min_date:
if not validate_iso_date(self.args.min_date):
raise HTTPError(400, reason="Invalid min_date format")
date_range_filter["query"]["range"]["date_collected"]["gte"] = self.args.min_date

results = {}
Expand All @@ -65,7 +57,9 @@ async def _get(self):
query_obj = create_nested_mutation_query(
lineages=lineages, mutations=j, location_id=query_location
)
query["aggs"]["prevalence"]["aggs"]["count"]["aggs"]["lineage_count"]["filter"] = query_obj
query["aggs"]["prevalence"]["aggs"]["count"]["aggs"]["lineage_count"][
"filter"
] = query_obj
# import json
# print(json.dumps(query))
resp = await self.asynchronous_fetch(query)
Expand Down

0 comments on commit 897b144

Please sign in to comment.