Skip to content
forked from pydata/xarray

Commit

Permalink
Support pandas copy-on-write behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Mar 16, 2024
1 parent fbcac76 commit 12c253d
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 38 deletions.
6 changes: 5 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,11 @@ def _possibly_convert_objects(values):
as_series = pd.Series(values.ravel(), copy=False)
if as_series.dtype.kind in "mM":
as_series = _as_nanosecond_precision(as_series)
return np.asarray(as_series).reshape(values.shape)
result = np.asarray(as_series).reshape(values.shape)
if not result.flags.writeable:
# GH8843, pandas copy-on-write mode creates read-only arrays by default
result = result.copy()
return result


def _possibly_convert_datetime_or_timedelta_index(data):
Expand Down
13 changes: 12 additions & 1 deletion xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401
from xarray.core.indexing import ExplicitlyIndexed
from xarray.core.options import set_options
from xarray.core.variable import IndexVariable
from xarray.testing import ( # noqa: F401
assert_chunks_equal,
assert_duckarray_allclose,
Expand All @@ -36,6 +37,7 @@
except ImportError:
pass


# https://github.com/pydata/xarray/issues/7322
warnings.filterwarnings("ignore", "'urllib3.contrib.pyopenssl' module is deprecated")
warnings.filterwarnings("ignore", "Deprecated call to `pkg_resources.declare_namespace")
Expand All @@ -47,6 +49,15 @@
)


def assert_writeable(ds):
readonly = [
name
for name, var in ds.variables.items()
if not isinstance(var, IndexVariable) and not var.data.flags.writeable
]
assert not readonly, readonly


def _importorskip(
modname: str, minversion: str | None = None
) -> tuple[bool, pytest.MarkDecorator]:
Expand Down Expand Up @@ -326,7 +337,7 @@ def create_test_data(
numbers_values = np.random.randint(0, 3, _dims["dim3"], dtype="int64")
obj.coords["numbers"] = ("dim3", numbers_values)
obj.encoding = {"foo": "bar"}
assert all(obj.data.flags.writeable for obj in obj.variables.values())
assert_writeable(obj)
return obj


Expand Down
8 changes: 6 additions & 2 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2605,7 +2605,9 @@ def test_append_with_append_dim_no_overwrite(self) -> None:
# overwrite a coordinate;
# for mode='a-', this will not get written to the store
# because it does not have the append_dim as a dim
ds_to_append.lon.data[:] = -999
lon = ds_to_append.lon.to_numpy().copy()
lon[:] = -999
ds_to_append["lon"] = lon
ds_to_append.to_zarr(
store_target, mode="a-", append_dim="time", **self.version_kwargs
)
Expand All @@ -2615,7 +2617,9 @@ def test_append_with_append_dim_no_overwrite(self) -> None:
# by default, mode="a" will overwrite all coordinates.
ds_to_append.to_zarr(store_target, append_dim="time", **self.version_kwargs)
actual = xr.open_dataset(store_target, engine="zarr", **self.version_kwargs)
original2.lon.data[:] = -999
lon = original2.lon.to_numpy().copy()
lon[:] = -999
original2["lon"] = lon
assert_identical(original2, actual)

@requires_dask
Expand Down
52 changes: 18 additions & 34 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
assert_equal,
assert_identical,
assert_no_warnings,
assert_writeable,
create_test_data,
has_cftime,
has_dask,
Expand Down Expand Up @@ -96,11 +97,11 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]:
nt2 = 2
time1 = pd.date_range("2000-01-01", periods=nt1)
time2 = pd.date_range("2000-02-01", periods=nt2)
string_var = np.array(["ae", "bc", "df"], dtype=object)
string_var = np.array(["a", "bc", "def"], dtype=object)
string_var_to_append = np.array(["asdf", "asdfg"], dtype=object)
string_var_fixed_length = np.array(["aa", "bb", "cc"], dtype="|S2")
string_var_fixed_length_to_append = np.array(["dd", "ee"], dtype="|S2")
unicode_var = ["áó", "áó", "áó"]
unicode_var = np.array(["áó", "áó", "áó"])
datetime_var = np.array(
["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[s]"
)
Expand All @@ -119,17 +120,11 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]:
coords=[lat, lon, time1],
dims=["lat", "lon", "time"],
),
"string_var": xr.DataArray(string_var, coords=[time1], dims=["time"]),
"string_var_fixed_length": xr.DataArray(
string_var_fixed_length, coords=[time1], dims=["time"]
),
"unicode_var": xr.DataArray(
unicode_var, coords=[time1], dims=["time"]
).astype(np.str_),
"datetime_var": xr.DataArray(
datetime_var, coords=[time1], dims=["time"]
),
"bool_var": xr.DataArray(bool_var, coords=[time1], dims=["time"]),
"string_var": ("time", string_var),
"string_var_fixed_length": ("time", string_var_fixed_length),
"unicode_var": ("time", unicode_var),
"datetime_var": ("time", datetime_var),
"bool_var": ("time", bool_var),
}
)

Expand All @@ -140,21 +135,11 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]:
coords=[lat, lon, time2],
dims=["lat", "lon", "time"],
),
"string_var": xr.DataArray(
string_var_to_append, coords=[time2], dims=["time"]
),
"string_var_fixed_length": xr.DataArray(
string_var_fixed_length_to_append, coords=[time2], dims=["time"]
),
"unicode_var": xr.DataArray(
unicode_var[:nt2], coords=[time2], dims=["time"]
).astype(np.str_),
"datetime_var": xr.DataArray(
datetime_var_to_append, coords=[time2], dims=["time"]
),
"bool_var": xr.DataArray(
bool_var_to_append, coords=[time2], dims=["time"]
),
"string_var": ("time", string_var_to_append),
"string_var_fixed_length": ("time", string_var_fixed_length_to_append),
"unicode_var": ("time", unicode_var[:nt2]),
"datetime_var": ("time", datetime_var_to_append),
"bool_var": ("time", bool_var_to_append),
}
)

Expand All @@ -168,8 +153,9 @@ def create_append_test_data(seed=None) -> tuple[Dataset, Dataset, Dataset]:
}
)

assert all(objp.data.flags.writeable for objp in ds.variables.values())
assert all(objp.data.flags.writeable for objp in ds_to_append.variables.values())
assert_writeable(ds)
assert_writeable(ds_to_append)
assert_writeable(ds_with_new_var)
return ds, ds_to_append, ds_with_new_var


Expand All @@ -182,10 +168,8 @@ def make_datasets(data, data_to_append) -> tuple[Dataset, Dataset]:
ds_to_append = xr.Dataset(
{"temperature": (["time"], data_to_append)}, coords={"time": [0, 1, 2]}
)
assert all(objp.data.flags.writeable for objp in ds.variables.values())
assert all(
objp.data.flags.writeable for objp in ds_to_append.variables.values()
)
assert_writeable(ds)
assert_writeable(ds_to_append)
return ds, ds_to_append

u2_strings = ["ab", "cd", "ef"]
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ def var():
return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5))


@pytest.mark.parametrize(
"data",
[
np.array(["a", "bc", "def"], dtype=object),
np.array(["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[ns]"),
],
)
def test_as_compatible_data_writeable(data):
pd.set_option("mode.copy_on_write", True)
# GH8843, ensure writeable arrays for data_vars even with
# pandas copy-on-write mode
assert as_compatible_data(data).flags.writeable
pd.reset_option("mode.copy_on_write")


class VariableSubclassobjects(NamedArraySubclassobjects, ABC):
@pytest.fixture
def target(self, data):
Expand Down

0 comments on commit 12c253d

Please sign in to comment.