Skip to content
forked from pydata/xarray

Commit

Permalink
bugfix.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 1, 2019
1 parent 4ee2963 commit 0711eb0
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 57 deletions.
54 changes: 28 additions & 26 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,34 +197,36 @@ def process_subset_opt(opt, subset):
equals[k] = getattr(variables[0], compat)(
var, equiv=lazy_array_equiv
)
if not equals[k]:
if equals[k] is not True:
# exit early if we know these are not equal or that
# equality cannot be determined i.e. one or all of
# the variables wraps a numpy array
break

if equals[k] is not None:
if equals[k] is False:
concat_over.add(k)
continue

# Compare the variable of all datasets vs. the one
# of the first dataset. Perform the minimum amount of
# loads in order to avoid multiple loads from disk
# while keeping the RAM footprint low.
v_lhs = datasets[0].variables[k].load()
# We'll need to know later on if variables are equal.
computed = []
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed.append(v_rhs)
if not getattr(v_lhs, compat)(v_rhs):
concat_over.add(k)
equals[k] = False
# computed variables are not to be re-computed
# again in the future
for ds, v in zip(datasets[1:], computed):
ds.variables[k].data = v.data
break
else:
equals[k] = True
if equals[k] is False:
concat_over.add(k)

elif equals[k] is None:
# Compare the variable of all datasets vs. the one
# of the first dataset. Perform the minimum amount of
# loads in order to avoid multiple loads from disk
# while keeping the RAM footprint low.
v_lhs = datasets[0].variables[k].load()
# We'll need to know later on if variables are equal.
computed = []
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed.append(v_rhs)
if not getattr(v_lhs, compat)(v_rhs):
concat_over.add(k)
equals[k] = False
# computed variables are not to be re-computed
# again in the future
for ds, v in zip(datasets[1:], computed):
ds.variables[k].data = v.data
break
else:
equals[k] = True

elif opt == "all":
concat_over.update(
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ def as_shared_dtype(scalars_or_arrays):

def lazy_array_equiv(arr1, arr2):
"""Like array_equal, but doesn't actually compare values.
Returns True or False when equality can be determined without computing.
Returns None when equality cannot determined (e.g. one or both of arr1, arr2 are numpy arrays)
Returns True when arr1, arr2 identical or their dask names are equal.
Returns False when shapes are not equal.
Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays;
or their dask names are not equal
"""
if arr1 is arr2:
return True
Expand All @@ -193,6 +195,8 @@ def lazy_array_equiv(arr1, arr2):
# GH3068
if arr1.name == arr2.name:
return True
else:
return None
return None


Expand Down
6 changes: 3 additions & 3 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ def unique_variable(
# first check without comparing values i.e. no computes
for var in variables[1:]:
equals = getattr(out, compat)(var, equiv=lazy_array_equiv)
if not equals:
if equals is not True:
break

# now compare values with minimum number of computes
if not equals:
if equals is None:
# now compare values with minimum number of computes
out = out.compute()
for var in variables[1:]:
equals = getattr(out, compat)(var)
Expand Down
72 changes: 46 additions & 26 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,32 +1202,6 @@ def test_identical_coords_no_computes():
assert_identical(c, a)


def test_lazy_array_equiv():
lons1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
var1 = lons1.variable
var2 = lons2.variable
with raise_if_dask_computes():
lons1.equals(lons2)
with raise_if_dask_computes():
var1.equals(var2 / 2, equiv=lazy_array_equiv)
assert var1.equals(var2.compute(), equiv=lazy_array_equiv) is None
assert var1.compute().equals(var2.compute(), equiv=lazy_array_equiv) is None

with raise_if_dask_computes():
assert lons1.equals(lons1.transpose("y", "x"))

with raise_if_dask_computes():
for compat in [
"broadcast_equals",
"equals",
"override",
"identical",
"no_conflicts",
]:
xr.merge([lons1, lons2], compat=compat)


@pytest.mark.parametrize(
"obj", [make_da(), make_da().compute(), make_ds(), make_ds().compute()]
)
Expand Down Expand Up @@ -1315,3 +1289,49 @@ def test_normalize_token_with_backend(map_ds):
map_ds.to_netcdf(tmp_file)
read = xr.open_dataset(tmp_file)
assert not dask.base.tokenize(map_ds) == dask.base.tokenize(read)


@pytest.mark.parametrize(
"compat", ["broadcast_equals", "equals", "identical", "no_conflicts"]
)
def test_lazy_array_equiv_variables(compat):
var1 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2))
var2 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2))
var3 = xr.Variable(("y", "x"), da.zeros((20, 10), chunks=2))

with raise_if_dask_computes():
assert getattr(var1, compat)(var2, equiv=lazy_array_equiv)
# values are actually equal, but we don't know that till we compute, return None
with raise_if_dask_computes():
assert getattr(var1, compat)(var2 / 2, equiv=lazy_array_equiv) is None

# shapes are not equal, return False without computes
with raise_if_dask_computes():
assert getattr(var1, compat)(var3, equiv=lazy_array_equiv) is False

# if one or both arrays are numpy, return None
assert getattr(var1, compat)(var2.compute(), equiv=lazy_array_equiv) is None
assert (
getattr(var1.compute(), compat)(var2.compute(), equiv=lazy_array_equiv) is None
)

with raise_if_dask_computes():
assert getattr(var1, compat)(var2.transpose("y", "x"))


@pytest.mark.parametrize(
"compat", ["broadcast_equals", "equals", "identical", "no_conflicts"]
)
def test_lazy_array_equiv_merge(compat):
da1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
da2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
da3 = xr.DataArray(da.ones((20, 10), chunks=2), dims=("y", "x"))

with raise_if_dask_computes():
xr.merge([da1, da2], compat=compat)
# shapes are not equal; no computes necessary
with raise_if_dask_computes(max_computes=0):
with pytest.raises(ValueError):
xr.merge([da1, da3], compat=compat)
with raise_if_dask_computes(max_computes=2):
xr.merge([da1, da2 / 2], compat=compat)

0 comments on commit 0711eb0

Please sign in to comment.