Skip to content

Commit

Permalink
Fix to_yaml serialization dropping global checks (#428)
Browse files Browse the repository at this point in the history
* increase cache version

* ci: add dataframe checks tests

* bugfix: allow serialization of dataframe checks to_yaml

* ci: add test to ensure serialization of lambda check fails

* bugfix: ensure checks with no parameters generate appropriate schema

* wip: allow looking up registered checks

* fix: compare checks by name rather than by object equality

* ci: black

* ci: lint

* enh: use REGISTERED_CUSTOM_CHECKS for attribute lookup, add dir method

* enh: add to_yaml method to Schema, add unit test

* ci: disable typechecking on _CheckMeta

* ci: isort

* ci: doctests

* ci: improve coverage

* ci: codecov

In these lines, dataframe_checks cannot be None based on the call condition.

* fix unrecognized check dtype during (de)serialization

* fix handle_stat__dtype closures

* enh: move metaclass onto _CheckBase

* ci: add test that ensures to_yaml warns on unregistered checks

* ci: revert adding duplicate test

Co-authored-by: cosmicBboy <[email protected]>
Co-authored-by: Jean-Francois Zinque <[email protected]>
  • Loading branch information
3 people authored Mar 23, 2021
1 parent c85ce63 commit 32543d4
Show file tree
Hide file tree
Showing 13 changed files with 358 additions and 48 deletions.
1 change: 1 addition & 0 deletions docs/source/schema_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ is a convenience method for this functionality.
coerce: false
required: true
regex: false
checks: null
index:
- pandas_dtype: int64
nullable: false
Expand Down
89 changes: 69 additions & 20 deletions pandera/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,21 @@
import inspect
import operator
import re
from collections import namedtuple
from collections import ChainMap, namedtuple
from functools import partial, wraps
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from itertools import chain
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Type,
TypeVar,
Union,
no_type_check,
)

import pandas as pd

Expand Down Expand Up @@ -51,7 +63,45 @@ def _wrapper(cls, *args, **kwargs):
return register_check_statistics_decorator


class _CheckBase:
_T = TypeVar("_T", bound="_CheckBase")


class _CheckMeta(type): # pragma: no cover
"""Check metaclass."""

REGISTERED_CUSTOM_CHECKS: Dict[str, Callable] = {} # noqa

def __getattr__(cls, name: str) -> Any:
"""Prevent attribute errors for registered checks."""
attr = ChainMap(cls.__dict__, cls.REGISTERED_CUSTOM_CHECKS).get(name)
if attr is None:
raise AttributeError(
f"'{cls}' object has no attribute '{name}'. "
"Make sure any custom checks have been registered "
"using the extensions api."
)
return attr

def __dir__(cls) -> Iterable[str]:
"""Allow custom checks to show up as attributes when autocompleting."""
return chain(super().__dir__(), cls.REGISTERED_CUSTOM_CHECKS.keys())

# pylint: disable=line-too-long
# mypy has limited metaclass support so this doesn't pass typecheck
# see https://mypy.readthedocs.io/en/stable/metaclasses.html#gotchas-and-limitations-of-metaclass-support
# pylint: enable=line-too-long
@no_type_check
def __contains__(cls: Type[_T], item: Union[_T, str]) -> bool:
"""Allow lookups for registered checks."""
if isinstance(item, cls):
name = item.name
return hasattr(cls, name)

# assume item is str
return hasattr(cls, item)


class _CheckBase(metaclass=_CheckMeta):
"""Check base class."""

def __init__(
Expand Down Expand Up @@ -397,9 +447,11 @@ def __call__(
)

def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented

are_check_fn_objects_equal = (
self.__dict__["_check_fn"].__code__.co_code
== other.__dict__["_check_fn"].__code__.co_code
self._get_check_fn_code() == other._get_check_fn_code()
)

try:
Expand Down Expand Up @@ -427,8 +479,18 @@ def __eq__(self, other):
and are_all_other_check_attributes_equal
)

def _get_check_fn_code(self):
check_fn = self.__dict__["_check_fn"]
try:
code = check_fn.__code__.co_code
except AttributeError:
# try accessing the functools.partial wrapper
code = check_fn.func.__code__.co_code

return code

def __hash__(self):
return hash(self.__dict__["_check_fn"].__code__.co_code)
return hash(self._get_check_fn_code())

def __repr__(self):
return (
Expand All @@ -438,22 +500,9 @@ def __repr__(self):
)


class _CheckMeta(type): # pragma: no cover
"""Check metaclass."""

def __getattr__(cls, name: str) -> Any:
"""Prevent attribute errors for registered checks."""
attr = cls.__dict__.get(name)
if attr is None:
raise AttributeError(f"'{cls}' object has no attribute '{name}'")
return attr


class Check(_CheckBase, metaclass=_CheckMeta):
class Check(_CheckBase):
"""Check a pandas Series or DataFrame for certain properties."""

REGISTERED_CUSTOM_CHECKS: Dict[str, Callable] = {} # noqa

@classmethod
@st.register_check_strategy(st.eq_strategy)
@register_check_statistics(["value"])
Expand Down
5 changes: 2 additions & 3 deletions pandera/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ def check_method(cls, *args, **kwargs):
if strategy is not None:
check_method = st.register_check_strategy(strategy)(check_method)

setattr(Check, check_fn.__name__, classmethod(check_method))
Check.REGISTERED_CUSTOM_CHECKS[check_fn.__name__] = getattr(
Check, check_fn.__name__
Check.REGISTERED_CUSTOM_CHECKS[check_fn.__name__] = partial(
check_method, Check
)

return register_check_wrapper(check_fn)
46 changes: 39 additions & 7 deletions pandera/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
NOT_JSON_SERIALIZABLE = {PandasDtype.DateTime, PandasDtype.Timedelta}


def _serialize_check_stats(check_stats, pandas_dtype):
def _serialize_check_stats(check_stats, pandas_dtype=None):
"""Serialize check statistics into json/yaml-compatible format."""

def handle_stat_dtype(stat):
Expand All @@ -34,6 +34,7 @@ def handle_stat_dtype(stat):
elif pandas_dtype == PandasDtype.Timedelta:
# serialize to int in nanoseconds
return stat.delta

return stat

# for unary checks, return a single value instead of a dictionary
Expand All @@ -47,18 +48,37 @@ def handle_stat_dtype(stat):
return serialized_check_stats


def _serialize_dataframe_stats(dataframe_checks):
"""
Serialize global dataframe check statistics into json/yaml-compatible format.
"""
serialized_checks = {}

for check_name, check_stats in dataframe_checks.items():
# The case that `check_name` is not registered is handled in `parse_checks`,
# so we know that `check_name` exists.

# infer dtype of statistics and serialize them
serialized_checks[check_name] = _serialize_check_stats(check_stats)

return serialized_checks


def _serialize_component_stats(component_stats):
"""
Serialize column or index statistics into json/yaml-compatible format.
"""
# pylint: disable=import-outside-toplevel
from pandera.checks import Check

serialized_checks = None
if component_stats["checks"] is not None:
serialized_checks = {}
for check_name, check_stats in component_stats["checks"].items():
if check_stats is None:
if check_name not in Check:
warnings.warn(
f"Check {check_name} cannot be serialized. This check will be "
f"ignored"
"ignored. Did you forget to register it with the extension API?"
)
else:
serialized_checks[check_name] = _serialize_check_stats(
Expand Down Expand Up @@ -93,7 +113,7 @@ def _serialize_schema(dataframe_schema):

statistics = get_dataframe_schema_statistics(dataframe_schema)

columns, index = None, None
columns, index, checks = None, None, None
if statistics["columns"] is not None:
columns = {
col_name: _serialize_component_stats(column_stats)
Expand All @@ -106,17 +126,21 @@ def _serialize_schema(dataframe_schema):
for index_stats in statistics["index"]
]

if statistics["checks"] is not None:
checks = _serialize_dataframe_stats(statistics["checks"])

return {
"schema_type": "dataframe",
"version": __version__,
"columns": columns,
"checks": checks,
"index": index,
"coerce": dataframe_schema.coerce,
"strict": dataframe_schema.strict,
}


def _deserialize_check_stats(check, serialized_check_stats, pandas_dtype):
def _deserialize_check_stats(check, serialized_check_stats, pandas_dtype=None):
def handle_stat_dtype(stat):
if pandas_dtype == PandasDtype.DateTime:
return pd.to_datetime(stat, format=DATETIME_FORMAT)
Expand Down Expand Up @@ -173,9 +197,9 @@ def _deserialize_component_stats(serialized_component_stats):

def _deserialize_schema(serialized_schema):
# pylint: disable=import-outside-toplevel
from pandera import Column, DataFrameSchema, Index, MultiIndex
from pandera import Check, Column, DataFrameSchema, Index, MultiIndex

columns, index = None, None
columns, index, checks = None, None, None
if serialized_schema["columns"] is not None:
columns = {
col_name: Column(**_deserialize_component_stats(column_stats))
Expand All @@ -188,6 +212,13 @@ def _deserialize_schema(serialized_schema):
for index_component in serialized_schema["index"]
]

if serialized_schema["checks"] is not None:
# handles unregistered checks by raising AttributeErrors from getattr
checks = [
_deserialize_check_stats(getattr(Check, check_name), check_stats)
for check_name, check_stats in serialized_schema["checks"].items()
]

if index is None:
pass
elif len(index) == 1:
Expand All @@ -199,6 +230,7 @@ def _deserialize_schema(serialized_schema):

return DataFrameSchema(
columns=columns,
checks=checks,
index=index,
coerce=serialized_schema["coerce"],
strict=serialized_schema["strict"],
Expand Down
8 changes: 8 additions & 0 deletions pandera/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Class-based api"""
import inspect
import os
import re
import sys
import typing
Expand Down Expand Up @@ -170,6 +171,13 @@ def to_schema(cls) -> DataFrameSchema:
MODEL_CACHE[cls] = cls.__schema__
return cls.__schema__

@classmethod
def to_yaml(cls, stream: Optional[os.PathLike] = None):
"""
Convert `Schema` to yaml using `io.to_yaml`.
"""
return cls.to_schema().to_yaml(stream)

@classmethod
@pd.util.Substitution(validate_doc=DataFrameSchema.validate.__doc__)
def validate(
Expand Down
13 changes: 12 additions & 1 deletion pandera/schema_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def get_dataframe_schema_statistics(dataframe_schema):
}
for col_name, column in dataframe_schema.columns.items()
},
"checks": parse_checks(dataframe_schema.checks),
"index": (
None
if dataframe_schema.index is None
Expand Down Expand Up @@ -158,7 +159,17 @@ def parse_checks(checks) -> Union[Dict[str, Any], None]:
check_statistics = {}
_check_memo = {}
for check in checks:
check_statistics[check.name] = check.statistics
if check not in Check:
warnings.warn(
"Only registered checks may be serialized to statistics. "
"Did you forget to register it with the extension API? "
f"Check `{check.name}` will be skipped."
)
continue

check_statistics[check.name] = (
{} if check.statistics is None else check.statistics
)
_check_memo[check.name] = check

# raise ValueError on incompatible checks
Expand Down
6 changes: 3 additions & 3 deletions pandera/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import copy
import itertools
import os
import warnings
from functools import wraps
from pathlib import Path
Expand Down Expand Up @@ -1186,17 +1187,16 @@ def from_yaml(cls, yaml_schema) -> "DataFrameSchema":

return pandera.io.from_yaml(yaml_schema)

def to_yaml(self, fp: Union[str, Path] = None):
def to_yaml(self, stream: Optional[os.PathLike] = None):
"""Write DataFrameSchema to yaml file.
:param dataframe_schema: schema to write to file or dump to string.
:param stream: file stream to write to. If None, dumps to string.
:returns: yaml string if stream is None, otherwise returns None.
"""
# pylint: disable=import-outside-toplevel,cyclic-import
import pandera.io

return pandera.io.to_yaml(self, fp)
return pandera.io.to_yaml(self, stream=stream)

def set_index(
self, keys: List[str], drop: bool = True, append: bool = False
Expand Down
33 changes: 33 additions & 0 deletions tests/core/checks_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Pytest fixtures for testing custom checks."""
import unittest.mock as mock

import pandas as pd
import pytest

import pandera as pa
import pandera.extensions as pa_ext

__all__ = "custom_check_teardown", "extra_registered_checks"


@pytest.fixture(scope="function")
def custom_check_teardown():
"""Remove all custom checks after execution of each pytest function."""
yield
for check_name in list(pa.Check.REGISTERED_CUSTOM_CHECKS):
del pa.Check.REGISTERED_CUSTOM_CHECKS[check_name]


@pytest.fixture(scope="function")
def extra_registered_checks():
"""temporarily registers custom checks onto the Check class"""
# pylint: disable=unused-variable
with mock.patch(
"pandera.Check.REGISTERED_CUSTOM_CHECKS", new_callable=dict
):
# register custom checks here
@pa_ext.register_check_method()
def no_param_check(_: pd.DataFrame) -> bool:
return True

yield
4 changes: 4 additions & 0 deletions tests/core/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Registers fixtures for core"""

# pylint: disable=unused-import
from .checks_fixtures import custom_check_teardown, extra_registered_checks
Loading

0 comments on commit 32543d4

Please sign in to comment.