Skip to content

Commit

Permalink
Added version checking and new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jsignell committed Jan 23, 2019
1 parent 79a35a1 commit ffad7d9
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 25 deletions.
36 changes: 28 additions & 8 deletions intake_xarray/netcdf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# -*- coding: utf-8 -*-
import xarray as xr
from distutils.version import LooseVersion
try:
import xarray as xr
XARRAY_VERSION = LooseVersion(xr.__version__)
except ImportError:
XARRAY_VERSION = None
from intake.source.base import PatternMixin
from intake.source.utils import reverse_format
from .base import DataSourceMixin
Expand All @@ -10,38 +15,47 @@ class NetCDFSource(DataSourceMixin, PatternMixin):
Parameters
----------
urlpath: str
urlpath : str
Path to source file. May include glob "*" characters, format
pattern strings, or list.
Some examples:
- ``{{ CATALOG_DIR }}data/air.nc``
- ``{{ CATALOG_DIR }}data/*.nc``
- ``{{ CATALOG_DIR }}data/air_{year}.nc``
chunks: int or dict
- ``{{ CATALOG_DIR }}/data/air.nc``
- ``{{ CATALOG_DIR }}/data/*.nc``
- ``{{ CATALOG_DIR }}/data/air_{year}.nc``
chunks : int or dict, optional
Chunks is used to load the new dataset into dask
arrays. ``chunks={}`` loads the dataset with dask using a single
chunk for all arrays.
path_as_pattern: bool or str, optional
concat_dim : str, optional
Name of dimension along which to concatenate the files. Can
be new or pre-existing. Default is 'concat_dim'.
path_as_pattern : bool or str, optional
Whether to treat the path as a pattern (ie. ``data_{field}.nc``)
and create new coodinates in the output corresponding to pattern
fields. If str, is treated as pattern to match on. Default is True.
"""
name = 'netcdf'

def __init__(self, urlpath, chunks, xarray_kwargs=None, metadata=None,
def __init__(self, urlpath, chunks=None, concat_dim='concat_dim',
xarray_kwargs=None, metadata=None,
path_as_pattern=True, **kwargs):
self.path_as_pattern = path_as_pattern
self.urlpath = urlpath
self.chunks = chunks
self.concat_dim = concat_dim
self._kwargs = xarray_kwargs or kwargs
self._ds = None
super(NetCDFSource, self).__init__(metadata=metadata)

def _open_dataset(self):
if not XARRAY_VERSION:
raise ImportError("xarray not available")
url = self.urlpath
kwargs = self._kwargs
if "*" in url or isinstance(url, list):
_open_dataset = xr.open_mfdataset
if 'concat_dim' not in kwargs.keys():
kwargs.update(concat_dim=self.concat_dim)
if self.pattern:
kwargs.update(preprocess=self._add_path_to_ds)
else:
Expand All @@ -52,6 +66,12 @@ def _open_dataset(self):
def _add_path_to_ds(self, ds):
"""Adding path info to a coord for a particular file
"""
if not (XARRAY_VERSION > '0.11.1'):
raise ImportError("Your version of xarray is '{}'. "
"The insurance that source path is available on output of "
"open_dataset was added in 0.11.2, so "
"pattern urlpaths are not supported.".format(XARRAY_VERSION))

var = next(var for var in ds)
new_coords = reverse_format(self.pattern, ds[var].encoding['source'])
return ds.assign_coords(**new_coords)
6 changes: 3 additions & 3 deletions tests/util.py → tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

import os
import posixpath
import pytest
import shutil
import tempfile
Expand All @@ -11,11 +11,11 @@

TEST_DATA_DIR = 'tests/data'
TEST_DATA = 'example_1.nc'
TEST_URLPATH = os.path.join(TEST_DATA_DIR, TEST_DATA)
TEST_URLPATH = posixpath.join(TEST_DATA_DIR, TEST_DATA)


@pytest.fixture
def cdf_source():
def netcdf_source():
return NetCDFSource(TEST_URLPATH, {})


Expand Down
Binary file added tests/data/example_2.nc
Binary file not shown.
1 change: 0 additions & 1 deletion tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest

from intake import open_catalog
from .util import dataset # noqa


@pytest.fixture
Expand Down
46 changes: 33 additions & 13 deletions tests/test_intake_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@

here = os.path.dirname(__file__)

from .util import TEST_URLPATH, cdf_source, zarr_source, dataset # noqa


@pytest.mark.parametrize('source', ['cdf', 'zarr'])
def test_discover(source, cdf_source, zarr_source, dataset):
source = {'cdf': cdf_source, 'zarr': zarr_source}[source]
@pytest.mark.parametrize('source', ['netcdf', 'zarr'])
def test_discover(source, netcdf_source, zarr_source, dataset):
source = {'netcdf': netcdf_source, 'zarr': zarr_source}[source]
r = source.discover()

assert r['datashape'] is None
Expand All @@ -25,9 +23,9 @@ def test_discover(source, cdf_source, zarr_source, dataset):
assert set(source.metadata['coords']) == set(dataset.coords.keys())


@pytest.mark.parametrize('source', ['cdf', 'zarr'])
def test_read(source, cdf_source, zarr_source, dataset):
source = {'cdf': cdf_source, 'zarr': zarr_source}[source]
@pytest.mark.parametrize('source', ['netcdf', 'zarr'])
def test_read(source, netcdf_source, zarr_source, dataset):
source = {'netcdf': netcdf_source, 'zarr': zarr_source}[source]

ds = source.read_chunked()
assert ds.temp.chunks
Expand All @@ -38,8 +36,8 @@ def test_read(source, cdf_source, zarr_source, dataset):
assert np.all(ds.rh == dataset.rh)


def test_read_partition_cdf(cdf_source):
source = cdf_source
def test_read_partition_netcdf(netcdf_source):
source = netcdf_source
with pytest.raises(TypeError):
source.read_partition(None)
out = source.read_partition(('temp', 0, 0, 0, 0))
Expand All @@ -48,6 +46,28 @@ def test_read_partition_cdf(cdf_source):
assert np.all(out == expected)


def test_read_list_of_netcdf_files():
from intake_xarray.netcdf import NetCDFSource
source = NetCDFSource([
os.path.join(here, 'data', 'example_1.nc'),
os.path.join(here, 'data', 'example_2.nc'),
])
d = source.to_dask()
assert d.dims == {'lat': 5, 'lon': 10, 'level': 4, 'time': 1,
'concat_dim': 2}


def test_read_glob_pattern_of_netcdf_files():
from intake_xarray.netcdf import NetCDFSource

source = NetCDFSource(os.path.join(here, 'data', 'example_{num: d}.nc'),
concat_dim='num')
d = source.to_dask()
assert d.dims == {'lat': 5, 'lon': 10, 'level': 4, 'time': 1,
'num': 2}
assert (d.num.data == np.array([1, 2])).all()


def test_read_partition_zarr(zarr_source):
source = zarr_source
with pytest.raises(TypeError):
Expand All @@ -57,9 +77,9 @@ def test_read_partition_zarr(zarr_source):
assert np.all(out == expected)


@pytest.mark.parametrize('source', ['cdf', 'zarr'])
def test_to_dask(source, cdf_source, zarr_source, dataset):
source = {'cdf': cdf_source, 'zarr': zarr_source}[source]
@pytest.mark.parametrize('source', ['netcdf', 'zarr'])
def test_to_dask(source, netcdf_source, zarr_source, dataset):
source = {'netcdf': netcdf_source, 'zarr': zarr_source}[source]
ds = source.to_dask()

assert ds.dims == dataset.dims
Expand Down

0 comments on commit ffad7d9

Please sign in to comment.