Skip to content

Commit

Permalink
revert to set_dims
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 committed Oct 27, 2023
1 parent a5918c3 commit 8091265
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def _dataset_concat(

# case where concat dimension is a coordinate or data_var but not a dimension
if (dim in coord_names or dim in data_names) and dim not in dim_names:
datasets = [ds.set_dims(dim) for ds in datasets]
datasets = [ds.expand_dims(dim) for ds in datasets]

# determine which variables to concatenate
concat_over, equals, concat_dim_lengths = _calc_concat_over(
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2537,7 +2537,7 @@ def expand_dims(
See Also
--------
Dataset.set_dims
Dataset.expand_dims
Examples
--------
Expand All @@ -2549,13 +2549,13 @@ def expand_dims(
Add new dimension of length 2:
>>> da.set_dims(dim={"y": 2})
>>> da.expand_dims(dim={"y": 2})
<xarray.DataArray (y: 2, x: 5)>
array([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]])
Dimensions without coordinates: y, x
>>> da.set_dims(dim={"y": 2}, axis=1)
>>> da.expand_dims(dim={"y": 2}, axis=1)
<xarray.DataArray (x: 5, y: 2)>
array([[0, 0],
[1, 1],
Expand All @@ -2566,7 +2566,7 @@ def expand_dims(
Add a new dimension with coordinates from array:
>>> da.set_dims(dim={"y": np.arange(5)}, axis=0)
>>> da.expand_dims(dim={"y": np.arange(5)}, axis=0)
<xarray.DataArray (y: 5, x: 5)>
array([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
Expand All @@ -2587,7 +2587,7 @@ def expand_dims(
dim = {dim: 1}

dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims")
ds = self._to_temp_dataset().set_dims(dim, axis)
ds = self._to_temp_dataset().expand_dims(dim, axis)
return self._from_temp_dataset(ds)

def set_index(
Expand Down
14 changes: 7 additions & 7 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4482,7 +4482,7 @@ def expand_dims(
# Expand the dataset with a new dimension called "time"
>>> dataset.set_dims(dim="time")
>>> dataset.expand_dims(dim="time")
<xarray.Dataset>
Dimensions: (time: 1)
Dimensions without coordinates: time
Expand All @@ -4502,7 +4502,7 @@ def expand_dims(
# Expand the dataset with a new dimension called "time" using axis argument
>>> dataset_1d.set_dims(dim="time", axis=0)
>>> dataset_1d.expand_dims(dim="time", axis=0)
<xarray.Dataset>
Dimensions: (time: 1, x: 3)
Dimensions without coordinates: time, x
Expand All @@ -4522,7 +4522,7 @@ def expand_dims(
# Expand the dataset with a new dimension called "time" using axis argument
>>> dataset_2d.set_dims(dim="time", axis=2)
>>> dataset_2d.expand_dims(dim="time", axis=2)
<xarray.Dataset>
Dimensions: (y: 3, x: 4, time: 1)
Dimensions without coordinates: y, x, time
Expand All @@ -4531,7 +4531,7 @@ def expand_dims(
See Also
--------
DataArray.set_dims
DataArray.expand_dims
"""
if dim is None:
pass
Expand Down Expand Up @@ -5143,7 +5143,7 @@ def _stack_once(
vdims = list(var.dims) + add_dims
shape = [self.dims[d] for d in vdims]
exp_var = var.set_dims(vdims, shape)
stacked_var = exp_var.stack(**{new_dim: dims}) # type: ignore
stacked_var = exp_var.stack(**{new_dim: dims})
new_variables[name] = stacked_var
stacked_var_names.append(name)
else:
Expand Down Expand Up @@ -5327,7 +5327,7 @@ def stack_dataarray(da):

return (
da.assign_coords(**missing_stack_coords)
.set_dims(missing_stack_dims)
.expand_dims(missing_stack_dims)
.stack({new_dim: (variable_dim,) + stacking_dims})
)

Expand Down Expand Up @@ -7337,7 +7337,7 @@ def to_dask_dataframe(

# Broadcast then flatten the array:
var_new_dims = var.set_dims(ordered_dims).chunk(ds_chunks)
dask_array = var_new_dims._data.reshape(-1) # type: ignore
dask_array = var_new_dims._data.reshape(-1)

series = dd.from_dask_array(dask_array, columns=name, meta=df_meta)
series_list.append(series)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def _binary_op(self, other, f, reflexive=False):
for var in other.coords:
if other[var].ndim == 0:
other[var] = (
other[var].drop_vars(var).set_dims({name: other.sizes[name]})
other[var].drop_vars(var).expand_dims({name: other.sizes[name]})
)

# need to handle NaNs in group or elements that don't belong to any bins
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_variable_property(prop):
(do("notnull"), True),
(do("roll"), True),
(do("round"), True),
(do("expand_dims", dims=("x", "y", "z")), True),
(do("set_dims", dims=("x", "y", "z")), True),
(do("stack", dimensions={"flat": ("x", "y")}), True),
(do("to_base_variable"), True),
(do("transpose"), True),
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -2147,7 +2147,7 @@ def test_concat(self, unit, error, dtype):
assert_units_equal(expected, actual)
assert_identical(expected, actual)

def test_expand_dims(self, dtype):
def test_set_dims(self, dtype):
array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m
variable = xr.Variable(("x", "y"), array)

Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,7 +1606,7 @@ def test_get_axis_num(self):
with pytest.raises(ValueError, match=r"not found in array dim"):
v.get_axis_num("foobar")

def test_expand_dims(self):
def test_set_dims(self):
v = Variable(["x"], [0, 1])
actual = v.set_dims(["x", "y"])
expected = Variable(["x", "y"], [[0], [1]])
Expand All @@ -1627,7 +1627,7 @@ def test_expand_dims(self):
with pytest.raises(ValueError, match=r"must be a superset"):
v.set_dims(["z"])

def test_expand_dims_object_dtype(self):
def test_set_dims_object_dtype(self):
v = Variable([], ("a", 1))
actual = v.set_dims(("x",), (3,))
exp_values = np.empty((3,), dtype=object)
Expand Down

0 comments on commit 8091265

Please sign in to comment.