Skip to content

Commit

Permalink
Adding this feature and starting a test for it.
Browse files Browse the repository at this point in the history
  • Loading branch information
phargogh committed Aug 10, 2022
1 parent 4929c2f commit 9e2820f
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 7 deletions.
43 changes: 36 additions & 7 deletions src/pygeoprocessing/geoprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,8 @@ def interpolate_points(
def zonal_statistics(
base_raster_path_band, aggregate_vector_path,
aggregate_layer_name=None, ignore_nodata=True,
polygons_might_overlap=True, working_dir=None):
polygons_might_overlap=True, include_value_counts=False,
working_dir=None):
"""Collect stats on pixel values which lie within polygons.
This function summarizes raster statistics including min, max,
Expand Down Expand Up @@ -1183,6 +1184,10 @@ def zonal_statistics(
computationally expensive for cases where there are many polygons.
this flag to False directs the function rasterize in one
step.
include_value_counts (boolean): If True, the function tallies the
number of pixels of each value under the polygon. This is useful
for classified rasters but could exhaust available memory when run
on a continuous (floating-point) raster. Defaults to False.
working_dir (string): If not None, indicates where temporary files
should be created during this run.
Expand All @@ -1194,7 +1199,12 @@ def zonal_statistics(
'max': 1,
'sum': 1.7,
'count': 3,
'nodata_count': 1
'nodata_count': 1,
'value_counts': {
2: 5,
4: 1,
14: 2,
}
}
}
Expand Down Expand Up @@ -1234,12 +1244,24 @@ def zonal_statistics(

# clip base raster to aggregating vector intersection
raster_info = get_raster_info(base_raster_path_band[0])

if (raster_info['datatype'] in {gdal.GDT_Float32, gdal.GDT_Float64}
and include_value_counts):
LOGGER.warning(
"Value counts requested on a floating-point raster, which can "
"cause excessive memory usage.")

# -1 here because bands are 1 indexed
raster_nodata = raster_info['nodata'][base_raster_path_band[1]-1]
temp_working_dir = tempfile.mkdtemp(dir=working_dir)
clipped_raster_path = os.path.join(
temp_working_dir, 'clipped_raster.tif')

sample_aggregate_dict = {
'min': None, 'max': None, 'count': 0, 'nodata_count': 0, 'sum': 0.0}
if include_value_counts:
sample_aggregate_dict['value_counts'] = {}

try:
align_and_resize_raster_stack(
[base_raster_path_band[0]], [clipped_raster_path], ['near'],
Expand All @@ -1254,9 +1276,7 @@ def zonal_statistics(
"aggregate vector %s does not intersect with the raster %s",
aggregate_vector_path, base_raster_path_band)
aggregate_stats = collections.defaultdict(
lambda: {
'min': None, 'max': None, 'count': 0, 'nodata_count': 0,
'sum': 0.0})
lambda: sample_aggregate_dict)
for feature in aggregate_layer:
_ = aggregate_stats[feature.GetFID()]
return dict(aggregate_stats)
Expand Down Expand Up @@ -1297,8 +1317,7 @@ def zonal_statistics(
iterblocks((agg_fid_raster_path, 1), offset_only=True))
agg_fid_raster = gdal.OpenEx(
agg_fid_raster_path, gdal.GA_Update | gdal.OF_RASTER)
aggregate_stats = collections.defaultdict(lambda: {
'min': None, 'max': None, 'count': 0, 'nodata_count': 0, 'sum': 0.0})
aggregate_stats = collections.defaultdict(lambda: sample_aggregate_dict)
last_time = time.time()
LOGGER.info("processing %d disjoint polygon sets", len(disjoint_fid_sets))
for set_index, disjoint_fid_set in enumerate(disjoint_fid_sets):
Expand Down Expand Up @@ -1403,6 +1422,12 @@ def zonal_statistics(
masked_clipped_block.size)
aggregate_stats[agg_fid]['sum'] += numpy.sum(
masked_clipped_block)

if include_value_counts:
aggregate_stats[agg_fid]['value_counts'] = dict(
pair for pair in zip(
*numpy.unique(masked_clipped_block,
return_counts=True)))
unset_fids = aggregate_layer_fid_set.difference(aggregate_stats)
LOGGER.debug(
"unset_fids: %s of %s ", len(unset_fids),
Expand Down Expand Up @@ -1487,6 +1512,10 @@ def zonal_statistics(
aggregate_stats[unset_fid]['count'] = valid_unset_fid_block.size
aggregate_stats[unset_fid]['nodata_count'] = numpy.count_nonzero(
unset_fid_nodata_mask)
if include_value_counts:
aggregate_stats[unset_fid]['value_counts'] = dict(
pair for pair in zip(*numpy.unique(valid_unset_fid_block,
return_counts=True)))

unset_fids = aggregate_layer_fid_set.difference(aggregate_stats)
LOGGER.debug(
Expand Down
73 changes: 73 additions & 0 deletions tests/test_geoprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,79 @@ def test_zonal_statistics(self):
'sum': 0.0}}
self.assertEqual(result, expected_result)

def test_zonal_statistics_value_counts(self):
"""PGP.geoprocessing: test zonal stats function (value counts)."""
# create aggregating polygon
pixel_size = 30.0
n_pixels = 9
origin = (444720, 3751320)
polygon_a = shapely.geometry.Polygon([
(origin[0], origin[1]),
(origin[0], -pixel_size * n_pixels+origin[1]),
(origin[0]+pixel_size * n_pixels,
-pixel_size * n_pixels+origin[1]),
(origin[0]+pixel_size * n_pixels, origin[1]),
(origin[0], origin[1])])
origin = (444720, 3751320)
polygon_b = shapely.geometry.Polygon([
(origin[0], origin[1]),
(origin[0], -pixel_size+origin[1]),
(origin[0]+pixel_size, -pixel_size+origin[1]),
(origin[0]+pixel_size, origin[1]),
(origin[0], origin[1])])
polygon_c = shapely.geometry.Polygon([
(origin[1]*2, origin[1]*3),
(origin[1]*2, -pixel_size+origin[1]*3),
(origin[1]*2+pixel_size,
-pixel_size+origin[1]*3),
(origin[1]*2+pixel_size, origin[1]*3),
(origin[1]*2, origin[1]*3)])
aggregating_vector_path = os.path.join(
self.workspace_dir, 'aggregate_vector')
_geometry_to_vector(
[polygon_a, polygon_b, polygon_c], aggregating_vector_path)
pixel_matrix = numpy.ones((n_pixels, n_pixels), numpy.float32)
target_nodata = None
raster_path = os.path.join(self.workspace_dir, 'raster.tif')
_array_to_raster(
pixel_matrix, target_nodata, raster_path)
with capture_logging(
logging.getLogger('pygeoprocessing')) as log_messages:
result = pygeoprocessing.zonal_statistics(
(raster_path, 1), aggregating_vector_path,
aggregate_layer_name=None,
ignore_nodata=True,
include_value_counts=True,
polygons_might_overlap=True)

# Raster is float32, so we expect a warning to be posted.
self.assertEqual(len(log_messages), 1)
self.assertEqual(log_messages[0].level, logging.WARNING)
expected_result = {
0: {
'count': 81,
'max': 1.0,
'min': 1.0,
'nodata_count': 0,
'sum': 81.0,
'value_counts': {1.0: 81}},
1: {
'count': 1,
'max': 1.0,
'min': 1.0,
'nodata_count': 0,
'sum': 1.0,
'value_counts': {1.0: 1}},
2: {
'min': None,
'max': None,
'count': 0,
'nodata_count': 0,
'sum': 0.0,
'value_counts': {}}
}
self.assertEqual(result, expected_result)

def test_zonal_statistics_nodata(self):
"""PGP.geoprocessing: test zonal stats function with non-overlap."""
# create aggregating polygon
Expand Down

0 comments on commit 9e2820f

Please sign in to comment.