Skip to content

Commit

Permalink
adding a CountMethod to default pixel selction and including tests (#676
Browse files Browse the repository at this point in the history
)

* adding a CountMethod to default pixel selction and including tests

* setting count pixel method to output uint8 when mosaic stack is < 256

* running linter

---------

Co-authored-by: Ryan McCarthy <[email protected]>
  • Loading branch information
mccarthyryanc and Ryan McCarthy authored Feb 19, 2024
1 parent 874f141 commit fde5bf0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
33 changes: 33 additions & 0 deletions rio_tiler/mosaic/methods/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,36 @@ def feed(self, array: Optional[numpy.ma.MaskedArray]):
mask = numpy.where(pidex, array.mask, self.mosaic.mask)
self.mosaic = numpy.ma.where(pidex, array, self.mosaic)
self.mosaic.mask = mask


@dataclass
class CountMethod(MosaicMethodBase):
"""Stack the arrays and return the valid pixel count."""

stack: List[numpy.ma.MaskedArray] = field(default_factory=list, init=False)

@property
def data(self) -> Optional[numpy.ma.MaskedArray]:
"""Return valid data count of the data stack."""
if self.stack:
data = numpy.ma.count(numpy.ma.stack(self.stack, axis=0), axis=0)

# only need unint8 for small mosaic stacks
if len(self.stack) < 256:
data = data.astype(numpy.uint8)

# only need the counts from one band
if len(data.shape) > 2:
data = data[0]

# mask is always empty
mask = numpy.zeros(data.shape, dtype=bool)
array = numpy.ma.MaskedArray(data, mask)

return array

return None

def feed(self, array: Optional[numpy.ma.MaskedArray]):
"""Add array to the stack."""
self.stack.append(array)
10 changes: 10 additions & 0 deletions tests/test_mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,16 @@ class aClass(object):
assert t.dtype == "uint16"
assert m.dtype == "uint8"

# Test count pixel selection
(t, m), _ = mosaic.mosaic_reader(
assets, _read_tile, x, y, z, pixel_selection=defaults.CountMethod()
)
assert t.shape == (1, 256, 256)
assert m.shape == (256, 256)
assert m.all()
assert t.dtype == "uint8"
assert m.dtype == "uint8"


def mock_rasterio_open(asset):
"""Mock rasterio Open."""
Expand Down

0 comments on commit fde5bf0

Please sign in to comment.