Skip to content

Commit

Permalink
disable inherited __init__ with immutable(init=False)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jean-Francois Zinque committed May 13, 2021
1 parent ec23e0c commit 3cd08e8
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 61 deletions.
99 changes: 68 additions & 31 deletions pandera/dtypes_.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,6 @@
import dataclasses
import functools
from dataclasses import dataclass, field
from typing import Any, Tuple, Type, Union

try: # python 3.8+
from typing import Literal # type: ignore
except ImportError:
from typing_extensions import Literal # type: ignore


def immutable(dtype=None, **kwargs) -> Type:
dataclass_kwargs = {"frozen": True, "init": False, "repr": False}
dataclass_kwargs.update(kwargs)

if dtype is None:
return functools.partial(dataclass, **dataclass_kwargs)
return dataclass(**dataclass_kwargs)(dtype)


class DisableInitMixin:
def __init__(self) -> None:
pass
from typing import Any, Tuple, Type


class DataType:
Expand All @@ -37,20 +18,76 @@ def coerce(self, obj: Any):
"""Coerce object to the dtype."""
raise NotImplementedError()

def check(self, datatype: "DataType") -> bool:
if not isinstance(datatype, DataType):
return False
return self == datatype

def __repr__(self) -> str:
return f"DataType({str(self)})"

def __str__(self) -> str:
"""Must be implemented by subclasses."""
raise NotImplementedError()

def check(self, datatype: "DataType") -> bool:
if not isinstance(datatype, DataType):
return False
return self == datatype

def __hash__(self) -> int:
pass
raise NotImplementedError()


def immutable(
dtype: Type[DataType] = None, **dataclass_kwargs: Any
) -> Type[DataType]:
""":func:`dataclasses.dataclass` decorator with different default values:
`frozen=True`, `init=False`, `repr=False`.
:param dtype: :class:`DataType` to decorate.
:param dataclass_kwargs: Keywords arguments forwarded to
:func:`dataclasses.dataclass`.
:returns: Immutable :class:`DataType`
"""
kwargs = {"frozen": True, "init": False, "repr": False}
kwargs.update(dataclass_kwargs)


# if dtype is None:
# return functools.partial(dataclasses.dataclass, **kwargs)
# return dataclasses.dataclass(**kwargs)(dtype)


def immutable(
dtype: Type[DataType] = None, **dataclass_kwargs: Any
) -> Type[DataType]:
""":func:`dataclasses.dataclass` decorator with different default values:
`frozen=True`, `init=False`, `repr=False`.
In addition, `init=False` disables inherited `__init__` method to ensure
the DataType's default attributes are not altered during initialization.
:param dtype: :class:`DataType` to decorate.
:param dataclass_kwargs: Keywords arguments forwarded to
:func:`dataclasses.dataclass`.
:returns: Immutable :class:`DataType`
"""
kwargs = {"frozen": True, "init": False, "repr": False}
kwargs.update(dataclass_kwargs)

def _wrapper(dtype):
immutable_dtype = dataclasses.dataclass(**kwargs)(dtype)
if not kwargs["init"]:

def __init__(self):
pass

# delattr(immutable_dtype, "__init__") doesn't work because
# super.__init__ would still exist.
setattr(immutable_dtype, "__init__", __init__)

return immutable_dtype

if dtype is None:
return _wrapper

return _wrapper(dtype)


################################################################################
Expand Down Expand Up @@ -87,7 +124,7 @@ def check(self, datatype: "DataType") -> bool:
@immutable
class _PhysicalNumber(_Number):
bit_width: int = None
_base_name: str = field(default=None, init=False, repr=False)
_base_name: str = dataclasses.field(default=None, init=False, repr=False)

def __eq__(self, obj: object) -> bool:
if isinstance(obj, type(self)):
Expand All @@ -109,7 +146,7 @@ class Int(_PhysicalNumber):
continuous = False
exact = True
bit_width = 64
signed: bool = field(default=True, init=False)
signed: bool = dataclasses.field(default=True, init=False)


@immutable
Expand Down Expand Up @@ -140,7 +177,7 @@ class Int8(Int16):
@immutable
class UInt(Int):
_base_name = "uint"
signed: bool = field(default=False, init=False)
signed: bool = dataclasses.field(default=False, init=False)


@immutable
Expand Down Expand Up @@ -228,7 +265,7 @@ class Complex64(Complex128):


@immutable(init=True)
class Category(DisableInitMixin, DataType):
class Category(DataType):
categories: Tuple[Any] = None # immutable sequence to ensure safe hash
ordered: bool = False

Expand Down
32 changes: 18 additions & 14 deletions pandera/engines/numpy_engine.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import builtins
import dataclasses
import datetime
from dataclasses import field
from typing import Any, List

import numpy as np

from .. import dtypes_
from ..dtypes_ import DisableInitMixin, immutable
from ..dtypes_ import immutable
from . import engine


@immutable(init=True)
class DataType(dtypes_.DataType):
type: np.dtype = field(default=np.dtype("object"), repr=False)
type: np.dtype = dataclasses.field(
default=np.dtype("object"), repr=False, init=False
)

def __init__(self, dtype: Any):
object.__setattr__(self, "type", np.dtype(dtype))

def __post_init__(self):
object.__setattr__(self, "type", np.dtype(self.type))
Expand All @@ -39,6 +44,7 @@ def dtype(cls, data_type: Any) -> "DataType":
raise TypeError(
f"data type '{data_type}' not understood by {cls.__name__}."
) from None

try:
return engine.Engine.dtype(cls, np_dtype)
except TypeError:
Expand All @@ -54,7 +60,7 @@ def dtype(cls, data_type: Any) -> "DataType":
equivalents=["bool", bool, np.bool_, dtypes_.Bool, dtypes_.Bool()]
)
@immutable
class Bool(DisableInitMixin, DataType, dtypes_.Bool):
class Bool(DataType, dtypes_.Bool):
"""representation of a boolean data type."""

type = np.dtype("bool")
Expand Down Expand Up @@ -106,7 +112,7 @@ def _build_number_equivalents(

@Engine.register_dtype(equivalents=_int_equivalents[64])
@immutable
class Int64(DisableInitMixin, DataType, dtypes_.Int64):
class Int64(DataType, dtypes_.Int64):
type = np.dtype("int64")
bit_width: int = 64

Expand Down Expand Up @@ -145,7 +151,7 @@ class Int8(Int16):

@Engine.register_dtype(equivalents=_uint_equivalents[64])
@immutable
class UInt64(DisableInitMixin, DataType, dtypes_.UInt64):
class UInt64(DataType, dtypes_.UInt64):
type = np.dtype("uint64")
bit_width: int = 64

Expand Down Expand Up @@ -184,7 +190,7 @@ class UInt8(UInt16):

@Engine.register_dtype(equivalents=_float_equivalents[128])
@immutable
class Float128(DisableInitMixin, DataType, dtypes_.Float128):
class Float128(DataType, dtypes_.Float128):
type = np.dtype("float128")
bit_width: int = 128

Expand Down Expand Up @@ -223,7 +229,7 @@ class Float16(Float32):

@Engine.register_dtype(equivalents=_complex_equivalents[256])
@immutable
class Complex256(DisableInitMixin, DataType, dtypes_.Complex256):
class Complex256(DataType, dtypes_.Complex256):
type = np.dtype("complex256")
bit_width: int = 256

Expand All @@ -249,7 +255,7 @@ class Complex64(Complex128):

@Engine.register_dtype(equivalents=["str", "string", str, np.str_])
@immutable
class String(DisableInitMixin, DataType, dtypes_.String):
class String(DataType, dtypes_.String):
type = np.dtype("str")

def coerce(self, arr: np.ndarray) -> np.ndarray:
Expand All @@ -269,12 +275,10 @@ def check(self, datatype: "dtypes_.DataType") -> bool:

@Engine.register_dtype(equivalents=["object", "O", object, np.object_])
@immutable
class Object(DisableInitMixin, DataType):
class Object(DataType):
type = np.dtype("object")


Object = Object

################################################################################
# time
################################################################################
Expand All @@ -289,7 +293,7 @@ class Object(DisableInitMixin, DataType):
]
)
@immutable
class DateTime64(DisableInitMixin, DataType, dtypes_.Timestamp):
class DateTime64(DataType, dtypes_.Timestamp):
type = np.dtype("datetime64")


Expand All @@ -302,5 +306,5 @@ class DateTime64(DisableInitMixin, DataType, dtypes_.Timestamp):
]
)
@immutable
class Timedelta64(DisableInitMixin, DataType, dtypes_.Timedelta):
class Timedelta64(DataType, dtypes_.Timedelta):
type = np.dtype("timedelta64")
Loading

0 comments on commit 3cd08e8

Please sign in to comment.