From 47c7f86e7992ec508bde0dd2379243e78881084e Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Thu, 15 Jul 2021 18:43:00 -0400 Subject: [PATCH] decouple pandera and pandas dtypes (#559) * refactor PandasDtype into class hierarchy supported by engines * refactor DataFrameSchema based on DataType hierarchy * refactor SchemaModel based on DataType hierarchy * revert fix coerce=True and dtype=None should be a noop * apply code style * fix running tests/core with nox * consolidate dtype names * consolidate engine internal naming * disable inherited __init__ with immutable(init=False) * delete duplicated immutable * disambiguate dtype variables * add warning on base pandas_engine, numpy_engine.DataType init * fix pylint, mypy errors * fix DataFrameSchema.dtypes return type * enable CI on dtypes branch * Refactor inference, schema_statistics, strategies and io using the DataType hierarchy (#504) * fix pandas_engine.Interval * fix Timedelta64 registration with pandas_engine.Engine * add DataType helpers * add DataType.continuous attribute * add dtypes.is_numeric * refactor schema_statistics based on DataType hierarchy * refactor schema_inference based on DataType hierarchy * fix numpy_engine.Timedelta64.type * add is_subdtype helper * add Engine.get_registered_dtypes * fix Engine error when registering a base DataType * fix pandas_engine DateTime string alias * clean up test_dtypes * fix test_extensions * refactor strategies based on DataType hierarchy * refactor io based on DataType hierarchy * replace dtypes module by new DataType hierarchy * fix black * delete dtypes_.py * drop legacy pandas and python 3.6 from CI * fix mypy errors * fix ci-docs * fix conda dependencies * fix lint, update noxfile * simplify nox tests, fix test_io * update ci build * update nox * pin nox, handle windows data types * fix windows platform * fix pandas_engine on windows platform * fix test_dtypes on windows platform * force pip on docs CI * test out windows dtype stuff * more messing around with windows * more debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * revert ci * increase cache * testing Co-authored-by: cosmicBboy * Add DataTypes documentation (#536) * delete print statements * pin furo * fix generated docs not removed by nox * re-organize API section * replace aliased pandas_engine data types with their aliases * drop warning when calling Engine.register_dtype without arguments * add data types to api reference doc * add document for DataType refactor * unpin sphinx and drop sphinx_rtd_theme * add xdoctest * ignore prompt when copying example from doc * add doctest builder when running sphinx-build locally * fix dtypes doc examples * fix pandas_engine.DataType.check * fix pylint * remove whitespaces in dtypes doc * Update docs/source/dtypes.rst * Update dtypes.rst * update docs structure * update nox file * force pip on doctests * update test_schemas * fix docs session not overriding html with doctest output Co-authored-by: Niels Bantilan * add deprecation warnings for pandas_dtype and PandasDtype enum (#547) * remove auto-generated docs * add deprecation warnings, support pandas>=1.3.0 * add deprecation warnings for PandasDtype enum * fix sphinx * fix windows * fix windows * add support for pyarrow backed string data type (#548) * add support for pyarrow backed string data type * fix regression for pandas < 1.3.0 * add verbosity to test run * loosen strategies unit tests deadline, exclude windows ci * loosen test_strategies.py tests * use "dev" hypothesis profile for python 3.7 * add pandas==1.2.5 test * fix ci * ci typo * don't install environment.yml on unit tests * install nox in ci * remove environment.yml * update environment in ci Co-authored-by: cosmicBboy Co-authored-by: Jean-Francois Zinque --- .github/workflows/ci-tests.yml | 51 +- .gitignore | 2 +- Makefile | 2 +- docs/source/API_reference.rst | 167 --- .../_templates/{enum_class.rst => dtype.rst} | 28 +- docs/source/_templates/pandas_dtype_class.rst | 25 - docs/source/conf.py | 7 +- docs/source/data_synthesis_strategies.rst | 4 +- docs/source/dataframe_schemas.rst | 81 +- docs/source/dtypes.rst | 188 +++ docs/source/extensions.rst | 14 +- docs/source/index.rst | 9 +- docs/source/lazy_validation.rst | 14 +- docs/source/reference/core.rst | 35 + docs/source/reference/decorators.rst | 13 + docs/source/reference/dtypes.rst | 109 ++ docs/source/reference/errors.rst | 14 + docs/source/reference/extensions.rst | 11 + docs/source/reference/index.rst | 41 + docs/source/reference/io.rst | 16 + docs/source/reference/schema_inference.rst | 10 + docs/source/reference/schema_models.rst | 45 + docs/source/reference/strategies.rst | 11 + docs/source/schema_inference.rst | 35 +- docs/source/schema_models.rst | 49 +- environment.yml | 12 +- noxfile.py | 133 +-- pandera/__init__.py | 76 +- pandera/checks.py | 12 +- pandera/deprecations.py | 38 + pandera/dtypes.py | 914 ++++++++------- pandera/engines/__init__.py | 0 pandera/engines/engine.py | 219 ++++ pandera/engines/numpy_engine.py | 361 ++++++ pandera/engines/pandas_engine.py | 704 +++++++++++ pandera/io.py | 64 +- pandera/model.py | 4 +- pandera/schema_components.py | 42 +- pandera/schema_inference.py | 7 +- pandera/schema_statistics.py | 54 +- pandera/schemas.py | 350 +++--- pandera/strategies.py | 219 ++-- pandera/typing.py | 143 +-- requirements-dev.txt | 12 +- setup.py | 7 +- tests/conftest.py | 4 +- tests/core/test_decorators.py | 15 +- tests/core/test_deprecations.py | 53 + tests/core/test_dtypes.py | 1028 +++++++---------- tests/core/test_engine.py | 195 ++++ tests/core/test_extensions.py | 6 +- tests/core/test_model.py | 4 +- tests/core/test_model_components.py | 5 +- tests/core/test_schema_components.py | 67 +- tests/core/test_schema_statistics.py | 178 +-- tests/core/test_schemas.py | 468 +++----- tests/core/test_typing.py | 108 +- tests/io/test_io.py | 263 +++-- tests/strategies/test_strategies.py | 318 +++-- 59 files changed, 4326 insertions(+), 2738 deletions(-) delete mode 100644 docs/source/API_reference.rst rename docs/source/_templates/{enum_class.rst => dtype.rst} (57%) delete mode 100644 docs/source/_templates/pandas_dtype_class.rst create mode 100644 docs/source/dtypes.rst create mode 100644 docs/source/reference/core.rst create mode 100644 docs/source/reference/decorators.rst create mode 100644 docs/source/reference/dtypes.rst create mode 100644 docs/source/reference/errors.rst create mode 100644 docs/source/reference/extensions.rst create mode 100644 docs/source/reference/index.rst create mode 100644 docs/source/reference/io.rst create mode 100644 docs/source/reference/schema_inference.rst create mode 100644 docs/source/reference/schema_models.rst create mode 100644 docs/source/reference/strategies.rst create mode 100644 pandera/deprecations.py create mode 100644 pandera/engines/__init__.py create mode 100644 pandera/engines/engine.py create mode 100644 pandera/engines/numpy_engine.py create mode 100644 pandera/engines/pandas_engine.py create mode 100644 tests/core/test_deprecations.py create mode 100644 tests/core/test_engine.py diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 196f68cd4..81fbf88a9 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -2,16 +2,18 @@ name: CI Tests on: push: branches: - - master - - dev - - bugfix - - 'release/*' + - master + - dev + - bugfix + - "release/*" + - dtypes pull_request: branches: - - master - - dev - - bugfix - - 'release/*' + - master + - dev + - bugfix + - "release/*" + - dtypes env: DEFAULT_PYTHON: 3.8 @@ -71,7 +73,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.6", "3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9"] defaults: run: shell: bash -l {0} @@ -133,16 +135,21 @@ jobs: tests: name: > - CI Tests (${{ matrix.python-version }}, - ${{ matrix.os }}, - pandas-${{ matrix.pandas-version }}) + CI Tests (${{ matrix.python-version }}, ${{ matrix.os }}, pandas-${{ matrix.pandas-version }}) runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] - python-version: ["3.6", "3.7", "3.8", "3.9"] - pandas-version: ["latest", "0.25.3"] + python-version: ["3.7", "3.8", "3.9"] + pandas-version: ["1.2.5", "latest"] + # exclude these configurations until issue tracked here is fixed: + # https://github.com/pandera-dev/pandera/issues/555 + exclude: + - os: windows-latest + python-version: "3.7" + - os: windows-latest + python-version: "3.9" defaults: run: @@ -174,6 +181,13 @@ jobs: use-only-tar-bz2: true auto-activate-base: false + - name: Install pandas + run: | + if [[ "${{ matrix.pandas-version }}" != 'latest' ]] + then + mamba install -c conda-forge pandas==${{ matrix.pandas-version }} + fi + - name: Conda info run: | conda info @@ -210,9 +224,16 @@ jobs: - name: Upload coverage to Codecov uses: "codecov/codecov-action@v1" + - name: Check Docstrings + run: > + nox + -db conda -r -v + --non-interactive + --session "doctests-${{ matrix.python-version }}" + - name: Check Docs run: > nox -db conda -r -v --non-interactive - --session "docs-${{ matrix.python-version }}(pandas='${{ matrix.pandas-version }}')" + --session "docs-${{ matrix.python-version }}" diff --git a/.gitignore b/.gitignore index 0d8362231..39e188cf5 100644 --- a/.gitignore +++ b/.gitignore @@ -113,7 +113,7 @@ venv.bak/ /asv_bench/results/ # Docs -docs/source/generated +docs/source/reference/generated # Nox .nox diff --git a/Makefile b/Makefile index e035de64d..23d7e5d05 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ requirements: pip install -r requirements-dev.txt docs: - rm -rf docs/source/generated && \ + rm -rf docs/**/generated docs/**/methods docs/_build && \ python -m sphinx -E "docs/source" "docs/_build" -W && \ make -C docs doctest diff --git a/docs/source/API_reference.rst b/docs/source/API_reference.rst deleted file mode 100644 index 1234b3e11..000000000 --- a/docs/source/API_reference.rst +++ /dev/null @@ -1,167 +0,0 @@ -.. pandera package index documentation toctree - -.. currentmodule:: pandera - -API -=== - -The ``io`` module and built-in ``Hypothesis`` checks require a pandera -installation with the corresponding extension, see the -:ref:`installation` instructions for more details. - -Schemas -------- - -.. autosummary:: - :toctree: generated - :template: class.rst - :nosignatures: - - pandera.schemas.DataFrameSchema - pandera.schemas.SeriesSchema - - -Schema Components ------------------ - -.. autosummary:: - :toctree: generated - :template: class.rst - :nosignatures: - - pandera.schema_components.Column - pandera.schema_components.Index - pandera.schema_components.MultiIndex - - -Schema Models -------------- - -.. autosummary:: - :toctree: generated - :template: class.rst - :nosignatures: - - pandera.model.SchemaModel - -**Model Components** - -.. autosummary:: - :toctree: generated - :nosignatures: - - pandera.model_components.Field - pandera.model_components.check - pandera.model_components.dataframe_check - -**Typing** - -.. autosummary:: - :toctree: generated - :template: typing_module.rst - :nosignatures: - - pandera.typing - -**Config** - -.. autosummary:: - :toctree: generated - :template: model_component_class.rst - :nosignatures: - - pandera.model.BaseConfig - - -Checks ------- - -.. autosummary:: - :toctree: generated - :template: class.rst - :nosignatures: - - pandera.checks.Check - pandera.hypotheses.Hypothesis - - -Pandas Data Types ------------------ - -.. autosummary:: - :toctree: generated - :template: pandas_dtype_class.rst - :nosignatures: - - pandera.dtypes.PandasDtype - - -Decorators ----------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - pandera.decorators.check_input - pandera.decorators.check_output - pandera.decorators.check_io - pandera.decorators.check_types - - -Schema Inference ----------------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - pandera.schema_inference.infer_schema - - -IO Utils --------- - -.. autosummary:: - :toctree: generated - :nosignatures: - - pandera.io.from_yaml - pandera.io.to_yaml - pandera.io.to_script - - -Data Synthesis Strategies -------------------------- - -.. autosummary:: - :toctree: generated - :template: strategies_module.rst - :nosignatures: - - pandera.strategies - - -Extensions ----------- - -.. autosummary:: - :toctree: generated - :template: module.rst - :nosignatures: - - pandera.extensions - - -Errors ------- - -.. autosummary:: - :toctree: generated - :template: class.rst - :nosignatures: - - pandera.errors.SchemaError - pandera.errors.SchemaErrors - pandera.errors.SchemaInitError - pandera.errors.SchemaDefinitionError diff --git a/docs/source/_templates/enum_class.rst b/docs/source/_templates/dtype.rst similarity index 57% rename from docs/source/_templates/enum_class.rst rename to docs/source/_templates/dtype.rst index c10df62d9..7625a0dfe 100644 --- a/docs/source/_templates/enum_class.rst +++ b/docs/source/_templates/dtype.rst @@ -2,17 +2,6 @@ .. currentmodule:: {{ module }} -.. autoclass:: PandasDtype - :show-inheritance: - :exclude-members: - - .. autoattribute:: str_alias - .. automethod:: from_str_alias - .. automethod:: from_pandas_api_type - - - - .. autoclass:: {{ objname }} {% block attributes %} @@ -37,15 +26,16 @@ :nosignatures: :toctree: methods - {% for item in methods %} - {%- if item not in inherited_members %} - ~{{ name }}.{{ item }} - {%- endif %} - {%- endfor %} - {% endif %} + {# Ignore the DateTime alias to avoid `WARNING: document isn't included in any toctree`#} + {% if objname != "DateTime" %} + {% for item in methods %} + ~{{ name }}.{{ item }} + {%- endfor %} - {%- if '__call__' in members %} - ~{{ name }}.__call__ + {%- if members and '__call__' in members %} + ~{{ name }}.__call__ + {%- endif %} {%- endif %} + {%- endif %} {% endblock %} diff --git a/docs/source/_templates/pandas_dtype_class.rst b/docs/source/_templates/pandas_dtype_class.rst deleted file mode 100644 index 3bc84a901..000000000 --- a/docs/source/_templates/pandas_dtype_class.rst +++ /dev/null @@ -1,25 +0,0 @@ -{{ fullname | escape | underline}} - -.. currentmodule:: {{ module }} - -.. autoclass:: {{ objname }} - :show-inheritance: - :exclude-members: - - {% block attributes %} - {% if attributes %} - .. rubric:: Attributes - - .. autosummary:: - :nosignatures: - - {% for item in attributes %} - ~{{ name }}.{{ item }} - {%- endfor %} - - {% endif %} - {% endblock %} - - .. autoattribute:: str_alias - .. automethod:: from_str_alias - .. automethod:: from_pandas_api_type diff --git a/docs/source/conf.py b/docs/source/conf.py index f82b1237e..32feb47d7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -162,7 +162,7 @@ .. role:: green """ -autosummary_generate = ["API_reference.rst"] +autosummary_generate = True autosummary_filename_map = { "pandera.Check": "pandera.Check", "pandera.check": "pandera.check_decorator", @@ -174,6 +174,11 @@ "pandas": ("http://pandas.pydata.org/pandas-docs/stable/", None), } +# strip prompts +copybutton_prompt_text = ( + r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " +) +copybutton_prompt_is_regexp = True # this is a workaround to filter out forward reference issue in # sphinx_autodoc_typehints diff --git a/docs/source/data_synthesis_strategies.rst b/docs/source/data_synthesis_strategies.rst index d49e4a115..3bbd410e0 100644 --- a/docs/source/data_synthesis_strategies.rst +++ b/docs/source/data_synthesis_strategies.rst @@ -4,8 +4,8 @@ .. _data synthesis strategies: -Data Synthesis Strategies (new) -=============================== +Data Synthesis Strategies +========================= *new in 0.6.0* diff --git a/docs/source/dataframe_schemas.rst b/docs/source/dataframe_schemas.rst index c6a9f8f4d..99431f4e0 100644 --- a/docs/source/dataframe_schemas.rst +++ b/docs/source/dataframe_schemas.rst @@ -10,7 +10,7 @@ DataFrame Schemas The :class:`~pandera.schemas.DataFrameSchema` class enables the specification of a schema that verifies the columns and index of a pandas ``DataFrame`` object. -The ``DataFrameSchema`` object consists of |column|_\s and an |index|_. +The :class:`~pandera.schemas.DataFrameSchema` object consists of |column|_\s and an |index|_. .. |column| replace:: ``Column`` .. |index| replace:: ``Index`` @@ -44,12 +44,25 @@ The ``DataFrameSchema`` object consists of |column|_\s and an |index|_. Column Validation ----------------- -A :class:`~pandera.schema_components.Column` must specify the properties of a column in a dataframe -object. It can be optionally verified for its data type, `null values`_ or +A :class:`~pandera.schema_components.Column` must specify the properties of a +column in a dataframe object. It can be optionally verified for its data type, +`null values`_ or duplicate values. The column can be coerced_ into the specified type, and the required_ parameter allows control over whether or not the column is allowed to be missing. +Similarly to pandas, the data type can be specified as: + +* a string alias, as long as it is recognized by pandas. +* a python type: `int`, `float`, `double`, `bool`, `str` +* a `numpy data type <(https://numpy.org/doc/stable/user/basics.types.html)>`_ +* a `pandas extension type <(https://pandas.pydata.org/pandas-docs/stable/user_guide/basics.html#dtypes)>`_: + it can be an instance (e.g `pd.CategoricalDtype(["a", "b"])`) or a + class (e.g `pandas.CategoricalDtype`) if it can be initialized with default + values. +* a pandera :class:`~pandera.dtypes.DataType`: it can also be an instance or a + class. + :ref:`Column checks` allow for the DataFrame's values to be checked against a user-provided function. ``Check`` objects also support :ref:`grouping` by a different column so that the user can make @@ -80,7 +93,7 @@ nullable. In order to accept null values, you need to explicitly specify df = pd.DataFrame({"column1": [5, 1, np.nan]}) non_null_schema = DataFrameSchema({ - "column1": Column(pa.Int, Check(lambda x: x > 0)) + "column1": Column(pa.Float, Check(lambda x: x > 0)) }) non_null_schema.validate(df) @@ -91,18 +104,11 @@ nullable. In order to accept null values, you need to explicitly specify ... SchemaError: non-nullable series contains null values: {2: nan} -.. note:: Due to a known limitation in - `pandas prior to version 0.24.0 `_, - integer arrays cannot contain ``NaN`` values, so this schema will return - a DataFrame where ``column1`` is of type ``float``. - :class:`~pandera.dtypes.PandasDtype` does not currently support the nullable integer - array type, but you can still use the "Int64" string alias for nullable - integer arrays .. testcode:: null_values_in_columns null_schema = DataFrameSchema({ - "column1": Column(pa.Int, Check(lambda x: x > 0), nullable=True) + "column1": Column(pa.Float, Check(lambda x: x > 0), nullable=True) }) print(null_schema.validate(df)) @@ -277,7 +283,7 @@ objects can also be used to validate columns in a dataframe on its own: validated_df = df.pipe(column1_schema).pipe(column2_schema) -For multi-column use cases, the ``DataFrameSchema`` is still recommended, but +For multi-column use cases, the :class:`~pandera.schemas.DataFrameSchema` is still recommended, but if you have one or a small number of columns to verify, using ``Column`` objects by themselves is appropriate. @@ -309,12 +315,12 @@ a set of meaningfully grouped columns that have ``str`` names. }) schema = pa.DataFrameSchema({ - "num_var_*": pa.Column( + "num_var_.+": pa.Column( pa.Float, checks=pa.Check.greater_than_or_equal_to(0), regex=True, ), - "cat_var_*": pa.Column( + "cat_var_.+": pa.Column( pa.Category, checks=pa.Check.isin(categories), coerce=True, @@ -347,12 +353,12 @@ You can also regex pattern match on ``pd.MultiIndex`` columns: }) schema = pa.DataFrameSchema({ - ("num_var_*", "x*"): pa.Column( + ("num_var_.+", "x.+"): pa.Column( pa.Float, checks=pa.Check.greater_than_or_equal_to(0), regex=True, ), - ("cat_var_*", "y*"): pa.Column( + ("cat_var_.+", "y.+"): pa.Column( pa.Category, checks=pa.Check.isin(categories), coerce=True, @@ -401,7 +407,7 @@ schema, specify ``strict=True``: Traceback (most recent call last): ... - SchemaError: column 'column2' not in DataFrameSchema {'column1': } + SchemaError: column 'column2' not in DataFrameSchema {'column1': } Alternatively, if your DataFrame contains columns that are not in the schema, and you would like these to be dropped on validation, @@ -601,12 +607,13 @@ indexes by composing a list of ``pandera.Index`` objects. foo 2 3 -Get Pandas Datatypes --------------------- +Get Pandas Data Types +--------------------- Pandas provides a `dtype` parameter for casting a dataframe to a specific dtype -schema. ``DataFrameSchema`` provides a `dtype` property which returns a pandas -style dict. The keys of the dict are column names and values are the dtype. +schema. :class:`~pandera.schemas.DataFrameSchema` provides +a :attr:`~pandera.schemas.DataFrameSchema.dtypes` property which returns a +dictionary whose keys are column names and values are :class:`~pandera.dtypes.DataType`. Some examples of where this can be provided to pandas are: @@ -626,13 +633,17 @@ Some examples of where this can be provided to pandas are: }, ) - df = pd.DataFrame.from_dict( - { - "a": {"column1": 1, "column2": "valueA", "column3": True}, - "b": {"column1": 1, "column2": "valueB", "column3": True}, - }, - orient="index" - ).astype(schema.dtype).sort_index(axis=1) + df = ( + pd.DataFrame.from_dict( + { + "a": {"column1": 1, "column2": "valueA", "column3": True}, + "b": {"column1": 1, "column2": "valueB", "column3": True}, + }, + orient="index", + ) + .astype({col: str(dtype) for col, dtype in schema.dtypes.items()}) + .sort_index(axis=1) + ) print(schema.validate(df)) @@ -718,11 +729,11 @@ data pipeline: + 'col1': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index=None, strict=True name=None, @@ -756,15 +767,15 @@ the pipeline output. + 'column2': }, checks=[], coerce=True, - pandas_dtype=None, + dtype=None, index= - + + ] coerce=False, strict=False, diff --git a/docs/source/dtypes.rst b/docs/source/dtypes.rst new file mode 100644 index 000000000..b687d0d19 --- /dev/null +++ b/docs/source/dtypes.rst @@ -0,0 +1,188 @@ +.. pandera documentation for check_input and check_output decorators + +.. currentmodule:: pandera + +.. _dtypes: + +Pandera Data Types (new) +======================== + +*new in 0.7.0* + +Motivations +~~~~~~~~~~~ + +Pandera defines its own interface for data types in order to abstract the +specifics of dataframe-like data structures in the python ecosystem, such +as Apache Spark, Apache Arrow and xarray. + +.. note:: In the following section ``Pandera Data Type`` refers to a + :class:`pandera.dtypes.DataType` object whereas ``native data type`` refers + to data types used by third-party libraries that Pandera supports (e.g. pandas). + +Most of the time, it is transparent to end users since pandera columns and +indexes accept native data types. However, it is possible to extend the pandera +interface by: + +* modifying the **data type check** performed during schema validation. +* modifying the behavior of the **coerce** argument for :class:`~pandea.schemas.DataFrameSchema`. +* adding your **own custom data types**. + +DataType basics +~~~~~~~~~~~~~~~ + +All pandera data types inherit from :class:`pandera.dtypes.DataType` and must +be hashable. + +A data type implements three key methods: + +* :meth:`pandera.dtypes.DataType.check` which validates that data types are equivalent. +* :meth:`pandera.dtypes.DataType.coerce` which coerces a data container + (e.g. :class:`pandas.Series`) to the data type. +* The dunder method ``__str__()`` which should output the native alias. + For example ``str(pandera.Float64) == "float64"`` + + +For pandera's validation methods to be aware of a data type, it has to be +registered with the targeted engine via :meth:`pandera.engines.engine.Engine.register_dtype`. +An engine is in charge of mapping a pandera :class:`~pandera.dtypes.DataType` +with a native data type counterpart belonging to a third-party library. The mapping +can be queried with :meth:`pandera.engines.engine.Engine.dtype`. + +As of pandera ``0.7.0``, only the pandas :class:`~pandera.engines.pandas_engine.Engine` +is supported. + + +Example +~~~~~~~ + +Let's extend :class:`pandas.BooleanDtype` coercion to handle the string +literals ``"True"`` and ``"False"``. + +.. testcode:: dtypes + + import pandas as pd + import pandera as pa + from pandera import dtypes + from pandera.engines import pandas_engine + + + @pandas_engine.Engine.register_dtype # step 1 + @dtypes.immutable # step 2 + class LiteralBool(pandas_engine.BOOL): # step 3 + def coerce(self, series: pd.Series) -> pd.Series: + """Coerce a pandas.Series to date types.""" + if pd.api.types.is_string_dtype(series): + series = series.replace({"True": 1, "False": 0}) + return series.astype("boolean") + + + data = pd.Series(["True", "False"], name="literal_bools") + + # step 4 + print( + pa.SeriesSchema(LiteralBool(), coerce=True, name="literal_bools") + .validate(data) + .dtype + ) + +.. testoutput:: dtypes + + boolean + +The example above performs the following steps: + +1. Register the data type with the pandas engine. +2. :func:`pandera.dtypes.immutable` creates an immutable (and hashable) + :func:`dataclass`. +3. Inherit :class:`pandera.engines.pandas_engine.BOOL`, which is the pandera + representation of :class:`pandas.BooleanDtype`. This is not mandatory but + it makes our life easier by having already implemented all the required + methods. +4. Check that our new data type can coerce the string literals. + +So far we did not override the default behavior: + +.. testcode:: dtypes + + import pandera as pa + + pa.SeriesSchema("boolean", coerce=True).validate(data) + + +.. testoutput:: dtypes + + Traceback (most recent call last): + ... + pandera.errors.SchemaError: Error while coercing 'literal_bools' to type boolean: Need to pass bool-like values + +To completely replace the default :class:`~pandera.engines.pandas_engine.BOOL`, +we need to supply all the equivalent representations to +:meth:`~pandera.engines.engine.Engine.register_dtype`. Behind the scenes, when +``pa.SeriesSchema("boolean")`` is called the corresponding pandera data type +is looked up using :meth:`pandera.engines.engine.Engine.dtype`. + +.. testcode:: dtypes + + print(f"before: {pandas_engine.Engine.dtype('boolean').__class__}") + + + @pandas_engine.Engine.register_dtype( + equivalents=["boolean", pd.BooleanDtype, pd.BooleanDtype()], + ) + @dtypes.immutable + class LiteralBool(pandas_engine.BOOL): + def coerce(self, series: pd.Series) -> pd.Series: + """Coerce a pandas.Series to date types.""" + if pd.api.types.is_string_dtype(series): + series = series.replace({"True": 1, "False": 0}) + return series.astype("boolean") + + + print(f"after: {pandas_engine.Engine.dtype('boolean').__class__}") + + for dtype in ["boolean", pd.BooleanDtype, pd.BooleanDtype()]: + pa.SeriesSchema(dtype, coerce=True).validate(data) + +.. testoutput:: dtypes + + before: + after: + +.. note:: For convenience, we specified both ``pd.BooleanDtype`` and + ``pd.BooleanDtype()`` as equivalents. That gives us more flexibility in + what pandera schemas can recognize (see last for-loop above). + +Parametrized data types +~~~~~~~~~~~~~~~~~~~~~~~ + +Some data types can be parametrized. One common example is +:class:`pandas.CategoricalDtype`. + +The ``equivalents`` argument of +:meth:`~pandera.engines.engine.Engine.register_dtype` does not handle +this situation but will automatically register a :func:`classmethod` with +signature ``from_parametrized_dtype(cls, equivalent:...)`` if the decorated +:class:`~pandera.dtypes.DataType` defines it. The ``equivalent`` argument must +be type-annotated because it is leveraged to dispatch the input of +:class:`~pandera.engines.engine.Engine.dtype` to the appropriate +``from_parametrized_dtype`` class method. + +For example, here is a snippet from :class:`pandera.engines.pandas_engine.Category`: + +.. code-block:: python + + import pandas as pd + from pandera import dtypes + + @classmethod + def from_parametrized_dtype( + cls, cat: Union[dtypes.Category, pd.CategoricalDtype] + ): + """Convert a categorical to + a Pandera :class:`pandera.dtypes.pandas_engine.Category`.""" + return cls(categories=cat.categories, ordered=cat.ordered) # type: ignore + + +.. note:: The dispatch mechanism relies on :func:`functools.singledispatch`. + Unlike the built-in implementation, :data:`typing.Union` is recognized. diff --git a/docs/source/extensions.rst b/docs/source/extensions.rst index c0c3e4173..dd9be7344 100644 --- a/docs/source/extensions.rst +++ b/docs/source/extensions.rst @@ -4,8 +4,8 @@ .. _extensions: -Extensions (new) -================ +Extensions +========== *new in 0.6.0* @@ -94,20 +94,20 @@ The corresponding strategy for this check would be: import pandera.strategies as st def equals_strategy( - pandas_dtype: pa.PandasDtype, + pandera_dtype: pa.DataType, strategy: Optional[st.SearchStrategy] = None, *, value, ): if strategy is None: return st.pandas_dtype_strategy( - pandas_dtype, strategy=hypothesis.strategies.just(value), + pandera_dtype, strategy=hypothesis.strategies.just(value), ) return strategy.filter(lambda x: x == value) As you may notice, the ``pandera`` strategy interface is has two arguments followed by keyword-only arguments that match the check function keyword-only -check statistics. The ``pandas_dtype`` positional argument is useful for +check statistics. The ``pandera_dtype`` positional argument is useful for ensuring the correct data type. In the above example, we're using the :func:`~pandera.strategies.pandas_dtype_strategy` strategy to make sure the generated ``value`` is of the correct data type. @@ -147,7 +147,7 @@ would look like: :skipif: SKIP_STRATEGY def in_between_strategy( - pandas_dtype: pa.PandasDtype, + pandera_dtype: pa.DataType, strategy: Optional[st.SearchStrategy] = None, *, min_value, @@ -155,7 +155,7 @@ would look like: ): if strategy is None: return st.pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, min_value=min_value, max_value=max_value, exclude_min=False, diff --git a/docs/source/index.rst b/docs/source/index.rst index 9ea8ae79b..c49cf791d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -155,7 +155,7 @@ Quick Start You can pass the built-in python types that are supported by pandas, or strings representing the `legal pandas datatypes `_, -or pandera's ``PandasDtype`` enum: +or pandera's ``DataType``: .. testcode:: quick_start @@ -171,13 +171,13 @@ or pandera's ``PandasDtype`` enum: # pandas > 1.0.0 support native "string" type "str_column2": pa.Column("str"), - # pandera PandasDtype enum + # pandera DataType "int_column3": pa.Column(pa.Int), "float_column3": pa.Column(pa.Float), "str_column3": pa.Column(pa.String), }) -For more details on data types, see :class:`~pandera.dtypes.PandasDtype` +For more details on data types, see :class:`~pandera.dtypes.DataType` Schema Model @@ -306,6 +306,7 @@ Submit issues, feature requests or bugfixes on series_schemas checks hypothesis + dtypes decorators schema_inference schema_models @@ -318,7 +319,7 @@ Submit issues, feature requests or bugfixes on :caption: Reference :hidden: - API_reference + reference/index .. toctree:: :maxdepth: 6 diff --git a/docs/source/lazy_validation.rst b/docs/source/lazy_validation.rst index 0ed7177da..76f735b2d 100644 --- a/docs/source/lazy_validation.rst +++ b/docs/source/lazy_validation.rst @@ -15,9 +15,9 @@ exception: * a column specified in the schema is not present in the dataframe. * if ``strict=True``, a column in the dataframe is not specified in the schema. -* the ``pandas_dtype`` does not match. +* the ``data type`` does not match. * if ``coerce=True``, the dataframe column cannot be coerced into the specified - ``pandas_dtype``. + ``data type``. * the :class:`~pandera.checks.Check` specified in one of the columns returns ``False`` or a boolean series containing at least one ``False`` value. @@ -94,8 +94,8 @@ of all schemas and schema components gives you the option of doing just this: schema_context column check DataFrameSchema column_in_dataframe [date_column] 1 column_in_schema [unknown_column] 1 - Column float_column pandas_dtype('float64') [int64] 1 - int_column pandas_dtype('int64') [object] 1 + Column float_column dtype('float64') [int64] 1 + int_column dtype('int64') [object] 1 str_column equal_to(a) [b, d] 2 Usage Tip @@ -135,9 +135,9 @@ catch these errors and inspect the failure cases in a more granular form: schema_context column check check_number \ 0 DataFrameSchema None column_in_schema None 1 DataFrameSchema None column_in_dataframe None - 2 Column int_column pandas_dtype('int64') None - 3 Column float_column pandas_dtype('float64') None - 4 Column float_column greater_than(0) 0 + 2 Column int_column dtype('int64') None + 3 Column float_column dtype('float64') None + 4 Column str_column equal_to(a) 0 5 Column str_column equal_to(a) 0 6 Column str_column equal_to(a) 0 diff --git a/docs/source/reference/core.rst b/docs/source/reference/core.rst new file mode 100644 index 000000000..c39175c6e --- /dev/null +++ b/docs/source/reference/core.rst @@ -0,0 +1,35 @@ +.. _api-core: + +Schemas +======= + +.. autosummary:: + :toctree: generated + :template: class.rst + :nosignatures: + + pandera.schemas.DataFrameSchema + pandera.schemas.SeriesSchema + +Schema Components +================= + +.. autosummary:: + :toctree: generated + :template: class.rst + :nosignatures: + + pandera.schema_components.Column + pandera.schema_components.Index + pandera.schema_components.MultiIndex + +Checks +====== + +.. autosummary:: + :toctree: generated + :template: class.rst + :nosignatures: + + pandera.checks.Check + pandera.hypotheses.Hypothesis diff --git a/docs/source/reference/decorators.rst b/docs/source/reference/decorators.rst new file mode 100644 index 000000000..2506336f4 --- /dev/null +++ b/docs/source/reference/decorators.rst @@ -0,0 +1,13 @@ +.. _api-decorators: + +Decorators +========== + +.. autosummary:: + :toctree: generated + :nosignatures: + + pandera.decorators.check_input + pandera.decorators.check_output + pandera.decorators.check_io + pandera.decorators.check_types diff --git a/docs/source/reference/dtypes.rst b/docs/source/reference/dtypes.rst new file mode 100644 index 000000000..4e4db6cf1 --- /dev/null +++ b/docs/source/reference/dtypes.rst @@ -0,0 +1,109 @@ +.. _api-dtypes: + +Pandera Data Types +================== + +Library-agnostic dtypes +----------------------- + +.. autosummary:: + :toctree: generated + :template: dtype.rst + :nosignatures: + + pandera.dtypes.DataType + pandera.dtypes.Bool + pandera.dtypes.Timestamp + pandera.dtypes.DateTime + pandera.dtypes.Timedelta + pandera.dtypes.Category + pandera.dtypes.Float + pandera.dtypes.Float16 + pandera.dtypes.Float32 + pandera.dtypes.Float64 + pandera.dtypes.Float128 + pandera.dtypes.Int + pandera.dtypes.Int8 + pandera.dtypes.Int16 + pandera.dtypes.Int32 + pandera.dtypes.Int64 + pandera.dtypes.UInt + pandera.dtypes.UInt8 + pandera.dtypes.UInt16 + pandera.dtypes.UInt32 + pandera.dtypes.UInt64 + pandera.dtypes.Complex + pandera.dtypes.Complex64 + pandera.dtypes.Complex128 + pandera.dtypes.Complex256 + pandera.dtypes.String + + +Pandas-specific Dtypes +---------------------- + +Listed here for compatibility with pandera versions < 0.7. +Passing native pandas dtypes to pandera components is preferred. + +.. autosummary:: + :toctree: generated + :template: dtype.rst + :nosignatures: + + pandera.engines.pandas_engine.BOOL + pandera.engines.pandas_engine.INT8 + pandera.engines.pandas_engine.INT16 + pandera.engines.pandas_engine.INT32 + pandera.engines.pandas_engine.INT64 + pandera.engines.pandas_engine.UINT8 + pandera.engines.pandas_engine.UINT16 + pandera.engines.pandas_engine.UINT32 + pandera.engines.pandas_engine.UINT64 + pandera.engines.pandas_engine.STRING + pandera.engines.numpy_engine.Object + +Utility functions +----------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + pandera.dtypes.is_subdtype + pandera.dtypes.is_float + pandera.dtypes.is_int + pandera.dtypes.is_uint + pandera.dtypes.is_complex + pandera.dtypes.is_numeric + pandera.dtypes.is_bool + pandera.dtypes.is_string + pandera.dtypes.is_datetime + pandera.dtypes.is_timedelta + pandera.dtypes.immutable + +Engines +------- + +.. autosummary:: + :toctree: generated + :template: class.rst + :nosignatures: + + pandera.engines.engine.Engine + pandera.engines.numpy_engine.Engine + pandera.engines.pandas_engine.Engine + + +PandasDtype Enum +---------------- + +.. warning:: + + This class deprecated and will be removed from the pandera API in ``0.9.0`` + +.. autosummary:: + :toctree: generated + :template: class.rst + :nosignatures: + + pandera.engines.pandas_engine.PandasDtype diff --git a/docs/source/reference/errors.rst b/docs/source/reference/errors.rst new file mode 100644 index 000000000..74fac1bde --- /dev/null +++ b/docs/source/reference/errors.rst @@ -0,0 +1,14 @@ +.. _api-errors: + +Errors +====== + +.. autosummary:: + :toctree: generated + :template: class.rst + :nosignatures: + + pandera.errors.SchemaError + pandera.errors.SchemaErrors + pandera.errors.SchemaInitError + pandera.errors.SchemaDefinitionError diff --git a/docs/source/reference/extensions.rst b/docs/source/reference/extensions.rst new file mode 100644 index 000000000..617b5ed7a --- /dev/null +++ b/docs/source/reference/extensions.rst @@ -0,0 +1,11 @@ +.. _api-extensions: + +Extensions +========== + +.. autosummary:: + :toctree: generated + :template: module.rst + :nosignatures: + + pandera.extensions diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst new file mode 100644 index 000000000..3ec1d4713 --- /dev/null +++ b/docs/source/reference/index.rst @@ -0,0 +1,41 @@ +.. pandera package index documentation toctree + +.. currentmodule:: pandera + +API +=== + +.. list-table:: + :widths: 25 75 + + * - :ref:`Core ` + - The core objects for defining pandera schemas + * - :ref:`Data Types ` + - Data types for type checking and coercion. + * - :ref:`Schema Models ` + - Alternative class-based API for defining pandera schemas. + * - :ref:`Decorators ` + - Decorators for integrating pandera schemas with python functions. + * - :ref:`Schema Inference ` + - Bootstrap schemas from real data + * - :ref:`IO Utilities ` + - Utility functions for reading/writing schemas + * - :ref:`Strategies ` + - Module of functions for generating data from schemas. + * - :ref:`Extensions ` + - Utility functions for extending pandera functionality + * - :ref:`Errors ` + - Pandera-specific exceptions + +.. toctree:: + :hidden: + + core + schema_models + decorators + schema_inference + io + strategies + extensions + errors + dtypes diff --git a/docs/source/reference/io.rst b/docs/source/reference/io.rst new file mode 100644 index 000000000..2da272a14 --- /dev/null +++ b/docs/source/reference/io.rst @@ -0,0 +1,16 @@ +.. _api-io-utils: + +IO Utils +======== + +The ``io`` module and built-in ``Hypothesis`` checks require a pandera +installation with the corresponding extension, see the +:ref:`installation` instructions for more details. + +.. autosummary:: + :toctree: generated + :nosignatures: + + pandera.io.from_yaml + pandera.io.to_yaml + pandera.io.to_script diff --git a/docs/source/reference/schema_inference.rst b/docs/source/reference/schema_inference.rst new file mode 100644 index 000000000..179c151ac --- /dev/null +++ b/docs/source/reference/schema_inference.rst @@ -0,0 +1,10 @@ +.. _api-schema-inference: + +Schema Inference +================ + +.. autosummary:: + :toctree: generated + :nosignatures: + + pandera.schema_inference.infer_schema diff --git a/docs/source/reference/schema_models.rst b/docs/source/reference/schema_models.rst new file mode 100644 index 000000000..9468a3380 --- /dev/null +++ b/docs/source/reference/schema_models.rst @@ -0,0 +1,45 @@ +.. _api-schema-models: + +Schema Models +============= + +.. currentmodule:: pandera + +Schema Model +------------ + +.. autosummary:: + :toctree: generated + :template: class.rst + + pandera.model.SchemaModel + +Model Components +---------------- + +.. autosummary:: + :toctree: generated + + pandera.model_components.Field + pandera.model_components.check + pandera.model_components.dataframe_check + +Typing +------ + +.. autosummary:: + :toctree: generated + :template: typing_module.rst + :nosignatures: + + pandera.typing + +Config +------ + +.. autosummary:: + :toctree: generated + :template: model_component_class.rst + :nosignatures: + + pandera.model.BaseConfig diff --git a/docs/source/reference/strategies.rst b/docs/source/reference/strategies.rst new file mode 100644 index 000000000..16f9b1aaa --- /dev/null +++ b/docs/source/reference/strategies.rst @@ -0,0 +1,11 @@ +.. _api-strategies: + +Data Synthesis Strategies +========================= + +.. autosummary:: + :toctree: generated + :template: strategies_module.rst + :nosignatures: + + pandera.strategies diff --git a/docs/source/schema_inference.rst b/docs/source/schema_inference.rst index 33eda0505..be802a783 100644 --- a/docs/source/schema_inference.rst +++ b/docs/source/schema_inference.rst @@ -36,14 +36,14 @@ is a simple example: - 'column2': - 'column3': + 'column1': + 'column2': + 'column3': }, checks=[], coerce=True, - pandas_dtype=None, - index=, + dtype=None, + index=, strict=False name=None, ordered=False @@ -96,19 +96,12 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr :skipif: SKIP from pandas import Timestamp - from pandera import ( - DataFrameSchema, - Column, - Check, - Index, - MultiIndex, - PandasDtype, - ) + from pandera import DataFrameSchema, Column, Check, Index, MultiIndex schema = DataFrameSchema( columns={ "column1": Column( - pandas_dtype=PandasDtype.Int64, + dtype=pandera.engines.numpy_engine.Int64, checks=[ Check.greater_than_or_equal_to(min_value=5.0), Check.less_than_or_equal_to(max_value=20.0), @@ -120,7 +113,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr regex=False, ), "column2": Column( - pandas_dtype=PandasDtype.String, + dtype=pandera.engines.numpy_engine.Object, checks=None, nullable=False, allow_duplicates=True, @@ -129,7 +122,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr regex=False, ), "column3": Column( - pandas_dtype=PandasDtype.DateTime, + dtype=pandera.engines.pandas_engine.DateTime, checks=[ Check.greater_than_or_equal_to( min_value=Timestamp("2010-01-01 00:00:00") @@ -146,7 +139,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr ), }, index=Index( - pandas_dtype=PandasDtype.Int64, + dtype=pandera.engines.numpy_engine.Int64, checks=[ Check.greater_than_or_equal_to(min_value=0.0), Check.less_than_or_equal_to(max_value=2.0), @@ -187,7 +180,7 @@ is a convenience method for this functionality. version: 0.6.5 columns: column1: - pandas_dtype: int64 + dtype: int64 nullable: false checks: greater_than_or_equal_to: 5.0 @@ -197,7 +190,7 @@ is a convenience method for this functionality. required: true regex: false column2: - pandas_dtype: str + dtype: object nullable: false checks: null allow_duplicates: true @@ -205,7 +198,7 @@ is a convenience method for this functionality. required: true regex: false column3: - pandas_dtype: datetime64[ns] + dtype: datetime64[ns] nullable: false checks: greater_than_or_equal_to: '2010-01-01 00:00:00' @@ -216,7 +209,7 @@ is a convenience method for this functionality. regex: false checks: null index: - - pandas_dtype: int64 + - dtype: int64 nullable: false checks: greater_than_or_equal_to: 0.0 diff --git a/docs/source/schema_models.rst b/docs/source/schema_models.rst index a8f6fd0d0..65a4970cc 100644 --- a/docs/source/schema_models.rst +++ b/docs/source/schema_models.rst @@ -72,7 +72,7 @@ Basic Usage Traceback (most recent call last): ... - pandera.errors.SchemaError: > failed element-wise validator 0: + pandera.errors.SchemaError: failed element-wise validator 0: failure cases: index failure_case @@ -121,13 +121,13 @@ You can easily convert a :class:`~pandera.model.SchemaModel` class into a )> - 'month': )> - 'day': )> + 'year': + 'month': + 'day': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index=None, strict=False name=None, @@ -165,12 +165,8 @@ however, a couple of gotchas. Dtype aliases ^^^^^^^^^^^^^ -The enumeration :class:`~pandera.dtypes.PandasDtype` is not directly supported because -the type parameter of a :class:`typing.Generic` cannot be an enumeration [#dtypes]_. -Instead, you can use the :mod:`pandera.typing` counterparts: -:data:`pandera.typing.Category`, :data:`pandera.typing.Float32`, ... - -:green:`✔` Good: +:mod:`pandera.typing` aliases will be deprecated in a future version, +please use :class:`~pandera.dtypes.DataType` subclasses instead. .. code-block:: @@ -180,21 +176,6 @@ Instead, you can use the :mod:`pandera.typing` counterparts: class Schema(pa.SchemaModel): a: Series[String] -:red:`✘` Bad: - -.. testcode:: dataframe_schema_model - :skipif: SKIP_PANDAS_LT_V1 - - class Schema(pa.SchemaModel): - a: Series[pa.PandasDtype.String] - -.. testoutput:: dataframe_schema_model - :skipif: SKIP_PANDAS_LT_V1 - - Traceback (most recent call last): - ... - TypeError: python type '' not recognized as pandas data type - Type Vs instance ^^^^^^^^^^^^^^^^ @@ -437,8 +418,8 @@ the class-based API: )> - )> + + ] coerce=True, strict=True, @@ -531,7 +512,7 @@ Column/Index checks Traceback (most recent call last): ... - pandera.errors.SchemaError: > failed series validator 1: + pandera.errors.SchemaError: failed series validator 1: .. _schema_model_dataframe_check: @@ -611,7 +592,7 @@ The custom checks are inherited and therefore can be overwritten by the subclass Traceback (most recent call last): ... - pandera.errors.SchemaError: > failed element-wise validator 0: + pandera.errors.SchemaError: failed element-wise validator 0: failure cases: index failure_case @@ -700,11 +681,3 @@ the class scope, and it will respect the alias. 2020 a 0 99 101 - - -Footnotes ---------- - -.. [#dtypes] It is actually possible to use a PandasDtype by encasing it in a - :class:`typing.Literal` like ``Series[Literal[PandasDtype.Category]]``. - :mod:`pandera.typing` defines aliases to reduce boilerplate. diff --git a/environment.yml b/environment.yml index 080fab82d..659f9147e 100644 --- a/environment.yml +++ b/environment.yml @@ -17,6 +17,7 @@ dependencies: - typing_inspect >= 0.6.0 - typing_extensions >= 3.7.4.3 - frictionless + - pyarrow # testing and dependencies - black >= 20.8b1 @@ -24,19 +25,19 @@ dependencies: # testing - isort >= 5.7.0 - codecov - - mypy = 0.812 # TODO: update codebase to be 0.902+ compatible + - mypy >= 0.902 # mypy no longer bundle stubs for third-party libraries - pylint >= 2.7.2 - pytest - pytest-cov - pytest-xdist - pytest-asyncio + - xdoctest - setuptools >= 52.0.0 - - nox = 2020.12.31 # TODO: update codebase to be 2021.6.6+ compatible + - nox = 2020.12.31 # pinning due to UnicodeDecodeError, see https://github.com/pandera-dev/pandera/pull/504/checks?check_run_id=2841360122 - importlib_metadata # required if python < 3.8 # documentation - - sphinx = 3.5.4 # pinned due to doc-building error https://github.com/pandera-dev/pandera/runs/2601459267 - - sphinx_rtd_theme + - sphinx - sphinx-autodoc-typehints - sphinx-copybutton - recommonmark @@ -52,3 +53,6 @@ dependencies: - pip: - furo==2021.6.18b36 + - types-click + - types-pyyaml + - types-pkg_resources diff --git a/noxfile.py b/noxfile.py index c30627e45..7ebab94bb 100644 --- a/noxfile.py +++ b/noxfile.py @@ -3,16 +3,15 @@ import os import shutil import sys -from typing import Dict, List, cast +from typing import Dict, List # setuptools must be imported before distutils ! -import setuptools # pylint:disable=unused-import +import setuptools # pylint:disable=unused-import # noqa: F401 from distutils.core import run_setup # pylint:disable=wrong-import-order import nox from nox import Session from pkg_resources import Requirement, parse_requirements -from packaging import version nox.options.sessions = ( @@ -23,16 +22,18 @@ "mypy", "tests", "docs", + "doctests", ) DEFAULT_PYTHON = "3.8" -PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] +PYTHON_VERSIONS = ["3.7", "3.8", "3.9"] +PANDAS_VERSIONS = ["1.2.5", "latest"] PACKAGE = "pandera" SOURCE_PATHS = PACKAGE, "tests", "noxfile.py" REQUIREMENT_PATH = "requirements-dev.txt" -ALWAYS_USE_PIP = ["furo"] +ALWAYS_USE_PIP = ["furo", "types-click", "types-pyyaml", "types-pkg_resources"] CI_RUN = os.environ.get("CI") == "true" if CI_RUN: @@ -169,13 +170,13 @@ def install_from_requirements(session: Session, *packages: str) -> None: def install_extras( session: Session, - pandas: str = "latest", extra: str = "core", - force_pip=False, + force_pip: bool = False, + pandas: str = "latest", ) -> None: """Install dependencies.""" - pandas_version = "" if pandas == "latest" else f"=={pandas}" specs, pip_specs = [], [] + pandas_version = "" if pandas == "latest" else f"=={pandas}" for spec in REQUIRES[extra].values(): if spec.split("==")[0] in ALWAYS_USE_PIP: pip_specs.append(spec) @@ -186,9 +187,11 @@ def install_extras( if extra == "core": specs.append(REQUIRES["all"]["hypothesis"]) + # CI installs conda dependencies, so only run this for local runs if ( isinstance(session.virtualenv, nox.virtualenv.CondaEnv) and not force_pip + and not CI_RUN ): print("using conda installer") conda_install(session, *specs) @@ -196,8 +199,8 @@ def install_extras( print("using pip installer") session.install(*specs) - session.install(*pip_specs) # always use pip for these packages + session.install(*pip_specs) session.install("-e", ".", "--no-deps") # install pandera @@ -284,31 +287,11 @@ def lint(session: Session) -> None: @nox.session(python=PYTHON_VERSIONS) def mypy(session: Session) -> None: """Type-check using mypy.""" - python_version = version.parse(cast(str, session.python)) - install_extras( - session, - extra="all", - # this is a hack until typed-ast conda package starts working again, - # basically this issue comes up: - # https://github.com/python/mypy/pull/2906 - force_pip=python_version == version.parse("3.7"), - ) + install_extras(session, extra="all") args = session.posargs or SOURCE_PATHS session.run("mypy", "--follow-imports=silent", *args, silent=True) -def _invalid_python_pandas_versions(session: Session, pandas: str) -> bool: - python_version = version.parse(cast(str, session.python)) - if pandas == "0.25.3" and ( - python_version >= version.parse("3.9") - # this is just a bandaid until support for 0.25.3 is dropped - or python_version == version.parse("3.7") - ): - print("Python 3.9 does not support pandas 0.25.3") - return True - return False - - EXTRA_NAMES = [ extra for extra in REQUIRES @@ -317,22 +300,11 @@ def _invalid_python_pandas_versions(session: Session, pandas: str) -> bool: @nox.session(python=PYTHON_VERSIONS) -@nox.parametrize("pandas", ["0.25.3", "latest"]) +@nox.parametrize("pandas", PANDAS_VERSIONS) @nox.parametrize("extra", EXTRA_NAMES) def tests(session: Session, pandas: str, extra: str) -> None: """Run the test suite.""" - if _invalid_python_pandas_versions(session, pandas): - return - python_version = version.parse(cast(str, session.python)) - install_extras( - session, - pandas, - extra, - # this is a hack until typed-ast conda package starts working again, - # basically this issue comes up: - # https://github.com/python/mypy/pull/2906 - force_pip=python_version == version.parse("3.7"), - ) + install_extras(session, extra, pandas=pandas) if session.posargs: args = session.posargs @@ -340,11 +312,15 @@ def tests(session: Session, pandas: str, extra: str) -> None: path = f"tests/{extra}/" if extra != "all" else "tests" args = [] if extra == "strategies": + # strategies tests runs very slowly in python 3.7: + # https://github.com/pandera-dev/pandera/issues/556 + # as a stop-gap, use the "dev" profile for 3.7 + profile = "ci" if CI_RUN and session.python != "3.7" else "dev" # enable threading via pytest-xdist args = [ "-n=auto", "-q", - f"--hypothesis-profile={'ci' if CI_RUN else 'dev'}", + f"--hypothesis-profile={profile}", ] args += [ f"--cov={PACKAGE}", @@ -360,45 +336,44 @@ def tests(session: Session, pandas: str, extra: str) -> None: @nox.session(python=PYTHON_VERSIONS) -@nox.parametrize("pandas", ["0.25.3", "latest"]) -def docs(session: Session, pandas: str) -> None: +def doctests(session: Session) -> None: """Build the documentation.""" - if _invalid_python_pandas_versions(session, pandas): - return - python_version = version.parse(cast(str, session.python)) - install_extras( - session, - pandas, - extra="all", - # this is a hack until typed-ast conda package starts working again, - # basically this issue comes up: - # https://github.com/python/mypy/pull/2906 - force_pip=python_version == version.parse("3.7"), - ) - session.chdir("docs") + install_extras(session, extra="all", force_pip=True) + session.run("xdoctest", PACKAGE, "--quiet") - shutil.rmtree(os.path.join("_build"), ignore_errors=True) - args = session.posargs or [ - "-v", - "-v", - "-W", - "-E", - "-b=doctest", - "source", - "_build", - ] - session.run("sphinx-build", *args) + +@nox.session(python=PYTHON_VERSIONS) +def docs(session: Session) -> None: + """Build the documentation.""" + install_extras(session, extra="all", force_pip=True) + session.chdir("docs") # build html docs if not CI_RUN and not session.posargs: - shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) - session.run( - "sphinx-build", + shutil.rmtree("_build", ignore_errors=True) + shutil.rmtree( + os.path.join("source", "reference", "generated"), + ignore_errors=True, + ) + for builder in ["doctest", "html"]: + session.run( + "sphinx-build", + "-W", + "-T", + f"-b={builder}", + "-d", + os.path.join("_build", "doctrees", ""), + "source", + os.path.join("_build", builder, ""), + ) + else: + shutil.rmtree(os.path.join("_build"), ignore_errors=True) + args = session.posargs or [ + "-v", "-W", - "-T", - "-b=html", - "-d", - os.path.join("_build", "doctrees", ""), + "-E", + "-b=doctest", "source", - os.path.join("_build", "html", ""), - ) + "_build", + ] + session.run("sphinx-build", *args) diff --git a/pandera/__init__.py b/pandera/__init__.py index 477e5e8c1..e7f15bd69 100644 --- a/pandera/__init__.py +++ b/pandera/__init__.py @@ -1,9 +1,52 @@ """A flexible and expressive pandas validation library.""" +import platform + +from pandera.dtypes import ( + Bool, + Category, + Complex, + Complex64, + Complex128, + DataType, + DateTime, + Float, + Float16, + Float32, + Float64, + Int, + Int8, + Int16, + Int32, + Int64, + String, + Timedelta, + Timestamp, + UInt, + UInt8, + UInt16, + UInt32, + UInt64, +) +from pandera.engines.numpy_engine import Object +from pandera.engines.pandas_engine import ( + BOOL, + INT8, + INT16, + INT32, + INT64, + PANDAS_1_3_0_PLUS, + STRING, + UINT8, + UINT16, + UINT32, + UINT64, +) +from pandera.engines.pandas_engine import _PandasDtype as PandasDtype +from pandera.engines.pandas_engine import pandas_version from . import constants, errors, pandas_accessor from .checks import Check from .decorators import check_input, check_io, check_output, check_types -from .dtypes import LEGACY_PANDAS, PandasDtype from .hypotheses import Hypothesis from .model import SchemaModel from .model_components import Field, check, dataframe_check @@ -12,32 +55,5 @@ from .schemas import DataFrameSchema, SeriesSchema from .version import __version__ -# pylint: disable=invalid-name -Bool = PandasDtype.Bool -DateTime = PandasDtype.DateTime -Category = PandasDtype.Category -Float = PandasDtype.Float -Float16 = PandasDtype.Float16 -Float32 = PandasDtype.Float32 -Float64 = PandasDtype.Float64 -Int = PandasDtype.Int -Int8 = PandasDtype.Int8 -Int16 = PandasDtype.Int16 -Int32 = PandasDtype.Int32 -Int64 = PandasDtype.Int64 -UInt8 = PandasDtype.UInt8 -UInt16 = PandasDtype.UInt16 -UInt32 = PandasDtype.UInt32 -UInt64 = PandasDtype.UInt64 -INT8 = PandasDtype.INT8 -INT16 = PandasDtype.INT16 -INT32 = PandasDtype.INT32 -INT64 = PandasDtype.INT64 -UINT8 = PandasDtype.UINT8 -UINT16 = PandasDtype.UINT16 -UINT32 = PandasDtype.UINT32 -UINT64 = PandasDtype.UINT64 -Object = PandasDtype.Object -String = PandasDtype.String -STRING = PandasDtype.STRING -Timedelta = PandasDtype.Timedelta +if platform.system() != "Windows": + from pandera.dtypes import Complex256, Float128 diff --git a/pandera/checks.py b/pandera/checks.py index c5f026eaf..e4a66c0c0 100644 --- a/pandera/checks.py +++ b/pandera/checks.py @@ -470,13 +470,13 @@ def __eq__(self, other: object) -> bool: are_strategy_fn_objects_equal = True are_all_other_check_attributes_equal = { - i: v - for i, v in self.__dict__.items() - if i not in ["_check_fn", "strategy"] + k: v + for k, v in self.__dict__.items() + if k not in ["_check_fn", "strategy"] } == { - i: v - for i, v in other.__dict__.items() - if i not in ["_check_fn", "strategy"] + k: v + for k, v in other.__dict__.items() + if k not in ["_check_fn", "strategy"] } return ( diff --git a/pandera/deprecations.py b/pandera/deprecations.py new file mode 100644 index 000000000..f06b489bb --- /dev/null +++ b/pandera/deprecations.py @@ -0,0 +1,38 @@ +"""Utility functions for deprecating features.""" + +import inspect +import warnings +from functools import wraps + +from pandera.errors import SchemaInitError + + +def deprecate_pandas_dtype(fn): + """ + __init__ decorator for raising SchemaInitError or warnings based on + the dtype and pandas_dtype input. + """ + + @wraps(fn) + def wrapper(*args, **kwargs): + """__init__ method wrapper for raising deprecation warning.""" + sig = inspect.signature(fn) + bound_args = sig.bind(*args, **kwargs) + dtype = bound_args.arguments.get("dtype", None) + pandas_dtype = bound_args.arguments.get("pandas_dtype", None) + + msg = ( + "`pandas_dtype` is deprecated and will be removed as an " + "option in pandera v0.9.0, use `dtype` instead." + ) + + if dtype is not None and pandas_dtype is not None: + raise SchemaInitError( + f"`dtype` and `pandas_dtype` cannot both be specified. {msg}" + ) + if pandas_dtype is not None: + warnings.warn(msg, DeprecationWarning) + + return fn(*args, **kwargs) + + return wrapper diff --git a/pandera/dtypes.py b/pandera/dtypes.py index b57c74954..30b746e4a 100644 --- a/pandera/dtypes.py +++ b/pandera/dtypes.py @@ -1,460 +1,488 @@ -# pylint: disable=no-member,too-many-public-methods -"""Schema datatypes.""" - -from enum import Enum -from typing import Optional, Type, Union - -import numpy as np -import pandas as pd -from packaging import version - -PandasExtensionType = pd.core.dtypes.base.ExtensionDtype - -PANDAS_VERSION = version.parse(pd.__version__) -LEGACY_PANDAS = PANDAS_VERSION.major < 1 # type: ignore -PANDAS_1_3_0_PLUS = PANDAS_VERSION.release >= (1, 3, 0) # type: ignore -NUMPY_NONNULLABLE_INT_DTYPES = [ - "int", - "int_", - "int8", - "int16", - "int32", - "int64", - "uint8", - "uint16", - "uint32", - "uint64", -] - -NUMPY_TYPES = frozenset( - [item for sublist in np.sctypes.values() for item in sublist] # type: ignore -).union( - frozenset([np.complex_, np.int_, np.uint, np.float_, np.str_, np.bool_]) +"""Pandera data types.""" +# pylint:disable=too-many-ancestors +import dataclasses +import inspect +from abc import ABC +from typing import ( + Any, + Callable, + Iterable, + Optional, + Tuple, + Type, + TypeVar, + Union, ) -# for int and float dtype, delegate string representation to the -# default based on OS. In Windows, pandas defaults to int64 while numpy -# defaults to int32. -_DEFAULT_PANDAS_INT_TYPE = str(pd.Series([1]).dtype) -_DEFAULT_PANDAS_FLOAT_TYPE = str(pd.Series([1.0]).dtype) -_DEFAULT_PANDAS_COMPLEX_TYPE = str(pd.Series([complex(1)]).dtype) -_DEFAULT_NUMPY_INT_TYPE = str(np.dtype(int)) -_DEFAULT_NUMPY_FLOAT_TYPE = str(np.dtype(float)) +class DataType(ABC): + """Base class of all Pandera data types.""" + + continuous: Optional[bool] = None + """Whether the number data type is continuous.""" + + def __init__(self): + if self.__class__ is DataType: + raise TypeError( + f"{self.__class__.__name__} may not be instantiated." + ) + + def coerce(self, data_container: Any): + """Coerce data container to the data type.""" + raise NotImplementedError() + + def __call__(self, data_container: Any): + """Coerce data container to the data type.""" + return self.coerce(data_container) + + def check(self, pandera_dtype: "DataType") -> bool: + """Check that pandera :class:`~pandera.dtypes.DataType` are + equivalent.""" + return self == pandera_dtype + + def __repr__(self) -> str: + return f"DataType({str(self)})" + + def __str__(self) -> str: + raise NotImplementedError() + + def __hash__(self) -> int: + raise NotImplementedError() -def is_extension_dtype(dtype): - """Check if a value is a pandas extension type or instance of one.""" - return isinstance(dtype, PandasExtensionType) or ( - isinstance(dtype, type) and issubclass(dtype, PandasExtensionType) - ) +_Dtype = TypeVar("_Dtype", bound=DataType) +_DataTypeClass = Type[_Dtype] -class PandasDtype(Enum): - # pylint: disable=line-too-long,invalid-name - """Enumerate all valid pandas data types. - - ``pandera`` follows the - `numpy data types `_ - subscribed to by ``pandas`` and by default supports using the numpy data - type string aliases to validate DataFrame or Series dtypes. - - This class simply enumerates the valid numpy dtypes for pandas arrays. - For convenience ``PandasDtype`` enums can all be accessed in the top-level - ``pandera`` name space via the same enum name. - - :examples: - - >>> import pandas as pd - >>> import pandera as pa - >>> - >>> - >>> pa.SeriesSchema(pa.Int).validate(pd.Series([1, 2, 3])) - 0 1 - 1 2 - 2 3 - dtype: int64 - >>> pa.SeriesSchema(pa.Float).validate(pd.Series([1.1, 2.3, 3.4])) - 0 1.1 - 1 2.3 - 2 3.4 - dtype: float64 - >>> pa.SeriesSchema(pa.String).validate(pd.Series(["a", "b", "c"])) - 0 a - 1 b - 2 c - dtype: object - - Alternatively, you can use built-in python scalar types for integers, - floats, booleans, and strings: - - >>> pa.SeriesSchema(int).validate(pd.Series([1, 2, 3])) - 0 1 - 1 2 - 2 3 - dtype: int64 - - You can also use the pandas string aliases in the schema definition: - - >>> pa.SeriesSchema("int").validate(pd.Series([1, 2, 3])) - 0 1 - 1 2 - 2 3 - dtype: int64 - - .. note:: - ``pandera`` also offers limited support for - `pandas extension types `_, - however since the release of pandas 1.0.0 there are backwards - incompatible extension types like the ``Integer`` array. The extension - types, e.g. ``pd.IntDtype64()`` and their string alias should work - when supplied to the ``pandas_dtype`` argument, unless otherwise - specified below, but this functionality is only tested for - pandas >= 1.0.0. Extension types in earlier versions are not guaranteed - to work as the ``pandas_dtype`` argument in schemas or schema - components. +def immutable( + pandera_dtype_cls: Optional[_DataTypeClass] = None, **dataclass_kwargs: Any +) -> Union[_DataTypeClass, Callable[[_DataTypeClass], _DataTypeClass]]: + """:func:`dataclasses.dataclass` decorator with different default values: + `frozen=True`, `init=False`, `repr=False`. + + In addition, `init=False` disables inherited `__init__` method to ensure + the DataType's default attributes are not altered during initialization. + + :param dtype: :class:`DataType` to decorate. + :param dataclass_kwargs: Keywords arguments forwarded to + :func:`dataclasses.dataclass`. + :returns: Immutable :class:`DataType` """ + kwargs = {"frozen": True, "init": False, "repr": False} + kwargs.update(dataclass_kwargs) - Bool = "bool" #: ``"bool"`` numpy dtype - DateTime = "datetime64[ns]" #: ``"datetime64[ns]"`` numpy dtype - Timedelta = "timedelta64[ns]" #: ``"timedelta64[ns]"`` numpy dtype - Category = "category" #: pandas ``"categorical"`` datatype - Float = "float" #: ``"float"`` numpy dtype - Float16 = "float16" #: ``"float16"`` numpy dtype - Float32 = "float32" #: ``"float32"`` numpy dtype - Float64 = "float64" #: ``"float64"`` numpy dtype - Int = "int" #: ``"int"`` numpy dtype - Int8 = "int8" #: ``"int8"`` numpy dtype - Int16 = "int16" #: ``"int16"`` numpy dtype - Int32 = "int32" #: ``"int32"`` numpy dtype - Int64 = "int64" #: ``"int64"`` numpy dtype - UInt8 = "uint8" #: ``"uint8"`` numpy dtype - UInt16 = "uint16" #: ``"uint16"`` numpy dtype - UInt32 = "uint32" #: ``"uint32"`` numpy dtype - UInt64 = "uint64" #: ``"uint64"`` numpy dtype - INT8 = "Int8" #: ``"Int8"`` pandas dtype:: pandas 0.24.0+ - INT16 = "Int16" #: ``"Int16"`` pandas dtype: pandas 0.24.0+ - INT32 = "Int32" #: ``"Int32"`` pandas dtype: pandas 0.24.0+ - INT64 = "Int64" #: ``"Int64"`` pandas dtype: pandas 0.24.0+ - UINT8 = "UInt8" #: ``"UInt8"`` pandas dtype: pandas 0.24.0+ - UINT16 = "UInt16" #: ``"UInt16"`` pandas dtype: pandas 0.24.0+ - UINT32 = "UInt32" #: ``"UInt32"`` pandas dtype: pandas 0.24.0+ - UINT64 = "UInt64" #: ``"UInt64"`` pandas dtype: pandas 0.24.0+ - Object = "object" #: ``"object"`` numpy dtype - Complex = "complex" #: ``"complex"`` numpy dtype - Complex64 = "complex64" #: ``"complex"`` numpy dtype - Complex128 = "complex128" #: ``"complex"`` numpy dtype - Complex256 = "complex256" #: ``"complex"`` numpy dtype - String = "str" #: ``"str"`` numpy dtype - - #: ``"string"`` pandas dtypes: pandas 1.0.0+. For <1.0.0, this enum will - #: fall back on the str-as-object-array representation. - STRING = "string" - - @property - def str_alias(self): - """Get datatype string alias.""" - return { - "int": _DEFAULT_PANDAS_INT_TYPE, - "float": _DEFAULT_PANDAS_FLOAT_TYPE, - "complex": _DEFAULT_PANDAS_COMPLEX_TYPE, - "str": "object", - "string": "object" if LEGACY_PANDAS else "string", - }.get(self.value, self.value) - - @classmethod - def from_str_alias(cls, str_alias: str) -> "PandasDtype": - """Get PandasDtype from string alias. - - :param: pandas dtype string alias from - https://pandas.pydata.org/pandas-docs/stable/getting_started/basics.html#basics-dtypes - :returns: pandas dtype - """ - pandas_dtype = { - "bool": cls.Bool, - "datetime64[ns]": cls.DateTime, - "timedelta64[ns]": cls.Timedelta, - "category": cls.Category, - "float": cls.Float, - "float16": cls.Float16, - "float32": cls.Float32, - "float64": cls.Float64, - "int": cls.Int, - "int8": cls.Int8, - "int16": cls.Int16, - "int32": cls.Int32, - "int64": cls.Int64, - "uint8": cls.UInt8, - "uint16": cls.UInt16, - "uint32": cls.UInt32, - "uint64": cls.UInt64, - "Int8": cls.INT8, - "Int16": cls.INT16, - "Int32": cls.INT32, - "Int64": cls.INT64, - "UInt8": cls.UINT8, - "UInt16": cls.UINT16, - "UInt32": cls.UINT32, - "UInt64": cls.UINT64, - "object": cls.Object, - "complex": cls.Complex, - "complex64": cls.Complex64, - "complex128": cls.Complex128, - "complex256": cls.Complex256, - "str": cls.String, - "string": cls.String if LEGACY_PANDAS else cls.STRING, - }.get(str_alias) - - if pandas_dtype is None: - raise TypeError( - f"pandas dtype string alias '{str_alias}' not recognized" - ) + def _wrapper(pandera_dtype_cls: _DataTypeClass) -> _DataTypeClass: + immutable_dtype = dataclasses.dataclass(**kwargs)(pandera_dtype_cls) + if not kwargs["init"]: - return pandas_dtype - - @classmethod - def from_pandas_api_type(cls, pandas_api_type: str) -> "PandasDtype": - """Get PandasDtype enum from pandas api type. - - :param pandas_api_type: string output from - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.api.types.infer_dtype.html - :returns: pandas dtype - """ - if pandas_api_type.startswith("mixed"): - return cls.Object - - pandas_dtype = { - "string": cls.String, - "floating": cls.Float, - "integer": cls.Int, - "categorical": cls.Category, - "boolean": cls.Bool, - "datetime64": cls.DateTime, - "datetime": cls.DateTime, - "timedelta64": cls.Timedelta, - "timedelta": cls.Timedelta, - }.get(pandas_api_type) - - if pandas_dtype is None: - raise TypeError( - f"pandas api type '{pandas_api_type}' not recognized" - ) + def __init__(self): # pylint:disable=unused-argument + pass - return pandas_dtype - - @classmethod - def from_python_type(cls, python_type: Type) -> "PandasDtype": - """Get PandasDtype enum from built-in python type. - - :param python_type: built-in python type. Allowable types are: - str, int, float, and bool. - """ - pandas_dtype = { - bool: cls.Bool, - str: cls.String, - int: cls.Int, - float: cls.Float, - object: cls.Object, - complex: cls.Complex, - }.get(python_type) - - if pandas_dtype is None: - raise TypeError( - f"python type '{python_type}' not recognized as pandas data type" - ) + # delattr(immutable_dtype, "__init__") doesn't work because + # super.__init__ would still exist. + setattr(immutable_dtype, "__init__", __init__) + + return immutable_dtype + + if pandera_dtype_cls is None: + return _wrapper + + return _wrapper(pandera_dtype_cls) + + +############################################################################### +# number +############################################################################### + + +@immutable +class _Number(DataType): + """Semantic representation of a numeric data type.""" + + exact: Optional[bool] = None + """Whether the data type is an exact representation of a number.""" + + def check(self, pandera_dtype: "DataType") -> bool: + if self.__class__ is _Number: + return isinstance(pandera_dtype, _Number) + return super().check(pandera_dtype) + + +@immutable + +class _PhysicalNumber(_Number): + bit_width: Optional[int] = None + """Number of bits used by the machine representation.""" + _base_name: Optional[str] = dataclasses.field( + default=None, init=False, repr=False + ) + + def __eq__(self, obj: object) -> bool: + if isinstance(obj, type(self)): + return obj.bit_width == self.bit_width + return super().__eq__(obj) + + def __str__(self) -> str: + return f"{self._base_name}{self.bit_width}" + + +############################################################################### +# boolean +############################################################################### + + +@immutable +class Bool(_Number): + """Semantic representation of a boolean data type.""" + + def __str__(self) -> str: + return "bool" + + +############################################################################### +# signed integer +############################################################################### - return pandas_dtype - - @classmethod - def from_numpy_type(cls, numpy_type: Type[np.generic]) -> "PandasDtype": - """Get PandasDtype enum from numpy type. - - :param numpy_type: numpy data type. - """ - pd_dtype = pd.api.types.pandas_dtype(numpy_type) - return cls.from_str_alias(pd_dtype.name) - - @classmethod - def get_dtype( - cls, - pandas_dtype_arg: Union[ - str, - type, - "PandasDtype", - "pd.core.dtypes.dtypes.ExtensionDtype", - np.dtype, - ], - ) -> Optional[ - Union["PandasDtype", "pd.core.dtypes.dtypes.ExtensionDtype"] - ]: - """Get PandasDtype from schema argument. - - :param pandas_dtype_arg: ``pandas_dtype`` argument specified in schema - definition. - """ - dtype_ = pandas_dtype_arg - - if dtype_ is None: - return dtype_ - elif isinstance(dtype_, PandasDtype): - return pandas_dtype_arg - elif is_extension_dtype(dtype_): - if isinstance(dtype_, type): - try: - # Convert to str here because some pandas dtypes allow - # an empty constructor for compatibility but fail on - # str(). e.g: PeriodDtype - str(dtype_().name) - return dtype_() - except (TypeError, AttributeError) as err: - raise TypeError( - f"Pandas dtype {dtype_} cannot be instantiated: " - f"{err}\n Usage Tip: Use an instance or a string " - "representation." - ) from err - return dtype_ - - if dtype_ in NUMPY_TYPES: - dtype_ = cls.from_numpy_type(dtype_) # type: ignore - elif isinstance(dtype_, str): - dtype_ = cls.from_str_alias(dtype_) - elif isinstance(dtype_, type): - dtype_ = cls.from_python_type(dtype_) - - if isinstance(dtype_, PandasDtype): - return dtype_ - raise TypeError( - "type of `pandas_dtype` argument not recognized: " - f"{type(pandas_dtype_arg)}. Please specify a pandera PandasDtype " - "enum, built-in python type, pandas data type, pandas data type " - "string alias, or numpy data type string alias" + +@immutable(eq=False) +class Int(_PhysicalNumber): # type: ignore + """Semantic representation of an integer data type.""" + + _base_name = "int" + exact = True + bit_width = 64 + signed: bool = dataclasses.field(default=True, init=False) + """Whether the integer data type is signed.""" + + def check(self, pandera_dtype: DataType) -> bool: + return ( + isinstance(pandera_dtype, Int) + and self.signed == pandera_dtype.signed + and self.bit_width == pandera_dtype.bit_width ) - @classmethod - def get_str_dtype(cls, pandas_dtype_arg) -> Optional[str]: - """Get pandas-compatible string representation of dtype.""" - pandas_dtype = cls.get_dtype(pandas_dtype_arg) - if pandas_dtype is None: - return pandas_dtype - elif isinstance(pandas_dtype, PandasDtype): - return pandas_dtype.str_alias - return str(pandas_dtype) - - def __eq__(self, other): - # pylint: disable=comparison-with-callable - # see https://github.com/PyCQA/pylint/issues/2306 - if other is None: - return False - other_dtype = PandasDtype.get_dtype(other) - if self.value == "string" and LEGACY_PANDAS: - return PandasDtype.String.value == other_dtype.value - elif self.value == "string": - return self.value == other_dtype.value - return self.str_alias == other_dtype.str_alias - - def __hash__(self): - if self is PandasDtype.Int: - hash_obj = _DEFAULT_PANDAS_INT_TYPE - elif self is PandasDtype.Float: - hash_obj = _DEFAULT_PANDAS_FLOAT_TYPE - else: - hash_obj = self.str_alias - return id(hash_obj) - - @property - def numpy_dtype(self): - """Get numpy data type.""" - if self is PandasDtype.Category: - raise TypeError( - "the pandas Categorical data type doesn't have a numpy " - "equivalent." - ) + def __str__(self) -> str: + if self.__class__ is Int: + return "int" + return super().__str__() + + +@immutable +class Int64(Int): + """Semantic representation of an integer data type stored in 64 bits.""" + + bit_width = 64 + + +@immutable +class Int32(Int64): + """Semantic representation of an integer data type stored in 32 bits.""" + + bit_width = 32 + + +@immutable +class Int16(Int32): + """Semantic representation of an integer data type stored in 16 bits.""" + + bit_width = 16 + + +@immutable +class Int8(Int16): + """Semantic representation of an integer data type stored in 8 bits.""" + + bit_width = 8 + + +############################################################################### +# unsigned integer +############################################################################### + + +@immutable +class UInt(Int): + """Semantic representation of an unsigned integer data type.""" + + _base_name = "uint" + signed: bool = dataclasses.field(default=False, init=False) + + def __str__(self) -> str: + if self.__class__ is UInt: + return "uint" + return super().__str__() + + +@immutable +class UInt64(UInt): + """Semantic representation of an unsigned integer data type stored + in 64 bits.""" + + bit_width = 64 + + +@immutable +class UInt32(UInt64): + """Semantic representation of an unsigned integer data type stored + in 32 bits.""" + + bit_width = 32 + + +@immutable +class UInt16(UInt32): + """Semantic representation of an unsigned integer data type stored + in 16 bits.""" + + bit_width = 16 + + +@immutable +class UInt8(UInt16): + """Semantic representation of an unsigned integer data type stored + in 8 bits.""" + + bit_width = 8 + + +############################################################################### +# float +############################################################################### + + +@immutable(eq=False) +class Float(_PhysicalNumber): # type: ignore + """Semantic representation of a floating data type.""" - # pylint: disable=comparison-with-callable - if self.value in {"str", "string"}: - dtype = np.dtype("str") - else: - dtype = np.dtype(self.str_alias.lower()) - return dtype - - @property - def is_int(self) -> bool: - """Return True if PandasDtype is an integer.""" - return self.value.lower().startswith("int") - - @property - def is_nullable_int(self) -> bool: - """Return True if PandasDtype is a nullable integer.""" - return self.value.startswith("Int") - - @property - def is_nonnullable_int(self) -> bool: - """Return True if PandasDtype is a non-nullable integer.""" - return self.value.startswith("int") - - @property - def is_uint(self) -> bool: - """Return True if PandasDtype is an unsigned integer.""" - return self.value.lower().startswith("uint") - - @property - def is_nullable_uint(self) -> bool: - """Return True if PandasDtype is a nullable unsigned integer.""" - return self.value.startswith("UInt") - - @property - def is_nonnullable_uint(self) -> bool: - """Return True if PandasDtype is a non-nullable unsigned integer.""" - return self.value.startswith("uint") - - @property - def is_float(self) -> bool: - """Return True if PandasDtype is a float.""" - return self.value.startswith("float") - - @property - def is_complex(self) -> bool: - """Return True if PandasDtype is a complex number.""" - return self.value.startswith("complex") - - @property - def is_bool(self) -> bool: - """Return True if PandasDtype is a boolean.""" - return self is PandasDtype.Bool - - @property - def is_string(self) -> bool: - """Return True if PandasDtype is a string.""" - return self in [PandasDtype.String, PandasDtype.STRING] - - @property - def is_category(self) -> bool: - """Return True if PandasDtype is a category.""" - return self is PandasDtype.Category - - @property - def is_datetime(self) -> bool: - """Return True if PandasDtype is a datetime.""" - return self is PandasDtype.DateTime - - @property - def is_timedelta(self) -> bool: - """Return True if PandasDtype is a timedelta.""" - return self is PandasDtype.Timedelta - - @property - def is_object(self) -> bool: - """Return True if PandasDtype is an object.""" - return self is PandasDtype.Object - - @property - def is_continuous(self) -> bool: - """Return True if PandasDtype is a continuous datatype.""" + _base_name = "float" + continuous = True + exact = False + bit_width = 64 + + def check(self, pandera_dtype: DataType) -> bool: return ( - self.is_int - or self.is_uint - or self.is_float - or self.is_complex - or self.is_datetime - or self.is_timedelta + isinstance(pandera_dtype, Float) + and self.bit_width == pandera_dtype.bit_width ) + + def __str__(self) -> str: + if self.__class__ is Float: + return "float" + return super().__str__() + + +@immutable +class Float128(Float): + """Semantic representation of a floating data type stored in 128 bits.""" + + bit_width = 128 + + +@immutable +class Float64(Float128): + """Semantic representation of a floating data type stored in 64 bits.""" + + bit_width = 64 + + +@immutable +class Float32(Float64): + """Semantic representation of a floating data type stored in 32 bits.""" + + bit_width = 32 + + +@immutable +class Float16(Float32): + """Semantic representation of a floating data type stored in 16 bits.""" + + bit_width = 16 + + +############################################################################### +# complex +############################################################################### + + +@immutable(eq=False) +class Complex(_PhysicalNumber): # type: ignore + """Semantic representation of a complex number data type.""" + + _base_name = "complex" + bit_width = 128 + + def check(self, pandera_dtype: DataType) -> bool: + return ( + isinstance(pandera_dtype, Complex) + and self.bit_width == pandera_dtype.bit_width + ) + + def __str__(self) -> str: + if self.__class__ is Complex: + return "complex" + return super().__str__() + + +@immutable +class Complex256(Complex): + """Semantic representation of a complex number data type stored + in 256 bits.""" + + bit_width = 256 + + +@immutable +class Complex128(Complex): + """Semantic representation of a complex number data type stored + in 128 bits.""" + + bit_width = 128 + + +@immutable +class Complex64(Complex128): + """Semantic representation of a complex number data type stored + in 64 bits.""" + + bit_width = 64 + + +############################################################################### +# nominal +############################################################################### + + +@immutable(init=True) +class Category(DataType): # type: ignore + """Semantic representation of a categorical data type.""" + + categories: Optional[Tuple[Any]] = None # tuple to ensure safe hash + ordered: bool = False + + def __init__( + self, categories: Optional[Iterable[Any]] = None, ordered: bool = False + ): + # Define __init__ to avoid exposing pylint errors to end users. + super().__init__() + if categories is not None: + object.__setattr__(self, "categories", tuple(categories)) + object.__setattr__(self, "ordered", ordered) + + def check(self, pandera_dtype: "DataType") -> bool: + if isinstance(pandera_dtype, Category) and ( + self.categories is None or pandera_dtype.categories is None + ): + # Category without categories is a superset of any Category + # Allow end-users to not list categories when validating. + return True + + return super().check(pandera_dtype) + + def __str__(self) -> str: + return "category" + + +@immutable +class String(DataType): + """Semantic representation of a string data type.""" + + def __str__(self) -> str: + return "string" + + +############################################################################### +# time +############################################################################### + + +@immutable +class Date(DataType): + """Semantic representation of a date data type.""" + + def __str__(self) -> str: + return "date" + + +@immutable +class Timestamp(Date): + """Semantic representation of a timestamp data type.""" + + def __str__(self) -> str: + return "timestamp" + + +DateTime = Timestamp + + +@immutable +class Timedelta(DataType): + """Semantic representation of a delta time data type.""" + + def __str__(self) -> str: + return "timedelta" + + +############################################################################### +# Utilities +############################################################################### + + +def is_subdtype( + arg1: Union[DataType, Type[DataType]], + arg2: Union[DataType, Type[DataType]], +) -> bool: + """Returns True if first argument is lower/equal in DataType hierarchy.""" + arg1_cls = arg1 if inspect.isclass(arg1) else arg1.__class__ + arg2_cls = arg2 if inspect.isclass(arg2) else arg2.__class__ + return issubclass(arg1_cls, arg2_cls) # type: ignore + + +def is_int(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is an integer.""" + return is_subdtype(pandera_dtype, Int) + + +def is_uint(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is + an unsigned integer.""" + return is_subdtype(pandera_dtype, UInt) + + +def is_float(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a float.""" + return is_subdtype(pandera_dtype, Float) + + +def is_complex(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a complex number.""" + return is_subdtype(pandera_dtype, Complex) + + +def is_numeric(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a complex number.""" + return is_subdtype(pandera_dtype, _Number) + + +def is_bool(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a boolean.""" + return is_subdtype(pandera_dtype, Bool) + + +def is_string(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a string.""" + return is_subdtype(pandera_dtype, String) + + +def is_category(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a category.""" + return is_subdtype(pandera_dtype, Category) + + +def is_datetime(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a datetime.""" + return is_subdtype(pandera_dtype, DateTime) + + +def is_timedelta(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a timedelta.""" + return is_subdtype(pandera_dtype, Timedelta) diff --git a/pandera/engines/__init__.py b/pandera/engines/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pandera/engines/engine.py b/pandera/engines/engine.py new file mode 100644 index 000000000..3a6819fda --- /dev/null +++ b/pandera/engines/engine.py @@ -0,0 +1,219 @@ +"""Data types engine interface.""" +# https://github.com/PyCQA/pylint/issues/3268 +# pylint:disable=no-value-for-parameter +import functools +import inspect +from abc import ABCMeta +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + get_type_hints, +) + +import typing_inspect + +from pandera.dtypes import DataType + +_DataType = TypeVar("_DataType", bound=DataType) +_Engine = TypeVar("_Engine", bound="Engine") +_EngineType = Type[_Engine] + +if TYPE_CHECKING: + + class Dispatch: + """Only used for type annotation.""" + + def __call__(self, data_type: Any, **kwds: Any) -> Any: + pass + + @staticmethod + def register( + data_type: Any, func: Callable[[Any], DataType] + ) -> Callable[[Any], DataType]: + """Register a new implementation for the given cls.""" + + +else: + Dispatch = Callable[[Any], DataType] + + +@dataclass +class _DtypeRegistry: + dispatch: Dispatch + equivalents: Dict[Any, DataType] + + +class Engine(ABCMeta): + """Base Engine metaclass. + + Keep a registry of concrete Engines. + """ + + _registry: Dict["Engine", _DtypeRegistry] = {} + _registered_dtypes: Set[Type[DataType]] + _base_pandera_dtypes: Tuple[Type[DataType]] + + def __new__(cls, name, bases, namespace, **kwargs): + base_pandera_dtypes = kwargs.pop("base_pandera_dtypes") + try: + namespace["_base_pandera_dtypes"] = tuple(base_pandera_dtypes) + except TypeError: + namespace["_base_pandera_dtypes"] = (base_pandera_dtypes,) + + namespace["_registered_dtypes"] = set() + engine = super().__new__(cls, name, bases, namespace, **kwargs) + + @functools.singledispatch + def dtype(data_type: Any) -> DataType: + raise ValueError(f"Data type '{data_type}' not understood") + + cls._registry[engine] = _DtypeRegistry(dispatch=dtype, equivalents={}) + return engine + + def _check_source_dtype(cls, data_type: Any) -> None: + if isinstance(data_type, cls._base_pandera_dtypes) or ( + inspect.isclass(data_type) + and issubclass(data_type, cls._base_pandera_dtypes) + ): + base_names = [ + f"{base.__module__}.{base.__qualname__}" + for base in cls._base_pandera_dtypes + ] + raise ValueError( + f"Subclasses of {base_names} cannot be registered" + f" with {cls.__name__}." + ) + + def _register_from_parametrized_dtype( + cls, + pandera_dtype_cls: Type[DataType], + ) -> None: + method = pandera_dtype_cls.__dict__["from_parametrized_dtype"] + if not isinstance(method, classmethod): + raise ValueError( + f"{pandera_dtype_cls.__name__}.from_parametrized_dtype " + + "must be a classmethod." + ) + func = method.__func__ + annotations = get_type_hints(func).values() + dtype = next(iter(annotations)) # get 1st annotation + # parse typing.Union + dtypes = typing_inspect.get_args(dtype) or [dtype] + + def _method(*args, **kwargs): + return func(pandera_dtype_cls, *args, **kwargs) + + for source_dtype in dtypes: + cls._check_source_dtype(source_dtype) + cls._registry[cls].dispatch.register(source_dtype, _method) + + def _register_equivalents( + cls, pandera_dtype_cls: Type[DataType], *source_dtypes: Any + ) -> None: + pandera_dtype = pandera_dtype_cls() # type: ignore + for source_dtype in source_dtypes: + cls._check_source_dtype(source_dtype) + cls._registry[cls].equivalents[source_dtype] = pandera_dtype + + def register_dtype( + cls: _EngineType, + pandera_dtype_cls: Type[_DataType] = None, + *, + equivalents: Optional[List[Any]] = None, + ) -> Callable: + """Register a Pandera :class:`~pandera.dtypes.DataType` with the engine, + as class decorator. + + :param pandera_dtype: The DataType to register. + :param equivalents: Equivalent scalar data type classes or + non-parametrized data type instances. + + .. note:: + The classmethod ``from_parametrized_dtype`` will also be + registered. See :ref:`here` for more usage details. + + :example: + + >>> import pandera as pa + >>> + >>> class MyDataType(pa.DataType): + ... pass + >>> + >>> class MyEngine( + ... metaclass=pa.engines.engine.Engine, + ... base_pandera_dtypes=MyDataType, + ... ): + ... pass + >>> + >>> @MyEngine.register_dtype(equivalents=[bool]) + ... class MyBool(MyDataType): + ... pass + + """ + + def _wrapper(pandera_dtype_cls: Type[_DataType]) -> Type[_DataType]: + if not inspect.isclass(pandera_dtype_cls): + raise ValueError( + f"{cls.__name__}.register_dtype can only decorate a class," + f" got {pandera_dtype_cls}" + ) + + if equivalents: + cls._register_equivalents(pandera_dtype_cls, *equivalents) + + if "from_parametrized_dtype" in pandera_dtype_cls.__dict__: + cls._register_from_parametrized_dtype(pandera_dtype_cls) + + cls._registered_dtypes.add(pandera_dtype_cls) + return pandera_dtype_cls + + if pandera_dtype_cls: + return _wrapper(pandera_dtype_cls) + + return _wrapper + + def dtype(cls: _EngineType, data_type: Any) -> _DataType: + """Convert input into a Pandera :class:`DataType` object.""" + if isinstance(data_type, cls._base_pandera_dtypes): + return data_type + + if inspect.isclass(data_type) and issubclass( + data_type, cls._base_pandera_dtypes + ): + try: + return data_type() + except (TypeError, AttributeError) as err: + raise TypeError( + f"DataType '{data_type.__name__}' cannot be instantiated: " + f"{err}\n " + + "Usage Tip: Use an instance or a string representation." + ) from err + + registry = cls._registry[cls] + + equivalent_data_type = registry.equivalents.get(data_type) + if equivalent_data_type is not None: + return equivalent_data_type + + try: + return registry.dispatch(data_type) + except (KeyError, ValueError): + raise TypeError( + f"Data type '{data_type}' not understood by {cls.__name__}." + ) from None + + def get_registered_dtypes( # pylint:disable=W1401 + cls, + ) -> List[Type[DataType]]: + """Return the :class:`pandera.dtypes.DataType`\s registered + with this engine.""" + return list(cls._registered_dtypes) diff --git a/pandera/engines/numpy_engine.py b/pandera/engines/numpy_engine.py new file mode 100644 index 000000000..1a6982895 --- /dev/null +++ b/pandera/engines/numpy_engine.py @@ -0,0 +1,361 @@ +"""Numpy engine and data types.""" +# docstrings are inherited +# pylint:disable=missing-class-docstring,too-many-ancestors +import builtins +import dataclasses +import datetime +import inspect +import platform +import warnings +from typing import Any, Dict, List, Union + +import numpy as np + +from .. import dtypes +from ..dtypes import immutable +from . import engine + +WINDOWS_PLATFORM = platform.system() == "Windows" + + +@immutable(init=True) +class DataType(dtypes.DataType): + """Base `DataType` for boxing Numpy data types.""" + + type: np.dtype = dataclasses.field( + default=np.dtype("object"), repr=False, init=False + ) + """Native numpy dtype boxed by the data type.""" + + def __init__(self, dtype: Any): + super().__init__() + object.__setattr__(self, "type", np.dtype(dtype)) + dtype_cls = dtype if inspect.isclass(dtype) else dtype.__class__ + warnings.warn( + f"'{dtype_cls}' support is not guaranteed.\n" + + "Usage Tip: Consider writing a custom " + + "pandera.dtypes.DataType or opening an issue at " + + "https://github.com/pandera-dev/pandera" + ) + + def __post_init__(self): + object.__setattr__(self, "type", np.dtype(self.type)) + + def coerce(self, data_container: np.ndarray) -> np.ndarray: + return data_container.astype(self.type) + + def __str__(self) -> str: + return self.type.name + + def __repr__(self) -> str: + return f"DataType({self})" + + +class Engine( # pylint:disable=too-few-public-methods + metaclass=engine.Engine, base_pandera_dtypes=DataType +): + """Numpy data type engine.""" + + @classmethod + def dtype(cls, data_type: Any) -> dtypes.DataType: + """Convert input into a numpy-compatible + Pandera :class:`~pandera.dtypes.DataType` object.""" + try: + return engine.Engine.dtype(cls, data_type) + except TypeError: + try: + np_dtype = np.dtype(data_type).type + except TypeError: + raise TypeError( + f"data type '{data_type}' not understood by " + f"{cls.__name__}." + ) from None + + try: + return engine.Engine.dtype(cls, np_dtype) + except TypeError: + return DataType(data_type) + + +############################################################################### +# boolean +############################################################################### + + +@Engine.register_dtype( + equivalents=["bool", bool, np.bool_, dtypes.Bool, dtypes.Bool()] +) +@immutable +class Bool(DataType, dtypes.Bool): + type = np.dtype("bool") + + +def _build_number_equivalents( + builtin_name: str, pandera_name: str, sizes: List[int] +) -> Dict[int, List[Union[type, str, np.dtype, dtypes.DataType]]]: + """Return a dict of equivalent builtin, numpy, pandera dtypes + indexed by size in bit_width.""" + builtin_type = getattr(builtins, builtin_name, None) + default_np_dtype = np.dtype(builtin_name) + default_size = int(default_np_dtype.name.replace(builtin_name, "")) + + default_equivalents = [ + # e.g.: np.int64 + np.dtype(builtin_name).type, + # e.g: pandera.dtypes.Int + getattr(dtypes, pandera_name), + ] + if builtin_type: + default_equivalents.append(builtin_type) + + return { + bit_width: list( + set( + ( + # e.g.: numpy.int64 + getattr(np, f"{builtin_name}{bit_width}"), + # e.g.: pandera.dtypes.Int64 + getattr(dtypes, f"{pandera_name}{bit_width}"), + getattr(dtypes, f"{pandera_name}{bit_width}")(), + # e.g.: pandera.dtypes.Int(64) + getattr(dtypes, pandera_name)(), + ) + ) + | set(default_equivalents if bit_width == default_size else []) + ) + for bit_width in sizes + } + + +############################################################################### +# signed integer +############################################################################### + +_int_equivalents = _build_number_equivalents( + builtin_name="int", pandera_name="Int", sizes=[64, 32, 16, 8] +) + + +@Engine.register_dtype(equivalents=_int_equivalents[64]) +@immutable +class Int64(DataType, dtypes.Int64): + + type = np.dtype("int64") + bit_width: int = 64 + + +@Engine.register_dtype(equivalents=_int_equivalents[32]) +@immutable +class Int32(Int64): + type = np.dtype("int32") # type: ignore + bit_width: int = 32 + + +@Engine.register_dtype(equivalents=_int_equivalents[16]) +@immutable +class Int16(Int32): + type = np.dtype("int16") # type: ignore + bit_width: int = 16 + + +@Engine.register_dtype(equivalents=_int_equivalents[8]) +@immutable +class Int8(Int16): + type = np.dtype("int8") # type: ignore + bit_width: int = 8 + + +############################################################################### +# unsigned integer +############################################################################### + +_uint_equivalents = _build_number_equivalents( + builtin_name="uint", + pandera_name="UInt", + sizes=[64, 32, 16, 8], +) + + +@Engine.register_dtype(equivalents=_uint_equivalents[64]) +@immutable +class UInt64(DataType, dtypes.UInt64): + type = np.dtype("uint64") + bit_width: int = 64 + + +@Engine.register_dtype(equivalents=_uint_equivalents[32]) +@immutable +class UInt32(UInt64): + type = np.dtype("uint32") # type: ignore + bit_width: int = 32 + + +@Engine.register_dtype(equivalents=_uint_equivalents[16]) +@immutable +class UInt16(UInt32): + type = np.dtype("uint16") # type: ignore + bit_width: int = 16 + + +@Engine.register_dtype(equivalents=_uint_equivalents[8]) +@immutable +class UInt8(UInt16): + type = np.dtype("uint8") # type: ignore + bit_width: int = 8 + + +############################################################################### +# float +############################################################################### + +_float_equivalents = _build_number_equivalents( + builtin_name="float", + pandera_name="Float", + sizes=[64, 32, 16] if WINDOWS_PLATFORM else [128, 64, 32, 16], +) + + +if not WINDOWS_PLATFORM: + # not supported in windows + # https://github.com/winpython/winpython/issues/613 + @Engine.register_dtype(equivalents=_float_equivalents[128]) + @immutable + class Float128(DataType, dtypes.Float128): + type = np.dtype("float128") + bit_width: int = 128 + + @Engine.register_dtype(equivalents=_float_equivalents[64]) + @immutable + class Float64(Float128): + type = np.dtype("float64") + bit_width: int = 64 + + +else: + + @Engine.register_dtype(equivalents=_float_equivalents[64]) + @immutable + class Float64(DataType, dtypes.Float64): # type: ignore + type = np.dtype("float64") + bit_width: int = 64 + + +@Engine.register_dtype(equivalents=_float_equivalents[32]) +@immutable +class Float32(Float64): + type = np.dtype("float32") # type: ignore + bit_width: int = 32 + + +@Engine.register_dtype(equivalents=_float_equivalents[16]) +@immutable +class Float16(Float32): + type = np.dtype("float16") # type: ignore + bit_width: int = 16 + + +############################################################################### +# complex +############################################################################### + +_complex_equivalents = _build_number_equivalents( + builtin_name="complex", + pandera_name="Complex", + sizes=[128, 64] if WINDOWS_PLATFORM else [256, 128, 64], +) + + +if not WINDOWS_PLATFORM: + # not supported in windows + # https://github.com/winpython/winpython/issues/613 + @Engine.register_dtype(equivalents=_complex_equivalents[256]) + @immutable + class Complex256(DataType, dtypes.Complex256): + type = np.dtype("complex256") + bit_width: int = 256 + + @Engine.register_dtype(equivalents=_complex_equivalents[128]) + @immutable + class Complex128(Complex256): + type = np.dtype("complex128") # type: ignore + bit_width: int = 128 + + +else: + + @Engine.register_dtype(equivalents=_complex_equivalents[128]) + @immutable + class Complex128(DataType, dtypes.Complex128): # type: ignore + type = np.dtype("complex128") # type: ignore + bit_width: int = 128 + + +@Engine.register_dtype(equivalents=_complex_equivalents[64]) +@immutable +class Complex64(Complex128): + type = np.dtype("complex64") # type: ignore + bit_width: int = 64 + + +############################################################################### +# string +############################################################################### + + +@Engine.register_dtype(equivalents=["str", "string", str, np.str_]) +@immutable +class String(DataType, dtypes.String): + type = np.dtype("str") + + def coerce(self, data_container: np.ndarray) -> np.ndarray: + data_container = data_container.astype(object) + notna = ~np.isnan(data_container) + data_container[notna] = data_container[notna].astype(str) + return data_container + + def check(self, pandera_dtype: "dtypes.DataType") -> bool: + return isinstance(pandera_dtype, (Object, type(self))) + + +############################################################################### +# object +############################################################################### + + +@Engine.register_dtype(equivalents=["object", "O", object, np.object_]) +@immutable +class Object(DataType): + """Semantic representation of a :class:`numpy.object_`.""" + + type = np.dtype("object") + + +############################################################################### +# time +############################################################################### + + +@Engine.register_dtype( + equivalents=[ + datetime.datetime, + np.datetime64, + dtypes.Timestamp, + dtypes.Timestamp(), + ] +) +@immutable +class DateTime64(DataType, dtypes.Timestamp): + type = np.dtype("datetime64") + + +@Engine.register_dtype( + equivalents=[ + datetime.datetime, + np.timedelta64, + dtypes.Timedelta, + dtypes.Timedelta(), + ] +) +@immutable +class Timedelta64(DataType, dtypes.Timedelta): + type = np.dtype("timedelta64[ns]") diff --git a/pandera/engines/pandas_engine.py b/pandera/engines/pandas_engine.py new file mode 100644 index 000000000..03b0201b9 --- /dev/null +++ b/pandera/engines/pandas_engine.py @@ -0,0 +1,704 @@ +"""Pandas engine and data types.""" +# pylint:disable=too-many-ancestors + +# docstrings are inherited +# pylint:disable=missing-class-docstring + +# pylint doesn't know about __init__ generated with dataclass +# pylint:disable=unexpected-keyword-arg,no-value-for-parameter +import builtins +import dataclasses +import datetime +import inspect +import platform +import warnings +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional, Union + +import numpy as np +import pandas as pd +from packaging import version + +from .. import dtypes +from ..dtypes import immutable +from . import engine, numpy_engine + + +def pandas_version(): + """Return the pandas version.""" + + return version.parse(pd.__version__) + + +PANDAS_1_3_0_PLUS = pandas_version().release >= (1, 3, 0) + +try: + from typing import Literal # type: ignore +except ImportError: + from typing_extensions import Literal # type: ignore + + +WINDOWS_PLATFORM = platform.system() == "Windows" + +PandasObject = Union[pd.Series, pd.Index, pd.DataFrame] +PandasExtensionType = pd.core.dtypes.base.ExtensionDtype +PandasDataType = Union[pd.core.dtypes.base.ExtensionDtype, np.dtype, type] + + +def is_extension_dtype(pd_dtype: PandasDataType) -> bool: + """Check if a value is a pandas extension type or instance of one.""" + return isinstance(pd_dtype, PandasExtensionType) or ( + isinstance(pd_dtype, type) + and issubclass(pd_dtype, PandasExtensionType) + ) + + +@immutable(init=True) +class DataType(dtypes.DataType): + """Base `DataType` for boxing Pandas data types.""" + + type: Any = dataclasses.field(repr=False, init=False) + """Native pandas dtype boxed by the data type.""" + + def __init__(self, dtype: Any): + super().__init__() + object.__setattr__(self, "type", pd.api.types.pandas_dtype(dtype)) + dtype_cls = dtype if inspect.isclass(dtype) else dtype.__class__ + warnings.warn( + f"'{dtype_cls}' support is not guaranteed.\n" + + "Usage Tip: Consider writing a custom " + + "pandera.dtypes.DataType or opening an issue at " + + "https://github.com/pandera-dev/pandera" + ) + + def __post_init__(self): + object.__setattr__(self, "type", pd.api.types.pandas_dtype(self.type)) + + def coerce(self, data_container: PandasObject) -> PandasObject: + return data_container.astype(self.type) + + def check(self, pandera_dtype: dtypes.DataType) -> bool: + try: + pandera_dtype = Engine.dtype(pandera_dtype) + except TypeError: + return False + + # attempts to compare pandas native type if possible + # to let subclass inherit check + # (super will compare that DataType classes are exactly the same) + try: + return self.type == pandera_dtype.type or super().check( + pandera_dtype + ) + except TypeError: + return super().check(pandera_dtype) + + def __str__(self) -> str: + return str(self.type) + + def __repr__(self) -> str: + return f"DataType({self})" + + +class Engine( # pylint:disable=too-few-public-methods + metaclass=engine.Engine, + base_pandera_dtypes=(DataType, numpy_engine.DataType), +): + """Pandas data type engine.""" + + @classmethod + def dtype(cls, data_type: Any) -> "DataType": + """Convert input into a pandas-compatible + Pandera :class:`~pandera.dtypes.DataType` object.""" + try: + return engine.Engine.dtype(cls, data_type) + except TypeError: + if is_extension_dtype(data_type) and isinstance(data_type, type): + try: + np_or_pd_dtype = data_type() + # Convert to str here because some pandas dtypes allow + # an empty constructor for compatibility but fail on + # str(). e.g: PeriodDtype + str(np_or_pd_dtype.name) + except (TypeError, AttributeError) as err: + raise TypeError( + f" dtype {data_type} cannot be instantiated: {err}\n" + "Usage Tip: Use an instance or a string " + "representation." + ) from None + else: + # let pandas transform any acceptable value + # into a numpy or pandas dtype. + np_or_pd_dtype = pd.api.types.pandas_dtype(data_type) + if isinstance(np_or_pd_dtype, np.dtype): + np_or_pd_dtype = np_or_pd_dtype.type + + try: + return engine.Engine.dtype(cls, np_or_pd_dtype) + except TypeError: + return DataType(np_or_pd_dtype) + + @classmethod + def numpy_dtype(cls, pandera_dtype: dtypes.DataType) -> np.dtype: + """Convert a Pandera :class:`~pandera.dtypes.DataType + to a :class:`numpy.dtype`.""" + pandera_dtype = engine.Engine.dtype(cls, pandera_dtype) + + alias = str(pandera_dtype).lower() + if alias == "boolean": + alias = "bool" + elif alias.startswith("string"): + alias = "str" + return np.dtype(alias) + + +############################################################################### +# boolean +############################################################################### + + +Engine.register_dtype( + numpy_engine.Bool, + equivalents=["bool", bool, np.bool_, dtypes.Bool, dtypes.Bool()], +) + + +@Engine.register_dtype( + equivalents=["boolean", pd.BooleanDtype, pd.BooleanDtype()], +) +@immutable +class BOOL(DataType, dtypes.Bool): + """Semantic representation of a :class:`pandas.BooleanDtype`.""" + + type = pd.BooleanDtype() + + +############################################################################### +# number +############################################################################### + + +def _register_numpy_numbers( + builtin_name: str, pandera_name: str, sizes: List[int] +) -> None: + """Register pandera.engines.numpy_engine DataTypes + with the pandas engine.""" + + builtin_type = getattr(builtins, builtin_name, None) # uint doesn't exist + + # default to int64 regardless of OS + default_pd_dtype = { + "int": np.dtype("int64"), + "uint": np.dtype("uint64"), + }.get(builtin_name, pd.Series([1], dtype=builtin_name).dtype) + + for bit_width in sizes: + # e.g.: numpy.int64 + np_dtype = getattr(np, f"{builtin_name}{bit_width}") + + equivalents = set( + ( + np_dtype, + # e.g.: pandera.dtypes.Int64 + getattr(dtypes, f"{pandera_name}{bit_width}"), + getattr(dtypes, f"{pandera_name}{bit_width}")(), + ) + ) + + if np_dtype == default_pd_dtype: + equivalents |= set( + ( + default_pd_dtype, + builtin_name, + getattr(dtypes, pandera_name), + getattr(dtypes, pandera_name)(), + ) + ) + if builtin_type: + equivalents.add(builtin_type) + + # results from pd.api.types.infer_dtype + if builtin_type is float: + equivalents.add("floating") + equivalents.add("mixed-integer-float") + elif builtin_type is int: + equivalents.add("integer") + + numpy_data_type = getattr(numpy_engine, f"{pandera_name}{bit_width}") + Engine.register_dtype(numpy_data_type, equivalents=list(equivalents)) + + +############################################################################### +# signed integer +############################################################################### + +_register_numpy_numbers( + builtin_name="int", + pandera_name="Int", + sizes=[64, 32, 16, 8], +) + + +@Engine.register_dtype(equivalents=[pd.Int64Dtype, pd.Int64Dtype()]) +@immutable +class INT64(DataType, dtypes.Int): + """Semantic representation of a :class:`pandas.Int64Dtype`.""" + + type = pd.Int64Dtype() + bit_width: int = 64 + + +@Engine.register_dtype(equivalents=[pd.Int32Dtype, pd.Int32Dtype()]) +@immutable +class INT32(INT64): + """Semantic representation of a :class:`pandas.Int32Dtype`.""" + + type = pd.Int32Dtype() + bit_width: int = 32 + + +@Engine.register_dtype(equivalents=[pd.Int16Dtype, pd.Int16Dtype()]) +@immutable +class INT16(INT32): + """Semantic representation of a :class:`pandas.Int16Dtype`.""" + + type = pd.Int16Dtype() + bit_width: int = 16 + + +@Engine.register_dtype(equivalents=[pd.Int8Dtype, pd.Int8Dtype()]) +@immutable +class INT8(INT16): + """Semantic representation of a :class:`pandas.Int8Dtype`.""" + + type = pd.Int8Dtype() + bit_width: int = 8 + + +############################################################################### +# unsigned integer +############################################################################### + +_register_numpy_numbers( + builtin_name="uint", + pandera_name="UInt", + sizes=[64, 32, 16, 8], +) + + +@Engine.register_dtype(equivalents=[pd.UInt64Dtype, pd.UInt64Dtype()]) +@immutable +class UINT64(DataType, dtypes.UInt): + """Semantic representation of a :class:`pandas.UInt64Dtype`.""" + + type = pd.UInt64Dtype() + bit_width: int = 64 + + +@Engine.register_dtype(equivalents=[pd.UInt32Dtype, pd.UInt32Dtype()]) +@immutable +class UINT32(UINT64): + """Semantic representation of a :class:`pandas.UInt32Dtype`.""" + + type = pd.UInt32Dtype() + bit_width: int = 32 + + +@Engine.register_dtype(equivalents=[pd.UInt16Dtype, pd.UInt16Dtype()]) +@immutable +class UINT16(UINT32): + """Semantic representation of a :class:`pandas.UInt16Dtype`.""" + + type = pd.UInt16Dtype() + bit_width: int = 16 + + +@Engine.register_dtype(equivalents=[pd.UInt8Dtype, pd.UInt8Dtype()]) +@immutable +class UINT8(UINT16): + """Semantic representation of a :class:`pandas.UInt8Dtype`.""" + + type = pd.UInt8Dtype() + bit_width: int = 8 + + +# ############################################################################### +# # float +# ############################################################################### + +_register_numpy_numbers( + builtin_name="float", + pandera_name="Float", + sizes=[64, 32, 16] if WINDOWS_PLATFORM else [128, 64, 32, 16], +) + +# ############################################################################### +# # complex +# ############################################################################### + +_register_numpy_numbers( + builtin_name="complex", + pandera_name="Complex", + sizes=[128, 64] if WINDOWS_PLATFORM else [256, 128, 64], +) + +# ############################################################################### +# # nominal +# ############################################################################### + + +@Engine.register_dtype( + equivalents=[ + "category", + "categorical", + dtypes.Category, + pd.CategoricalDtype, + ] +) +@immutable(init=True) +class Category(DataType, dtypes.Category): + """Semantic representation of a :class:`pandas.CategoricalDtype`.""" + + type: pd.CategoricalDtype = dataclasses.field(default=None, init=False) + + def __init__( # pylint:disable=super-init-not-called + self, + categories: Optional[Iterable[Any]] = None, + ordered: bool = False, + ) -> None: + dtypes.Category.__init__(self, categories, ordered) + object.__setattr__( + self, + "type", + pd.CategoricalDtype(self.categories, self.ordered), + ) + + @classmethod + def from_parametrized_dtype( + cls, cat: Union[dtypes.Category, pd.CategoricalDtype] + ): + """Convert a categorical to + a Pandera :class:`pandera.dtypes.pandas_engine.Category`.""" + return cls( # type: ignore + categories=cat.categories, ordered=cat.ordered + ) + + +if PANDAS_1_3_0_PLUS: + + @Engine.register_dtype(equivalents=["string", pd.StringDtype]) + @immutable(init=True) + class STRING(DataType, dtypes.String): + """Semantic representation of a :class:`pandas.StringDtype`.""" + + type: pd.StringDtype = dataclasses.field(default=None, init=False) + storage: Optional[Literal["python", "pyarrow"]] = "python" + + def __post_init__(self): + if PANDAS_1_3_0_PLUS: + type_ = pd.StringDtype(self.storage) + else: + type_ = pd.StringDtype() + object.__setattr__(self, "type", type_) + + @classmethod + def from_parametrized_dtype(cls, pd_dtype: pd.StringDtype): + """Convert a :class:`pandas.StringDtype` to + a Pandera :class:`pandera.engines.pandas_engine.STRING`.""" + return cls(pd_dtype.storage) + + def __str__(self) -> str: + return repr(self.type) + + +else: + + @Engine.register_dtype( + equivalents=["string", pd.StringDtype, pd.StringDtype()] + ) # type: ignore + @immutable + class STRING(DataType, dtypes.String): # type: ignore + """Semantic representation of a :class:`pandas.StringDtype`.""" + + type = pd.StringDtype() + + +@Engine.register_dtype( + equivalents=["str", str, dtypes.String, dtypes.String(), np.str_] +) +@immutable +class NpString(numpy_engine.String): + """Specializes numpy_engine.String.coerce to handle pd.NA values.""" + + def coerce(self, data_container: PandasObject) -> np.ndarray: + # Convert to object first to avoid + # TypeError: object cannot be converted to an IntegerDtype + data_container = data_container.astype(object) + return data_container.where( + data_container.isna(), data_container.astype(str) + ) + + def check(self, pandera_dtype: dtypes.DataType) -> bool: + return isinstance(pandera_dtype, (numpy_engine.Object, type(self))) + + +Engine.register_dtype( + numpy_engine.Object, + equivalents=[ + "object", + "O", + "bytes", + "decimal", + "mixed-integer", + "mixed", + object, + np.object_, + ], +) + +# ############################################################################### +# # time +# ############################################################################### + + +_PandasDatetime = Union[np.datetime64, pd.DatetimeTZDtype] + + +@Engine.register_dtype( + equivalents=[ + "time", + "datetime", + "datetime64", + datetime.datetime, + np.datetime64, + dtypes.Timestamp, + dtypes.Timestamp(), + pd.Timestamp, + ] +) +@immutable(init=True) +class DateTime(DataType, dtypes.Timestamp): + type: Optional[_PandasDatetime] = dataclasses.field( + default=None, init=False + ) + unit: str = "ns" + tz: Optional[datetime.tzinfo] = None + to_datetime_kwargs: Dict[str, Any] = dataclasses.field( + default_factory=dict, compare=False, repr=False + ) + + def __post_init__(self): + if self.tz is None: + type_ = np.dtype("datetime64[ns]") + else: + type_ = pd.DatetimeTZDtype(self.unit, self.tz) + # DatetimeTZDtype converted tz to tzinfo for us + object.__setattr__(self, "tz", type_.tz) + + object.__setattr__(self, "type", type_) + + def coerce(self, data_container: PandasObject) -> PandasObject: + def _to_datetime(col: pd.Series) -> pd.Series: + col = pd.to_datetime(col, **self.to_datetime_kwargs) + return col.astype(self.type) + + if isinstance(data_container, pd.DataFrame): + # pd.to_datetime transforms a df input into a series. + # We actually want to coerce every columns. + return data_container.transform(_to_datetime) + return _to_datetime(data_container) + + @classmethod + def from_parametrized_dtype(cls, pd_dtype: pd.DatetimeTZDtype): + """Convert a :class:`pandas.DatetimeTZDtype` to + a Pandera :class:`pandera.engines.pandas_engine.DateTime`.""" + return cls(unit=pd_dtype.unit, tz=pd_dtype.tz) # type: ignore + + def __str__(self) -> str: + if self.type == np.dtype("datetime64[ns]"): + return "datetime64[ns]" + return str(self.type) + + +Engine.register_dtype( + numpy_engine.Timedelta64, + equivalents=[ + "timedelta", + "timedelta64", + datetime.timedelta, + np.timedelta64, + pd.Timedelta, + dtypes.Timedelta, + dtypes.Timedelta(), + ], +) + + +@Engine.register_dtype +@immutable(init=True) +class Period(DataType): + """Representation of pandas :class:`pd.Period`.""" + + type: pd.PeriodDtype = dataclasses.field(default=None, init=False) + freq: Union[str, pd.tseries.offsets.DateOffset] + + def __post_init__(self): + object.__setattr__(self, "type", pd.PeriodDtype(freq=self.freq)) + + @classmethod + def from_parametrized_dtype(cls, pd_dtype: pd.PeriodDtype): + """Convert a :class:`pandas.PeriodDtype` to + a Pandera :class:`pandera.engines.pandas_engine.Period`.""" + return cls(freq=pd_dtype.freq) # type: ignore + + +# ############################################################################### +# # misc +# ############################################################################### + + +@Engine.register_dtype(equivalents=[pd.SparseDtype]) +@immutable(init=True) +class Sparse(DataType): + """Representation of pandas :class:`pd.SparseDtype`.""" + + type: pd.SparseDtype = dataclasses.field(default=None, init=False) + dtype: PandasDataType = np.float_ + fill_value: Any = np.nan + + def __post_init__(self): + object.__setattr__( + self, + "type", + pd.SparseDtype(dtype=self.dtype, fill_value=self.fill_value), + ) + + @classmethod + def from_parametrized_dtype(cls, pd_dtype: pd.SparseDtype): + """Convert a :class:`pandas.SparseDtype` to + a Pandera :class:`pandera.engines.pandas_engine.Sparse`.""" + return cls( # type: ignore + dtype=pd_dtype.subtype, fill_value=pd_dtype.fill_value + ) + + +@Engine.register_dtype +@immutable(init=True) +class Interval(DataType): + """Representation of pandas :class:`pd.IntervalDtype`.""" + + type: pd.IntervalDtype = dataclasses.field(default=None, init=False) + subtype: Union[str, np.dtype] + + def __post_init__(self): + object.__setattr__( + self, "type", pd.IntervalDtype(subtype=self.subtype) + ) + + @classmethod + def from_parametrized_dtype(cls, pd_dtype: pd.IntervalDtype): + """Convert a :class:`pandas.IntervalDtype` to + a Pandera :class:`pandera.engines.pandas_engine.Interval`.""" + return cls(subtype=pd_dtype.subtype) # type: ignore + + +class PandasDtype(Enum): + # pylint: disable=line-too-long,invalid-name + """Enumerate all valid pandas data types. + + This class simply enumerates the valid numpy dtypes for pandas arrays. + For convenience ``PandasDtype`` enums can all be accessed in the top-level + ``pandera`` name space via the same enum name. + + .. warning:: + + This class is deprecated and will be removed in pandera v0.9.0. Use + python types, pandas type string aliases, numpy dtypes, or pandas + dtypes instead. See :ref:`dtypes` for details. + + :examples: + + >>> import pandas as pd + >>> import pandera as pa + >>> + >>> + >>> pa.SeriesSchema(pa.PandasDtype.Int).validate(pd.Series([1, 2, 3])) + 0 1 + 1 2 + 2 3 + dtype: int64 + >>> pa.SeriesSchema(pa.PandasDtype.Float).validate(pd.Series([1.1, 2.3, 3.4])) + 0 1.1 + 1 2.3 + 2 3.4 + dtype: float64 + >>> pa.SeriesSchema(pa.PandasDtype.String).validate(pd.Series(["a", "b", "c"])) + 0 a + 1 b + 2 c + dtype: object + + """ + + # numpy data types + Bool = "bool" #: ``"bool"`` numpy dtype + DateTime = "datetime64" #: ``"datetime64[ns]"`` numpy dtype + Timedelta = "timedelta64" #: ``"timedelta64[ns]"`` numpy dtype + Float = "float" #: ``"float"`` numpy dtype + Float16 = "float16" #: ``"float16"`` numpy dtype + Float32 = "float32" #: ``"float32"`` numpy dtype + Float64 = "float64" #: ``"float64"`` numpy dtype + Int = "int" #: ``"int"`` numpy dtype + Int8 = "int8" #: ``"int8"`` numpy dtype + Int16 = "int16" #: ``"int16"`` numpy dtype + Int32 = "int32" #: ``"int32"`` numpy dtype + Int64 = "int64" #: ``"int64"`` numpy dtype + UInt8 = "uint8" #: ``"uint8"`` numpy dtype + UInt16 = "uint16" #: ``"uint16"`` numpy dtype + UInt32 = "uint32" #: ``"uint32"`` numpy dtype + UInt64 = "uint64" #: ``"uint64"`` numpy dtype + Object = "object" #: ``"object"`` numpy dtype + Complex = "complex" #: ``"complex"`` numpy dtype + Complex64 = "complex64" #: ``"complex"`` numpy dtype + Complex128 = "complex128" #: ``"complex"`` numpy dtype + Complex256 = "complex256" #: ``"complex"`` numpy dtype + + # pandas data types + Category = "category" #: pandas ``"categorical"`` datatype + INT8 = "Int8" #: ``"Int8"`` pandas dtype:: pandas 0.24.0+ + INT16 = "Int16" #: ``"Int16"`` pandas dtype: pandas 0.24.0+ + INT32 = "Int32" #: ``"Int32"`` pandas dtype: pandas 0.24.0+ + INT64 = "Int64" #: ``"Int64"`` pandas dtype: pandas 0.24.0+ + UINT8 = "UInt8" #: ``"UInt8"`` pandas dtype: pandas 0.24.0+ + UINT16 = "UInt16" #: ``"UInt16"`` pandas dtype: pandas 0.24.0+ + UINT32 = "UInt32" #: ``"UInt32"`` pandas dtype: pandas 0.24.0+ + UINT64 = "UInt64" #: ``"UInt64"`` pandas dtype: pandas 0.24.0+ + String = "str" #: ``"str"`` numpy dtype + + #: ``"string"`` pandas dtypes: pandas 1.0.0+. For <1.0.0, this enum will + #: fall back on the str-as-object-array representation. + STRING = "string" + + +# NOTE: This is a hack to raise a deprecation warning to show for users who +# are still using the PandasDtype enum. +# pylint:disable=invalid-name +class __PandasDtype__: + def __init__(self): + self.pandas_dtypes = PandasDtype + + def __getattr__(self, name): + warnings.warn( + "The PandasDtype class is deprecated and will be removed in " + "pandera v0.9.0. Use python types, pandas type string aliases, " + "numpy dtypes, or pandas dtypes instead.", + DeprecationWarning, + ) + return Engine.dtype(getattr(self.pandas_dtypes, name).value) + + def __iter__(self): + for k in self.pandas_dtypes: + yield k.name + + +_PandasDtype = __PandasDtype__() diff --git a/pandera/io.py b/pandera/io.py index 91953b2c2..5f228460a 100644 --- a/pandera/io.py +++ b/pandera/io.py @@ -10,8 +10,9 @@ import pandera.errors +from . import dtypes from .checks import Check -from .dtypes import PandasDtype +from .engines import pandas_engine from .schema_components import Column from .schema_statistics import get_dataframe_schema_statistics from .schemas import DataFrameSchema @@ -29,18 +30,20 @@ ) from exc -SCHEMA_TYPES = {"dataframe"} DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S" -NOT_JSON_SERIALIZABLE = {PandasDtype.DateTime, PandasDtype.Timedelta} -def _serialize_check_stats(check_stats, pandas_dtype=None): +def _get_qualified_name(cls: type) -> str: + return f"{cls.__module__}.{cls.__qualname__}" + + +def _serialize_check_stats(check_stats, dtype=None): """Serialize check statistics into json/yaml-compatible format.""" def handle_stat_dtype(stat): - if pandas_dtype == PandasDtype.DateTime: + if pandas_engine.Engine.dtype(dtypes.DateTime).check(dtype): return stat.strftime(DATETIME_FORMAT) - elif pandas_dtype == PandasDtype.Timedelta: + elif pandas_engine.Engine.dtype(dtypes.Timedelta).check(dtype): # serialize to int in nanoseconds return stat.delta @@ -83,15 +86,15 @@ def _serialize_component_stats(component_stats): serialized_checks = {} for check_name, check_stats in component_stats["checks"].items(): serialized_checks[check_name] = _serialize_check_stats( - check_stats, component_stats["pandas_dtype"] + check_stats, component_stats["dtype"] ) - pandas_dtype = component_stats.get("pandas_dtype") - if pandas_dtype: - pandas_dtype = pandas_dtype.value + dtype = component_stats.get("dtype") + if dtype: + dtype = str(dtype) return { - "pandas_dtype": pandas_dtype, + "dtype": dtype, "nullable": component_stats["nullable"], "checks": serialized_checks, **{ @@ -141,11 +144,11 @@ def _serialize_schema(dataframe_schema): } -def _deserialize_check_stats(check, serialized_check_stats, pandas_dtype=None): +def _deserialize_check_stats(check, serialized_check_stats, dtype=None): def handle_stat_dtype(stat): - if pandas_dtype == PandasDtype.DateTime: + if pandas_engine.Engine.dtype(dtypes.DateTime).check(dtype): return pd.to_datetime(stat, format=DATETIME_FORMAT) - elif pandas_dtype == PandasDtype.Timedelta: + elif pandas_engine.Engine.dtype(dtypes.Timedelta).check(dtype): # serialize to int in nanoseconds return pd.to_timedelta(stat, unit="ns") return stat @@ -162,20 +165,20 @@ def handle_stat_dtype(stat): def _deserialize_component_stats(serialized_component_stats): - pandas_dtype = serialized_component_stats.get("pandas_dtype") - if pandas_dtype: - pandas_dtype = PandasDtype.from_str_alias(pandas_dtype) + dtype = serialized_component_stats.get("dtype") + if dtype: + dtype = pandas_engine.Engine.dtype(dtype) checks = serialized_component_stats.get("checks") if checks is not None: checks = [ _deserialize_check_stats( - getattr(Check, check_name), check_stats, pandas_dtype + getattr(Check, check_name), check_stats, dtype ) for check_name, check_stats in checks.items() ] return { - "pandas_dtype": pandas_dtype, + "dtype": dtype, "checks": checks, **{ key: serialized_component_stats.get(key) @@ -280,7 +283,7 @@ def _write_yaml(obj, stream): SCRIPT_TEMPLATE = """ from pandera import ( - DataFrameSchema, Column, Check, Index, MultiIndex, PandasDtype + DataFrameSchema, Column, Check, Index, MultiIndex ) schema = DataFrameSchema( @@ -294,7 +297,7 @@ def _write_yaml(obj, stream): COLUMN_TEMPLATE = """ Column( - pandas_dtype={pandas_dtype}, + dtype={dtype}, checks={checks}, nullable={nullable}, allow_duplicates={allow_duplicates}, @@ -305,7 +308,7 @@ def _write_yaml(obj, stream): """ INDEX_TEMPLATE = ( - "Index(pandas_dtype={pandas_dtype},checks={checks}," + "Index(dtype={dtype},checks={checks}," "nullable={nullable},coerce={coerce},name={name})" ) @@ -336,8 +339,9 @@ def _format_checks(checks_dict): def _format_index(index_statistics): index = [] for properties in index_statistics: + dtype = properties.get("dtype") index_code = INDEX_TEMPLATE.format( - pandas_dtype=f"PandasDtype.{properties['pandas_dtype'].name}", + dtype=f"{_get_qualified_name(dtype.__class__)}", checks=( "None" if properties["checks"] is None @@ -376,12 +380,10 @@ def to_script(dataframe_schema, path_or_buf=None): columns = {} for colname, properties in statistics["columns"].items(): - pandas_dtype = properties.get("pandas_dtype") + dtype = properties.get("dtype") column_code = COLUMN_TEMPLATE.format( - pandas_dtype=( - None - if pandas_dtype is None - else f"PandasDtype.{properties['pandas_dtype'].name}" + dtype=( + None if dtype is None else _get_qualified_name(dtype.__class__) ), checks=_format_checks(properties["checks"]), nullable=properties["nullable"], @@ -444,9 +446,9 @@ def __init__(self, field, primary_keys) -> None: self.type = field.get("type", "string") @property - def pandas_dtype(self) -> str: + def dtype(self) -> str: """Determine what type of field this is, so we can feed that into - :class:`~pandera.dtypes.PandasDtype`. If no type is specified in the + :class:`~pandera.dtypes.DataType`. If no type is specified in the frictionless schema, we default to string values. :returns: the pandas-compatible representation of this field type as a @@ -562,7 +564,7 @@ def to_pandera_column(self) -> Dict: "checks": self.checks, "coerce": self.coerce, "nullable": self.nullable, - "pandas_dtype": self.pandas_dtype, + "dtype": self.dtype, "required": self.required, "name": self.name, "regex": self.regex, diff --git a/pandera/model.py b/pandera/model.py index 17c2e280b..676854ca0 100644 --- a/pandera/model.py +++ b/pandera/model.py @@ -220,8 +220,8 @@ def to_schema(cls) -> DataFrameSchema: ordered=cls.__config__.ordered, ) if cls not in MODEL_CACHE: - MODEL_CACHE[cls] = cls.__schema__ - return cls.__schema__ + MODEL_CACHE[cls] = cls.__schema__ # type: ignore + return cls.__schema__ # type: ignore @classmethod def to_yaml(cls, stream: Optional[os.PathLike] = None): diff --git a/pandera/schema_components.py b/pandera/schema_components.py index bdfc27113..8d7fc4f7e 100644 --- a/pandera/schema_components.py +++ b/pandera/schema_components.py @@ -9,6 +9,7 @@ from . import errors from . import strategies as st +from .deprecations import deprecate_pandas_dtype from .error_handlers import SchemaErrorHandler from .schemas import ( CheckList, @@ -26,9 +27,10 @@ def _is_valid_multiindex_tuple_str(x: Tuple[Any, ...]) -> bool: class Column(SeriesSchemaBase): """Validate types and properties of DataFrame columns.""" + @deprecate_pandas_dtype def __init__( self, - pandas_dtype: PandasDtypeInputTypes = None, + dtype: PandasDtypeInputTypes = None, checks: CheckList = None, nullable: bool = False, allow_duplicates: bool = True, @@ -36,10 +38,11 @@ def __init__( required: bool = True, name: Union[str, Tuple[str, ...], None] = None, regex: bool = False, + pandas_dtype: PandasDtypeInputTypes = None, ) -> None: """Create column validator object. - :param pandas_dtype: datatype of the column. A ``PandasDtype`` for + :param dtype: datatype of the column. A ``PandasDtype`` for type-checking dataframe. If a string is specified, then assumes one of the valid pandas string values: http://pandas.pydata.org/pandas-docs/stable/basics.html#dtypes @@ -54,6 +57,10 @@ def __init__( :param name: column name in dataframe to validate. :param regex: whether the ``name`` attribute should be treated as a regex pattern to apply to multiple columns in a dataframe. + :param pandas_dtype: alias of ``dtype`` for backwards compatibility. + + .. warning:: This option will be deprecated in 0.8.0 + :raises SchemaInitError: if impossible to build schema from parameters :example: @@ -74,7 +81,13 @@ def __init__( See :ref:`here` for more usage details. """ super().__init__( - pandas_dtype, checks, nullable, allow_duplicates, coerce + dtype, + checks, + nullable, + allow_duplicates, + coerce, + name, + pandas_dtype, ) if ( name is not None @@ -103,7 +116,7 @@ def _allow_groupby(self) -> bool: def properties(self) -> Dict[str, Any]: """Get column properties.""" return { - "pandas_dtype": self._pandas_dtype, + "dtype": self.dtype, "checks": self._checks, "nullable": self._nullable, "allow_duplicates": self._allow_duplicates, @@ -264,7 +277,7 @@ def strategy(self, *, size=None): def strategy_component(self): """Generate column data object for use by DataFrame strategy.""" return st.column_strategy( - self.pdtype, + self.dtype, checks=self.checks, allow_duplicates=self.allow_duplicates, name=self.name, @@ -293,6 +306,9 @@ def example(self, size=None) -> pd.DataFrame: ) def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + def _compare_dict(obj): return { k: v if k != "_checks" else set(v) @@ -351,9 +367,9 @@ def validate( check_obj.index = self.coerce_dtype(check_obj.index) # handles case where pandas native string type is not supported # by index. - obj_to_validate = pd.Series( - check_obj.index, name=check_obj.index.name - ).astype(self.dtype) + obj_to_validate = self.dtype.coerce( + pd.Series(check_obj.index, name=check_obj.index.name) + ) else: obj_to_validate = pd.Series( check_obj.index, name=check_obj.index.name @@ -381,7 +397,7 @@ def strategy(self, *, size: int = None): :returns: index strategy. """ return st.index_strategy( - self.pdtype, # type: ignore + self.dtype, # type: ignore checks=self.checks, nullable=self.nullable, allow_duplicates=self.allow_duplicates, @@ -393,7 +409,7 @@ def strategy(self, *, size: int = None): def strategy_component(self): """Generate column data object for use by MultiIndex strategy.""" return st.column_strategy( - self.pdtype, + self.dtype, checks=self.checks, allow_duplicates=self.allow_duplicates, name=self.name, @@ -439,7 +455,7 @@ def __init__( :param indexes: list of Index validators for each level of the MultiIndex index. :param coerce: Whether or not to coerce the MultiIndex to the - specified pandas_dtypes before validation + specified dtypes before validation :param strict: whether or not to accept columns in the MultiIndex that aren't defined in the ``indexes`` argument. :param name: name of schema component @@ -496,7 +512,7 @@ def __init__( "component is not ordered." ) columns[i if index.name is None else index.name] = Column( - pandas_dtype=index._pandas_dtype, + dtype=index._dtype, checks=index.checks, nullable=index._nullable, allow_duplicates=index._allow_duplicates, @@ -525,7 +541,7 @@ def coerce(self, value: bool) -> None: self._coerce = value def coerce_dtype(self, obj: pd.MultiIndex) -> pd.MultiIndex: - """Coerce type of a pd.Series by type specified in pandas_dtype. + """Coerce type of a pd.Series by type specified in dtype. :param obj: multi-index to coerce. :returns: ``MultiIndex`` with coerced data type diff --git a/pandera/schema_inference.py b/pandera/schema_inference.py index 2493fe220..348518068 100644 --- a/pandera/schema_inference.py +++ b/pandera/schema_inference.py @@ -36,7 +36,7 @@ def infer_schema( def _create_index(index_statistics): index = [ Index( - properties["pandas_dtype"], + properties["dtype"], checks=parse_check_statistics(properties["checks"]), nullable=properties["nullable"], name=properties["name"], @@ -58,11 +58,10 @@ def infer_dataframe_schema(df: pd.DataFrame) -> DataFrameSchema: :returns: DataFrameSchema """ df_statistics = infer_dataframe_statistics(df) - schema = DataFrameSchema( columns={ colname: Column( - properties["pandas_dtype"], + properties["dtype"], checks=parse_check_statistics(properties["checks"]), nullable=properties["nullable"], ) @@ -83,7 +82,7 @@ def infer_series_schema(series) -> SeriesSchema: """ series_statistics = infer_series_statistics(series) schema = SeriesSchema( - pandas_dtype=series_statistics["pandas_dtype"], + dtype=series_statistics["dtype"], checks=parse_check_statistics(series_statistics["checks"]), nullable=series_statistics["nullable"], name=series_statistics["name"], diff --git a/pandera/schema_statistics.py b/pandera/schema_statistics.py index 1d8bdf1d7..22cf80368 100644 --- a/pandera/schema_statistics.py +++ b/pandera/schema_statistics.py @@ -1,31 +1,12 @@ """Module for inferring the statistics of pandas objects.""" - import warnings from typing import Any, Dict, Union import pandas as pd +from . import dtypes from .checks import Check -from .dtypes import PandasDtype - -NUMERIC_DTYPES = frozenset( - [ - PandasDtype.Float, - PandasDtype.Float16, - PandasDtype.Float32, - PandasDtype.Float64, - PandasDtype.Int, - PandasDtype.Int8, - PandasDtype.Int16, - PandasDtype.Int32, - PandasDtype.Int64, - PandasDtype.UInt8, - PandasDtype.UInt16, - PandasDtype.UInt32, - PandasDtype.UInt64, - PandasDtype.DateTime, - ] -) +from .engines import pandas_engine def infer_dataframe_statistics(df: pd.DataFrame) -> Dict[str, Any]: @@ -34,7 +15,7 @@ def infer_dataframe_statistics(df: pd.DataFrame) -> Dict[str, Any]: inferred_column_dtypes = {col: _get_array_type(df[col]) for col in df} column_statistics = { col: { - "pandas_dtype": dtype, + "dtype": dtype, "nullable": bool(nullable_columns[col]), "checks": _get_array_check_statistics(df[col], dtype), } @@ -50,7 +31,7 @@ def infer_series_statistics(series: pd.Series) -> Dict[str, Any]: """Infer column and index statistics from a pandas Series.""" dtype = _get_array_type(series) return { - "pandas_dtype": dtype, + "dtype": dtype, "nullable": bool(series.isna().any()), "checks": _get_array_check_statistics(series, dtype), "name": series.name, @@ -63,7 +44,7 @@ def infer_index_statistics(index: Union[pd.Index, pd.MultiIndex]): def _index_stats(index_level): dtype = _get_array_type(index_level) return { - "pandas_dtype": dtype, + "dtype": dtype, "nullable": bool(index_level.isna().any()), "checks": _get_array_check_statistics(index_level, dtype), "name": index_level.name, @@ -105,7 +86,7 @@ def get_dataframe_schema_statistics(dataframe_schema): statistics = { "columns": { col_name: { - "pandas_dtype": column.pdtype, + "dtype": column.dtype, "nullable": column.nullable, "allow_duplicates": column.allow_duplicates, "coerce": column.coerce, @@ -128,7 +109,7 @@ def get_dataframe_schema_statistics(dataframe_schema): def _get_series_base_schema_statistics(series_schema_base): return { - "pandas_dtype": series_schema_base._pandas_dtype, + "dtype": series_schema_base.dtype, "nullable": series_schema_base.nullable, "checks": parse_checks(series_schema_base.checks), "coerce": series_schema_base.coerce, @@ -194,30 +175,31 @@ def parse_checks(checks) -> Union[Dict[str, Any], None]: def _get_array_type(x): # get most granular type possible - dtype = PandasDtype.from_str_alias(str(x.dtype)) + + data_type = pandas_engine.Engine.dtype(x.dtype) # for object arrays, try to infer dtype - if dtype is PandasDtype.Object: - dtype = PandasDtype.from_pandas_api_type( - pd.api.types.infer_dtype(x, skipna=True) - ) - return dtype + if data_type is pandas_engine.Engine.dtype("object"): + inferred_alias = pd.api.types.infer_dtype(x, skipna=True) + if inferred_alias != "string": + data_type = pandas_engine.Engine.dtype(inferred_alias) + return data_type def _get_array_check_statistics( - x, dtype: PandasDtype + x, data_type: dtypes.DataType ) -> Union[Dict[str, Any], None]: """Get check statistics from an array-like object.""" - if dtype is PandasDtype.DateTime: + if dtypes.is_datetime(data_type): check_stats = { "greater_than_or_equal_to": x.min(), "less_than_or_equal_to": x.max(), } - elif dtype in NUMERIC_DTYPES: + elif dtypes.is_numeric(data_type) and not dtypes.is_bool(data_type): check_stats = { "greater_than_or_equal_to": float(x.min()), "less_than_or_equal_to": float(x.max()), } - elif dtype is PandasDtype.Category: + elif dtypes.is_category(data_type): try: categories = x.cat.categories except AttributeError: diff --git a/pandera/schemas.py b/pandera/schemas.py index 8935d4180..da7018403 100644 --- a/pandera/schemas.py +++ b/pandera/schemas.py @@ -12,10 +12,12 @@ import numpy as np import pandas as pd -from . import constants, dtypes, errors +from . import constants, errors from . import strategies as st from .checks import Check -from .dtypes import PandasDtype, PandasExtensionType, is_extension_dtype +from .deprecations import deprecate_pandas_dtype +from .dtypes import DataType +from .engines import pandas_engine from .error_formatters import ( format_generic_error_message, format_vectorized_error_message, @@ -34,8 +36,8 @@ PandasDtypeInputTypes = Union[ str, type, - PandasDtype, - PandasExtensionType, + DataType, + pd.core.dtypes.base.ExtensionDtype, np.dtype, None, ] @@ -63,17 +65,19 @@ def _wrapper(schema, *args, **kwargs): class DataFrameSchema: # pylint: disable=too-many-public-methods """A light-weight pandas DataFrame validator.""" + @deprecate_pandas_dtype def __init__( self, columns: Optional[Dict[Any, Any]] = None, checks: CheckList = None, index=None, - pandas_dtype: PandasDtypeInputTypes = None, - transformer: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, + dtype: PandasDtypeInputTypes = None, + transformer: Callable = None, coerce: bool = False, strict: Union[bool, str] = False, name: Optional[str] = None, ordered: bool = False, + pandas_dtype: PandasDtypeInputTypes = None, ) -> None: """Initialize DataFrameSchema validator. @@ -83,7 +87,7 @@ def __init__( :type columns: mapping of column names and column schema component. :param checks: dataframe-wide checks. :param index: specify the datatypes and properties of the index. - :param pandas_dtype: datatype of the dataframe. This overrides the data + :param dtype: datatype of the dataframe. This overrides the data types specified in any of the columns. If a string is specified, then assumes one of the valid pandas string values: http://pandas.pydata.org/pandas-docs/stable/basics.html#dtypes. @@ -101,8 +105,13 @@ def __init__( are not present in the dataframe, will throw an error. :param name: name of the schema. :param ordered: whether or not to validate the columns order. + :param pandas_dtype: alias of ``dtype`` for backwards compatibility. + + .. warning:: This option will be deprecated in 0.8.0 :raises SchemaInitError: if impossible to build schema from parameters + :raises SchemaInitError: if ``dtype`` and ``pandas_dtype`` are both + supplied. :examples: @@ -166,7 +175,7 @@ def __init__( self.index = index self.strict = strict self.name = name - self._pandas_dtype = pandas_dtype + self.dtype = dtype or pandas_dtype # type: ignore self._coerce = coerce self._ordered = ordered self._validate_schema() @@ -236,10 +245,12 @@ def _set_column_handler(column, column_name): } @property - def dtype(self) -> Dict[str, str]: + def dtypes(self) -> Dict[str, DataType]: + # pylint:disable=anomalous-backslash-in-string """ - A pandas style dtype dict where the keys are column names and values - are pandas dtype for the column. Excludes columns where regex=True. + A dict where the keys are column names and values are + :class:`~pandera.dtypes.DataType` s for the column. Excludes columns + where `regex=True`. :returns: dictionary of columns and their associated dtypes. """ @@ -249,13 +260,13 @@ def dtype(self) -> Dict[str, str]: if regex_columns: warnings.warn( "Schema has columns specified as regex column names: %s " - "Use the `get_dtype` to get the datatypes for these " + "Use the `get_dtypes` to get the datatypes for these " "columns." % regex_columns, UserWarning, ) return {n: c.dtype for n, c in self.columns.items() if not c.regex} - def get_dtype(self, dataframe: pd.DataFrame) -> Dict[str, str]: + def get_dtypes(self, dataframe: pd.DataFrame) -> Dict[str, str]: """ Same as the ``dtype`` property, but expands columns where ``regex == True`` based on the supplied dataframe. @@ -277,60 +288,44 @@ def get_dtype(self, dataframe: pd.DataFrame) -> Dict[str, str]: } @property - def pandas_dtype( + def dtype( self, - ) -> Union[str, dtypes.PandasDtype, dtypes.PandasExtensionType]: - """Get the pandas dtype property.""" - return self._pandas_dtype + ) -> DataType: + """Get the dtype property.""" + return self._dtype # type: ignore - @pandas_dtype.setter - def pandas_dtype( - self, value: Union[str, dtypes.PandasDtype, dtypes.PandasExtensionType] - ) -> None: + @dtype.setter + def dtype(self, value: PandasDtypeInputTypes) -> None: """Set the pandas dtype property.""" - self._pandas_dtype = value - self.dtype # pylint: disable=pointless-statement - - @property - def pdtype(self) -> Optional[PandasDtype]: - """PandasDtype of the dataframe.""" - pandas_dtype = PandasDtype.get_str_dtype(self.pandas_dtype) - if pandas_dtype is None: - return pandas_dtype - return PandasDtype.from_str_alias(pandas_dtype) + self._dtype = pandas_engine.Engine.dtype(value) if value else None def _coerce_dtype(self, obj: pd.DataFrame) -> pd.DataFrame: - if self.pandas_dtype is dtypes.PandasDtype.String: - # only coerce non-null elements to string - return obj.where(obj.isna(), obj.astype(str)) - - if self.pdtype is None: + if self.dtype is None: raise ValueError( - "pandas_dtype argument is None. Must specify this argument " + "dtype argument is None. Must specify this argument " "to coerce dtype" ) + try: - return obj.astype(self.pdtype.str_alias) - except (ValueError, TypeError) as exc: - msg = ( - f"Error while coercing '{self.name}' to type {self.dtype}: " - f"{exc}" - ) + return self.dtype.coerce(obj) + except Exception as exc: raise errors.SchemaError( self, obj, - msg, + ( + f"Error while coercing '{self.name}' to type " + f"{self.dtype}: {exc}" + ), failure_cases=scalar_failure_case(str(obj.dtypes.to_dict())), - check=f"coerce_dtype('{self.pdtype.str_alias}')", + check=f"coerce_dtype('{self.dtype}')", ) from exc def coerce_dtype(self, obj: pd.DataFrame) -> pd.DataFrame: - """Coerce dataframe to the type specified in pandas_dtype. + """Coerce dataframe to the type specified in dtype. :param obj: dataframe to coerce. :returns: dataframe with coerced dtypes """ - error_handler = SchemaErrorHandler(lazy=True) def _try_coercion(coerce_fn, obj): @@ -354,14 +349,14 @@ def _try_coercion(coerce_fn, obj): ) elif ( (col_schema.coerce or self.coerce) - and self.pdtype is None + and self.dtype is None and colname in obj ): obj.loc[:, colname] = _try_coercion( col_schema.coerce_dtype, obj[colname] ) - if self.pdtype is not None: + if self.dtype is not None: obj = _try_coercion(self._coerce_dtype, obj) if self.index is not None and (self.index.coerce or self.coerce): index_schema = copy.deepcopy(self.index) @@ -571,10 +566,10 @@ def validate( if ( col.required or col_name in check_obj ) and col_name not in lazy_exclude_columns: - if self.pdtype is not None: + if self.dtype is not None: # override column dtype with dataframe dtype col = copy.deepcopy(col) - col.pandas_dtype = self.pdtype + col.dtype = self.dtype schema_components.append(col) if self.index is not None: @@ -656,17 +651,13 @@ def __call__( def __repr__(self) -> str: """Represent string for logging.""" - if isinstance(self._pandas_dtype, PandasDtype): - dtype = self._pandas_dtype.value - else: - dtype = self._pandas_dtype return ( f" - 'probability': - 'even_number': + 'category': + 'probability': + 'even_number': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index=None, strict=False name=None, @@ -857,11 +843,11 @@ def remove_columns(self, cols_to_remove: List[str]) -> "DataFrameSchema": >>> print(example_schema.remove_columns(["category"])) + 'probability': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index=None, strict=False name=None, @@ -912,17 +898,17 @@ def update_column(self, column_name: str, **kwargs) -> "DataFrameSchema": ... }) >>> print( ... example_schema.update_column( - ... 'category', pandas_dtype=pa.Category + ... 'category', dtype=pa.Category ... ) ... ) - 'probability': + 'category': + 'probability': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index=None, strict=False name=None, @@ -972,17 +958,17 @@ def update_columns( >>> >>> print( ... example_schema.update_columns( - ... {"category": {"pandas_dtype":pa.Category}} + ... {"category": {"dtype":pa.Category}} ... ) ... ) - 'probability': + 'category': + 'probability': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index=None, strict=False name=None, @@ -1059,12 +1045,12 @@ def rename_columns(self, rename_dict: Dict[str, str]) -> "DataFrameSchema": ... ) - 'probabilities': + 'categories': + 'probabilities': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index=None, strict=False name=None, @@ -1135,11 +1121,11 @@ def select_columns(self, columns: List[Any]) -> "DataFrameSchema": >>> print(example_schema.select_columns(['category'])) + 'category': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index=None, strict=False name=None, @@ -1235,12 +1221,12 @@ def set_index( >>> print(example_schema.set_index(['category'])) + 'probability': }, checks=[], coerce=False, - pandas_dtype=None, - index=, + dtype=None, + index=, strict=False name=None, ordered=False @@ -1255,21 +1241,21 @@ def set_index( ... "column1": pa.Column(pa.String), ... "column2": pa.Column(pa.Int) ... }, - ... index=pa.Index(name = "column3", pandas_dtype = pa.Int) + ... index=pa.Index(name = "column3", dtype = pa.Int) ... ) >>> >>> print(example_schema.set_index(["column2"], append = True)) + 'column1': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index= - + + ] coerce=False, strict=False, @@ -1315,7 +1301,7 @@ def set_index( for col in keys_temp: ind_list.append( Index( - pandas_dtype=new_schema.columns[col].pandas_dtype, + dtype=new_schema.columns[col].dtype, name=col, checks=new_schema.columns[col].checks, nullable=new_schema.columns[col].nullable, @@ -1359,18 +1345,18 @@ def reset_index( >>> >>> example_schema = pa.DataFrameSchema( ... {"probability" : pa.Column(pa.Float)}, - ... index = pa.Index(name="unique_id", pandas_dtype=pa.Int) + ... index = pa.Index(name="unique_id", dtype=pa.Int) ... ) >>> >>> print(example_schema.reset_index()) - 'unique_id': + 'probability': + 'unique_id': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index=None, strict=False name=None, @@ -1386,21 +1372,21 @@ def reset_index( >>> example_schema = pa.DataFrameSchema({ ... "category" : pa.Column(pa.String)}, ... index = pa.MultiIndex([ - ... pa.Index(name = "unique_id1", pandas_dtype = pa.Int), - ... pa.Index(name = "unique_id2", pandas_dtype = pa.String) + ... pa.Index(name = "unique_id1", dtype = pa.Int), + ... pa.Index(name = "unique_id2", dtype = pa.String) ... ] ... ) ... ) >>> print(example_schema.reset_index(level = ["unique_id1"])) - 'unique_id1': + 'category': + 'unique_id1': }, checks=[], coerce=False, - pandas_dtype=None, - index=, + dtype=None, + index=, strict=False name=None, ordered=False @@ -1447,9 +1433,7 @@ def reset_index( new_index if new_index is None else Index( - pandas_dtype=new_index.columns[ - list(new_index.columns)[0] - ].pandas_dtype, + dtype=new_index.columns[list(new_index.columns)[0]].dtype, checks=new_index.columns[list(new_index.columns)[0]].checks, nullable=new_index.columns[ list(new_index.columns)[0] @@ -1475,7 +1459,7 @@ def reset_index( new_schema = new_schema.add_columns( { k: Column( - pandas_dtype=v.dtype, + dtype=v.dtype, checks=v.checks, nullable=v.nullable, allow_duplicates=v.allow_duplicates, @@ -1494,18 +1478,20 @@ def reset_index( class SeriesSchemaBase: """Base series validator object.""" + @deprecate_pandas_dtype def __init__( self, - pandas_dtype: PandasDtypeInputTypes = None, + dtype: PandasDtypeInputTypes = None, checks: CheckList = None, nullable: bool = False, allow_duplicates: bool = True, coerce: bool = False, name: Any = None, + pandas_dtype: PandasDtypeInputTypes = None, ) -> None: """Initialize series schema base object. - :param pandas_dtype: datatype of the column. If a string is specified, + :param dtype: datatype of the column. If a string is specified, then assumes one of the valid pandas string values: http://pandas.pydata.org/pandas-docs/stable/basics.html#dtypes :param checks: If element_wise is True, then callable signature should @@ -1514,18 +1500,23 @@ def __init__( ``Callable[Any, bool]`` where the ``Any`` input is a scalar element in the column. Otherwise, the input is assumed to be a pandas.Series object. - :type checks: callable :param nullable: Whether or not column can contain null values. - :type nullable: bool - :param allow_duplicates: - :type allow_duplicates: bool + :param allow_duplicates: Whether or not column can contain duplicate + values. + :param coerce: If True, when schema.validate is called the column will + be coerced into the specified dtype. This has no effect on columns + where ``dtype=None``. + :param name: column name in dataframe to validate. + :param pandas_dtype: alias of ``dtype`` for backwards compatibility. + + .. warning:: This option will be deprecated in 0.8.0 + """ if checks is None: checks = [] if isinstance(checks, (Check, Hypothesis)): checks = [checks] - - self._pandas_dtype = pandas_dtype + self.dtype = dtype or pandas_dtype # type: ignore self._nullable = nullable self._allow_duplicates = allow_duplicates self._coerce = coerce @@ -1600,58 +1591,29 @@ def name(self) -> Union[str, None]: return self._name @property - def pandas_dtype( + def dtype( self, - ) -> Union[str, dtypes.PandasDtype, dtypes.PandasExtensionType]: + ) -> DataType: """Get the pandas dtype""" - return self._pandas_dtype + return self._dtype # type: ignore - @pandas_dtype.setter - def pandas_dtype( - self, value: Union[str, dtypes.PandasDtype, dtypes.PandasExtensionType] - ) -> None: + @dtype.setter + def dtype(self, value: PandasDtypeInputTypes) -> None: """Set the pandas dtype""" - self._pandas_dtype = value - self.dtype # pylint: disable=pointless-statement - - @property - def dtype(self) -> Optional[str]: - """String representation of the dtype.""" - return PandasDtype.get_str_dtype(self._pandas_dtype) - - @property - def pdtype(self) -> Optional[PandasDtype]: - """PandasDtype of the series.""" - if self.dtype is None: - return None - if isinstance(self.pandas_dtype, PandasDtype): - return self.pandas_dtype - return PandasDtype.from_str_alias(self.dtype) + self._dtype = pandas_engine.Engine.dtype(value) if value else None def coerce_dtype(self, obj: Union[pd.Series, pd.Index]) -> pd.Series: - """Coerce type of a pd.Series by type specified in pandas_dtype. + """Coerce type of a pd.Series by type specified in dtype. :param pd.Series series: One-dimensional ndarray with axis labels (including time series). :returns: ``Series`` with coerced data type """ - if self._pandas_dtype is None: + if self.dtype is None: return obj - elif ( - self._pandas_dtype is PandasDtype.String - or self._pandas_dtype is str - or self._pandas_dtype == "str" - and self._pandas_dtype is not PandasDtype.Object - ): - # only coerce non-null elements to string, make sure series is of - # object dtype - return obj.astype(object).where( - obj.isna(), - obj.astype(str), - ) try: - return obj.astype(self.dtype) + return self.dtype.coerce(obj) except (ValueError, TypeError) as exc: msg = ( f"Error while coercing '{self.name}' to type " @@ -1748,33 +1710,7 @@ def validate( ), ) - series_dtype = series.dtype - if self._nullable: - series_no_nans = series.dropna() - if self.dtype in dtypes.NUMPY_NONNULLABLE_INT_DTYPES: - _series = series_no_nans.astype(self.dtype) - series_dtype = _series.dtype - if (_series != series_no_nans).any(): - # in case where dtype is meant to be int, make sure that - # casting to int results in equal values. - msg = ( - "after dropping null values, expected values in " - "series '%s' to be int, found: %s" - % (series.name, set(series)) - ) - error_handler.collect_error( - "unexpected_nullable_integer_type", - errors.SchemaError( - self, - check_obj, - msg, - failure_cases=reshape_failure_cases( - series_no_nans - ), - check="nullable_integer", - ), - ) - else: + if not self._nullable: nulls = series.isna() if sum(nulls) > 0: msg = "non-nullable series '%s' contains null values: %s" % ( @@ -1817,25 +1753,21 @@ def validate( ), ) - if is_extension_dtype(self._pandas_dtype): - target_dtype = PandasDtype.get_dtype(self._pandas_dtype) - else: - series_dtype = str(series_dtype) - target_dtype = self.dtype - if self._pandas_dtype is not None and series_dtype != target_dtype: - msg = "expected series '%s' to have type %s, got %s" % ( - series.name, - repr(target_dtype), - repr(series_dtype), + if self._dtype is not None and ( + not self._dtype.check(pandas_engine.Engine.dtype(series.dtype)) + ): + msg = ( + f"expected series '{series.name}' to have type {self._dtype}, " + + f"got {series.dtype}" ) error_handler.collect_error( - "wrong_pandas_dtype", + "wrong_dtype", errors.SchemaError( self, check_obj, msg, - failure_cases=scalar_failure_case(str(series_dtype)), - check=f"pandas_dtype('{self.dtype}')", + failure_cases=scalar_failure_case(str(series.dtype)), + check=f"dtype('{self.dtype}')", ), ) @@ -1906,7 +1838,7 @@ def strategy(self, *, size=None): :returns: a strategy that generates pandas Series objects. """ return st.series_strategy( - self.pdtype, + self.dtype, checks=self.checks, nullable=self.nullable, allow_duplicates=self.allow_duplicates, @@ -1931,32 +1863,30 @@ def example(self, size=None) -> pd.Series: return self.strategy(size=size).example() def __repr__(self): - if isinstance(self._pandas_dtype, PandasDtype): - dtype = self._pandas_dtype.value - else: - dtype = self._pandas_dtype return ( f"" + f"(name={self._name}, type={self.dtype!r})>" ) class SeriesSchema(SeriesSchemaBase): """Series validator.""" + @deprecate_pandas_dtype def __init__( self, - pandas_dtype: PandasDtypeInputTypes = None, + dtype: PandasDtypeInputTypes = None, checks: CheckList = None, index=None, nullable: bool = False, allow_duplicates: bool = True, coerce: bool = False, name: str = None, + pandas_dtype: PandasDtypeInputTypes = None, ) -> None: """Initialize series schema base object. - :param pandas_dtype: datatype of the column. If a string is specified, + :param dtype: datatype of the column. If a string is specified, then assumes one of the valid pandas string values: http://pandas.pydata.org/pandas-docs/stable/basics.html#dtypes :param checks: If element_wise is True, then callable signature should @@ -1965,15 +1895,27 @@ def __init__( ``Callable[Any, bool]`` where the ``Any`` input is a scalar element in the column. Otherwise, the input is assumed to be a pandas.Series object. - :type checks: callable :param index: specify the datatypes and properties of the index. :param nullable: Whether or not column can contain null values. - :type nullable: bool - :param allow_duplicates: - :type allow_duplicates: bool + :param allow_duplicates: Whether or not column can contain duplicate + values. + :param coerce: If True, when schema.validate is called the column will + be coerced into the specified dtype. This has no effect on columns + where ``pandas_dtype=None``. + :param name: series name. + :param pandas_dtype: alias of ``dtype`` for backwards compatibility. + + .. warning:: This option will be deprecated in 0.8.0 + """ super().__init__( - pandas_dtype, checks, nullable, allow_duplicates, coerce, name + dtype, + checks, + nullable, + allow_duplicates, + coerce, + name, + pandas_dtype, ) self.index = index diff --git a/pandera/strategies.py b/pandera/strategies.py index bf7684995..152976d77 100644 --- a/pandera/strategies.py +++ b/pandera/strategies.py @@ -10,7 +10,6 @@ See the :ref:`user guide` for more details. """ - import operator import re import warnings @@ -32,7 +31,15 @@ import numpy as np import pandas as pd -from .dtypes import PandasDtype +from .dtypes import ( + DataType, + is_category, + is_complex, + is_datetime, + is_float, + is_timedelta, +) +from .engines import numpy_engine, pandas_engine from .errors import BaseStrategyOnlyError, SchemaDefinitionError try: @@ -139,12 +146,16 @@ def set_pandas_index( return df_or_series -def verify_pandas_dtype(pandas_dtype, schema_type: str, name: Optional[str]): - """Verify that pandas_dtype argument is not None.""" - if pandas_dtype is None: +def verify_dtype( + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], + schema_type: str, + name: Optional[str], +): + """Verify that pandera_dtype argument is not None.""" + if pandera_dtype is None: raise SchemaDefinitionError( f"'{schema_type}' schema with name '{name}' has no specified " - "pandas_dtype. You need to specify one in order to synthesize " + "dtype. You need to specify one in order to synthesize " "data from a strategy." ) @@ -208,7 +219,7 @@ def _wrapper(cls, *args, **kwargs): MAX_DT_VALUE = 2 ** 63 - 1 -def numpy_time_dtypes(dtype, min_value=None, max_value=None): +def numpy_time_dtypes(dtype: np.dtype, min_value=None, max_value=None): """Create numpy strategy for datetime and timedelta data types. :param dtype: numpy datetime or timedelta datatype @@ -293,15 +304,24 @@ def build_complex(draw): return build_complex() +def to_numpy_dtype(pandera_dtype: DataType): + """Convert a :class:`~pandera.dtypes.DataType` to numpy dtype compatible + with hypothesis.""" + np_dtype = pandas_engine.Engine.numpy_dtype(pandera_dtype) + if np_dtype == np.dtype("object"): + np_dtype = np.dtype(str) + return np_dtype + + def pandas_dtype_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: DataType, strategy: Optional[SearchStrategy] = None, **kwargs, ) -> SearchStrategy: # pylint: disable=line-too-long,no-else-raise - """Strategy to generate data from a :class:`pandera.dtypes.PandasDtype`. + """Strategy to generate data from a :class:`pandera.dtypes.DataType`. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :kwargs: key-word arguments passed into @@ -317,35 +337,30 @@ def compat_kwargs(*args): # hypothesis doesn't support categoricals or objects, so we'll will need to # build a pandera-specific solution. - if pandas_dtype is PandasDtype.Category: + if is_category(pandera_dtype): raise TypeError( "data generation for the Category dtype is currently " "unsupported. Consider using a string or int dtype and " "Check.isin(values) to ensure a finite set of values." ) - # The object type falls back onto generating strings. - if pandas_dtype is PandasDtype.Object: - dtype = np.dtype("str") - else: - dtype = pandas_dtype.numpy_dtype - + np_dtype = to_numpy_dtype(pandera_dtype) if strategy: - return strategy.map(dtype.type) - elif pandas_dtype.is_datetime or pandas_dtype.is_timedelta: + return strategy.map(np_dtype.type) + elif is_datetime(pandera_dtype) or is_timedelta(pandera_dtype): return numpy_time_dtypes( - dtype, + np_dtype, **compat_kwargs("min_value", "max_value"), ) - elif pandas_dtype.is_complex: + elif is_complex(pandera_dtype): return numpy_complex_dtypes( - dtype, + np_dtype, **compat_kwargs( "min_value", "max_value", "allow_infinity", "allow_nan" ), ) return npst.from_dtype( - dtype, + np_dtype, **{ # type: ignore "allow_nan": False, "allow_infinity": False, @@ -355,14 +370,14 @@ def compat_kwargs(*args): def eq_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, value: Any, ) -> SearchStrategy: """Strategy to generate a single value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param value: value to generate. @@ -370,38 +385,38 @@ def eq_strategy( """ # override strategy preceding this one and generate value of the same type if strategy is None: - strategy = pandas_dtype_strategy(pandas_dtype) - return st.just(value).map(pandas_dtype.numpy_dtype.type) + strategy = pandas_dtype_strategy(pandera_dtype) + return st.just(value).map(to_numpy_dtype(pandera_dtype).type) def ne_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, value: Any, ) -> SearchStrategy: """Strategy to generate anything except for a particular value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param value: value to avoid. :returns: ``hypothesis`` strategy """ if strategy is None: - strategy = pandas_dtype_strategy(pandas_dtype) + strategy = pandas_dtype_strategy(pandera_dtype) return strategy.filter(lambda x: x != value) def gt_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, min_value: Union[int, float], ) -> SearchStrategy: """Strategy to generate values greater than a minimum value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param min_value: generate values larger than this. @@ -409,22 +424,22 @@ def gt_strategy( """ if strategy is None: strategy = pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, min_value=min_value, - exclude_min=True if pandas_dtype.is_float else None, + exclude_min=True if is_float(pandera_dtype) else None, ) return strategy.filter(lambda x: x > min_value) def ge_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, min_value: Union[int, float], ) -> SearchStrategy: """Strategy to generate values greater than or equal to a minimum value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param min_value: generate values greater than or equal to this. @@ -432,22 +447,22 @@ def ge_strategy( """ if strategy is None: return pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, min_value=min_value, - exclude_min=False if pandas_dtype.is_float else None, + exclude_min=False if is_float(pandera_dtype) else None, ) return strategy.filter(lambda x: x >= min_value) def lt_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, max_value: Union[int, float], ) -> SearchStrategy: """Strategy to generate values less than a maximum value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param max_value: generate values less than this. @@ -455,22 +470,22 @@ def lt_strategy( """ if strategy is None: strategy = pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, max_value=max_value, - exclude_max=True if pandas_dtype.is_float else None, + exclude_max=True if is_float(pandera_dtype) else None, ) return strategy.filter(lambda x: x < max_value) def le_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, max_value: Union[int, float], ) -> SearchStrategy: """Strategy to generate values less than or equal to a maximum value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param max_value: generate values less than or equal to this. @@ -478,15 +493,15 @@ def le_strategy( """ if strategy is None: return pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, max_value=max_value, - exclude_max=False if pandas_dtype.is_float else None, + exclude_max=False if is_float(pandera_dtype) else None, ) return strategy.filter(lambda x: x <= max_value) def in_range_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, min_value: Union[int, float], @@ -496,7 +511,7 @@ def in_range_strategy( ) -> SearchStrategy: """Strategy to generate values within a particular range. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param min_value: generate values greater than this. @@ -507,7 +522,7 @@ def in_range_strategy( """ if strategy is None: return pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, min_value=min_value, max_value=max_value, exclude_min=not include_min, @@ -521,14 +536,14 @@ def in_range_strategy( def isin_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, allowed_values: Sequence[Any], ) -> SearchStrategy: """Strategy to generate values within a finite set. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param allowed_values: set of allowable values. @@ -536,39 +551,39 @@ def isin_strategy( """ if strategy is None: return st.sampled_from(allowed_values).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) return strategy.filter(lambda x: x in allowed_values) def notin_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, forbidden_values: Sequence[Any], ) -> SearchStrategy: """Strategy to generate values excluding a set of forbidden values - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param forbidden_values: set of forbidden values. :returns: ``hypothesis`` strategy """ if strategy is None: - strategy = pandas_dtype_strategy(pandas_dtype) + strategy = pandas_dtype_strategy(pandera_dtype) return strategy.filter(lambda x: x not in forbidden_values) def str_matches_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, pattern: str, ) -> SearchStrategy: """Strategy to generate strings that patch a regex pattern. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param pattern: regex pattern. @@ -576,7 +591,7 @@ def str_matches_strategy( """ if strategy is None: return st.from_regex(pattern, fullmatch=True).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) def matches(x): @@ -586,14 +601,14 @@ def matches(x): def str_contains_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, pattern: str, ) -> SearchStrategy: """Strategy to generate strings that contain a particular pattern. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param pattern: regex pattern. @@ -601,7 +616,7 @@ def str_contains_strategy( """ if strategy is None: return st.from_regex(pattern, fullmatch=False).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) def contains(x): @@ -611,14 +626,14 @@ def contains(x): def str_startswith_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, string: str, ) -> SearchStrategy: """Strategy to generate strings that start with a specific string pattern. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param string: string pattern. @@ -626,21 +641,21 @@ def str_startswith_strategy( """ if strategy is None: return st.from_regex(f"\\A{string}", fullmatch=False).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) return strategy.filter(lambda x: x.startswith(string)) def str_endswith_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, string: str, ) -> SearchStrategy: """Strategy to generate strings that end with a specific string pattern. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param string: string pattern. @@ -648,14 +663,14 @@ def str_endswith_strategy( """ if strategy is None: return st.from_regex(f"{string}\\Z", fullmatch=False).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) return strategy.filter(lambda x: x.endswith(string)) def str_length_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, min_value: int, @@ -663,7 +678,7 @@ def str_length_strategy( ) -> SearchStrategy: """Strategy to generate strings of a particular length - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param min_value: minimum string length. @@ -672,21 +687,21 @@ def str_length_strategy( """ if strategy is None: return st.text(min_size=min_value, max_size=max_value).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) return strategy.filter(lambda x: min_value <= len(x) <= max_value) def field_element_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, checks: Optional[Sequence] = None, ) -> SearchStrategy: """Strategy to generate elements of a column or index. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param checks: sequence of :class:`~pandera.checks.Check` s to constrain @@ -709,25 +724,25 @@ def undefined_check_strategy(elements, check): "definition. This can considerably slow down data-generation." ) return ( - pandas_dtype_strategy(pandas_dtype) + pandas_dtype_strategy(pandera_dtype) if elements is None else elements ).filter(check._check_fn) for check in checks: if hasattr(check, "strategy"): - elements = check.strategy(pandas_dtype, elements) + elements = check.strategy(pandera_dtype, elements) elif check.element_wise: elements = undefined_check_strategy(elements, check) # NOTE: vectorized checks with undefined strategies should be handled # by the series/dataframe strategy. if elements is None: - elements = pandas_dtype_strategy(pandas_dtype) + elements = pandas_dtype_strategy(pandera_dtype) return elements def series_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, checks: Optional[Sequence] = None, @@ -738,7 +753,7 @@ def series_strategy( ): """Strategy to generate a pandas Series. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param checks: sequence of :class:`~pandera.checks.Check` s to constrain @@ -750,11 +765,11 @@ def series_strategy( :param size: number of elements in the Series. :returns: ``hypothesis`` strategy. """ - elements = field_element_strategy(pandas_dtype, strategy, checks=checks) + elements = field_element_strategy(pandera_dtype, strategy, checks=checks) strategy = ( pdst.series( elements=elements, - dtype=pandas_dtype.numpy_dtype, + dtype=to_numpy_dtype(pandera_dtype), index=pdst.range_indexes( min_size=0 if size is None else size, max_size=size ), @@ -762,7 +777,7 @@ def series_strategy( ) .filter(lambda x: x.shape[0] > 0) .map(lambda x: x.rename(name)) - .map(lambda x: x.astype(pandas_dtype.str_alias)) + .map(lambda x: x.astype(str(pandera_dtype))) ) if nullable: strategy = null_field_masks(strategy) @@ -788,7 +803,7 @@ def _check_fn(series): def column_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, checks: Optional[Sequence] = None, @@ -798,7 +813,7 @@ def column_strategy( # pylint: disable=line-too-long """Create a data object describing a column in a DataFrame. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param checks: sequence of :class:`~pandera.checks.Check` s to constrain @@ -808,18 +823,18 @@ def column_strategy( :param name: name of the Series. :returns: a `column `_ object. """ - verify_pandas_dtype(pandas_dtype, schema_type="column", name=name) - elements = field_element_strategy(pandas_dtype, strategy, checks=checks) + verify_dtype(pandera_dtype, schema_type="column", name=name) + elements = field_element_strategy(pandera_dtype, strategy, checks=checks) return pdst.column( name=name, elements=elements, - dtype=pandas_dtype.numpy_dtype, + dtype=to_numpy_dtype(pandera_dtype), unique=not allow_duplicates, ) def index_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, checks: Optional[Sequence] = None, @@ -830,7 +845,7 @@ def index_strategy( ): """Strategy to generate a pandas Index. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param checks: sequence of :class:`~pandera.checks.Check` s to constrain @@ -842,15 +857,15 @@ def index_strategy( :param size: number of elements in the Series. :returns: ``hypothesis`` strategy. """ - verify_pandas_dtype(pandas_dtype, schema_type="index", name=name) - elements = field_element_strategy(pandas_dtype, strategy, checks=checks) + verify_dtype(pandera_dtype, schema_type="index", name=name) + elements = field_element_strategy(pandera_dtype, strategy, checks=checks) strategy = pdst.indexes( elements=elements, - dtype=pandas_dtype.numpy_dtype, + dtype=to_numpy_dtype(pandera_dtype), min_size=0 if size is None else size, max_size=size, unique=not allow_duplicates, - ).map(lambda x: x.astype(pandas_dtype.str_alias)) + ).map(lambda x: x.astype(str(pandera_dtype))) if name is not None: strategy = strategy.map(lambda index: index.rename(name)) if nullable: @@ -859,7 +874,7 @@ def index_strategy( def dataframe_strategy( - pandas_dtype: Optional[PandasDtype] = None, + pandera_dtype: Optional[DataType] = None, strategy: Optional[SearchStrategy] = None, *, columns: Optional[Dict] = None, @@ -870,7 +885,7 @@ def dataframe_strategy( ): """Strategy to generate a pandas DataFrame. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: if specified, this will raise a BaseStrategyOnlyError, since it cannot be chained to a prior strategy. :param columns: a dictionary where keys are column names and values @@ -931,18 +946,18 @@ def make_row_strategy(col, checks): strategy = None for check in checks: if hasattr(check, "strategy"): - strategy = check.strategy(col.pdtype, strategy) + strategy = check.strategy(col.dtype, strategy) else: strategy = undefined_check_strategy( strategy=( - pandas_dtype_strategy(col.pdtype) + pandas_dtype_strategy(col.dtype) if strategy is None else strategy ), check=check, ) if strategy is None: - strategy = pandas_dtype_strategy(col.pdtype) + strategy = pandas_dtype_strategy(col.dtype) return strategy @composite @@ -989,9 +1004,9 @@ def _dataframe_strategy(draw): # override the column datatype with dataframe-level datatype if # specified col_dtypes = { - col_name: col.dtype - if pandas_dtype is None - else pandas_dtype.str_alias + col_name: str(col.dtype) + if pandera_dtype is None + else str(pandera_dtype) for col_name, col in expanded_columns.items() } @@ -1042,7 +1057,7 @@ def _dataframe_strategy(draw): # pylint: disable=unused-argument def multiindex_strategy( - pandas_dtype: Optional[PandasDtype] = None, + pandera_dtype: Optional[DataType] = None, strategy: Optional[SearchStrategy] = None, *, indexes: Optional[List] = None, @@ -1050,7 +1065,7 @@ def multiindex_strategy( ): """Strategy to generate a pandas MultiIndex object. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param indexes: a list of :class:`~pandera.schema_components.Inded` @@ -1066,7 +1081,7 @@ def multiindex_strategy( ) indexes = [] if indexes is None else indexes index_dtypes = { - index.name if index.name is not None else i: index.dtype + index.name if index.name is not None else i: str(index.dtype) for i, index in enumerate(indexes) } nullable_index = { diff --git a/pandera/typing.py b/pandera/typing.py index 6d00dc9ab..c662a4b94 100644 --- a/pandera/typing.py +++ b/pandera/typing.py @@ -6,52 +6,77 @@ import pandas as pd import typing_inspect -from .dtypes import PandasDtype, PandasExtensionType - -try: # python 3.8+ - from typing import Literal # type: ignore -except ImportError: - from typing_extensions import Literal # type: ignore - +from . import dtypes +from .engines import numpy_engine, pandas_engine LEGACY_TYPING = sys.version_info[:2] < (3, 7) +Bool = dtypes.Bool #: ``"bool"`` numpy dtype +DateTime = dtypes.DateTime #: ``"datetime64[ns]"`` numpy dtype +Timedelta = dtypes.Timedelta #: ``"timedelta64[ns]"`` numpy dtype +Category = dtypes.Category #: pandas ``"categorical"`` datatype +Float = dtypes.Float #: ``"float"`` numpy dtype +Float16 = dtypes.Float16 #: ``"float16"`` numpy dtype +Float32 = dtypes.Float32 #: ``"float32"`` numpy dtype +Float64 = dtypes.Float64 #: ``"float64"`` numpy dtype +Int = dtypes.Int #: ``"int"`` numpy dtype +Int8 = dtypes.Int8 #: ``"int8"`` numpy dtype +Int16 = dtypes.Int16 #: ``"int16"`` numpy dtype +Int32 = dtypes.Int32 #: ``"int32"`` numpy dtype +Int64 = dtypes.Int64 #: ``"int64"`` numpy dtype +UInt8 = dtypes.UInt8 #: ``"uint8"`` numpy dtype +UInt16 = dtypes.UInt16 #: ``"uint16"`` numpy dtype +UInt32 = dtypes.UInt32 #: ``"uint32"`` numpy dtype +UInt64 = dtypes.UInt64 #: ``"uint64"`` numpy dtype +INT8 = pandas_engine.INT8 #: ``"Int8"`` pandas dtype:: pandas 0.24.0+ +INT16 = pandas_engine.INT16 #: ``"Int16"`` pandas dtype: pandas 0.24.0+ +INT32 = pandas_engine.INT32 #: ``"Int32"`` pandas dtype: pandas 0.24.0+ +INT64 = pandas_engine.INT64 #: ``"Int64"`` pandas dtype: pandas 0.24.0+ +UINT8 = pandas_engine.UINT8 #: ``"UInt8"`` pandas dtype:: pandas 0.24.0+ +UINT16 = pandas_engine.UINT16 #: ``"UInt16"`` pandas dtype: pandas 0.24.0+ +UINT32 = pandas_engine.UINT32 #: ``"UInt32"`` pandas dtype: pandas 0.24.0+ +UINT64 = pandas_engine.UINT64 #: ``"UInt64"`` pandas dtype: pandas 0.24.0+ +Object = numpy_engine.Object #: ``"object"`` numpy dtype +String = dtypes.String #: ``"str"`` numpy dtype +#: ``"string"`` pandas dtypes: pandas 1.0.0+. For <1.0.0, this enum will +#: fall back on the str-as-object-array representation. +STRING = pandas_engine.STRING #: ``"str"`` numpy dtype + GenericDtype = TypeVar( # type: ignore "GenericDtype", - PandasDtype, - PandasExtensionType, bool, int, str, float, - Literal[PandasDtype.Bool], - Literal[PandasDtype.DateTime], - Literal[PandasDtype.Category], - Literal[PandasDtype.Float], - Literal[PandasDtype.Float16], - Literal[PandasDtype.Float32], - Literal[PandasDtype.Float64], - Literal[PandasDtype.Int], - Literal[PandasDtype.Int8], - Literal[PandasDtype.Int16], - Literal[PandasDtype.Int32], - Literal[PandasDtype.Int64], - Literal[PandasDtype.UInt8], - Literal[PandasDtype.UInt16], - Literal[PandasDtype.UInt32], - Literal[PandasDtype.UInt64], - Literal[PandasDtype.INT8], - Literal[PandasDtype.INT16], - Literal[PandasDtype.INT32], - Literal[PandasDtype.INT64], - Literal[PandasDtype.UINT8], - Literal[PandasDtype.UINT16], - Literal[PandasDtype.UINT32], - Literal[PandasDtype.UINT64], - Literal[PandasDtype.Object], - Literal[PandasDtype.String], - Literal[PandasDtype.STRING], - Literal[PandasDtype.Timedelta], + pd.core.dtypes.base.ExtensionDtype, + Bool, + DateTime, + Timedelta, + Category, + Float, + Float16, + Float32, + Float64, + Int, + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + Object, + String, + STRING, covariant=True, ) Schema = TypeVar("Schema", bound="SchemaModel") # type: ignore @@ -157,47 +182,3 @@ def _parse_annotation(self, raw_annotation: Type) -> None: self.literal = typing_inspect.is_literal_type(self.arg) if self.literal: self.arg = typing_inspect.get_args(self.arg)[0] - - -Bool = Literal[PandasDtype.Bool] #: ``"bool"`` numpy dtype -DateTime = Literal[PandasDtype.DateTime] #: ``"datetime64[ns]"`` numpy dtype -Timedelta = Literal[ - PandasDtype.Timedelta -] #: ``"timedelta64[ns]"`` numpy dtype -Category = Literal[PandasDtype.Category] #: pandas ``"categorical"`` datatype -Float = Literal[PandasDtype.Float] #: ``"float"`` numpy dtype -Float16 = Literal[PandasDtype.Float16] #: ``"float16"`` numpy dtype -Float32 = Literal[PandasDtype.Float32] #: ``"float32"`` numpy dtype -Float64 = Literal[PandasDtype.Float64] #: ``"float64"`` numpy dtype -Int = Literal[PandasDtype.Int] #: ``"int"`` numpy dtype -Int8 = Literal[PandasDtype.Int8] #: ``"int8"`` numpy dtype -Int16 = Literal[PandasDtype.Int16] #: ``"int16"`` numpy dtype -Int32 = Literal[PandasDtype.Int32] #: ``"int32"`` numpy dtype -Int64 = Literal[PandasDtype.Int64] #: ``"int64"`` numpy dtype -UInt8 = Literal[PandasDtype.UInt8] #: ``"uint8"`` numpy dtype -UInt16 = Literal[PandasDtype.UInt16] #: ``"uint16"`` numpy dtype -UInt32 = Literal[PandasDtype.UInt32] #: ``"uint32"`` numpy dtype -UInt64 = Literal[PandasDtype.UInt64] #: ``"uint64"`` numpy dtype -INT8 = Literal[PandasDtype.INT8] #: ``"Int8"`` pandas dtype:: pandas 0.24.0+ -INT16 = Literal[PandasDtype.INT16] #: ``"Int16"`` pandas dtype: pandas 0.24.0+ -INT32 = Literal[PandasDtype.INT32] #: ``"Int32"`` pandas dtype: pandas 0.24.0+ -INT64 = Literal[PandasDtype.INT64] #: ``"Int64"`` pandas dtype: pandas 0.24.0+ -UINT8 = Literal[ - PandasDtype.UINT8 -] #: ``"UInt8"`` pandas dtype:: pandas 0.24.0+ -UINT16 = Literal[ - PandasDtype.UINT16 -] #: ``"UInt16"`` pandas dtype: pandas 0.24.0+ -UINT32 = Literal[ - PandasDtype.UINT32 -] #: ``"UInt32"`` pandas dtype: pandas 0.24.0+ -UINT64 = Literal[ - PandasDtype.UINT64 -] #: ``"UInt64"`` pandas dtype: pandas 0.24.0+ -Object = Literal[PandasDtype.Object] #: ``"object"`` numpy dtype - -String = Literal[PandasDtype.String] #: ``"str"`` numpy dtype - -#: ``"string"`` pandas dtypes: pandas 1.0.0+. For <1.0.0, this enum will -#: fall back on the str-as-object-array representation. -STRING = Literal[PandasDtype.STRING] #: ``"str"`` numpy dtype diff --git a/requirements-dev.txt b/requirements-dev.txt index ad5471ee8..869e6c066 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,24 +12,28 @@ pyyaml >=5.1 typing_inspect >= 0.6.0 typing_extensions >= 3.7.4.3 frictionless +pyarrow black >= 20.8b1 isort >= 5.7.0 codecov -mypy == 0.812 +mypy >= 0.902 pylint >= 2.7.2 pytest pytest-cov pytest-xdist pytest-asyncio +xdoctest setuptools >= 52.0.0 nox == 2020.12.31 importlib_metadata -sphinx == 3.5.4 -sphinx_rtd_theme +sphinx sphinx-autodoc-typehints sphinx-copybutton recommonmark twine asv pre_commit -furo==2021.6.18b36 \ No newline at end of file +furo==2021.6.18b36 +types-click +types-pyyaml +types-pkg_resources \ No newline at end of file diff --git a/setup.py b/setup.py index 86b399ebe..d67ea3108 100644 --- a/setup.py +++ b/setup.py @@ -39,13 +39,15 @@ install_requires=[ "packaging >= 20.0", "numpy >= 1.9.0", - "pandas >= 0.25.3", + "pandas >= 1.0", "typing_extensions >= 3.7.4.3 ; python_version<'3.8'", "typing_inspect >= 0.6.0", "wrapt", + "frictionless", + "pyarrow", ], extras_require=extras_require, - python_requires=">=3.6", + python_requires=">=3.7", platforms="any", classifiers=[ "Development Status :: 5 - Production/Stable", @@ -54,7 +56,6 @@ "Intended Audience :: Science/Research", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", diff --git a/tests/conftest.py b/tests/conftest.py index b904c26af..4a29e47bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,6 @@ if not HAS_HYPOTHESIS: collect_ignore.append("test_strategies.py") else: - settings.register_profile("ci", max_examples=100, deadline=5000) - settings.register_profile("dev", max_examples=10, deadline=2000) + settings.register_profile("ci", max_examples=100, deadline=None) + settings.register_profile("dev", max_examples=10, deadline=None) settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "dev")) diff --git a/tests/core/test_decorators.py b/tests/core/test_decorators.py index 635bee242..8a2a7005d 100644 --- a/tests/core/test_decorators.py +++ b/tests/core/test_decorators.py @@ -14,7 +14,6 @@ Field, Float, Int, - PandasDtype, SchemaModel, String, check_input, @@ -23,6 +22,7 @@ check_types, errors, ) +from pandera.engines.pandas_engine import Engine from pandera.typing import DataFrame, Index, Series try: @@ -356,7 +356,8 @@ def validate_inplace(df): # invalid out schema types for out_schema in [1, 5.0, "foo", {"foo": "bar"}, ["foo"]]: - @check_io(out=out_schema) # type: ignore[arg-type] # mypy finds correctly the wrong usage + # mypy finds correctly the wrong usage + @check_io(out=out_schema) # type: ignore[arg-type] def invalid_out_schema_type(df): return df @@ -656,8 +657,8 @@ def transform_in(df: DataFrame[InSchema]): return df df = transform_in(pd.DataFrame({"a": ["1"]}, index=["1"])) - expected = InSchema.to_schema().columns["a"].pandas_dtype - assert PandasDtype(str(df["a"].dtype)) == expected == PandasDtype("int") + expected = InSchema.to_schema().columns["a"].dtype + assert Engine.dtype(df["a"].dtype) == expected @check_types() def transform_out() -> DataFrame[OutSchema]: @@ -665,10 +666,8 @@ def transform_out() -> DataFrame[OutSchema]: return pd.DataFrame({"b": ["1"]}) out_df = transform_out() - expected = OutSchema.to_schema().columns["b"].pandas_dtype - assert ( - PandasDtype(str(out_df["b"].dtype)) == expected == PandasDtype("int") - ) + expected = OutSchema.to_schema().columns["b"].dtype + assert Engine.dtype(out_df["b"].dtype) == expected @pytest.mark.parametrize( diff --git a/tests/core/test_deprecations.py b/tests/core/test_deprecations.py new file mode 100644 index 000000000..f509cc44b --- /dev/null +++ b/tests/core/test_deprecations.py @@ -0,0 +1,53 @@ +"""Unit tests for deprecated features.""" + +import platform + +import pytest + +import pandera as pa + +WINDOWS_PLATFORM = platform.system() == "Windows" + + +@pytest.mark.parametrize( + "schema_cls,as_pos_arg", + [ + [pa.DataFrameSchema, False], + [pa.SeriesSchema, True], + [pa.Column, True], + [pa.Index, True], + ], +) +def test_deprecate_pandas_dtype(schema_cls, as_pos_arg): + """Test that pandas_dtype deprecation warnings/errors are raised.""" + assert schema_cls(dtype=int).dtype.check(pa.Int()) + assert schema_cls(pandas_dtype=int).dtype.check(pa.Int()) + + with pytest.warns(DeprecationWarning): + schema_cls(pandas_dtype=int) + with pytest.raises(pa.errors.SchemaInitError): + schema_cls(dtype=int, pandas_dtype=int) + + if as_pos_arg: + assert schema_cls(int).dtype.check(pa.Int()) + with pytest.raises(pa.errors.SchemaInitError): + schema_cls(int, pandas_dtype=int) + + +@pytest.mark.parametrize( + "schema_cls", + [ + pa.DataFrameSchema, + pa.SeriesSchema, + pa.Column, + pa.Index, + ], +) +def test_deprecate_pandas_dtype_enum(schema_cls): + """Test that using the PandasDtype enum raises a DeprecationWarning.""" + for attr in pa.PandasDtype: + if WINDOWS_PLATFORM and attr in {"Float128", "Complex256"}: + continue + with pytest.warns(DeprecationWarning): + pandas_dtype = getattr(pa.PandasDtype, attr) + schema_cls(dtype=pandas_dtype) diff --git a/tests/core/test_dtypes.py b/tests/core/test_dtypes.py index 6b8760cb0..03b6be048 100644 --- a/tests/core/test_dtypes.py +++ b/tests/core/test_dtypes.py @@ -1,654 +1,520 @@ """Tests a variety of python and pandas dtypes, and tests some specific coercion examples.""" - +# pylint doesn't know about __init__ generated with dataclass +# pylint:disable=unexpected-keyword-arg,no-value-for-parameter +import dataclasses +import datetime +import inspect import platform -from typing import Callable, List, Type +from decimal import Decimal +from typing import Any, Dict, List, Tuple +import hypothesis import numpy as np import pandas as pd import pytest -from packaging import version +from _pytest.mark.structures import ParameterSet +from _pytest.python import Metafunc +from hypothesis import strategies as st import pandera as pa -from pandera import ( - Bool, - Category, - Check, - Column, - DataFrameSchema, - DateTime, - Float, - Int, - Object, - PandasDtype, - SeriesSchema, - String, - Timedelta, -) -from pandera.dtypes import ( - _DEFAULT_NUMPY_FLOAT_TYPE, - _DEFAULT_NUMPY_INT_TYPE, - _DEFAULT_PANDAS_FLOAT_TYPE, - _DEFAULT_PANDAS_INT_TYPE, - PANDAS_1_3_0_PLUS, -) -from pandera.errors import SchemaError - -PANDAS_VERSION = version.parse(pd.__version__) -WINDOWS = platform.system() == "Windows" - -TESTABLE_DTYPES = [ - (Bool, "bool"), - (DateTime, "datetime64[ns]"), - (Category, "category"), - (Float, Float.str_alias), - (Int, Int.str_alias), - (Object, "object"), - (String, String.str_alias), - (Timedelta, "timedelta64[ns]"), - ("bool", "bool"), - ("datetime64[ns]", "datetime64[ns]"), - ("category", "category"), - ("float64", "float64"), -] +from pandera.engines import pandas_engine + +WINDOWS_PLATFORM = platform.system() == "Windows" + +# List dtype classes and associated pandas alias, +# except for parameterizable dtypes that should also list examples of +# instances. +int_dtypes = { + int: "int64", + pa.Int: "int64", + pa.Int8: "int8", + pa.Int16: "int16", + pa.Int32: "int32", + pa.Int64: "int64", + np.int8: "int8", + np.int16: "int16", + np.int32: "int32", + np.int64: "int64", +} + + +nullable_int_dtypes = { + pandas_engine.INT8: "Int8", + pandas_engine.INT16: "Int16", + pandas_engine.INT32: "Int32", + pandas_engine.INT64: "Int64", +} + +uint_dtypes = { + pa.UInt: "uint64", + pa.UInt8: "uint8", + pa.UInt16: "uint16", + pa.UInt32: "uint32", + pa.UInt64: "uint64", + np.uint8: "uint8", + np.uint16: "uint16", + np.uint32: "uint32", + np.uint64: "uint64", +} + +nullable_uint_dtypes = { + pandas_engine.UINT8: "UInt8", + pandas_engine.UINT16: "UInt16", + pandas_engine.UINT32: "UInt32", + pandas_engine.UINT64: "UInt64", +} + +float_dtypes = { + float: "float", + pa.Float: "float64", + pa.Float16: "float16", + pa.Float32: "float32", + pa.Float64: "float64", + np.float16: "float16", + np.float32: "float32", + np.float64: "float64", +} + + +complex_dtypes = { + complex: "complex", + pa.Complex: "complex128", + pa.Complex64: "complex64", + pa.Complex128: "complex128", +} + + +if not WINDOWS_PLATFORM: + float_dtypes.update( + { + pa.Float128: "float128", + np.float128: "float128", + } + ) + complex_dtypes.update( + { + pa.Complex256: "complex256", + np.complex256: "complex256", + } + ) +boolean_dtypes = {bool: "bool", pa.Bool: "bool", np.bool_: "bool"} +nullable_boolean_dtypes = {pd.BooleanDtype: "boolean", pa.BOOL: "boolean"} -def test_default_numeric_dtypes() -> None: - """Test that default numeric dtypes int and float are consistent.""" - assert str(pd.Series([1]).dtype) == _DEFAULT_PANDAS_INT_TYPE - assert pa.Int.str_alias == _DEFAULT_PANDAS_INT_TYPE - assert str(pd.Series([1], dtype=int).dtype) == _DEFAULT_NUMPY_INT_TYPE - assert str(pd.Series([1], dtype="int").dtype) == _DEFAULT_NUMPY_INT_TYPE +string_dtypes = { + str: "str", + pa.String: "str", + np.str_: "str", +} - assert str(pd.Series([1.0]).dtype) == _DEFAULT_PANDAS_FLOAT_TYPE - assert pa.Float.str_alias == _DEFAULT_PANDAS_FLOAT_TYPE - assert ( - str(pd.Series([1.0], dtype=float).dtype) == _DEFAULT_NUMPY_FLOAT_TYPE - ) - assert ( - str(pd.Series([1.0], dtype="float").dtype) == _DEFAULT_NUMPY_FLOAT_TYPE +nullable_string_dtypes = {pd.StringDtype: "string"} +if pa.PANDAS_1_3_0_PLUS: + nullable_string_dtypes.update( + {pd.StringDtype(storage="pyarrow"): "string[pyarrow]"} ) +object_dtypes = {object: "object", np.object_: "object"} + +category_dtypes = { + pa.Category: "category", + pa.Category(["A", "B"], ordered=True): pd.CategoricalDtype( + ["A", "B"], ordered=True + ), + pd.CategoricalDtype(["A", "B"], ordered=True): pd.CategoricalDtype( + ["A", "B"], ordered=True + ), +} + +timestamp_dtypes = { + datetime.datetime: "datetime64[ns]", + np.datetime64: "datetime64[ns]", + pa.Timestamp: "datetime64[ns]", + pd.DatetimeTZDtype(tz="CET"): "datetime64[ns, CET]", + pandas_engine.DateTime: "datetime64[ns]", + pandas_engine.DateTime(unit="ns", tz="CET"): "datetime64[ns, CET]", # type: ignore +} + +timedelta_dtypes = { + datetime.timedelta: "timedelta64", + datetime.timedelta: "timedelta64", + np.timedelta64: "timedelta64", + pd.Timedelta: "timedelta64", + pa.Timedelta: "timedelta64", +} + +period_dtypes = {pd.PeriodDtype(freq="D"): "period[D]"} +# Series.astype does not accept a string alias for SparseDtype. +sparse_dtypes = { + pd.SparseDtype: pd.SparseDtype(), + pd.SparseDtype(np.float64): pd.SparseDtype(np.float64), +} +interval_dtypes = {pd.IntervalDtype(subtype=np.int64): "interval[int64]"} + +dtype_fixtures: List[Tuple[Dict, List]] = [ + (int_dtypes, [-1]), + (nullable_int_dtypes, [-1, None]), + (uint_dtypes, [1]), + (nullable_uint_dtypes, [1, None]), + (float_dtypes, [1.0]), + (complex_dtypes, [complex(1)]), + (boolean_dtypes, [True, False]), + (nullable_boolean_dtypes, [True, None]), + (string_dtypes, ["A", "B"]), + (object_dtypes, ["A", "B"]), + (nullable_string_dtypes, [1, 2, None]), + (category_dtypes, [1, 2, None]), + ( + timestamp_dtypes, + pd.to_datetime(["2019/01/01", "2018/05/21"]).to_series(), + ), + ( + period_dtypes, + pd.to_datetime(["2019/01/01", "2018/05/21"]) + .to_period("D") + .to_series(), + ), + (sparse_dtypes, pd.Series([1, None], dtype=pd.SparseDtype(float))), + (interval_dtypes, pd.interval_range(-10.0, 10.0).to_series()), +] -def test_numeric_dtypes() -> None: - """Test every numeric type can be validated properly by schema.validate""" - for dtype in [pa.Float, pa.Float16, pa.Float32, pa.Float64]: - assert all( - isinstance( - schema.validate( - pd.DataFrame( - {"col": [-123.1, -7654.321, 1.0, 1.1, 1199.51, 5.1]}, - dtype=dtype.str_alias, - ) - ), - pd.DataFrame, - ) - for schema in [ - DataFrameSchema({"col": Column(dtype, nullable=False)}), - DataFrameSchema( - {"col": Column(dtype.str_alias, nullable=False)} - ), - ] - ) - for dtype in [pa.Int, pa.Int8, pa.Int16, pa.Int32, pa.Int64]: - assert all( - isinstance( - schema.validate( - pd.DataFrame( - {"col": [-712, -4, -321, 0, 1, 777, 5, 123, 9000]}, - dtype=dtype.str_alias, - ) - ), - pd.DataFrame, - ) - for schema in [ - DataFrameSchema({"col": Column(dtype, nullable=False)}), - DataFrameSchema( - {"col": Column(dtype.str_alias, nullable=False)} - ), - ] +def pretty_param(*values: Any, **kw: Any) -> ParameterSet: + """Return a pytest parameter with a human-readable id.""" + id_ = kw.pop("id", None) + if not id_: + id_ = "-".join( + f"{val.__module__}.{val.__name__}" + if inspect.isclass(val) + else repr(val) + for val in values ) + return pytest.param(*values, id=id_, **kw) + - for dtype in [pa.UInt8, pa.UInt16, pa.UInt32, pa.UInt64]: - assert all( - isinstance( - schema.validate( - pd.DataFrame( - {"col": [1, 777, 5, 123, 9000]}, dtype=dtype.str_alias +def pytest_generate_tests(metafunc: Metafunc) -> None: + """Inject `dtype`, `data_type` (filter pandera DataTypes), `alias`, `data` + fixtures from `dtype_fixtures`. + """ + fixtures = [ + fixture + for fixture in ("data_type", "dtype", "pd_dtype", "data") + if fixture in metafunc.fixturenames + ] + arg_names = ",".join(fixtures) + + if arg_names: + arg_values = [] + for dtypes, data in dtype_fixtures: + for dtype, pd_dtype in dtypes.items(): + if "data_type" in fixtures and not ( + isinstance(dtype, pa.DataType) + or ( + inspect.isclass(dtype) + and issubclass(dtype, pa.DataType) ) - ), - pd.DataFrame, - ) - for schema in [ - DataFrameSchema({"col": Column(dtype, nullable=False)}), - DataFrameSchema( - {"col": Column(dtype.str_alias, nullable=False)} - ), - ] - ) + ): + # not a pa.DataType class or instance + continue + params = [dtype] + if "pd_dtype" in fixtures: + params.append(pd_dtype) + if "data" in fixtures: + params.append(data) + arg_values.append(pretty_param(*params)) -@pytest.mark.skipif( - PANDAS_VERSION.release < (1, 0, 0), # type: ignore - reason="pandas >= 1.0.0 required", -) -@pytest.mark.parametrize( - "dtype", - [ - pa.INT8, - pa.INT16, - pa.INT32, - pa.INT64, - pa.UINT8, - pa.UINT16, - pa.UINT32, - pa.UINT64, - ], -) -@pytest.mark.parametrize("coerce", [True, False]) -def test_pandas_nullable_int_dtype( - dtype: pa.PandasDtype, coerce: bool -) -> None: - """Test that pandas nullable int dtype can be specified in a schema.""" - assert all( - isinstance( - schema.validate( - pd.DataFrame( - # keep max range to 127 in order to support Int8 - {"col": range(128)}, - **({} if coerce else {"dtype": dtype.str_alias}), - ) - ), - pd.DataFrame, + metafunc.parametrize(arg_names, arg_values) + + +def test_datatype_init(data_type: Any): + """Test that a default pa.DataType can be constructed.""" + if not inspect.isclass(data_type): + pytest.skip( + "test_datatype_init tests pa.DataType classes, not instances." ) - for schema in [ - DataFrameSchema( - {"col": Column(dtype, nullable=False)}, coerce=coerce - ), - DataFrameSchema( - {"col": Column(dtype.str_alias, nullable=False)}, coerce=coerce - ), - ] - ) + assert isinstance(data_type(), pa.DataType) + + +def test_datatype_alias(data_type: Any, pd_dtype: Any): + """Test that a default pa.DataType can be constructed.""" + assert str(pandas_engine.Engine.dtype(data_type)) == str(pd_dtype) + + +def test_frozen_datatype(data_type: Any): + """Test that pa.DataType instances are immutable.""" + data_type = data_type() if inspect.isclass(data_type) else data_type + with pytest.raises(dataclasses.FrozenInstanceError): + data_type.foo = 1 -@pytest.mark.parametrize("str_alias", ["foo", "bar", "baz", "asdf", "qwerty"]) -def test_unrecognized_str_aliases(str_alias: str) -> None: - """Test that unrecognized string aliases are supported.""" +def test_invalid_pandas_extension_dtype(): + """Test that an invalid dtype is rejected.""" with pytest.raises(TypeError): - PandasDtype.from_str_alias(str_alias) - - -def test_category_dtype() -> None: - """Test the category type can be validated properly by schema.validate""" - schema = DataFrameSchema( - columns={ - "col": Column( - pa.Category, - checks=[ - Check(lambda s: set(s) == {"A", "B", "C"}), - Check( - lambda s: s.cat.categories.tolist() == ["A", "B", "C"] - ), - Check(lambda s: s.isin(["A", "B", "C"])), - ], - nullable=False, - ), - }, - coerce=False, - ) - validated_df = schema.validate( - pd.DataFrame( - {"col": pd.Series(["A", "B", "A", "B", "C"], dtype="category")} - ) + pandas_engine.Engine.dtype( + pd.PeriodDtype + ) # PerioDtype has required parameters + + +def test_check_equivalent(dtype: Any, pd_dtype: Any): + """Test that a pandas-compatible dtype can be validated by check().""" + actual_dtype = pandas_engine.Engine.dtype(pd_dtype) + expected_dtype = pandas_engine.Engine.dtype(dtype) + assert actual_dtype.check(expected_dtype) + + +def test_check_not_equivalent(dtype: Any): + """Test that check() rejects non-equivalent dtypes.""" + if str(pandas_engine.Engine.dtype(dtype)) == "object": + actual_dtype = pandas_engine.Engine.dtype(int) + else: + actual_dtype = pandas_engine.Engine.dtype(object) + expected_dtype = pandas_engine.Engine.dtype(dtype) + assert actual_dtype.check(expected_dtype) is False + + +def test_coerce_no_cast(dtype: Any, pd_dtype: Any, data: List[Any]): + """Test that dtypes can be coerced without casting.""" + expected_dtype = pandas_engine.Engine.dtype(dtype) + series = pd.Series(data, dtype=pd_dtype) + coerced_series = expected_dtype.coerce(series) + + assert series.equals(coerced_series) + assert expected_dtype.check( + pandas_engine.Engine.dtype(coerced_series.dtype) ) - assert isinstance(validated_df, pd.DataFrame) - - -def test_category_dtype_coerce() -> None: - """Test coercion of the category type is validated properly by - schema.validate and fails safely.""" - columns = { - "col": Column( - pa.Category, - checks=Check(lambda s: set(s) == {"A", "B", "C"}), - nullable=False, - ), - } - - with pytest.raises(SchemaError): - DataFrameSchema(columns=columns, coerce=False).validate( - pd.DataFrame( - {"col": pd.Series(["A", "B", "A", "B", "C"], dtype="object")} - ) - ) - validated_df = DataFrameSchema(columns=columns, coerce=True).validate( - pd.DataFrame( - {"col": pd.Series(["A", "B", "A", "B", "C"], dtype="object")} - ) + df = pd.DataFrame({"col": data}, dtype=pd_dtype) + coerced_df = expected_dtype.coerce(df) + + assert df.equals(coerced_df) + assert expected_dtype.check( + pandas_engine.Engine.dtype(coerced_df["col"].dtype) ) - assert isinstance(validated_df, pd.DataFrame) -def helper_type_validation( - dataframe_type, schema_type, debugging: bool = False -) -> None: - """ - Helper function for using same or different dtypes for the dataframe and - the schema_type - """ - df = pd.DataFrame({"column1": [dataframe_type(1)]}) - if debugging: - print(dataframe_type, df.column1) - schema = pa.DataFrameSchema({"column1": pa.Column(schema_type)}) - if debugging: - print(schema) - schema(df) +def _flatten_dtypesdict(*dtype_kinds): + return [ + (datatype, pd_dtype) + for dtype_kind in dtype_kinds + for datatype, pd_dtype in dtype_kind.items() + ] -@pytest.mark.parametrize( - "type1, type2", - [ - (np.complex_, np.complex_), - (np.complex_, np.complex128), - (np.complex128, np.complex_), - (np.float_, np.float_), - (np.float_, np.float64), - (np.int_, np.int32 if WINDOWS and PANDAS_1_3_0_PLUS else np.int64), - # unsigned ints are converted to signed ints if passed as a scalar - (np.uint, np.int32 if WINDOWS and PANDAS_1_3_0_PLUS else np.int64), - (np.bool_, np.bool_), - (np.str_, np.str_) - # np.object, np.void and bytes are not tested - ], +numeric_dtypes = _flatten_dtypesdict( + int_dtypes, + uint_dtypes, + float_dtypes, + complex_dtypes, + boolean_dtypes, ) -def test_valid_numpy_type_scalar_conversions( - type1: np.generic, type2: np.generic -) -> None: - """Test correct conversions of numpy dtypes""" - try: - helper_type_validation(type1, type2) - except: # pylint: disable=bare-except # noqa E722 - # No exceptions since it should cover all exceptions for debug - # purpose - # Rerun test with debug information - print(f"Error on types: {type1}, {type2}") - helper_type_validation(type1, type2, True) - - -@pytest.mark.skipif( - PANDAS_VERSION.release >= (1, 3, 0), # type: ignore - reason="pandas < 1.3.0 converts number types to default", + +nullable_numeric_dtypes = _flatten_dtypesdict( + nullable_int_dtypes, + nullable_uint_dtypes, + nullable_boolean_dtypes, ) -@pytest.mark.parametrize( - "type1, type2", - [ - # Pandas < 1.3.0 always converts complex numbers to np.complex128 - (np.complex64, np.complex128), - (np.complex128, np.complex128), - # Pandas < 1.3.0 always converts float numbers to np.float64 - (np.float16, np.float64), - (np.float32, np.float64), - (np.float64, np.float64), - # Pandas < 1.3.0 always converts int numbers to np.int64 - (np.int16, np.int64), - (np.int32, np.int64), - (np.int64, np.int64), - # Pandas < 1.3.0 always converts int numbers to np.int64 - (np.uint8, np.int64), - (np.uint16, np.int64), - (np.uint32, np.int64), - (np.uint64, np.int64), - # np.object, np.void and bytes are not tested - ], + +nominal_dtypes = _flatten_dtypesdict( + string_dtypes, + nullable_string_dtypes, + category_dtypes, ) -def test_valid_numpy_type_scalar_conversions_pandas_pre_1_3_0( - type1: np.generic, type2: np.generic -) -> None: - """Test correct conversions of numpy dtypes""" - try: - helper_type_validation(type1, type2) - except: # pylint: disable=bare-except # noqa E722 - # No exceptions since it should cover all exceptions for debug - # purpose - # Rerun test with debug inforation - print(f"Error on types: {type1}, {type2}") - helper_type_validation(type1, type2, True) @pytest.mark.parametrize( - "type1, type2", + "dtypes, examples", [ - (np.complex_, np.int_), - (np.int_, np.complex_), - (float, np.complex_), - (np.complex_, float), - (np.int_, np.float_), - (np.uint8, np.float_), - (np.complex_, str), + (numeric_dtypes, [1]), + (nullable_numeric_dtypes, [1, None]), + (nominal_dtypes, ["A", "B"]), ], ) -def test_invalid_numpy_type_conversions( - type1: np.generic, type2: np.generic -) -> None: - """Test various numpy dtypes""" - with pytest.raises(SchemaError): - helper_type_validation(type1, type2) - - PandasDtype.from_numpy_type(np.float_) - with pytest.raises(TypeError): - PandasDtype.from_numpy_type(pd.DatetimeIndex) +@hypothesis.given(st.data()) +def test_coerce_cast(dtypes, examples, data): + """Test that dtypes can be coerced with casting.""" + _, from_pd_dtype = data.draw(st.sampled_from(dtypes)) + to_datatype, _ = data.draw(st.sampled_from(dtypes)) + expected_dtype = pandas_engine.Engine.dtype(to_datatype) -def test_datetime() -> None: - """Test datetime types can be validated properly by schema.validate""" - schema = DataFrameSchema( - columns={ - "col": Column( - pa.DateTime, - checks=Check(lambda s: s.min() > pd.Timestamp("2015")), - ) - } - ) - - validated_df = schema.validate( - pd.DataFrame( - {"col": pd.to_datetime(["2019/01/01", "2018/05/21", "2016/03/10"])} - ) - ) + series = pd.Series(examples, dtype=from_pd_dtype) + coerced_dtype = expected_dtype.coerce(series).dtype + assert expected_dtype.check(pandas_engine.Engine.dtype(coerced_dtype)) - assert isinstance(validated_df, pd.DataFrame) + df = pd.DataFrame({"col": examples}, dtype=from_pd_dtype) + coerced_dtype = expected_dtype.coerce(df)["col"].dtype + assert expected_dtype.check(pandas_engine.Engine.dtype(coerced_dtype)) - with pytest.raises(SchemaError): - schema.validate(pd.DataFrame({"col": pd.to_datetime(["2010/01/01"])})) +def test_coerce_string(): + """Test that strings can be coerced.""" + data = pd.Series([1, None], dtype="Int32") + coerced = pandas_engine.Engine.dtype(str).coerce(data).to_list() + assert isinstance(coerced[0], str) + assert pd.isna(coerced[1]) -@pytest.mark.skipif( - PANDAS_VERSION.release < (1, 0, 0), # type: ignore - reason="pandas >= 1.0.0 required", -) -def test_pandas_extension_types() -> None: - """Test pandas extension data type happy path.""" - # pylint: disable=no-member - test_params = [ - ( - pd.CategoricalDtype( - ["a", "b", "c"] if PANDAS_1_3_0_PLUS else None - ), - pd.Series(["a", "a", "b", "b", "c", "c"], dtype="category"), - None, - ), - ( - pd.DatetimeTZDtype(tz="UTC"), - pd.Series( - pd.date_range(start="20200101", end="20200301"), - dtype="datetime64[ns, utc]", - ), - None, - ), - (pd.Int64Dtype(), pd.Series(range(10), dtype="Int64"), None), - ( - pd.StringDtype(), - pd.Series(["foo", "bar", "baz"], dtype="string"), - None, - ), - ( - pd.PeriodDtype(freq="D"), - pd.Series(pd.period_range("1/1/2019", "1/1/2020", freq="D")), - None, - ), - ( - pd.SparseDtype("float"), - pd.Series(range(100)) - .where(lambda s: s < 5, other=np.nan) - .astype("Sparse[float]"), - {"nullable": True}, - ), - (pd.BooleanDtype(), pd.Series([1, 0, 0, 1, 1], dtype="boolean"), None), - ( - ( - # pylint:disable=unexpected-keyword-arg - pd.IntervalDtype(subtype="int64", closed="right") - if PANDAS_1_3_0_PLUS - else pd.IntervalDtype(subtype="int64") - ), - pd.Series(pd.IntervalIndex.from_breaks([0, 1, 2, 3, 4])), - None, - ), - ] - for dtype, data, series_kwargs in test_params: - series_kwargs = {} if series_kwargs is None else series_kwargs - series_schema = SeriesSchema(pandas_dtype=dtype, **series_kwargs) # type: ignore - assert isinstance(series_schema.validate(data), pd.Series) +def test_default_numeric_dtypes(): + """ + Test that default numeric dtypes int, float and complex are consistent. + """ + default_int_dtype = pd.Series([1]).dtype + assert ( + pandas_engine.Engine.dtype(default_int_dtype) + == pandas_engine.Engine.dtype(int) + == pandas_engine.Engine.dtype("int") + ) -def test_python_builtin_types() -> None: - """Test support python data types can be used for validation.""" - schema = DataFrameSchema( - { - "int_col": Column(int), - "float_col": Column(float), - "str_col": Column(str), - "bool_col": Column(bool), - "object_col": Column(object), - "complex_col": Column(complex), - } + default_float_dtype = pd.Series([1.0]).dtype + assert ( + pandas_engine.Engine.dtype(default_float_dtype) + == pandas_engine.Engine.dtype(float) + == pandas_engine.Engine.dtype("float") ) - df = pd.DataFrame( - { - "int_col": [1, 2, 3], - "float_col": [1.0, 2.0, 3.0], - "str_col": list("abc"), - "bool_col": [True, False, True], - "object_col": [[1], 1, {"foo": "bar"}], - "complex_col": [complex(1), complex(2), complex(3)], - } + + default_complex_dtype = pd.Series([complex(1)]).dtype + assert ( + pandas_engine.Engine.dtype(default_complex_dtype) + == pandas_engine.Engine.dtype(complex) + == pandas_engine.Engine.dtype("complex") ) - assert isinstance(schema(df), pd.DataFrame) - assert schema.dtype["int_col"] == PandasDtype.Int.str_alias - assert schema.dtype["float_col"] == PandasDtype.Float.str_alias - assert schema.dtype["str_col"] == PandasDtype.String.str_alias - assert schema.dtype["bool_col"] == PandasDtype.Bool.str_alias - assert schema.dtype["object_col"] == PandasDtype.Object.str_alias - assert schema.dtype["complex_col"] == PandasDtype.Complex.str_alias - - -@pytest.mark.parametrize("python_type", [list, dict, set]) -def test_python_builtin_types_not_supported(python_type: Type) -> None: - """Test unsupported python data types raise a type error.""" - with pytest.raises(TypeError): - Column(python_type) @pytest.mark.parametrize( - "pandas_api_type,pandas_dtype", + "examples", [ - ["string", PandasDtype.String], - ["floating", PandasDtype.Float], - ["integer", PandasDtype.Int], - ["categorical", PandasDtype.Category], - ["boolean", PandasDtype.Bool], - ["datetime64", PandasDtype.DateTime], - ["datetime", PandasDtype.DateTime], - ["timedelta64", PandasDtype.Timedelta], - ["timedelta", PandasDtype.Timedelta], - ["mixed-integer", PandasDtype.Object], + pretty_param(param) + for param in [ + ["A", "B"], # string + [b"foo", b"bar"], # bytes + [1, 2, 3], # integer + ["a", datetime.date(2013, 1, 1)], # mixed + ["a", 1], # mixed-integer + [1, 2, 3.5, "foo"], # mixed-integer-float + [1.0, 2.0, 3.5], # floating + [Decimal(1), Decimal(2.0)], # decimal + [pd.Timestamp("20130101")], # datetime + [datetime.date(2013, 1, 1)], # date + [datetime.timedelta(0, 1, 1)], # timedelta + pd.Series(list("aabc")).astype("category"), # categorical + [Decimal(1), Decimal(2.0)], # decimal + ] ], ) -def test_pandas_api_types( - pandas_api_type: str, pandas_dtype: pa.PandasDtype -) -> None: - """Test pandas api type conversion.""" - assert PandasDtype.from_pandas_api_type(pandas_api_type) is pandas_dtype +def test_inferred_dtype(examples: pd.Series): + """Test compatibility with pd.api.types.infer_dtype's outputs.""" + alias = pd.api.types.infer_dtype(examples) + if "mixed" in alias or alias in ("date", "string"): + # infer_dtype returns "string", "date" + # whereas a Series will default to a "np.object_" dtype + inferred_datatype = pandas_engine.Engine.dtype(object) + else: + inferred_datatype = pandas_engine.Engine.dtype(alias) + actual_dtype = pandas_engine.Engine.dtype(pd.Series(examples).dtype) + assert actual_dtype.check(inferred_datatype) @pytest.mark.parametrize( - "invalid_pandas_api_type", - [ - "foo", - "bar", - "baz", - "this is not a type", - ], + "int_dtype, expected", + [(dtype, True) for dtype in (*int_dtypes, *nullable_int_dtypes)] + + [("string", False)], # type:ignore ) -def test_pandas_api_type_exception(invalid_pandas_api_type: str) -> None: - """Test unsupported values for pandas api type conversion.""" - with pytest.raises(TypeError): - PandasDtype.from_pandas_api_type(invalid_pandas_api_type) +def test_is_int(int_dtype: Any, expected: bool): + """Test is_int.""" + pandera_dtype = pandas_engine.Engine.dtype(int_dtype) + assert pa.dtypes.is_int(pandera_dtype) == expected @pytest.mark.parametrize( - "pandas_dtype", (pandas_dtype for pandas_dtype in PandasDtype) + "uint_dtype, expected", + [(dtype, True) for dtype in (*uint_dtypes, *nullable_uint_dtypes)] + + [("string", False)], # type:ignore ) -def test_pandas_dtype_equality(pandas_dtype: pa.PandasDtype) -> None: - """Test __eq__ implementation.""" - assert pandas_dtype is not None # pylint:disable=singleton-comparison - assert pandas_dtype == pandas_dtype.value +def test_is_uint(uint_dtype: Any, expected: bool): + """Test is_uint.""" + pandera_dtype = pandas_engine.Engine.dtype(uint_dtype) + assert pa.dtypes.is_uint(pandera_dtype) == expected -@pytest.mark.parametrize("pdtype", PandasDtype) -def test_dtype_none_comparison(pdtype: pa.PandasDtype) -> None: - """Test that comparing PandasDtype to None is False.""" - assert pdtype is not None +@pytest.mark.parametrize( + "float_dtype, expected", + [(dtype, True) for dtype in float_dtypes] + [("string", False)], # type: ignore +) +def test_is_float(float_dtype: Any, expected: bool): + """Test is_float.""" + pandera_dtype = pandas_engine.Engine.dtype(float_dtype) + assert pa.dtypes.is_float(pandera_dtype) == expected @pytest.mark.parametrize( - "property_fn, pdtypes", - [ - [ - lambda x: x.is_int, - [ - PandasDtype.Int, - PandasDtype.Int8, - PandasDtype.Int16, - PandasDtype.Int32, - PandasDtype.Int64, - PandasDtype.INT8, - PandasDtype.INT16, - PandasDtype.INT32, - PandasDtype.INT64, - ], - ], - [ - lambda x: x.is_nullable_int, - [ - PandasDtype.INT8, - PandasDtype.INT16, - PandasDtype.INT32, - PandasDtype.INT64, - ], - ], - [ - lambda x: x.is_nonnullable_int, - [ - PandasDtype.Int, - PandasDtype.Int8, - PandasDtype.Int16, - PandasDtype.Int32, - PandasDtype.Int64, - ], - ], - [ - lambda x: x.is_uint, - [ - PandasDtype.UInt8, - PandasDtype.UInt16, - PandasDtype.UInt32, - PandasDtype.UInt64, - PandasDtype.UINT8, - PandasDtype.UINT16, - PandasDtype.UINT32, - PandasDtype.UINT64, - ], - ], - [ - lambda x: x.is_nullable_uint, - [ - PandasDtype.UINT8, - PandasDtype.UINT16, - PandasDtype.UINT32, - PandasDtype.UINT64, - ], - ], - [ - lambda x: x.is_nonnullable_uint, - [ - PandasDtype.UInt8, - PandasDtype.UInt16, - PandasDtype.UInt32, - PandasDtype.UInt64, - ], - ], - [ - lambda x: x.is_float, - [ - PandasDtype.Float, - PandasDtype.Float16, - PandasDtype.Float32, - PandasDtype.Float64, - ], - ], - [ - lambda x: x.is_complex, - [ - PandasDtype.Complex, - PandasDtype.Complex64, - PandasDtype.Complex128, - PandasDtype.Complex256, - ], - ], - [lambda x: x.is_bool, [PandasDtype.Bool]], - [lambda x: x.is_string, [PandasDtype.String, PandasDtype.String]], - [lambda x: x.is_category, [PandasDtype.Category]], - [lambda x: x.is_datetime, [PandasDtype.DateTime]], - [lambda x: x.is_timedelta, [PandasDtype.Timedelta]], - [lambda x: x.is_object, [PandasDtype.Object]], - [ - lambda x: x.is_continuous, - [ - PandasDtype.Int, - PandasDtype.Int8, - PandasDtype.Int16, - PandasDtype.Int32, - PandasDtype.Int64, - PandasDtype.INT8, - PandasDtype.INT16, - PandasDtype.INT32, - PandasDtype.INT64, - PandasDtype.UInt8, - PandasDtype.UInt16, - PandasDtype.UInt32, - PandasDtype.UInt64, - PandasDtype.UINT8, - PandasDtype.UINT16, - PandasDtype.UINT32, - PandasDtype.UINT64, - PandasDtype.Float, - PandasDtype.Float16, - PandasDtype.Float32, - PandasDtype.Float64, - PandasDtype.Complex, - PandasDtype.Complex64, - PandasDtype.Complex128, - PandasDtype.Complex256, - PandasDtype.DateTime, - PandasDtype.Timedelta, - ], - ], - ], + "complex_dtype, expected", + [(dtype, True) for dtype in complex_dtypes] + + [("string", False)], # type: ignore ) -def test_dtype_is_checks( - property_fn: Callable[..., bool], pdtypes: List[pa.PandasDtype] -) -> None: - """Test all the pandas dtype is_* properties.""" - for pdtype in pdtypes: - assert property_fn(pdtype) +def test_is_complex(complex_dtype: Any, expected: bool): + """Test is_complex.""" + pandera_dtype = pandas_engine.Engine.dtype(complex_dtype) + assert pa.dtypes.is_complex(pandera_dtype) == expected -def test_category_dtype_exception() -> None: - """Test that category dtype has no numpy dtype equivalent.""" - with pytest.raises(TypeError): - # pylint: disable=pointless-statement - PandasDtype.Category.numpy_dtype +@pytest.mark.parametrize( + "bool_dtype, expected", + [(dtype, True) for dtype in (*boolean_dtypes, *nullable_boolean_dtypes)] + + [("string", False)], +) +def test_is_bool(bool_dtype: Any, expected: bool): + """Test is_bool.""" + pandera_dtype = pandas_engine.Engine.dtype(bool_dtype) + assert pa.dtypes.is_bool(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "string_dtype, expected", + [(dtype, True) for dtype in string_dtypes] + + [("int", False)], # type:ignore +) +def test_is_string(string_dtype: Any, expected: bool): + """Test is_string.""" + pandera_dtype = pandas_engine.Engine.dtype(string_dtype) + assert pa.dtypes.is_string(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "category_dtype, expected", + [(dtype, True) for dtype in category_dtypes] + [("string", False)], +) +def test_is_category(category_dtype: Any, expected: bool): + """Test is_category.""" + pandera_dtype = pandas_engine.Engine.dtype(category_dtype) + assert pa.dtypes.is_category(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "datetime_dtype, expected", + [(dtype, True) for dtype in timestamp_dtypes] + [("string", False)], +) +def test_is_datetime(datetime_dtype: Any, expected: bool): + """Test is_datetime.""" + pandera_dtype = pandas_engine.Engine.dtype(datetime_dtype) + assert pa.dtypes.is_datetime(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "timedelta_dtype, expected", + [(dtype, True) for dtype in timedelta_dtypes] + [("string", False)], +) +def test_is_timedelta(timedelta_dtype: Any, expected: bool): + """Test is_timedelta.""" + pandera_dtype = pandas_engine.Engine.dtype(timedelta_dtype) + assert pa.dtypes.is_timedelta(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "numeric_dtype, expected", + [(dtype, True) for dtype, _ in numeric_dtypes] + [("string", False)], +) +def test_is_numeric(numeric_dtype: Any, expected: bool): + """Test is_timedelta.""" + pandera_dtype = pandas_engine.Engine.dtype(numeric_dtype) + assert pa.dtypes.is_numeric(pandera_dtype) == expected diff --git a/tests/core/test_engine.py b/tests/core/test_engine.py new file mode 100644 index 000000000..6e38121c0 --- /dev/null +++ b/tests/core/test_engine.py @@ -0,0 +1,195 @@ +"""Tests Engine subclassing and registring DataTypes.""" +# pylint:disable=redefined-outer-name,unused-argument +# pylint:disable=missing-function-docstring,missing-class-docstring +import re +from typing import Any, Generator, List, Union + +import pytest + +from pandera.dtypes import DataType +from pandera.engines.engine import Engine + + +class BaseDataType(DataType): + def __eq__(self, obj: object) -> bool: + if isinstance(obj, type(self)): + return True + return False + + def __hash__(self) -> int: + return hash(self.__class__.__name__) + + +class SimpleDtype(BaseDataType): + pass + + +@pytest.fixture +def equivalents() -> List[Any]: + return [int, "int", 1] + + +@pytest.fixture +def engine() -> Generator[Engine, None, None]: + class FakeEngine( # pylint:disable=too-few-public-methods + metaclass=Engine, base_pandera_dtypes=BaseDataType + ): + pass + + yield FakeEngine + + del FakeEngine + + +def test_register_equivalents(engine: Engine, equivalents: List[Any]): + """Test that a dtype with equivalents can be registered.""" + engine.register_dtype(SimpleDtype, equivalents=equivalents) + + for equivalent in equivalents: + assert engine.dtype(equivalent) == SimpleDtype() + + with pytest.raises( + TypeError, match="Data type 'foo' not understood by FakeEngine" + ): + engine.dtype("foo") + + +def test_register_from_parametrized_dtype(engine: Engine): + """Test that a dtype with from_parametrized_dtype can be registered.""" + + @engine.register_dtype + class _Dtype(BaseDataType): + @classmethod + def from_parametrized_dtype(cls, x: int): + return x + + assert engine.dtype(42) == 42 + + with pytest.raises( + TypeError, match="Data type 'foo' not understood by FakeEngine" + ): + engine.dtype("foo") + + +def test_register_from_parametrized_dtype_union(engine: Engine): + """Test that a dtype with from_parametrized_dtype and Union annotation + can be registered. + """ + + @engine.register_dtype + class _Dtype(BaseDataType): + @classmethod + def from_parametrized_dtype(cls, x: Union[int, str]): + return x + + assert engine.dtype(42) == 42 + + +def test_register_notclassmethod_from_parametrized_dtype(engine: Engine): + """Test that a dtype with invalid from_parametrized_dtype + cannot be registered. + """ + + with pytest.raises( + ValueError, + match="_InvalidDtype.from_parametrized_dtype must be a classmethod.", + ): + + @engine.register_dtype + class _InvalidDtype(BaseDataType): + def from_parametrized_dtype( # pylint:disable=no-self-argument,no-self-use + cls, x: int + ): + return x + + +def test_register_dtype_complete(engine: Engine, equivalents: List[Any]): + """Test that a dtype with equivalents and from_parametrized_dtype + can be registered. + """ + + @engine.register_dtype(equivalents=equivalents) + class _Dtype(BaseDataType): + @classmethod + def from_parametrized_dtype(cls, x: Union[int, str]): + return x + + assert engine.dtype(42) == 42 + assert engine.dtype("foo") == "foo" + + for equivalent in equivalents: + assert engine.dtype(equivalent) == _Dtype() + + with pytest.raises( + TypeError, + match="Data type '' not understood by FakeEngine", + ): + engine.dtype(str) + + +def test_register_dtype_overwrite(engine: Engine): + """Test that register_dtype overwrites existing registrations.""" + + @engine.register_dtype(equivalents=["foo"]) + class _DtypeA(BaseDataType): + @classmethod + def from_parametrized_dtype(cls, x: Union[int, str]): + return _DtypeA() + + assert engine.dtype("foo") == _DtypeA() + assert engine.dtype("bar") == _DtypeA() + assert engine.dtype(42) == _DtypeA() + + @engine.register_dtype(equivalents=["foo"]) + class _DtypeB(BaseDataType): + @classmethod + def from_parametrized_dtype(cls, x: int): + return _DtypeB() + + assert engine.dtype("foo") == _DtypeB() + assert engine.dtype("bar") == _DtypeA() + assert engine.dtype(42) == _DtypeB() + + +def test_register_base_pandera_dtypes(): + """Test that base datatype cannot be registered.""" + + class FakeEngine( # pylint:disable=too-few-public-methods + metaclass=Engine, base_pandera_dtypes=(BaseDataType, BaseDataType) + ): + pass + + with pytest.raises( + ValueError, + match=re.escape( + "Subclasses of ['tests.core.test_engine.BaseDataType', " + + "'tests.core.test_engine.BaseDataType'] " + + "cannot be registered with FakeEngine." + ), + ): + + @FakeEngine.register_dtype(equivalents=[SimpleDtype]) + class _Dtype(BaseDataType): + pass + + +def test_return_base_dtype(engine: Engine): + """Test that Engine.dtype returns back base datatypes.""" + assert engine.dtype(SimpleDtype()) == SimpleDtype() + assert engine.dtype(SimpleDtype) == SimpleDtype() + + class ParametrizedDtypec(BaseDataType): + def __init__(self, x: int) -> None: + super().__init__() + self.x = x + + def __eq__(self, obj: object) -> bool: + if not isinstance(obj, ParametrizedDtypec): + return NotImplemented + return obj.x == self.x + + assert engine.dtype(ParametrizedDtypec(1)) == ParametrizedDtypec(1) + with pytest.raises( + TypeError, match="DataType 'ParametrizedDtypec' cannot be instantiated" + ): + engine.dtype(ParametrizedDtypec) diff --git a/tests/core/test_extensions.py b/tests/core/test_extensions.py index 5fd633dd1..0691b6615 100644 --- a/tests/core/test_extensions.py +++ b/tests/core/test_extensions.py @@ -9,7 +9,7 @@ import pandera as pa import pandera.strategies as st -from pandera import PandasDtype, extensions +from pandera import DataType, extensions from pandera.checks import Check @@ -206,7 +206,7 @@ def test_register_check_with_strategy(custom_check_teardown: None) -> None: import hypothesis # pylint: disable=import-outside-toplevel,import-error def custom_ge_strategy( - pandas_dtype: PandasDtype, + pandas_dtype: DataType, strategy: Optional[st.SearchStrategy] = None, *, min_value: Any, @@ -226,7 +226,7 @@ def custom_ge_check(pandas_obj, *, min_value): return pandas_obj >= min_value check = Check.custom_ge_check(min_value=0) - strat = check.strategy(PandasDtype.Int) + strat = check.strategy(pa.Int) with pytest.warns(hypothesis.errors.NonInteractiveExampleWarning): assert strat.example() >= 0 diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 4c2b61028..bbe6aa915 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -81,7 +81,7 @@ class InvalidDtype(pa.SchemaModel): d: Series[Decimal] # type: ignore with pytest.raises( - TypeError, match="python type '" + TypeError, match="dtype '' not understood" ): InvalidDtype.to_schema() @@ -480,7 +480,7 @@ class ChildEmpty(Mid): ) expected_child_override_attr = expected_mid.rename_columns( {"_b": "b"} - ).update_column("b", pandas_dtype=int) + ).update_column("b", dtype=int) expected_child_override_alias = expected_mid.rename_columns( {"_b": "new_b"} ) diff --git a/tests/core/test_model_components.py b/tests/core/test_model_components.py index f57bfaddc..3fe9e16be 100644 --- a/tests/core/test_model_components.py +++ b/tests/core/test_model_components.py @@ -5,6 +5,7 @@ import pytest import pandera as pa +from pandera.engines.pandas_engine import Engine def test_field_to_column() -> None: @@ -15,7 +16,7 @@ def test_field_to_column() -> None: pa.DateTime, required=value ) assert isinstance(col, pa.Column) - assert col.dtype == pa.DateTime.value + assert col.dtype == Engine.dtype(pa.DateTime) assert col.properties[flag] == value assert col.required == value @@ -26,7 +27,7 @@ def test_field_to_index() -> None: for value in [True, False]: index = pa.Field(**{flag: value}).to_index(pa.DateTime) # type: ignore[arg-type] assert isinstance(index, pa.Index) - assert index.dtype == pa.DateTime.value + assert index.dtype == Engine.dtype(pa.DateTime) assert getattr(index, flag) == value diff --git a/tests/core/test_schema_components.py b/tests/core/test_schema_components.py index 282d1c1e0..f1ee287e8 100644 --- a/tests/core/test_schema_components.py +++ b/tests/core/test_schema_components.py @@ -3,7 +3,6 @@ import copy from typing import Any, List, Optional, Tuple, Type, Union -import numpy as np import pandas as pd import pytest @@ -16,14 +15,11 @@ Index, Int, MultiIndex, - Object, - PandasDtype, SeriesSchema, String, errors, ) - -from .test_dtypes import TESTABLE_DTYPES +from pandera.engines.pandas_engine import Engine def test_column() -> None: @@ -48,26 +44,7 @@ def test_column() -> None: Column(Int)(data) -def test_coerce_nullable_object_column() -> None: - """Test that Object dtype coercing preserves object types.""" - df_objects_with_na = pd.DataFrame( - {"col": [1, 2.0, [1, 2, 3], {"a": 1}, np.nan, None]} - ) - - column_schema = Column(Object, name="col", coerce=True, nullable=True) - - validated_df = column_schema.validate(df_objects_with_na) - assert isinstance(validated_df, pd.DataFrame) - assert pd.isna(validated_df["col"].iloc[-1]) - assert pd.isna(validated_df["col"].iloc[-2]) - for i in range(4): - isinstance( - validated_df["col"].iloc[i], - type(df_objects_with_na["col"].iloc[i]), - ) - - -def test_column_in_dataframe_schema() -> None: +def test_column_in_dataframe_schema(): """Test that a Column check returns a dataframe.""" schema = DataFrameSchema( {"a": Column(Int, Check(lambda x: x > 0, element_wise=True))} @@ -96,18 +73,13 @@ def test_index_schema(): schema.validate(pd.DataFrame(index=range(1, 20))) -@pytest.mark.parametrize("pdtype", [Float, Int, String, String]) -def test_index_schema_coerce(pdtype: PandasDtype) -> None: +@pytest.mark.parametrize("dtype", [Float, Int, String]) +def test_index_schema_coerce(dtype): """Test that index can be type-coerced.""" - schema = DataFrameSchema(index=Index(pdtype, coerce=True)) + schema = DataFrameSchema(index=Index(dtype, coerce=True)) df = pd.DataFrame(index=pd.Index([1, 2, 3, 4], dtype="int64")) - validated_df = schema(df) - # pandas-native "string" dtype doesn't apply to indexes - assert ( - validated_df.index.dtype == "object" - if pdtype is String - else pdtype.str_alias - ) + validated_index_dtype = Engine.dtype(schema(df).index.dtype) + assert schema.index.dtype.check(validated_index_dtype) def test_multi_index_columns() -> None: @@ -215,10 +187,8 @@ def test_multi_index_schema_coerce() -> None: ) validated_df = schema(df) for level_i in range(validated_df.index.nlevels): - assert ( - validated_df.index.get_level_values(level_i).dtype - == indexes[level_i].dtype - ) + index_dtype = validated_df.index.get_level_values(level_i).dtype + assert indexes[level_i].dtype.check(Engine.dtype(index_dtype)) def tests_multi_index_subindex_coerce() -> None: @@ -253,15 +223,7 @@ def tests_multi_index_subindex_coerce() -> None: schema(data, lazy=True) -@pytest.mark.parametrize("pandas_dtype, expected", TESTABLE_DTYPES) -def test_column_dtype_property( - pandas_dtype: Union[PandasDtype, str], expected: str -) -> None: - """Tests that the dtypes provided by Column match pandas dtypes""" - assert Column(pandas_dtype).dtype == expected - - -def test_schema_component_equality_operators() -> None: +def test_schema_component_equality_operators(): """Test the usage of == for Column, Index and MultiIndex.""" column = Column(Int, Check(lambda s: s >= 0)) index = Index(Int, [Check(lambda x: 1 <= x <= 11, element_wise=True)]) @@ -572,18 +534,17 @@ def test_column_type_can_be_set() -> None: column_a = Column(Int, name="a") changed_type = Float - column_a.pandas_dtype = Float + column_a.dtype = Float - assert column_a.pandas_dtype == changed_type - assert column_a.dtype == changed_type.str_alias + assert column_a.dtype == Engine.dtype(changed_type) for invalid_dtype in ("foobar", "bar"): with pytest.raises(TypeError): - column_a.pandas_dtype = invalid_dtype + column_a.dtype = invalid_dtype for invalid_dtype in (1, 2.2, ["foo", 1, 1.1], {"b": 1}): with pytest.raises(TypeError): - column_a.pandas_dtype = invalid_dtype + column_a.dtype = invalid_dtype @pytest.mark.parametrize( diff --git a/tests/core/test_schema_statistics.py b/tests/core/test_schema_statistics.py index 5b0b5f9bb..edf6d8836 100644 --- a/tests/core/test_schema_statistics.py +++ b/tests/core/test_schema_statistics.py @@ -1,15 +1,34 @@ # pylint: disable=W0212 """Unit tests for inferring statistics of pandas objects.""" - import pandas as pd import pytest import pandera as pa -from pandera import PandasDtype, dtypes, schema_statistics -from pandera.dtypes import PANDAS_1_3_0_PLUS - -DEFAULT_INT = PandasDtype.from_str_alias(dtypes._DEFAULT_PANDAS_INT_TYPE) -DEFAULT_FLOAT = PandasDtype.from_str_alias(dtypes._DEFAULT_PANDAS_FLOAT_TYPE) +from pandera import dtypes, schema_statistics +from pandera.engines import pandas_engine + +DEFAULT_FLOAT = pandas_engine.Engine.dtype(float) +DEFAULT_INT = pandas_engine.Engine.dtype(int) + +NUMERIC_TYPES = [ + pa.Int(), + pa.UInt(), + pa.Float(), + pa.Complex(), + pandas_engine.Engine.dtype("Int32"), + pandas_engine.Engine.dtype("UInt32"), +] +INTEGER_TYPES = [ + dtypes.Int(), + dtypes.Int8(), + dtypes.Int16(), + dtypes.Int32(), + dtypes.Int64(), + dtypes.UInt8(), + dtypes.UInt16(), + dtypes.UInt32(), + dtypes.UInt64(), +] def _create_dataframe( @@ -55,33 +74,41 @@ def test_infer_dataframe_statistics(multi_index: bool, nullable: bool) -> None: statistics = schema_statistics.infer_dataframe_statistics(dataframe) stat_columns = statistics["columns"] - if PANDAS_1_3_0_PLUS: + if pa.pandas_version().release >= (1, 3, 0): if nullable: - assert stat_columns["int"]["pandas_dtype"] is DEFAULT_FLOAT + assert DEFAULT_FLOAT.check(stat_columns["int"]["dtype"]) else: - assert stat_columns["int"]["pandas_dtype"] is DEFAULT_INT + assert DEFAULT_INT.check(stat_columns["int"]["dtype"]) else: if nullable: - assert stat_columns["boolean"]["pandas_dtype"] is DEFAULT_FLOAT + assert DEFAULT_FLOAT.check(stat_columns["boolean"]["dtype"]) else: - assert stat_columns["boolean"]["pandas_dtype"] is pa.Bool + assert pandas_engine.Engine.dtype(bool).check( + stat_columns["boolean"]["dtype"] + ) - assert stat_columns["float"]["pandas_dtype"] is DEFAULT_FLOAT - assert stat_columns["string"]["pandas_dtype"] is pa.String - assert stat_columns["datetime"]["pandas_dtype"] is pa.DateTime + assert DEFAULT_FLOAT.check(stat_columns["float"]["dtype"]) + assert pandas_engine.Engine.dtype(str).check( + stat_columns["string"]["dtype"] + ) + assert pandas_engine.Engine.dtype(pa.DateTime).check( + stat_columns["datetime"]["dtype"] + ) if multi_index: stat_indices = statistics["index"] for stat_index, name, dtype in zip( - stat_indices, ["int_index", "str_index"], [DEFAULT_INT, pa.String] + stat_indices, + ["int_index", "str_index"], + [DEFAULT_INT, pandas_engine.Engine.dtype(str)], ): assert stat_index["name"] == name - assert stat_index["pandas_dtype"] is dtype + assert dtype.check(stat_index["dtype"]) assert not stat_index["nullable"] else: stat_index = statistics["index"][0] assert stat_index["name"] == "int_index" - assert stat_index["pandas_dtype"] is DEFAULT_INT + assert stat_index["dtype"] == DEFAULT_INT assert not stat_index["nullable"] for properties in stat_columns.values(): @@ -129,14 +156,30 @@ def test_parse_check_statistics(check_stats, expectation) -> None: assert set(checks) == set(expectation) +def _test_statistics(statistics, expectations): + if not isinstance(statistics, list): + statistics = [statistics] + if not isinstance(expectations, list): + expectations = [expectations] + + for stats, expectation in zip(statistics, expectations): + stat_dtype = stats.pop("dtype") + expectation_dtype = expectation.pop("dtype") + + assert stats == expectation + assert expectation_dtype.check(stat_dtype) + + @pytest.mark.parametrize( "series, expectation", [ *[ [ - pd.Series([1, 2, 3], dtype=dtype.str_alias), + pd.Series( + [1, 2, 3], dtype=str(pandas_engine.Engine.dtype(data_type)) + ), { - "pandas_dtype": dtype, + "dtype": pandas_engine.Engine.dtype(data_type), "nullable": False, "checks": { "greater_than_or_equal_to": 1, @@ -145,25 +188,21 @@ def test_parse_check_statistics(check_stats, expectation) -> None: "name": None, }, ] - for dtype in ( - x - for x in schema_statistics.NUMERIC_DTYPES - if x != PandasDtype.DateTime - ) + for data_type in NUMERIC_TYPES ], [ pd.Series(["a", "b", "c", "a"], dtype="category"), { - "pandas_dtype": pa.Category, + "dtype": pandas_engine.Engine.dtype(pa.Category), "nullable": False, "checks": {"isin": ["a", "b", "c"]}, "name": None, }, ], [ - pd.Series(["a", "b", "c", "a"], name="str_series"), + pd.Series(["a", "b", "c", "a"], dtype="string", name="str_series"), { - "pandas_dtype": pa.String, + "dtype": pandas_engine.Engine.dtype("string"), "nullable": False, "checks": None, "name": "str_series", @@ -172,7 +211,7 @@ def test_parse_check_statistics(check_stats, expectation) -> None: [ pd.Series(pd.to_datetime(["20180101", "20180102", "20180103"])), { - "pandas_dtype": pa.DateTime, + "dtype": pandas_engine.Engine.dtype(pa.DateTime), "nullable": False, "checks": { "greater_than_or_equal_to": pd.Timestamp("20180101"), @@ -186,20 +225,7 @@ def test_parse_check_statistics(check_stats, expectation) -> None: def test_infer_series_schema_statistics(series, expectation) -> None: """Test series statistics are correctly inferred.""" statistics = schema_statistics.infer_series_statistics(series) - assert statistics == expectation - - -INTEGER_TYPES = [ - PandasDtype.Int, - PandasDtype.Int8, - PandasDtype.Int16, - PandasDtype.Int32, - PandasDtype.Int64, - PandasDtype.UInt8, - PandasDtype.UInt16, - PandasDtype.UInt32, - PandasDtype.UInt64, -] + _test_statistics(statistics, expectation) @pytest.mark.parametrize( @@ -208,10 +234,10 @@ def test_infer_series_schema_statistics(series, expectation) -> None: *[ [ 0, - pd.Series([1, 2, 3], dtype=dtype.value), + pd.Series([1, 2, 3], dtype=str(data_type)), { # introducing nans to integer arrays upcasts to float - "pandas_dtype": DEFAULT_FLOAT, + "dtype": DEFAULT_FLOAT, "nullable": True, "checks": { "greater_than_or_equal_to": 2, @@ -220,7 +246,7 @@ def test_infer_series_schema_statistics(series, expectation) -> None: "name": None, }, ] - for dtype in INTEGER_TYPES + for data_type in INTEGER_TYPES ], [ # introducing nans to bool arrays upcasts to float except @@ -228,13 +254,15 @@ def test_infer_series_schema_statistics(series, expectation) -> None: 0, pd.Series([True, False, True, False]), { - "pandas_dtype": ( - pa.Bool if PANDAS_1_3_0_PLUS else DEFAULT_FLOAT + "dtype": ( + pandas_engine.Engine.dtype(pa.BOOL) + if pa.PANDAS_1_3_0_PLUS + else DEFAULT_FLOAT ), "nullable": True, "checks": ( None - if PANDAS_1_3_0_PLUS + if pa.PANDAS_1_3_0_PLUS else { "greater_than_or_equal_to": 0, "less_than_or_equal_to": 1, @@ -247,7 +275,7 @@ def test_infer_series_schema_statistics(series, expectation) -> None: 0, pd.Series(["a", "b", "c", "a"], dtype="category"), { - "pandas_dtype": pa.Category, + "dtype": pandas_engine.Engine.dtype(pa.Category), "nullable": True, "checks": {"isin": ["a", "b", "c"]}, "name": None, @@ -257,7 +285,7 @@ def test_infer_series_schema_statistics(series, expectation) -> None: 0, pd.Series(["a", "b", "c", "a"], name="str_series"), { - "pandas_dtype": pa.String, + "dtype": pandas_engine.Engine.dtype(str), "nullable": True, "checks": None, "name": "str_series", @@ -267,7 +295,7 @@ def test_infer_series_schema_statistics(series, expectation) -> None: 2, pd.Series(pd.to_datetime(["20180101", "20180102", "20180103"])), { - "pandas_dtype": pa.DateTime, + "dtype": pandas_engine.Engine.dtype(pa.DateTime), "nullable": True, "checks": { "greater_than_or_equal_to": pd.Timestamp("20180101"), @@ -284,7 +312,7 @@ def test_infer_nullable_series_schema_statistics( """Test nullable series statistics are correctly inferred.""" series.iloc[null_index] = None statistics = schema_statistics.infer_series_statistics(series) - assert statistics == expectation + _test_statistics(statistics, expectation) @pytest.mark.parametrize( @@ -295,7 +323,7 @@ def test_infer_nullable_series_schema_statistics( [ { "name": None, - "pandas_dtype": PandasDtype.Int, + "dtype": DEFAULT_INT, "nullable": False, "checks": { "greater_than_or_equal_to": 0, @@ -309,7 +337,7 @@ def test_infer_nullable_series_schema_statistics( [ { "name": "int_index", - "pandas_dtype": PandasDtype.Int, + "dtype": DEFAULT_INT, "nullable": False, "checks": { "greater_than_or_equal_to": 1, @@ -323,7 +351,7 @@ def test_infer_nullable_series_schema_statistics( [ { "name": "str_index", - "pandas_dtype": PandasDtype.String, + "dtype": pandas_engine.Engine.dtype("object"), "nullable": False, "checks": None, }, @@ -337,7 +365,7 @@ def test_infer_nullable_series_schema_statistics( [ { "name": "int_index", - "pandas_dtype": PandasDtype.Int, + "dtype": DEFAULT_INT, "nullable": False, "checks": { "greater_than_or_equal_to": 10, @@ -346,7 +374,7 @@ def test_infer_nullable_series_schema_statistics( }, { "name": "str_index", - "pandas_dtype": PandasDtype.Category, + "dtype": pandas_engine.Engine.dtype(pa.Category), "nullable": False, "checks": {"isin": ["a", "b", "c"]}, }, @@ -367,7 +395,9 @@ def test_infer_index_statistics(index, expectation): with pytest.warns(UserWarning, match="^index type .+ not recognized"): schema_statistics.infer_index_statistics(index) else: - assert schema_statistics.infer_index_statistics(index) == expectation + _test_statistics( + schema_statistics.infer_index_statistics(index), expectation + ) def test_get_dataframe_schema_statistics(): @@ -375,7 +405,7 @@ def test_get_dataframe_schema_statistics(): schema = pa.DataFrameSchema( columns={ "int": pa.Column( - pa.Int, + int, checks=[ pa.Check.greater_than_or_equal_to(0), pa.Check.less_than_or_equal_to(100), @@ -383,18 +413,19 @@ def test_get_dataframe_schema_statistics(): nullable=True, ), "float": pa.Column( - pa.Float, + float, checks=[ pa.Check.greater_than_or_equal_to(50), pa.Check.less_than_or_equal_to(100), ], ), "str": pa.Column( - pa.String, checks=[pa.Check.isin(["foo", "bar", "baz"])] + str, + checks=[pa.Check.isin(["foo", "bar", "baz"])], ), }, index=pa.Index( - pa.Int, + int, checks=pa.Check.greater_than_or_equal_to(0), nullable=False, name="int_index", @@ -404,7 +435,7 @@ def test_get_dataframe_schema_statistics(): "checks": None, "columns": { "int": { - "pandas_dtype": pa.Int, + "dtype": DEFAULT_INT, "checks": { "greater_than_or_equal_to": {"min_value": 0}, "less_than_or_equal_to": {"max_value": 100}, @@ -416,7 +447,7 @@ def test_get_dataframe_schema_statistics(): "regex": False, }, "float": { - "pandas_dtype": pa.Float, + "dtype": DEFAULT_FLOAT, "checks": { "greater_than_or_equal_to": {"min_value": 50}, "less_than_or_equal_to": {"max_value": 100}, @@ -428,7 +459,7 @@ def test_get_dataframe_schema_statistics(): "regex": False, }, "str": { - "pandas_dtype": pa.String, + "dtype": pandas_engine.Engine.dtype(str), "checks": {"isin": {"allowed_values": ["foo", "bar", "baz"]}}, "nullable": False, "allow_duplicates": True, @@ -439,7 +470,7 @@ def test_get_dataframe_schema_statistics(): }, "index": [ { - "pandas_dtype": pa.Int, + "dtype": DEFAULT_INT, "checks": {"greater_than_or_equal_to": {"min_value": 0}}, "nullable": False, "coerce": False, @@ -455,7 +486,7 @@ def test_get_dataframe_schema_statistics(): def test_get_series_schema_statistics(): """Test that series schema statistics logic is correct.""" schema = pa.SeriesSchema( - pa.Int, + int, nullable=False, checks=[ pa.Check.greater_than_or_equal_to(0), @@ -464,7 +495,7 @@ def test_get_series_schema_statistics(): ) statistics = schema_statistics.get_series_schema_statistics(schema) assert statistics == { - "pandas_dtype": pa.Int, + "dtype": pandas_engine.Engine.dtype(int), "nullable": False, "checks": { "greater_than_or_equal_to": {"min_value": 0}, @@ -480,7 +511,7 @@ def test_get_series_schema_statistics(): [ [ pa.Index( - pa.Int, + int, checks=[ pa.Check.greater_than_or_equal_to(10), pa.Check.less_than_or_equal_to(20), @@ -490,7 +521,7 @@ def test_get_series_schema_statistics(): ), [ { - "pandas_dtype": pa.Int, + "dtype": pandas_engine.Engine.dtype(int), "nullable": False, "checks": { "greater_than_or_equal_to": {"min_value": 10}, @@ -508,7 +539,7 @@ def test_get_index_schema_statistics(index_schema_component, expectation): statistics = schema_statistics.get_index_schema_statistics( index_schema_component ) - assert statistics == expectation + _test_statistics(statistics, expectation) @pytest.mark.parametrize( @@ -579,7 +610,10 @@ def test_parse_checks_and_statistics_roundtrip(checks, expectation): # pylint: disable=unused-argument def test_parse_checks_and_statistics_no_param(extra_registered_checks): - """Ensure that an edge case where a check does not have parameters is appropriately handled.""" + """ + Ensure that an edge case where a check does not have parameters is + appropriately handled. + """ checks = [pa.Check.no_param_check()] expectation = {"no_param_check": {}} diff --git a/tests/core/test_schemas.py b/tests/core/test_schemas.py index a2e78fd75..0a52b41f0 100644 --- a/tests/core/test_schemas.py +++ b/tests/core/test_schemas.py @@ -2,6 +2,7 @@ # pylint: disable=too-many-lines,redefined-outer-name import copy +from datetime import datetime, timedelta from functools import partial from typing import ( Any, @@ -20,29 +21,18 @@ import pytest from pandera import ( - STRING, - Bool, Category, Check, Column, DataFrameSchema, - DateTime, - Float, Index, - Int, MultiIndex, - Object, - PandasDtype, SeriesSchema, - String, - Timedelta, errors, ) -from pandera.dtypes import LEGACY_PANDAS +from pandera.engines.pandas_engine import Engine from pandera.schemas import SeriesSchemaBase -from .test_dtypes import TESTABLE_DTYPES - def test_dataframe_schema() -> None: """Tests the Checking of a DataFrame that has a wide variety of types and @@ -51,25 +41,25 @@ def test_dataframe_schema() -> None: """ schema = DataFrameSchema( { - "a": Column(Int, Check(lambda x: x > 0, element_wise=True)), + "a": Column(int, Check(lambda x: x > 0, element_wise=True)), "b": Column( - Float, Check(lambda x: 0 <= x <= 10, element_wise=True) + float, Check(lambda x: 0 <= x <= 10, element_wise=True) ), - "c": Column(String, Check(lambda x: set(x) == {"x", "y", "z"})), - "d": Column(Bool, Check(lambda x: x.mean() > 0.5)), + "c": Column(str, Check(lambda x: set(x) == {"x", "y", "z"})), + "d": Column(bool, Check(lambda x: x.mean() > 0.5)), "e": Column( Category, Check(lambda x: set(x) == {"c1", "c2", "c3"}) ), - "f": Column(Object, Check(lambda x: x.isin([(1,), (2,), (3,)]))), + "f": Column(object, Check(lambda x: x.isin([(1,), (2,), (3,)]))), "g": Column( - DateTime, + datetime, Check( lambda x: x >= pd.Timestamp("2015-01-01"), element_wise=True, ), ), "i": Column( - Timedelta, + timedelta, Check( lambda x: x < pd.Timedelta(10, unit="D"), element_wise=True ), @@ -113,11 +103,11 @@ def test_dataframe_schema() -> None: def test_dataframe_schema_equality() -> None: """Test DataframeSchema equality.""" - schema = DataFrameSchema({"a": Column(Int)}) + schema = DataFrameSchema({"a": Column(int)}) assert schema == copy.copy(schema) assert schema != "schema" assert DataFrameSchema(coerce=True) != DataFrameSchema(coerce=False) - assert schema != schema.update_column("a", pandas_dtype=Float) + assert schema != schema.update_column("a", dtype=float) assert schema != schema.update_column("a", checks=Check.eq(1)) @@ -127,7 +117,10 @@ def test_dataframe_schema_strict() -> None: not present in the dataframe. """ schema = DataFrameSchema( - {"a": Column(Int, nullable=True), "b": Column(Int, nullable=True)}, + { + "a": Column(int, nullable=True), + "b": Column(int, nullable=True), + }, strict=True, ) df = pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) @@ -142,7 +135,10 @@ def test_dataframe_schema_strict() -> None: with pytest.raises(errors.SchemaInitError): DataFrameSchema( - {"a": Column(Int, nullable=True), "b": Column(Int, nullable=True)}, + { + "a": Column(int, nullable=True), + "b": Column(int, nullable=True), + }, strict="foobar", ) @@ -155,7 +151,7 @@ def test_dataframe_schema_strict() -> None: def test_dataframe_schema_strict_regex() -> None: """Test that strict dataframe schema checks for regex matches.""" schema = DataFrameSchema( - {"foo_*": Column(Int, regex=True)}, + {"foo_*": Column(int, regex=True)}, strict=True, ) df = pd.DataFrame({"foo_%d" % i: range(10) for i in range(5)}) @@ -170,25 +166,26 @@ def test_dataframe_schema_strict_regex() -> None: ) -def test_dataframe_pandas_dtype_coerce() -> None: +def test_dataframe_dtype_coerce(): """ Test that pandas dtype specified at the dataframe level overrides column data types. """ schema = DataFrameSchema( columns={f"column_{i}": Column(float) for i in range(5)}, - pandas_dtype=int, + dtype=int, coerce=True, ) - df = pd.DataFrame({f"column_{i}": range(10) for i in range(5)}).astype( - float + df = pd.DataFrame( + {f"column_{i}": range(10) for i in range(5)}, dtype=float ) - assert (schema(df).dtypes == Int.str_alias).all() + int_alias = str(Engine.dtype(int)) + assert (schema(df).dtypes == int_alias).all() - # test that pandas_dtype in columns are preserved + # test that dtype in schema.columns are preserved for col in schema.columns.values(): - assert col.pandas_dtype is float + assert col.dtype == Engine.dtype(float) # raises SchemeError if dataframe can't be coerced with pytest.raises(errors.SchemaErrors): @@ -199,31 +196,26 @@ def test_dataframe_pandas_dtype_coerce() -> None: schema(pd.DataFrame({"foo": list("abcdef")}), lazy=True) # test that original dataframe dtypes are preserved - assert (df.dtypes == Float.str_alias).all() - - # test case where pandas_dtype is string - schema.pandas_dtype = str - assert (schema(df).dtypes == "object").all() + float_alias = str(Engine.dtype(float)) + assert (df.dtypes == float_alias).all() - schema.pandas_dtype = PandasDtype.String - assert (schema(df).dtypes == "object").all() - - # raises ValueError if _coerce_dtype is called when pandas_dtype is None - schema.pandas_dtype = None + # raises ValueError if _coerce_dtype is called when dtype is None + schema.dtype = None with pytest.raises(ValueError): schema._coerce_dtype(df) # test setting coerce as false at the dataframe level no longer coerces # columns to int schema.coerce = False - assert (schema(df).dtypes == "float64").all() + pd_dtypes = [Engine.dtype(pd_dtype) for pd_dtype in schema(df).dtypes] + assert all(pd_dtype == Engine.dtype(float) for pd_dtype in pd_dtypes) def test_dataframe_coerce_regex() -> None: """Test dataframe pandas dtype coercion for regex columns""" schema = DataFrameSchema( columns={"column_": Column(float, regex=True, required=False)}, - pandas_dtype=int, + dtype=int, coerce=True, ) @@ -257,15 +249,15 @@ def test_dataframe_reset_column_name() -> None: [ ( { - "a": Column(Int, required=False), - "b": Column(Int, required=False), + "a": Column(int, required=False), + "b": Column(int, required=False), }, None, ), ( None, MultiIndex( - indexes=[Index(Int, name="a"), Index(Int, name="b")], + indexes=[Index(int, name="a"), Index(int, name="b")], ), ), ], @@ -325,15 +317,18 @@ def test_series_schema() -> None: SeriesSchema("int").validate(pd.Series([1, 2, 3])) int_schema = SeriesSchema( - Int, Check(lambda x: 0 <= x <= 100, element_wise=True) + int, Check(lambda x: 0 <= x <= 100, element_wise=True) ) assert isinstance( int_schema.validate(pd.Series([0, 30, 50, 100])), pd.Series ) + def f(series): + return series.isin(["foo", "bar", "baz"]) + str_schema = SeriesSchema( - String, - Check(lambda s: s.isin(["foo", "bar", "baz"])), + str, + Check(f), nullable=True, coerce=True, ) @@ -354,21 +349,18 @@ def test_series_schema() -> None: with pytest.raises(TypeError): int_schema.validate(TypeError) - non_duplicate_schema = SeriesSchema(Int, allow_duplicates=False) + non_duplicate_schema = SeriesSchema(int, allow_duplicates=False) with pytest.raises(errors.SchemaError): non_duplicate_schema.validate(pd.Series([0, 1, 2, 3, 4, 1])) # when series name doesn't match schema - named_schema = SeriesSchema(Int, name="my_series") + named_schema = SeriesSchema(int, name="my_series") with pytest.raises(errors.SchemaError, match=r"^Expected .+ to have name"): named_schema.validate(pd.Series(range(5), name="your_series")) # when series floats are declared to be integer - with pytest.raises( - errors.SchemaError, - match=r"^after dropping null values, expected values in series", - ): - SeriesSchema(Int, nullable=True).validate( + with pytest.raises(errors.SchemaError): + SeriesSchema(int, nullable=True).validate( pd.Series([1.1, 2.3, 5.5, np.nan]) ) @@ -377,7 +369,7 @@ def test_series_schema() -> None: errors.SchemaError, match=r"^non-nullable series .+ contains null values", ): - SeriesSchema(Float, nullable=False).validate( + SeriesSchema(float, nullable=False).validate( pd.Series([1.1, 2.3, 5.5, np.nan]) ) @@ -386,7 +378,7 @@ def test_series_schema() -> None: errors.SchemaError, match="Error while coercing", ): - SeriesSchema(Float, coerce=True).validate(pd.Series(list("abcdefg"))) + SeriesSchema(float, coerce=True).validate(pd.Series(list("abcdefg"))) def test_series_schema_checks() -> None: @@ -413,7 +405,7 @@ def test_series_schema_multiple_validators() -> None: """Tests how multiple Checks on a Series Schema are handled both successfully and when errors are expected.""" schema = SeriesSchema( - Int, + int, [ Check(lambda x: 0 <= x <= 50, element_wise=True), Check(lambda s: (s == 21).any()), @@ -431,18 +423,18 @@ def test_series_schema_multiple_validators() -> None: def test_series_schema_with_index(coerce: bool) -> None: """Test SeriesSchema with Index and MultiIndex components.""" schema_with_index = SeriesSchema( - pandas_dtype=Int, - index=Index(Int, coerce=coerce), + dtype=int, + index=Index(int, coerce=coerce), ) validated_series = schema_with_index(pd.Series([1, 2, 3], index=[1, 2, 3])) assert isinstance(validated_series, pd.Series) schema_with_multiindex = SeriesSchema( - pandas_dtype=Int, + dtype=int, index=MultiIndex( [ - Index(Int, coerce=coerce), - Index(String, coerce=coerce), + Index(int, coerce=coerce), + Index(str, coerce=coerce), ] ), ) @@ -500,8 +492,8 @@ def test_dataframe_schema_check_function_types( """Tests a DataFrameSchema against a variety of Check conditions.""" schema = DataFrameSchema( { - "a": Column(Int, Check(check_function, element_wise=False)), - "b": Column(Float, Check(check_function, element_wise=False)), + "a": Column(int, Check(check_function, element_wise=False)), + "b": Column(float, Check(check_function, element_wise=False)), } ) df = pd.DataFrame({"a": [1, 2, 3], "b": [1.1, 2.5, 9.9]}) @@ -512,20 +504,7 @@ def test_dataframe_schema_check_function_types( schema.validate(df) -def test_nullable_int_in_dataframe() -> None: - """Tests handling of nullability when datatype is integers.""" - df = pd.DataFrame({"column1": [5, 1, np.nan]}) - null_schema = DataFrameSchema( - {"column1": Column(Int, Check(lambda x: x > 0), nullable=True)} - ) - assert isinstance(null_schema.validate(df), pd.DataFrame) - - # test case where column is an object - df = df.astype({"column1": "object"}) - assert isinstance(null_schema.validate(df), pd.DataFrame) - - -def test_coerce_dtype_in_dataframe() -> None: +def test_coerce_dtype_in_dataframe(): """Tests coercions of datatypes, especially regarding nullable integers.""" df = pd.DataFrame( { @@ -538,31 +517,30 @@ def test_coerce_dtype_in_dataframe() -> None: # specify `coerce` at the Column level schema1 = DataFrameSchema( { - "column1": Column(Int, Check(lambda x: x > 0), coerce=True), - "column2": Column(DateTime, coerce=True), - "column3": Column(String, coerce=True, nullable=True), + "column1": Column(int, Check(lambda x: x > 0), coerce=True), + "column2": Column(datetime, coerce=True), } ) # specify `coerce` at the DataFrameSchema level schema2 = DataFrameSchema( { - "column1": Column(Int, Check(lambda x: x > 0)), - "column2": Column(DateTime), - "column3": Column(String, nullable=True), + "column1": Column(int, Check(lambda x: x > 0)), + "column2": Column(datetime), }, coerce=True, ) for schema in [schema1, schema2]: result = schema.validate(df) - assert result.column1.dtype == Int.str_alias - assert result.column2.dtype == DateTime.str_alias - for _, x in result.column3.iteritems(): - assert pd.isna(x) or isinstance(x, str) + column1_datatype = Engine.dtype(result.column1.dtype) + assert column1_datatype == Engine.dtype(int) + + column2_datatype = Engine.dtype(result.column2.dtype) + assert column2_datatype == Engine.dtype(datetime) # make sure that correct error is raised when null values are present # in a float column that's coerced to an int - schema = DataFrameSchema({"column4": Column(Int, coerce=True)}) + schema = DataFrameSchema({"column4": Column(int, coerce=True)}) with pytest.raises( errors.SchemaError, match=r"^Error while coercing .* to type u{0,1}int[0-9]{1,2}: " @@ -571,74 +549,7 @@ def test_coerce_dtype_in_dataframe() -> None: schema.validate(df) -@pytest.mark.parametrize( - "data, dtype, nonnull_idx", - [ - # some values are null - [["foobar", "foo", "bar", "baz", np.nan, np.nan], str, 4], - [["foobar", "foo", "bar", "baz", None, None], str, 4], - # some values are null, non-null values are not strings - [[1.0, 2.0, 3.0, 4.0, np.nan, np.nan], float, 4], - [[1, 2, 3, 4, None, None], "Int64", 4], - # all values are null - [[np.nan] * 6, object, 0], - [[None] * 6, object, 0], - [[np.nan] * 6, float, 0], - [[None] * 6, float, 0], - ], -) -@pytest.mark.parametrize("string_type", [String, str, "str", STRING, "string"]) -@pytest.mark.parametrize("nullable", [True, False]) -def test_coerce_dtype_nullable_str( - data, dtype, nonnull_idx: int, string_type, nullable: bool -) -> None: - """Tests how null values are handled with string dtypes.""" - if LEGACY_PANDAS and ( - dtype == "Int64" or string_type in {STRING, "string"} - ): - pytest.skip("Skipping data types that depend on pandas>1.0.0") - dataframe = pd.DataFrame({"col": pd.Series(data, dtype=dtype)}) - schema = DataFrameSchema( - {"col": Column(string_type, coerce=True, nullable=nullable)} - ) - - if not nullable: - with pytest.raises(errors.SchemaError): - schema.validate(dataframe) - return - - validated_df = schema.validate(dataframe) - assert isinstance(validated_df, pd.DataFrame) - for i, element in validated_df["col"].iteritems(): - if i < nonnull_idx: - assert isinstance(element, str) - else: - assert pd.isna(element) - - -@pytest.mark.parametrize( - "data, expected_type", - [ - [{"a": 1, "b": 2, "c": 3}, dict], - [[1, 2, 3, 4], list], - [[1, {"a": 5}], list], - [{1, 2, 3}, set], - ], -) -@pytest.mark.parametrize("dtype", ["object", object, Object]) -def test_coerce_object_dtype( - data, expected_type: Type[Iterable], dtype -) -> None: - """Test coercing on object dtype.""" - schema = DataFrameSchema({"col": Column(dtype)}, coerce=True) - df = pd.DataFrame({"col": [data] * 3}) - validated_df = schema(df) - assert isinstance(validated_df, pd.DataFrame) - for _, x in validated_df["col"].iteritems(): - assert isinstance(x, expected_type) - - -def test_no_dtype_dataframe() -> None: +def test_no_dtype_dataframe(): """Test how nullability is handled in DataFrameSchemas where no type is specified.""" schema = DataFrameSchema({"col": Column(nullable=False)}) @@ -688,7 +599,7 @@ def test_required() -> None: isn't available. """ schema = DataFrameSchema( - {"col1": Column(Int, required=False), "col2": Column(String)} + {"col1": Column(int, required=False), "col2": Column(str)} ) df_ok_1 = pd.DataFrame({"col2": ["hello", "world"]}) @@ -741,7 +652,7 @@ def test_head_dataframe_schema() -> None: ) schema = DataFrameSchema( - columns={"col1": Column(Int, Check(lambda s: s >= 0))} + columns={"col1": Column(int, Check(lambda s: s >= 0))} ) # Validating with head of 100 should pass @@ -757,7 +668,7 @@ def test_tail_dataframe_schema() -> None: ) schema = DataFrameSchema( - columns={"col1": Column(Int, Check(lambda s: s < 0))} + columns={"col1": Column(int, Check(lambda s: s < 0))} ) # Validating with tail of 1000 should pass @@ -772,7 +683,7 @@ def test_sample_dataframe_schema() -> None: # assert all values -1 schema = DataFrameSchema( - columns={"col1": Column(Int, Check(lambda s: s == -1))} + columns={"col1": Column(int, Check(lambda s: s == -1))} ) for seed in [11, 123456, 9000, 654]: @@ -786,11 +697,11 @@ def test_dataframe_schema_str_repr() -> None: printing/logging of a DataFrameSchema.""" schema = DataFrameSchema( columns={ - "col1": Column(Int), - "col2": Column(String), - "col3": Column(DateTime), + "col1": Column(int), + "col2": Column(str), + "col3": Column(datetime), }, - index=Index(Int, name="my_index"), + index=Index(int, name="my_index"), ) for x in [schema.__str__(), schema.__repr__()]: @@ -804,62 +715,52 @@ def test_dataframe_schema_dtype_property() -> None: """Test that schema.dtype returns the matching Column types.""" schema = DataFrameSchema( columns={ - "col1": Column(Int), - "col2": Column(String), - "col3": Column(STRING), - "col4": Column(DateTime), - "col5": Column("uint16"), + "col1": Column(int), + "col2": Column(str), + "col3": Column(datetime), + "col4": Column("uint16"), } ) - assert schema.dtype == { - "col1": "int64", - "col2": "object", - "col3": ("object" if LEGACY_PANDAS else "string"), - "col4": "datetime64[ns]", - "col5": "uint16", + assert schema.dtypes == { + "col1": Engine.dtype("int64"), + "col2": Engine.dtype("str"), + "col3": Engine.dtype("datetime64[ns]"), + "col4": Engine.dtype("uint16"), } -@pytest.mark.parametrize("pandas_dtype, expected", TESTABLE_DTYPES) -def test_series_schema_dtype_property( - pandas_dtype: Union[PandasDtype, str], expected: str -) -> None: - """Tests every type of allowed dtype.""" - assert SeriesSchema(pandas_dtype).dtype == expected - - -def test_schema_equality_operators() -> None: +def test_schema_equality_operators(): """Test the usage of == for DataFrameSchema, SeriesSchema and SeriesSchemaBase.""" df_schema = DataFrameSchema( { - "col1": Column(Int, Check(lambda s: s >= 0)), - "col2": Column(String, Check(lambda s: s >= 2)), + "col1": Column(int, Check(lambda s: s >= 0)), + "col2": Column(str, Check(lambda s: s >= 2)), }, strict=True, ) df_schema_columns_in_different_order = DataFrameSchema( { - "col2": Column(String, Check(lambda s: s >= 2)), - "col1": Column(Int, Check(lambda s: s >= 0)), + "col2": Column(str, Check(lambda s: s >= 2)), + "col1": Column(int, Check(lambda s: s >= 0)), }, strict=True, ) series_schema = SeriesSchema( - String, + str, checks=[Check(lambda s: s.str.startswith("foo"))], nullable=False, allow_duplicates=True, name="my_series", ) series_schema_base = SeriesSchemaBase( - String, + str, checks=[Check(lambda s: s.str.startswith("foo"))], nullable=False, allow_duplicates=True, name="my_series", ) - not_equal_schema = DataFrameSchema({"col1": Column(String)}, strict=False) + not_equal_schema = DataFrameSchema({"col1": Column(str)}, strict=False) assert df_schema == copy.deepcopy(df_schema) assert df_schema != not_equal_schema @@ -875,7 +776,7 @@ def test_add_and_remove_columns() -> None: modify the original underlying DataFrameSchema.""" schema1 = DataFrameSchema( { - "col1": Column(Int, Check(lambda s: s >= 0)), + "col1": Column(int, Check(lambda s: s >= 0)), }, strict=True, ) @@ -885,8 +786,8 @@ def test_add_and_remove_columns() -> None: # test that add_columns doesn't modify schema1 after add_columns: schema2 = schema1.add_columns( { - "col2": Column(String, Check(lambda x: x <= 0)), - "col3": Column(Object, Check(lambda x: x == 0)), + "col2": Column(str, Check(lambda x: x <= 0)), + "col3": Column(object, Check(lambda x: x == 0)), } ) @@ -897,9 +798,9 @@ def test_add_and_remove_columns() -> None: # test that add_columns changed schema1 into schema2: expected_schema_2 = DataFrameSchema( { - "col1": Column(Int, Check(lambda s: s >= 0)), - "col2": Column(String, Check(lambda x: x <= 0)), - "col3": Column(Object, Check(lambda x: x == 0)), + "col1": Column(int, Check(lambda s: s >= 0)), + "col2": Column(str, Check(lambda x: x <= 0)), + "col3": Column(object, Check(lambda x: x == 0)), }, strict=True, ) @@ -914,8 +815,8 @@ def test_add_and_remove_columns() -> None: # test that remove_columns has removed the changes as expected: expected_schema_3 = DataFrameSchema( { - "col1": Column(Int, Check(lambda s: s >= 0)), - "col3": Column(Object, Check(lambda x: x == 0)), + "col1": Column(int, Check(lambda s: s >= 0)), + "col3": Column(object, Check(lambda x: x == 0)), }, strict=True, ) @@ -926,7 +827,7 @@ def test_add_and_remove_columns() -> None: schema4 = schema2.remove_columns(["col2", "col3"]) expected_schema_4 = DataFrameSchema( - {"col1": Column(Int, Check(lambda s: s >= 0))}, strict=True + {"col1": Column(int, Check(lambda s: s >= 0))}, strict=True ) assert schema4 == expected_schema_4 == schema1 @@ -936,12 +837,12 @@ def test_add_and_remove_columns() -> None: schema2.remove_columns(["foo", "bar"]) -def test_schema_get_dtype() -> None: - """Test that schema dtype and get_dtype methods handle regex columns.""" +def test_schema_get_dtypes(): + """Test that schema dtype and get_dtypes methods handle regex columns.""" schema = DataFrameSchema( { - "col1": Column(Int), - "var*": Column(Float, regex=True), + "col1": Column(int), + "var*": Column(float, regex=True), } ) @@ -955,7 +856,7 @@ def test_schema_get_dtype() -> None: ) with pytest.warns(UserWarning) as record: - assert schema.dtype == {"col1": Int.str_alias} + assert schema.dtypes == {"col1": Engine.dtype(int)} assert len(record) == 1 assert ( record[0] # type: ignore[union-attr] @@ -963,11 +864,11 @@ def test_schema_get_dtype() -> None: .startswith("Schema has columns specified as regex column names:") ) - assert schema.get_dtype(data) == { - "col1": Int.str_alias, - "var1": Float.str_alias, - "var2": Float.str_alias, - "var3": Float.str_alias, + assert schema.get_dtypes(data) == { + "col1": Engine.dtype(int), + "var1": Engine.dtype(float), + "var2": Engine.dtype(float), + "var3": Engine.dtype(float), } @@ -983,24 +884,24 @@ def _assert_bool_case(old_schema, new_schema): assert not getattr(old_schema.columns["col"], bool_kwarg) assert getattr(new_schema.columns["col"], bool_kwarg) - return ( - Column(Int, **{bool_kwarg: False}), # type: ignore[arg-type] + return [ + Column(int, **{bool_kwarg: False}), "col", {bool_kwarg: True}, _assert_bool_case, - ) + ] @pytest.mark.parametrize( "column, column_to_update, update, assertion_fn", [ [ - Column(Int), + Column(int), "col", - {"pandas_dtype": String}, + {"dtype": str}, lambda old, new: [ - old.columns["col"].pandas_dtype is Int, - new.columns["col"].pandas_dtype is String, + old.columns["col"].dtype is int, + new.columns["col"].dtype is str, ], ], *[ @@ -1014,7 +915,7 @@ def _assert_bool_case(old_schema, new_schema): ] ], [ - Column(Int, checks=Check.greater_than(0)), + Column(int, checks=Check.greater_than(0)), "col", {"checks": Check.less_than(10)}, lambda old, new: [ @@ -1023,8 +924,8 @@ def _assert_bool_case(old_schema, new_schema): ], ], # error cases - [Column(Int), "col", {"name": "renamed_col"}, ValueError], - [Column(Int), "foobar", {}, ValueError], + [Column(int), "col", {"name": "renamed_col"}, ValueError], + [Column(int), "foobar", {}, ValueError], ], ) def test_dataframe_schema_update_column( @@ -1049,7 +950,7 @@ def test_rename_columns() -> None: rename_dict = {"col1": "col1_new_name", "col2": "col2_new_name"} schema_original = DataFrameSchema( - columns={"col1": Column(Int), "col2": Column(Float)} + columns={"col1": Column(int), "col2": Column(float)} ) schema_renamed = schema_original.rename_columns(rename_dict) @@ -1077,9 +978,9 @@ def test_rename_columns() -> None: ["col1", "col2"], DataFrameSchema( columns={ - "col1": Column(Int), - "col2": Column(Int), - "col3": Column(Int), + "col1": Column(int), + "col2": Column(int), + "col3": Column(int), } ), ), @@ -1087,10 +988,10 @@ def test_rename_columns() -> None: [("col1", "col1b"), ("col2", "col2b")], DataFrameSchema( columns={ - ("col1", "col1a"): Column(Int), - ("col1", "col1b"): Column(Int), - ("col2", "col2a"): Column(Int), - ("col2", "col2b"): Column(Int), + ("col1", "col1a"): Column(int), + ("col1", "col1b"): Column(int), + ("col2", "col2a"): Column(int), + ("col2", "col2b"): Column(int), } ), ), @@ -1113,16 +1014,16 @@ def test_lazy_dataframe_validation_error() -> None: """Test exceptions on lazy dataframe validation.""" schema = DataFrameSchema( columns={ - "int_col": Column(Int, Check.greater_than(5)), - "int_col2": Column(Int), - "float_col": Column(Float, Check.less_than(0)), - "str_col": Column(String, Check.isin(["foo", "bar"])), - "not_in_dataframe": Column(Int), + "int_col": Column(int, Check.greater_than(5)), + "int_col2": Column(int), + "float_col": Column(float, Check.less_than(0)), + "str_col": Column(str, Check.isin(["foo", "bar"])), + "not_in_dataframe": Column(int), }, checks=Check( lambda df: df != 1, error="dataframe_not_equal_1", ignore_na=False ), - index=Index(String, name="str_index"), + index=Index(str, name="str_index"), strict=True, ) @@ -1147,7 +1048,7 @@ def test_lazy_dataframe_validation_error() -> None: }, "Column": { "greater_than(5)": [1, 2], - "pandas_dtype('int64')": ["object"], + "dtype('int64')": ["object"], "less_than(0)": [1, 3], }, } @@ -1232,9 +1133,9 @@ def test_lazy_dataframe_validation_nullable() -> None: """ schema = DataFrameSchema( columns={ - "int_column": Column(Int, nullable=False), - "float_column": Column(Float, nullable=False), - "str_column": Column(String, nullable=False), + "int_column": Column(int, nullable=False), + "float_column": Column(float, nullable=False), + "str_column": Column(str, nullable=False), }, strict=True, ) @@ -1400,7 +1301,7 @@ def test_lazy_dataframe_scalar_false_check( { "data": pd.Series([0.1]), "schema_errors": { - "SeriesSchema": {"pandas_dtype('int64')": ["float64"]}, + "SeriesSchema": {"dtype('int64')": ["float64"]}, }, }, ], @@ -1411,7 +1312,7 @@ def test_lazy_dataframe_scalar_false_check( { "data": pd.Series([1, 2, 3], index=list("abc")), "schema_errors": { - "Index": {"pandas_dtype('int64')": ["object"]}, + "Index": {"dtype('int64')": ["object"]}, }, }, ], @@ -1423,7 +1324,7 @@ def test_lazy_dataframe_scalar_false_check( "data": pd.Series(["1", "foo", "bar"]), "schema_errors": { "SeriesSchema": { - "pandas_dtype('float64')": ["object"], + "dtype('float64')": ["object"], "coerce_dtype('float64')": ["object"], }, }, @@ -1456,7 +1357,7 @@ def test_lazy_dataframe_scalar_false_check( # TypeError raised in python=3.5 'TypeError("unorderable types: str() > int()")', ], - "pandas_dtype('int64')": ["object"], + "dtype('int64')": ["object"], }, }, }, @@ -1464,7 +1365,7 @@ def test_lazy_dataframe_scalar_false_check( # case: multiple series checks don't satisfy schema [ Column( - Int, + int, checks=[Check.greater_than(1), Check.less_than(3)], name="column", ), @@ -1477,7 +1378,7 @@ def test_lazy_dataframe_scalar_false_check( }, ], [ - Index(String, checks=Check.isin(["a", "b", "c"])), + Index(str, checks=Check.isin(["a", "b", "c"])), pd.DataFrame({"col": [1, 2, 3]}, index=["a", "b", "d"]), { # expect that the data in the SchemaError is the pd.Index cast @@ -1491,8 +1392,8 @@ def test_lazy_dataframe_scalar_false_check( [ MultiIndex( indexes=[ - Index(Int, checks=Check.greater_than(0), name="index0"), - Index(Int, checks=Check.less_than(0), name="index1"), + Index(int, checks=Check.greater_than(0), name="index0"), + Index(int, checks=Check.less_than(0), name="index1"), ] ), pd.DataFrame( @@ -1570,10 +1471,10 @@ def test_schema_coerce_inplace_validation( inplace: bool, from_dtype: Type, to_dtype: Type ) -> None: """Test coercion logic for validation when inplace is True and False""" - - to_dtype = PandasDtype.from_python_type(to_dtype).str_alias - from_dtype = PandasDtype.from_python_type(from_dtype).str_alias - + from_dtype = ( + from_dtype if from_dtype is not int else str(Engine.dtype(from_dtype)) + ) + to_dtype = to_dtype if to_dtype is not int else str(Engine.dtype(to_dtype)) df = pd.DataFrame({"column": pd.Series([1, 2, 6], dtype=from_dtype)}) schema = DataFrameSchema({"column": Column(to_dtype, coerce=True)}) validated_df = schema.validate(df, inplace=inplace) @@ -1592,10 +1493,10 @@ def schema_simple() -> DataFrameSchema: """Simple schema fixture.""" schema = DataFrameSchema( columns={ - "col1": Column(pandas_dtype=Int), - "col2": Column(pandas_dtype=Float), + "col1": Column(dtype=int), + "col2": Column(dtype=float), }, - index=Index(pandas_dtype=String, name="ind0"), + index=Index(dtype=str, name="ind0"), ) return schema @@ -1605,13 +1506,13 @@ def schema_multiindex() -> DataFrameSchema: """Fixture for schema with MultiIndex.""" schema = DataFrameSchema( columns={ - "col1": Column(pandas_dtype=Int), - "col2": Column(pandas_dtype=Float), + "col1": Column(dtype=int), + "col2": Column(dtype=float), }, index=MultiIndex( [ - Index(pandas_dtype=String, name="ind0"), - Index(pandas_dtype=String, name="ind1"), + Index(dtype=str, name="ind0"), + Index(dtype=str, name="ind1"), ] ), ) @@ -1699,51 +1600,43 @@ def test_update_columns(schema_simple: DataFrameSchema) -> None: """Catch-all test for update columns functionality""" # Basic function - test_schema = schema_simple.update_columns({"col2": {"pandas_dtype": Int}}) + test_schema = schema_simple.update_columns({"col2": {"dtype": int}}) assert ( schema_simple.columns["col1"].properties == test_schema.columns["col1"].properties ) - assert test_schema.columns["col2"].pandas_dtype == Int + assert test_schema.columns["col2"].dtype == Engine.dtype(int) # Multiple columns, multiple properties test_schema = schema_simple.update_columns( { - "col1": {"pandas_dtype": Category, "coerce": True}, - "col2": {"pandas_dtype": Int, "allow_duplicates": False}, + "col1": {"dtype": Category, "coerce": True}, + "col2": {"dtype": int, "allow_duplicates": False}, } ) - assert test_schema.columns["col1"].pandas_dtype == Category + assert test_schema.columns["col1"].dtype == Engine.dtype(Category) assert test_schema.columns["col1"].coerce is True - assert test_schema.columns["col2"].pandas_dtype == Int + assert test_schema.columns["col2"].dtype == Engine.dtype(int) assert test_schema.columns["col2"].allow_duplicates is False # Errors with pytest.raises(errors.SchemaInitError): - schema_simple.update_columns({"col3": {"pandas_dtype": Int}}) + schema_simple.update_columns({"col3": {"dtype": int}}) with pytest.raises(errors.SchemaInitError): schema_simple.update_columns({"col1": {"name": "foo"}}) with pytest.raises(errors.SchemaInitError): - schema_simple.update_columns({"ind0": {"pandas_dtype": Int}}) + schema_simple.update_columns({"ind0": {"dtype": int}}) -@pytest.mark.parametrize("pdtype", list(PandasDtype) + [None]) # type: ignore -def test_series_schema_pdtype(pdtype: Optional[PandasDtype]) -> None: - """Series schema pdtype property should return PandasDtype.""" - if pdtype is None: - series_schema = SeriesSchema(pdtype) - assert series_schema.pdtype is None - return - for pandas_dtype_input in [ - pdtype, - pdtype.str_alias, - pdtype.value, - ]: - series_schema = SeriesSchema(pandas_dtype_input) - if pdtype is PandasDtype.STRING and LEGACY_PANDAS: - assert series_schema.pdtype == PandasDtype.String - else: - assert series_schema.pdtype == pdtype +@pytest.mark.parametrize("dtype", [int, None]) # type: ignore +def test_series_schema_dtype(dtype): + """Series schema dtype property should return + a Engine-compatible dtype.""" + if dtype is None: + series_schema = SeriesSchema(dtype) + assert series_schema.dtype is None + else: + assert SeriesSchema(dtype).dtype == Engine.dtype(dtype) @pytest.mark.parametrize( @@ -1788,7 +1681,7 @@ def test_dataframe_duplicated_columns(data, error, schema) -> None: columns={"col": Column(int)}, checks=Check.gt(0), index=Index(int), - pandas_dtype=int, + dtype=int, coerce=True, strict=True, name="schema", @@ -1798,7 +1691,7 @@ def test_dataframe_duplicated_columns(data, error, schema) -> None: "columns", "checks", "index", - "pandas_dtype", + "dtype", "coerce", "strict", "name", @@ -1834,7 +1727,6 @@ def test_schema_str_repr(schema, fields: List[str]) -> None: schema.__str__(), schema.__repr__(), ]: - print(x) assert x.startswith(f"") for field in fields: diff --git a/tests/core/test_typing.py b/tests/core/test_typing.py index a268ae9eb..78f58c801 100644 --- a/tests/core/test_typing.py +++ b/tests/core/test_typing.py @@ -8,7 +8,7 @@ import pytest import pandera as pa -from pandera.dtypes import LEGACY_PANDAS, PANDAS_1_3_0_PLUS, PandasDtype +from pandera.dtypes import DataType from pandera.typing import LEGACY_TYPING, Series if not LEGACY_TYPING: @@ -127,7 +127,7 @@ class SchemaUINT64(pa.SchemaModel): def _test_literal_pandas_dtype( - model: Type[pa.SchemaModel], pandas_dtype: PandasDtype + model: Type[pa.SchemaModel], pandas_dtype: DataType ): schema = model.to_schema() expected = pa.Column(pandas_dtype, name="col").dtype @@ -159,13 +159,12 @@ def _test_literal_pandas_dtype( ], ) def test_literal_legacy_pandas_dtype( - model: Type[pa.SchemaModel], pandas_dtype: PandasDtype + model: Type[pa.SchemaModel], pandas_dtype: DataType ): """Test literal annotations with the legacy pandas dtypes.""" _test_literal_pandas_dtype(model, pandas_dtype) -@pytest.mark.skipif(LEGACY_PANDAS, reason="pandas >= 1.0.0 required") @pytest.mark.parametrize( "model, pandas_dtype", [ @@ -180,7 +179,7 @@ def test_literal_legacy_pandas_dtype( ], ) def test_literal_new_pandas_dtype( - model: Type[pa.SchemaModel], pandas_dtype: PandasDtype + model: Type[pa.SchemaModel], pandas_dtype: DataType ): """Test literal annotations with the new nullable pandas dtypes.""" _test_literal_pandas_dtype(model, pandas_dtype) @@ -313,7 +312,7 @@ class SchemaAnnotatedCategoricalDtype(pa.SchemaModel): class SchemaAnnotatedDatetimeTZDtype(pa.SchemaModel): col: Series[Annotated[pd.DatetimeTZDtype, "ns", "est"]] - if PANDAS_1_3_0_PLUS: + if pa.PANDAS_1_3_0_PLUS: class SchemaAnnotatedIntervalDtype(pa.SchemaModel): col: Series[Annotated[pd.IntervalDtype, "int32", "both"]] @@ -347,7 +346,7 @@ class SchemaAnnotatedSparseDtype(pa.SchemaModel): pd.IntervalDtype, ( {"subtype": "int32", "closed": "both"} - if PANDAS_1_3_0_PLUS + if pa.PANDAS_1_3_0_PLUS else {"subtype": "int32"} ), ), @@ -395,58 +394,65 @@ def test_pandas_extension_dtype_redundant_field(): SchemaRedundantField.to_schema() -if not LEGACY_PANDAS: +class SchemaInt8Dtype(pa.SchemaModel): + col: Series[pd.Int8Dtype] - class SchemaInt8Dtype(pa.SchemaModel): - col: Series[pd.Int8Dtype] - class SchemaInt16Dtype(pa.SchemaModel): - col: Series[pd.Int16Dtype] +class SchemaInt16Dtype(pa.SchemaModel): + col: Series[pd.Int16Dtype] - class SchemaInt32Dtype(pa.SchemaModel): - col: Series[pd.Int32Dtype] - class SchemaInt64Dtype(pa.SchemaModel): - col: Series[pd.Int64Dtype] +class SchemaInt32Dtype(pa.SchemaModel): + col: Series[pd.Int32Dtype] - class SchemaUInt8Dtype(pa.SchemaModel): - col: Series[pd.UInt8Dtype] - class SchemaUInt16Dtype(pa.SchemaModel): - col: Series[pd.UInt16Dtype] +class SchemaInt64Dtype(pa.SchemaModel): + col: Series[pd.Int64Dtype] - class SchemaUInt32Dtype(pa.SchemaModel): - col: Series[pd.UInt32Dtype] - class SchemaUInt64Dtype(pa.SchemaModel): - col: Series[pd.UInt64Dtype] +class SchemaUInt8Dtype(pa.SchemaModel): + col: Series[pd.UInt8Dtype] - class SchemaStringDtype(pa.SchemaModel): - col: Series[pd.StringDtype] - class SchemaBooleanDtype(pa.SchemaModel): - col: Series[pd.BooleanDtype] +class SchemaUInt16Dtype(pa.SchemaModel): + col: Series[pd.UInt16Dtype] - @pytest.mark.skipif(LEGACY_PANDAS, reason="pandas >= 1.0.0 required") - @pytest.mark.parametrize( - "model, dtype, has_mandatory_args", - [ - (SchemaInt8Dtype, pd.Int8Dtype, False), - (SchemaInt16Dtype, pd.Int16Dtype, False), - (SchemaInt32Dtype, pd.Int32Dtype, False), - (SchemaInt64Dtype, pd.Int64Dtype, False), - (SchemaUInt8Dtype, pd.UInt8Dtype, False), - (SchemaUInt16Dtype, pd.UInt16Dtype, False), - (SchemaUInt32Dtype, pd.UInt32Dtype, False), - (SchemaUInt64Dtype, pd.UInt64Dtype, False), - (SchemaStringDtype, pd.StringDtype, False), - (SchemaBooleanDtype, pd.BooleanDtype, False), - ], - ) - def test_new_pandas_extension_dtype_class( - model, - dtype: pd.core.dtypes.base.ExtensionDtype, - has_mandatory_args: bool, - ): - """Test type annotations with the new nullable pandas dtypes.""" - _test_default_annotated_dtype(model, dtype, has_mandatory_args) + +class SchemaUInt32Dtype(pa.SchemaModel): + col: Series[pd.UInt32Dtype] + + +class SchemaUInt64Dtype(pa.SchemaModel): + col: Series[pd.UInt64Dtype] + + +class SchemaStringDtype(pa.SchemaModel): + col: Series[pd.StringDtype] + + +class SchemaBooleanDtype(pa.SchemaModel): + col: Series[pd.BooleanDtype] + + +@pytest.mark.parametrize( + "model, dtype, has_mandatory_args", + [ + (SchemaInt8Dtype, pd.Int8Dtype, False), + (SchemaInt16Dtype, pd.Int16Dtype, False), + (SchemaInt32Dtype, pd.Int32Dtype, False), + (SchemaInt64Dtype, pd.Int64Dtype, False), + (SchemaUInt8Dtype, pd.UInt8Dtype, False), + (SchemaUInt16Dtype, pd.UInt16Dtype, False), + (SchemaUInt32Dtype, pd.UInt32Dtype, False), + (SchemaUInt64Dtype, pd.UInt64Dtype, False), + (SchemaStringDtype, pd.StringDtype, False), + (SchemaBooleanDtype, pd.BooleanDtype, False), + ], +) +def test_new_pandas_extension_dtype_class( + model, + dtype: pd.core.dtypes.base.ExtensionDtype, + has_mandatory_args: bool, +): + """Test type annotations with the new nullable pandas dtypes.""" + _test_default_annotated_dtype(model, dtype, has_mandatory_args) diff --git a/tests/io/test_io.py b/tests/io/test_io.py index 032adc710..f1b7507c5 100644 --- a/tests/io/test_io.py +++ b/tests/io/test_io.py @@ -9,9 +9,10 @@ import pytest from packaging import version -import pandera as pa +import pandera import pandera.extensions as pa_ext import pandera.typing as pat +from pandera.engines import pandas_engine try: from pandera import io @@ -41,69 +42,69 @@ def _create_schema(index="single"): if index == "multi": - index = pa.MultiIndex( + index = pandera.MultiIndex( [ - pa.Index(pa.Int, name="int_index0"), - pa.Index(pa.Int, name="int_index1"), - pa.Index(pa.Int, name="int_index2"), + pandera.Index(pandera.Int, name="int_index0"), + pandera.Index(pandera.Int, name="int_index1"), + pandera.Index(pandera.Int, name="int_index2"), ] ) elif index == "single": # make sure io modules can handle case when index name is None - index = pa.Index(pa.Int, name=None) + index = pandera.Index(pandera.Int, name=None) else: index = None - return pa.DataFrameSchema( + return pandera.DataFrameSchema( columns={ - "int_column": pa.Column( - pa.Int, + "int_column": pandera.Column( + pandera.Int, checks=[ - pa.Check.greater_than(0), - pa.Check.less_than(10), - pa.Check.in_range(0, 10), + pandera.Check.greater_than(0), + pandera.Check.less_than(10), + pandera.Check.in_range(0, 10), ], ), - "float_column": pa.Column( - pa.Float, + "float_column": pandera.Column( + pandera.Float, checks=[ - pa.Check.greater_than(-10), - pa.Check.less_than(20), - pa.Check.in_range(-10, 20), + pandera.Check.greater_than(-10), + pandera.Check.less_than(20), + pandera.Check.in_range(-10, 20), ], ), - "str_column": pa.Column( - pa.String, + "str_column": pandera.Column( + pandera.String, checks=[ - pa.Check.isin(["foo", "bar", "x", "xy"]), - pa.Check.str_length(1, 3), + pandera.Check.isin(["foo", "bar", "x", "xy"]), + pandera.Check.str_length(1, 3), ], ), - "datetime_column": pa.Column( - pa.DateTime, + "datetime_column": pandera.Column( + pandera.DateTime, checks=[ - pa.Check.greater_than(pd.Timestamp("20100101")), - pa.Check.less_than(pd.Timestamp("20200101")), + pandera.Check.greater_than(pd.Timestamp("20100101")), + pandera.Check.less_than(pd.Timestamp("20200101")), ], ), - "timedelta_column": pa.Column( - pa.Timedelta, + "timedelta_column": pandera.Column( + pandera.Timedelta, checks=[ - pa.Check.greater_than(pd.Timedelta(1000, unit="ns")), - pa.Check.less_than(pd.Timedelta(10000, unit="ns")), + pandera.Check.greater_than(pd.Timedelta(1000, unit="ns")), + pandera.Check.less_than(pd.Timedelta(10000, unit="ns")), ], ), - "optional_props_column": pa.Column( - pa.String, + "optional_props_column": pandera.Column( + pandera.String, nullable=True, allow_duplicates=True, coerce=True, required=False, regex=True, - checks=[pa.Check.str_length(1, 3)], + checks=[pandera.Check.str_length(1, 3)], ), - "notype_column": pa.Column( - checks=pa.Check.isin(["foo", "bar", "x", "xy"]), + "notype_column": pandera.Column( + checks=pandera.Check.isin(["foo", "bar", "x", "xy"]), ), }, index=index, @@ -114,10 +115,10 @@ def _create_schema(index="single"): YAML_SCHEMA = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: int_column: - pandas_dtype: int + dtype: int64 nullable: false checks: greater_than: 0 @@ -130,7 +131,7 @@ def _create_schema(index="single"): required: true regex: false float_column: - pandas_dtype: float + dtype: float64 nullable: false checks: greater_than: -10 @@ -143,7 +144,7 @@ def _create_schema(index="single"): required: true regex: false str_column: - pandas_dtype: str + dtype: str nullable: false checks: isin: @@ -159,7 +160,7 @@ def _create_schema(index="single"): required: true regex: false datetime_column: - pandas_dtype: datetime64[ns] + dtype: datetime64[ns] nullable: false checks: greater_than: '2010-01-01 00:00:00' @@ -169,7 +170,7 @@ def _create_schema(index="single"): required: true regex: false timedelta_column: - pandas_dtype: timedelta64[ns] + dtype: timedelta64[ns] nullable: false checks: greater_than: 1000 @@ -179,7 +180,7 @@ def _create_schema(index="single"): required: true regex: false optional_props_column: - pandas_dtype: str + dtype: str nullable: true checks: str_length: @@ -190,7 +191,7 @@ def _create_schema(index="single"): required: false regex: true notype_column: - pandas_dtype: null + dtype: null nullable: false checks: isin: @@ -204,7 +205,7 @@ def _create_schema(index="single"): regex: false checks: null index: -- pandas_dtype: int +- dtype: int64 nullable: false checks: null name: null @@ -216,21 +217,21 @@ def _create_schema(index="single"): def _create_schema_null_index(): - return pa.DataFrameSchema( + return pandera.DataFrameSchema( columns={ - "float_column": pa.Column( - pa.Float, + "float_column": pandera.Column( + pandera.Float, checks=[ - pa.Check.greater_than(-10), - pa.Check.less_than(20), - pa.Check.in_range(-10, 20), + pandera.Check.greater_than(-10), + pandera.Check.less_than(20), + pandera.Check.in_range(-10, 20), ], ), - "str_column": pa.Column( - pa.String, + "str_column": pandera.Column( + pandera.String, checks=[ - pa.Check.isin(["foo", "bar", "x", "xy"]), - pa.Check.str_length(1, 3), + pandera.Check.isin(["foo", "bar", "x", "xy"]), + pandera.Check.str_length(1, 3), ], ), }, @@ -240,10 +241,10 @@ def _create_schema_null_index(): YAML_SCHEMA_NULL_INDEX = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: float_column: - pandas_dtype: float + dtype: float64 nullable: false checks: greater_than: -10 @@ -252,7 +253,7 @@ def _create_schema_null_index(): min_value: -10 max_value: 20 str_column: - pandas_dtype: str + dtype: str nullable: false checks: isin: @@ -271,28 +272,28 @@ def _create_schema_null_index(): def _create_schema_python_types(): - return pa.DataFrameSchema( + return pandera.DataFrameSchema( { - "int_column": pa.Column(int), - "float_column": pa.Column(float), - "str_column": pa.Column(str), - "object_column": pa.Column(object), + "int_column": pandera.Column(int), + "float_column": pandera.Column(float), + "str_column": pandera.Column(str), + "object_column": pandera.Column(object), } ) YAML_SCHEMA_PYTHON_TYPES = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: int_column: - pandas_dtype: int64 + dtype: int64 float_column: - pandas_dtype: float64 + dtype: float64 str_column: - pandas_dtype: str + dtype: str object_column: - pandas_dtype: object + dtype: object checks: null index: null coerce: false @@ -302,16 +303,16 @@ def _create_schema_python_types(): YAML_SCHEMA_MISSING_GLOBAL_CHECK = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: int_column: - pandas_dtype: int64 + dtype: int64 float_column: - pandas_dtype: float64 + dtype: float64 str_column: - pandas_dtype: str + dtype: str object_column: - pandas_dtype: object + dtype: object checks: unregistered_check: stat1: missing_str_stat @@ -324,20 +325,20 @@ def _create_schema_python_types(): YAML_SCHEMA_MISSING_COLUMN_CHECK = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: int_column: - pandas_dtype: int64 + dtype: int64 checks: unregistered_check: stat1: missing_str_stat stat2: 11 float_column: - pandas_dtype: float64 + dtype: float64 str_column: - pandas_dtype: str + dtype: str object_column: - pandas_dtype: object + dtype: object index: null coerce: false strict: false @@ -357,7 +358,7 @@ def test_inferred_schema_io(): "column3": ["a", "b", "c"], } ) - schema = pa.infer_schema(df) + schema = pandera.infer_schema(df) schema_yaml_str = schema.to_yaml() schema_from_yaml = io.from_yaml(schema_yaml_str) assert schema == schema_from_yaml @@ -371,6 +372,10 @@ def test_to_yaml(): """Test that to_yaml writes to yaml string.""" schema = _create_schema() yaml_str = io.to_yaml(schema) + with tempfile.NamedTemporaryFile("w+") as f: + f.write(yaml_str) + with tempfile.NamedTemporaryFile("w+") as f: + f.write(YAML_SCHEMA) assert yaml_str.strip() == YAML_SCHEMA.strip() yaml_str_schema_method = schema.to_yaml() @@ -412,7 +417,7 @@ def test_from_yaml_load_required_fields(): io.from_yaml("") with pytest.raises( - pa.errors.SchemaDefinitionError, match=".*must be a mapping.*" + pandera.errors.SchemaDefinitionError, match=".*must be a mapping.*" ): io.from_yaml( """ @@ -430,7 +435,7 @@ def test_io_yaml_file_obj(): output = schema.to_yaml(f) assert output is None f.seek(0) - schema_from_yaml = pa.DataFrameSchema.from_yaml(f) + schema_from_yaml = pandera.DataFrameSchema.from_yaml(f) assert schema_from_yaml == schema @@ -454,7 +459,7 @@ def test_io_yaml(index): with tempfile.NamedTemporaryFile("w+") as f: output = schema.to_yaml(Path(f.name)) assert output is None - schema_from_yaml = pa.DataFrameSchema.from_yaml(Path(f.name)) + schema_from_yaml = pandera.DataFrameSchema.from_yaml(Path(f.name)) assert schema_from_yaml == schema @@ -488,44 +493,48 @@ def test_to_script(index): def test_to_script_lambda_check(): """Test writing DataFrameSchema to a script with lambda check.""" - schema1 = pa.DataFrameSchema( + schema1 = pandera.DataFrameSchema( { - "a": pa.Column( - pa.Int, - checks=pa.Check(lambda s: s.mean() > 5, element_wise=False), + "a": pandera.Column( + pandera.Int, + checks=pandera.Check( + lambda s: s.mean() > 5, element_wise=False + ), ), } ) with pytest.warns(UserWarning): - pa.io.to_script(schema1) + pandera.io.to_script(schema1) - schema2 = pa.DataFrameSchema( + schema2 = pandera.DataFrameSchema( { - "a": pa.Column( - pa.Int, + "a": pandera.Column( + pandera.Int, ), }, - checks=pa.Check(lambda s: s.mean() > 5, element_wise=False), + checks=pandera.Check(lambda s: s.mean() > 5, element_wise=False), ) with pytest.warns(UserWarning, match=".*registered checks.*"): - pa.io.to_script(schema2) + pandera.io.to_script(schema2) def test_to_yaml_lambda_check(): """Test writing DataFrameSchema to a yaml with lambda check.""" - schema = pa.DataFrameSchema( + schema = pandera.DataFrameSchema( { - "a": pa.Column( - pa.Int, - checks=pa.Check(lambda s: s.mean() > 5, element_wise=False), + "a": pandera.Column( + pandera.Int, + checks=pandera.Check( + lambda s: s.mean() > 5, element_wise=False + ), ), } ) with pytest.warns(UserWarning): - pa.io.to_yaml(schema) + pandera.io.to_yaml(schema) def test_format_checks_warning(): @@ -555,24 +564,24 @@ def ncols_gt(pandas_obj: pd.DataFrame, column_count: int) -> bool: return len(pandas_obj.columns) > column_count assert ( - len(pa.Check.REGISTERED_CUSTOM_CHECKS) == 1 + len(pandera.Check.REGISTERED_CUSTOM_CHECKS) == 1 ), "custom check is registered" - schema = pa.DataFrameSchema( + schema = pandera.DataFrameSchema( { - "a": pa.Column( - pa.Int, + "a": pandera.Column( + pandera.Int, ), }, - checks=[pa.Check.ncols_gt(column_count=5)], + checks=[pandera.Check.ncols_gt(column_count=5)], ) - serialized = pa.io.to_yaml(schema) - loaded = pa.io.from_yaml(serialized) + serialized = pandera.io.to_yaml(schema) + loaded = pandera.io.from_yaml(serialized) assert len(loaded.checks) == 1, "global check was stripped" - with pytest.raises(pa.errors.SchemaError): + with pytest.raises(pandera.errors.SchemaError): schema.validate(pd.DataFrame(data={"a": [1]})) assert ncols_gt_called, "did not call ncols_gt" @@ -581,17 +590,17 @@ def ncols_gt(pandas_obj: pd.DataFrame, column_count: int) -> bool: def test_to_yaml_custom_dataframe_check(): """Tests that writing DataFrameSchema with an unregistered check raises.""" - schema = pa.DataFrameSchema( + schema = pandera.DataFrameSchema( { - "a": pa.Column( - pa.Int, + "a": pandera.Column( + pandera.Int, ), }, - checks=[pa.Check(lambda obj: len(obj.index) > 1)], + checks=[pandera.Check(lambda obj: len(obj.index) > 1)], ) with pytest.warns(UserWarning, match=".*registered checks.*"): - pa.io.to_yaml(schema) + pandera.io.to_yaml(schema) # the unregistered column check case is tested in # `test_to_yaml_lambda_check` @@ -601,13 +610,13 @@ def test_to_yaml_bugfix_419(): """Ensure that GH#419 is fixed""" # pylint: disable=no-self-use - class CheckedSchemaModel(pa.SchemaModel): + class CheckedSchemaModel(pandera.SchemaModel): """Schema with a global check""" a: pat.Series[pat.Int64] b: pat.Series[pat.Int64] - @pa.dataframe_check() + @pandera.dataframe_check() def unregistered_check(self, _): """sample unregistered check""" ... @@ -706,17 +715,17 @@ def unregistered_check(self, _): } # pandas dtype aliases to support testing across multiple pandas versions: -STR_DTYPE = pa.dtypes.PandasDtype.from_str_alias("string").value -STR_DTYPE_ALIAS = pa.dtypes.PandasDtype.from_str_alias("string").str_alias -INT_DTYPE = pa.dtypes.PandasDtype.from_str_alias("int").value -INT_DTYPE_ALIAS = pa.dtypes.PandasDtype.from_str_alias("int").str_alias +STR_DTYPE = pandas_engine.Engine.dtype("string") +STR_DTYPE_ALIAS = str(pandas_engine.Engine.dtype("string")) +INT_DTYPE = pandas_engine.Engine.dtype("int") +INT_DTYPE_ALIAS = str(pandas_engine.Engine.dtype("int")) YAML_FROM_FRICTIONLESS = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: integer_col: - pandas_dtype: {INT_DTYPE} + dtype: {INT_DTYPE} nullable: false checks: in_range: @@ -727,7 +736,7 @@ def unregistered_check(self, _): required: true regex: false integer_col_2: - pandas_dtype: {INT_DTYPE} + dtype: {INT_DTYPE} nullable: true checks: less_than_or_equal_to: 30 @@ -736,7 +745,7 @@ def unregistered_check(self, _): required: true regex: false string_col: - pandas_dtype: {STR_DTYPE} + dtype: {STR_DTYPE} nullable: true checks: str_length: @@ -747,7 +756,7 @@ def unregistered_check(self, _): required: true regex: false string_col_2: - pandas_dtype: {STR_DTYPE} + dtype: {STR_DTYPE} nullable: true checks: str_matches: ^\\d{{3}}[A-Z]$ @@ -756,7 +765,7 @@ def unregistered_check(self, _): required: true regex: false string_col_3: - pandas_dtype: {STR_DTYPE} + dtype: {STR_DTYPE} nullable: true checks: str_length: 3 @@ -765,7 +774,7 @@ def unregistered_check(self, _): required: true regex: false string_col_4: - pandas_dtype: {STR_DTYPE} + dtype: {STR_DTYPE} nullable: true checks: str_length: 3 @@ -774,7 +783,7 @@ def unregistered_check(self, _): required: true regex: false float_col: - pandas_dtype: category + dtype: category nullable: false checks: isin: @@ -786,7 +795,7 @@ def unregistered_check(self, _): required: true regex: false float_col_2: - pandas_dtype: float + dtype: float64 nullable: true checks: null allow_duplicates: true @@ -794,7 +803,7 @@ def unregistered_check(self, _): required: true regex: false date_col: - pandas_dtype: {STR_DTYPE} + dtype: {STR_DTYPE} nullable: true checks: greater_than_or_equal_to: '20201231' @@ -848,12 +857,12 @@ def unregistered_check(self, _): ) def test_frictionless_schema_parses_correctly(frictionless_schema): """Test parsing frictionless schema from yaml and json.""" - schema = pa.io.from_frictionless_schema(frictionless_schema) + schema = pandera.io.from_frictionless_schema(frictionless_schema) assert str(schema.to_yaml()).strip() == YAML_FROM_FRICTIONLESS.strip() assert isinstance( - schema, pa.schemas.DataFrameSchema + schema, pandera.schemas.DataFrameSchema ), "schema object not loaded successfully" df = schema.validate(VALID_FRICTIONLESS_DF) @@ -871,7 +880,7 @@ def test_frictionless_schema_parses_correctly(frictionless_schema): "date_col": STR_DTYPE_ALIAS, }, "dtypes not parsed correctly from frictionless schema" - with pytest.raises(pa.errors.SchemaErrors) as err: + with pytest.raises(pandera.errors.SchemaErrors) as err: schema.validate(INVALID_FRICTIONLESS_DF, lazy=True) # check we're capturing all errors according to the frictionless schema: assert err.value.failure_cases[["check", "failure_case"]].fillna( diff --git a/tests/strategies/test_strategies.py b/tests/strategies/test_strategies.py index ee293f582..f5ed43dd5 100644 --- a/tests/strategies/test_strategies.py +++ b/tests/strategies/test_strategies.py @@ -1,8 +1,6 @@ # pylint: disable=undefined-variable,redefined-outer-name,invalid-name,undefined-loop-variable # noqa """Unit tests for pandera data generating strategies.""" - import operator -import platform import re from typing import Any, Callable, Optional from unittest.mock import MagicMock @@ -14,6 +12,8 @@ import pandera as pa from pandera import strategies from pandera.checks import _CheckBase, register_check_statistics +from pandera.dtypes import is_category, is_complex, is_float +from pandera.engines import pandas_engine try: import hypothesis @@ -27,88 +27,102 @@ HAS_HYPOTHESIS = True -TYPE_ERROR_FMT = "data generation for the {} dtype is currently unsupported" - -SUPPORTED_DTYPES = [] -for pdtype in pa.PandasDtype: +SUPPORTED_DTYPES = set() +for data_type in pandas_engine.Engine.get_registered_dtypes(): if ( - pdtype is pa.PandasDtype.Complex256 and platform.system() == "Windows" - ) or pdtype is pa.Category: + # valid hypothesis.strategies.floats <=64 + getattr(data_type, "bit_width", -1) > 64 + or is_category(data_type) + or data_type + in (pandas_engine.Interval, pandas_engine.Period, pandas_engine.Sparse) + ): continue - SUPPORTED_DTYPES.append(pdtype) + SUPPORTED_DTYPES.add(pandas_engine.Engine.dtype(data_type)) NUMERIC_DTYPES = [ - pdtype for pdtype in SUPPORTED_DTYPES if pdtype.is_continuous + data_type for data_type in SUPPORTED_DTYPES if data_type.continuous ] NULLABLE_DTYPES = [ - pdtype - for pdtype in SUPPORTED_DTYPES - if not pdtype.is_complex - and not pdtype.is_category - and not pdtype.is_object + data_type + for data_type in SUPPORTED_DTYPES + if not is_complex(data_type) + and not is_category(data_type) + and not data_type == pandas_engine.Engine.dtype("object") ] NUMERIC_RANGE_CONSTANT = 10 DATE_RANGE_CONSTANT = np.timedelta64(NUMERIC_RANGE_CONSTANT, "D") COMPLEX_RANGE_CONSTANT = np.complex64( - complex(NUMERIC_RANGE_CONSTANT, NUMERIC_RANGE_CONSTANT) + complex(NUMERIC_RANGE_CONSTANT, NUMERIC_RANGE_CONSTANT) # type: ignore ) -@pytest.mark.parametrize("pdtype", [pa.Category]) -def test_unsupported_pandas_dtype_strategy(pdtype: pa.PandasDtype) -> None: +@pytest.mark.parametrize("data_type", [pa.Category]) +def test_unsupported_pandas_dtype_strategy(data_type): """Test unsupported pandas dtype strategy raises error.""" - with pytest.raises(TypeError, match=TYPE_ERROR_FMT.format(pdtype.name)): - strategies.pandas_dtype_strategy(pdtype) + with pytest.raises( + TypeError, + match="data generation for the Category dtype is currently unsupported", + ): + strategies.pandas_dtype_strategy(data_type) -@pytest.mark.parametrize("pdtype", SUPPORTED_DTYPES) +@pytest.mark.parametrize("data_type", SUPPORTED_DTYPES) @hypothesis.given(st.data()) -def test_pandas_dtype_strategy(pdtype: pa.PandasDtype, data) -> None: +@hypothesis.settings( + suppress_health_check=[hypothesis.HealthCheck.too_slow], + max_examples=20, +) +def test_pandas_dtype_strategy(data_type, data): """Test that series can be constructed from pandas dtype.""" - strategy = strategies.pandas_dtype_strategy(pdtype) + strategy = strategies.pandas_dtype_strategy(data_type) example = data.draw(strategy) - expected_type = ( - pdtype.String.numpy_dtype.type # type: ignore[attr-defined] - if pdtype is pa.Object - else pdtype.numpy_dtype.type - ) - + expected_type = strategies.to_numpy_dtype(data_type).type assert example.dtype.type == expected_type - chained_strategy = strategies.pandas_dtype_strategy(pdtype, strategy) + chained_strategy = strategies.pandas_dtype_strategy(data_type, strategy) chained_example = data.draw(chained_strategy) assert chained_example.dtype.type == expected_type -@pytest.mark.parametrize("pdtype", NUMERIC_DTYPES) +@pytest.mark.parametrize("data_type", NUMERIC_DTYPES) @hypothesis.given(st.data()) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) -def test_check_strategy_continuous(pdtype: pa.PandasDtype, data) -> None: +def test_check_strategy_continuous(data_type, data): """Test built-in check strategies can generate continuous data.""" + np_dtype = strategies.to_numpy_dtype(data_type) value = data.draw( npst.from_dtype( - pdtype.numpy_dtype, + strategies.to_numpy_dtype(data_type), allow_nan=False, allow_infinity=False, ) ) - pdtype = pa.PandasDtype.Int - value = data.draw(npst.from_dtype(pdtype.numpy_dtype)) - assert data.draw(strategies.ne_strategy(pdtype, value=value)) != value - assert data.draw(strategies.eq_strategy(pdtype, value=value)) == value - assert data.draw(strategies.gt_strategy(pdtype, min_value=value)) > value - assert data.draw(strategies.ge_strategy(pdtype, min_value=value)) >= value - assert data.draw(strategies.lt_strategy(pdtype, max_value=value)) < value - assert data.draw(strategies.le_strategy(pdtype, max_value=value)) <= value + # don't overstep bounds of representation + hypothesis.assume(np.finfo(np_dtype).min < value < np.finfo(np_dtype).max) + + assert data.draw(strategies.ne_strategy(data_type, value=value)) != value + assert data.draw(strategies.eq_strategy(data_type, value=value)) == value + assert ( + data.draw(strategies.gt_strategy(data_type, min_value=value)) > value + ) + assert ( + data.draw(strategies.ge_strategy(data_type, min_value=value)) >= value + ) + assert ( + data.draw(strategies.lt_strategy(data_type, max_value=value)) < value + ) + assert ( + data.draw(strategies.le_strategy(data_type, max_value=value)) <= value + ) -def value_ranges(pdtype: pa.PandasDtype): +def value_ranges(data_type: pa.DataType): """Strategy to generate value range based on PandasDtype""" kwargs = dict( allow_nan=False, @@ -118,15 +132,19 @@ def value_ranges(pdtype: pa.PandasDtype): ) return ( st.tuples( - strategies.pandas_dtype_strategy(pdtype, strategy=None, **kwargs), - strategies.pandas_dtype_strategy(pdtype, strategy=None, **kwargs), + strategies.pandas_dtype_strategy( + data_type, strategy=None, **kwargs + ), + strategies.pandas_dtype_strategy( + data_type, strategy=None, **kwargs + ), ) .map(sorted) .filter(lambda x: x[0] < x[1]) ) -@pytest.mark.parametrize("pdtype", NUMERIC_DTYPES) +@pytest.mark.parametrize("data_type", NUMERIC_DTYPES) @pytest.mark.parametrize( "strat_fn, arg_name, base_st_type, compare_op", [ @@ -143,22 +161,17 @@ def value_ranges(pdtype: pa.PandasDtype): suppress_health_check=[hypothesis.HealthCheck.too_slow], ) def test_check_strategy_chained_continuous( - pdtype: pa.PandasDtype, - strat_fn: Callable, - arg_name: str, - base_st_type: str, - compare_op, - data, -) -> None: + data_type, strat_fn, arg_name, base_st_type, compare_op, data +): """ Test built-in check strategies can generate continuous data building off of a parent strategy. """ - min_value, max_value = data.draw(value_ranges(pdtype)) + min_value, max_value = data.draw(value_ranges(data_type)) hypothesis.assume(min_value < max_value) value = min_value base_st = strategies.pandas_dtype_strategy( - pdtype, + data_type, min_value=min_value, max_value=max_value, allow_nan=False, @@ -170,7 +183,7 @@ def test_check_strategy_chained_continuous( assert_base_st = st.just(value) elif base_st_type == "limit": assert_base_st = strategies.pandas_dtype_strategy( - pdtype, + data_type, min_value=min_value, max_value=max_value, allow_nan=False, @@ -182,27 +195,25 @@ def test_check_strategy_chained_continuous( local_vars = locals() assert_value = local_vars[arg_name] example = data.draw( - strat_fn(pdtype, assert_base_st, **{arg_name: assert_value}) + strat_fn(data_type, assert_base_st, **{arg_name: assert_value}) ) assert compare_op(example, assert_value) -@pytest.mark.parametrize("pdtype", NUMERIC_DTYPES) +@pytest.mark.parametrize("data_type", NUMERIC_DTYPES) @pytest.mark.parametrize("chained", [True, False]) @hypothesis.given(st.data()) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) -def test_in_range_strategy( - pdtype: pa.PandasDtype, chained: bool, data -) -> None: +def test_in_range_strategy(data_type, chained, data): """Test the built-in in-range strategy can correctly generate data.""" - min_value, max_value = data.draw(value_ranges(pdtype)) + min_value, max_value = data.draw(value_ranges(data_type)) hypothesis.assume(min_value < max_value) base_st_in_range = None if chained: - if pdtype.is_float: + if is_float(data_type): base_st_kwargs = { "exclude_min": False, "exclude_max": False, @@ -212,13 +223,13 @@ def test_in_range_strategy( # constraining the strategy this way makes testing more efficient base_st_in_range = strategies.pandas_dtype_strategy( - pdtype, + data_type, min_value=min_value, max_value=max_value, **base_st_kwargs, # type: ignore[arg-type] ) strat = strategies.in_range_strategy( - pdtype, + data_type, base_st_in_range, min_value=min_value, max_value=max_value, @@ -228,20 +239,18 @@ def test_in_range_strategy( @pytest.mark.parametrize( - "pdtype", - [pdtype for pdtype in SUPPORTED_DTYPES if pdtype.is_continuous], + "data_type", + [data_type for data_type in SUPPORTED_DTYPES if data_type.continuous], ) @pytest.mark.parametrize("chained", [True, False]) @hypothesis.given(st.data()) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) -def test_isin_notin_strategies( - pdtype: pa.PandasDtype, chained: bool, data -) -> None: +def test_isin_notin_strategies(data_type, chained, data): """Test built-in check strategies that rely on discrete values.""" value_st = strategies.pandas_dtype_strategy( - pdtype, + data_type, allow_nan=False, allow_infinity=False, exclude_min=False, @@ -254,17 +263,17 @@ def test_isin_notin_strategies( if chained: base_values = values + [data.draw(value_st) for _ in range(10)] isin_base_st = strategies.isin_strategy( - pdtype, allowed_values=base_values + data_type, allowed_values=base_values ) notin_base_st = strategies.notin_strategy( - pdtype, forbidden_values=base_values + data_type, forbidden_values=base_values ) isin_st = strategies.isin_strategy( - pdtype, isin_base_st, allowed_values=values + data_type, isin_base_st, allowed_values=values ) notin_st = strategies.notin_strategy( - pdtype, notin_base_st, forbidden_values=values + data_type, notin_base_st, forbidden_values=values ) assert data.draw(isin_st) in values assert data.draw(notin_st) not in values @@ -329,7 +338,8 @@ def test_str_pattern_checks( .filter(lambda x: x[0] < x[1]) # type: ignore ), ) -def test_str_length_checks(chained: bool, data, value_range) -> None: +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow]) +def test_str_length_checks(chained, data, value_range): """Test built-in check strategies for string length.""" min_value, max_value = value_range base_st = None @@ -352,12 +362,12 @@ def test_register_check_strategy(data) -> None: # pylint: disable=unused-argument def custom_eq_strategy( - pandas_dtype: pa.PandasDtype, + pandas_dtype: pa.DataType, strategy: st.SearchStrategy = None, *, value: Any, ): - return st.just(value).map(pandas_dtype.numpy_dtype.type) + return st.just(value).map(strategies.to_numpy_dtype(pandas_dtype).type) # pylint: disable=no-member class CustomCheck(_CheckBase): @@ -381,7 +391,7 @@ def _custom_equals(series: pd.Series) -> pd.Series: ) check = CustomCheck.custom_equals(100) - result = data.draw(check.strategy(pa.Int)) + result = data.draw(check.strategy(pa.Int())) assert result == 100 @@ -423,13 +433,13 @@ def _custom_check(series: pd.Series) -> pd.Series: ) def test_series_strategy(data) -> None: """Test SeriesSchema strategy.""" - series_schema = pa.SeriesSchema(pa.Int, pa.Check.gt(0)) + series_schema = pa.SeriesSchema(pa.Int(), pa.Check.gt(0)) series_schema(data.draw(series_schema.strategy())) def test_series_example() -> None: """Test SeriesSchema example method generate examples that pass.""" - series_schema = pa.SeriesSchema(pa.Int, pa.Check.gt(0)) + series_schema = pa.SeriesSchema(pa.Int(), pa.Check.gt(0)) for _ in range(10): series_schema(series_schema.example()) @@ -440,35 +450,25 @@ def test_series_example() -> None: ) def test_column_strategy(data) -> None: """Test Column schema strategy.""" - column_schema = pa.Column(pa.Int, pa.Check.gt(0), name="column") + column_schema = pa.Column(pa.Int(), pa.Check.gt(0), name="column") column_schema(data.draw(column_schema.strategy())) def test_column_example(): """Test Column schema example method generate examples that pass.""" - column_schema = pa.Column(pa.Int, pa.Check.gt(0), name="column") + column_schema = pa.Column(pa.Int(), pa.Check.gt(0), name="column") for _ in range(10): column_schema(column_schema.example()) -@pytest.mark.parametrize( - "pdtype", - SUPPORTED_DTYPES, -) -@pytest.mark.parametrize( - "size", - [None, 0, 1, 3, 5], -) +@pytest.mark.parametrize("data_type", SUPPORTED_DTYPES) +@pytest.mark.parametrize("size", [None, 0, 1, 3, 5]) @hypothesis.given(st.data()) -@hypothesis.settings( - suppress_health_check=[hypothesis.HealthCheck.too_slow], -) -def test_dataframe_strategy( - pdtype: pa.PandasDtype, size: Optional[int], data -) -> None: +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow]) +def test_dataframe_strategy(data_type, size, data): """Test DataFrameSchema strategy.""" dataframe_schema = pa.DataFrameSchema( - {f"{pdtype.value}_col": pa.Column(pdtype)} + {f"{data_type}_col": pa.Column(data_type)} ) df_sample = data.draw(dataframe_schema.strategy(size=size)) if size == 0: @@ -479,15 +479,17 @@ def test_dataframe_strategy( ) else: assert isinstance(dataframe_schema(df_sample), pd.DataFrame) - with pytest.raises(pa.errors.BaseStrategyOnlyError): - strategies.dataframe_strategy( - pdtype, strategies.pandas_dtype_strategy(pdtype) - ) + # with pytest.raises(pa.errors.BaseStrategyOnlyError): + # strategies.dataframe_strategy( + # data_type, strategies.pandas_dtype_strategy(data_type) + # ) def test_dataframe_example() -> None: """Test DataFrameSchema example method generate examples that pass.""" - schema = pa.DataFrameSchema({"column": pa.Column(pa.Int, pa.Check.gt(0))}) + schema = pa.DataFrameSchema( + {"column": pa.Column(pa.Int(), pa.Check.gt(0))} + ) for _ in range(10): schema(schema.example()) @@ -521,21 +523,16 @@ def test_dataframe_with_regex(regex: str, data, n_regex_columns: int) -> None: assert df.shape[1] == n_regex_columns -@pytest.mark.parametrize("pdtype", NUMERIC_DTYPES) +@pytest.mark.parametrize("data_type", NUMERIC_DTYPES) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) @hypothesis.given(st.data()) -def test_dataframe_checks(pdtype: pa.PandasDtype, data) -> None: +def test_dataframe_checks(data_type, data): """Test dataframe strategy with checks defined at the dataframe level.""" - if pa.LEGACY_PANDAS and pdtype in { - pa.PandasDtype.UInt64, - pa.PandasDtype.UINT64, - }: - pytest.xfail("pandas<1.0.0 leads to OverflowError for these dtypes.") - min_value, max_value = data.draw(value_ranges(pdtype)) + min_value, max_value = data.draw(value_ranges(data_type)) dataframe_schema = pa.DataFrameSchema( - {f"{pdtype.value}_col": pa.Column(pdtype) for _ in range(5)}, + {f"{data_type}_col": pa.Column(data_type) for _ in range(5)}, checks=pa.Check.in_range(min_value, max_value), ) strat = dataframe_schema.strategy(size=5) @@ -543,17 +540,19 @@ def test_dataframe_checks(pdtype: pa.PandasDtype, data) -> None: dataframe_schema(example) -@pytest.mark.parametrize("pdtype", [pa.Int, pa.Float, pa.String, pa.DateTime]) +@pytest.mark.parametrize( + "data_type", [pa.Int(), pa.Float, pa.String, pa.DateTime] +) @hypothesis.given(st.data()) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) -def test_dataframe_strategy_with_indexes(pdtype: pa.PandasDtype, data) -> None: +def test_dataframe_strategy_with_indexes(data_type, data): """Test dataframe strategy with index and multiindex components.""" - dataframe_schema_index = pa.DataFrameSchema(index=pa.Index(pdtype)) + dataframe_schema_index = pa.DataFrameSchema(index=pa.Index(data_type)) dataframe_schema_multiindex = pa.DataFrameSchema( index=pa.MultiIndex( - [pa.Index(pdtype, name=f"index{i}") for i in range(3)] + [pa.Index(data_type, name=f"index{i}") for i in range(3)] ) ) @@ -569,12 +568,14 @@ def test_dataframe_strategy_with_indexes(pdtype: pa.PandasDtype, data) -> None: ) def test_index_strategy(data) -> None: """Test Index schema component strategy.""" - pdtype = pa.PandasDtype.Int - index_schema = pa.Index(pdtype, allow_duplicates=False, name="index") + data_type = pa.Int() + index_schema = pa.Index(data_type, allow_duplicates=False, name="index") strat = index_schema.strategy(size=10) example = data.draw(strat) + assert (~example.duplicated()).all() - assert example.dtype == pdtype.str_alias + actual_data_type = pandas_engine.Engine.dtype(example.dtype) + assert data_type.check(actual_data_type) index_schema(pd.DataFrame(index=example)) @@ -582,8 +583,8 @@ def test_index_example() -> None: """ Test Index schema component example method generates examples that pass. """ - pdtype = pa.PandasDtype.Int - index_schema = pa.Index(pdtype, allow_duplicates=False) + data_type = pa.Int() + index_schema = pa.Index(data_type, allow_duplicates=False) for _ in range(10): index_schema(pd.DataFrame(index=index_schema.example())) @@ -594,22 +595,25 @@ def test_index_example() -> None: ) def test_multiindex_strategy(data) -> None: """Test MultiIndex schema component strategy.""" - pdtype = pa.PandasDtype.Float + data_type = pa.Float() multiindex = pa.MultiIndex( indexes=[ - pa.Index(pdtype, allow_duplicates=False, name="level_0"), - pa.Index(pdtype, nullable=True), - pa.Index(pdtype), + pa.Index(data_type, allow_duplicates=False, name="level_0"), + pa.Index(data_type, nullable=True), + pa.Index(data_type), ] ) strat = multiindex.strategy(size=10) example = data.draw(strat) for i in range(example.nlevels): - assert example.get_level_values(i).dtype == pdtype.str_alias + actual_data_type = pandas_engine.Engine.dtype( + example.get_level_values(i).dtype + ) + assert data_type.check(actual_data_type) with pytest.raises(pa.errors.BaseStrategyOnlyError): strategies.multiindex_strategy( - pdtype, strategies.pandas_dtype_strategy(pdtype) + data_type, strategies.pandas_dtype_strategy(data_type) ) @@ -618,12 +622,12 @@ def test_multiindex_example() -> None: Test MultiIndex schema component example method generates examples that pass. """ - pdtype = pa.PandasDtype.Float + data_type = pa.Float() multiindex = pa.MultiIndex( indexes=[ - pa.Index(pdtype, allow_duplicates=False, name="level_0"), - pa.Index(pdtype, nullable=True), - pa.Index(pdtype), + pa.Index(data_type, allow_duplicates=False, name="level_0"), + pa.Index(data_type, nullable=True), + pa.Index(data_type), ] ) for _ in range(10): @@ -631,21 +635,23 @@ def test_multiindex_example() -> None: multiindex(pd.DataFrame(index=example)) -@pytest.mark.parametrize("pdtype", NULLABLE_DTYPES) +@pytest.mark.parametrize("data_type", NULLABLE_DTYPES) @hypothesis.given(st.data()) -def test_field_element_strategy(pdtype: pa.PandasDtype, data) -> None: +def test_field_element_strategy(data_type, data): """Test strategy for generating elements in columns/indexes.""" - strategy = strategies.field_element_strategy(pdtype) + strategy = strategies.field_element_strategy(data_type) element = data.draw(strategy) - assert element.dtype.type == pdtype.numpy_dtype.type + + expected_type = strategies.to_numpy_dtype(data_type).type + assert element.dtype.type == expected_type with pytest.raises(pa.errors.BaseStrategyOnlyError): strategies.field_element_strategy( - pdtype, strategies.pandas_dtype_strategy(pdtype) + data_type, strategies.pandas_dtype_strategy(data_type) ) -@pytest.mark.parametrize("pdtype", NULLABLE_DTYPES) +@pytest.mark.parametrize("data_type", NULLABLE_DTYPES) @pytest.mark.parametrize( "field_strategy", [strategies.index_strategy, strategies.series_strategy], @@ -656,21 +662,11 @@ def test_field_element_strategy(pdtype: pa.PandasDtype, data) -> None: suppress_health_check=[hypothesis.HealthCheck.too_slow], ) def test_check_nullable_field_strategy( - pdtype: pa.PandasDtype, field_strategy, nullable: bool, data -) -> None: + data_type, field_strategy, nullable, data +): """Test strategies for generating nullable column/index data.""" - - if ( - pa.LEGACY_PANDAS - and field_strategy is strategies.index_strategy - and (pdtype.is_nullable_int or pdtype.is_nullable_uint) - ): - pytest.skip( - "pandas version<1 does not handle nullable integer indexes" - ) - size = 5 - strat = field_strategy(pdtype, nullable=nullable, size=size) + strat = field_strategy(data_type, nullable=nullable, size=size) example = data.draw(strat) if nullable: @@ -679,24 +675,18 @@ def test_check_nullable_field_strategy( assert example.notna().all() -@pytest.mark.parametrize("pdtype", NULLABLE_DTYPES) +@pytest.mark.parametrize("data_type", NULLABLE_DTYPES) @pytest.mark.parametrize("nullable", [True, False]) @hypothesis.given(st.data()) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) -def test_check_nullable_dataframe_strategy( - pdtype: pa.PandasDtype, nullable: bool, data -) -> None: +def test_check_nullable_dataframe_strategy(data_type, nullable, data): """Test strategies for generating nullable DataFrame data.""" size = 5 # pylint: disable=no-value-for-parameter strat = strategies.dataframe_strategy( - columns={ - "col": pa.Column( - pandas_dtype=pdtype, nullable=nullable, name="col" - ) - }, + columns={"col": pa.Column(data_type, nullable=nullable, name="col")}, size=size, ) example = data.draw(strat) @@ -711,7 +701,7 @@ def test_check_nullable_dataframe_strategy( [ [ pa.SeriesSchema( - pa.Int, + pa.Int(), checks=[ pa.Check(lambda x: x > 0, element_wise=True), pa.Check(lambda x: x > -10, element_wise=True), @@ -721,7 +711,7 @@ def test_check_nullable_dataframe_strategy( ], [ pa.SeriesSchema( - pa.Int, + pa.Int(), checks=[ pa.Check(lambda s: s > -10000), pa.Check(lambda s: s > -9999), @@ -755,7 +745,7 @@ def test_series_strategy_undefined_check_strategy( [ [ pa.DataFrameSchema( - columns={"column": pa.Column(pa.Int)}, + columns={"column": pa.Column(pa.Int())}, checks=[ pa.Check(lambda x: x > 0, element_wise=True), pa.Check(lambda x: x > -10, element_wise=True), @@ -767,7 +757,7 @@ def test_series_strategy_undefined_check_strategy( pa.DataFrameSchema( columns={ "column": pa.Column( - pa.Int, + pa.Int(), checks=[ pa.Check(lambda s: s > -10000), pa.Check(lambda s: s > -9999), @@ -779,7 +769,7 @@ def test_series_strategy_undefined_check_strategy( ], [ pa.DataFrameSchema( - columns={"column": pa.Column(pa.Int)}, + columns={"column": pa.Column(pa.Int())}, checks=[ pa.Check(lambda s: s > -10000), pa.Check(lambda s: s > -9999), @@ -877,4 +867,4 @@ def test_schema_component_with_no_pdtype() -> None: strategies.index_strategy, ]: with pytest.raises(pa.errors.SchemaDefinitionError): - schema_component_strategy(pandas_dtype=None) # type: ignore[operator] + schema_component_strategy(pandera_dtype=None)