diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index 4a0ac7d67f9..65c5bc2a02b 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -12,7 +12,7 @@ BackendEntrypoint, ) from .locks import SerializableLock, ensure_lock -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import cfgrib @@ -86,62 +86,58 @@ def get_encoding(self): return encoding -def guess_can_open_cfgrib(store_spec): - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - return ext in {".grib", ".grib2", ".grb", ".grb2"} - - -def open_backend_dataset_cfgrib( - filename_or_obj, - *, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - lock=None, - indexpath="{path}.{short_hash}.idx", - filter_by_keys={}, - read_keys=[], - encode_cf=("parameter", "time", "geography", "vertical"), - squeeze=True, - time_dims=("time", "step"), -): - - store = CfGribDataStore( +class CfgribfBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".grib", ".grib2", ".grb", ".grb2"} + + def open_dataset( + self, filename_or_obj, - indexpath=indexpath, - filter_by_keys=filter_by_keys, - read_keys=read_keys, - encode_cf=encode_cf, - squeeze=squeeze, - time_dims=time_dims, - lock=lock, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + *, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + lock=None, + indexpath="{path}.{short_hash}.idx", + filter_by_keys={}, + read_keys=[], + encode_cf=("parameter", "time", "geography", "vertical"), + squeeze=True, + time_dims=("time", "step"), + ): + + store = CfGribDataStore( + filename_or_obj, + indexpath=indexpath, + filter_by_keys=filter_by_keys, + read_keys=read_keys, + encode_cf=encode_cf, + squeeze=squeeze, + time_dims=time_dims, + lock=lock, ) - return ds - - -cfgrib_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_cfgrib, guess_can_open=guess_can_open_cfgrib -) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_cfgrib: - BACKEND_ENTRYPOINTS["cfgrib"] = cfgrib_backend + BACKEND_ENTRYPOINTS["cfgrib"] = CfgribfBackendEntrypoint diff --git a/xarray/backends/common.py b/xarray/backends/common.py index adb70658fab..e2905d0866b 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,7 +1,7 @@ import logging import time import traceback -from typing import Dict +from typing import Dict, Tuple, Type, Union import numpy as np @@ -344,12 +344,13 @@ def encode(self, variables, attributes): class BackendEntrypoint: - __slots__ = ("guess_can_open", "open_dataset", "open_dataset_parameters") + open_dataset_parameters: Union[Tuple, None] = None - def __init__(self, open_dataset, open_dataset_parameters=None, guess_can_open=None): - self.open_dataset = open_dataset - self.open_dataset_parameters = open_dataset_parameters - self.guess_can_open = guess_can_open + def open_dataset(self): + raise NotImplementedError + def guess_can_open(self, store_spec): + return False -BACKEND_ENTRYPOINTS: Dict[str, BackendEntrypoint] = {} + +BACKEND_ENTRYPOINTS: Dict[str, Type[BackendEntrypoint]] = {} diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 562600de4b6..aa892c4f89c 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -23,7 +23,7 @@ _get_datatype, _nc4_require_group, ) -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import h5netcdf @@ -328,62 +328,61 @@ def close(self, **kwargs): self._manager.close(**kwargs) -def guess_can_open_h5netcdf(store_spec): - try: - return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n") - except TypeError: - pass - - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - - return ext in {".nc", ".nc4", ".cdf"} - - -def open_backend_dataset_h5netcdf( - filename_or_obj, - *, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - format=None, - group=None, - lock=None, - invalid_netcdf=None, - phony_dims=None, -): - - store = H5NetCDFStore.open( +class H5netcdfBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + try: + return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n") + except TypeError: + pass + + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + + return ext in {".nc", ".nc4", ".cdf"} + + def open_dataset( + self, filename_or_obj, - format=format, - group=group, - lock=lock, - invalid_netcdf=invalid_netcdf, - phony_dims=phony_dims, - ) + *, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + format=None, + group=None, + lock=None, + invalid_netcdf=None, + phony_dims=None, + ): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) - return ds + store = H5NetCDFStore.open( + filename_or_obj, + format=format, + group=group, + lock=lock, + invalid_netcdf=invalid_netcdf, + phony_dims=phony_dims, + ) + store_entrypoint = StoreBackendEntrypoint() + + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds -h5netcdf_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_h5netcdf, guess_can_open=guess_can_open_h5netcdf -) if has_h5netcdf: - BACKEND_ENTRYPOINTS["h5netcdf"] = h5netcdf_backend + BACKEND_ENTRYPOINTS["h5netcdf"] = H5netcdfBackendEntrypoint diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 5bb4eec837b..e3d87aaf83f 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -22,7 +22,7 @@ from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, get_write_lock from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import netCDF4 @@ -512,65 +512,62 @@ def close(self, **kwargs): self._manager.close(**kwargs) -def guess_can_open_netcdf4(store_spec): - if isinstance(store_spec, str) and is_remote_uri(store_spec): - return True - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - return ext in {".nc", ".nc4", ".cdf"} - - -def open_backend_dataset_netcdf4( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - group=None, - mode="r", - format="NETCDF4", - clobber=True, - diskless=False, - persist=False, - lock=None, - autoclose=False, -): +class NetCDF4BackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + if isinstance(store_spec, str) and is_remote_uri(store_spec): + return True + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".nc", ".nc4", ".cdf"} - store = NetCDF4DataStore.open( + def open_dataset( + self, filename_or_obj, - mode=mode, - format=format, - group=group, - clobber=clobber, - diskless=diskless, - persist=persist, - lock=lock, - autoclose=autoclose, - ) + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + group=None, + mode="r", + format="NETCDF4", + clobber=True, + diskless=False, + persist=False, + lock=None, + autoclose=False, + ): - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + store = NetCDF4DataStore.open( + filename_or_obj, + mode=mode, + format=format, + group=group, + clobber=clobber, + diskless=diskless, + persist=persist, + lock=lock, + autoclose=autoclose, ) - return ds - -netcdf4_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_netcdf4, guess_can_open=guess_can_open_netcdf4 -) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_netcdf4: - BACKEND_ENTRYPOINTS["netcdf4"] = netcdf4_backend + BACKEND_ENTRYPOINTS["netcdf4"] = NetCDF4BackendEntrypoint diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index 6d3ec7e7da5..b8cd2bf6378 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -36,6 +36,7 @@ def remove_duplicates(backend_entrypoints): def detect_parameters(open_dataset): signature = inspect.signature(open_dataset) parameters = signature.parameters + parameters_list = [] for name, param in parameters.items(): if param.kind in ( inspect.Parameter.VAR_KEYWORD, @@ -45,7 +46,9 @@ def detect_parameters(open_dataset): f"All the parameters in {open_dataset!r} signature should be explicit. " "*args and **kwargs is not supported" ) - return tuple(parameters) + if name != "self": + parameters_list.append(name) + return tuple(parameters_list) def create_engines_dict(backend_entrypoints): @@ -57,8 +60,8 @@ def create_engines_dict(backend_entrypoints): return engines -def set_missing_parameters(engines): - for name, backend in engines.items(): +def set_missing_parameters(backend_entrypoints): + for name, backend in backend_entrypoints.items(): if backend.open_dataset_parameters is None: open_dataset = backend.open_dataset backend.open_dataset_parameters = detect_parameters(open_dataset) @@ -70,7 +73,10 @@ def build_engines(entrypoints): external_backend_entrypoints = create_engines_dict(pkg_entrypoints) backend_entrypoints.update(external_backend_entrypoints) set_missing_parameters(backend_entrypoints) - return backend_entrypoints + engines = {} + for name, backend in backend_entrypoints.items(): + engines[name] = backend() + return engines @functools.lru_cache(maxsize=1) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index c2bfd519bed..80485fce459 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -11,7 +11,7 @@ ) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: from PseudoNetCDF import pncopen @@ -100,57 +100,55 @@ def close(self): self._manager.close() -def open_backend_dataset_pseudonetcdf( - filename_or_obj, - mask_and_scale=False, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - mode=None, - lock=None, - **format_kwargs, -): - - store = PseudoNetCDFDataStore.open( - filename_or_obj, lock=lock, mode=mode, **format_kwargs +class PseudoNetCDFBackendEntrypoint(BackendEntrypoint): + + # *args and **kwargs are not allowed in open_backend_dataset_ kwargs, + # unless the open_dataset_parameters are explicity defined like this: + open_dataset_parameters = ( + "filename_or_obj", + "mask_and_scale", + "decode_times", + "concat_characters", + "decode_coords", + "drop_variables", + "use_cftime", + "decode_timedelta", + "mode", + "lock", ) - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + def open_dataset( + self, + filename_or_obj, + mask_and_scale=False, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode=None, + lock=None, + **format_kwargs, + ): + store = PseudoNetCDFDataStore.open( + filename_or_obj, lock=lock, mode=mode, **format_kwargs ) - return ds - - -# *args and **kwargs are not allowed in open_backend_dataset_ kwargs, -# unless the open_dataset_parameters are explicity defined like this: -open_dataset_parameters = ( - "filename_or_obj", - "mask_and_scale", - "decode_times", - "concat_characters", - "decode_coords", - "drop_variables", - "use_cftime", - "decode_timedelta", - "mode", - "lock", -) -pseudonetcdf_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_pseudonetcdf, - open_dataset_parameters=open_dataset_parameters, -) + + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_pseudonetcdf: - BACKEND_ENTRYPOINTS["pseudonetcdf"] = pseudonetcdf_backend + BACKEND_ENTRYPOINTS["pseudonetcdf"] = PseudoNetCDFBackendEntrypoint diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index c5ce943a10a..7f8622ca66e 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -11,7 +11,7 @@ BackendEntrypoint, robust_getitem, ) -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import pydap.client @@ -107,45 +107,41 @@ def get_dimensions(self): return Frozen(self.ds.dimensions) -def guess_can_open_pydap(store_spec): - return isinstance(store_spec, str) and is_remote_uri(store_spec) +class PydapBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + return isinstance(store_spec, str) and is_remote_uri(store_spec) - -def open_backend_dataset_pydap( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - session=None, -): - - store = PydapDataStore.open( + def open_dataset( + self, filename_or_obj, - session=session, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + session=None, + ): + store = PydapDataStore.open( + filename_or_obj, + session=session, ) - return ds - -pydap_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_pydap, guess_can_open=guess_can_open_pydap -) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_pydap: - BACKEND_ENTRYPOINTS["pydap"] = pydap_backend + BACKEND_ENTRYPOINTS["pydap"] = PydapBackendEntrypoint diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 261daa69880..41c99efd076 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -11,7 +11,7 @@ ) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import Nio @@ -97,41 +97,39 @@ def close(self): self._manager.close() -def open_backend_dataset_pynio( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - mode="r", - lock=None, -): - - store = NioDataStore( +class PynioBackendEntrypoint(BackendEntrypoint): + def open_dataset( filename_or_obj, - mode=mode, - lock=lock, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode="r", + lock=None, + ): + store = NioDataStore( + filename_or_obj, + mode=mode, + lock=lock, ) - return ds - -pynio_backend = BackendEntrypoint(open_dataset=open_backend_dataset_pynio) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_pynio: - BACKEND_ENTRYPOINTS["pynio"] = pynio_backend + BACKEND_ENTRYPOINTS["pynio"] = PynioBackendEntrypoint diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index df51d07d686..ddc157ed8e4 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -15,7 +15,7 @@ from .file_manager import CachingFileManager, DummyFileManager from .locks import ensure_lock, get_write_lock from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import scipy.io @@ -232,56 +232,54 @@ def close(self): self._manager.close() -def guess_can_open_scipy(store_spec): - try: - return read_magic_number(store_spec).startswith(b"CDF") - except TypeError: - pass +class ScipyBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + try: + return read_magic_number(store_spec).startswith(b"CDF") + except TypeError: + pass - try: - _, ext = os.path.splitext(store_spec) - except TypeError: - return False - return ext in {".nc", ".nc4", ".cdf", ".gz"} - - -def open_backend_dataset_scipy( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - mode="r", - format=None, - group=None, - mmap=None, - lock=None, -): - - store = ScipyDataStore( - filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock - ) - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) - return ds + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".nc", ".nc4", ".cdf", ".gz"} + + def open_dataset( + self, + filename_or_obj, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode="r", + format=None, + group=None, + mmap=None, + lock=None, + ): + store = ScipyDataStore( + filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock + ) -scipy_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_scipy, guess_can_open=guess_can_open_scipy -) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_scipy: - BACKEND_ENTRYPOINTS["scipy"] = scipy_backend + BACKEND_ENTRYPOINTS["scipy"] = ScipyBackendEntrypoint diff --git a/xarray/backends/store.py b/xarray/backends/store.py index 66fca0d39c3..d57b3ab9df8 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -3,47 +3,43 @@ from .common import BACKEND_ENTRYPOINTS, AbstractDataStore, BackendEntrypoint -def guess_can_open_store(store_spec): - return isinstance(store_spec, AbstractDataStore) - - -def open_backend_dataset_store( - store, - *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, -): - vars, attrs = store.load() - encoding = store.get_encoding() - - vars, attrs, coord_names = conventions.decode_cf_variables( - vars, - attrs, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) - - ds = Dataset(vars, attrs=attrs) - ds = ds.set_coords(coord_names.intersection(vars)) - ds.set_close(store.close) - ds.encoding = encoding - - return ds - - -store_backend = BackendEntrypoint( - open_dataset=open_backend_dataset_store, guess_can_open=guess_can_open_store -) - - -BACKEND_ENTRYPOINTS["store"] = store_backend +class StoreBackendEntrypoint(BackendEntrypoint): + def guess_can_open(self, store_spec): + return isinstance(store_spec, AbstractDataStore) + + def open_dataset( + self, + store, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + ): + vars, attrs = store.load() + encoding = store.get_encoding() + + vars, attrs, coord_names = conventions.decode_cf_variables( + vars, + attrs, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + + ds = Dataset(vars, attrs=attrs) + ds = ds.set_coords(coord_names.intersection(vars)) + ds.set_close(store.close) + ds.encoding = encoding + + return ds + + +BACKEND_ENTRYPOINTS["store"] = StoreBackendEntrypoint diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index ceeb23cac9b..1d667a38b53 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -15,7 +15,7 @@ BackendEntrypoint, _encode_variable_name, ) -from .store import open_backend_dataset_store +from .store import StoreBackendEntrypoint try: import zarr @@ -670,49 +670,48 @@ def open_zarr( return ds -def open_backend_dataset_zarr( - filename_or_obj, - mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - use_cftime=None, - decode_timedelta=None, - group=None, - mode="r", - synchronizer=None, - consolidated=False, - consolidate_on_close=False, - chunk_store=None, -): - - store = ZarrStore.open_group( +class ZarrBackendEntrypoint(BackendEntrypoint): + def open_dataset( + self, filename_or_obj, - group=group, - mode=mode, - synchronizer=synchronizer, - consolidated=consolidated, - consolidate_on_close=consolidate_on_close, - chunk_store=chunk_store, - ) - - with close_on_error(store): - ds = open_backend_dataset_store( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + group=None, + mode="r", + synchronizer=None, + consolidated=False, + consolidate_on_close=False, + chunk_store=None, + ): + store = ZarrStore.open_group( + filename_or_obj, + group=group, + mode=mode, + synchronizer=synchronizer, + consolidated=consolidated, + consolidate_on_close=consolidate_on_close, + chunk_store=chunk_store, ) - return ds - -zarr_backend = BackendEntrypoint(open_dataset=open_backend_dataset_zarr) + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds if has_zarr: - BACKEND_ENTRYPOINTS["zarr"] = zarr_backend + BACKEND_ENTRYPOINTS["zarr"] = ZarrBackendEntrypoint diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index 38ebce6da1a..64a1c563dba 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -6,19 +6,24 @@ from xarray.backends import common, plugins -def dummy_open_dataset_args(filename_or_obj, *args): - pass +class DummyBackendEntrypointArgs(common.BackendEntrypoint): + def open_dataset(filename_or_obj, *args): + pass -def dummy_open_dataset_kwargs(filename_or_obj, **kwargs): - pass +class DummyBackendEntrypointKwargs(common.BackendEntrypoint): + def open_dataset(filename_or_obj, **kwargs): + pass -def dummy_open_dataset(filename_or_obj, *, decoder): - pass +class DummyBackendEntrypoint1(common.BackendEntrypoint): + def open_dataset(self, filename_or_obj, *, decoder): + pass -dummy_cfgrib = common.BackendEntrypoint(dummy_open_dataset) +class DummyBackendEntrypoint2(common.BackendEntrypoint): + def open_dataset(self, filename_or_obj, *, decoder): + pass @pytest.fixture @@ -65,46 +70,48 @@ def test_create_engines_dict(): def test_set_missing_parameters(): - backend_1 = common.BackendEntrypoint(dummy_open_dataset) - backend_2 = common.BackendEntrypoint(dummy_open_dataset, ("filename_or_obj",)) + backend_1 = DummyBackendEntrypoint1 + backend_2 = DummyBackendEntrypoint2 + backend_2.open_dataset_parameters = ("filename_or_obj",) engines = {"engine_1": backend_1, "engine_2": backend_2} plugins.set_missing_parameters(engines) assert len(engines) == 2 - engine_1 = engines["engine_1"] - assert engine_1.open_dataset_parameters == ("filename_or_obj", "decoder") - engine_2 = engines["engine_2"] - assert engine_2.open_dataset_parameters == ("filename_or_obj",) + assert backend_1.open_dataset_parameters == ("filename_or_obj", "decoder") + assert backend_2.open_dataset_parameters == ("filename_or_obj",) + + backend = DummyBackendEntrypointKwargs() + backend.open_dataset_parameters = ("filename_or_obj", "decoder") + plugins.set_missing_parameters({"engine": backend}) + assert backend.open_dataset_parameters == ("filename_or_obj", "decoder") + + backend = DummyBackendEntrypointArgs() + backend.open_dataset_parameters = ("filename_or_obj", "decoder") + plugins.set_missing_parameters({"engine": backend}) + assert backend.open_dataset_parameters == ("filename_or_obj", "decoder") def test_set_missing_parameters_raise_error(): - backend = common.BackendEntrypoint(dummy_open_dataset_args) + backend = DummyBackendEntrypointKwargs() with pytest.raises(TypeError): plugins.set_missing_parameters({"engine": backend}) - backend = common.BackendEntrypoint( - dummy_open_dataset_args, ("filename_or_obj", "decoder") - ) - plugins.set_missing_parameters({"engine": backend}) - - backend = common.BackendEntrypoint(dummy_open_dataset_kwargs) + backend = DummyBackendEntrypointArgs() with pytest.raises(TypeError): plugins.set_missing_parameters({"engine": backend}) - backend = common.BackendEntrypoint( - dummy_open_dataset_kwargs, ("filename_or_obj", "decoder") - ) - plugins.set_missing_parameters({"engine": backend}) - -@mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=dummy_cfgrib)) +@mock.patch( + "pkg_resources.EntryPoint.load", + mock.MagicMock(return_value=DummyBackendEntrypoint1), +) def test_build_engines(): - dummy_cfgrib_pkg_entrypoint = pkg_resources.EntryPoint.parse( + dummy_pkg_entrypoint = pkg_resources.EntryPoint.parse( "cfgrib = xarray.tests.test_plugins:backend_1" ) - backend_entrypoints = plugins.build_engines([dummy_cfgrib_pkg_entrypoint]) - assert backend_entrypoints["cfgrib"] is dummy_cfgrib + backend_entrypoints = plugins.build_engines([dummy_pkg_entrypoint]) + assert isinstance(backend_entrypoints["cfgrib"], DummyBackendEntrypoint1) assert backend_entrypoints["cfgrib"].open_dataset_parameters == ( "filename_or_obj", "decoder",