Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: cache responses #42

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xpublish_edr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
xpublish_edr is not a real package, just a set of best practices examples.
Xpublish routers for the OGC EDR API.
"""

from xpublish_edr.plugin import CfEdrPlugin
Expand Down
12 changes: 10 additions & 2 deletions xpublish_edr/formats/to_covjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@

import numpy as np
import xarray as xr
from fastapi.responses import JSONResponse


class CovJSONResponse(JSONResponse):
"""CovJSON response type"""

# https://docs.ogc.org/cs/21-069r2/21-069r2.html#_b8b17e78-0147-4b58-8ade-a19465b57abc
media_type = "application/vnd.cov+json"


class Domain(TypedDict):
Expand Down Expand Up @@ -74,7 +82,7 @@ def invert_cf_dims(ds):
return inverted


def to_cf_covjson(ds: xr.Dataset) -> CovJSON:
def to_cf_covjson(ds: xr.Dataset) -> CovJSONResponse:
"""Transform an xarray dataset to CoverageJSON using CF conventions"""

covjson: CovJSON = {
Expand Down Expand Up @@ -164,4 +172,4 @@ def to_cf_covjson(ds: xr.Dataset) -> CovJSON:

covjson["ranges"][var] = cov_range

return covjson
return CovJSONResponse(content=covjson)
2 changes: 1 addition & 1 deletion xpublish_edr/formats/to_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi import Response


def to_csv(ds: xr.Dataset):
def to_csv(ds: xr.Dataset) -> Response:
"""Return a CSV response from an xarray dataset"""
ds = ds.squeeze()
df = ds.to_pandas()
Expand Down
2 changes: 1 addition & 1 deletion xpublish_edr/formats/to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fastapi import Response


def to_netcdf(ds: xr.Dataset):
def to_netcdf(ds: xr.Dataset) -> Response:
"""Return a NetCDF response from a dataset"""
with TemporaryDirectory() as tmpdir:
path = Path(tmpdir) / "position.nc"
Expand Down
164 changes: 102 additions & 62 deletions xpublish_edr/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,35 @@
"""
import importlib
import logging
from typing import List, Optional
from functools import cache
from typing import Hashable, List, Optional, Tuple

import cachey
import dask
import xarray as xr
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from xpublish import Dependencies, Plugin, hookimpl
from xpublish.utils.cache import CostTimer

from .formats.to_covjson import to_cf_covjson
from .query import EDRQuery, edr_query, edr_query_params

logger = logging.getLogger("cf_edr")


def cache_key_from_request(
route: str,
request: Request,
query: EDRQuery,
dataset: xr.Dataset,
) -> Tuple[Hashable, ...]:
"""Generate a cache key from the request and query parameters"""
with dask.config.set({"tokenize.ensure-deterministic": True}):
ds_token = dask.base.tokenize(dataset)
return (route, request, query, ds_token)


@cache
def position_formats():
"""
Return response format functions from registered
Expand Down Expand Up @@ -71,85 +88,108 @@ def get_position(
request: Request,
query: EDRQuery = Depends(edr_query),
dataset: xr.Dataset = Depends(deps.dataset),
cache: cachey.Cache = Depends(deps.cache),
):
"""
Returns position data based on WKT `Point(lon lat)` coordinates

Extra selecting/slicing parameters can be provided as extra query parameters
"""
try:
ds = dataset.cf.sel(X=query.point.x, Y=query.point.y, method="nearest")
except KeyError:
raise HTTPException(
status_code=404,
detail="Dataset does not have CF Convention compliant metadata",
)
cache_key = cache_key_from_request("position", request, query, dataset)
response: Optional[Response] = cache.get(cache_key)

if query.z:
ds = dataset.cf.sel(Z=query.z, method="nearest")

if query.datetime:
datetimes = query.datetime.split("/")
if response is not None:
logger.debug(f"Cache hit for {cache_key}")
return response

with CostTimer() as ct:
try:
if len(datetimes) == 1:
ds = ds.cf.sel(T=datetimes[0], method="nearest")
elif len(datetimes) == 2:
ds = ds.cf.sel(T=slice(datetimes[0], datetimes[1]))
else:
raise HTTPException(
status_code=404,
detail="Invalid datetimes submitted",
)
except ValueError as e:
logger.error("Error with datetime", exc_info=True)
raise HTTPException(
status_code=404,
detail=f"Invalid datetime ({e})",
) from e

if query.parameters:
try:
ds = ds.cf[query.parameters.split(",")]
except KeyError as e:
ds = dataset.cf.sel(
X=query.point.x,
Y=query.point.y,
method="nearest",
)
except KeyError:
raise HTTPException(
status_code=404,
detail=f"Invalid variable: {e}",
detail="Dataset does not have CF Convention compliant metadata",
)

logger.debug(f"Dataset filtered by query params {ds}")
if query.z:
ds = dataset.cf.sel(Z=query.z, method="nearest")

if query.datetime:
datetimes = query.datetime.split("/")

try:
if len(datetimes) == 1:
ds = ds.cf.sel(T=datetimes[0], method="nearest")
elif len(datetimes) == 2:
ds = ds.cf.sel(T=slice(datetimes[0], datetimes[1]))
else:
raise HTTPException(
status_code=404,
detail="Invalid datetimes submitted",
)
except ValueError as e:
logger.error("Error with datetime", exc_info=True)
raise HTTPException(
status_code=404,
detail=f"Invalid datetime ({e})",
) from e

query_params = dict(request.query_params)
for query_param in request.query_params:
if query_param in edr_query_params:
del query_params[query_param]
if query.parameters:
try:
ds = ds.cf[query.parameters.split(",")]
except KeyError as e:
raise HTTPException(
status_code=404,
detail=f"Invalid variable: {e}",
)

method: Optional[str] = "nearest"
logger.debug(f"Dataset filtered by query params {ds}")

for key, value in query_params.items():
split_value = value.split("/")
if len(split_value) == 1:
continue
elif len(split_value) == 2:
query_params[key] = slice(split_value[0], split_value[1])
method = None
else:
raise HTTPException(404, f"Too many values for selecting {key}")
query_params = dict(request.query_params)
for query_param in request.query_params:
if query_param in edr_query_params:
del query_params[query_param]

ds = ds.sel(query_params, method=method)
method: Optional[str] = "nearest"

if query.format:
try:
format_fn = position_formats()[query.format]
except KeyError:
raise HTTPException(
404,
f"{query.format} is not a valid format for EDR position queries. "
"Get `./formats` for valid formats",
)
for key, value in query_params.items():
split_value = value.split("/")
if len(split_value) == 1:
continue
elif len(split_value) == 2:
query_params[key] = slice(split_value[0], split_value[1])
method = None
else:
raise HTTPException(
404,
f"Too many values for selecting {key}",
)

return format_fn(ds)
ds = ds.sel(query_params, method=method)

return to_cf_covjson(ds)
if query.format:
try:
format_fn = position_formats()[query.format]
except KeyError:
raise HTTPException(
404,
f"{query.format} is not a valid format for EDR position queries. "
"Get `./formats` for valid formats",
)
else:
format_fn = to_cf_covjson

response = format_fn(ds)
cache.put(
cache_key,
response,
ct.time,
int(response.headers["content-length"]),
)
return response

return router
13 changes: 13 additions & 0 deletions xpublish_edr/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ def point(self):
"""Shapely point from WKT query params"""
return wkt.loads(self.coords)

def __hash__(self):
"""Hash based on query parameters"""
return hash(
(
self.coords,
self.z,
self.datetime,
self.parameters,
self.crs,
self.format,
),
)


def edr_query(
coords: str = Query(
Expand Down
Loading