From b2b9f83142527f90cfde3a72b594c91989e33cb9 Mon Sep 17 00:00:00 2001 From: Vincent Sarago Date: Mon, 9 Sep 2024 12:59:32 +0200 Subject: [PATCH] enable dynamic definition of Reader for MultiBaseReader (#711) * enable dynamic definition of Reader for MultiBaseReader * add cast_to_sequence tool * update types * add tests --- CHANGES.md | 5 +- rio_tiler/io/base.py | 273 +++++++++++++++++++------------- rio_tiler/io/stac.py | 6 +- rio_tiler/reader.py | 7 +- rio_tiler/types.py | 1 + rio_tiler/utils.py | 8 + tests/fixtures/stac_raster.json | 7 + tests/test_io_stac.py | 36 ++++- 8 files changed, 227 insertions(+), 116 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 90b08ae5..f9ac7178 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,4 +1,7 @@ -# unreleased + +# Unreleased + +* Enable dynamic definition of Asset **reader** in `MultiBaseReader` (https://github.com/cogeotiff/rio-tiler/pull/711/) # 6.7.0 (2024-09-05) diff --git a/rio_tiler/io/base.py b/rio_tiler/io/base.py index 94839c3d..4778657f 100644 --- a/rio_tiler/io/base.py +++ b/rio_tiler/io/base.py @@ -25,7 +25,7 @@ from rio_tiler.models import BandStatistics, ImageData, Info, PointData from rio_tiler.tasks import multi_arrays, multi_points, multi_values from rio_tiler.types import AssetInfo, BBox, Indexes -from rio_tiler.utils import normalize_bounds +from rio_tiler.utils import cast_to_sequence, normalize_bounds @attr.s @@ -283,6 +283,10 @@ def _get_asset_info(self, asset: str) -> AssetInfo: """Validate asset name and construct url.""" ... + def _get_reader(self, asset_info: AssetInfo) -> Type[BaseReader]: + """Get Asset Reader.""" + return self.reader + def parse_expression(self, expression: str, asset_as_band: bool = False) -> Tuple: """Parse rio-tiler band math expression.""" input_assets = "|".join(self.assets) @@ -309,8 +313,7 @@ def _update_statistics( statistics: Optional[Sequence[Tuple[float, float]]] = None, ): """Update ImageData Statistics from AssetInfo.""" - if isinstance(indexes, int): - indexes = (indexes,) + indexes = cast_to_sequence(indexes) if indexes is None: indexes = tuple(range(1, img.count + 1)) @@ -322,7 +325,9 @@ def _update_statistics( img.dataset_statistics = [statistics[bidx - 1] for bidx in indexes] def info( - self, assets: Union[Sequence[str], str] = None, **kwargs: Any + self, + assets: Optional[Union[Sequence[str], str]] = None, + **kwargs: Any, ) -> Dict[str, Info]: """Return metadata from multiple assets. @@ -338,26 +343,27 @@ def info( "No `assets` option passed, will fetch info for all available assets.", UserWarning, ) - - assets = assets or self.assets - - if isinstance(assets, str): - assets = (assets,) + assets = cast_to_sequence(assets or self.assets) def _reader(asset: str, **kwargs: Any) -> Dict: asset_info = self._get_asset_info(asset) - url = asset_info["url"] + reader = self._get_reader(asset_info) + with self.ctx(**asset_info.get("env", {})): - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with reader( + asset_info["url"], + tms=self.tms, + **self.reader_options, + ) as src: return src.info() return multi_values(assets, _reader, **kwargs) def statistics( self, - assets: Union[Sequence[str], str] = None, - asset_indexes: Optional[Dict[str, Indexes]] = None, # Indexes for each asset - asset_expression: Optional[Dict[str, str]] = None, # Expression for each asset + assets: Optional[Union[Sequence[str], str]] = None, + asset_indexes: Optional[Dict[str, Indexes]] = None, + asset_expression: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> Dict[str, Dict[str, BandStatistics]]: """Return array statistics for multiple assets. @@ -378,23 +384,24 @@ def statistics( UserWarning, ) - assets = assets or self.assets - - if isinstance(assets, str): - assets = (assets,) - + assets = cast_to_sequence(assets or self.assets) asset_indexes = asset_indexes or {} asset_expression = asset_expression or {} - def _reader(asset: str, *args, **kwargs) -> Dict: + def _reader(asset: str, *args: Any, **kwargs: Any) -> Dict: asset_info = self._get_asset_info(asset) - url = asset_info["url"] + reader = self._get_reader(asset_info) + with self.ctx(**asset_info.get("env", {})): - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with reader( + asset_info["url"], + tms=self.tms, + **self.reader_options, + ) as src: return src.statistics( *args, - indexes=asset_indexes.get(asset, kwargs.pop("indexes", None)), # type: ignore - expression=asset_expression.get(asset), # type: ignore + indexes=asset_indexes.get(asset, kwargs.pop("indexes", None)), + expression=asset_expression.get(asset), **kwargs, ) @@ -402,9 +409,9 @@ def _reader(asset: str, *args, **kwargs) -> Dict: def merged_statistics( self, - assets: Union[Sequence[str], str] = None, + assets: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, - asset_indexes: Optional[Dict[str, Indexes]] = None, # Indexes for each asset + asset_indexes: Optional[Dict[str, Indexes]] = None, categorical: bool = False, categories: Optional[List[float]] = None, percentiles: Optional[List[int]] = None, @@ -436,7 +443,7 @@ def merged_statistics( "No `assets` option passed, will fetch statistics for all available assets.", UserWarning, ) - assets = assets or self.assets + assets = cast_to_sequence(assets or self.assets) data = self.preview( assets=assets, @@ -457,9 +464,9 @@ def tile( tile_x: int, tile_y: int, tile_z: int, - assets: Union[Sequence[str], str] = None, + assets: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, - asset_indexes: Optional[Dict[str, Indexes]] = None, # Indexes for each asset + asset_indexes: Optional[Dict[str, Indexes]] = None, asset_as_band: bool = False, **kwargs: Any, ) -> ImageData: @@ -483,9 +490,7 @@ def tile( f"Tile(x={tile_x}, y={tile_y}, z={tile_z}) is outside bounds" ) - if isinstance(assets, str): - assets = (assets,) - + assets = cast_to_sequence(assets) if assets and expression: warnings.warn( "Both expression and assets passed; expression will overwrite assets parameter.", @@ -506,12 +511,17 @@ def tile( indexes = kwargs.pop("indexes", None) def _reader(asset: str, *args: Any, **kwargs: Any) -> ImageData: - idx = asset_indexes.get(asset) or indexes # type: ignore + idx = asset_indexes.get(asset) or indexes asset_info = self._get_asset_info(asset) - url = asset_info["url"] + reader = self._get_reader(asset_info) + with self.ctx(**asset_info.get("env", {})): - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with reader( + asset_info["url"], + tms=self.tms, + **self.reader_options, + ) as src: data = src.tile(*args, indexes=idx, **kwargs) self._update_statistics( @@ -545,9 +555,9 @@ def _reader(asset: str, *args: Any, **kwargs: Any) -> ImageData: def part( self, bbox: BBox, - assets: Union[Sequence[str], str] = None, + assets: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, - asset_indexes: Optional[Dict[str, Indexes]] = None, # Indexes for each asset + asset_indexes: Optional[Dict[str, Indexes]] = None, asset_as_band: bool = False, **kwargs: Any, ) -> ImageData: @@ -564,9 +574,7 @@ def part( rio_tiler.models.ImageData: ImageData instance with data, mask and tile spatial info. """ - if isinstance(assets, str): - assets = (assets,) - + assets = cast_to_sequence(assets) if assets and expression: warnings.warn( "Both expression and assets passed; expression will overwrite assets parameter.", @@ -587,12 +595,17 @@ def part( indexes = kwargs.pop("indexes", None) def _reader(asset: str, *args: Any, **kwargs: Any) -> ImageData: - idx = asset_indexes.get(asset) or indexes # type: ignore + idx = asset_indexes.get(asset) or indexes asset_info = self._get_asset_info(asset) - url = asset_info["url"] + reader = self._get_reader(asset_info) + with self.ctx(**asset_info.get("env", {})): - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with reader( + asset_info["url"], + tms=self.tms, + **self.reader_options, + ) as src: data = src.part(*args, indexes=idx, **kwargs) self._update_statistics( @@ -625,9 +638,9 @@ def _reader(asset: str, *args: Any, **kwargs: Any) -> ImageData: def preview( self, - assets: Union[Sequence[str], str] = None, + assets: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, - asset_indexes: Optional[Dict[str, Indexes]] = None, # Indexes for each asset + asset_indexes: Optional[Dict[str, Indexes]] = None, asset_as_band: bool = False, **kwargs: Any, ) -> ImageData: @@ -643,9 +656,7 @@ def preview( rio_tiler.models.ImageData: ImageData instance with data, mask and tile spatial info. """ - if isinstance(assets, str): - assets = (assets,) - + assets = cast_to_sequence(assets) if assets and expression: warnings.warn( "Both expression and assets passed; expression will overwrite assets parameter.", @@ -666,12 +677,17 @@ def preview( indexes = kwargs.pop("indexes", None) def _reader(asset: str, **kwargs: Any) -> ImageData: - idx = asset_indexes.get(asset) or indexes # type: ignore + idx = asset_indexes.get(asset) or indexes asset_info = self._get_asset_info(asset) - url = asset_info["url"] + reader = self._get_reader(asset_info) + with self.ctx(**asset_info.get("env", {})): - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with reader( + asset_info["url"], + tms=self.tms, + **self.reader_options, + ) as src: data = src.preview(indexes=idx, **kwargs) self._update_statistics( @@ -706,9 +722,9 @@ def point( self, lon: float, lat: float, - assets: Union[Sequence[str], str] = None, + assets: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, - asset_indexes: Optional[Dict[str, Indexes]] = None, # Indexes for each asset + asset_indexes: Optional[Dict[str, Indexes]] = None, asset_as_band: bool = False, **kwargs: Any, ) -> PointData: @@ -726,9 +742,7 @@ def point( PointData """ - if isinstance(assets, str): - assets = (assets,) - + assets = cast_to_sequence(assets) if assets and expression: warnings.warn( "Both expression and assets passed; expression will overwrite assets parameter.", @@ -748,13 +762,18 @@ def point( # We fall back to `indexes` if provided indexes = kwargs.pop("indexes", None) - def _reader(asset: str, *args, **kwargs: Any) -> PointData: - idx = asset_indexes.get(asset) or indexes # type: ignore + def _reader(asset: str, *args: Any, **kwargs: Any) -> PointData: + idx = asset_indexes.get(asset) or indexes asset_info = self._get_asset_info(asset) - url = asset_info["url"] + reader = self._get_reader(asset_info) + with self.ctx(**asset_info.get("env", {})): - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with reader( + asset_info["url"], + tms=self.tms, + **self.reader_options, + ) as src: data = src.point(*args, indexes=idx, **kwargs) metadata = data.metadata or {} @@ -782,9 +801,9 @@ def _reader(asset: str, *args, **kwargs: Any) -> PointData: def feature( self, shape: Dict, - assets: Union[Sequence[str], str] = None, + assets: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, - asset_indexes: Optional[Dict[str, Indexes]] = None, # Indexes for each asset + asset_indexes: Optional[Dict[str, Indexes]] = None, asset_as_band: bool = False, **kwargs: Any, ) -> ImageData: @@ -801,9 +820,7 @@ def feature( rio_tiler.models.ImageData: ImageData instance with data, mask and tile spatial info. """ - if isinstance(assets, str): - assets = (assets,) - + assets = cast_to_sequence(assets) if assets and expression: warnings.warn( "Both expression and assets passed; expression will overwrite assets parameter.", @@ -824,12 +841,17 @@ def feature( indexes = kwargs.pop("indexes", None) def _reader(asset: str, *args: Any, **kwargs: Any) -> ImageData: - idx = asset_indexes.get(asset) or indexes # type: ignore + idx = asset_indexes.get(asset) or indexes asset_info = self._get_asset_info(asset) - url = asset_info["url"] + reader = self._get_reader(asset_info) + with self.ctx(**asset_info.get("env", {})): - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with reader( + asset_info["url"], + tms=self.tms, + **self.reader_options, + ) as src: data = src.feature(*args, indexes=idx, **kwargs) self._update_statistics( @@ -913,7 +935,11 @@ def parse_expression(self, expression: str) -> Tuple: return bands - def info(self, bands: Union[Sequence[str], str] = None, *args, **kwargs: Any) -> Info: + def info( + self, + bands: Optional[Union[Sequence[str], str]] = None, + **kwargs: Any, + ) -> Info: """Return metadata from multiple bands. Args: @@ -929,17 +955,18 @@ def info(self, bands: Union[Sequence[str], str] = None, *args, **kwargs: Any) -> UserWarning, ) - bands = bands or self.bands - - if isinstance(bands, str): - bands = (bands,) + bands = cast_to_sequence(bands or self.bands) def _reader(band: str, **kwargs: Any) -> Info: url = self._get_band_url(band) - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with self.reader( + url, + tms=self.tms, + **self.reader_options, + ) as src: return src.info() - bands_metadata = multi_values(bands, _reader, *args, **kwargs) + bands_metadata = multi_values(bands, _reader, **kwargs) meta = { "bounds": self.geographic_bounds, @@ -965,7 +992,7 @@ def _reader(band: str, **kwargs: Any) -> Info: def statistics( self, - bands: Union[Sequence[str], str] = None, + bands: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, categorical: bool = False, categories: Optional[List[float]] = None, @@ -996,7 +1023,7 @@ def statistics( "No `bands` option passed, will fetch statistics for all available bands.", UserWarning, ) - bands = bands or self.bands + bands = cast_to_sequence(bands or self.bands) data = self.preview( bands=bands, @@ -1016,7 +1043,7 @@ def tile( tile_x: int, tile_y: int, tile_z: int, - bands: Union[Sequence[str], str] = None, + bands: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, **kwargs: Any, ) -> ImageData: @@ -1039,9 +1066,7 @@ def tile( f"Tile(x={tile_x}, y={tile_y}, z={tile_z}) is outside bounds" ) - if isinstance(bands, str): - bands = (bands,) - + bands = cast_to_sequence(bands) if bands and expression: warnings.warn( "Both expression and bands passed; expression will overwrite bands parameter.", @@ -1058,11 +1083,19 @@ def tile( def _reader(band: str, *args: Any, **kwargs: Any) -> ImageData: url = self._get_band_url(band) - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with self.reader( + url, + tms=self.tms, + **self.reader_options, + ) as src: data = src.tile(*args, **kwargs) + if data.metadata: data.metadata = {band: data.metadata} - data.band_names = [band] # use `band` as name instead of band index + + # use `band` as name instead of band index + data.band_names = [band] + return data img = multi_arrays(bands, _reader, tile_x, tile_y, tile_z, **kwargs) @@ -1075,7 +1108,7 @@ def _reader(band: str, *args: Any, **kwargs: Any) -> ImageData: def part( self, bbox: BBox, - bands: Union[Sequence[str], str] = None, + bands: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, **kwargs: Any, ) -> ImageData: @@ -1091,9 +1124,7 @@ def part( rio_tiler.models.ImageData: ImageData instance with data, mask and tile spatial info. """ - if isinstance(bands, str): - bands = (bands,) - + bands = cast_to_sequence(bands) if bands and expression: warnings.warn( "Both expression and bands passed; expression will overwrite bands parameter.", @@ -1110,11 +1141,19 @@ def part( def _reader(band: str, *args: Any, **kwargs: Any) -> ImageData: url = self._get_band_url(band) - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with self.reader( + url, + tms=self.tms, + **self.reader_options, + ) as src: data = src.part(*args, **kwargs) + if data.metadata: data.metadata = {band: data.metadata} - data.band_names = [band] # use `band` as name instead of band index + + # use `band` as name instead of band index + data.band_names = [band] + return data img = multi_arrays(bands, _reader, bbox, **kwargs) @@ -1126,7 +1165,7 @@ def _reader(band: str, *args: Any, **kwargs: Any) -> ImageData: def preview( self, - bands: Union[Sequence[str], str] = None, + bands: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, **kwargs: Any, ) -> ImageData: @@ -1141,9 +1180,7 @@ def preview( rio_tiler.models.ImageData: ImageData instance with data, mask and tile spatial info. """ - if isinstance(bands, str): - bands = (bands,) - + bands = cast_to_sequence(bands) if bands and expression: warnings.warn( "Both expression and bands passed; expression will overwrite bands parameter.", @@ -1160,11 +1197,19 @@ def preview( def _reader(band: str, **kwargs: Any) -> ImageData: url = self._get_band_url(band) - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with self.reader( + url, + tms=self.tms, + **self.reader_options, + ) as src: data = src.preview(**kwargs) + if data.metadata: data.metadata = {band: data.metadata} - data.band_names = [band] # use `band` as name instead of band index + + # use `band` as name instead of band index + data.band_names = [band] + return data img = multi_arrays(bands, _reader, **kwargs) @@ -1178,7 +1223,7 @@ def point( self, lon: float, lat: float, - bands: Union[Sequence[str], str] = None, + bands: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, **kwargs: Any, ) -> PointData: @@ -1195,9 +1240,7 @@ def point( PointData """ - if isinstance(bands, str): - bands = (bands,) - + bands = cast_to_sequence(bands) if bands and expression: warnings.warn( "Both expression and bands passed; expression will overwrite bands parameter.", @@ -1212,13 +1255,21 @@ def point( "bands must be passed either via `expression` or `bands` options." ) - def _reader(band: str, *args, **kwargs: Any) -> PointData: + def _reader(band: str, *args: Any, **kwargs: Any) -> PointData: url = self._get_band_url(band) - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with self.reader( + url, + tms=self.tms, + **self.reader_options, + ) as src: data = src.point(*args, **kwargs) + if data.metadata: data.metadata = {band: data.metadata} - data.band_names = [band] # use `band` as name instead of band index + + # use `band` as name instead of band index + data.band_names = [band] + return data data = multi_points(bands, _reader, lon, lat, **kwargs) @@ -1230,7 +1281,7 @@ def _reader(band: str, *args, **kwargs: Any) -> PointData: def feature( self, shape: Dict, - bands: Union[Sequence[str], str] = None, + bands: Optional[Union[Sequence[str], str]] = None, expression: Optional[str] = None, **kwargs: Any, ) -> ImageData: @@ -1246,9 +1297,7 @@ def feature( rio_tiler.models.ImageData: ImageData instance with data, mask and tile spatial info. """ - if isinstance(bands, str): - bands = (bands,) - + bands = cast_to_sequence(bands) if bands and expression: warnings.warn( "Both expression and bands passed; expression will overwrite bands parameter.", @@ -1265,11 +1314,19 @@ def feature( def _reader(band: str, *args: Any, **kwargs: Any) -> ImageData: url = self._get_band_url(band) - with self.reader(url, tms=self.tms, **self.reader_options) as src: # type: ignore + with self.reader( + url, + tms=self.tms, + **self.reader_options, + ) as src: data = src.feature(*args, **kwargs) + if data.metadata: data.metadata = {band: data.metadata} - data.band_names = [band] # use `band` as name instead of band index + + # use `band` as name instead of band index + data.band_names = [band] + return data img = multi_arrays(bands, _reader, shape, **kwargs) diff --git a/rio_tiler/io/stac.py b/rio_tiler/io/stac.py index e2723a37..31c5989d 100644 --- a/rio_tiler/io/stac.py +++ b/rio_tiler/io/stac.py @@ -276,13 +276,13 @@ def _maxzoom(self): return self.tms.maxzoom def _get_asset_info(self, asset: str) -> AssetInfo: - """Validate asset names and return asset's url. + """Validate asset names and return asset's info. Args: asset (str): STAC asset name. Returns: - str: STAC asset href. + AssetInfo: STAC asset info. """ if asset not in self.assets: @@ -297,6 +297,8 @@ def _get_asset_info(self, asset: str) -> AssetInfo: url=asset_info.get_absolute_href() or asset_info.href, metadata=extras, ) + if asset_info.media_type: + info["media_type"] = asset_info.media_type if head := extras.get("file:header_size"): info["env"] = {"GDAL_INGESTED_BYTES_AT_OPEN": head} diff --git a/rio_tiler/reader.py b/rio_tiler/reader.py index 50e4ad46..5b8763c8 100644 --- a/rio_tiler/reader.py +++ b/rio_tiler/reader.py @@ -23,6 +23,7 @@ from rio_tiler.utils import _requested_tile_aligned_with_internal_tile as is_aligned from rio_tiler.utils import ( _round_window, + cast_to_sequence, get_vrt_transform, has_alpha_band, non_alpha_indexes, @@ -120,8 +121,7 @@ def read( ImageData """ - if isinstance(indexes, int): - indexes = (indexes,) + indexes = cast_to_sequence(indexes) if max_size and width and height: warnings.warn( @@ -529,8 +529,7 @@ def point( PointData """ - if isinstance(indexes, int): - indexes = (indexes,) + indexes = cast_to_sequence(indexes) with contextlib.ExitStack() as ctx: # Use WarpedVRT when User provided Nodata or VRT Option (cutline) diff --git a/rio_tiler/types.py b/rio_tiler/types.py index 267e4ce5..53244b82 100644 --- a/rio_tiler/types.py +++ b/rio_tiler/types.py @@ -59,6 +59,7 @@ class AssetInfo(TypedDict, total=False): """Asset Reader Options.""" url: str + media_type: str env: Optional[Dict] metadata: Optional[Dict] dataset_statistics: Optional[Sequence[Tuple[float, float]]] diff --git a/rio_tiler/utils.py b/rio_tiler/utils.py index f5bc15a3..3bd5c556 100644 --- a/rio_tiler/utils.py +++ b/rio_tiler/utils.py @@ -783,3 +783,11 @@ def _validate_shape_input(shape: Dict) -> Dict: raise RioTilerError("Invalid geometry") return shape + + +def cast_to_sequence(val: Optional[Any] = None) -> Sequence: + """Cast input to sequence if not Tuple of List.""" + if val is not None and not isinstance(val, (list, tuple)): + val = (val,) + + return val diff --git a/tests/fixtures/stac_raster.json b/tests/fixtures/stac_raster.json index 66e776aa..1177b979 100644 --- a/tests/fixtures/stac_raster.json +++ b/tests/fixtures/stac_raster.json @@ -144,6 +144,13 @@ } } ] + }, + "netcdf": { + "href": "http://somewhere-over-the-rainbow.io/some_netcdf.nc", + "type": "application/x-netcdf", + "roles": [ + "data" + ] } }, "bbox": [ diff --git a/tests/test_io_stac.py b/tests/test_io_stac.py index e4f1b2cf..b2c603c8 100644 --- a/tests/test_io_stac.py +++ b/tests/test_io_stac.py @@ -2,6 +2,7 @@ import json import os +from typing import Set, Type from unittest.mock import patch import attr @@ -18,8 +19,9 @@ MissingAssets, TileOutsideBounds, ) -from rio_tiler.io import Reader, STACReader +from rio_tiler.io import BaseReader, Reader, STACReader, XarrayReader from rio_tiler.models import BandStatistics +from rio_tiler.types import AssetInfo PREFIX = os.path.join(os.path.dirname(__file__), "fixtures") STAC_PATH = os.path.join(PREFIX, "stac.json") @@ -890,3 +892,35 @@ def test_expression_with_wrong_stac_stats(rio): expression="where((wrongstat>0.5),1,0)", asset_as_band=True, ) + + +def test_get_reader(): + """Should use the correct reader depending on the media type.""" + valid_types = { + "image/tiff; application=geotiff", + "application/x-netcdf", + } + + @attr.s + class CustomSTACReader(STACReader): + include_asset_types: Set[str] = attr.ib(default=valid_types) + + def _get_reader(self, asset_info: AssetInfo) -> Type[BaseReader]: + """Get Asset Reader.""" + asset_type = asset_info.get("media_type", None) + if asset_type and asset_type in [ + "application/x-netcdf", + ]: + return XarrayReader + + return Reader + + with CustomSTACReader(STAC_RASTER_PATH) as stac: + assert stac.assets == ["red", "green", "blue", "netcdf"] + info = stac._get_asset_info("netcdf") + assert info["media_type"] == "application/x-netcdf" + assert stac._get_reader(info) == XarrayReader + + info = stac._get_asset_info("red") + assert info["media_type"] == "image/tiff; application=geotiff" + assert stac._get_reader(info) == Reader