Skip to content

Commit

Permalink
TST: Add more tests for rasterio engine
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Apr 13, 2021
1 parent 8fafded commit 362d19e
Show file tree
Hide file tree
Showing 6 changed files with 586 additions and 492 deletions.
4 changes: 2 additions & 2 deletions rioxarray/rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def _generate_spatial_coords(affine, width, height):
}
else:
return {
"xc": (("x", "y"), new_spatial_coords["x"]),
"yc": (("x", "y"), new_spatial_coords["y"]),
"xc": (("y", "x"), new_spatial_coords["x"]),
"yc": (("y", "x"), new_spatial_coords["y"]),
}


Expand Down
12 changes: 10 additions & 2 deletions rioxarray/xarray_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import xarray as xr

from . import _io
from .exceptions import RioXarrayError

CAN_OPEN_EXTS = {
"asc",
Expand All @@ -28,11 +29,13 @@ class RasterioBackend(xr.backends.common.BackendEntrypoint):
def open_dataset(
self,
filename_or_obj,
drop_variables=None,
mask_and_scale=True,
drop_variables=None, # SKIP FROM XARRAY
parse_coordinates=None,
chunks=None,
cache=None,
lock=None,
masked=False,
mask_and_scale=True,
variable=None,
group=None,
default_name="band_data",
Expand All @@ -53,6 +56,11 @@ def open_dataset(
)
if isinstance(ds, xr.DataArray):
ds = ds.to_dataset()
if not isinstance(ds, xr.Dataset):
raise RioXarrayError(
"Multiple resolution sets found. "
"Use 'variable' or 'group' to filter."
)
return ds

def guess_can_open(self, filename_or_obj):
Expand Down
23 changes: 23 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import os
from distutils.version import LooseVersion

import pyproj
import pytest
import rasterio
from numpy.testing import assert_almost_equal, assert_array_equal

import rioxarray
from rioxarray.raster_array import UNWANTED_RIO_ATTRS

TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "test_data")
TEST_INPUT_DATA_DIR = os.path.join(TEST_DATA_DIR, "input")
TEST_COMPARE_DATA_DIR = os.path.join(TEST_DATA_DIR, "compare")
PYPROJ_LT_3 = LooseVersion(pyproj.__version__) < LooseVersion("3")
RASTERIO_LT_122 = LooseVersion(rasterio.__version__) < LooseVersion("1.2.2")


# xarray.testing.assert_equal(input_xarray, compare_xarray)
Expand Down Expand Up @@ -84,3 +91,19 @@ def _assert_xarrays_equal(
assert input_xarray.rio.grid_mapping == compare_xarray.rio.grid_mapping
for unwanted_attr in UNWANTED_RIO_ATTRS:
assert unwanted_attr not in input_xarray.attrs


def open_rasterio_engine(file_name_or_object, **kwargs):
# FIXME: change to the next xarray version after release
xr = pytest.importorskip("xarray", minversion="0.17.1.dev0")
return xr.open_dataset(file_name_or_object, engine="rasterio", **kwargs)


@pytest.fixture(
params=[
rioxarray.open_rasterio,
open_rasterio_engine,
]
)
def open_rasterio(request):
return request.param
Loading

0 comments on commit 362d19e

Please sign in to comment.