Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fill_value in shift #2470

Merged
merged 17 commits into from
Dec 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ Enhancements
- 0d slices of ndarrays are now obtained directly through indexing, rather than
extracting and wrapping a scalar, avoiding unnecessary copying. By `Daniel
Wennberg <https://github.com/danielwe>`_.
- Added support for ``fill_value`` with
:py:meth:`~xarray.DataArray.shift` and :py:meth:`~xarray.Dataset.shift`
By `Maximilian Roos <https://github.com/max-sixty>`_

Bug fixes
~~~~~~~~~
Expand Down
12 changes: 8 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np
import pandas as pd

from . import computation, groupby, indexing, ops, resample, rolling, utils
from . import (
computation, dtypes, groupby, indexing, ops, resample, rolling, utils)
from ..plot.plot import _PlotMethods
from .accessors import DatetimeAccessor
from .alignment import align, reindex_like_indexers
Expand Down Expand Up @@ -2085,7 +2086,7 @@ def diff(self, dim, n=1, label='upper'):
ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label)
return self._from_temp_dataset(ds)

def shift(self, shifts=None, **shifts_kwargs):
def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
"""Shift this array by an offset along one or more dimensions.

Only the data is moved; coordinates stay in place. Values shifted from
Expand All @@ -2098,6 +2099,8 @@ def shift(self, shifts=None, **shifts_kwargs):
Integer offset to shift along each of the given dimensions.
Positive offsets shift to the right; negative offsets shift to the
left.
fill_value: scalar, optional
Value to use for newly missing values
**shifts_kwargs:
The keyword arguments form of ``shifts``.
One of shifts or shifts_kwarg must be provided.
Expand All @@ -2122,8 +2125,9 @@ def shift(self, shifts=None, **shifts_kwargs):
Coordinates:
* x (x) int64 0 1 2
"""
ds = self._to_temp_dataset().shift(shifts=shifts, **shifts_kwargs)
return self._from_temp_dataset(ds)
variable = self.variable.shift(
shifts=shifts, fill_value=fill_value, **shifts_kwargs)
return self._replace(variable=variable)

def roll(self, shifts=None, roll_coords=None, **shifts_kwargs):
"""Roll this array by an offset along one or more dimensions.
Expand Down
15 changes: 9 additions & 6 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import xarray as xr

from . import (
alignment, duck_array_ops, formatting, groupby, indexing, ops, pdcompat,
resample, rolling, utils)
alignment, dtypes, duck_array_ops, formatting, groupby, indexing, ops,
pdcompat, resample, rolling, utils)
from ..coding.cftimeindex import _parse_array_of_cftime_strings
from .alignment import align
from .common import (
Expand Down Expand Up @@ -3476,7 +3476,7 @@ def diff(self, dim, n=1, label='upper'):
else:
return difference

def shift(self, shifts=None, **shifts_kwargs):
def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
"""Shift this dataset by an offset along one or more dimensions.

Only data variables are moved; coordinates stay in place. This is
Expand All @@ -3488,6 +3488,8 @@ def shift(self, shifts=None, **shifts_kwargs):
Integer offset to shift along each of the given dimensions.
Positive offsets shift to the right; negative offsets shift to the
left.
fill_value: scalar, optional
Value to use for newly missing values
**shifts_kwargs:
The keyword arguments form of ``shifts``.
One of shifts or shifts_kwarg must be provided.
Expand Down Expand Up @@ -3522,9 +3524,10 @@ def shift(self, shifts=None, **shifts_kwargs):
variables = OrderedDict()
for name, var in iteritems(self.variables):
if name in self.data_vars:
var_shifts = dict((k, v) for k, v in shifts.items()
if k in var.dims)
variables[name] = var.shift(**var_shifts)
var_shifts = {k: v for k, v in shifts.items()
if k in var.dims}
variables[name] = var.shift(
fill_value=fill_value, shifts=var_shifts)
else:
variables[name] = var

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def wrapped_func(self, **kwargs):
else:
shift = (-self.window // 2) + 1
valid = (slice(None), ) * axis + (slice(-shift, None), )
padded = padded.pad_with_fill_value(**{self.dim: (0, -shift)})
padded = padded.pad_with_fill_value({self.dim: (0, -shift)})

if isinstance(padded.data, dask_array_type):
values = dask_rolling_wrapper(func, padded,
Expand Down
20 changes: 13 additions & 7 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ def squeeze(self, dim=None):
dims = common.get_squeeze_dims(self, dim)
return self.isel({d: 0 for d in dims})

def _shift_one_dim(self, dim, count):
def _shift_one_dim(self, dim, count, fill_value=dtypes.NA):
axis = self.get_axis_num(dim)

if count > 0:
Expand All @@ -944,7 +944,11 @@ def _shift_one_dim(self, dim, count):
keep = slice(None)

trimmed_data = self[(slice(None),) * axis + (keep,)].data
dtype, fill_value = dtypes.maybe_promote(self.dtype)

if fill_value is dtypes.NA:
dtype, fill_value = dtypes.maybe_promote(self.dtype)
else:
dtype = self.dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if filler is not compatible with self.dtype?
For example, feeding np.nan to an int array.
Probably it is a part of user responsibility and we do not need to take care of this, but I am just curious of it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, NumPy should raise an error... But it may not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this is the issue I'm looking at ref #2470 (comment)), good foresight @fujiisoup !


shape = list(self.shape)
shape[axis] = min(abs(count), shape[axis])
Expand All @@ -956,12 +960,12 @@ def _shift_one_dim(self, dim, count):
else:
full = np.full

nans = full(shape, fill_value, dtype=dtype)
filler = full(shape, fill_value, dtype=dtype)

if count > 0:
arrays = [nans, trimmed_data]
arrays = [filler, trimmed_data]
else:
arrays = [trimmed_data, nans]
arrays = [trimmed_data, filler]

data = duck_array_ops.concatenate(arrays, axis)

Expand All @@ -973,7 +977,7 @@ def _shift_one_dim(self, dim, count):

return type(self)(self.dims, data, self._attrs, fastpath=True)

def shift(self, shifts=None, **shifts_kwargs):
def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
"""
Return a new Variable with shifted data.

Expand All @@ -983,6 +987,8 @@ def shift(self, shifts=None, **shifts_kwargs):
Integer offset to shift along each of the given dimensions.
Positive offsets shift to the right; negative offsets shift to the
left.
fill_value: scalar, optional
Value to use for newly missing values
**shifts_kwargs:
The keyword arguments form of ``shifts``.
One of shifts or shifts_kwarg must be provided.
Expand All @@ -995,7 +1001,7 @@ def shift(self, shifts=None, **shifts_kwargs):
shifts = either_dict_or_kwargs(shifts, shifts_kwargs, 'shift')
result = self
for dim, count in shifts.items():
result = result._shift_one_dim(dim, count)
result = result._shift_one_dim(dim, count, fill_value=fill_value)
return result

def pad_with_fill_value(self, pad_widths=None, fill_value=dtypes.NA,
Expand Down
18 changes: 13 additions & 5 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DataArray, Dataset, IndexVariable, Variable, align, broadcast)
from xarray.coding.times import CFDatetimeCoder, _import_cftime
from xarray.convert import from_cdms2
from xarray.core import dtypes
from xarray.core.common import ALL_DIMS, full_like
from xarray.core.pycompat import OrderedDict, iteritems
from xarray.tests import (
Expand Down Expand Up @@ -3128,12 +3129,19 @@ def test_coordinate_diff(self):
actual = lon.diff('lon')
assert_equal(expected, actual)

@pytest.mark.parametrize('offset', [-5, -2, -1, 0, 1, 2, 5])
def test_shift(self, offset):
@pytest.mark.parametrize('offset', [-5, 0, 1, 2])
@pytest.mark.parametrize('fill_value, dtype',
[(2, int), (dtypes.NA, float)])
def test_shift(self, offset, fill_value, dtype):
arr = DataArray([1, 2, 3], dims='x')
actual = arr.shift(x=1)
expected = DataArray([np.nan, 1, 2], dims='x')
assert_identical(expected, actual)
actual = arr.shift(x=1, fill_value=fill_value)
if fill_value == dtypes.NA:
# if we supply the default, we expect the missing value for a
# float array
fill_value = np.nan
expected = DataArray([fill_value, 1, 2], dims='x')
assert_identical(expected, actual)
assert actual.dtype == dtype

arr = DataArray([1, 2, 3], [('x', ['a', 'b', 'c'])])
expected = DataArray(arr.to_pandas().shift(offset))
Expand Down
13 changes: 9 additions & 4 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from xarray import (
ALL_DIMS, DataArray, Dataset, IndexVariable, MergeError, Variable, align,
backends, broadcast, open_dataset, set_options)
from xarray.core import indexing, npcompat, utils
from xarray.core import dtypes, indexing, npcompat, utils
from xarray.core.common import full_like
from xarray.core.pycompat import (
OrderedDict, integer_types, iteritems, unicode_type)
Expand Down Expand Up @@ -3917,12 +3917,17 @@ def test_dataset_diff_exception_label_str(self):
with raises_regex(ValueError, '\'label\' argument has to'):
ds.diff('dim2', label='raise_me')

def test_shift(self):
@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
def test_shift(self, fill_value):
coords = {'bar': ('x', list('abc')), 'x': [-4, 3, 2]}
attrs = {'meta': 'data'}
ds = Dataset({'foo': ('x', [1, 2, 3])}, coords, attrs)
actual = ds.shift(x=1)
expected = Dataset({'foo': ('x', [np.nan, 1, 2])}, coords, attrs)
actual = ds.shift(x=1, fill_value=fill_value)
if fill_value == dtypes.NA:
# if we supply the default, we expect the missing value for a
# float array
fill_value = np.nan
expected = Dataset({'foo': ('x', [fill_value, 1, 2])}, coords, attrs)
assert_identical(expected, actual)

with raises_regex(ValueError, 'dimensions'):
Expand Down
32 changes: 20 additions & 12 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pytz

from xarray import Coordinate, Dataset, IndexVariable, Variable, set_options
from xarray.core import indexing
from xarray.core import dtypes, indexing
from xarray.core.common import full_like, ones_like, zeros_like
from xarray.core.indexing import (
BasicIndexer, CopyOnWriteArray, DaskIndexingAdapter,
Expand Down Expand Up @@ -1179,33 +1179,41 @@ def test_indexing_0d_unicode(self):
expected = Variable((), u'tmax')
assert_identical(actual, expected)

def test_shift(self):
@pytest.mark.parametrize('fill_value', [dtypes.NA, 2, 2.0])
def test_shift(self, fill_value):
v = Variable('x', [1, 2, 3, 4, 5])

assert_identical(v, v.shift(x=0))
assert v is not v.shift(x=0)

expected = Variable('x', [np.nan, 1, 2, 3, 4])
assert_identical(expected, v.shift(x=1))

expected = Variable('x', [np.nan, np.nan, 1, 2, 3])
assert_identical(expected, v.shift(x=2))

expected = Variable('x', [2, 3, 4, 5, np.nan])
assert_identical(expected, v.shift(x=-1))
if fill_value == dtypes.NA:
# if we supply the default, we expect the missing value for a
# float array
fill_value_exp = np.nan
else:
fill_value_exp = fill_value

expected = Variable('x', [fill_value_exp, 1, 2, 3, 4])
assert_identical(expected, v.shift(x=1, fill_value=fill_value))

expected = Variable('x', [2, 3, 4, 5, fill_value_exp])
assert_identical(expected, v.shift(x=-1, fill_value=fill_value))

expected = Variable('x', [np.nan] * 5)
assert_identical(expected, v.shift(x=5))
assert_identical(expected, v.shift(x=6))
expected = Variable('x', [fill_value_exp] * 5)
assert_identical(expected, v.shift(x=5, fill_value=fill_value))
assert_identical(expected, v.shift(x=6, fill_value=fill_value))

with raises_regex(ValueError, 'dimension'):
v.shift(z=0)

v = Variable('x', [1, 2, 3, 4, 5], {'foo': 'bar'})
assert_identical(v, v.shift(x=0))

expected = Variable('x', [np.nan, 1, 2, 3, 4], {'foo': 'bar'})
assert_identical(expected, v.shift(x=1))
expected = Variable('x', [fill_value_exp, 1, 2, 3, 4], {'foo': 'bar'})
assert_identical(expected, v.shift(x=1, fill_value=fill_value))

def test_shift2d(self):
v = Variable(('x', 'y'), [[1, 2], [3, 4]])
Expand Down