Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] support for xarray DataArray & mtypes #3255

Merged
merged 13 commits into from
Aug 26, 2022
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ all_extras = [
"tensorflow",
"tsfresh>=0.17.0; python_version < '3.10'",
"tslearn>=0.5.2",
"xarray",
]

dev = [
Expand Down
69 changes: 69 additions & 0 deletions sktime/datatypes/_series/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import numpy as np
import pandas as pd

from sktime.utils.validation._dependencies import _check_soft_dependencies
from sktime.utils.validation.series import is_in_valid_index_types

VALID_INDEX_TYPES = (pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)
Expand Down Expand Up @@ -235,3 +236,71 @@ def _index_equally_spaced(index):
all_equal = np.all(diffs == diffs[0])

return all_equal


if _check_soft_dependencies("xarray", severity="none"):
import xarray as xr

def check_xrdataarray_series(obj, return_metadata=False, var_name="obj"):
metadata = {}

def ret(valid, msg, metadata, return_metadata):
if return_metadata:
return valid, msg, metadata
return valid

if not isinstance(obj, xr.DataArray):
msg = f"{var_name} must be a xarray.DataArray, found {type(obj)}"
return ret(False, msg, None, return_metadata)

# we now know obj is a xr.DataArray
if len(obj.dims) > 2: # Without multi indexing only two dimensions are possible
msg = f"{var_name} must have two or less dimension, found {type(obj.dims)}"
return ret(False, msg, None, return_metadata)

# The first dimension is the index of the time series in sktimelen
index = obj.indexes[obj.dims[0]]
benHeid marked this conversation as resolved.
Show resolved Hide resolved

metadata["is_empty"] = len(index) < 1 or len(obj.values) < 1
# The second dimension is the set of columns
metadata["is_univariate"] = len(obj.dims) == 1 or len(obj[obj.dims[1]]) < 2

# check that columns are unique
msg = f"{var_name} must have " f"unique column indices, but found {obj.dims}"
assert len(obj.dims) == len(set(obj.dims)), msg

# check whether the time index is of valid type
if not is_in_valid_index_types(index):
msg = (
f"{type(index)} is not supported for {var_name}, use "
f"one of {VALID_INDEX_TYPES} or integer index instead."
)
return ret(False, msg, None, return_metadata)

# check that the dtype is not object
if "object" == obj.dtype:
msg = f"{var_name} should not have column of 'object' dtype"
return ret(False, msg, None, return_metadata)

# Check time index is ordered in time
if not index.is_monotonic:
msg = (
f"The (time) index of {var_name} must be sorted "
f"monotonically increasing, but found: {index}"
)
return ret(False, msg, None, return_metadata)

if FREQ_SET_CHECK and isinstance(index, pd.DatetimeIndex):
if index.freq is None:
msg = f"{var_name} has DatetimeIndex, but no freq attribute set."
return ret(False, msg, None, return_metadata)

# check whether index is equally spaced or if there are any nans
# compute only if needed
if return_metadata:
metadata["is_equally_spaced"] = _index_equally_spaced(index)
metadata["has_nans"] = obj.isnull().values.any()

return ret(True, None, metadata, return_metadata)

check_dict[("xr.DataArray", "Series")] = check_xrdataarray_series
71 changes: 70 additions & 1 deletion sktime/datatypes/_series/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
##############################################################
# methods to convert one machine type to another machine type
##############################################################
from sktime.datatypes._registry import MTYPE_LIST_SERIES
from sktime.utils.validation._dependencies import _check_soft_dependencies

convert_dict = dict()

Expand All @@ -47,7 +49,7 @@ def convert_identity(obj, store=None):


# assign identity function to type conversion to self
for tp in ["pd.Series", "pd.DataFrame", "np.ndarray"]:
for tp in ["pd.Series", "pd.DataFrame", "np.ndarray", "xr.DataArray"]:
convert_dict[(tp, tp, "Series")] = convert_identity


Expand Down Expand Up @@ -172,3 +174,70 @@ def convert_np_to_UvS_as_Series(obj: np.ndarray, store=None) -> pd.Series:


convert_dict[("np.ndarray", "pd.Series", "Series")] = convert_np_to_UvS_as_Series


# obtain other conversions from/to numpyflat via concatenation to numpy3D
def _concat(fun1, fun2):
def concat_fun(obj, store=None):
obj1 = fun1(obj, store=store)
obj2 = fun2(obj1, store=store)
return obj2

return concat_fun


def _extend_conversions(mtype, anchor_mtype, convert_dict):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If datatypes is refactored, then this method should be placed somewhere else. Currently similar methods are implemented in panels and tables.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, indeed - too much copy-paste!
Though I would go with incurring technical debt due to priorities...

Re refactor directions:

@chrisholder started sth, but abandoned it.
#1460
Haven´t had the time to salvage it.

This would be another option:
dylan-profiler/visions#194

keys = convert_dict.keys()
scitype = list(keys)[0][2]

for tp in set(MTYPE_LIST_SERIES).difference([mtype, anchor_mtype]):
if (anchor_mtype, tp, scitype) in convert_dict.keys():
convert_dict[(mtype, tp, scitype)] = _concat(
convert_dict[(mtype, anchor_mtype, scitype)],
convert_dict[(anchor_mtype, tp, scitype)],
)
if (tp, anchor_mtype, scitype) in convert_dict.keys():
convert_dict[(tp, mtype, scitype)] = _concat(
convert_dict[(tp, anchor_mtype, scitype)],
convert_dict[(anchor_mtype, mtype, scitype)],
)


if _check_soft_dependencies("xarray", severity="none"):
import xarray as xr

def convert_xrdataarray_to_Mvs_as_Series(
obj: xr.DataArray, store=None
) -> pd.DataFrame:
if not isinstance(obj, xr.DataArray):
raise TypeError("input must be a xr.DataArray")

if isinstance(store, dict):
store["coords"] = list(obj.coords.keys())

index = obj.indexes[obj.dims[0]]
columns = obj.indexes[obj.dims[1]] if len(obj.dims) == 2 else None
return pd.DataFrame(obj.values, index=index, columns=columns)

convert_dict[
("xr.DataArray", "pd.DataFrame", "Series")
] = convert_xrdataarray_to_Mvs_as_Series

def convert_Mvs_to_xrdatarray_as_Series(
obj: pd.DataFrame, store=None
) -> xr.DataArray:
if not isinstance(obj, pd.DataFrame):
raise TypeError("input must be a xr.DataArray")

result = xr.DataArray(obj.values, coords=[obj.index, obj.columns])
fkiraly marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(store, dict) and "coords" in store:
result = result.rename(
dict(zip(list(result.coords.keys()), store["coords"]))
)
return result

convert_dict[
("pd.DataFrame", "xr.DataArray", "Series")
] = convert_Mvs_to_xrdatarray_as_Series

_extend_conversions("xr.DataArray", "pd.DataFrame", convert_dict)
36 changes: 36 additions & 0 deletions sktime/datatypes/_series/_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import numpy as np
import pandas as pd

from sktime.utils.validation._dependencies import _check_soft_dependencies

example_dict = dict()
example_dict_lossy = dict()
example_dict_metadata = dict()
Expand All @@ -54,6 +56,18 @@
example_dict[("np.ndarray", "Series", 0)] = arr
example_dict_lossy[("np.ndarray", "Series", 0)] = True

if _check_soft_dependencies("xarray", severity="none"):
import xarray as xr

da = xr.DataArray(
[[1], [4], [0.5], [-3]],
coords=[[0, 1, 2, 3], ["a"]],
)

example_dict[("xr.DataArray", "Series", 0)] = da
example_dict_lossy[("xr.DataArray", "Series", 0)] = False


example_dict_metadata[("Series", 0)] = {
"is_univariate": True,
"is_equally_spaced": True,
Expand All @@ -76,6 +90,16 @@

example_dict[("np.ndarray", "Series", 1)] = arr
example_dict_lossy[("np.ndarray", "Series", 1)] = True
if _check_soft_dependencies("xarray", severity="none"):
import xarray as xr

da = xr.DataArray(
[[1, 3], [4, 7], [0.5, 2], [-3, -3 / 7]],
coords=[[0, 1, 2, 3], ["a", "b"]],
)

example_dict[("xr.DataArray", "Series", 1)] = da
example_dict_lossy[("xr.DataArray", "Series", 1)] = False

example_dict_metadata[("Series", 1)] = {
"is_univariate": False,
Expand All @@ -100,6 +124,18 @@
example_dict[("np.ndarray", "Series", 2)] = arr
example_dict_lossy[("np.ndarray", "Series", 2)] = True

if _check_soft_dependencies("xarray", severity="none"):
import xarray as xr

da = xr.DataArray(
[[1, 3], [4, 7], [0.5, 2], [3, 3 / 7]],
coords=[[0, 1, 2, 3], ["a", "b"]],
)

example_dict[("xr.DataArray", "Series", 2)] = da
example_dict_lossy[("xr.DataArray", "Series", 2)] = False


example_dict_metadata[("Series", 2)] = {
"is_univariate": False,
"is_equally_spaced": True,
Expand Down
5 changes: 5 additions & 0 deletions sktime/datatypes/_series/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
"Series",
"2D numpy.ndarray with rows=samples, cols=variables, index=integers",
),
(
"xr.DataArray",
"Series",
"xr.DataArray representation of a uni- or multivariate series",
),
]

MTYPE_LIST_SERIES = pd.DataFrame(MTYPE_REGISTER_SERIES)[0].values
5 changes: 3 additions & 2 deletions sktime/utils/_testing/deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,10 @@ def ret(is_equal, msg):
)
elif type(x).__name__ == "ForecastingHorizon":
return ret(*_fh_equals(x, y, return_msg=True))
elif x != y:
elif isinstance(x != y, bool) and x != y:
return ret(False, f" !=, {x} != {y}")
elif np.any(x != y):
benHeid marked this conversation as resolved.
Show resolved Hide resolved
return ret(False, f" !=, {x} != {y}")

return ret(True, "")


Expand Down