Skip to content
This repository has been archived by the owner on Jun 1, 2022. It is now read-only.

Commit

Permalink
Streaming formats for searchLocations plus format=nlgeojson, refs #367
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Apr 22, 2021
1 parent 12c8ae7 commit d29c265
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 27 deletions.
77 changes: 62 additions & 15 deletions vaccinate/api/search.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import json
from collections import namedtuple
from html import escape

import beeline
from core.models import ConcordanceIdentifier, Location, State
from django.http import JsonResponse
from django.http.response import StreamingHttpResponse
from django.shortcuts import render
from django.utils.safestring import mark_safe

OutputFormat = namedtuple(
"Format", ("start", "transform", "separator", "end", "content_type")
)


@beeline.traced("search_locations")
def search_locations(request):
format = request.GET.get("format") or "json"
size = min(int(request.GET.get("size", "10")), 1000)
q = (request.GET.get("q") or "").strip().lower()
all = request.GET.get("all")
state = (request.GET.get("state") or "").upper()
if state:
try:
Expand All @@ -29,6 +36,7 @@ def search_locations(request):
return render(
request, "api/search_locations_map.html", {"query_string": get.urlencode()}
)

qs = Location.objects.filter(soft_deleted=False)
if q:
qs = qs.filter(name__icontains=q)
Expand All @@ -43,29 +51,43 @@ def search_locations(request):
)
qs = location_json_queryset(qs)
page_qs = qs[:size]
json_results = lambda: {
"results": [location_json(location) for location in page_qs],
"total": qs.count(),
}
output = None
if format == "geojson":
output = {
"type": "FeatureCollection",
"features": [location_geojson(location) for location in qs],
}

else:
output = json_results()
if format not in FORMATS:
return JsonResponse({"error": "Invalid format"}, status=400)

formatter = FORMATS[format]

def stream():
if callable(formatter.start):
yield formatter.start(qs)
else:
yield formatter.start
started = False
for location in page_qs:
if started and formatter.separator:
yield formatter.separator
started = True
yield formatter.transform(location)
if callable(formatter.end):
yield formatter.end(qs)
else:
yield formatter.end

if debug:
if all:
return JsonResponse({"error": "Cannot use both all and debug"}, status=400)
output = "".join(stream())
if formatter.content_type == "application/json":
output = json.dumps(json.loads(output), indent=2)
return render(
request,
"api/search_locations_debug.html",
{
"json_results": mark_safe(escape(json.dumps(output, indent=2))),
"output": mark_safe(escape(output)),
},
)
else:
return JsonResponse(output)

return StreamingHttpResponse(stream(), content_type=formatter.content_type)


def location_json_queryset(queryset):
Expand Down Expand Up @@ -114,3 +136,28 @@ def location_geojson(location):
"coordinates": [location.longitude, location.latitude],
},
}


FORMATS = {
"json": OutputFormat(
start='{"results": [',
transform=lambda l: json.dumps(location_json(l)),
separator=",",
end=lambda qs: '], "total": TOTAL}'.replace("TOTAL", str(qs.count())),
content_type="application/json",
),
"geojson": OutputFormat(
start='{"type": "FeatureCollection", "features": [',
transform=lambda l: json.dumps(location_geojson(l)),
separator=",",
end=lambda qs: "]}",
content_type="application/json",
),
"nlgeojson": OutputFormat(
start="",
transform=lambda l: json.dumps(location_geojson(l)),
separator="\n",
end="",
content_type="text/plain",
),
}
70 changes: 59 additions & 11 deletions vaccinate/api/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from core.models import ConcordanceIdentifier, Location, State


def search_get_json(client, query_string):
response = client.get("/api/searchLocations?" + query_string)
assert response.status_code == 200
joined = b"".join(response.streaming_content)
return json.loads(joined)


@pytest.mark.parametrize(
"query_string,expected",
(
Expand All @@ -28,21 +35,62 @@ def test_search_locations(client, query_string, expected, ten_locations):
with_concordances_2.concordances.add(
ConcordanceIdentifier.for_idref("google_places:456")
)
response = client.get("/api/searchLocations?" + query_string)
assert response.status_code == 200
data = json.loads(response.content)
data = search_get_json(client, query_string)
names = [r["name"] for r in data["results"]]
assert names == expected
assert data["total"] == len(expected)


def test_search_locations_ignores_soft_deleted(client, ten_locations):
assert (
json.loads(client.get("/api/searchLocations?q=Location+1").content)["total"]
== 2
)
assert search_get_json(client, "q=Location+1")["total"] == 2
Location.objects.filter(name="Location 10").update(soft_deleted=True)
assert (
json.loads(client.get("/api/searchLocations?q=Location+1").content)["total"]
== 1
)
assert search_get_json(client, "q=Location+1")["total"] == 1


def test_search_locations_format_json(client, ten_locations):
result = search_get_json(client, "q=Location+1")
assert set(result.keys()) == {"results", "keys"}
record = result["results"][0]
assert set(record.keys()) == {
"id",
"name",
"state",
"latitude",
"longitude",
"location_type",
"import_ref",
"phone_number",
"full_address",
"city",
"county",
"google_places_id",
"vaccinefinder_location_id",
"vaccinespotter_location_id",
"zip_code",
"hours",
"website",
"preferred_contact_method",
"provider",
"concordances",
}


def test_search_locations_format_geojson(client, ten_locations):
result = search_get_json(client, "q=Location+1&format=geojson")
assert set(result.keys()) == {"type", "features"}
assert result["type"] == "FeatureCollection"
record = result["features"][0]
assert set(record.keys()) == {"type", "properties", "geometry"}
assert record["geometry"] == {"type": "Point", "coordinates": [40.0, 30.0]}


def test_search_locations_format_nlgeojson(client, ten_locations):
response = client.get("/api/searchLocations?q=Location+1&format=nlgeojson")
assert response.status_code == 200
joined = b"".join(response.streaming_content)
# Should return two results split by newlines
lines = joined.split(b"\n")
assert len(lines) == 2
for line in lines:
record = json.loads(line)
assert set(record.keys()) == {"type", "properties", "geometry"}
2 changes: 1 addition & 1 deletion vaccinate/templates/api/search_locations_debug.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
<head><title>Debug search locations</title>
</head>
<body>
<pre>{{ json_results }}</pre>
<pre>{{ output }}</pre>
</body>
</html>

0 comments on commit d29c265

Please sign in to comment.