diff --git a/CHANGES.md b/CHANGES.md index 1a1f8843..78f2564d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -13,6 +13,7 @@ * Adding support for GDAL VRT Connection string for STAC Assets * Improve type hint definition * make `ImageData.rescale` and `ImageData.apply_color_formula` to return `self` +* add support for `.json` colormap files # 6.7.0 (2024-09-05) diff --git a/rio_tiler/colormap.py b/rio_tiler/colormap.py index 0f65aa2a..04ec96f7 100644 --- a/rio_tiler/colormap.py +++ b/rio_tiler/colormap.py @@ -1,5 +1,6 @@ """rio-tiler colormap functions and classes.""" +import json import os import pathlib import re @@ -22,23 +23,30 @@ ) try: - from importlib.resources import files as resources_files # type: ignore + from importlib.resources import as_file + from importlib.resources import files as resources_files except ImportError: # Try backported to PY<39 `importlib_resources`. + from importlib_resources import as_file # type: ignore from importlib_resources import files as resources_files # type: ignore EMPTY_COLORMAP: GDALColorMapType = {i: (0, 0, 0, 0) for i in range(256)} -DEFAULT_CMAPS_FILES = { - f.stem: str(f) - for f in (resources_files(__package__) / "cmap_data").glob("*.npy") # type: ignore -} +_RIO_CMAP_DIR = resources_files(__package__) / "cmap_data" +with as_file(_RIO_CMAP_DIR) as p: + DEFAULT_CMAPS_FILES = { + f.stem: f for f in p.glob("**/*") if f.suffix in {".npy", ".json"} + } USER_CMAPS_DIR = os.environ.get("COLORMAP_DIRECTORY", None) if USER_CMAPS_DIR: DEFAULT_CMAPS_FILES.update( - {f.stem: str(f) for f in pathlib.Path(USER_CMAPS_DIR).glob("*.npy")} + { + f.stem: f + for f in pathlib.Path(USER_CMAPS_DIR).glob("**/*") + if f.suffix in {".npy", ".json"} + } ) @@ -274,7 +282,7 @@ class ColorMaps: """ - data: Dict[str, Union[str, ColorMapType]] = attr.ib( + data: Dict[str, Union[str, pathlib.Path, ColorMapType]] = attr.ib( default=attr.Factory(lambda: DEFAULT_CMAPS_FILES) ) @@ -292,13 +300,37 @@ def get(self, name: str) -> ColorMapType: if cmap is None: raise InvalidColorMapName(f"Invalid colormap name: {name}") - if isinstance(cmap, str): - colormap = numpy.load(cmap) - assert colormap.shape == (256, 4) - assert colormap.dtype == numpy.uint8 - return {idx: tuple(value) for idx, value in enumerate(colormap)} # type: ignore - else: - return cmap + if isinstance(cmap, (pathlib.Path, str)): + if isinstance(cmap, str): + cmap = pathlib.Path(cmap) + + if cmap.suffix == ".npy": + colormap = numpy.load(cmap) + assert colormap.shape == (256, 4) + assert colormap.dtype == numpy.uint8 + return {idx: tuple(value) for idx, value in enumerate(colormap)} + + elif cmap.suffix == ".json": + with cmap.open() as f: + cmap_data = json.load( + f, + object_hook=lambda x: { + int(k): parse_color(v) for k, v in x.items() + }, + ) + + # Make sure to match colormap type + if isinstance(cmap_data, Sequence): + cmap_data = [ + (tuple(inter), parse_color(v)) # type: ignore + for (inter, v) in cmap_data + ] + + return cmap_data + + raise ValueError(f"Not supported {cmap.suffix} extension for ColorMap") + + return cmap def list(self) -> List[str]: """List registered Colormaps. @@ -311,7 +343,7 @@ def list(self) -> List[str]: def register( self, - custom_cmap: Dict[str, Union[str, ColorMapType]], + custom_cmap: Dict[str, Union[str, pathlib.Path, ColorMapType]], overwrite: bool = False, ) -> "ColorMaps": """Register a custom colormap. diff --git a/tests/fixtures/cmap/bad.json b/tests/fixtures/cmap/bad.json new file mode 100644 index 00000000..2cfa4836 --- /dev/null +++ b/tests/fixtures/cmap/bad.json @@ -0,0 +1,3 @@ +{ + "reallybad": "something bad" +} diff --git a/tests/fixtures/cmap/nlcd.json b/tests/fixtures/cmap/nlcd.json new file mode 100644 index 00000000..df84b7e7 --- /dev/null +++ b/tests/fixtures/cmap/nlcd.json @@ -0,0 +1,22 @@ +{ + "11": "#486DA2", + "12": "#E7EFFC", + "21": "#E1CDCE", + "22": "#DC9881", + "23": "#F10100", + "24": "#AB0101", + "31": "#B3AFA4", + "41": "#6BA966", + "42": "#1D6533", + "43": "#BDCC93", + "51": "#B29C46", + "52": "#D1BB82", + "71": "#EDECCD", + "72": "#D0D181", + "73": "#A4CC51", + "74": "#82BA9D", + "81": "#DDD83E", + "82": "#AE7229", + "90": "#BBD7ED", + "95": "#71A4C1" +} diff --git a/tests/fixtures/cmap/sequence.json b/tests/fixtures/cmap/sequence.json new file mode 100644 index 00000000..cef5acb4 --- /dev/null +++ b/tests/fixtures/cmap/sequence.json @@ -0,0 +1,26 @@ +[ + [ + [ + 1, + 2 + ], + [ + 255, + 0, + 0, + 255 + ] + ], + [ + [ + 2, + 3 + ], + [ + 255, + 240, + 255, + 255 + ] + ] +] diff --git a/tests/test_cmap.py b/tests/test_cmap.py index 3760d9d8..7d1c42c6 100644 --- a/tests/test_cmap.py +++ b/tests/test_cmap.py @@ -1,12 +1,14 @@ """tests rio_tiler colormaps""" +import json +import os +import pathlib from copy import deepcopy import numpy import pytest from rio_tiler import colormap -from rio_tiler.colormap import DEFAULT_CMAPS_FILES from rio_tiler.errors import ( ColorMapAlreadyRegistered, InvalidColorFormat, @@ -14,20 +16,30 @@ InvalidFormat, ) +try: + from importlib.resources import as_file +except ImportError: + # Try backported to PY<39 `importlib_resources`. + from importlib_resources import as_file # type: ignore + +PREFIX = os.path.join(os.path.dirname(__file__), "fixtures", "cmap") + colormap_number = 211 +with as_file(colormap._RIO_CMAP_DIR) as p: + DEFAULT_CMAPS_FILES = { + f.stem: f for f in p.glob("**/*") if f.suffix in {".npy", ".json"} + } + -def test_get_cmaplist(monkeypatch): +def test_get_cmaplist(): """Should work as expected return all rio-tiler colormaps.""" - monkeypatch.delenv("COLORMAP_DIRECTORY", raising=False) assert len(DEFAULT_CMAPS_FILES) == colormap_number -def test_cmapObject(monkeypatch): +def test_cmapObject(): """Test Colormap object handler.""" - monkeypatch.delenv("COLORMAP_DIRECTORY", raising=False) - - cmap = colormap.cmap + cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES) assert len(cmap.list()) == colormap_number with pytest.raises(InvalidColorMapName): @@ -54,17 +66,45 @@ def test_cmapObject(monkeypatch): assert new_cmap.get("empty") +def test_cmap_json(): + """Test Colormap with JSON files.""" + cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES) + assert len(cmap.list()) == colormap_number + + new_cmap = cmap.register( + { + "nlcd": pathlib.Path(PREFIX) / "nlcd.json", + "sequence": pathlib.Path(PREFIX) / "sequence.json", + "bad": pathlib.Path(PREFIX) / "bad.json", + } + ) + assert len(new_cmap.list()) == colormap_number + 3 + nlcd = new_cmap.get("nlcd") + assert isinstance(nlcd, dict) + assert nlcd[11] == (72, 109, 162, 255) + + seq = new_cmap.get("sequence") + assert isinstance(seq, list) + assert seq[0][0] == (1, 2) + assert seq[0][1] == (255, 0, 0, 255) + + with pytest.raises((json.JSONDecodeError, ValueError)): + new_cmap.get("bad") + + def test_valid_cmaps(): """Make sure all colormaps have 4 values and 256 items.""" - for c in colormap.cmap.list(): - cm = colormap.cmap.get(c) + cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES) + for c in cmap.list(): + cm = cmap.get(c) assert len(cm[0]) == 4 assert len(cm.items()) == 256 def test_update_alpha(): """Should update the alpha channel.""" - cm = colormap.cmap.get("viridis") + cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES) + cm = cmap.get("viridis") idx = 1 assert cm[idx][-1] == 255 colormap._update_alpha(cm, idx) @@ -83,7 +123,8 @@ def test_update_alpha(): def test_remove_value(): """Should remove cmap value.""" - cm = colormap.cmap.get("viridis") + cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES) + cm = cmap.get("viridis") idx = 1 colormap._remove_value(cm, idx) assert not cm.get(1) @@ -96,7 +137,8 @@ def test_remove_value(): def test_update_cmap(): """Should update the colormap.""" - cm = colormap.cmap.get("viridis") + cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES) + cm = cmap.get("viridis") val = {1: (0, 0, 0, 0), 2: (255, 255, 255, 255)} colormap._update_cmap(cm, val) assert cm[1] == (0, 0, 0, 0)