Skip to content

Commit

Permalink
refactor benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentsarago committed Oct 4, 2024
1 parent a2302cb commit 6e84019
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 22 deletions.
32 changes: 16 additions & 16 deletions tests/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@
def read_tile(dst, tile):
"""Benchmark rio-tiler.utils._tile_read."""
# We make sure to not store things in cache.
with rasterio.Env(GDAL_CACHEMAX=0, NUM_THREADS="all"):
with Reader(None, dataset=dst) as src:
return src.tile(*tile)
with Reader(None, dataset=dst) as src:
return src.tile(*tile)


data_types = list(dtype_ranges.keys())
Expand All @@ -93,16 +92,17 @@ def test_tile(
tile = benchmark_tiles[dataset_name][tile_name]

dst_info = datasets[dataset_name]
with MemoryFile(
dataset_fixture(
crs=dst_info["crs"],
bounds=list(dst_info["bounds"]),
dtype=data_type,
nodata_type=nodata_type,
width=256,
height=256,
)
) as memfile:
with memfile.open() as dst:
img = benchmark(read_tile, dst, tile)
assert img.data.dtype == data_type
with rasterio.Env(GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR", NUM_THREADS="all"):
with MemoryFile(
dataset_fixture(
crs=dst_info["crs"],
bounds=list(dst_info["bounds"]),
dtype=data_type,
nodata_type=nodata_type,
width=4000,
height=4000,
)
) as memfile:
with memfile.open() as dst:
img = benchmark(read_tile, dst, tile)
assert img.data.dtype == data_type
33 changes: 27 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import rasterio
from rasterio.crs import CRS
from rasterio.enums import ColorInterp
from rasterio.enums import Resampling as ResamplingEnums
from rasterio.io import MemoryFile
from rasterio.rio.overview import get_maximum_overview_level
from rasterio.shutil import copy
from rasterio.transform import from_bounds

with rasterio.Env() as env:
Expand All @@ -33,16 +36,18 @@ def _dataset(
width: int = 256,
height: int = 256,
):
max_value = 127 if dtype == "int8" else 255

# Data
arr = numpy.zeros((nband, height, width), dtype=dtype) + 1
arr[:, range(height), range(width)] = max_value
arr[:, range(height - 1, 0, -1), range(width - 1)] = max_value
arr[:, :, width // 2] = max_value
arr[:, height // 2, :] = max_value
arr[:, 0:128, 0:128] = 0

# Mask/Alpha
if dtype == "int8":
mask = numpy.zeros((1, height, width), dtype=dtype) + 127
else:
mask = numpy.zeros((1, height, width), dtype=dtype) + 255

mask = numpy.zeros((1, height, width), dtype=dtype) + max_value
mask[:, 0:128, 0:128] = 0

# Input Profile
Expand Down Expand Up @@ -89,6 +94,22 @@ def _dataset(
if nodata_type == "mask":
mem.write_mask(mask[0])

return BytesIO(memfile.read())
overview_level = get_maximum_overview_level(
mem.width, mem.height, minsize=512
)
overviews = [2**j for j in range(1, overview_level + 1)]
mem.build_overviews(overviews, ResamplingEnums.bilinear)

cog_profile = {
"interleave": "pixel",
"compress": "DEFLATE",
"tiled": True,
"blockxsize": 512,
"blockysize": 512,
}

with MemoryFile() as cogfile:
copy(mem, cogfile.name, copy_src_overviews=True, **cog_profile)
return BytesIO(cogfile.read())

return _dataset

0 comments on commit 6e84019

Please sign in to comment.