From fde5bf03b7fe94007b50dad6aec09fcecddfd414 Mon Sep 17 00:00:00 2001 From: Ryan Date: Mon, 19 Feb 2024 05:57:27 -0800 Subject: [PATCH] adding a CountMethod to default pixel selction and including tests (#676) * adding a CountMethod to default pixel selction and including tests * setting count pixel method to output uint8 when mosaic stack is < 256 * running linter --------- Co-authored-by: Ryan McCarthy --- rio_tiler/mosaic/methods/defaults.py | 33 ++++++++++++++++++++++++++++ tests/test_mosaic.py | 10 +++++++++ 2 files changed, 43 insertions(+) diff --git a/rio_tiler/mosaic/methods/defaults.py b/rio_tiler/mosaic/methods/defaults.py index 8e163808..d93aef4c 100644 --- a/rio_tiler/mosaic/methods/defaults.py +++ b/rio_tiler/mosaic/methods/defaults.py @@ -188,3 +188,36 @@ def feed(self, array: Optional[numpy.ma.MaskedArray]): mask = numpy.where(pidex, array.mask, self.mosaic.mask) self.mosaic = numpy.ma.where(pidex, array, self.mosaic) self.mosaic.mask = mask + + +@dataclass +class CountMethod(MosaicMethodBase): + """Stack the arrays and return the valid pixel count.""" + + stack: List[numpy.ma.MaskedArray] = field(default_factory=list, init=False) + + @property + def data(self) -> Optional[numpy.ma.MaskedArray]: + """Return valid data count of the data stack.""" + if self.stack: + data = numpy.ma.count(numpy.ma.stack(self.stack, axis=0), axis=0) + + # only need unint8 for small mosaic stacks + if len(self.stack) < 256: + data = data.astype(numpy.uint8) + + # only need the counts from one band + if len(data.shape) > 2: + data = data[0] + + # mask is always empty + mask = numpy.zeros(data.shape, dtype=bool) + array = numpy.ma.MaskedArray(data, mask) + + return array + + return None + + def feed(self, array: Optional[numpy.ma.MaskedArray]): + """Add array to the stack.""" + self.stack.append(array) diff --git a/tests/test_mosaic.py b/tests/test_mosaic.py index 5fb1a2d3..ca2cf9df 100644 --- a/tests/test_mosaic.py +++ b/tests/test_mosaic.py @@ -285,6 +285,16 @@ class aClass(object): assert t.dtype == "uint16" assert m.dtype == "uint8" + # Test count pixel selection + (t, m), _ = mosaic.mosaic_reader( + assets, _read_tile, x, y, z, pixel_selection=defaults.CountMethod() + ) + assert t.shape == (1, 256, 256) + assert m.shape == (256, 256) + assert m.all() + assert t.dtype == "uint8" + assert m.dtype == "uint8" + def mock_rasterio_open(asset): """Mock rasterio Open."""