Skip to content

Commit

Permalink
add support for JSON colormap files
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentsarago committed Oct 3, 2024
1 parent 0a84561 commit cb3fa3a
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* Adding support for GDAL VRT Connection string for STAC Assets
* Improve type hint definition
* make `ImageData.rescale` and `ImageData.apply_color_formula` to return `self`
* add support for `.json` colormap files

# 6.7.0 (2024-09-05)

Expand Down
62 changes: 47 additions & 15 deletions rio_tiler/colormap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""rio-tiler colormap functions and classes."""

import json
import os
import pathlib
import re
Expand All @@ -22,23 +23,30 @@
)

try:
from importlib.resources import files as resources_files # type: ignore
from importlib.resources import as_file
from importlib.resources import files as resources_files
except ImportError:
# Try backported to PY<39 `importlib_resources`.
from importlib_resources import as_file # type: ignore
from importlib_resources import files as resources_files # type: ignore


EMPTY_COLORMAP: GDALColorMapType = {i: (0, 0, 0, 0) for i in range(256)}

DEFAULT_CMAPS_FILES = {
f.stem: str(f)
for f in (resources_files(__package__) / "cmap_data").glob("*.npy") # type: ignore
}
_RIO_CMAP_DIR = resources_files(__package__) / "cmap_data"
with as_file(_RIO_CMAP_DIR) as p:
DEFAULT_CMAPS_FILES = {
f.stem: f for f in p.glob("**/*") if f.suffix in {".npy", ".json"}
}

USER_CMAPS_DIR = os.environ.get("COLORMAP_DIRECTORY", None)
if USER_CMAPS_DIR:
DEFAULT_CMAPS_FILES.update(
{f.stem: str(f) for f in pathlib.Path(USER_CMAPS_DIR).glob("*.npy")}
{
f.stem: f
for f in pathlib.Path(USER_CMAPS_DIR).glob("**/*")
if f.suffix in {".npy", ".json"}
}
)


Expand Down Expand Up @@ -274,7 +282,7 @@ class ColorMaps:
"""

data: Dict[str, Union[str, ColorMapType]] = attr.ib(
data: Dict[str, Union[str, pathlib.Path, ColorMapType]] = attr.ib(
default=attr.Factory(lambda: DEFAULT_CMAPS_FILES)
)

Expand All @@ -292,13 +300,37 @@ def get(self, name: str) -> ColorMapType:
if cmap is None:
raise InvalidColorMapName(f"Invalid colormap name: {name}")

if isinstance(cmap, str):
colormap = numpy.load(cmap)
assert colormap.shape == (256, 4)
assert colormap.dtype == numpy.uint8
return {idx: tuple(value) for idx, value in enumerate(colormap)} # type: ignore
else:
return cmap
if isinstance(cmap, (pathlib.Path, str)):
if isinstance(cmap, str):
cmap = pathlib.Path(cmap)

if cmap.suffix == ".npy":
colormap = numpy.load(cmap)
assert colormap.shape == (256, 4)
assert colormap.dtype == numpy.uint8
return {idx: tuple(value) for idx, value in enumerate(colormap)}

elif cmap.suffix == ".json":
with cmap.open() as f:
cmap_data = json.load(
f,
object_hook=lambda x: {
int(k): parse_color(v) for k, v in x.items()
},
)

# Make sure to match colormap type
if isinstance(cmap_data, Sequence):
cmap_data = [
(tuple(inter), parse_color(v)) # type: ignore
for (inter, v) in cmap_data
]

return cmap_data

raise ValueError(f"Not supported {cmap.suffix} extension for ColorMap")

return cmap

def list(self) -> List[str]:
"""List registered Colormaps.
Expand All @@ -311,7 +343,7 @@ def list(self) -> List[str]:

def register(
self,
custom_cmap: Dict[str, Union[str, ColorMapType]],
custom_cmap: Dict[str, Union[str, pathlib.Path, ColorMapType]],
overwrite: bool = False,
) -> "ColorMaps":
"""Register a custom colormap.
Expand Down
3 changes: 3 additions & 0 deletions tests/fixtures/cmap/bad.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"reallybad": "something bad"
}
22 changes: 22 additions & 0 deletions tests/fixtures/cmap/nlcd.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"11": "#486DA2",
"12": "#E7EFFC",
"21": "#E1CDCE",
"22": "#DC9881",
"23": "#F10100",
"24": "#AB0101",
"31": "#B3AFA4",
"41": "#6BA966",
"42": "#1D6533",
"43": "#BDCC93",
"51": "#B29C46",
"52": "#D1BB82",
"71": "#EDECCD",
"72": "#D0D181",
"73": "#A4CC51",
"74": "#82BA9D",
"81": "#DDD83E",
"82": "#AE7229",
"90": "#BBD7ED",
"95": "#71A4C1"
}
26 changes: 26 additions & 0 deletions tests/fixtures/cmap/sequence.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[
[
[
1,
2
],
[
255,
0,
0,
255
]
],
[
[
2,
3
],
[
255,
240,
255,
255
]
]
]
66 changes: 54 additions & 12 deletions tests/test_cmap.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,45 @@
"""tests rio_tiler colormaps"""

import json
import os
import pathlib
from copy import deepcopy

import numpy
import pytest

from rio_tiler import colormap
from rio_tiler.colormap import DEFAULT_CMAPS_FILES
from rio_tiler.errors import (
ColorMapAlreadyRegistered,
InvalidColorFormat,
InvalidColorMapName,
InvalidFormat,
)

try:
from importlib.resources import as_file
except ImportError:
# Try backported to PY<39 `importlib_resources`.
from importlib_resources import as_file # type: ignore

PREFIX = os.path.join(os.path.dirname(__file__), "fixtures", "cmap")

colormap_number = 211

with as_file(colormap._RIO_CMAP_DIR) as p:
DEFAULT_CMAPS_FILES = {
f.stem: f for f in p.glob("**/*") if f.suffix in {".npy", ".json"}
}


def test_get_cmaplist(monkeypatch):
def test_get_cmaplist():
"""Should work as expected return all rio-tiler colormaps."""
monkeypatch.delenv("COLORMAP_DIRECTORY", raising=False)
assert len(DEFAULT_CMAPS_FILES) == colormap_number


def test_cmapObject(monkeypatch):
def test_cmapObject():
"""Test Colormap object handler."""
monkeypatch.delenv("COLORMAP_DIRECTORY", raising=False)

cmap = colormap.cmap
cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES)
assert len(cmap.list()) == colormap_number

with pytest.raises(InvalidColorMapName):
Expand All @@ -54,17 +66,45 @@ def test_cmapObject(monkeypatch):
assert new_cmap.get("empty")


def test_cmap_json():
"""Test Colormap with JSON files."""
cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES)
assert len(cmap.list()) == colormap_number

new_cmap = cmap.register(
{
"nlcd": pathlib.Path(PREFIX) / "nlcd.json",
"sequence": pathlib.Path(PREFIX) / "sequence.json",
"bad": pathlib.Path(PREFIX) / "bad.json",
}
)
assert len(new_cmap.list()) == colormap_number + 3
nlcd = new_cmap.get("nlcd")
assert isinstance(nlcd, dict)
assert nlcd[11] == (72, 109, 162, 255)

seq = new_cmap.get("sequence")
assert isinstance(seq, list)
assert seq[0][0] == (1, 2)
assert seq[0][1] == (255, 0, 0, 255)

with pytest.raises((json.JSONDecodeError, ValueError)):
new_cmap.get("bad")


def test_valid_cmaps():
"""Make sure all colormaps have 4 values and 256 items."""
for c in colormap.cmap.list():
cm = colormap.cmap.get(c)
cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES)
for c in cmap.list():
cm = cmap.get(c)
assert len(cm[0]) == 4
assert len(cm.items()) == 256


def test_update_alpha():
"""Should update the alpha channel."""
cm = colormap.cmap.get("viridis")
cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES)
cm = cmap.get("viridis")
idx = 1
assert cm[idx][-1] == 255
colormap._update_alpha(cm, idx)
Expand All @@ -83,7 +123,8 @@ def test_update_alpha():

def test_remove_value():
"""Should remove cmap value."""
cm = colormap.cmap.get("viridis")
cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES)
cm = cmap.get("viridis")
idx = 1
colormap._remove_value(cm, idx)
assert not cm.get(1)
Expand All @@ -96,7 +137,8 @@ def test_remove_value():

def test_update_cmap():
"""Should update the colormap."""
cm = colormap.cmap.get("viridis")
cmap = colormap.ColorMaps(data=DEFAULT_CMAPS_FILES)
cm = cmap.get("viridis")
val = {1: (0, 0, 0, 0), 2: (255, 255, 255, 255)}
colormap._update_cmap(cm, val)
assert cm[1] == (0, 0, 0, 0)
Expand Down

0 comments on commit cb3fa3a

Please sign in to comment.