diff --git a/docs/source/schema_inference.rst b/docs/source/schema_inference.rst index 04f184b53..bba6889bf 100644 --- a/docs/source/schema_inference.rst +++ b/docs/source/schema_inference.rst @@ -214,6 +214,7 @@ is a convenience method for this functionality. coerce: false required: true regex: false + checks: null index: - pandas_dtype: int64 nullable: false diff --git a/pandera/checks.py b/pandera/checks.py index 8046b4ba2..0345dc39b 100644 --- a/pandera/checks.py +++ b/pandera/checks.py @@ -3,9 +3,21 @@ import inspect import operator import re -from collections import namedtuple +from collections import ChainMap, namedtuple from functools import partial, wraps -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from itertools import chain +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Type, + TypeVar, + Union, + no_type_check, +) import pandas as pd @@ -51,7 +63,45 @@ def _wrapper(cls, *args, **kwargs): return register_check_statistics_decorator -class _CheckBase: +_T = TypeVar("_T", bound="_CheckBase") + + +class _CheckMeta(type): # pragma: no cover + """Check metaclass.""" + + REGISTERED_CUSTOM_CHECKS: Dict[str, Callable] = {} # noqa + + def __getattr__(cls, name: str) -> Any: + """Prevent attribute errors for registered checks.""" + attr = ChainMap(cls.__dict__, cls.REGISTERED_CUSTOM_CHECKS).get(name) + if attr is None: + raise AttributeError( + f"'{cls}' object has no attribute '{name}'. " + "Make sure any custom checks have been registered " + "using the extensions api." + ) + return attr + + def __dir__(cls) -> Iterable[str]: + """Allow custom checks to show up as attributes when autocompleting.""" + return chain(super().__dir__(), cls.REGISTERED_CUSTOM_CHECKS.keys()) + + # pylint: disable=line-too-long + # mypy has limited metaclass support so this doesn't pass typecheck + # see https://mypy.readthedocs.io/en/stable/metaclasses.html#gotchas-and-limitations-of-metaclass-support + # pylint: enable=line-too-long + @no_type_check + def __contains__(cls: Type[_T], item: Union[_T, str]) -> bool: + """Allow lookups for registered checks.""" + if isinstance(item, cls): + name = item.name + return hasattr(cls, name) + + # assume item is str + return hasattr(cls, item) + + +class _CheckBase(metaclass=_CheckMeta): """Check base class.""" def __init__( @@ -397,9 +447,11 @@ def __call__( ) def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + are_check_fn_objects_equal = ( - self.__dict__["_check_fn"].__code__.co_code - == other.__dict__["_check_fn"].__code__.co_code + self._get_check_fn_code() == other._get_check_fn_code() ) try: @@ -427,8 +479,18 @@ def __eq__(self, other): and are_all_other_check_attributes_equal ) + def _get_check_fn_code(self): + check_fn = self.__dict__["_check_fn"] + try: + code = check_fn.__code__.co_code + except AttributeError: + # try accessing the functools.partial wrapper + code = check_fn.func.__code__.co_code + + return code + def __hash__(self): - return hash(self.__dict__["_check_fn"].__code__.co_code) + return hash(self._get_check_fn_code()) def __repr__(self): return ( @@ -438,22 +500,9 @@ def __repr__(self): ) -class _CheckMeta(type): # pragma: no cover - """Check metaclass.""" - - def __getattr__(cls, name: str) -> Any: - """Prevent attribute errors for registered checks.""" - attr = cls.__dict__.get(name) - if attr is None: - raise AttributeError(f"'{cls}' object has no attribute '{name}'") - return attr - - -class Check(_CheckBase, metaclass=_CheckMeta): +class Check(_CheckBase): """Check a pandas Series or DataFrame for certain properties.""" - REGISTERED_CUSTOM_CHECKS: Dict[str, Callable] = {} # noqa - @classmethod @st.register_check_strategy(st.eq_strategy) @register_check_statistics(["value"]) diff --git a/pandera/extensions.py b/pandera/extensions.py index 308476f9f..32f0eb606 100644 --- a/pandera/extensions.py +++ b/pandera/extensions.py @@ -161,9 +161,8 @@ def check_method(cls, *args, **kwargs): if strategy is not None: check_method = st.register_check_strategy(strategy)(check_method) - setattr(Check, check_fn.__name__, classmethod(check_method)) - Check.REGISTERED_CUSTOM_CHECKS[check_fn.__name__] = getattr( - Check, check_fn.__name__ + Check.REGISTERED_CUSTOM_CHECKS[check_fn.__name__] = partial( + check_method, Check ) return register_check_wrapper(check_fn) diff --git a/pandera/io.py b/pandera/io.py index ef3f3451d..08f4f7a5d 100644 --- a/pandera/io.py +++ b/pandera/io.py @@ -25,7 +25,7 @@ NOT_JSON_SERIALIZABLE = {PandasDtype.DateTime, PandasDtype.Timedelta} -def _serialize_check_stats(check_stats, pandas_dtype): +def _serialize_check_stats(check_stats, pandas_dtype=None): """Serialize check statistics into json/yaml-compatible format.""" def handle_stat_dtype(stat): @@ -34,6 +34,7 @@ def handle_stat_dtype(stat): elif pandas_dtype == PandasDtype.Timedelta: # serialize to int in nanoseconds return stat.delta + return stat # for unary checks, return a single value instead of a dictionary @@ -47,18 +48,37 @@ def handle_stat_dtype(stat): return serialized_check_stats +def _serialize_dataframe_stats(dataframe_checks): + """ + Serialize global dataframe check statistics into json/yaml-compatible format. + """ + serialized_checks = {} + + for check_name, check_stats in dataframe_checks.items(): + # The case that `check_name` is not registered is handled in `parse_checks`, + # so we know that `check_name` exists. + + # infer dtype of statistics and serialize them + serialized_checks[check_name] = _serialize_check_stats(check_stats) + + return serialized_checks + + def _serialize_component_stats(component_stats): """ Serialize column or index statistics into json/yaml-compatible format. """ + # pylint: disable=import-outside-toplevel + from pandera.checks import Check + serialized_checks = None if component_stats["checks"] is not None: serialized_checks = {} for check_name, check_stats in component_stats["checks"].items(): - if check_stats is None: + if check_name not in Check: warnings.warn( f"Check {check_name} cannot be serialized. This check will be " - f"ignored" + "ignored. Did you forget to register it with the extension API?" ) else: serialized_checks[check_name] = _serialize_check_stats( @@ -93,7 +113,7 @@ def _serialize_schema(dataframe_schema): statistics = get_dataframe_schema_statistics(dataframe_schema) - columns, index = None, None + columns, index, checks = None, None, None if statistics["columns"] is not None: columns = { col_name: _serialize_component_stats(column_stats) @@ -106,17 +126,21 @@ def _serialize_schema(dataframe_schema): for index_stats in statistics["index"] ] + if statistics["checks"] is not None: + checks = _serialize_dataframe_stats(statistics["checks"]) + return { "schema_type": "dataframe", "version": __version__, "columns": columns, + "checks": checks, "index": index, "coerce": dataframe_schema.coerce, "strict": dataframe_schema.strict, } -def _deserialize_check_stats(check, serialized_check_stats, pandas_dtype): +def _deserialize_check_stats(check, serialized_check_stats, pandas_dtype=None): def handle_stat_dtype(stat): if pandas_dtype == PandasDtype.DateTime: return pd.to_datetime(stat, format=DATETIME_FORMAT) @@ -173,9 +197,9 @@ def _deserialize_component_stats(serialized_component_stats): def _deserialize_schema(serialized_schema): # pylint: disable=import-outside-toplevel - from pandera import Column, DataFrameSchema, Index, MultiIndex + from pandera import Check, Column, DataFrameSchema, Index, MultiIndex - columns, index = None, None + columns, index, checks = None, None, None if serialized_schema["columns"] is not None: columns = { col_name: Column(**_deserialize_component_stats(column_stats)) @@ -188,6 +212,13 @@ def _deserialize_schema(serialized_schema): for index_component in serialized_schema["index"] ] + if serialized_schema["checks"] is not None: + # handles unregistered checks by raising AttributeErrors from getattr + checks = [ + _deserialize_check_stats(getattr(Check, check_name), check_stats) + for check_name, check_stats in serialized_schema["checks"].items() + ] + if index is None: pass elif len(index) == 1: @@ -199,6 +230,7 @@ def _deserialize_schema(serialized_schema): return DataFrameSchema( columns=columns, + checks=checks, index=index, coerce=serialized_schema["coerce"], strict=serialized_schema["strict"], diff --git a/pandera/model.py b/pandera/model.py index 540367b56..1efecfac2 100644 --- a/pandera/model.py +++ b/pandera/model.py @@ -1,5 +1,6 @@ """Class-based api""" import inspect +import os import re import sys import typing @@ -170,6 +171,13 @@ def to_schema(cls) -> DataFrameSchema: MODEL_CACHE[cls] = cls.__schema__ return cls.__schema__ + @classmethod + def to_yaml(cls, stream: Optional[os.PathLike] = None): + """ + Convert `Schema` to yaml using `io.to_yaml`. + """ + return cls.to_schema().to_yaml(stream) + @classmethod @pd.util.Substitution(validate_doc=DataFrameSchema.validate.__doc__) def validate( diff --git a/pandera/schema_statistics.py b/pandera/schema_statistics.py index 3271442a1..6b30782dc 100644 --- a/pandera/schema_statistics.py +++ b/pandera/schema_statistics.py @@ -115,6 +115,7 @@ def get_dataframe_schema_statistics(dataframe_schema): } for col_name, column in dataframe_schema.columns.items() }, + "checks": parse_checks(dataframe_schema.checks), "index": ( None if dataframe_schema.index is None @@ -158,7 +159,17 @@ def parse_checks(checks) -> Union[Dict[str, Any], None]: check_statistics = {} _check_memo = {} for check in checks: - check_statistics[check.name] = check.statistics + if check not in Check: + warnings.warn( + "Only registered checks may be serialized to statistics. " + "Did you forget to register it with the extension API? " + f"Check `{check.name}` will be skipped." + ) + continue + + check_statistics[check.name] = ( + {} if check.statistics is None else check.statistics + ) _check_memo[check.name] = check # raise ValueError on incompatible checks diff --git a/pandera/schemas.py b/pandera/schemas.py index cbad5e4d2..e197f648c 100644 --- a/pandera/schemas.py +++ b/pandera/schemas.py @@ -3,6 +3,7 @@ import copy import itertools +import os import warnings from functools import wraps from pathlib import Path @@ -1186,17 +1187,16 @@ def from_yaml(cls, yaml_schema) -> "DataFrameSchema": return pandera.io.from_yaml(yaml_schema) - def to_yaml(self, fp: Union[str, Path] = None): + def to_yaml(self, stream: Optional[os.PathLike] = None): """Write DataFrameSchema to yaml file. - :param dataframe_schema: schema to write to file or dump to string. :param stream: file stream to write to. If None, dumps to string. :returns: yaml string if stream is None, otherwise returns None. """ # pylint: disable=import-outside-toplevel,cyclic-import import pandera.io - return pandera.io.to_yaml(self, fp) + return pandera.io.to_yaml(self, stream=stream) def set_index( self, keys: List[str], drop: bool = True, append: bool = False diff --git a/tests/core/checks_fixtures.py b/tests/core/checks_fixtures.py new file mode 100644 index 000000000..baa99e81e --- /dev/null +++ b/tests/core/checks_fixtures.py @@ -0,0 +1,33 @@ +"""Pytest fixtures for testing custom checks.""" +import unittest.mock as mock + +import pandas as pd +import pytest + +import pandera as pa +import pandera.extensions as pa_ext + +__all__ = "custom_check_teardown", "extra_registered_checks" + + +@pytest.fixture(scope="function") +def custom_check_teardown(): + """Remove all custom checks after execution of each pytest function.""" + yield + for check_name in list(pa.Check.REGISTERED_CUSTOM_CHECKS): + del pa.Check.REGISTERED_CUSTOM_CHECKS[check_name] + + +@pytest.fixture(scope="function") +def extra_registered_checks(): + """temporarily registers custom checks onto the Check class""" + # pylint: disable=unused-variable + with mock.patch( + "pandera.Check.REGISTERED_CUSTOM_CHECKS", new_callable=dict + ): + # register custom checks here + @pa_ext.register_check_method() + def no_param_check(_: pd.DataFrame) -> bool: + return True + + yield diff --git a/tests/core/conftest.py b/tests/core/conftest.py new file mode 100644 index 000000000..06f637dd9 --- /dev/null +++ b/tests/core/conftest.py @@ -0,0 +1,4 @@ +"""Registers fixtures for core""" + +# pylint: disable=unused-import +from .checks_fixtures import custom_check_teardown, extra_registered_checks diff --git a/tests/core/test_checks.py b/tests/core/test_checks.py index 5f2f3678d..b2fb9f027 100644 --- a/tests/core/test_checks.py +++ b/tests/core/test_checks.py @@ -353,12 +353,14 @@ def test_reshape_failure_cases_exceptions(): def test_check_equality_operators(): - """Test the usage of == between a Check and an entirely different Check.""" + """Test the usage of == between a Check and an entirely different Check, + and a non-Check.""" check = Check(lambda g: g["foo"]["col1"].iat[0] == 1, groupby="col3") not_equal_check = Check(lambda x: x.isna().sum() == 0) assert check == copy.deepcopy(check) assert check != not_equal_check + assert check != "not a check" def test_equality_operators_functional_equivalence(): diff --git a/tests/core/test_extensions.py b/tests/core/test_extensions.py index 689268a0a..17fc3b24e 100644 --- a/tests/core/test_extensions.py +++ b/tests/core/test_extensions.py @@ -13,13 +13,9 @@ from pandera.checks import Check -@pytest.fixture(scope="function") -def custom_check_teardown(): - """Remove all custom checks after execution of each pytest function.""" - yield - for check_name in list(pa.Check.REGISTERED_CUSTOM_CHECKS): - delattr(pa.Check, check_name) - del pa.Check.REGISTERED_CUSTOM_CHECKS[check_name] +def test_custom_checks_in_dir(extra_registered_checks): + """Ensures that autocomplete works with registered custom checks.""" + assert "no_param_check" in dir(pa.Check) @pytest.mark.parametrize( diff --git a/tests/core/test_schema_statistics.py b/tests/core/test_schema_statistics.py index 640468eff..d82f85f9d 100644 --- a/tests/core/test_schema_statistics.py +++ b/tests/core/test_schema_statistics.py @@ -388,6 +388,7 @@ def test_get_dataframe_schema_statistics(): ), ) expectation = { + "checks": None, "columns": { "int": { "pandas_dtype": pa.Int, @@ -561,3 +562,19 @@ def test_parse_checks_and_statistics_roundtrip(checks, expectation): check_statistics = {check.name: check.statistics for check in checks} check_list = schema_statistics.parse_check_statistics(check_statistics) assert set(check_list) == set(checks) + + +# pylint: disable=unused-argument +def test_parse_checks_and_statistics_no_param(extra_registered_checks): + """Ensure that an edge case where a check does not have parameters is appropriately handled.""" + + checks = [pa.Check.no_param_check()] + expectation = {"no_param_check": {}} + assert schema_statistics.parse_checks(checks) == expectation + + check_statistics = {check.name: check.statistics for check in checks} + check_list = schema_statistics.parse_check_statistics(check_statistics) + assert set(check_list) == set(checks) + + +# pylint: enable=unused-argument diff --git a/tests/io/test_io.py b/tests/io/test_io.py index 848c9108b..d82542a01 100644 --- a/tests/io/test_io.py +++ b/tests/io/test_io.py @@ -2,6 +2,7 @@ import platform import tempfile +import unittest.mock as mock from pathlib import Path import pandas as pd @@ -9,6 +10,8 @@ from packaging import version import pandera as pa +import pandera.extensions as pa_ext +import pandera.typing as pat try: from pandera import io @@ -99,7 +102,9 @@ def _create_schema(index="single"): regex=True, checks=[pa.Check.str_length(1, 3)], ), - "empty_column": pa.Column(), + "notype_column": pa.Column( + checks=pa.Check.isin(["foo", "bar", "x", "xy"]), + ), }, index=index, coerce=False, @@ -184,14 +189,20 @@ def _create_schema(index="single"): coerce: true required: false regex: true - empty_column: + notype_column: pandas_dtype: null nullable: false - checks: null + checks: + isin: + - foo + - bar + - x + - xy allow_duplicates: true coerce: false required: true regex: false +checks: null index: - pandas_dtype: int nullable: false @@ -253,6 +264,7 @@ def _create_schema_null_index(): min_value: 1 max_value: 3 index: null +checks: null coerce: false strict: false """ @@ -281,6 +293,51 @@ def _create_schema_python_types(): pandas_dtype: str object_column: pandas_dtype: object +checks: null +index: null +coerce: false +strict: false +""" + + +YAML_SCHEMA_MISSING_GLOBAL_CHECK = f""" +schema_type: dataframe +version: {pa.__version__} +columns: + int_column: + pandas_dtype: int64 + float_column: + pandas_dtype: float64 + str_column: + pandas_dtype: str + object_column: + pandas_dtype: object +checks: + unregistered_check: + stat1: missing_str_stat + stat2: 11 +index: null +coerce: false +strict: false +""" + + +YAML_SCHEMA_MISSING_COLUMN_CHECK = f""" +schema_type: dataframe +version: {pa.__version__} +columns: + int_column: + pandas_dtype: int64 + checks: + unregistered_check: + stat1: missing_str_stat + stat2: 11 + float_column: + pandas_dtype: float64 + str_column: + pandas_dtype: str + object_column: + pandas_dtype: object index: null coerce: false strict: false @@ -292,7 +349,7 @@ def _create_schema_python_types(): reason="pyyaml >= 5.1.0 required", ) def test_inferred_schema_io(): - """Test that inferred schema can be writted to yaml.""" + """Test that inferred schema can be written to yaml.""" df = pd.DataFrame( { "column1": [5, 10, 20], @@ -340,6 +397,16 @@ def test_from_yaml(yaml_str, schema_creator): assert expected_schema == schema_from_yaml +def test_from_yaml_unregistered_checks(): + """Test that from_yaml raises an exception when deserializing unregistered checks.""" + + with pytest.raises(AttributeError, match=".*custom checks.*"): + io.from_yaml(YAML_SCHEMA_MISSING_COLUMN_CHECK) + + with pytest.raises(AttributeError, match=".*custom checks.*"): + io.from_yaml(YAML_SCHEMA_MISSING_GLOBAL_CHECK) + + def test_io_yaml_file_obj(): """Test read and write operation on file object.""" schema = _create_schema() @@ -407,7 +474,7 @@ def test_to_script(index): def test_to_script_lambda_check(): """Test writing DataFrameSchema to a script with lambda check.""" - schema = pa.DataFrameSchema( + schema1 = pa.DataFrameSchema( { "a": pa.Column( pa.Int, @@ -417,7 +484,19 @@ def test_to_script_lambda_check(): ) with pytest.warns(UserWarning): - pa.io.to_script(schema) + pa.io.to_script(schema1) + + schema2 = pa.DataFrameSchema( + { + "a": pa.Column( + pa.Int, + ), + }, + checks=pa.Check(lambda s: s.mean() > 5, element_wise=False), + ) + + with pytest.warns(UserWarning, match=".*registered checks.*"): + pa.io.to_script(schema2) def test_to_yaml_lambda_check(): @@ -433,3 +512,82 @@ def test_to_yaml_lambda_check(): with pytest.warns(UserWarning): pa.io.to_yaml(schema) + + +@mock.patch("pandera.Check.REGISTERED_CUSTOM_CHECKS", new_callable=dict) +def test_to_yaml_registered_dataframe_check(_): + """Tests that writing DataFrameSchema with a registered dataframe check works.""" + ncols_gt_called = False + + @pa_ext.register_check_method(statistics=["column_count"]) + def ncols_gt(pandas_obj: pd.DataFrame, column_count: int) -> bool: + """test registered dataframe check""" + + # pylint: disable=unused-variable + nonlocal ncols_gt_called + ncols_gt_called = True + assert isinstance(column_count, int), "column_count must be integral" + assert isinstance( + pandas_obj, pd.DataFrame + ), "ncols_gt should only be applied to DataFrame" + return len(pandas_obj.columns) > column_count + + assert ( + len(pa.Check.REGISTERED_CUSTOM_CHECKS) == 1 + ), "custom check is registered" + + schema = pa.DataFrameSchema( + { + "a": pa.Column( + pa.Int, + ), + }, + checks=[pa.Check.ncols_gt(column_count=5)], + ) + + serialized = pa.io.to_yaml(schema) + loaded = pa.io.from_yaml(serialized) + + assert len(loaded.checks) == 1, "global check was stripped" + + with pytest.raises(pa.errors.SchemaError): + schema.validate(pd.DataFrame(data={"a": [1]})) + + assert ncols_gt_called, "did not call ncols_gt" + + +def test_to_yaml_custom_dataframe_check(): + """Tests that writing DataFrameSchema with an unregistered check raises.""" + + schema = pa.DataFrameSchema( + { + "a": pa.Column( + pa.Int, + ), + }, + checks=[pa.Check(lambda obj: len(obj.index) > 1)], + ) + + with pytest.warns(UserWarning, match=".*registered checks.*"): + pa.io.to_yaml(schema) + + # the unregistered column check case is tested in `test_to_yaml_lambda_check` + + +def test_to_yaml_bugfix_419(): + """Ensure that GH#419 is fixed""" + # pylint: disable=no-self-use + + class CheckedSchemaModel(pa.SchemaModel): + """Schema with a global check""" + + a: pat.Series[pat.Int64] + b: pat.Series[pat.Int64] + + @pa.dataframe_check() + def unregistered_check(self, _): + """sample unregistered check""" + ... + + with pytest.warns(UserWarning, match=".*registered checks.*"): + CheckedSchemaModel.to_yaml()