From 3cd08e85be75449a2b1a13b3245bb112d5d5d749 Mon Sep 17 00:00:00 2001 From: Jean-Francois Zinque Date: Thu, 13 May 2021 23:28:38 +0200 Subject: [PATCH] disable inherited __init__ with immutable(init=False) --- pandera/dtypes_.py | 99 ++++++++++++++++++++++---------- pandera/engines/numpy_engine.py | 32 ++++++----- pandera/engines/pandas_engine.py | 35 +++++------ 3 files changed, 105 insertions(+), 61 deletions(-) diff --git a/pandera/dtypes_.py b/pandera/dtypes_.py index 3e6aeb83f..f9eb822d6 100644 --- a/pandera/dtypes_.py +++ b/pandera/dtypes_.py @@ -1,25 +1,6 @@ +import dataclasses import functools -from dataclasses import dataclass, field -from typing import Any, Tuple, Type, Union - -try: # python 3.8+ - from typing import Literal # type: ignore -except ImportError: - from typing_extensions import Literal # type: ignore - - -def immutable(dtype=None, **kwargs) -> Type: - dataclass_kwargs = {"frozen": True, "init": False, "repr": False} - dataclass_kwargs.update(kwargs) - - if dtype is None: - return functools.partial(dataclass, **dataclass_kwargs) - return dataclass(**dataclass_kwargs)(dtype) - - -class DisableInitMixin: - def __init__(self) -> None: - pass +from typing import Any, Tuple, Type class DataType: @@ -37,6 +18,11 @@ def coerce(self, obj: Any): """Coerce object to the dtype.""" raise NotImplementedError() + def check(self, datatype: "DataType") -> bool: + if not isinstance(datatype, DataType): + return False + return self == datatype + def __repr__(self) -> str: return f"DataType({str(self)})" @@ -44,13 +30,64 @@ def __str__(self) -> str: """Must be implemented by subclasses.""" raise NotImplementedError() - def check(self, datatype: "DataType") -> bool: - if not isinstance(datatype, DataType): - return False - return self == datatype - def __hash__(self) -> int: - pass + raise NotImplementedError() + + +def immutable( + dtype: Type[DataType] = None, **dataclass_kwargs: Any +) -> Type[DataType]: + """:func:`dataclasses.dataclass` decorator with different default values: + `frozen=True`, `init=False`, `repr=False`. + + :param dtype: :class:`DataType` to decorate. + :param dataclass_kwargs: Keywords arguments forwarded to + :func:`dataclasses.dataclass`. + :returns: Immutable :class:`DataType` + """ + kwargs = {"frozen": True, "init": False, "repr": False} + kwargs.update(dataclass_kwargs) + + +# if dtype is None: +# return functools.partial(dataclasses.dataclass, **kwargs) +# return dataclasses.dataclass(**kwargs)(dtype) + + +def immutable( + dtype: Type[DataType] = None, **dataclass_kwargs: Any +) -> Type[DataType]: + """:func:`dataclasses.dataclass` decorator with different default values: + `frozen=True`, `init=False`, `repr=False`. + + In addition, `init=False` disables inherited `__init__` method to ensure + the DataType's default attributes are not altered during initialization. + + :param dtype: :class:`DataType` to decorate. + :param dataclass_kwargs: Keywords arguments forwarded to + :func:`dataclasses.dataclass`. + :returns: Immutable :class:`DataType` + """ + kwargs = {"frozen": True, "init": False, "repr": False} + kwargs.update(dataclass_kwargs) + + def _wrapper(dtype): + immutable_dtype = dataclasses.dataclass(**kwargs)(dtype) + if not kwargs["init"]: + + def __init__(self): + pass + + # delattr(immutable_dtype, "__init__") doesn't work because + # super.__init__ would still exist. + setattr(immutable_dtype, "__init__", __init__) + + return immutable_dtype + + if dtype is None: + return _wrapper + + return _wrapper(dtype) ################################################################################ @@ -87,7 +124,7 @@ def check(self, datatype: "DataType") -> bool: @immutable class _PhysicalNumber(_Number): bit_width: int = None - _base_name: str = field(default=None, init=False, repr=False) + _base_name: str = dataclasses.field(default=None, init=False, repr=False) def __eq__(self, obj: object) -> bool: if isinstance(obj, type(self)): @@ -109,7 +146,7 @@ class Int(_PhysicalNumber): continuous = False exact = True bit_width = 64 - signed: bool = field(default=True, init=False) + signed: bool = dataclasses.field(default=True, init=False) @immutable @@ -140,7 +177,7 @@ class Int8(Int16): @immutable class UInt(Int): _base_name = "uint" - signed: bool = field(default=False, init=False) + signed: bool = dataclasses.field(default=False, init=False) @immutable @@ -228,7 +265,7 @@ class Complex64(Complex128): @immutable(init=True) -class Category(DisableInitMixin, DataType): +class Category(DataType): categories: Tuple[Any] = None # immutable sequence to ensure safe hash ordered: bool = False diff --git a/pandera/engines/numpy_engine.py b/pandera/engines/numpy_engine.py index 34c49481d..bb1f95acc 100644 --- a/pandera/engines/numpy_engine.py +++ b/pandera/engines/numpy_engine.py @@ -1,18 +1,23 @@ import builtins +import dataclasses import datetime -from dataclasses import field from typing import Any, List import numpy as np from .. import dtypes_ -from ..dtypes_ import DisableInitMixin, immutable +from ..dtypes_ import immutable from . import engine @immutable(init=True) class DataType(dtypes_.DataType): - type: np.dtype = field(default=np.dtype("object"), repr=False) + type: np.dtype = dataclasses.field( + default=np.dtype("object"), repr=False, init=False + ) + + def __init__(self, dtype: Any): + object.__setattr__(self, "type", np.dtype(dtype)) def __post_init__(self): object.__setattr__(self, "type", np.dtype(self.type)) @@ -39,6 +44,7 @@ def dtype(cls, data_type: Any) -> "DataType": raise TypeError( f"data type '{data_type}' not understood by {cls.__name__}." ) from None + try: return engine.Engine.dtype(cls, np_dtype) except TypeError: @@ -54,7 +60,7 @@ def dtype(cls, data_type: Any) -> "DataType": equivalents=["bool", bool, np.bool_, dtypes_.Bool, dtypes_.Bool()] ) @immutable -class Bool(DisableInitMixin, DataType, dtypes_.Bool): +class Bool(DataType, dtypes_.Bool): """representation of a boolean data type.""" type = np.dtype("bool") @@ -106,7 +112,7 @@ def _build_number_equivalents( @Engine.register_dtype(equivalents=_int_equivalents[64]) @immutable -class Int64(DisableInitMixin, DataType, dtypes_.Int64): +class Int64(DataType, dtypes_.Int64): type = np.dtype("int64") bit_width: int = 64 @@ -145,7 +151,7 @@ class Int8(Int16): @Engine.register_dtype(equivalents=_uint_equivalents[64]) @immutable -class UInt64(DisableInitMixin, DataType, dtypes_.UInt64): +class UInt64(DataType, dtypes_.UInt64): type = np.dtype("uint64") bit_width: int = 64 @@ -184,7 +190,7 @@ class UInt8(UInt16): @Engine.register_dtype(equivalents=_float_equivalents[128]) @immutable -class Float128(DisableInitMixin, DataType, dtypes_.Float128): +class Float128(DataType, dtypes_.Float128): type = np.dtype("float128") bit_width: int = 128 @@ -223,7 +229,7 @@ class Float16(Float32): @Engine.register_dtype(equivalents=_complex_equivalents[256]) @immutable -class Complex256(DisableInitMixin, DataType, dtypes_.Complex256): +class Complex256(DataType, dtypes_.Complex256): type = np.dtype("complex256") bit_width: int = 256 @@ -249,7 +255,7 @@ class Complex64(Complex128): @Engine.register_dtype(equivalents=["str", "string", str, np.str_]) @immutable -class String(DisableInitMixin, DataType, dtypes_.String): +class String(DataType, dtypes_.String): type = np.dtype("str") def coerce(self, arr: np.ndarray) -> np.ndarray: @@ -269,12 +275,10 @@ def check(self, datatype: "dtypes_.DataType") -> bool: @Engine.register_dtype(equivalents=["object", "O", object, np.object_]) @immutable -class Object(DisableInitMixin, DataType): +class Object(DataType): type = np.dtype("object") -Object = Object - ################################################################################ # time ################################################################################ @@ -289,7 +293,7 @@ class Object(DisableInitMixin, DataType): ] ) @immutable -class DateTime64(DisableInitMixin, DataType, dtypes_.Timestamp): +class DateTime64(DataType, dtypes_.Timestamp): type = np.dtype("datetime64") @@ -302,5 +306,5 @@ class DateTime64(DisableInitMixin, DataType, dtypes_.Timestamp): ] ) @immutable -class Timedelta64(DisableInitMixin, DataType, dtypes_.Timedelta): +class Timedelta64(DataType, dtypes_.Timedelta): type = np.dtype("timedelta64") diff --git a/pandera/engines/pandas_engine.py b/pandera/engines/pandas_engine.py index 2c3a7dee8..429965889 100644 --- a/pandera/engines/pandas_engine.py +++ b/pandera/engines/pandas_engine.py @@ -1,13 +1,13 @@ import builtins +import dataclasses import datetime -from dataclasses import field from typing import Any, Dict, List, Union import numpy as np import pandas as pd from .. import dtypes_ -from ..dtypes_ import DisableInitMixin, immutable +from ..dtypes_ import immutable from . import engine, numpy_engine PandasObject = Union[pd.Series, pd.Index, pd.DataFrame] @@ -23,7 +23,10 @@ def is_extension_dtype(dtype): @immutable(init=True) class DataType(dtypes_.DataType): - type: Any = field(repr=False) + type: Any = dataclasses.field(repr=False, init=False) + + def __init__(self, dtype: Any): + object.__setattr__(self, "type", pd.api.types.pandas_dtype(dtype)) def __post_init__(self): object.__setattr__(self, "type", pd.api.types.pandas_dtype(self.type)) @@ -93,7 +96,7 @@ def dtype(cls, obj: Any) -> "DataType": equivalents=["boolean", pd.BooleanDtype, pd.BooleanDtype()], ) @immutable -class Bool(DisableInitMixin, DataType, dtypes_.Bool): +class Bool(DataType, dtypes_.Bool): type = pd.BooleanDtype() @@ -164,7 +167,7 @@ def _register_numpy_numbers( @Engine.register_dtype(equivalents=[pd.Int64Dtype, pd.Int64Dtype()]) @immutable -class Int64(DisableInitMixin, DataType, dtypes_.Int): +class Int64(DataType, dtypes_.Int): type = pd.Int64Dtype() bit_width: int = 64 @@ -214,7 +217,7 @@ class Int8(Int16): @Engine.register_dtype(equivalents=[pd.UInt64Dtype, pd.UInt64Dtype()]) @immutable -class UInt64(DisableInitMixin, DataType, dtypes_.UInt): +class UInt64(DataType, dtypes_.UInt): type = pd.UInt64Dtype() bit_width: int = 64 @@ -280,7 +283,7 @@ class UInt8(UInt16): ) @immutable(init=True) class Category(DataType, dtypes_.Category): - type: pd.CategoricalDtype = field(default=None, init=False) + type: pd.CategoricalDtype = dataclasses.field(default=None, init=False) def __post_init__(self): dtypes_.Category.__post_init__(self) @@ -301,7 +304,7 @@ def from_parametrized_dtype( equivalents=["string", pd.StringDtype, pd.StringDtype()] ) @immutable -class String(DisableInitMixin, DataType, dtypes_.String): +class String(DataType, dtypes_.String): type = pd.StringDtype() @@ -357,12 +360,12 @@ def check(self, datatype: "DataType") -> bool: ) @immutable(init=True) class DateTime(DataType, dtypes_.Timestamp): - type: Union[np.datetime64, pd.DatetimeTZDtype] = field( + type: Union[np.datetime64, pd.DatetimeTZDtype] = dataclasses.field( default=None, init=False ) unit: str = "ns" tz: datetime.tzinfo = None - to_datetime_kwargs: Dict[str, Any] = field( + to_datetime_kwargs: Dict[str, Any] = dataclasses.field( default=None, compare=False, repr=False ) @@ -380,7 +383,7 @@ def coerce(self, obj: PandasObject) -> PandasObject: kwargs = self.to_datetime_kwargs or {} def _to_datetime(col: pd.Series) -> pd.Series: - return pd.to_datetime(col, **kwargs).astype(self.type) + return pd.to_datetime(col, **kwargs).astype(self.type).to_series() if isinstance(obj, pd.DataFrame): # pd.to_datetime transforms a df input into a series. @@ -412,12 +415,12 @@ def __str__(self) -> str: ) @immutable(init=True) class DateTime(DataType, dtypes_.Timestamp): - type: Union[np.datetime64, pd.DatetimeTZDtype] = field( + type: Union[np.datetime64, pd.DatetimeTZDtype] = dataclasses.field( default=None, init=False ) unit: str = "ns" tz: datetime.tzinfo = None - to_datetime_kwargs: Dict[str, Any] = field( + to_datetime_kwargs: Dict[str, Any] = dataclasses.field( default=None, compare=False, repr=False ) @@ -470,7 +473,7 @@ def __str__(self) -> str: @Engine.register_dtype @immutable(init=True) class Period(DataType): - type: pd.PeriodDtype = field(default=None, init=False) + type: pd.PeriodDtype = dataclasses.field(default=None, init=False) freq: Union[str, pd.tseries.offsets.DateOffset] def __post_init__(self): @@ -489,7 +492,7 @@ def from_parametrized_dtype(cls, pd_dtype: pd.PeriodDtype): @Engine.register_dtype(equivalents=[pd.SparseDtype]) @immutable(init=True) class Sparse(DataType): - type: pd.SparseDtype = field(default=None, init=False) + type: pd.SparseDtype = dataclasses.field(default=None, init=False) dtype: Union[str, PandasExtensionType, np.dtype, "type"] = np.float_ fill_value: Any = np.nan @@ -508,7 +511,7 @@ def from_parametrized_dtype(cls, pd_dtype: pd.SparseDtype): @Engine.register_dtype @immutable(init=True) class Interval(DataType): - type: pd.IntervalDtype = field(default=None, init=False) + type: pd.IntervalDtype = dataclasses.field(default=None, init=False) subtype: Union[str, np.dtype] def __post_init__(self):