Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add statistics endpoints #347

Merged
merged 4 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 185 additions & 4 deletions src/titiler/core/tests/test_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
def test_TilerFactory():
"""Test TilerFactory class."""
cog = TilerFactory()
assert len(cog.router.routes) == 24
assert len(cog.router.routes) == 26
assert cog.tms_dependency == TMSParams

cog = TilerFactory(router_prefix="something", tms_dependency=WebMercatorTMSParams)
Expand All @@ -47,7 +47,7 @@ def test_TilerFactory():
response = client.get(f"/something/NZTM2000/tilejson.json?url={DATA_DIR}/cog.tif")
assert response.status_code == 422

cog = TilerFactory(add_preview=False, add_part=False)
cog = TilerFactory(add_preview=False, add_part=False, add_statistics=False)
assert len(cog.router.routes) == 17

app = FastAPI()
Expand Down Expand Up @@ -313,14 +313,147 @@ def test_TilerFactory():
assert meta["width"] == 100
assert meta["height"] == 100

# GET - statistics
response = client.get(f"/statistics?url={DATA_DIR}/cog.tif&bidx=1,1,1")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
resp = response.json()
assert len(resp) == 3
assert list(resp[0]) == [
"min",
"max",
"mean",
"count",
"sum",
"std",
"median",
"majority",
"minority",
"unique",
"percentile_2",
"percentile_98",
"valid_pixels",
"masked_pixels",
"valid_percent",
]

response = client.get(f"/statistics?url={DATA_DIR}/cog.tif&bidx=1,1,1&p=4")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
resp = response.json()
assert len(resp) == 3
assert list(resp[0]) == [
"min",
"max",
"mean",
"count",
"sum",
"std",
"median",
"majority",
"minority",
"unique",
"percentile_4",
"valid_pixels",
"masked_pixels",
"valid_percent",
]

response = client.get(f"/statistics?url={DATA_DIR}/cog.tif&categorical=true")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
resp = response.json()
assert len(resp) == 1
assert list(resp[0]) == [
"categories",
"valid_pixels",
"masked_pixels",
"valid_percent",
]
assert len(resp[0]["categories"]) == 15

response = client.get(
f"/statistics?url={DATA_DIR}/cog.tif&categorical=true&c=1&c=2&c=3&c=4"
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
resp = response.json()
assert len(resp) == 1
assert list(resp[0]) == [
"categories",
"valid_pixels",
"masked_pixels",
"valid_percent",
]
assert len(resp[0]["categories"]) == 4
assert resp[0]["categories"]["4"] == 0

# POST - statistics
response = client.post(
f"/statistics?url={DATA_DIR}/cog.tif&bidx=1,1,1", data=feature
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
resp = response.json()
assert len(resp) == 3
assert list(resp[0]) == [
"min",
"max",
"mean",
"count",
"sum",
"std",
"median",
"majority",
"minority",
"unique",
"percentile_2",
"percentile_98",
"valid_pixels",
"masked_pixels",
"valid_percent",
]

response = client.post(
f"/statistics?url={DATA_DIR}/cog.tif&categorical=true", data=feature
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
resp = response.json()
assert len(resp) == 1
assert list(resp[0]) == [
"categories",
"valid_pixels",
"masked_pixels",
"valid_percent",
]
assert len(resp[0]["categories"]) == 12

response = client.post(
f"/statistics?url={DATA_DIR}/cog.tif&categorical=true&c=1&c=2&c=3&c=4",
data=feature,
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
resp = response.json()
assert len(resp) == 1
assert list(resp[0]) == [
"categories",
"valid_pixels",
"masked_pixels",
"valid_percent",
]
assert len(resp[0]["categories"]) == 4
assert resp[0]["categories"]["4"] == 0


@patch("rio_tiler.io.cogeo.rasterio")
def test_MultiBaseTilerFactory(rio):
"""test MultiBaseTilerFactory."""
rio.open = mock_rasterio_open

stac = MultiBaseTilerFactory(reader=STACReader)
assert len(stac.router.routes) == 25
assert len(stac.router.routes) == 27

app = FastAPI()
app.include_router(stac.router)
Expand Down Expand Up @@ -371,6 +504,30 @@ def test_MultiBaseTilerFactory(rio):
assert meta["dtype"] == "int32"
assert meta["count"] == 3

# GET - statistics
response = client.get(f"/statistics?url={DATA_DIR}/item.json&assets=B01,B09")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
resp = response.json()
assert len(resp) == 2
assert list(resp[0]) == [
"min",
"max",
"mean",
"count",
"sum",
"std",
"median",
"majority",
"minority",
"unique",
"percentile_2",
"percentile_98",
"valid_pixels",
"masked_pixels",
"valid_percent",
]


@attr.s
class BandFileReader(MultiBandReader):
Expand Down Expand Up @@ -400,7 +557,7 @@ def test_MultiBandTilerFactory():
"""test MultiBandTilerFactory."""

bands = MultiBandTilerFactory(reader=BandFileReader)
assert len(bands.router.routes) == 25
assert len(bands.router.routes) == 27

app = FastAPI()
app.include_router(bands.router)
Expand Down Expand Up @@ -446,6 +603,30 @@ def test_MultiBandTilerFactory():
assert meta["dtype"] == "int32"
assert meta["count"] == 3

# GET - statistics
response = client.get(f"/statistics?url={DATA_DIR}&bands=B01,B09")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
resp = response.json()
assert len(resp) == 2
assert list(resp[0]) == [
"min",
"max",
"mean",
"count",
"sum",
"std",
"median",
"majority",
"minority",
"unique",
"percentile_2",
"percentile_98",
"valid_pixels",
"masked_pixels",
"valid_percent",
]


def test_TMSFactory():
"""test TMSFactory."""
Expand Down
90 changes: 88 additions & 2 deletions src/titiler/core/titiler/core/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import abc
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, Union
from urllib.parse import urlencode, urlparse

import rasterio
Expand Down Expand Up @@ -33,7 +33,7 @@
from titiler.core.models.OGC import TileMatrixSetList
from titiler.core.resources.enums import ImageType, MediaType, OptionalHeader
from titiler.core.resources.responses import GeoJSONResponse, XMLResponse
from titiler.core.utils import Timer, bbox_to_feature
from titiler.core.utils import Timer, bbox_to_feature, data_stats

from fastapi import APIRouter, Body, Depends, Path, Query

Expand Down Expand Up @@ -146,6 +146,7 @@ class TilerFactory(BaseTilerFactory):
# Add/Remove some endpoints
add_preview: bool = True
add_part: bool = True
add_statistics: bool = True

def register_routes(self):
"""
Expand All @@ -172,6 +173,9 @@ def register_routes(self):
if self.add_part:
self.part()

if self.add_statistics:
self.statistics()

############################################################################
# /bounds
############################################################################
Expand Down Expand Up @@ -772,6 +776,88 @@ def geojson_crop(

return Response(content, media_type=format.mediatype, headers=headers)

############################################################################
# /statistics (Optional)
############################################################################
def statistics(self):
"""add statistics endpoints."""

@self.router.get(
"/statistics",
responses={
200: {
"content": {"application/json": {}},
"description": "Return dataset's statistics.",
}
},
)
def statistics(
src_path=Depends(self.path_dependency),
layer_params=Depends(self.layer_dependency),
image_params=Depends(self.img_dependency),
dataset_params=Depends(self.dataset_dependency),
categorical: bool = Query(
False, description="Return statistics for categorical dataset."
),
c: List[Union[float, int]] = Query(
None, description="Pixels values for categories."
),
p: List[int] = Query([2, 98], description="Percentiles values."),
kwargs: Dict = Depends(self.additional_dependency),
):
"""Create image from a geojson feature."""
with rasterio.Env(**self.gdal_config):
with self.reader(src_path, **self.reader_options) as src_dst:
data = src_dst.preview(
**layer_params.kwargs,
**image_params.kwargs,
**dataset_params.kwargs,
**kwargs,
).as_masked()

return data_stats(
data, categorical=categorical, categories=c, percentiles=p
)

@self.router.post(
"/statistics",
responses={
200: {
"content": {"application/json": {}},
"description": "Return dataset's statistics.",
}
},
)
def geojson_statistics(
feature: Feature = Body(..., descriptiom="GeoJSON Feature."),
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved
src_path=Depends(self.path_dependency),
layer_params=Depends(self.layer_dependency),
image_params=Depends(self.img_dependency),
dataset_params=Depends(self.dataset_dependency),
categorical: bool = Query(
False, description="Return statistics for categorical dataset."
),
c: List[Union[float, int]] = Query(
None, description="Pixels values for categories."
),
p: List[int] = Query([2, 98], description="Percentiles values."),
kwargs: Dict = Depends(self.additional_dependency),
):
"""Create image from a geojson feature."""
with rasterio.Env(**self.gdal_config):
with self.reader(src_path, **self.reader_options) as src_dst:
data = src_dst.feature(
feature.dict(exclude_none=True),
**layer_params.kwargs,
**image_params.kwargs,
**dataset_params.kwargs,
**kwargs,
).as_masked()

return data_stats(
data, categorical=categorical, categories=c, percentiles=p,
)
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class MultiBaseTilerFactory(TilerFactory):
Expand Down
Loading