Skip to content

Commit

Permalink
refactor benchmark (#743)
Browse files Browse the repository at this point in the history
* refactor benchmark

* optional filled

* update benchmark names
  • Loading branch information
vincentsarago authored Oct 4, 2024
1 parent a2302cb commit bb33ffb
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 33 deletions.
49 changes: 22 additions & 27 deletions tests/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,38 +71,33 @@
def read_tile(dst, tile):
"""Benchmark rio-tiler.utils._tile_read."""
# We make sure to not store things in cache.
with rasterio.Env(GDAL_CACHEMAX=0, NUM_THREADS="all"):
with Reader(None, dataset=dst) as src:
return src.tile(*tile)
with Reader(None, dataset=dst) as src:
return src.tile(*tile)


data_types = list(dtype_ranges.keys())
nodata_type = ["nodata", "alpha", "mask", "none"]


@pytest.mark.parametrize("tile_name", ["full"])
@pytest.mark.parametrize("dataset_name", ["equator", "dateline"])
@pytest.mark.parametrize("data_type", list(dtype_ranges.keys()))
@pytest.mark.parametrize("nodata_type", ["nodata", "alpha", "mask", "none"])
def test_tile(
nodata_type, data_type, dataset_name, tile_name, dataset_fixture, benchmark
):
def test_tile(nodata_type, data_type, dataset_name, dataset_fixture, benchmark):
"""Test tile read for multiple combination of datatype/mask/tile extent."""
benchmark.name = f"{data_type}-{nodata_type}"
benchmark.group = f"{dataset_name} - {tile_name} tile "
tile = benchmark_tiles[dataset_name][tile_name]
benchmark.name = f"{dataset_name}-{data_type}-{nodata_type}"
benchmark.fullname = f"{dataset_name}-{data_type}-{nodata_type}"
benchmark.group = dataset_name

tile = benchmark_tiles[dataset_name]["full"]
dst_info = datasets[dataset_name]
with MemoryFile(
dataset_fixture(
crs=dst_info["crs"],
bounds=list(dst_info["bounds"]),
dtype=data_type,
nodata_type=nodata_type,
width=256,
height=256,
)
) as memfile:
with memfile.open() as dst:
img = benchmark(read_tile, dst, tile)
assert img.data.dtype == data_type
with rasterio.Env(GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR", NUM_THREADS="all"):
with MemoryFile(
dataset_fixture(
crs=dst_info["crs"],
bounds=list(dst_info["bounds"]),
dtype=data_type,
nodata_type=nodata_type,
width=4000,
height=4000,
filled=True,
)
) as memfile:
with memfile.open() as dst:
img = benchmark(read_tile, dst, tile)
assert img.data.dtype == data_type
36 changes: 30 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import rasterio
from rasterio.crs import CRS
from rasterio.enums import ColorInterp
from rasterio.enums import Resampling as ResamplingEnums
from rasterio.io import MemoryFile
from rasterio.rio.overview import get_maximum_overview_level
from rasterio.shutil import copy
from rasterio.transform import from_bounds

with rasterio.Env() as env:
Expand All @@ -32,17 +35,22 @@ def _dataset(
nband: int = 3,
width: int = 256,
height: int = 256,
filled: bool = False,
):
max_value = 127 if dtype == "int8" else 255

# Data
arr = numpy.zeros((nband, height, width), dtype=dtype) + 1
if filled:
arr[:, range(height), range(width)] = max_value
arr[:, range(height - 1, 0, -1), range(width - 1)] = max_value
arr[:, :, width // 2] = max_value
arr[:, height // 2, :] = max_value

arr[:, 0:128, 0:128] = 0

# Mask/Alpha
if dtype == "int8":
mask = numpy.zeros((1, height, width), dtype=dtype) + 127
else:
mask = numpy.zeros((1, height, width), dtype=dtype) + 255

mask = numpy.zeros((1, height, width), dtype=dtype) + max_value
mask[:, 0:128, 0:128] = 0

# Input Profile
Expand Down Expand Up @@ -89,6 +97,22 @@ def _dataset(
if nodata_type == "mask":
mem.write_mask(mask[0])

return BytesIO(memfile.read())
overview_level = get_maximum_overview_level(
mem.width, mem.height, minsize=512
)
overviews = [2**j for j in range(1, overview_level + 1)]
mem.build_overviews(overviews, ResamplingEnums.bilinear)

cog_profile = {
"interleave": "pixel",
"compress": "DEFLATE",
"tiled": True,
"blockxsize": 512,
"blockysize": 512,
}

with MemoryFile() as cogfile:
copy(mem, cogfile.name, copy_src_overviews=True, **cog_profile)
return BytesIO(cogfile.read())

return _dataset

0 comments on commit bb33ffb

Please sign in to comment.