Skip to content

Commit

Permalink
make colormap an independent dependency (#252)
Browse files Browse the repository at this point in the history
* make colormap an independent dependency

* update docs

* add customCmap tests

* update changelog
  • Loading branch information
vincentsarago authored Mar 5, 2021
1 parent 272d62d commit f5f56d4
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

* renamed `OptionalHeaders`, `MimeTypes` and `ImageDrivers` enums to the singular form. (https://github.com/developmentseed/titiler/pull/258)
* renamed `MimeType` to `MediaType` (https://github.com/developmentseed/titiler/pull/258)
* add `ColorMapParams` dependency to ease the creation of custom colormap dependency (https://github.com/developmentseed/titiler/pull/252)

## 0.1.0 (2021-02-17)

Expand Down
22 changes: 17 additions & 5 deletions docs/concepts/dependencies.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,27 @@ The `factories` allow users to set multiple default dependencies. Here is the li
title="Color Formula",
description="rio-color formula (info: https://github.com/mapbox/rio-color)",
)
color_map: Optional[ColorMapNames] = Query(
None, description="rio-tiler's colormap name"
)
return_mask: bool = Query(True, description="Add mask to the output data.")
colormap: Optional[Dict[int, Tuple[int, int, int, int]]] = field(init=False)

rescale_range: Optional[List[Union[float, int]]] = field(init=False)

def __post_init__(self):
"""Post Init."""
self.colormap = cmap.get(self.color_map.value) if self.color_map else None
self.rescale_range = (
list(map(float, self.rescale.split(","))) if self.rescale else None
)
```

* **colormap_dependency**: colormap options.

```python
def ColorMapParams(
color_map: ColorMapNames = Query(None, description="Colormap name",)
) -> Optional[Dict]:
"""Colormap Dependency."""
if color_map:
return cmap.get(color_map.value)
return None
```

* **additional_dependency**: Default dependency, will be passed as `**kwargs` to all endpoints.
Expand Down
56 changes: 56 additions & 0 deletions tests/test_CustomCmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# """Test TiTiler Custom Colormap Params."""

from enum import Enum
from io import BytesIO
from typing import Dict, Optional

import numpy
from rio_tiler.colormap import ColorMaps

from titiler.endpoints import factory

from .conftest import DATA_DIR

from fastapi import FastAPI, Query

from starlette.testclient import TestClient

cmap_values = {
"cmap1": {6: (4, 5, 6, 255)},
}
cmap = ColorMaps(data=cmap_values)
ColorMapNames = Enum( # type: ignore
"ColorMapNames", [(a, a) for a in sorted(cmap.list())]
)


def ColorMapParams(
color_map: ColorMapNames = Query(None, description="Colormap name",)
) -> Optional[Dict]:
"""Colormap Dependency."""
if color_map:
return cmap.get(color_map.value)
return None


def test_CustomCmap():
"""Test Custom Render Params dependency."""
app = FastAPI()
cog = factory.TilerFactory(colormap_dependency=ColorMapParams)
app.include_router(cog.router)
client = TestClient(app)

response = client.get(
f"/preview.npy?url={DATA_DIR}/above_cog.tif&bidx=1&color_map=cmap1"
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/x-binary"
data = numpy.load(BytesIO(response.content))
assert 4 in data[0]
assert 5 in data[1]
assert 6 in data[2]

response = client.get(
f"/preview.npy?url={DATA_DIR}/above_cog.tif&bidx=1&color_map=another_cmap"
)
assert response.status_code == 422
16 changes: 10 additions & 6 deletions titiler/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union

import numpy
from morecantile import tms
Expand Down Expand Up @@ -68,6 +68,15 @@ def TMSParams(
return tms.get(TileMatrixSetId.name)


def ColorMapParams(
color_map: ColorMapNames = Query(None, description="Colormap name",)
) -> Optional[Dict]:
"""Colormap Dependency."""
if color_map:
return cmap.get(color_map.value)
return None


@dataclass
class DefaultDependency:
"""Dependency Base Class"""
Expand Down Expand Up @@ -327,17 +336,12 @@ class RenderParams(DefaultDependency):
title="Color Formula",
description="rio-color formula (info: https://github.com/mapbox/rio-color)",
)
color_map: Optional[ColorMapNames] = Query(
None, description="rio-tiler's colormap name"
)
return_mask: bool = Query(True, description="Add mask to the output data.")

colormap: Optional[Dict[int, Tuple[int, int, int, int]]] = field(init=False)
rescale_range: Optional[List[Union[float, int]]] = field(init=False)

def __post_init__(self):
"""Post Init."""
self.colormap = cmap.get(self.color_map.value) if self.color_map else None
self.rescale_range = (
list(map(float, self.rescale.split(","))) if self.rescale else None
)
23 changes: 15 additions & 8 deletions titiler/endpoints/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
BandsParams,
BidxExprParams,
BidxParams,
ColorMapParams,
DatasetParams,
DefaultDependency,
ImageParams,
Expand Down Expand Up @@ -86,6 +87,8 @@ class BaseTilerFactory(metaclass=abc.ABCMeta):
# Image rendering Dependencies
render_dependency: Type[DefaultDependency] = RenderParams

colormap_dependency: Callable[..., Optional[Dict]] = ColorMapParams

# TileMatrixSet dependency
tms_dependency: Callable[..., TileMatrixSet] = WebMercatorTMSParams

Expand Down Expand Up @@ -298,6 +301,7 @@ def tile(
layer_params=Depends(self.layer_dependency),
dataset_params=Depends(self.dataset_dependency),
render_params=Depends(self.render_dependency),
colormap=Depends(self.colormap_dependency),
kwargs: Dict = Depends(self.additional_dependency),
):
"""Create map tile from a dataset."""
Expand Down Expand Up @@ -338,7 +342,7 @@ def tile(
content = image.render(
add_mask=render_params.return_mask,
img_format=format.driver,
colormap=render_params.colormap or dst_colormap,
colormap=colormap or dst_colormap,
**format.profile,
**render_params.kwargs,
)
Expand Down Expand Up @@ -385,6 +389,7 @@ def tilejson(
layer_params=Depends(self.layer_dependency), # noqa
dataset_params=Depends(self.dataset_dependency), # noqa
render_params=Depends(self.render_dependency), # noqa
colormap=Depends(self.colormap_dependency), # noqa
kwargs: Dict = Depends(self.additional_dependency), # noqa
):
"""Return TileJSON document for a dataset."""
Expand Down Expand Up @@ -452,6 +457,7 @@ def wmts(
layer_params=Depends(self.layer_dependency), # noqa
dataset_params=Depends(self.dataset_dependency), # noqa
render_params=Depends(self.render_dependency), # noqa
colormap=Depends(self.colormap_dependency), # noqa
kwargs: Dict = Depends(self.additional_dependency), # noqa
):
"""OGC WMTS endpoint."""
Expand Down Expand Up @@ -572,6 +578,7 @@ def preview(
img_params=Depends(self.img_dependency),
dataset_params=Depends(self.dataset_dependency),
render_params=Depends(self.render_dependency),
colormap=Depends(self.colormap_dependency),
kwargs: Dict = Depends(self.additional_dependency),
):
"""Create preview of a dataset."""
Expand All @@ -587,9 +594,7 @@ def preview(
**dataset_params.kwargs,
**kwargs,
)
colormap = render_params.colormap or getattr(
src_dst, "colormap", None
)
colormap = colormap or getattr(src_dst, "colormap", None)
timings.append(("dataread", round(t.elapsed * 1000, 2)))

if not format:
Expand Down Expand Up @@ -643,6 +648,7 @@ def part(
image_params=Depends(self.img_dependency),
dataset_params=Depends(self.dataset_dependency),
render_params=Depends(self.render_dependency),
colormap=Depends(self.colormap_dependency),
kwargs: Dict = Depends(self.additional_dependency),
):
"""Create image from part of a dataset."""
Expand All @@ -659,9 +665,7 @@ def part(
**dataset_params.kwargs,
**kwargs,
)
colormap = render_params.colormap or getattr(
src_dst, "colormap", None
)
colormap = colormap or getattr(src_dst, "colormap", None)
timings.append(("dataread", round(t.elapsed * 1000, 2)))

with utils.Timer() as t:
Expand Down Expand Up @@ -1082,6 +1086,7 @@ def tile(
layer_params=Depends(self.layer_dependency),
dataset_params=Depends(self.dataset_dependency),
render_params=Depends(self.render_dependency),
colormap=Depends(self.colormap_dependency),
pixel_selection: PixelSelectionMethod = Query(
PixelSelectionMethod.first, description="Pixel selection method."
),
Expand Down Expand Up @@ -1132,7 +1137,7 @@ def tile(
content = image.render(
add_mask=render_params.return_mask,
img_format=format.driver,
colormap=render_params.colormap,
colormap=colormap,
**format.profile,
**render_params.kwargs,
)
Expand Down Expand Up @@ -1182,6 +1187,7 @@ def tilejson(
layer_params=Depends(self.layer_dependency), # noqa
dataset_params=Depends(self.dataset_dependency), # noqa
render_params=Depends(self.render_dependency), # noqa
colormap=Depends(self.colormap_dependency), # noqa
pixel_selection: PixelSelectionMethod = Query(
PixelSelectionMethod.first, description="Pixel selection method."
), # noqa
Expand Down Expand Up @@ -1247,6 +1253,7 @@ def wmts(
layer_params=Depends(self.layer_dependency), # noqa
dataset_params=Depends(self.dataset_dependency), # noqa
render_params=Depends(self.render_dependency), # noqa
colormap=Depends(self.colormap_dependency), # noqa
pixel_selection: PixelSelectionMethod = Query(
PixelSelectionMethod.first, description="Pixel selection method."
), # noqa
Expand Down

0 comments on commit f5f56d4

Please sign in to comment.