Skip to content

Commit

Permalink
copy the dtypes module to the namedarray package. (#8250)
Browse files Browse the repository at this point in the history
* move dtypes module to namedarray

* keep original dtypes

* revert utils changes

* Update xarray/namedarray/dtypes.py

Co-authored-by: Illviljan <[email protected]>

* Apply suggestions from code review

Co-authored-by: Illviljan <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix missing imports

* update typing

* fix return types

* type fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* type fixes

---------

Co-authored-by: Illviljan <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
4 people authored Oct 4, 2023
1 parent 36fe917 commit 8d54acf
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 1 deletion.
2 changes: 1 addition & 1 deletion xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def imag(self) -> Self:
"""
return self._replace(data=self.data.imag)

def __dask_tokenize__(self) -> Hashable | None:
def __dask_tokenize__(self) -> Hashable:
# Use v.data, instead of v._data, in order to cope with the wrappers
# around NetCDF and the like
from dask.base import normalize_token
Expand Down
199 changes: 199 additions & 0 deletions xarray/namedarray/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from __future__ import annotations

import functools
import sys
from typing import Any, Literal

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard

import numpy as np

from xarray.namedarray import utils

# Use as a sentinel value to indicate a dtype appropriate NA value.
NA = utils.ReprObject("<NA>")


@functools.total_ordering
class AlwaysGreaterThan:
def __gt__(self, other: Any) -> Literal[True]:
return True

def __eq__(self, other: Any) -> bool:
return isinstance(other, type(self))


@functools.total_ordering
class AlwaysLessThan:
def __lt__(self, other: Any) -> Literal[True]:
return True

def __eq__(self, other: Any) -> bool:
return isinstance(other, type(self))


# Equivalence to np.inf (-np.inf) for object-type
INF = AlwaysGreaterThan()
NINF = AlwaysLessThan()


# Pairs of types that, if both found, should be promoted to object dtype
# instead of following NumPy's own type-promotion rules. These type promotion
# rules match pandas instead. For reference, see the NumPy type hierarchy:
# https://numpy.org/doc/stable/reference/arrays.scalars.html
PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = (
(np.number, np.character), # numpy promotes to character
(np.bool_, np.character), # numpy promotes to character
(np.bytes_, np.str_), # numpy promotes to unicode
)


def maybe_promote(dtype: np.dtype[np.generic]) -> tuple[np.dtype[np.generic], Any]:
"""Simpler equivalent of pandas.core.common._maybe_promote
Parameters
----------
dtype : np.dtype
Returns
-------
dtype : Promoted dtype that can hold missing values.
fill_value : Valid missing value for the promoted dtype.
"""
# N.B. these casting rules should match pandas
dtype_: np.typing.DTypeLike
fill_value: Any
if np.issubdtype(dtype, np.floating):
dtype_ = dtype
fill_value = np.nan
elif np.issubdtype(dtype, np.timedelta64):
# See https://github.com/numpy/numpy/issues/10685
# np.timedelta64 is a subclass of np.integer
# Check np.timedelta64 before np.integer
fill_value = np.timedelta64("NaT")
dtype_ = dtype
elif np.issubdtype(dtype, np.integer):
dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
fill_value = np.nan
elif np.issubdtype(dtype, np.complexfloating):
dtype_ = dtype
fill_value = np.nan + np.nan * 1j
elif np.issubdtype(dtype, np.datetime64):
dtype_ = dtype
fill_value = np.datetime64("NaT")
else:
dtype_ = object
fill_value = np.nan

dtype_out = np.dtype(dtype_)
fill_value = dtype_out.type(fill_value)
return dtype_out, fill_value


NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype}


def get_fill_value(dtype: np.dtype[np.generic]) -> Any:
"""Return an appropriate fill value for this dtype.
Parameters
----------
dtype : np.dtype
Returns
-------
fill_value : Missing value corresponding to this dtype.
"""
_, fill_value = maybe_promote(dtype)
return fill_value


def get_pos_infinity(
dtype: np.dtype[np.generic], max_for_int: bool = False
) -> float | complex | AlwaysGreaterThan:
"""Return an appropriate positive infinity for this dtype.
Parameters
----------
dtype : np.dtype
max_for_int : bool
Return np.iinfo(dtype).max instead of np.inf
Returns
-------
fill_value : positive infinity value corresponding to this dtype.
"""
if issubclass(dtype.type, np.floating):
return np.inf

if issubclass(dtype.type, np.integer):
return np.iinfo(dtype.type).max if max_for_int else np.inf
if issubclass(dtype.type, np.complexfloating):
return np.inf + 1j * np.inf

return INF


def get_neg_infinity(
dtype: np.dtype[np.generic], min_for_int: bool = False
) -> float | complex | AlwaysLessThan:
"""Return an appropriate positive infinity for this dtype.
Parameters
----------
dtype : np.dtype
min_for_int : bool
Return np.iinfo(dtype).min instead of -np.inf
Returns
-------
fill_value : positive infinity value corresponding to this dtype.
"""
if issubclass(dtype.type, np.floating):
return -np.inf

if issubclass(dtype.type, np.integer):
return np.iinfo(dtype.type).min if min_for_int else -np.inf
if issubclass(dtype.type, np.complexfloating):
return -np.inf - 1j * np.inf

return NINF


def is_datetime_like(
dtype: np.dtype[np.generic],
) -> TypeGuard[np.datetime64 | np.timedelta64]:
"""Check if a dtype is a subclass of the numpy datetime types"""
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)


def result_type(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
) -> np.dtype[np.generic]:
"""Like np.result_type, but with type promotion rules matching pandas.
Examples of changed behavior:
number + string -> object (not string)
bytes + unicode -> object (not unicode)
Parameters
----------
*arrays_and_dtypes : list of arrays and dtypes
The dtype is extracted from both numpy and dask arrays.
Returns
-------
numpy.dtype for the result.
"""
types = {np.result_type(t).type for t in arrays_and_dtypes}

for left, right in PROMOTE_TO_OBJECT:
if any(issubclass(t, left) for t in types) and any(
issubclass(t, right) for t in types
):
return np.dtype(object)

return np.result_type(*arrays_and_dtypes)
27 changes: 27 additions & 0 deletions xarray/namedarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import importlib
import sys
from collections.abc import Hashable
from enum import Enum
from typing import TYPE_CHECKING, Any, Final, Protocol, TypeVar

Expand Down Expand Up @@ -134,3 +135,29 @@ def to_0d_object_array(
result = np.empty((), dtype=object)
result[()] = value
return result


class ReprObject:
"""Object that prints as the given value, for use with sentinel values."""

__slots__ = ("_value",)

_value: str

def __init__(self, value: str):
self._value = value

def __repr__(self) -> str:
return self._value

def __eq__(self, other: ReprObject | Any) -> bool:
# TODO: What type can other be? ArrayLike?
return self._value == other._value if isinstance(other, ReprObject) else False

def __hash__(self) -> int:
return hash((type(self), self._value))

def __dask_tokenize__(self) -> Hashable:
from dask.base import normalize_token

return normalize_token((type(self), self._value)) # type: ignore[no-any-return]

0 comments on commit 8d54acf

Please sign in to comment.