Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

forward dataset statistics to ImageClass #531

Merged
merged 3 commits into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions rio_tiler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ class ImageData:
crs (rasterio.crs.CRS, optional): Coordinates Reference System of the bounds.
metadata (dict, optional): Additional metadata. Defaults to `{}`.
band_names (list, optional): name of each band. Defaults to `["1", "2", "3"]` for 3 bands image.
dataset_statistics (list, optional): dataset statistics `[(min, max), (min, max)]`

"""

Expand All @@ -266,6 +267,7 @@ class ImageData:
crs: Optional[CRS] = attr.ib(default=None)
metadata: Optional[Dict] = attr.ib(factory=dict)
band_names: List[str] = attr.ib()
dataset_statistics: Optional[Sequence[Tuple[float, float]]] = attr.ib(default=None)

@data.validator
def _validate_data(self, attribute, value):
Expand Down Expand Up @@ -331,8 +333,21 @@ def create_from_list(cls, data: Sequence["ImageData"]):
)
)

stats = list(
itertools.chain.from_iterable(
[img.dataset_statistics for img in data if img.dataset_statistics]
)
)
dataset_statistics = stats if len(stats) == len(band_names) else None

return cls(
arr, mask, assets=assets, crs=crs, bounds=bounds, band_names=band_names
arr,
mask,
assets=assets,
crs=crs,
bounds=bounds,
band_names=band_names,
dataset_statistics=dataset_statistics,
)

def as_masked(self) -> numpy.ma.MaskedArray:
Expand Down Expand Up @@ -391,6 +406,15 @@ def apply_color_formula(self, color_formula: Optional[str]):
def apply_expression(self, expression: str) -> "ImageData":
"""Apply expression to the image data."""
blocks = get_expression_blocks(expression)

stats = self.dataset_statistics
if stats:
res = []
for prod in itertools.product(*stats): # type: ignore
res.append(apply_expression(blocks, self.band_names, numpy.array(prod)))

stats = list(zip([min(r) for r in zip(*res)], [max(r) for r in zip(*res)]))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only real interesting part of the addition (because technically you can access the metadata directly with rasterio). Here we calculate the result of the expression on the min/max values so we get the min/max of the expression output


return ImageData(
apply_expression(blocks, self.band_names, self.data),
self.mask,
Expand All @@ -399,6 +423,7 @@ def apply_expression(self, expression: str) -> "ImageData":
bounds=self.bounds,
band_names=blocks,
metadata=self.metadata,
dataset_statistics=stats,
)

def post_process(
Expand Down Expand Up @@ -473,7 +498,7 @@ def render(
kwargs.update({"crs": self.crs})

data = self.data.copy()
datatype_range = (dtype_ranges[str(data.dtype)],)
datatype_range = self.dataset_statistics or (dtype_ranges[str(data.dtype)],)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if present we use the dataset min/max for the auto rescaling!


if not colormap:
if img_format in ["PNG"] and data.dtype not in ["uint8", "uint16"]:
Expand Down
15 changes: 15 additions & 0 deletions rio_tiler/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ def read(
boundless=boundless,
)

stats = []
for ix in indexes:
tags = dataset.tags(ix)
if all(
stat in tags for stat in ["STATISTICS_MINIMUM", "STATISTICS_MAXIMUM"]
):
stat_min = float(tags.get("STATISTICS_MINIMUM"))
stat_max = float(tags.get("STATISTICS_MAXIMUM"))
stats.append((stat_min, stat_max))

# We only add dataset statistics if we have them for all the indexes
dataset_statistics = stats if len(stats) == len(indexes) else None

if force_binary_mask:
mask = numpy.where(mask != 0, numpy.uint8(255), numpy.uint8(0))

Expand All @@ -209,6 +222,7 @@ def read(
bounds=out_bounds,
crs=dataset.crs,
band_names=[f"b{idx}" for idx in indexes],
dataset_statistics=dataset_statistics,
)

return img
Expand Down Expand Up @@ -387,6 +401,7 @@ def part(
bounds=bounds,
crs=img.crs,
band_names=img.band_names,
dataset_statistics=img.dataset_statistics,
)

return read(
Expand Down
Binary file modified tests/fixtures/blue.tif
Binary file not shown.
Binary file added tests/fixtures/cog_rgb.tif
Binary file not shown.
Binary file modified tests/fixtures/green.tif
Binary file not shown.
Binary file modified tests/fixtures/red.tif
Binary file not shown.
13 changes: 13 additions & 0 deletions tests/test_io_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,16 @@ def raise_for_status(self):
s3_get.assert_called_once()
assert s3_get.call_args[1]["request_pays"]
assert s3_get.call_args[0] == ("somewhereovertherainbow.io", "mystac.json")


@patch("rio_tiler.io.cogeo.rasterio")
def test_img_dataset_stats(rio):
"""Make sure dataset statistics are forwarded."""
rio.open = mock_rasterio_open

with STACReader(STAC_PATH) as stac:
img = stac.preview(assets=("green", "red"))
assert img.dataset_statistics == [(6883, 62785), (6101, 65035)]

img = stac.preview(expression="green_b1/red_b1")
assert img.dataset_statistics == [(6883 / 65035, 62785 / 6101)]
41 changes: 41 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy
import pytest
import rasterio
from rasterio.io import MemoryFile

from rio_tiler.errors import InvalidDatatypeWarning
from rio_tiler.models import ImageData
Expand Down Expand Up @@ -135,3 +136,43 @@ def test_apply_expression():
assert img2.width == 256
assert img2.height == 256
assert img2.band_names == ["b1+b2"]


def test_dataset_statistics():
"""Make statistics are preserved on expression"""
data = numpy.zeros((2, 256, 256), dtype="uint8")
data[0, 0:10, 0:10] = 0
data[0, 10:11, 10:11] = 100
data[1, 0:10, 0:10] = 100
data[1, 10:11, 10:11] = 200
img = ImageData(data, dataset_statistics=[(0, 100), (0, 200)])

img2 = img.apply_expression("b1+b2")
assert img2.dataset_statistics == [(0, 300)]

img2 = img.apply_expression("b1+b2;b1*b2;b1/b1")
assert img2.dataset_statistics == [(0, 300), (0, 20000), (0, 1)]
assert img2.data[0].min() == 0
assert img2.data[0].max() == 300
assert img2.data[1].min() == 0
assert img2.data[1].max() == 20000
assert img2.data[2].min() == 0
assert img2.data[2].max() == 1

data = numpy.zeros((1, 256, 256), dtype="int16")
data[0, 0:10, 0:10] = 0
data[0, 10:11, 10:11] = 1

img = ImageData(data, dataset_statistics=[(0, 1)]).render(img_format="PNG")
with MemoryFile(img) as mem:
with mem.open() as dst:
arr = dst.read(indexes=1)
assert arr.min() == 0
assert arr.max() == 255

img = ImageData(data).render(img_format="PNG")
with MemoryFile(img) as mem:
with mem.open() as dst:
arr = dst.read(indexes=1)
assert not arr.min() == 0
assert not arr.max() == 255