diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index d95f220c6..6d940cb0a 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -2,24 +2,24 @@ name: CI Tests on: push: branches: - - master - - dev - - bugfix - - 'release/*' - - dtypes + - master + - dev + - bugfix + - "release/*" + - dtypes pull_request: branches: - - master - - dev - - bugfix - - 'release/*' - - dtypes + - master + - dev + - bugfix + - "release/*" + - dtypes env: DEFAULT_PYTHON: 3.8 CI: "true" # Increase this value to reset cache if environment.yml has not changed - CACHE_VERSION: 2 + CACHE_VERSION: 3 jobs: codestyle: @@ -73,7 +73,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.6", "3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9"] defaults: run: shell: bash -l {0} @@ -135,16 +135,13 @@ jobs: tests: name: > - CI Tests (${{ matrix.python-version }}, - ${{ matrix.os }}, - pandas-${{ matrix.pandas-version }}) + CI Tests (${{ matrix.python-version }}, ${{ matrix.os }}) runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] - python-version: ["3.6", "3.7", "3.8", "3.9"] - pandas-version: ["latest", "0.25.3"] + python-version: ["3.7", "3.8", "3.9"] defaults: run: @@ -186,28 +183,28 @@ jobs: nox -db conda -r -v --non-interactive - --session "tests-${{ matrix.python-version }}(extra='core', pandas='${{ matrix.pandas-version }}')" + --session "tests-${{ matrix.python-version }}(extra='core')" - name: Unit Tests - Hypotheses run: > nox -db conda -r -v --non-interactive - --session "tests-${{ matrix.python-version }}(extra='hypotheses', pandas='${{ matrix.pandas-version }}')" + --session "tests-${{ matrix.python-version }}(extra='hypotheses')" - name: Unit Tests - IO run: > nox -db conda -r -v --non-interactive - --session "tests-${{ matrix.python-version }}(extra='io', pandas='${{ matrix.pandas-version }}')" + --session "tests-${{ matrix.python-version }}(extra='io')" - name: Unit Tests - Strategies run: > nox -db conda -r -v --non-interactive - --session "tests-${{ matrix.python-version }}(extra='strategies', pandas='${{ matrix.pandas-version }}')" + --session "tests-${{ matrix.python-version }}(extra='strategies')" - name: Upload coverage to Codecov uses: "codecov/codecov-action@v1" @@ -217,4 +214,4 @@ jobs: nox -db conda -r -v --non-interactive - --session "docs-${{ matrix.python-version }}(pandas='${{ matrix.pandas-version }}')" + --session "docs-${{ matrix.python-version }}" diff --git a/docs/source/API_reference.rst b/docs/source/API_reference.rst index 1234b3e11..cca8a43d5 100644 --- a/docs/source/API_reference.rst +++ b/docs/source/API_reference.rst @@ -93,7 +93,7 @@ Pandas Data Types :template: pandas_dtype_class.rst :nosignatures: - pandera.dtypes.PandasDtype + pandera.dtypes.DataType Decorators diff --git a/docs/source/_templates/enum_class.rst b/docs/source/_templates/enum_class.rst deleted file mode 100644 index c10df62d9..000000000 --- a/docs/source/_templates/enum_class.rst +++ /dev/null @@ -1,51 +0,0 @@ -{{ fullname | escape | underline}} - -.. currentmodule:: {{ module }} - -.. autoclass:: PandasDtype - :show-inheritance: - :exclude-members: - - .. autoattribute:: str_alias - .. automethod:: from_str_alias - .. automethod:: from_pandas_api_type - - - - -.. autoclass:: {{ objname }} - - {% block attributes %} - {% if attributes %} - .. rubric:: Attributes - - .. autosummary:: - :nosignatures: - - {% for item in attributes %} - ~{{ name }}.{{ item }} - {%- endfor %} - - {% endif %} - {% endblock %} - - {% block methods %} - {% if methods %} - .. rubric:: Methods - - .. autosummary:: - :nosignatures: - :toctree: methods - - {% for item in methods %} - {%- if item not in inherited_members %} - ~{{ name }}.{{ item }} - {%- endif %} - {%- endfor %} - {% endif %} - - {%- if '__call__' in members %} - ~{{ name }}.__call__ - {%- endif %} - - {% endblock %} diff --git a/docs/source/_templates/pandas_dtype_class.rst b/docs/source/_templates/pandas_dtype_class.rst deleted file mode 100644 index 3bc84a901..000000000 --- a/docs/source/_templates/pandas_dtype_class.rst +++ /dev/null @@ -1,25 +0,0 @@ -{{ fullname | escape | underline}} - -.. currentmodule:: {{ module }} - -.. autoclass:: {{ objname }} - :show-inheritance: - :exclude-members: - - {% block attributes %} - {% if attributes %} - .. rubric:: Attributes - - .. autosummary:: - :nosignatures: - - {% for item in attributes %} - ~{{ name }}.{{ item }} - {%- endfor %} - - {% endif %} - {% endblock %} - - .. autoattribute:: str_alias - .. automethod:: from_str_alias - .. automethod:: from_pandas_api_type diff --git a/docs/source/conf.py b/docs/source/conf.py index 273c90172..f82b1237e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,7 +9,6 @@ import doctest import inspect import logging as pylogging -import subprocess # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the @@ -17,6 +16,7 @@ # import os import shutil +import subprocess import sys from sphinx.util import logging diff --git a/docs/source/dataframe_schemas.rst b/docs/source/dataframe_schemas.rst index c6a9f8f4d..ec8b0fd7f 100644 --- a/docs/source/dataframe_schemas.rst +++ b/docs/source/dataframe_schemas.rst @@ -80,7 +80,7 @@ nullable. In order to accept null values, you need to explicitly specify df = pd.DataFrame({"column1": [5, 1, np.nan]}) non_null_schema = DataFrameSchema({ - "column1": Column(pa.Int, Check(lambda x: x > 0)) + "column1": Column(pa.Float, Check(lambda x: x > 0)) }) non_null_schema.validate(df) @@ -91,18 +91,11 @@ nullable. In order to accept null values, you need to explicitly specify ... SchemaError: non-nullable series contains null values: {2: nan} -.. note:: Due to a known limitation in - `pandas prior to version 0.24.0 `_, - integer arrays cannot contain ``NaN`` values, so this schema will return - a DataFrame where ``column1`` is of type ``float``. - :class:`~pandera.dtypes.PandasDtype` does not currently support the nullable integer - array type, but you can still use the "Int64" string alias for nullable - integer arrays .. testcode:: null_values_in_columns null_schema = DataFrameSchema({ - "column1": Column(pa.Int, Check(lambda x: x > 0), nullable=True) + "column1": Column(pa.Float, Check(lambda x: x > 0), nullable=True) }) print(null_schema.validate(df)) @@ -401,7 +394,7 @@ schema, specify ``strict=True``: Traceback (most recent call last): ... - SchemaError: column 'column2' not in DataFrameSchema {'column1': } + SchemaError: column 'column2' not in DataFrameSchema {'column1': } Alternatively, if your DataFrame contains columns that are not in the schema, and you would like these to be dropped on validation, @@ -626,13 +619,17 @@ Some examples of where this can be provided to pandas are: }, ) - df = pd.DataFrame.from_dict( - { - "a": {"column1": 1, "column2": "valueA", "column3": True}, - "b": {"column1": 1, "column2": "valueB", "column3": True}, - }, - orient="index" - ).astype(schema.dtype).sort_index(axis=1) + df = ( + pd.DataFrame.from_dict( + { + "a": {"column1": 1, "column2": "valueA", "column3": True}, + "b": {"column1": 1, "column2": "valueB", "column3": True}, + }, + orient="index", + ) + .astype({col: str(dtype) for col, dtype in schema.dtypes.items()}) + .sort_index(axis=1) + ) print(schema.validate(df)) @@ -718,11 +715,11 @@ data pipeline: + 'col1': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index=None, strict=True name=None, @@ -756,15 +753,15 @@ the pipeline output. + 'column2': }, checks=[], coerce=True, - pandas_dtype=None, + dtype=None, index= - + + ] coerce=False, strict=False, diff --git a/docs/source/extensions.rst b/docs/source/extensions.rst index c0c3e4173..de4928bd8 100644 --- a/docs/source/extensions.rst +++ b/docs/source/extensions.rst @@ -94,20 +94,20 @@ The corresponding strategy for this check would be: import pandera.strategies as st def equals_strategy( - pandas_dtype: pa.PandasDtype, + pandera_dtype: pa.DataType, strategy: Optional[st.SearchStrategy] = None, *, value, ): if strategy is None: return st.pandas_dtype_strategy( - pandas_dtype, strategy=hypothesis.strategies.just(value), + pandera_dtype, strategy=hypothesis.strategies.just(value), ) return strategy.filter(lambda x: x == value) As you may notice, the ``pandera`` strategy interface is has two arguments followed by keyword-only arguments that match the check function keyword-only -check statistics. The ``pandas_dtype`` positional argument is useful for +check statistics. The ``pandera_dtype`` positional argument is useful for ensuring the correct data type. In the above example, we're using the :func:`~pandera.strategies.pandas_dtype_strategy` strategy to make sure the generated ``value`` is of the correct data type. @@ -147,7 +147,7 @@ would look like: :skipif: SKIP_STRATEGY def in_between_strategy( - pandas_dtype: pa.PandasDtype, + pandera_dtype: pa.DataType, strategy: Optional[st.SearchStrategy] = None, *, min_value, @@ -155,7 +155,7 @@ would look like: ): if strategy is None: return st.pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, min_value=min_value, max_value=max_value, exclude_min=False, diff --git a/docs/source/index.rst b/docs/source/index.rst index 9ea8ae79b..3396d18e6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -155,7 +155,7 @@ Quick Start You can pass the built-in python types that are supported by pandas, or strings representing the `legal pandas datatypes `_, -or pandera's ``PandasDtype`` enum: +or pandera's ``DataType``: .. testcode:: quick_start @@ -171,13 +171,13 @@ or pandera's ``PandasDtype`` enum: # pandas > 1.0.0 support native "string" type "str_column2": pa.Column("str"), - # pandera PandasDtype enum + # pandera DataType "int_column3": pa.Column(pa.Int), "float_column3": pa.Column(pa.Float), "str_column3": pa.Column(pa.String), }) -For more details on data types, see :class:`~pandera.dtypes.PandasDtype` +For more details on data types, see :class:`~pandera.dtypes.DataType` Schema Model diff --git a/docs/source/lazy_validation.rst b/docs/source/lazy_validation.rst index 6654aeb7b..55f1423a6 100644 --- a/docs/source/lazy_validation.rst +++ b/docs/source/lazy_validation.rst @@ -15,9 +15,9 @@ exception: * a column specified in the schema is not present in the dataframe. * if ``strict=True``, a column in the dataframe is not specified in the schema. -* the ``pandas_dtype`` does not match. +* the ``data type`` does not match. * if ``coerce=True``, the dataframe column cannot be coerced into the specified - ``pandas_dtype``. + ``data type``. * the :class:`~pandera.checks.Check` specified in one of the columns returns ``False`` or a boolean series containing at least one ``False`` value. @@ -94,8 +94,8 @@ of all schemas and schema components gives you the option of doing just this: schema_context column check DataFrameSchema column_in_dataframe [date_column] 1 column_in_schema [unknown_column] 1 - Column float_column pandas_dtype('float64') [int64] 1 - int_column pandas_dtype('int64') [object] 1 + Column float_column dtype('float64') [int64] 1 + int_column dtype('int64') [object] 1 str_column equal_to(a) [b, d] 2 Usage Tip @@ -135,8 +135,8 @@ catch these errors and inspect the failure cases in a more granular form: schema_context column check check_number \ 0 DataFrameSchema None column_in_schema None 1 DataFrameSchema None column_in_dataframe None - 2 Column int_column pandas_dtype('int64') None - 3 Column float_column pandas_dtype('float64') None + 2 Column int_column dtype('int64') None + 3 Column float_column dtype('float64') None 4 Column str_column equal_to(a) 0 failure_case index diff --git a/docs/source/schema_inference.rst b/docs/source/schema_inference.rst index a60125470..218ef5adc 100644 --- a/docs/source/schema_inference.rst +++ b/docs/source/schema_inference.rst @@ -36,14 +36,14 @@ is a simple example: - 'column2': - 'column3': + 'column1': + 'column2': + 'column3': }, checks=[], coerce=True, - pandas_dtype=None, - index=, + dtype=None, + index=, strict=False name=None, ordered=False @@ -96,19 +96,12 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr :skipif: SKIP from pandas import Timestamp - from pandera import ( - DataFrameSchema, - Column, - Check, - Index, - MultiIndex, - PandasDtype, - ) + from pandera import DataFrameSchema, Column, Check, Index, MultiIndex schema = DataFrameSchema( columns={ "column1": Column( - pandas_dtype=PandasDtype.Int64, + dtype=pandera.engines.numpy_engine.Int64, checks=[ Check.greater_than_or_equal_to(min_value=5.0), Check.less_than_or_equal_to(max_value=20.0), @@ -120,7 +113,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr regex=False, ), "column2": Column( - pandas_dtype=PandasDtype.String, + dtype=pandera.engines.numpy_engine.Object, checks=None, nullable=False, allow_duplicates=True, @@ -129,7 +122,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr regex=False, ), "column3": Column( - pandas_dtype=PandasDtype.DateTime, + dtype=pandera.engines.pandas_engine.DateTime, checks=[ Check.greater_than_or_equal_to( min_value=Timestamp("2010-01-01 00:00:00") @@ -146,7 +139,7 @@ You can also write your schema to a python script with :func:`~pandera.io.to_scr ), }, index=Index( - pandas_dtype=PandasDtype.Int64, + dtype=pandera.engines.numpy_engine.Int64, checks=[ Check.greater_than_or_equal_to(min_value=0.0), Check.less_than_or_equal_to(max_value=2.0), @@ -187,7 +180,7 @@ is a convenience method for this functionality. version: 0.6.4 columns: column1: - pandas_dtype: int64 + dtype: int64 nullable: false checks: greater_than_or_equal_to: 5.0 @@ -197,7 +190,7 @@ is a convenience method for this functionality. required: true regex: false column2: - pandas_dtype: str + dtype: object nullable: false checks: null allow_duplicates: true @@ -205,7 +198,7 @@ is a convenience method for this functionality. required: true regex: false column3: - pandas_dtype: datetime64[ns] + dtype: datetime64[ns] nullable: false checks: greater_than_or_equal_to: '2010-01-01 00:00:00' @@ -216,7 +209,7 @@ is a convenience method for this functionality. regex: false checks: null index: - - pandas_dtype: int64 + - dtype: int64 nullable: false checks: greater_than_or_equal_to: 0.0 diff --git a/docs/source/schema_models.rst b/docs/source/schema_models.rst index a8f6fd0d0..65a4970cc 100644 --- a/docs/source/schema_models.rst +++ b/docs/source/schema_models.rst @@ -72,7 +72,7 @@ Basic Usage Traceback (most recent call last): ... - pandera.errors.SchemaError: > failed element-wise validator 0: + pandera.errors.SchemaError: failed element-wise validator 0: failure cases: index failure_case @@ -121,13 +121,13 @@ You can easily convert a :class:`~pandera.model.SchemaModel` class into a )> - 'month': )> - 'day': )> + 'year': + 'month': + 'day': }, checks=[], coerce=False, - pandas_dtype=None, + dtype=None, index=None, strict=False name=None, @@ -165,12 +165,8 @@ however, a couple of gotchas. Dtype aliases ^^^^^^^^^^^^^ -The enumeration :class:`~pandera.dtypes.PandasDtype` is not directly supported because -the type parameter of a :class:`typing.Generic` cannot be an enumeration [#dtypes]_. -Instead, you can use the :mod:`pandera.typing` counterparts: -:data:`pandera.typing.Category`, :data:`pandera.typing.Float32`, ... - -:green:`✔` Good: +:mod:`pandera.typing` aliases will be deprecated in a future version, +please use :class:`~pandera.dtypes.DataType` subclasses instead. .. code-block:: @@ -180,21 +176,6 @@ Instead, you can use the :mod:`pandera.typing` counterparts: class Schema(pa.SchemaModel): a: Series[String] -:red:`✘` Bad: - -.. testcode:: dataframe_schema_model - :skipif: SKIP_PANDAS_LT_V1 - - class Schema(pa.SchemaModel): - a: Series[pa.PandasDtype.String] - -.. testoutput:: dataframe_schema_model - :skipif: SKIP_PANDAS_LT_V1 - - Traceback (most recent call last): - ... - TypeError: python type '' not recognized as pandas data type - Type Vs instance ^^^^^^^^^^^^^^^^ @@ -437,8 +418,8 @@ the class-based API: )> - )> + + ] coerce=True, strict=True, @@ -531,7 +512,7 @@ Column/Index checks Traceback (most recent call last): ... - pandera.errors.SchemaError: > failed series validator 1: + pandera.errors.SchemaError: failed series validator 1: .. _schema_model_dataframe_check: @@ -611,7 +592,7 @@ The custom checks are inherited and therefore can be overwritten by the subclass Traceback (most recent call last): ... - pandera.errors.SchemaError: > failed element-wise validator 0: + pandera.errors.SchemaError: failed element-wise validator 0: failure cases: index failure_case @@ -700,11 +681,3 @@ the class scope, and it will respect the alias. 2020 a 0 99 101 - - -Footnotes ---------- - -.. [#dtypes] It is actually possible to use a PandasDtype by encasing it in a - :class:`typing.Literal` like ``Series[Literal[PandasDtype.Category]]``. - :mod:`pandera.typing` defines aliases to reduce boilerplate. diff --git a/environment.yml b/environment.yml index 5b0107bd3..d15718413 100644 --- a/environment.yml +++ b/environment.yml @@ -24,13 +24,13 @@ dependencies: # testing - isort >= 5.7.0 - codecov - - mypy + - mypy >= 0.902 # mypy no longer bundle stubs for third-party libraries - pylint >= 2.7.2 - pytest - pytest-cov - pytest-xdist - setuptools >= 52.0.0 - - nox + - nox = 2020.12.31 # pinning due to UnicodeDecodeError, see https://github.com/pandera-dev/pandera/pull/504/checks?check_run_id=2841360122 - importlib_metadata # required if python < 3.8 # documentation @@ -51,3 +51,6 @@ dependencies: - pip: - furo + - types-click + - types-pyyaml + - types-pkg_resources diff --git a/noxfile.py b/noxfile.py index 9c95dd188..bcd132984 100644 --- a/noxfile.py +++ b/noxfile.py @@ -3,16 +3,15 @@ import os import shutil import sys -from typing import Dict, List, cast +from typing import Dict, List # setuptools must be imported before distutils ! -import setuptools # pylint:disable=unused-import +import setuptools # pylint:disable=unused-import # noqa: F401 from distutils.core import run_setup # pylint:disable=wrong-import-order import nox from nox import Session from pkg_resources import Requirement, parse_requirements -from packaging import version nox.options.sessions = ( @@ -26,13 +25,13 @@ ) DEFAULT_PYTHON = "3.8" -PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] +PYTHON_VERSIONS = ["3.7", "3.8", "3.9"] PACKAGE = "pandera" SOURCE_PATHS = PACKAGE, "tests", "noxfile.py" REQUIREMENT_PATH = "requirements-dev.txt" -ALWAYS_USE_PIP = ["furo"] +ALWAYS_USE_PIP = ["furo", "types-click", "types-pyyaml", "types-pkg_resources"] CI_RUN = os.environ.get("CI") == "true" if CI_RUN: @@ -169,21 +168,18 @@ def install_from_requirements(session: Session, *packages: str) -> None: def install_extras( session: Session, - pandas: str = "latest", extra: str = "core", - force_pip=False, + force_pip: bool = False, ) -> None: """Install dependencies.""" - pandas_version = "" if pandas == "latest" else f"=={pandas}" specs = [ - spec if spec != "pandas" else f"pandas{pandas_version}" + spec if spec != "pandas" else "pandas" for spec in REQUIRES[extra].values() if spec not in ALWAYS_USE_PIP ] if extra == "core": specs.append(REQUIRES["all"]["hypothesis"]) - session.install(*ALWAYS_USE_PIP) if ( isinstance(session.virtualenv, nox.virtualenv.CondaEnv) and not force_pip @@ -193,6 +189,8 @@ def install_extras( else: print("using pip installer") session.install(*specs) + + session.install(*ALWAYS_USE_PIP) # always use pip for these packages session.install("-e", ".", "--no-deps") # install pandera @@ -280,31 +278,11 @@ def lint(session: Session) -> None: @nox.session(python=PYTHON_VERSIONS) def mypy(session: Session) -> None: """Type-check using mypy.""" - python_version = version.parse(cast(str, session.python)) - install_extras( - session, - extra="all", - # this is a hack until typed-ast conda package starts working again, - # basically this issue comes up: - # https://github.com/python/mypy/pull/2906 - force_pip=python_version == version.parse("3.7"), - ) + install_extras(session, extra="all") args = session.posargs or SOURCE_PATHS session.run("mypy", "--follow-imports=silent", *args, silent=True) -def _invalid_python_pandas_versions(session: Session, pandas: str) -> bool: - python_version = version.parse(cast(str, session.python)) - if pandas == "0.25.3" and ( - python_version >= version.parse("3.9") - # this is just a bandaid until support for 0.25.3 is dropped - or python_version == version.parse("3.7") - ): - print("Python 3.9 does not support pandas 0.25.3") - return True - return False - - EXTRA_NAMES = [ extra for extra in REQUIRES @@ -313,21 +291,12 @@ def _invalid_python_pandas_versions(session: Session, pandas: str) -> bool: @nox.session(python=PYTHON_VERSIONS) -@nox.parametrize("pandas", ["0.25.3", "latest"]) @nox.parametrize("extra", EXTRA_NAMES) -def tests(session: Session, pandas: str, extra: str) -> None: +def tests(session: Session, extra: str) -> None: """Run the test suite.""" - if _invalid_python_pandas_versions(session, pandas): - return - python_version = version.parse(cast(str, session.python)) install_extras( session, - pandas, extra, - # this is a hack until typed-ast conda package starts working again, - # basically this issue comes up: - # https://github.com/python/mypy/pull/2906 - force_pip=python_version == version.parse("3.7"), ) if session.posargs: @@ -356,38 +325,15 @@ def tests(session: Session, pandas: str, extra: str) -> None: @nox.session(python=PYTHON_VERSIONS) -@nox.parametrize("pandas", ["0.25.3", "latest"]) -def docs(session: Session, pandas: str) -> None: +def docs(session: Session) -> None: """Build the documentation.""" - if _invalid_python_pandas_versions(session, pandas): - return - python_version = version.parse(cast(str, session.python)) - install_extras( - session, - pandas, - extra="all", - # this is a hack until typed-ast conda package starts working again, - # basically this issue comes up: - # https://github.com/python/mypy/pull/2906 - force_pip=python_version == version.parse("3.7"), - ) + install_extras(session, extra="all", force_pip=True) session.chdir("docs") - shutil.rmtree(os.path.join("_build"), ignore_errors=True) - args = session.posargs or [ - "-v", - "-v", - "-W", - "-E", - "-b=doctest", - "source", - "_build", - ] - session.run("sphinx-build", *args) - # build html docs if not CI_RUN and not session.posargs: - shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) + shutil.rmtree(os.path.join("_build"), ignore_errors=True) + shutil.rmtree(os.path.join("generated"), ignore_errors=True) session.run( "sphinx-build", "-W", @@ -398,3 +344,15 @@ def docs(session: Session, pandas: str) -> None: "source", os.path.join("_build", "html", ""), ) + else: + shutil.rmtree(os.path.join("_build"), ignore_errors=True) + args = session.posargs or [ + "-v", + "-v", + "-W", + "-E", + "-b=doctest", + "source", + "_build", + ] + session.run("sphinx-build", *args) diff --git a/pandera/__init__.py b/pandera/__init__.py index 0c804fb14..c0968ff43 100644 --- a/pandera/__init__.py +++ b/pandera/__init__.py @@ -1,18 +1,18 @@ """A flexible and expressive pandas validation library.""" -from pandera.dtypes_ import ( +import platform + +from pandera.dtypes import ( Bool, Category, Complex, Complex64, Complex128, - Complex256, DataType, DateTime, Float, Float16, Float32, Float64, - Float128, Int, Int8, Int16, @@ -44,7 +44,6 @@ from . import constants, errors, pandas_accessor from .checks import Check from .decorators import check_input, check_io, check_output, check_types -from .dtypes import LEGACY_PANDAS, PandasDtype from .hypotheses import Hypothesis from .model import SchemaModel from .model_components import Field, check, dataframe_check @@ -52,3 +51,6 @@ from .schema_inference import infer_schema from .schemas import DataFrameSchema, SeriesSchema from .version import __version__ + +if platform.system() != "Windows": + from pandera.dtypes import Complex256, Float128 diff --git a/pandera/dtypes.py b/pandera/dtypes.py index bf68edd54..f0ccde032 100644 --- a/pandera/dtypes.py +++ b/pandera/dtypes.py @@ -1,458 +1,484 @@ -# pylint: disable=no-member,too-many-public-methods -"""Schema datatypes.""" - -from enum import Enum -from typing import Optional, Union - -import numpy as np -import pandas as pd -from packaging import version - -PandasExtensionType = pd.core.dtypes.base.ExtensionDtype - -LEGACY_PANDAS = version.parse(pd.__version__).major < 1 # type: ignore -NUMPY_NONNULLABLE_INT_DTYPES = [ - "int", - "int_", - "int8", - "int16", - "int32", - "int64", - "uint8", - "uint16", - "uint32", - "uint64", -] - -NUMPY_TYPES = frozenset( - [item for sublist in np.sctypes.values() for item in sublist] # type: ignore -).union( - frozenset([np.complex_, np.int_, np.uint, np.float_, np.str_, np.bool_]) +"""Pandera data types.""" +# pylint:disable=too-many-ancestors +import dataclasses +import inspect +from abc import ABC +from typing import ( + Any, + Callable, + Iterable, + Optional, + Tuple, + Type, + TypeVar, + Union, ) -# for int and float dtype, delegate string representation to the -# default based on OS. In Windows, pandas defaults to int64 while numpy -# defaults to int32. -_DEFAULT_PANDAS_INT_TYPE = str(pd.Series([1]).dtype) -_DEFAULT_PANDAS_FLOAT_TYPE = str(pd.Series([1.0]).dtype) -_DEFAULT_PANDAS_COMPLEX_TYPE = str(pd.Series([complex(1)]).dtype) -_DEFAULT_NUMPY_INT_TYPE = str(np.dtype(int)) -_DEFAULT_NUMPY_FLOAT_TYPE = str(np.dtype(float)) +class DataType(ABC): + """Base class of all Pandera data types.""" + + continuous: bool = False + + def __init__(self): + if self.__class__ is DataType: + raise TypeError( + f"{self.__class__.__name__} may not be instantiated." + ) + + def coerce(self, data_container: Any): + """Coerce data container to the dtype.""" + raise NotImplementedError() + + def __call__(self, data_container: Any): + """Coerce data container to the dtype.""" + return self.coerce(data_container) + + def check(self, pandera_dtype: "DataType") -> bool: + """Check that pandera :class:`DataType`s are equivalent.""" + return self == pandera_dtype + + def __repr__(self) -> str: + return f"DataType({str(self)})" + + def __str__(self) -> str: + raise NotImplementedError() + + def __hash__(self) -> int: + raise NotImplementedError() -def is_extension_dtype(dtype): - """Check if a value is a pandas extension type or instance of one.""" - return isinstance(dtype, PandasExtensionType) or ( - isinstance(dtype, type) and issubclass(dtype, PandasExtensionType) - ) +_Dtype = TypeVar("_Dtype", bound=DataType) +_DataTypeClass = Type[_Dtype] -class PandasDtype(Enum): - # pylint: disable=line-too-long,invalid-name - """Enumerate all valid pandas data types. - - ``pandera`` follows the - `numpy data types `_ - subscribed to by ``pandas`` and by default supports using the numpy data - type string aliases to validate DataFrame or Series dtypes. - - This class simply enumerates the valid numpy dtypes for pandas arrays. - For convenience ``PandasDtype`` enums can all be accessed in the top-level - ``pandera`` name space via the same enum name. - - :examples: - - >>> import pandas as pd - >>> import pandera as pa - >>> - >>> - >>> pa.SeriesSchema(pa.Int).validate(pd.Series([1, 2, 3])) - 0 1 - 1 2 - 2 3 - dtype: int64 - >>> pa.SeriesSchema(pa.Float).validate(pd.Series([1.1, 2.3, 3.4])) - 0 1.1 - 1 2.3 - 2 3.4 - dtype: float64 - >>> pa.SeriesSchema(pa.String).validate(pd.Series(["a", "b", "c"])) - 0 a - 1 b - 2 c - dtype: object - - Alternatively, you can use built-in python scalar types for integers, - floats, booleans, and strings: - - >>> pa.SeriesSchema(int).validate(pd.Series([1, 2, 3])) - 0 1 - 1 2 - 2 3 - dtype: int64 - - You can also use the pandas string aliases in the schema definition: - - >>> pa.SeriesSchema("int").validate(pd.Series([1, 2, 3])) - 0 1 - 1 2 - 2 3 - dtype: int64 - - .. note:: - ``pandera`` also offers limited support for - `pandas extension types `_, - however since the release of pandas 1.0.0 there are backwards - incompatible extension types like the ``Integer`` array. The extension - types, e.g. ``pd.IntDtype64()`` and their string alias should work - when supplied to the ``pandas_dtype`` argument, unless otherwise - specified below, but this functionality is only tested for - pandas >= 1.0.0. Extension types in earlier versions are not guaranteed - to work as the ``pandas_dtype`` argument in schemas or schema - components. +def immutable( + pandera_dtype_cls: Optional[_DataTypeClass] = None, **dataclass_kwargs: Any +) -> Union[_DataTypeClass, Callable[[_DataTypeClass], _DataTypeClass]]: + """:func:`dataclasses.dataclass` decorator with different default values: + `frozen=True`, `init=False`, `repr=False`. + + In addition, `init=False` disables inherited `__init__` method to ensure + the DataType's default attributes are not altered during initialization. + + :param dtype: :class:`DataType` to decorate. + :param dataclass_kwargs: Keywords arguments forwarded to + :func:`dataclasses.dataclass`. + :returns: Immutable :class:`~pandera.dtypes.DataType` """ + kwargs = {"frozen": True, "init": False, "repr": False} + kwargs.update(dataclass_kwargs) - Bool = "bool" #: ``"bool"`` numpy dtype - DateTime = "datetime64[ns]" #: ``"datetime64[ns]"`` numpy dtype - Timedelta = "timedelta64[ns]" #: ``"timedelta64[ns]"`` numpy dtype - Category = "category" #: pandas ``"categorical"`` datatype - Float = "float" #: ``"float"`` numpy dtype - Float16 = "float16" #: ``"float16"`` numpy dtype - Float32 = "float32" #: ``"float32"`` numpy dtype - Float64 = "float64" #: ``"float64"`` numpy dtype - Int = "int" #: ``"int"`` numpy dtype - Int8 = "int8" #: ``"int8"`` numpy dtype - Int16 = "int16" #: ``"int16"`` numpy dtype - Int32 = "int32" #: ``"int32"`` numpy dtype - Int64 = "int64" #: ``"int64"`` numpy dtype - UInt8 = "uint8" #: ``"uint8"`` numpy dtype - UInt16 = "uint16" #: ``"uint16"`` numpy dtype - UInt32 = "uint32" #: ``"uint32"`` numpy dtype - UInt64 = "uint64" #: ``"uint64"`` numpy dtype - INT8 = "Int8" #: ``"Int8"`` pandas dtype:: pandas 0.24.0+ - INT16 = "Int16" #: ``"Int16"`` pandas dtype: pandas 0.24.0+ - INT32 = "Int32" #: ``"Int32"`` pandas dtype: pandas 0.24.0+ - INT64 = "Int64" #: ``"Int64"`` pandas dtype: pandas 0.24.0+ - UINT8 = "UInt8" #: ``"UInt8"`` pandas dtype: pandas 0.24.0+ - UINT16 = "UInt16" #: ``"UInt16"`` pandas dtype: pandas 0.24.0+ - UINT32 = "UInt32" #: ``"UInt32"`` pandas dtype: pandas 0.24.0+ - UINT64 = "UInt64" #: ``"UInt64"`` pandas dtype: pandas 0.24.0+ - Object = "object" #: ``"object"`` numpy dtype - Complex = "complex" #: ``"complex"`` numpy dtype - Complex64 = "complex64" #: ``"complex"`` numpy dtype - Complex128 = "complex128" #: ``"complex"`` numpy dtype - Complex256 = "complex256" #: ``"complex"`` numpy dtype - String = "str" #: ``"str"`` numpy dtype - - #: ``"string"`` pandas dtypes: pandas 1.0.0+. For <1.0.0, this enum will - #: fall back on the str-as-object-array representation. - STRING = "string" - - @property - def str_alias(self): - """Get datatype string alias.""" - return { - "int": _DEFAULT_PANDAS_INT_TYPE, - "float": _DEFAULT_PANDAS_FLOAT_TYPE, - "complex": _DEFAULT_PANDAS_COMPLEX_TYPE, - "str": "object", - "string": "object" if LEGACY_PANDAS else "string", - }.get(self.value, self.value) - - @classmethod - def from_str_alias(cls, str_alias: str) -> "PandasDtype": - """Get PandasDtype from string alias. - - :param: pandas dtype string alias from - https://pandas.pydata.org/pandas-docs/stable/getting_started/basics.html#basics-dtypes - :returns: pandas dtype - """ - pandas_dtype = { - "bool": cls.Bool, - "datetime64[ns]": cls.DateTime, - "timedelta64[ns]": cls.Timedelta, - "category": cls.Category, - "float": cls.Float, - "float16": cls.Float16, - "float32": cls.Float32, - "float64": cls.Float64, - "int": cls.Int, - "int8": cls.Int8, - "int16": cls.Int16, - "int32": cls.Int32, - "int64": cls.Int64, - "uint8": cls.UInt8, - "uint16": cls.UInt16, - "uint32": cls.UInt32, - "uint64": cls.UInt64, - "Int8": cls.INT8, - "Int16": cls.INT16, - "Int32": cls.INT32, - "Int64": cls.INT64, - "UInt8": cls.UINT8, - "UInt16": cls.UINT16, - "UInt32": cls.UINT32, - "UInt64": cls.UINT64, - "object": cls.Object, - "complex": cls.Complex, - "complex64": cls.Complex64, - "complex128": cls.Complex128, - "complex256": cls.Complex256, - "str": cls.String, - "string": cls.String if LEGACY_PANDAS else cls.STRING, - }.get(str_alias) - - if pandas_dtype is None: - raise TypeError( - f"pandas dtype string alias '{str_alias}' not recognized" - ) + def _wrapper(pandera_dtype_cls: _DataTypeClass) -> _DataTypeClass: + immutable_dtype = dataclasses.dataclass(**kwargs)(pandera_dtype_cls) + if not kwargs["init"]: - return pandas_dtype - - @classmethod - def from_pandas_api_type(cls, pandas_api_type: str) -> "PandasDtype": - """Get PandasDtype enum from pandas api type. - - :param pandas_api_type: string output from - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.api.types.infer_dtype.html - :returns: pandas dtype - """ - if pandas_api_type.startswith("mixed"): - return cls.Object - - pandas_dtype = { - "string": cls.String, - "floating": cls.Float, - "integer": cls.Int, - "categorical": cls.Category, - "boolean": cls.Bool, - "datetime64": cls.DateTime, - "datetime": cls.DateTime, - "timedelta64": cls.Timedelta, - "timedelta": cls.Timedelta, - }.get(pandas_api_type) - - if pandas_dtype is None: - raise TypeError( - f"pandas api type '{pandas_api_type}' not recognized" - ) + def __init__(self): # pylint:disable=unused-argument + pass - return pandas_dtype - - @classmethod - def from_python_type(cls, python_type: type) -> "PandasDtype": - """Get PandasDtype enum from built-in python type. - - :param python_type: built-in python type. Allowable types are: - str, int, float, and bool. - """ - pandas_dtype = { - bool: cls.Bool, - str: cls.String, - int: cls.Int, - float: cls.Float, - object: cls.Object, - complex: cls.Complex, - }.get(python_type) - - if pandas_dtype is None: - raise TypeError( - f"python type '{python_type}' not recognized as pandas data type" - ) + # delattr(immutable_dtype, "__init__") doesn't work because + # super.__init__ would still exist. + setattr(immutable_dtype, "__init__", __init__) + + return immutable_dtype + + if pandera_dtype_cls is None: + return _wrapper + + return _wrapper(pandera_dtype_cls) + + +############################################################################### +# number +############################################################################### + + +@immutable +class _Number(DataType): + """Semantic representation of a numeric data type.""" + + exact: Optional[bool] = None + + def check(self, pandera_dtype: "DataType") -> bool: + if self.__class__ is _Number: + return isinstance(pandera_dtype, _Number) + return super().check(pandera_dtype) + + +@immutable +class _PhysicalNumber(_Number): + + bit_width: Optional[int] = None + _base_name: Optional[str] = dataclasses.field( + default=None, init=False, repr=False + ) + + def __eq__(self, obj: object) -> bool: + if isinstance(obj, type(self)): + return obj.bit_width == self.bit_width + return super().__eq__(obj) + + def __str__(self) -> str: + return f"{self._base_name}{self.bit_width}" + + +############################################################################### +# boolean +############################################################################### + + +@immutable +class Bool(_Number): + """Semantic representation of a boolean data type.""" + + def __str__(self) -> str: + return "bool" + + +Boolean = Bool +############################################################################### +# signed integer +############################################################################### - return pandas_dtype - - @classmethod - def from_numpy_type(cls, numpy_type: np.dtype) -> "PandasDtype": - """Get PandasDtype enum from numpy type. - - :param numpy_type: numpy data type. - """ - pd_dtype = pd.api.types.pandas_dtype(numpy_type) - return cls.from_str_alias(pd_dtype.name) - - @classmethod - def get_dtype( - cls, - pandas_dtype_arg: Union[ - str, - type, - "PandasDtype", - "pd.core.dtypes.dtypes.ExtensionDtype", - np.dtype, - ], - ) -> Optional[ - Union["PandasDtype", "pd.core.dtypes.dtypes.ExtensionDtype"] - ]: - """Get PandasDtype from schema argument. - - :param pandas_dtype_arg: ``pandas_dtype`` argument specified in schema - definition. - """ - dtype_ = pandas_dtype_arg - - if dtype_ is None: - return dtype_ - elif isinstance(dtype_, PandasDtype): - return pandas_dtype_arg - elif is_extension_dtype(dtype_): - if isinstance(dtype_, type): - try: - # Convert to str here because some pandas dtypes allow - # an empty constructor for compatibility but fail on - # str(). e.g: PeriodDtype - str(dtype_().name) - return dtype_() - except (TypeError, AttributeError) as err: - raise TypeError( - f"Pandas dtype {dtype_} cannot be instantiated: " - f"{err}\n Usage Tip: Use an instance or a string " - "representation." - ) from err - return dtype_ - - if dtype_ in NUMPY_TYPES: - dtype_ = cls.from_numpy_type(dtype_) # type: ignore - elif isinstance(dtype_, str): - dtype_ = cls.from_str_alias(dtype_) - elif isinstance(dtype_, type): - dtype_ = cls.from_python_type(dtype_) - - if isinstance(dtype_, PandasDtype): - return dtype_ - raise TypeError( - "type of `pandas_dtype` argument not recognized: " - f"{type(pandas_dtype_arg)}. Please specify a pandera PandasDtype " - "enum, built-in python type, pandas data type, pandas data type " - "string alias, or numpy data type string alias" + +@immutable(eq=False) +class Int(_PhysicalNumber): # type: ignore + """Semantic representation of an integer data type.""" + + _base_name = "int" + exact = True + bit_width = 64 + signed: bool = dataclasses.field(default=True, init=False) + + def check(self, pandera_dtype: DataType) -> bool: + return ( + isinstance(pandera_dtype, Int) + and self.signed == pandera_dtype.signed + and self.bit_width == pandera_dtype.bit_width ) - @classmethod - def get_str_dtype(cls, pandas_dtype_arg) -> Optional[str]: - """Get pandas-compatible string representation of dtype.""" - pandas_dtype = cls.get_dtype(pandas_dtype_arg) - if pandas_dtype is None: - return pandas_dtype - elif isinstance(pandas_dtype, PandasDtype): - return pandas_dtype.str_alias - return str(pandas_dtype) - - def __eq__(self, other): - # pylint: disable=comparison-with-callable - # see https://github.com/PyCQA/pylint/issues/2306 - if other is None: - return False - other_dtype = PandasDtype.get_dtype(other) - if self.value == "string" and LEGACY_PANDAS: - return PandasDtype.String.value == other_dtype.value - elif self.value == "string": - return self.value == other_dtype.value - return self.str_alias == other_dtype.str_alias - - def __hash__(self): - if self is PandasDtype.Int: - hash_obj = _DEFAULT_PANDAS_INT_TYPE - elif self is PandasDtype.Float: - hash_obj = _DEFAULT_PANDAS_FLOAT_TYPE - else: - hash_obj = self.str_alias - return id(hash_obj) - - @property - def numpy_dtype(self): - """Get numpy data type.""" - if self is PandasDtype.Category: - raise TypeError( - "the pandas Categorical data type doesn't have a numpy " - "equivalent." - ) + def __str__(self) -> str: + if self.__class__ is Int: + return "int" + return super().__str__() + + +@immutable +class Int64(Int): + """Semantic representation of an integer data type stored in 64 bits.""" + + bit_width = 64 + + +@immutable +class Int32(Int64): + """Semantic representation of an integer data type stored in 32 bits.""" + + bit_width = 32 + + +@immutable +class Int16(Int32): + """Semantic representation of an integer data type stored in 16 bits.""" + + bit_width = 16 + + +@immutable +class Int8(Int16): + """Semantic representation of an integer data type stored in 8 bits.""" + + bit_width = 8 + + +############################################################################### +# unsigned integer +############################################################################### + + +@immutable +class UInt(Int): + """Semantic representation of an unsigned integer data type.""" + + _base_name = "uint" + signed: bool = dataclasses.field(default=False, init=False) + + def __str__(self) -> str: + if self.__class__ is UInt: + return "uint" + return super().__str__() + + +@immutable +class UInt64(UInt): + """Semantic representation of an unsigned integer data type stored + in 64 bits.""" + + bit_width = 64 + + +@immutable +class UInt32(UInt64): + """Semantic representation of an unsigned integer data type stored + in 32 bits.""" + + bit_width = 32 + + +@immutable +class UInt16(UInt32): + """Semantic representation of an unsigned integer data type stored + in 16 bits.""" + + bit_width = 16 + + +@immutable +class UInt8(UInt16): + """Semantic representation of an unsigned integer data type stored + in 8 bits.""" + + bit_width = 8 + + +############################################################################### +# float +############################################################################### + + +@immutable(eq=False) +class Float(_PhysicalNumber): # type: ignore + """Semantic representation of a floating data type.""" - # pylint: disable=comparison-with-callable - if self.value in {"str", "string"}: - dtype = np.dtype("str") - else: - dtype = np.dtype(self.str_alias.lower()) - return dtype - - @property - def is_int(self) -> bool: - """Return True if PandasDtype is an integer.""" - return self.value.lower().startswith("int") - - @property - def is_nullable_int(self) -> bool: - """Return True if PandasDtype is a nullable integer.""" - return self.value.startswith("Int") - - @property - def is_nonnullable_int(self) -> bool: - """Return True if PandasDtype is a non-nullable integer.""" - return self.value.startswith("int") - - @property - def is_uint(self) -> bool: - """Return True if PandasDtype is an unsigned integer.""" - return self.value.lower().startswith("uint") - - @property - def is_nullable_uint(self) -> bool: - """Return True if PandasDtype is a nullable unsigned integer.""" - return self.value.startswith("UInt") - - @property - def is_nonnullable_uint(self) -> bool: - """Return True if PandasDtype is a non-nullable unsigned integer.""" - return self.value.startswith("uint") - - @property - def is_float(self) -> bool: - """Return True if PandasDtype is a float.""" - return self.value.startswith("float") - - @property - def is_complex(self) -> bool: - """Return True if PandasDtype is a complex number.""" - return self.value.startswith("complex") - - @property - def is_bool(self) -> bool: - """Return True if PandasDtype is a boolean.""" - return self is PandasDtype.Bool - - @property - def is_string(self) -> bool: - """Return True if PandasDtype is a string.""" - return self in [PandasDtype.String, PandasDtype.STRING] - - @property - def is_category(self) -> bool: - """Return True if PandasDtype is a category.""" - return self is PandasDtype.Category - - @property - def is_datetime(self) -> bool: - """Return True if PandasDtype is a datetime.""" - return self is PandasDtype.DateTime - - @property - def is_timedelta(self) -> bool: - """Return True if PandasDtype is a timedelta.""" - return self is PandasDtype.Timedelta - - @property - def is_object(self) -> bool: - """Return True if PandasDtype is an object.""" - return self is PandasDtype.Object - - @property - def is_continuous(self) -> bool: - """Return True if PandasDtype is a continuous datatype.""" + _base_name = "float" + continuous = True + exact = False + bit_width = 64 + + def check(self, pandera_dtype: DataType) -> bool: return ( - self.is_int - or self.is_uint - or self.is_float - or self.is_complex - or self.is_datetime - or self.is_timedelta + isinstance(pandera_dtype, Float) + and self.bit_width == pandera_dtype.bit_width ) + + def __str__(self) -> str: + if self.__class__ is Float: + return "float" + return super().__str__() + + +@immutable +class Float128(Float): + """Semantic representation of a floating data type stored in 128 bits.""" + + bit_width = 128 + + +@immutable +class Float64(Float128): + """Semantic representation of a floating data type stored in 64 bits.""" + + bit_width = 64 + + +@immutable +class Float32(Float64): + """Semantic representation of a floating data type stored in 32 bits.""" + + bit_width = 32 + + +@immutable +class Float16(Float32): + """Semantic representation of a floating data type stored in 16 bits.""" + + bit_width = 16 + + +############################################################################### +# complex +############################################################################### + + +@immutable(eq=False) +class Complex(_PhysicalNumber): # type: ignore + """Semantic representation of a complex number data type.""" + + _base_name = "complex" + bit_width = 128 + + def check(self, pandera_dtype: DataType) -> bool: + return ( + isinstance(pandera_dtype, Complex) + and self.bit_width == pandera_dtype.bit_width + ) + + def __str__(self) -> str: + if self.__class__ is Complex: + return "complex" + return super().__str__() + + +@immutable +class Complex256(Complex): + """Semantic representation of a complex number data type stored + in 256 bits.""" + + bit_width = 256 + + +@immutable +class Complex128(Complex): + """Semantic representation of a complex number data type stored + in 128 bits.""" + + bit_width = 128 + + +@immutable +class Complex64(Complex128): + """Semantic representation of a complex number data type stored + in 64 bits.""" + + bit_width = 64 + + +############################################################################### +# nominal +############################################################################### + + +@immutable(init=True) +class Category(DataType): # type: ignore + """Semantic representation of a categorical data type.""" + + categories: Optional[Tuple[Any]] = None # tuple to ensure safe hash + ordered: bool = False + + def __init__( + self, categories: Optional[Iterable[Any]] = None, ordered: bool = False + ): + # Define __init__ to avoid exposing pylint errors to end users. + super().__init__() + if categories is not None and not isinstance(categories, tuple): + object.__setattr__(self, "categories", tuple(categories)) + object.__setattr__(self, "ordered", ordered) + + def check(self, pandera_dtype: "DataType") -> bool: + if isinstance(pandera_dtype, Category) and ( + self.categories is None or pandera_dtype.categories is None + ): + # Category without categories is a superset of any Category + # Allow end-users to not list categories when validating. + return True + + return super().check(pandera_dtype) + + def __str__(self) -> str: + return "category" + + +@immutable +class String(DataType): + """Semantic representation of a string data type.""" + + def __str__(self) -> str: + return "string" + + +############################################################################### +# time +############################################################################### + + +@immutable +class Date(DataType): + """Semantic representation of a date data type.""" + + def __str__(self) -> str: + return "date" + + +@immutable +class Timestamp(Date): + """Semantic representation of a timestamp data type.""" + + def __str__(self) -> str: + return "timestamp" + + +DateTime = Timestamp + + +@immutable +class Timedelta(DataType): + """Semantic representation of a delta time data type.""" + + def __str__(self) -> str: + return "timedelta" + + +############################################################################### +# Utilities +############################################################################### + + +def is_subdtype( + arg1: Union[DataType, Type[DataType]], + arg2: Union[DataType, Type[DataType]], +) -> bool: + """Returns True if first argument is lower/equal in DataType hierarchy.""" + arg1_cls = arg1 if inspect.isclass(arg1) else arg1.__class__ + arg2_cls = arg2 if inspect.isclass(arg2) else arg2.__class__ + return issubclass(arg1_cls, arg2_cls) # type: ignore + + +def is_int(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is an integer.""" + return is_subdtype(pandera_dtype, Int) + + +def is_uint(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is + an unsigned integer.""" + return is_subdtype(pandera_dtype, UInt) + + +def is_float(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a float.""" + return is_subdtype(pandera_dtype, Float) + + +def is_complex(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a complex number.""" + return is_subdtype(pandera_dtype, Complex) + + +def is_numeric(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a complex number.""" + return is_subdtype(pandera_dtype, _Number) + + +def is_bool(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a boolean.""" + return is_subdtype(pandera_dtype, Bool) + + +def is_string(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a string.""" + return is_subdtype(pandera_dtype, String) + + +def is_category(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a category.""" + return is_subdtype(pandera_dtype, Category) + + +def is_datetime(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a datetime.""" + return is_subdtype(pandera_dtype, DateTime) + + +def is_timedelta(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a timedelta.""" + return is_subdtype(pandera_dtype, Timedelta) diff --git a/pandera/dtypes_.py b/pandera/dtypes_.py deleted file mode 100644 index 18ec1f50d..000000000 --- a/pandera/dtypes_.py +++ /dev/null @@ -1,381 +0,0 @@ -"""Pandera data types.""" -# pylint:disable=too-many-ancestors -import dataclasses -from abc import ABC -from typing import ( - Any, - Callable, - Iterable, - Optional, - Tuple, - Type, - TypeVar, - Union, -) - - -class DataType(ABC): - """Base class of all Pandera data types.""" - - def __init__(self): - if self.__class__ is DataType: - raise TypeError( - f"{self.__class__.__name__} may not be instantiated." - ) - - def coerce(self, data_container: Any): - """Coerce data container to the dtype.""" - raise NotImplementedError() - - def __call__(self, data_container: Any): - """Coerce data container to the dtype.""" - return self.coerce(data_container) - - def check(self, pandera_dtype: "DataType") -> bool: - """Check that pandera :class:`DataType`s are equivalent.""" - if not isinstance(pandera_dtype, DataType): - return False - return self == pandera_dtype - - def __repr__(self) -> str: - return f"DataType({str(self)})" - - def __str__(self) -> str: - raise NotImplementedError() - - def __hash__(self) -> int: - raise NotImplementedError() - - -_Dtype = TypeVar("_Dtype", bound=DataType) -_DataTypeClass = Type[_Dtype] - - -def immutable( - pandera_dtype_cls: Optional[_DataTypeClass] = None, **dataclass_kwargs: Any -) -> Union[_DataTypeClass, Callable[[_DataTypeClass], _DataTypeClass]]: - """:func:`dataclasses.dataclass` decorator with different default values: - `frozen=True`, `init=False`, `repr=False`. - - In addition, `init=False` disables inherited `__init__` method to ensure - the DataType's default attributes are not altered during initialization. - - :param dtype: :class:`DataType` to decorate. - :param dataclass_kwargs: Keywords arguments forwarded to - :func:`dataclasses.dataclass`. - :returns: Immutable :class:`~pandera.dtypes.DataType` - """ - kwargs = {"frozen": True, "init": False, "repr": False} - kwargs.update(dataclass_kwargs) - - def _wrapper(pandera_dtype_cls: _DataTypeClass) -> _DataTypeClass: - immutable_dtype = dataclasses.dataclass(**kwargs)(pandera_dtype_cls) - if not kwargs["init"]: - - def __init__(self): # pylint:disable=unused-argument - pass - - # delattr(immutable_dtype, "__init__") doesn't work because - # super.__init__ would still exist. - setattr(immutable_dtype, "__init__", __init__) - - return immutable_dtype - - if pandera_dtype_cls is None: - return _wrapper - - return _wrapper(pandera_dtype_cls) - - -################################################################################ -# boolean -################################################################################ - - -@immutable -class Bool(DataType): - """Semantic representation of a boolean data type.""" - - def __str__(self) -> str: - return "bool" - - -Boolean = Bool - -################################################################################ -# number -################################################################################ - - -@immutable -class _Number(DataType): - """Semantic representation of a numeric data type.""" - - continuous: Optional[bool] = None - exact: Optional[bool] = None - - def check(self, pandera_dtype: "DataType") -> bool: - if self.__class__ is _Number: - return isinstance(pandera_dtype, (Int, Float, Complex)) - return super().check(pandera_dtype) - - -@immutable -class _PhysicalNumber(_Number): - - bit_width: Optional[int] = None - _base_name: Optional[str] = dataclasses.field( - default=None, init=False, repr=False - ) - - def __eq__(self, obj: object) -> bool: - if isinstance(obj, type(self)): - return obj.bit_width == self.bit_width - return super().__eq__(obj) - - def __str__(self) -> str: - return f"{self._base_name}{self.bit_width}" - - -################################################################################ -## signed integer -################################################################################ - - -@immutable(eq=False) -class Int(_PhysicalNumber): # type: ignore - """Semantic representation of an integer data type.""" - - _base_name = "int" - continuous = False - exact = True - bit_width = 64 - signed: bool = dataclasses.field(default=True, init=False) - - -@immutable -class Int64(Int, _PhysicalNumber): - """Semantic representation of an integer data type stored in 64 bits.""" - - bit_width = 64 - - -@immutable -class Int32(Int64): - """Semantic representation of an integer data type stored in 32 bits.""" - - bit_width = 32 - - -@immutable -class Int16(Int32): - """Semantic representation of an integer data type stored in 16 bits.""" - - bit_width = 16 - - -@immutable -class Int8(Int16): - """Semantic representation of an integer data type stored in 8 bits.""" - - bit_width = 8 - - -################################################################################ -## unsigned integer -################################################################################ - - -@immutable -class UInt(Int): - """Semantic representation of an unsigned integer data type.""" - - _base_name = "uint" - signed: bool = dataclasses.field(default=False, init=False) - - -@immutable -class UInt64(UInt): - """Semantic representation of an unsigned integer data type stored - in 64 bits.""" - - bit_width = 64 - - -@immutable -class UInt32(UInt64): - """Semantic representation of an unsigned integer data type stored - in 32 bits.""" - - bit_width = 32 - - -@immutable -class UInt16(UInt32): - """Semantic representation of an unsigned integer data type stored - in 16 bits.""" - - bit_width = 16 - - -@immutable -class UInt8(UInt16): - """Semantic representation of an unsigned integer data type stored - in 8 bits.""" - - bit_width = 8 - - -################################################################################ -## float -################################################################################ - - -@immutable(eq=False) -class Float(_PhysicalNumber): # type: ignore - """Semantic representation of a floating data type.""" - - _base_name = "float" - continuous = True - exact = False - bit_width = 64 - - -@immutable -class Float128(Float): - """Semantic representation of a floating data type stored in 128 bits.""" - - bit_width = 128 - - -@immutable -class Float64(Float128): - """Semantic representation of a floating data type stored in 64 bits.""" - - bit_width = 64 - - -@immutable -class Float32(Float64): - """Semantic representation of a floating data type stored in 32 bits.""" - - bit_width = 32 - - -@immutable -class Float16(Float32): - """Semantic representation of a floating data type stored in 16 bits.""" - - bit_width = 16 - - -################################################################################ -## complex -################################################################################ - - -@immutable(eq=False) -class Complex(_PhysicalNumber): # type: ignore - """Semantic representation of a complex number data type.""" - - _base_name = "complex" - bit_width = 128 - - -@immutable -class Complex256(Complex): - """Semantic representation of a complex number data type stored - in 256 bits.""" - - bit_width = 256 - - -@immutable -class Complex128(Complex): - """Semantic representation of a complex number data type stored - in 128 bits.""" - - bit_width = 128 - - -@immutable -class Complex64(Complex128): - """Semantic representation of a complex number data type stored - in 64 bits.""" - - bit_width = 64 - - -################################################################################ -# nominal -################################################################################ - - -@immutable(init=True) -class Category(DataType): # type: ignore - """Semantic representation of a categorical data type.""" - - categories: Optional[Tuple[Any]] = None # tuple to ensure safe hash - ordered: bool = False - - def __init__( - self, categories: Optional[Iterable[Any]] = None, ordered: bool = False - ): - # Define __init__ to avoid exposing pylint errors to end users. - super().__init__() - if categories is not None and not isinstance(categories, tuple): - object.__setattr__(self, "categories", tuple(categories)) - object.__setattr__(self, "ordered", ordered) - - def check(self, pandera_dtype: "DataType") -> bool: - if isinstance(pandera_dtype, Category) and ( - self.categories is None or pandera_dtype.categories is None - ): - # Category without categories is a superset of any Category - # Allow end-users to not list categories when validating. - return True - - return super().check(pandera_dtype) - - def __str__(self) -> str: - return "category" - - -@immutable -class String(DataType): - """Semantic representation of a string data type.""" - - def __str__(self) -> str: - return "string" - - -################################################################################ -# time -################################################################################ - - -@immutable -class Date(DataType): - """Semantic representation of a date data type.""" - - def __str__(self) -> str: - return "date" - - -@immutable -class Timestamp(Date): - """Semantic representation of a timestamp data type.""" - - def __str__(self) -> str: - return "timestamp" - - -DateTime = Timestamp - - -@immutable -class Timedelta(DataType): - """Semantic representation of a delta time data type.""" - - def __str__(self) -> str: - return "timedelta" diff --git a/pandera/engines/engine.py b/pandera/engines/engine.py index 77ffa5c98..bf01e0473 100644 --- a/pandera/engines/engine.py +++ b/pandera/engines/engine.py @@ -12,6 +12,8 @@ Callable, Dict, List, + Set, + Tuple, Type, TypeVar, Union, @@ -20,7 +22,7 @@ import typing_inspect -from pandera.dtypes_ import DataType +from pandera.dtypes import DataType _DataType = TypeVar("_DataType", bound=DataType) _Engine = TypeVar("_Engine", bound="Engine") @@ -58,16 +60,17 @@ class Engine(ABCMeta): """ _registry: Dict["Engine", _DtypeRegistry] = {} - _base_pandera_dtypes: Type[DataType] + _registered_dtypes: Set[Type[DataType]] + _base_pandera_dtypes: Tuple[Type[DataType]] def __new__(cls, name, bases, namespace, **kwargs): - base_pandera_dtypes = kwargs.pop("base_pandera_dtypes") - try: # allow multiple base datatypes - base_pandera_dtypes = tuple(base_pandera_dtypes) + try: + namespace["_base_pandera_dtypes"] = tuple(base_pandera_dtypes) except TypeError: - pass - namespace["_base_pandera_dtypes"] = base_pandera_dtypes + namespace["_base_pandera_dtypes"] = (base_pandera_dtypes,) + + namespace["_registered_dtypes"] = set() engine = super().__new__(cls, name, bases, namespace, **kwargs) @functools.singledispatch @@ -82,8 +85,12 @@ def _check_source_dtype(cls, data_type: Any) -> None: inspect.isclass(data_type) and issubclass(data_type, cls._base_pandera_dtypes) ): + base_names = [ + f"{base.__module__}.{base.__qualname__}" + for base in cls._base_pandera_dtypes + ] raise ValueError( - f"{cls._base_pandera_dtypes.__name__} subclasses cannot be registered" + f"Subclasses of {base_names} cannot be registered" f" with {cls.__name__}." ) @@ -136,25 +143,26 @@ def register_dtype( The classmethod ``from_parametrized_dtype`` will also be registered. """ - def _wrapper(pandera_dtype: Union[DataType, Type[DataType]]): - if not inspect.isclass(pandera_dtype): + def _wrapper(pandera_dtype_cls: Union[DataType, Type[DataType]]): + if not inspect.isclass(pandera_dtype_cls): raise ValueError( f"{cls.__name__}.register_dtype can only decorate a class, " - + f"got {pandera_dtype}" + + f"got {pandera_dtype_cls}" ) if equivalents: - cls._register_equivalents(pandera_dtype, *equivalents) + cls._register_equivalents(pandera_dtype_cls, *equivalents) - if "from_parametrized_dtype" in pandera_dtype.__dict__: - cls._register_from_parametrized_dtype(pandera_dtype) + if "from_parametrized_dtype" in pandera_dtype_cls.__dict__: + cls._register_from_parametrized_dtype(pandera_dtype_cls) elif not equivalents: warnings.warn( - f"register_dtype({pandera_dtype}) on a class without a " + f"register_dtype({pandera_dtype_cls}) on a class without a " + "'from_parametrized_dtype' classmethod has no effect." ) - return pandera_dtype + cls._registered_dtypes.add(pandera_dtype_cls) + return pandera_dtype_cls if pandera_dtype_cls: return _wrapper(pandera_dtype_cls) @@ -190,3 +198,8 @@ def dtype(cls: _EngineType, data_type: Any) -> _DataType: raise TypeError( f"Data type '{data_type}' not understood by {cls.__name__}." ) from None + + def get_registered_dtypes(cls) -> List[Type[DataType]]: + """Return :class:`pandera.dtypes.DataType`s registered + with this engine.""" + return list(cls._registered_dtypes) diff --git a/pandera/engines/numpy_engine.py b/pandera/engines/numpy_engine.py index 269e1edfc..c5fd62ff2 100644 --- a/pandera/engines/numpy_engine.py +++ b/pandera/engines/numpy_engine.py @@ -5,18 +5,21 @@ import dataclasses import datetime import inspect +import platform import warnings from typing import Any, Dict, List, Union import numpy as np -from .. import dtypes_ -from ..dtypes_ import immutable +from .. import dtypes +from ..dtypes import immutable from . import engine +WINDOWS_PLATFORM = platform.system() == "Windows" + @immutable(init=True) -class DataType(dtypes_.DataType): +class DataType(dtypes.DataType): """Base `DataType` for boxing Numpy data types.""" type: np.dtype = dataclasses.field( @@ -53,7 +56,7 @@ class Engine( # pylint:disable=too-few-public-methods """Numpy data type engine.""" @classmethod - def dtype(cls, data_type: Any) -> dtypes_.DataType: + def dtype(cls, data_type: Any) -> dtypes.DataType: """Convert input into a numpy-compatible Pandera :class:`DataType` object.""" try: @@ -63,7 +66,8 @@ def dtype(cls, data_type: Any) -> dtypes_.DataType: np_dtype = np.dtype(data_type).type except TypeError: raise TypeError( - f"data type '{data_type}' not understood by {cls.__name__}." + f"data type '{data_type}' not understood by " + f"{cls.__name__}." ) from None try: @@ -72,22 +76,22 @@ def dtype(cls, data_type: Any) -> dtypes_.DataType: return DataType(data_type) -################################################################################ +############################################################################### # boolean -################################################################################ +############################################################################### @Engine.register_dtype( - equivalents=["bool", bool, np.bool_, dtypes_.Bool, dtypes_.Bool()] + equivalents=["bool", bool, np.bool_, dtypes.Bool, dtypes.Bool()] ) @immutable -class Bool(DataType, dtypes_.Bool): +class Bool(DataType, dtypes.Bool): type = np.dtype("bool") def _build_number_equivalents( builtin_name: str, pandera_name: str, sizes: List[int] -) -> Dict[int, List[Union[type, str, np.dtype, dtypes_.DataType]]]: +) -> Dict[int, List[Union[type, str, np.dtype, dtypes.DataType]]]: """Return a dict of equivalent builtin, numpy, pandera dtypes indexed by size in bit_width.""" builtin_type = getattr(builtins, builtin_name, None) @@ -98,7 +102,7 @@ def _build_number_equivalents( # e.g.: np.int64 np.dtype(builtin_name).type, # e.g: pandera.dtypes.Int - getattr(dtypes_, pandera_name), + getattr(dtypes, pandera_name), ] if builtin_type: default_equivalents.append(builtin_type) @@ -110,10 +114,10 @@ def _build_number_equivalents( # e.g.: numpy.int64 getattr(np, f"{builtin_name}{bit_width}"), # e.g.: pandera.dtypes.Int64 - getattr(dtypes_, f"{pandera_name}{bit_width}"), - getattr(dtypes_, f"{pandera_name}{bit_width}")(), + getattr(dtypes, f"{pandera_name}{bit_width}"), + getattr(dtypes, f"{pandera_name}{bit_width}")(), # e.g.: pandera.dtypes.Int(64) - getattr(dtypes_, pandera_name)(), + getattr(dtypes, pandera_name)(), ) ) | set(default_equivalents if bit_width == default_size else []) @@ -122,9 +126,9 @@ def _build_number_equivalents( } -################################################################################ -## signed integer -################################################################################ +############################################################################### +# signed integer +############################################################################### _int_equivalents = _build_number_equivalents( builtin_name="int", pandera_name="Int", sizes=[64, 32, 16, 8] @@ -133,7 +137,7 @@ def _build_number_equivalents( @Engine.register_dtype(equivalents=_int_equivalents[64]) @immutable -class Int64(DataType, dtypes_.Int64): +class Int64(DataType, dtypes.Int64): type = np.dtype("int64") bit_width: int = 64 @@ -160,9 +164,9 @@ class Int8(Int16): bit_width: int = 8 -################################################################################ -## unsigned integer -################################################################################ +############################################################################### +# unsigned integer +############################################################################### _uint_equivalents = _build_number_equivalents( builtin_name="uint", @@ -173,7 +177,7 @@ class Int8(Int16): @Engine.register_dtype(equivalents=_uint_equivalents[64]) @immutable -class UInt64(DataType, dtypes_.UInt64): +class UInt64(DataType, dtypes.UInt64): type = np.dtype("uint64") bit_width: int = 64 @@ -199,29 +203,40 @@ class UInt8(UInt16): bit_width: int = 8 -################################################################################ -## float -################################################################################ +############################################################################### +# float +############################################################################### _float_equivalents = _build_number_equivalents( builtin_name="float", pandera_name="Float", - sizes=[128, 64, 32, 16], + sizes=[64, 32, 16] if WINDOWS_PLATFORM else [128, 64, 32, 16], ) -@Engine.register_dtype(equivalents=_float_equivalents[128]) -@immutable -class Float128(DataType, dtypes_.Float128): - type = np.dtype("float128") - bit_width: int = 128 +if not WINDOWS_PLATFORM: + # not supported in windows + # https://github.com/winpython/winpython/issues/613 + @Engine.register_dtype(equivalents=_float_equivalents[128]) + @immutable + class Float128(DataType, dtypes.Float128): + type = np.dtype("float128") + bit_width: int = 128 + @Engine.register_dtype(equivalents=_float_equivalents[64]) + @immutable + class Float64(Float128): + type = np.dtype("float64") + bit_width: int = 64 -@Engine.register_dtype(equivalents=_float_equivalents[64]) -@immutable -class Float64(Float128): - type = np.dtype("float64") - bit_width: int = 64 + +else: + + @Engine.register_dtype(equivalents=_float_equivalents[64]) + @immutable + class Float64(DataType, dtypes.Float64): # type: ignore + type = np.dtype("float64") + bit_width: int = 64 @Engine.register_dtype(equivalents=_float_equivalents[32]) @@ -238,29 +253,40 @@ class Float16(Float32): bit_width: int = 16 -################################################################################ -## complex -################################################################################ +############################################################################### +# complex +############################################################################### _complex_equivalents = _build_number_equivalents( builtin_name="complex", pandera_name="Complex", - sizes=[256, 128, 64], + sizes=[128, 64] if WINDOWS_PLATFORM else [256, 128, 64], ) -@Engine.register_dtype(equivalents=_complex_equivalents[256]) -@immutable -class Complex256(DataType, dtypes_.Complex256): - type = np.dtype("complex256") - bit_width: int = 256 +if not WINDOWS_PLATFORM: + # not supported in windows + # https://github.com/winpython/winpython/issues/613 + @Engine.register_dtype(equivalents=_complex_equivalents[256]) + @immutable + class Complex256(DataType, dtypes.Complex256): + type = np.dtype("complex256") + bit_width: int = 256 + @Engine.register_dtype(equivalents=_complex_equivalents[128]) + @immutable + class Complex128(Complex256): + type = np.dtype("complex128") # type: ignore + bit_width: int = 128 -@Engine.register_dtype(equivalents=_complex_equivalents[128]) -@immutable -class Complex128(Complex256): - type = np.dtype("complex128") # type: ignore - bit_width: int = 128 + +else: + + @Engine.register_dtype(equivalents=_complex_equivalents[128]) + @immutable + class Complex128(DataType, dtypes.Complex128): # type: ignore + type = np.dtype("complex128") # type: ignore + bit_width: int = 128 @Engine.register_dtype(equivalents=_complex_equivalents[64]) @@ -270,14 +296,14 @@ class Complex64(Complex128): bit_width: int = 64 -################################################################################ +############################################################################### # string -################################################################################ +############################################################################### @Engine.register_dtype(equivalents=["str", "string", str, np.str_]) @immutable -class String(DataType, dtypes_.String): +class String(DataType, dtypes.String): type = np.dtype("str") def coerce(self, data_container: np.ndarray) -> np.ndarray: @@ -286,13 +312,13 @@ def coerce(self, data_container: np.ndarray) -> np.ndarray: data_container[notna] = data_container[notna].astype(str) return data_container - def check(self, pandera_dtype: "dtypes_.DataType") -> bool: + def check(self, pandera_dtype: "dtypes.DataType") -> bool: return isinstance(pandera_dtype, (Object, type(self))) -################################################################################ +############################################################################### # object -################################################################################ +############################################################################### @Engine.register_dtype(equivalents=["object", "O", object, np.object_]) @@ -301,21 +327,21 @@ class Object(DataType): type = np.dtype("object") -################################################################################ +############################################################################### # time -################################################################################ +############################################################################### @Engine.register_dtype( equivalents=[ datetime.datetime, np.datetime64, - dtypes_.Timestamp, - dtypes_.Timestamp(), + dtypes.Timestamp, + dtypes.Timestamp(), ] ) @immutable -class DateTime64(DataType, dtypes_.Timestamp): +class DateTime64(DataType, dtypes.Timestamp): type = np.dtype("datetime64") @@ -323,10 +349,10 @@ class DateTime64(DataType, dtypes_.Timestamp): equivalents=[ datetime.datetime, np.timedelta64, - dtypes_.Timedelta, - dtypes_.Timedelta(), + dtypes.Timedelta, + dtypes.Timedelta(), ] ) @immutable -class Timedelta64(DataType, dtypes_.Timedelta): - type = np.dtype("timedelta64") +class Timedelta64(DataType, dtypes.Timedelta): + type = np.dtype("timedelta64[ns]") diff --git a/pandera/engines/pandas_engine.py b/pandera/engines/pandas_engine.py index 25a7dfd80..44a8f2fb2 100644 --- a/pandera/engines/pandas_engine.py +++ b/pandera/engines/pandas_engine.py @@ -10,16 +10,19 @@ import dataclasses import datetime import inspect +import platform import warnings from typing import Any, Dict, Iterable, List, Optional, Union import numpy as np import pandas as pd -from .. import dtypes_ -from ..dtypes_ import immutable +from .. import dtypes +from ..dtypes import immutable from . import engine, numpy_engine +WINDOWS_PLATFORM = platform.system() == "Windows" + PandasObject = Union[pd.Series, pd.Index, pd.DataFrame] PandasExtensionType = pd.core.dtypes.base.ExtensionDtype PandasDtype = Union[pd.core.dtypes.base.ExtensionDtype, np.dtype, type] @@ -34,7 +37,7 @@ def is_extension_dtype(pd_dtype: PandasDtype) -> bool: @immutable(init=True) -class DataType(dtypes_.DataType): +class DataType(dtypes.DataType): """Base `DataType` for boxing Pandas data types.""" type: Any = dataclasses.field(repr=False, init=False) @@ -56,7 +59,7 @@ def __post_init__(self): def coerce(self, data_container: PandasObject) -> PandasObject: return data_container.astype(self.type) - def check(self, pandera_dtype: dtypes_.DataType) -> bool: + def check(self, pandera_dtype: dtypes.DataType) -> bool: try: pandera_dtype = Engine.dtype(pandera_dtype) except TypeError: @@ -93,7 +96,8 @@ def dtype(cls, data_type: Any) -> "DataType": except (TypeError, AttributeError) as err: raise TypeError( f" dtype {data_type} cannot be instantiated: {err}\n" - "Usage Tip: Use an instance or a string representation." + "Usage Tip: Use an instance or a string " + "representation." ) from None else: # let pandas transform any acceptable value @@ -104,18 +108,30 @@ def dtype(cls, data_type: Any) -> "DataType": try: return engine.Engine.dtype(cls, np_or_pd_dtype) - except TypeError as err: + except TypeError: return DataType(np_or_pd_dtype) + @classmethod + def numpy_dtype(cls, pandera_dtype: dtypes.DataType) -> np.dtype: + """Convert a pandera data type to a numpy data type.""" + pandera_dtype = engine.Engine.dtype(cls, pandera_dtype) + + alias = str(pandera_dtype).lower() + if alias == "boolean": + alias = "bool" + elif alias == "string": + alias = "str" + return np.dtype(alias) + -################################################################################ +############################################################################### # boolean -################################################################################ +############################################################################### Engine.register_dtype( numpy_engine.Bool, - equivalents=["bool", bool, np.bool_, dtypes_.Bool, dtypes_.Bool()], + equivalents=["bool", bool, np.bool_, dtypes.Bool, dtypes.Bool()], ) @@ -123,15 +139,15 @@ def dtype(cls, data_type: Any) -> "DataType": equivalents=["boolean", pd.BooleanDtype, pd.BooleanDtype()], ) @immutable -class Bool(DataType, dtypes_.Bool): +class Bool(DataType, dtypes.Bool): type = pd.BooleanDtype() BOOL = Bool -################################################################################ +############################################################################### # number -################################################################################ +############################################################################### def _register_numpy_numbers( @@ -141,7 +157,12 @@ def _register_numpy_numbers( with the pandas engine.""" builtin_type = getattr(builtins, builtin_name, None) # uint doesn't exist - default_pd_dtype = pd.Series([1], dtype=builtin_name).dtype + + # default to int64 regardless of OS + default_pd_dtype = { + "int": np.dtype("int64"), + "uint": np.dtype("uint64"), + }.get(builtin_name, pd.Series([1], dtype=builtin_name).dtype) for bit_width in sizes: # e.g.: numpy.int64 @@ -150,21 +171,19 @@ def _register_numpy_numbers( equivalents = set( ( np_dtype, - getattr(np, f"{builtin_name}{bit_width}"), # e.g.: pandera.dtypes.Int64 - getattr(dtypes_, f"{pandera_name}{bit_width}"), - getattr(dtypes_, f"{pandera_name}{bit_width}")(), + getattr(dtypes, f"{pandera_name}{bit_width}"), + getattr(dtypes, f"{pandera_name}{bit_width}")(), ) ) if np_dtype == default_pd_dtype: equivalents |= set( ( - # e.g: numpy.int_ default_pd_dtype, - # e.g: pandera.dtypes.Int - getattr(dtypes_, pandera_name), - getattr(dtypes_, pandera_name)(), + builtin_name, + getattr(dtypes, pandera_name), + getattr(dtypes, pandera_name)(), ) ) if builtin_type: @@ -178,12 +197,13 @@ def _register_numpy_numbers( equivalents.add("integer") numpy_data_type = getattr(numpy_engine, f"{pandera_name}{bit_width}") + print(f"EQUIVALENTS FOR {numpy_data_type}: {list(equivalents)}") Engine.register_dtype(numpy_data_type, equivalents=list(equivalents)) -################################################################################ -## signed integer -################################################################################ +############################################################################### +# signed integer +############################################################################### _register_numpy_numbers( builtin_name="int", @@ -194,7 +214,7 @@ def _register_numpy_numbers( @Engine.register_dtype(equivalents=[pd.Int64Dtype, pd.Int64Dtype()]) @immutable -class Int64(DataType, dtypes_.Int): +class Int64(DataType, dtypes.Int): type = pd.Int64Dtype() bit_width: int = 64 @@ -231,9 +251,9 @@ class Int8(Int16): INT8 = Int8 -################################################################################ -## unsigned integer -################################################################################ +############################################################################### +# unsigned integer +############################################################################### _register_numpy_numbers( builtin_name="uint", @@ -244,7 +264,7 @@ class Int8(Int16): @Engine.register_dtype(equivalents=[pd.UInt64Dtype, pd.UInt64Dtype()]) @immutable -class UInt64(DataType, dtypes_.UInt): +class UInt64(DataType, dtypes.UInt): type = pd.UInt64Dtype() bit_width: int = 64 @@ -275,41 +295,41 @@ class UInt8(UInt16): UINT16 = UInt16 UINT8 = UInt8 -# ################################################################################ -# ## float -# ################################################################################ +# ############################################################################### +# # float +# ############################################################################### _register_numpy_numbers( builtin_name="float", pandera_name="Float", - sizes=[128, 64, 32, 16], + sizes=[64, 32, 16] if WINDOWS_PLATFORM else [128, 64, 32, 16], ) -# ################################################################################ -# ## complex -# ################################################################################ +# ############################################################################### +# # complex +# ############################################################################### _register_numpy_numbers( builtin_name="complex", pandera_name="Complex", - sizes=[128, 64], + sizes=[128, 64] if WINDOWS_PLATFORM else [256, 128, 64], ) -# ################################################################################ +# ############################################################################### # # nominal -# ################################################################################ +# ############################################################################### @Engine.register_dtype( equivalents=[ "category", "categorical", - dtypes_.Category, + dtypes.Category, pd.CategoricalDtype, ] ) @immutable(init=True) -class Category(DataType, dtypes_.Category): +class Category(DataType, dtypes.Category): type: pd.CategoricalDtype = dataclasses.field(default=None, init=False) def __init__( # pylint:disable=super-init-not-called @@ -317,7 +337,7 @@ def __init__( # pylint:disable=super-init-not-called categories: Optional[Iterable[Any]] = None, ordered: bool = False, ) -> None: - dtypes_.Category.__init__(self, categories, ordered) + dtypes.Category.__init__(self, categories, ordered) object.__setattr__( self, "type", @@ -326,7 +346,7 @@ def __init__( # pylint:disable=super-init-not-called @classmethod def from_parametrized_dtype( - cls, cat: Union[dtypes_.Category, pd.CategoricalDtype] + cls, cat: Union[dtypes.Category, pd.CategoricalDtype] ): """Convert a categorical to a Pandera :class:`~pandera.dtypes.pandas_engine.Category`.""" @@ -339,7 +359,7 @@ def from_parametrized_dtype( equivalents=["string", pd.StringDtype, pd.StringDtype()] ) @immutable -class String(DataType, dtypes_.String): +class String(DataType, dtypes.String): type = pd.StringDtype() @@ -347,7 +367,7 @@ class String(DataType, dtypes_.String): @Engine.register_dtype( - equivalents=["str", str, dtypes_.String, dtypes_.String(), np.str_] + equivalents=["str", str, dtypes.String, dtypes.String(), np.str_] ) @immutable class NpString(numpy_engine.String): @@ -361,7 +381,7 @@ def coerce(self, data_container: PandasObject) -> np.ndarray: data_container.isna(), data_container.astype(str) ) - def check(self, pandera_dtype: dtypes_.DataType) -> bool: + def check(self, pandera_dtype: dtypes.DataType) -> bool: return isinstance(pandera_dtype, (numpy_engine.Object, type(self))) @@ -379,9 +399,9 @@ def check(self, pandera_dtype: dtypes_.DataType) -> bool: ], ) -# ################################################################################ +# ############################################################################### # # time -# ################################################################################ +# ############################################################################### _PandasDatetime = Union[np.datetime64, pd.DatetimeTZDtype] @@ -394,13 +414,13 @@ def check(self, pandera_dtype: dtypes_.DataType) -> bool: "datetime64", datetime.datetime, np.datetime64, - dtypes_.Timestamp, - dtypes_.Timestamp(), + dtypes.Timestamp, + dtypes.Timestamp(), pd.Timestamp, ] ) @immutable(init=True) -class DateTime(DataType, dtypes_.Timestamp): +class DateTime(DataType, dtypes.Timestamp): type: Optional[_PandasDatetime] = dataclasses.field( default=None, init=False ) @@ -412,7 +432,7 @@ class DateTime(DataType, dtypes_.Timestamp): def __post_init__(self): if self.tz is None: - type_ = np.dtype("datetime64") + type_ = np.dtype("datetime64[ns]") else: type_ = pd.DatetimeTZDtype(self.unit, self.tz) # DatetimeTZDtype converted tz to tzinfo for us @@ -438,21 +458,21 @@ def from_parametrized_dtype(cls, pd_dtype: pd.DatetimeTZDtype): return cls(unit=pd_dtype.unit, tz=pd_dtype.tz) # type: ignore def __str__(self) -> str: - if self.type == np.dtype("datetime64"): + if self.type == np.dtype("datetime64[ns]"): return "datetime64[ns]" return str(self.type) Engine.register_dtype( - numpy_engine.DateTime64, + numpy_engine.Timedelta64, equivalents=[ "timedelta", "timedelta64", datetime.timedelta, np.timedelta64, pd.Timedelta, - dtypes_.Timedelta, - dtypes_.Timedelta(), + dtypes.Timedelta, + dtypes.Timedelta(), ], ) @@ -475,9 +495,9 @@ def from_parametrized_dtype(cls, pd_dtype: pd.PeriodDtype): return cls(freq=pd_dtype.freq) # type: ignore -# ################################################################################ +# ############################################################################### # # misc -# ################################################################################ +# ############################################################################### @Engine.register_dtype(equivalents=[pd.SparseDtype]) @@ -522,4 +542,9 @@ def __post_init__(self): def from_parametrized_dtype(cls, pd_dtype: pd.IntervalDtype): """Convert a :class:`pandas.IntervalDtype` to a Pandera :class:`~pandera.engines.pandas_engine.Interval`.""" - return cls(subdtype=pd_dtype.subtype) # type: ignore + return cls(subtype=pd_dtype.subtype) # type: ignore + + +print("PANDAS ENGINE EQUIVALENTS") +for k, v in engine.Engine._registry[Engine].equivalents.items(): + print(f"{k}: equivalents={v}") diff --git a/pandera/io.py b/pandera/io.py index 91953b2c2..5f228460a 100644 --- a/pandera/io.py +++ b/pandera/io.py @@ -10,8 +10,9 @@ import pandera.errors +from . import dtypes from .checks import Check -from .dtypes import PandasDtype +from .engines import pandas_engine from .schema_components import Column from .schema_statistics import get_dataframe_schema_statistics from .schemas import DataFrameSchema @@ -29,18 +30,20 @@ ) from exc -SCHEMA_TYPES = {"dataframe"} DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S" -NOT_JSON_SERIALIZABLE = {PandasDtype.DateTime, PandasDtype.Timedelta} -def _serialize_check_stats(check_stats, pandas_dtype=None): +def _get_qualified_name(cls: type) -> str: + return f"{cls.__module__}.{cls.__qualname__}" + + +def _serialize_check_stats(check_stats, dtype=None): """Serialize check statistics into json/yaml-compatible format.""" def handle_stat_dtype(stat): - if pandas_dtype == PandasDtype.DateTime: + if pandas_engine.Engine.dtype(dtypes.DateTime).check(dtype): return stat.strftime(DATETIME_FORMAT) - elif pandas_dtype == PandasDtype.Timedelta: + elif pandas_engine.Engine.dtype(dtypes.Timedelta).check(dtype): # serialize to int in nanoseconds return stat.delta @@ -83,15 +86,15 @@ def _serialize_component_stats(component_stats): serialized_checks = {} for check_name, check_stats in component_stats["checks"].items(): serialized_checks[check_name] = _serialize_check_stats( - check_stats, component_stats["pandas_dtype"] + check_stats, component_stats["dtype"] ) - pandas_dtype = component_stats.get("pandas_dtype") - if pandas_dtype: - pandas_dtype = pandas_dtype.value + dtype = component_stats.get("dtype") + if dtype: + dtype = str(dtype) return { - "pandas_dtype": pandas_dtype, + "dtype": dtype, "nullable": component_stats["nullable"], "checks": serialized_checks, **{ @@ -141,11 +144,11 @@ def _serialize_schema(dataframe_schema): } -def _deserialize_check_stats(check, serialized_check_stats, pandas_dtype=None): +def _deserialize_check_stats(check, serialized_check_stats, dtype=None): def handle_stat_dtype(stat): - if pandas_dtype == PandasDtype.DateTime: + if pandas_engine.Engine.dtype(dtypes.DateTime).check(dtype): return pd.to_datetime(stat, format=DATETIME_FORMAT) - elif pandas_dtype == PandasDtype.Timedelta: + elif pandas_engine.Engine.dtype(dtypes.Timedelta).check(dtype): # serialize to int in nanoseconds return pd.to_timedelta(stat, unit="ns") return stat @@ -162,20 +165,20 @@ def handle_stat_dtype(stat): def _deserialize_component_stats(serialized_component_stats): - pandas_dtype = serialized_component_stats.get("pandas_dtype") - if pandas_dtype: - pandas_dtype = PandasDtype.from_str_alias(pandas_dtype) + dtype = serialized_component_stats.get("dtype") + if dtype: + dtype = pandas_engine.Engine.dtype(dtype) checks = serialized_component_stats.get("checks") if checks is not None: checks = [ _deserialize_check_stats( - getattr(Check, check_name), check_stats, pandas_dtype + getattr(Check, check_name), check_stats, dtype ) for check_name, check_stats in checks.items() ] return { - "pandas_dtype": pandas_dtype, + "dtype": dtype, "checks": checks, **{ key: serialized_component_stats.get(key) @@ -280,7 +283,7 @@ def _write_yaml(obj, stream): SCRIPT_TEMPLATE = """ from pandera import ( - DataFrameSchema, Column, Check, Index, MultiIndex, PandasDtype + DataFrameSchema, Column, Check, Index, MultiIndex ) schema = DataFrameSchema( @@ -294,7 +297,7 @@ def _write_yaml(obj, stream): COLUMN_TEMPLATE = """ Column( - pandas_dtype={pandas_dtype}, + dtype={dtype}, checks={checks}, nullable={nullable}, allow_duplicates={allow_duplicates}, @@ -305,7 +308,7 @@ def _write_yaml(obj, stream): """ INDEX_TEMPLATE = ( - "Index(pandas_dtype={pandas_dtype},checks={checks}," + "Index(dtype={dtype},checks={checks}," "nullable={nullable},coerce={coerce},name={name})" ) @@ -336,8 +339,9 @@ def _format_checks(checks_dict): def _format_index(index_statistics): index = [] for properties in index_statistics: + dtype = properties.get("dtype") index_code = INDEX_TEMPLATE.format( - pandas_dtype=f"PandasDtype.{properties['pandas_dtype'].name}", + dtype=f"{_get_qualified_name(dtype.__class__)}", checks=( "None" if properties["checks"] is None @@ -376,12 +380,10 @@ def to_script(dataframe_schema, path_or_buf=None): columns = {} for colname, properties in statistics["columns"].items(): - pandas_dtype = properties.get("pandas_dtype") + dtype = properties.get("dtype") column_code = COLUMN_TEMPLATE.format( - pandas_dtype=( - None - if pandas_dtype is None - else f"PandasDtype.{properties['pandas_dtype'].name}" + dtype=( + None if dtype is None else _get_qualified_name(dtype.__class__) ), checks=_format_checks(properties["checks"]), nullable=properties["nullable"], @@ -444,9 +446,9 @@ def __init__(self, field, primary_keys) -> None: self.type = field.get("type", "string") @property - def pandas_dtype(self) -> str: + def dtype(self) -> str: """Determine what type of field this is, so we can feed that into - :class:`~pandera.dtypes.PandasDtype`. If no type is specified in the + :class:`~pandera.dtypes.DataType`. If no type is specified in the frictionless schema, we default to string values. :returns: the pandas-compatible representation of this field type as a @@ -562,7 +564,7 @@ def to_pandera_column(self) -> Dict: "checks": self.checks, "coerce": self.coerce, "nullable": self.nullable, - "pandas_dtype": self.pandas_dtype, + "dtype": self.dtype, "required": self.required, "name": self.name, "regex": self.regex, diff --git a/pandera/schema_components.py b/pandera/schema_components.py index 0208c62ad..0d083af38 100644 --- a/pandera/schema_components.py +++ b/pandera/schema_components.py @@ -103,7 +103,7 @@ def _allow_groupby(self) -> bool: def properties(self) -> Dict[str, Any]: """Get column properties.""" return { - "dtype": self._dtype, + "dtype": self.dtype, "checks": self._checks, "nullable": self._nullable, "allow_duplicates": self._allow_duplicates, @@ -301,6 +301,9 @@ def example(self, size=None) -> pd.DataFrame: ) def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + def _compare_dict(obj): return { k: v if k != "_checks" else set(v) diff --git a/pandera/schema_inference.py b/pandera/schema_inference.py index f5f3184b1..69107407b 100644 --- a/pandera/schema_inference.py +++ b/pandera/schema_inference.py @@ -36,7 +36,7 @@ def infer_schema( def _create_index(index_statistics): index = [ Index( - properties["pandas_dtype"], + properties["dtype"], checks=parse_check_statistics(properties["checks"]), nullable=properties["nullable"], name=properties["name"], @@ -58,11 +58,10 @@ def infer_dataframe_schema(df: pd.DataFrame) -> DataFrameSchema: :returns: DataFrameSchema """ df_statistics = infer_dataframe_statistics(df) - schema = DataFrameSchema( columns={ colname: Column( - properties["pandas_dtype"], + properties["dtype"], checks=parse_check_statistics(properties["checks"]), nullable=properties["nullable"], ) @@ -83,7 +82,7 @@ def infer_series_schema(series) -> SeriesSchema: """ series_statistics = infer_series_statistics(series) schema = SeriesSchema( - pandas_dtype=series_statistics["pandas_dtype"], + dtype=series_statistics["dtype"], checks=parse_check_statistics(series_statistics["checks"]), nullable=series_statistics["nullable"], name=series_statistics["name"], diff --git a/pandera/schema_statistics.py b/pandera/schema_statistics.py index 6b30782dc..414970596 100644 --- a/pandera/schema_statistics.py +++ b/pandera/schema_statistics.py @@ -1,31 +1,12 @@ """Module for inferring the statistics of pandas objects.""" - import warnings from typing import Any, Dict, Union import pandas as pd +from . import dtypes from .checks import Check -from .dtypes import PandasDtype - -NUMERIC_DTYPES = frozenset( - [ - PandasDtype.Float, - PandasDtype.Float16, - PandasDtype.Float32, - PandasDtype.Float64, - PandasDtype.Int, - PandasDtype.Int8, - PandasDtype.Int16, - PandasDtype.Int32, - PandasDtype.Int64, - PandasDtype.UInt8, - PandasDtype.UInt16, - PandasDtype.UInt32, - PandasDtype.UInt64, - PandasDtype.DateTime, - ] -) +from .engines import pandas_engine def infer_dataframe_statistics(df: pd.DataFrame) -> Dict[str, Any]: @@ -34,7 +15,7 @@ def infer_dataframe_statistics(df: pd.DataFrame) -> Dict[str, Any]: inferred_column_dtypes = {col: _get_array_type(df[col]) for col in df} column_statistics = { col: { - "pandas_dtype": dtype, + "dtype": dtype, "nullable": bool(nullable_columns[col]), "checks": _get_array_check_statistics(df[col], dtype), } @@ -50,7 +31,7 @@ def infer_series_statistics(series: pd.Series) -> Dict[str, Any]: """Infer column and index statistics from a pandas Series.""" dtype = _get_array_type(series) return { - "pandas_dtype": dtype, + "dtype": dtype, "nullable": bool(series.isna().any()), "checks": _get_array_check_statistics(series, dtype), "name": series.name, @@ -63,7 +44,7 @@ def infer_index_statistics(index: Union[pd.Index, pd.MultiIndex]): def _index_stats(index_level): dtype = _get_array_type(index_level) return { - "pandas_dtype": dtype, + "dtype": dtype, "nullable": bool(index_level.isna().any()), "checks": _get_array_check_statistics(index_level, dtype), "name": index_level.name, @@ -105,7 +86,7 @@ def get_dataframe_schema_statistics(dataframe_schema): statistics = { "columns": { col_name: { - "pandas_dtype": column.pdtype, + "dtype": column.dtype, "nullable": column.nullable, "allow_duplicates": column.allow_duplicates, "coerce": column.coerce, @@ -128,7 +109,7 @@ def get_dataframe_schema_statistics(dataframe_schema): def _get_series_base_schema_statistics(series_schema_base): return { - "pandas_dtype": series_schema_base._pandas_dtype, + "dtype": series_schema_base.dtype, "nullable": series_schema_base.nullable, "checks": parse_checks(series_schema_base.checks), "coerce": series_schema_base.coerce, @@ -199,30 +180,31 @@ def parse_checks(checks) -> Union[Dict[str, Any], None]: def _get_array_type(x): # get most granular type possible - dtype = PandasDtype.from_str_alias(str(x.dtype)) + + data_type = pandas_engine.Engine.dtype(x.dtype) # for object arrays, try to infer dtype - if dtype is PandasDtype.Object: - dtype = PandasDtype.from_pandas_api_type( - pd.api.types.infer_dtype(x, skipna=True) - ) - return dtype + if data_type is pandas_engine.Engine.dtype("object"): + inferred_alias = pd.api.types.infer_dtype(x, skipna=True) + if inferred_alias != "string": + data_type = pandas_engine.Engine.dtype(inferred_alias) + return data_type def _get_array_check_statistics( - x, dtype: PandasDtype + x, data_type: dtypes.DataType ) -> Union[Dict[str, Any], None]: """Get check statistics from an array-like object.""" - if dtype is PandasDtype.DateTime: + if dtypes.is_datetime(data_type): check_stats = { "greater_than_or_equal_to": x.min(), "less_than_or_equal_to": x.max(), } - elif dtype in NUMERIC_DTYPES: + elif dtypes.is_numeric(data_type) and not dtypes.is_bool(data_type): check_stats = { "greater_than_or_equal_to": float(x.min()), "less_than_or_equal_to": float(x.max()), } - elif dtype is PandasDtype.Category: + elif dtypes.is_category(data_type): try: categories = x.cat.categories except AttributeError: diff --git a/pandera/schemas.py b/pandera/schemas.py index c9e57158f..d8f205345 100644 --- a/pandera/schemas.py +++ b/pandera/schemas.py @@ -12,12 +12,10 @@ import numpy as np import pandas as pd -from pandera import dtypes_ - from . import constants, errors from . import strategies as st from .checks import Check -from .dtypes_ import DataType +from .dtypes import DataType from .engines import pandas_engine from .error_formatters import ( format_generic_error_message, @@ -233,7 +231,7 @@ def _set_column_handler(column, column_name): } @property - def dtypes(self) -> Dict[str, dtypes_.DataType]: + def dtypes(self) -> Dict[str, DataType]: """ A pandas style dtypes dict where the keys are column names and values are pandas dtype for the column. Excludes columns where regex=True. @@ -246,7 +244,7 @@ def dtypes(self) -> Dict[str, dtypes_.DataType]: if regex_columns: warnings.warn( "Schema has columns specified as regex column names: %s " - "Use the `get_dtype` to get the datatypes for these " + "Use the `get_dtypes` to get the datatypes for these " "columns." % regex_columns, UserWarning, ) @@ -276,7 +274,7 @@ def get_dtypes(self, dataframe: pd.DataFrame) -> Dict[str, str]: @property def dtype( self, - ) -> dtypes_.DataType: + ) -> DataType: """Get the dtype property.""" return self._dtype # type: ignore @@ -770,9 +768,9 @@ def add_columns( ... ) - 'probability': - 'even_number': + 'category': + 'probability': + 'even_number': }, checks=[], coerce=False, @@ -822,7 +820,7 @@ def remove_columns(self, cols_to_remove: List[str]) -> "DataFrameSchema": >>> print(example_schema.remove_columns(["category"])) + 'probability': }, checks=[], coerce=False, @@ -882,8 +880,8 @@ def update_column(self, column_name: str, **kwargs) -> "DataFrameSchema": ... ) - 'probability': + 'category': + 'probability': }, checks=[], coerce=False, @@ -942,8 +940,8 @@ def update_columns( ... ) - 'probability': + 'category': + 'probability': }, checks=[], coerce=False, @@ -1024,8 +1022,8 @@ def rename_columns(self, rename_dict: Dict[str, str]) -> "DataFrameSchema": ... ) - 'probabilities': + 'categories': + 'probabilities': }, checks=[], coerce=False, @@ -1100,7 +1098,7 @@ def select_columns(self, columns: List[str]) -> "DataFrameSchema": >>> print(example_schema.select_columns(['category'])) + 'category': }, checks=[], coerce=False, @@ -1200,12 +1198,12 @@ def set_index( >>> print(example_schema.set_index(['category'])) + 'probability': }, checks=[], coerce=False, dtype=None, - index=, + index=, strict=False name=None, ordered=False @@ -1226,15 +1224,15 @@ def set_index( >>> print(example_schema.set_index(["column2"], append = True)) + 'column1': }, checks=[], coerce=False, dtype=None, index= - + + ] coerce=False, strict=False, @@ -1330,8 +1328,8 @@ def reset_index( >>> print(example_schema.reset_index()) - 'unique_id': + 'probability': + 'unique_id': }, checks=[], coerce=False, @@ -1359,13 +1357,13 @@ def reset_index( >>> print(example_schema.reset_index(level = ["unique_id1"])) - 'unique_id1': + 'category': + 'unique_id1': }, checks=[], coerce=False, dtype=None, - index=, + index=, strict=False name=None, ordered=False @@ -1564,7 +1562,7 @@ def name(self) -> Union[str, None]: @property def dtype( self, - ) -> dtypes_.DataType: + ) -> DataType: """Get the pandas dtype""" return self._dtype # type: ignore diff --git a/pandera/strategies.py b/pandera/strategies.py index 2c71d17fe..e5aea6644 100644 --- a/pandera/strategies.py +++ b/pandera/strategies.py @@ -10,7 +10,6 @@ See the :ref:`user guide` for more details. """ - import operator import re import warnings @@ -22,7 +21,15 @@ import numpy as np import pandas as pd -from .dtypes import PandasDtype +from .dtypes import ( + DataType, + is_category, + is_complex, + is_datetime, + is_float, + is_timedelta, +) +from .engines import numpy_engine, pandas_engine from .errors import BaseStrategyOnlyError, SchemaDefinitionError try: @@ -128,12 +135,16 @@ def set_pandas_index( return df_or_series -def verify_pandas_dtype(pandas_dtype, schema_type: str, name: Optional[str]): - """Verify that pandas_dtype argument is not None.""" - if pandas_dtype is None: +def verify_dtype( + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], + schema_type: str, + name: Optional[str], +): + """Verify that pandera_dtype argument is not None.""" + if pandera_dtype is None: raise SchemaDefinitionError( f"'{schema_type}' schema with name '{name}' has no specified " - "pandas_dtype. You need to specify one in order to synthesize " + "dtype. You need to specify one in order to synthesize " "data from a strategy." ) @@ -197,7 +208,7 @@ def _wrapper(cls, *args, **kwargs): MAX_DT_VALUE = 2 ** 63 - 1 -def numpy_time_dtypes(dtype, min_value=None, max_value=None): +def numpy_time_dtypes(dtype: np.dtype, min_value=None, max_value=None): """Create numpy strategy for datetime and timedelta data types. :param dtype: numpy datetime or timedelta datatype @@ -282,15 +293,24 @@ def build_complex(draw): return build_complex() +def to_numpy_dtype(pandera_dtype: DataType): + """Convert a :class:`~pandera.dtypes.DataType` to numpy dtype compatible + with hypothesis.""" + np_dtype = pandas_engine.Engine.numpy_dtype(pandera_dtype) + if np_dtype == np.dtype("object"): + np_dtype = np.dtype(str) + return np_dtype + + def pandas_dtype_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: DataType, strategy: Optional[SearchStrategy] = None, **kwargs, ) -> SearchStrategy: # pylint: disable=line-too-long,no-else-raise - """Strategy to generate data from a :class:`pandera.dtypes.PandasDtype`. + """Strategy to generate data from a :class:`pandera.dtypes.DataType`. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :kwargs: key-word arguments passed into @@ -306,35 +326,30 @@ def compat_kwargs(*args): # hypothesis doesn't support categoricals or objects, so we'll will need to # build a pandera-specific solution. - if pandas_dtype is PandasDtype.Category: + if is_category(pandera_dtype): raise TypeError( "data generation for the Category dtype is currently " "unsupported. Consider using a string or int dtype and " "Check.isin(values) to ensure a finite set of values." ) - # The object type falls back onto generating strings. - if pandas_dtype is PandasDtype.Object: - dtype = np.dtype("str") - else: - dtype = pandas_dtype.numpy_dtype - + np_dtype = to_numpy_dtype(pandera_dtype) if strategy: - return strategy.map(dtype.type) - elif pandas_dtype.is_datetime or pandas_dtype.is_timedelta: + return strategy.map(np_dtype.type) + elif is_datetime(pandera_dtype) or is_timedelta(pandera_dtype): return numpy_time_dtypes( - dtype, + np_dtype, **compat_kwargs("min_value", "max_value"), ) - elif pandas_dtype.is_complex: + elif is_complex(pandera_dtype): return numpy_complex_dtypes( - dtype, + np_dtype, **compat_kwargs( "min_value", "max_value", "allow_infinity", "allow_nan" ), ) return npst.from_dtype( - dtype, + np_dtype, **{ # type: ignore "allow_nan": False, "allow_infinity": False, @@ -344,14 +359,14 @@ def compat_kwargs(*args): def eq_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, value: Any, ) -> SearchStrategy: """Strategy to generate a single value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param value: value to generate. @@ -359,38 +374,38 @@ def eq_strategy( """ # override strategy preceding this one and generate value of the same type if strategy is None: - strategy = pandas_dtype_strategy(pandas_dtype) - return st.just(value).map(pandas_dtype.numpy_dtype.type) + strategy = pandas_dtype_strategy(pandera_dtype) + return st.just(value).map(to_numpy_dtype(pandera_dtype).type) def ne_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, value: Any, ) -> SearchStrategy: """Strategy to generate anything except for a particular value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param value: value to avoid. :returns: ``hypothesis`` strategy """ if strategy is None: - strategy = pandas_dtype_strategy(pandas_dtype) + strategy = pandas_dtype_strategy(pandera_dtype) return strategy.filter(lambda x: x != value) def gt_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, min_value: Union[int, float], ) -> SearchStrategy: """Strategy to generate values greater than a minimum value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param min_value: generate values larger than this. @@ -398,22 +413,22 @@ def gt_strategy( """ if strategy is None: strategy = pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, min_value=min_value, - exclude_min=True if pandas_dtype.is_float else None, + exclude_min=True if is_float(pandera_dtype) else None, ) return strategy.filter(lambda x: x > min_value) def ge_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, min_value: Union[int, float], ) -> SearchStrategy: """Strategy to generate values greater than or equal to a minimum value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param min_value: generate values greater than or equal to this. @@ -421,22 +436,22 @@ def ge_strategy( """ if strategy is None: return pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, min_value=min_value, - exclude_min=False if pandas_dtype.is_float else None, + exclude_min=False if is_float(pandera_dtype) else None, ) return strategy.filter(lambda x: x >= min_value) def lt_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, max_value: Union[int, float], ) -> SearchStrategy: """Strategy to generate values less than a maximum value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param max_value: generate values less than this. @@ -444,22 +459,22 @@ def lt_strategy( """ if strategy is None: strategy = pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, max_value=max_value, - exclude_max=True if pandas_dtype.is_float else None, + exclude_max=True if is_float(pandera_dtype) else None, ) return strategy.filter(lambda x: x < max_value) def le_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, max_value: Union[int, float], ) -> SearchStrategy: """Strategy to generate values less than or equal to a maximum value. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param max_value: generate values less than or equal to this. @@ -467,15 +482,15 @@ def le_strategy( """ if strategy is None: return pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, max_value=max_value, - exclude_max=False if pandas_dtype.is_float else None, + exclude_max=False if is_float(pandera_dtype) else None, ) return strategy.filter(lambda x: x <= max_value) def in_range_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, min_value: Union[int, float], @@ -485,7 +500,7 @@ def in_range_strategy( ) -> SearchStrategy: """Strategy to generate values within a particular range. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param min_value: generate values greater than this. @@ -496,7 +511,7 @@ def in_range_strategy( """ if strategy is None: return pandas_dtype_strategy( - pandas_dtype, + pandera_dtype, min_value=min_value, max_value=max_value, exclude_min=not include_min, @@ -510,14 +525,14 @@ def in_range_strategy( def isin_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, allowed_values: Sequence[Any], ) -> SearchStrategy: """Strategy to generate values within a finite set. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param allowed_values: set of allowable values. @@ -525,39 +540,39 @@ def isin_strategy( """ if strategy is None: return st.sampled_from(allowed_values).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) return strategy.filter(lambda x: x in allowed_values) def notin_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, forbidden_values: Sequence[Any], ) -> SearchStrategy: """Strategy to generate values excluding a set of forbidden values - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param forbidden_values: set of forbidden values. :returns: ``hypothesis`` strategy """ if strategy is None: - strategy = pandas_dtype_strategy(pandas_dtype) + strategy = pandas_dtype_strategy(pandera_dtype) return strategy.filter(lambda x: x not in forbidden_values) def str_matches_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, pattern: str, ) -> SearchStrategy: """Strategy to generate strings that patch a regex pattern. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param pattern: regex pattern. @@ -565,7 +580,7 @@ def str_matches_strategy( """ if strategy is None: return st.from_regex(pattern, fullmatch=True).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) def matches(x): @@ -575,14 +590,14 @@ def matches(x): def str_contains_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, pattern: str, ) -> SearchStrategy: """Strategy to generate strings that contain a particular pattern. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param pattern: regex pattern. @@ -590,7 +605,7 @@ def str_contains_strategy( """ if strategy is None: return st.from_regex(pattern, fullmatch=False).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) def contains(x): @@ -600,14 +615,14 @@ def contains(x): def str_startswith_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, string: str, ) -> SearchStrategy: """Strategy to generate strings that start with a specific string pattern. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param string: string pattern. @@ -615,21 +630,21 @@ def str_startswith_strategy( """ if strategy is None: return st.from_regex(f"\\A{string}", fullmatch=False).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) return strategy.filter(lambda x: x.startswith(string)) def str_endswith_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, string: str, ) -> SearchStrategy: """Strategy to generate strings that end with a specific string pattern. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param string: string pattern. @@ -637,14 +652,14 @@ def str_endswith_strategy( """ if strategy is None: return st.from_regex(f"{string}\\Z", fullmatch=False).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) return strategy.filter(lambda x: x.endswith(string)) def str_length_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, min_value: int, @@ -652,7 +667,7 @@ def str_length_strategy( ) -> SearchStrategy: """Strategy to generate strings of a particular length - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param min_value: minimum string length. @@ -661,21 +676,21 @@ def str_length_strategy( """ if strategy is None: return st.text(min_size=min_value, max_size=max_value).map( - pandas_dtype.numpy_dtype.type + to_numpy_dtype(pandera_dtype).type ) return strategy.filter(lambda x: min_value <= len(x) <= max_value) def field_element_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, checks: Optional[Sequence] = None, ) -> SearchStrategy: """Strategy to generate elements of a column or index. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param checks: sequence of :class:`~pandera.checks.Check` s to constrain @@ -698,25 +713,25 @@ def undefined_check_strategy(elements, check): "definition. This can considerably slow down data-generation." ) return ( - pandas_dtype_strategy(pandas_dtype) + pandas_dtype_strategy(pandera_dtype) if elements is None else elements ).filter(check._check_fn) for check in checks: if hasattr(check, "strategy"): - elements = check.strategy(pandas_dtype, elements) + elements = check.strategy(pandera_dtype, elements) elif check.element_wise: elements = undefined_check_strategy(elements, check) # NOTE: vectorized checks with undefined strategies should be handled # by the series/dataframe strategy. if elements is None: - elements = pandas_dtype_strategy(pandas_dtype) + elements = pandas_dtype_strategy(pandera_dtype) return elements def series_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, checks: Optional[Sequence] = None, @@ -727,7 +742,7 @@ def series_strategy( ): """Strategy to generate a pandas Series. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param checks: sequence of :class:`~pandera.checks.Check` s to constrain @@ -739,11 +754,11 @@ def series_strategy( :param size: number of elements in the Series. :returns: ``hypothesis`` strategy. """ - elements = field_element_strategy(pandas_dtype, strategy, checks=checks) + elements = field_element_strategy(pandera_dtype, strategy, checks=checks) strategy = ( pdst.series( elements=elements, - dtype=pandas_dtype.numpy_dtype, + dtype=to_numpy_dtype(pandera_dtype), index=pdst.range_indexes( min_size=0 if size is None else size, max_size=size ), @@ -751,7 +766,7 @@ def series_strategy( ) .filter(lambda x: x.shape[0] > 0) .map(lambda x: x.rename(name)) - .map(lambda x: x.astype(pandas_dtype.str_alias)) + .map(lambda x: x.astype(str(pandera_dtype))) ) if nullable: strategy = null_field_masks(strategy) @@ -777,7 +792,7 @@ def _check_fn(series): def column_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, checks: Optional[Sequence] = None, @@ -787,7 +802,7 @@ def column_strategy( # pylint: disable=line-too-long """Create a data object describing a column in a DataFrame. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param checks: sequence of :class:`~pandera.checks.Check` s to constrain @@ -797,18 +812,18 @@ def column_strategy( :param name: name of the Series. :returns: a `column `_ object. """ - verify_pandas_dtype(pandas_dtype, schema_type="column", name=name) - elements = field_element_strategy(pandas_dtype, strategy, checks=checks) + verify_dtype(pandera_dtype, schema_type="column", name=name) + elements = field_element_strategy(pandera_dtype, strategy, checks=checks) return pdst.column( name=name, elements=elements, - dtype=pandas_dtype.numpy_dtype, + dtype=to_numpy_dtype(pandera_dtype), unique=not allow_duplicates, ) def index_strategy( - pandas_dtype: PandasDtype, + pandera_dtype: Union[numpy_engine.DataType, pandas_engine.DataType], strategy: Optional[SearchStrategy] = None, *, checks: Optional[Sequence] = None, @@ -819,7 +834,7 @@ def index_strategy( ): """Strategy to generate a pandas Index. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param checks: sequence of :class:`~pandera.checks.Check` s to constrain @@ -831,15 +846,15 @@ def index_strategy( :param size: number of elements in the Series. :returns: ``hypothesis`` strategy. """ - verify_pandas_dtype(pandas_dtype, schema_type="index", name=name) - elements = field_element_strategy(pandas_dtype, strategy, checks=checks) + verify_dtype(pandera_dtype, schema_type="index", name=name) + elements = field_element_strategy(pandera_dtype, strategy, checks=checks) strategy = pdst.indexes( elements=elements, - dtype=pandas_dtype.numpy_dtype, + dtype=to_numpy_dtype(pandera_dtype), min_size=0 if size is None else size, max_size=size, unique=not allow_duplicates, - ).map(lambda x: x.astype(pandas_dtype.str_alias)) + ).map(lambda x: x.astype(str(pandera_dtype))) if name is not None: strategy = strategy.map(lambda index: index.rename(name)) if nullable: @@ -848,7 +863,7 @@ def index_strategy( def dataframe_strategy( - pandas_dtype: Optional[PandasDtype] = None, + pandera_dtype: Optional[DataType] = None, strategy: Optional[SearchStrategy] = None, *, columns: Optional[Dict] = None, @@ -859,7 +874,7 @@ def dataframe_strategy( ): """Strategy to generate a pandas DataFrame. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: if specified, this will raise a BaseStrategyOnlyError, since it cannot be chained to a prior strategy. :param columns: a dictionary where keys are column names and values @@ -920,18 +935,18 @@ def make_row_strategy(col, checks): strategy = None for check in checks: if hasattr(check, "strategy"): - strategy = check.strategy(col.pdtype, strategy) + strategy = check.strategy(col.dtype, strategy) else: strategy = undefined_check_strategy( strategy=( - pandas_dtype_strategy(col.pdtype) + pandas_dtype_strategy(col.dtype) if strategy is None else strategy ), check=check, ) if strategy is None: - strategy = pandas_dtype_strategy(col.pdtype) + strategy = pandas_dtype_strategy(col.dtype) return strategy @composite @@ -978,9 +993,9 @@ def _dataframe_strategy(draw): # override the column datatype with dataframe-level datatype if # specified col_dtypes = { - col_name: col.dtype - if pandas_dtype is None - else pandas_dtype.str_alias + col_name: str(col.dtype) + if pandera_dtype is None + else str(pandera_dtype) for col_name, col in expanded_columns.items() } @@ -1031,7 +1046,7 @@ def _dataframe_strategy(draw): # pylint: disable=unused-argument def multiindex_strategy( - pandas_dtype: Optional[PandasDtype] = None, + pandera_dtype: Optional[DataType] = None, strategy: Optional[SearchStrategy] = None, *, indexes: Optional[List] = None, @@ -1039,7 +1054,7 @@ def multiindex_strategy( ): """Strategy to generate a pandas MultiIndex object. - :param pandas_dtype: :class:`pandera.dtypes.PandasDtype` instance. + :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. :param strategy: an optional hypothesis strategy. If specified, the pandas dtype strategy will be chained onto this strategy. :param indexes: a list of :class:`~pandera.schema_components.Inded` @@ -1055,7 +1070,7 @@ def multiindex_strategy( ) indexes = [] if indexes is None else indexes index_dtypes = { - index.name if index.name is not None else i: index.dtype + index.name if index.name is not None else i: str(index.dtype) for i, index in enumerate(indexes) } nullable_index = { diff --git a/pandera/typing.py b/pandera/typing.py index d0f537579..f2d8829de 100644 --- a/pandera/typing.py +++ b/pandera/typing.py @@ -6,28 +6,28 @@ import pandas as pd import typing_inspect -from . import dtypes_ +from . import dtypes from .engines import numpy_engine, pandas_engine LEGACY_TYPING = sys.version_info[:2] < (3, 7) -Bool = dtypes_.Bool #: ``"bool"`` numpy dtype -DateTime = dtypes_.DateTime #: ``"datetime64[ns]"`` numpy dtype -Timedelta = dtypes_.Timedelta #: ``"timedelta64[ns]"`` numpy dtype -Category = dtypes_.Category #: pandas ``"categorical"`` datatype -Float = dtypes_.Float #: ``"float"`` numpy dtype -Float16 = dtypes_.Float16 #: ``"float16"`` numpy dtype -Float32 = dtypes_.Float32 #: ``"float32"`` numpy dtype -Float64 = dtypes_.Float64 #: ``"float64"`` numpy dtype -Int = dtypes_.Int #: ``"int"`` numpy dtype -Int8 = dtypes_.Int8 #: ``"int8"`` numpy dtype -Int16 = dtypes_.Int16 #: ``"int16"`` numpy dtype -Int32 = dtypes_.Int32 #: ``"int32"`` numpy dtype -Int64 = dtypes_.Int64 #: ``"int64"`` numpy dtype -UInt8 = dtypes_.UInt8 #: ``"uint8"`` numpy dtype -UInt16 = dtypes_.UInt16 #: ``"uint16"`` numpy dtype -UInt32 = dtypes_.UInt32 #: ``"uint32"`` numpy dtype -UInt64 = dtypes_.UInt64 #: ``"uint64"`` numpy dtype +Bool = dtypes.Bool #: ``"bool"`` numpy dtype +DateTime = dtypes.DateTime #: ``"datetime64[ns]"`` numpy dtype +Timedelta = dtypes.Timedelta #: ``"timedelta64[ns]"`` numpy dtype +Category = dtypes.Category #: pandas ``"categorical"`` datatype +Float = dtypes.Float #: ``"float"`` numpy dtype +Float16 = dtypes.Float16 #: ``"float16"`` numpy dtype +Float32 = dtypes.Float32 #: ``"float32"`` numpy dtype +Float64 = dtypes.Float64 #: ``"float64"`` numpy dtype +Int = dtypes.Int #: ``"int"`` numpy dtype +Int8 = dtypes.Int8 #: ``"int8"`` numpy dtype +Int16 = dtypes.Int16 #: ``"int16"`` numpy dtype +Int32 = dtypes.Int32 #: ``"int32"`` numpy dtype +Int64 = dtypes.Int64 #: ``"int64"`` numpy dtype +UInt8 = dtypes.UInt8 #: ``"uint8"`` numpy dtype +UInt16 = dtypes.UInt16 #: ``"uint16"`` numpy dtype +UInt32 = dtypes.UInt32 #: ``"uint32"`` numpy dtype +UInt64 = dtypes.UInt64 #: ``"uint64"`` numpy dtype INT8 = pandas_engine.INT8 #: ``"Int8"`` pandas dtype:: pandas 0.24.0+ INT16 = pandas_engine.INT16 #: ``"Int16"`` pandas dtype: pandas 0.24.0+ INT32 = pandas_engine.INT32 #: ``"Int32"`` pandas dtype: pandas 0.24.0+ @@ -37,7 +37,7 @@ UINT32 = pandas_engine.UINT32 #: ``"UInt32"`` pandas dtype: pandas 0.24.0+ UINT64 = pandas_engine.UINT64 #: ``"UInt64"`` pandas dtype: pandas 0.24.0+ Object = numpy_engine.Object #: ``"object"`` numpy dtype -String = dtypes_.String #: ``"str"`` numpy dtype +String = dtypes.String #: ``"str"`` numpy dtype #: ``"string"`` pandas dtypes: pandas 1.0.0+. For <1.0.0, this enum will #: fall back on the str-as-object-array representation. STRING = pandas_engine.STRING #: ``"str"`` numpy dtype diff --git a/requirements-dev.txt b/requirements-dev.txt index 455c67388..c95ac87a3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -15,13 +15,13 @@ frictionless black >= 20.8b1 isort >= 5.7.0 codecov -mypy +mypy >= 0.902 pylint >= 2.7.2 pytest pytest-cov pytest-xdist setuptools >= 52.0.0 -nox +nox == 2020.12.31 importlib_metadata sphinx == 3.5.4 sphinx_rtd_theme @@ -31,4 +31,7 @@ recommonmark twine asv pre_commit -furo \ No newline at end of file +furo +types-click +types-pyyaml +types-pkg_resources \ No newline at end of file diff --git a/setup.py b/setup.py index 915a1669c..85d39ad1e 100644 --- a/setup.py +++ b/setup.py @@ -38,13 +38,13 @@ install_requires=[ "packaging >= 20.0", "numpy >= 1.9.0", - "pandas >= 0.25.3", + "pandas >= 1.0", "typing_extensions >= 3.7.4.3 ; python_version<'3.8'", "typing_inspect >= 0.6.0", "wrapt", ], extras_require=extras_require, - python_requires=">=3.6", + python_requires=">=3.7", platforms="any", classifiers=[ "Development Status :: 5 - Production/Stable", @@ -53,7 +53,6 @@ "Intended Audience :: Science/Research", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", diff --git a/tests/core/test_dtypes.py b/tests/core/test_dtypes.py index e12f69db0..48e303c10 100644 --- a/tests/core/test_dtypes.py +++ b/tests/core/test_dtypes.py @@ -5,6 +5,7 @@ import dataclasses import datetime import inspect +import platform from decimal import Decimal from typing import Any, Dict, List, Tuple @@ -19,10 +20,12 @@ import pandera as pa from pandera.engines import pandas_engine +WINDOWS_PLATFORM = platform.system() == "Windows" + # List dtype classes and associated pandas alias, # except for parameterizable dtypes that should also list examples of instances. int_dtypes = { - int: "int", + int: "int64", pa.Int: "int64", pa.Int8: "int8", pa.Int16: "int16", @@ -67,13 +70,12 @@ pa.Float16: "float16", pa.Float32: "float32", pa.Float64: "float64", - pa.Float128: "float128", np.float16: "float16", np.float32: "float32", np.float64: "float64", - np.float128: "float128", } + complex_dtypes = { complex: "complex", pa.Complex: "complex128", @@ -81,6 +83,21 @@ pa.Complex128: "complex128", } + +if not WINDOWS_PLATFORM: + float_dtypes.update( + { + pa.Float128: "float128", + np.float128: "float128", + } + ) + complex_dtypes.update( + { + pa.Complex256: "complex256", + np.complex256: "complex256", + } + ) + boolean_dtypes = {bool: "bool", pa.Bool: "bool", np.bool_: "bool"} nullable_boolean_dtypes = {pd.BooleanDtype: "boolean", pa.BOOL: "boolean"} @@ -170,13 +187,12 @@ def pretty_param(*values: Any, **kw: Any) -> ParameterSet: def pytest_generate_tests(metafunc: Metafunc) -> None: - """Inject dtype, alias, data fixtures from `dtype_fixtures`. - - Filter pandera.dtypes.pa.DataType classes if the test name contains "datatype". + """Inject `dtype`, `data_type` (filter pandera DataTypes), `alias`, `data` + fixtures from `dtype_fixtures`. """ fixtures = [ fixture - for fixture in ("dtype", "pd_dtype", "data") + for fixture in ("data_type", "dtype", "pd_dtype", "data") if fixture in metafunc.fixturenames ] arg_names = ",".join(fixtures) @@ -185,7 +201,7 @@ def pytest_generate_tests(metafunc: Metafunc) -> None: arg_values = [] for dtypes, data in dtype_fixtures: for dtype, pd_dtype in dtypes.items(): - if "datatype" in metafunc.function.__name__ and not ( + if "data_type" in fixtures and not ( isinstance(dtype, pa.DataType) or ( inspect.isclass(dtype) @@ -205,23 +221,23 @@ def pytest_generate_tests(metafunc: Metafunc) -> None: metafunc.parametrize(arg_names, arg_values) -def test_datatype_init(dtype: Any): +def test_datatype_init(data_type: Any): """Test that a default pa.DataType can be constructed.""" - if not inspect.isclass(dtype): + if not inspect.isclass(data_type): pytest.skip( "test_datatype_init tests pa.DataType classes, not instances." ) - assert isinstance(dtype(), pa.DataType) + assert isinstance(data_type(), pa.DataType) -def test_datatype_alias(dtype: Any, pd_dtype: Any): +def test_datatype_alias(data_type: Any, pd_dtype: Any): """Test that a default pa.DataType can be constructed.""" - assert str(pandas_engine.Engine.dtype(dtype)) == str(pd_dtype) + assert str(pandas_engine.Engine.dtype(data_type)) == str(pd_dtype) -def test_frozen_datatype(dtype: Any): +def test_frozen_datatype(data_type: Any): """Test that pa.DataType instances are immutable.""" - data_type = dtype() if inspect.isclass(dtype) else dtype + data_type = data_type() if inspect.isclass(data_type) else data_type with pytest.raises(dataclasses.FrozenInstanceError): data_type.foo = 1 @@ -254,28 +270,24 @@ def test_check_not_equivalent(dtype: Any): def test_coerce_no_cast(dtype: Any, pd_dtype: Any, data: List[Any]): """Test that dtypes can be coerced without casting.""" expected_dtype = pandas_engine.Engine.dtype(dtype) - print(pd_dtype) series = pd.Series(data, dtype=pd_dtype) coerced_series = expected_dtype.coerce(series) + assert series.equals(coerced_series) - print(expected_dtype) - print(series) - print(coerced_series) - print(coerced_series.dtype) - print(pandas_engine.Engine.dtype(coerced_series.dtype)) assert expected_dtype.check( pandas_engine.Engine.dtype(coerced_series.dtype) ) df = pd.DataFrame({"col": data}, dtype=pd_dtype) coerced_df = expected_dtype.coerce(df) + assert df.equals(coerced_df) assert expected_dtype.check( pandas_engine.Engine.dtype(coerced_df["col"].dtype) ) -def _flatten_dtypes_dict(*dtype_kinds): +def _flatten_dtypesdict(*dtype_kinds): return [ (datatype, pd_dtype) for dtype_kind in dtype_kinds @@ -283,7 +295,7 @@ def _flatten_dtypes_dict(*dtype_kinds): ] -numeric_dtypes = _flatten_dtypes_dict( +numeric_dtypes = _flatten_dtypesdict( int_dtypes, uint_dtypes, float_dtypes, @@ -291,13 +303,13 @@ def _flatten_dtypes_dict(*dtype_kinds): boolean_dtypes, ) -nullable_numeric_dtypes = _flatten_dtypes_dict( +nullable_numeric_dtypes = _flatten_dtypesdict( nullable_int_dtypes, nullable_uint_dtypes, nullable_boolean_dtypes, ) -nominal_dtypes = _flatten_dtypes_dict( +nominal_dtypes = _flatten_dtypesdict( string_dtypes, nullable_string_dtypes, category_dtypes, @@ -339,21 +351,21 @@ def test_coerce_string(): def test_default_numeric_dtypes(): """Test that default numeric dtypes int, float and complex are consistent.""" - default_int_dtype = pd.Series([1], dtype=int).dtype + default_int_dtype = pd.Series([1]).dtype assert ( pandas_engine.Engine.dtype(default_int_dtype) == pandas_engine.Engine.dtype(int) == pandas_engine.Engine.dtype("int") ) - default_float_dtype = pd.Series([1], dtype=float).dtype + default_float_dtype = pd.Series([1.0]).dtype assert ( pandas_engine.Engine.dtype(default_float_dtype) == pandas_engine.Engine.dtype(float) == pandas_engine.Engine.dtype("float") ) - default_complex_dtype = pd.Series([1], dtype=complex).dtype + default_complex_dtype = pd.Series([complex(1)]).dtype assert ( pandas_engine.Engine.dtype(default_complex_dtype) == pandas_engine.Engine.dtype(complex) @@ -393,3 +405,108 @@ def test_inferred_dtype(examples: pd.Series): inferred_datatype = pandas_engine.Engine.dtype(alias) actual_dtype = pandas_engine.Engine.dtype(pd.Series(examples).dtype) assert actual_dtype.check(inferred_datatype) + + +@pytest.mark.parametrize( + "int_dtype, expected", + [(dtype, True) for dtype in (*int_dtypes, *nullable_int_dtypes)] + + [("string", False)], # type:ignore +) +def test_is_int(int_dtype: Any, expected: bool): + """Test is_int.""" + pandera_dtype = pandas_engine.Engine.dtype(int_dtype) + assert pa.dtypes.is_int(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "uint_dtype, expected", + [(dtype, True) for dtype in (*uint_dtypes, *nullable_uint_dtypes)] + + [("string", False)], # type:ignore +) +def test_is_uint(uint_dtype: Any, expected: bool): + """Test is_uint.""" + pandera_dtype = pandas_engine.Engine.dtype(uint_dtype) + assert pa.dtypes.is_uint(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "float_dtype, expected", + [(dtype, True) for dtype in float_dtypes] + [("string", False)], # type: ignore +) +def test_is_float(float_dtype: Any, expected: bool): + """Test is_float.""" + pandera_dtype = pandas_engine.Engine.dtype(float_dtype) + assert pa.dtypes.is_float(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "complex_dtype, expected", + [(dtype, True) for dtype in complex_dtypes] + + [("string", False)], # type: ignore +) +def test_is_complex(complex_dtype: Any, expected: bool): + """Test is_complex.""" + pandera_dtype = pandas_engine.Engine.dtype(complex_dtype) + assert pa.dtypes.is_complex(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "bool_dtype, expected", + [(dtype, True) for dtype in (*boolean_dtypes, *nullable_boolean_dtypes)] + + [("string", False)], +) +def test_is_bool(bool_dtype: Any, expected: bool): + """Test is_bool.""" + pandera_dtype = pandas_engine.Engine.dtype(bool_dtype) + assert pa.dtypes.is_bool(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "string_dtype, expected", + [(dtype, True) for dtype in string_dtypes] + + [("int", False)], # type:ignore +) +def test_is_string(string_dtype: Any, expected: bool): + """Test is_string.""" + pandera_dtype = pandas_engine.Engine.dtype(string_dtype) + assert pa.dtypes.is_string(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "category_dtype, expected", + [(dtype, True) for dtype in category_dtypes] + [("string", False)], +) +def test_is_category(category_dtype: Any, expected: bool): + """Test is_category.""" + pandera_dtype = pandas_engine.Engine.dtype(category_dtype) + assert pa.dtypes.is_category(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "datetime_dtype, expected", + [(dtype, True) for dtype in timestamp_dtypes] + [("string", False)], +) +def test_is_datetime(datetime_dtype: Any, expected: bool): + """Test is_datetime.""" + pandera_dtype = pandas_engine.Engine.dtype(datetime_dtype) + assert pa.dtypes.is_datetime(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "timedelta_dtype, expected", + [(dtype, True) for dtype in timedelta_dtypes] + [("string", False)], +) +def test_is_timedelta(timedelta_dtype: Any, expected: bool): + """Test is_timedelta.""" + pandera_dtype = pandas_engine.Engine.dtype(timedelta_dtype) + assert pa.dtypes.is_timedelta(pandera_dtype) == expected + + +@pytest.mark.parametrize( + "numeric_dtype, expected", + [(dtype, True) for dtype, _ in numeric_dtypes] + [("string", False)], +) +def test_is_numeric(numeric_dtype: Any, expected: bool): + """Test is_timedelta.""" + pandera_dtype = pandas_engine.Engine.dtype(numeric_dtype) + assert pa.dtypes.is_numeric(pandera_dtype) == expected diff --git a/tests/core/test_engine.py b/tests/core/test_engine.py index a85b395ac..a4b8bfefb 100644 --- a/tests/core/test_engine.py +++ b/tests/core/test_engine.py @@ -1,11 +1,12 @@ """Tests Engine subclassing and registring DataTypes.""" # pylint:disable=redefined-outer-name,unused-argument # pylint:disable=missing-function-docstring,missing-class-docstring +import re from typing import Any, Generator, List, Union import pytest -from pandera.dtypes_ import DataType +from pandera.dtypes import DataType from pandera.engines.engine import Engine @@ -158,14 +159,24 @@ def from_parametrized_dtype(cls, x: int): assert engine.dtype(42) == _DtypeB() -def test_register_base_pandera_dtypes(engine: Engine): +def test_register_base_pandera_dtypes(): """Test that base datatype cannot be registered.""" + + class FakeEngine( # pylint:disable=too-few-public-methods + metaclass=Engine, base_pandera_dtypes=(BaseDataType, BaseDataType) + ): + pass + with pytest.raises( ValueError, - match="BaseDataType subclasses cannot be registered with FakeEngine.", + match=re.escape( + "Subclasses of ['tests.core.test_engine.BaseDataType', " + + "'tests.core.test_engine.BaseDataType'] " + + "cannot be registered with FakeEngine." + ), ): - @engine.register_dtype(equivalents=[SimpleDtype]) + @FakeEngine.register_dtype(equivalents=[SimpleDtype]) class _Dtype(BaseDataType): pass diff --git a/tests/core/test_extensions.py b/tests/core/test_extensions.py index 17fc3b24e..cf5aeb598 100644 --- a/tests/core/test_extensions.py +++ b/tests/core/test_extensions.py @@ -9,7 +9,7 @@ import pandera as pa import pandera.strategies as st -from pandera import PandasDtype, extensions +from pandera import DataType, extensions from pandera.checks import Check @@ -202,7 +202,7 @@ def test_register_check_with_strategy(custom_check_teardown): import hypothesis # pylint: disable=import-outside-toplevel,import-error def custom_ge_strategy( - pandas_dtype: PandasDtype, + pandas_dtype: DataType, strategy: Optional[st.SearchStrategy] = None, *, min_value: Any, @@ -222,7 +222,7 @@ def custom_ge_check(pandas_obj, *, min_value): return pandas_obj >= min_value check = Check.custom_ge_check(min_value=0) - strat = check.strategy(PandasDtype.Int) + strat = check.strategy(pa.Int) with pytest.warns(hypothesis.errors.NonInteractiveExampleWarning): assert strat.example() >= 0 diff --git a/tests/core/test_schema_statistics.py b/tests/core/test_schema_statistics.py index d82f85f9d..a722013ef 100644 --- a/tests/core/test_schema_statistics.py +++ b/tests/core/test_schema_statistics.py @@ -1,14 +1,34 @@ # pylint: disable=W0212 """Unit tests for inferring statistics of pandas objects.""" - import pandas as pd import pytest import pandera as pa -from pandera import PandasDtype, dtypes, schema_statistics - -DEFAULT_INT = PandasDtype.from_str_alias(dtypes._DEFAULT_PANDAS_INT_TYPE) -DEFAULT_FLOAT = PandasDtype.from_str_alias(dtypes._DEFAULT_PANDAS_FLOAT_TYPE) +from pandera import dtypes, schema_statistics +from pandera.engines import pandas_engine + +DEFAULT_FLOAT = pandas_engine.Engine.dtype(float) +DEFAULT_INT = pandas_engine.Engine.dtype(int) + +NUMERIC_TYPES = [ + pa.Int(), + pa.UInt(), + pa.Float(), + pa.Complex(), + pandas_engine.Engine.dtype("Int32"), + pandas_engine.Engine.dtype("UInt32"), +] +INTEGER_TYPES = [ + dtypes.Int(), + dtypes.Int8(), + dtypes.Int16(), + dtypes.Int32(), + dtypes.Int64(), + dtypes.UInt8(), + dtypes.UInt16(), + dtypes.UInt32(), + dtypes.UInt64(), +] def _create_dataframe(multi_index=False, nullable=False): @@ -54,28 +74,36 @@ def test_infer_dataframe_statistics(multi_index, nullable): if nullable: # bool and int dtypes are cast to float in the nullable case - assert stat_columns["int"]["pandas_dtype"] is DEFAULT_FLOAT - assert stat_columns["boolean"]["pandas_dtype"] is DEFAULT_FLOAT + assert DEFAULT_FLOAT.check(stat_columns["int"]["dtype"]) + assert DEFAULT_FLOAT.check(stat_columns["boolean"]["dtype"]) else: - assert stat_columns["int"]["pandas_dtype"] is DEFAULT_INT - assert stat_columns["boolean"]["pandas_dtype"] is pa.Bool + assert DEFAULT_INT.check(stat_columns["int"]["dtype"]) + assert pandas_engine.Engine.dtype(bool).check( + stat_columns["boolean"]["dtype"] + ) - assert stat_columns["float"]["pandas_dtype"] is DEFAULT_FLOAT - assert stat_columns["string"]["pandas_dtype"] is pa.String - assert stat_columns["datetime"]["pandas_dtype"] is pa.DateTime + assert DEFAULT_FLOAT.check(stat_columns["float"]["dtype"]) + assert pandas_engine.Engine.dtype(str).check( + stat_columns["string"]["dtype"] + ) + assert pandas_engine.Engine.dtype(pa.DateTime).check( + stat_columns["datetime"]["dtype"] + ) if multi_index: stat_indices = statistics["index"] for stat_index, name, dtype in zip( - stat_indices, ["int_index", "str_index"], [DEFAULT_INT, pa.String] + stat_indices, + ["int_index", "str_index"], + [DEFAULT_INT, pandas_engine.Engine.dtype(str)], ): assert stat_index["name"] == name - assert stat_index["pandas_dtype"] is dtype + assert dtype.check(stat_index["dtype"]) assert not stat_index["nullable"] else: stat_index = statistics["index"][0] assert stat_index["name"] == "int_index" - assert stat_index["pandas_dtype"] is DEFAULT_INT + assert stat_index["dtype"] == DEFAULT_INT assert not stat_index["nullable"] for properties in stat_columns.values(): @@ -123,14 +151,30 @@ def test_parse_check_statistics(check_stats, expectation): assert set(checks) == set(expectation) +def _test_statistics(statistics, expectations): + if not isinstance(statistics, list): + statistics = [statistics] + if not isinstance(expectations, list): + expectations = [expectations] + + for stats, expectation in zip(statistics, expectations): + stat_dtype = stats.pop("dtype") + expectation_dtype = expectation.pop("dtype") + + assert stats == expectation + assert expectation_dtype.check(stat_dtype) + + @pytest.mark.parametrize( "series, expectation", [ *[ [ - pd.Series([1, 2, 3], dtype=dtype.str_alias), + pd.Series( + [1, 2, 3], dtype=str(pandas_engine.Engine.dtype(data_type)) + ), { - "pandas_dtype": dtype, + "dtype": pandas_engine.Engine.dtype(data_type), "nullable": False, "checks": { "greater_than_or_equal_to": 1, @@ -139,25 +183,21 @@ def test_parse_check_statistics(check_stats, expectation): "name": None, }, ] - for dtype in ( - x - for x in schema_statistics.NUMERIC_DTYPES - if x != PandasDtype.DateTime - ) + for data_type in NUMERIC_TYPES ], [ pd.Series(["a", "b", "c", "a"], dtype="category"), { - "pandas_dtype": pa.Category, + "dtype": pandas_engine.Engine.dtype(pa.Category), "nullable": False, "checks": {"isin": ["a", "b", "c"]}, "name": None, }, ], [ - pd.Series(["a", "b", "c", "a"], name="str_series"), + pd.Series(["a", "b", "c", "a"], dtype="string", name="str_series"), { - "pandas_dtype": pa.String, + "dtype": pandas_engine.Engine.dtype("string"), "nullable": False, "checks": None, "name": "str_series", @@ -166,7 +206,7 @@ def test_parse_check_statistics(check_stats, expectation): [ pd.Series(pd.to_datetime(["20180101", "20180102", "20180103"])), { - "pandas_dtype": pa.DateTime, + "dtype": pandas_engine.Engine.dtype(pa.DateTime), "nullable": False, "checks": { "greater_than_or_equal_to": pd.Timestamp("20180101"), @@ -180,20 +220,7 @@ def test_parse_check_statistics(check_stats, expectation): def test_infer_series_schema_statistics(series, expectation): """Test series statistics are correctly inferred.""" statistics = schema_statistics.infer_series_statistics(series) - assert statistics == expectation - - -INTEGER_TYPES = [ - PandasDtype.Int, - PandasDtype.Int8, - PandasDtype.Int16, - PandasDtype.Int32, - PandasDtype.Int64, - PandasDtype.UInt8, - PandasDtype.UInt16, - PandasDtype.UInt32, - PandasDtype.UInt64, -] + _test_statistics(statistics, expectation) @pytest.mark.parametrize( @@ -202,10 +229,10 @@ def test_infer_series_schema_statistics(series, expectation): *[ [ 0, - pd.Series([1, 2, 3], dtype=dtype.value), + pd.Series([1, 2, 3], dtype=str(data_type)), { # introducing nans to integer arrays upcasts to float - "pandas_dtype": DEFAULT_FLOAT, + "dtype": DEFAULT_FLOAT, "nullable": True, "checks": { "greater_than_or_equal_to": 2, @@ -214,14 +241,14 @@ def test_infer_series_schema_statistics(series, expectation): "name": None, }, ] - for dtype in INTEGER_TYPES + for data_type in INTEGER_TYPES ], [ # introducing nans to integer arrays upcasts to float 0, pd.Series([True, False, True, False]), { - "pandas_dtype": DEFAULT_FLOAT, + "dtype": DEFAULT_FLOAT, "nullable": True, "checks": { "greater_than_or_equal_to": 0, @@ -234,7 +261,7 @@ def test_infer_series_schema_statistics(series, expectation): 0, pd.Series(["a", "b", "c", "a"], dtype="category"), { - "pandas_dtype": pa.Category, + "dtype": pandas_engine.Engine.dtype(pa.Category), "nullable": True, "checks": {"isin": ["a", "b", "c"]}, "name": None, @@ -244,7 +271,7 @@ def test_infer_series_schema_statistics(series, expectation): 0, pd.Series(["a", "b", "c", "a"], name="str_series"), { - "pandas_dtype": pa.String, + "dtype": pandas_engine.Engine.dtype(str), "nullable": True, "checks": None, "name": "str_series", @@ -254,7 +281,7 @@ def test_infer_series_schema_statistics(series, expectation): 2, pd.Series(pd.to_datetime(["20180101", "20180102", "20180103"])), { - "pandas_dtype": pa.DateTime, + "dtype": pandas_engine.Engine.dtype(pa.DateTime), "nullable": True, "checks": { "greater_than_or_equal_to": pd.Timestamp("20180101"), @@ -271,7 +298,7 @@ def test_infer_nullable_series_schema_statistics( """Test nullable series statistics are correctly inferred.""" series.iloc[null_index] = None statistics = schema_statistics.infer_series_statistics(series) - assert statistics == expectation + _test_statistics(statistics, expectation) @pytest.mark.parametrize( @@ -282,7 +309,7 @@ def test_infer_nullable_series_schema_statistics( [ { "name": None, - "pandas_dtype": PandasDtype.Int, + "dtype": DEFAULT_INT, "nullable": False, "checks": { "greater_than_or_equal_to": 0, @@ -296,7 +323,7 @@ def test_infer_nullable_series_schema_statistics( [ { "name": "int_index", - "pandas_dtype": PandasDtype.Int, + "dtype": DEFAULT_INT, "nullable": False, "checks": { "greater_than_or_equal_to": 1, @@ -310,7 +337,7 @@ def test_infer_nullable_series_schema_statistics( [ { "name": "str_index", - "pandas_dtype": PandasDtype.String, + "dtype": pandas_engine.Engine.dtype("object"), "nullable": False, "checks": None, }, @@ -324,7 +351,7 @@ def test_infer_nullable_series_schema_statistics( [ { "name": "int_index", - "pandas_dtype": PandasDtype.Int, + "dtype": DEFAULT_INT, "nullable": False, "checks": { "greater_than_or_equal_to": 10, @@ -333,7 +360,7 @@ def test_infer_nullable_series_schema_statistics( }, { "name": "str_index", - "pandas_dtype": PandasDtype.Category, + "dtype": pandas_engine.Engine.dtype(pa.Category), "nullable": False, "checks": {"isin": ["a", "b", "c"]}, }, @@ -354,7 +381,9 @@ def test_infer_index_statistics(index, expectation): with pytest.warns(UserWarning, match="^index type .+ not recognized"): schema_statistics.infer_index_statistics(index) else: - assert schema_statistics.infer_index_statistics(index) == expectation + _test_statistics( + schema_statistics.infer_index_statistics(index), expectation + ) def test_get_dataframe_schema_statistics(): @@ -362,7 +391,7 @@ def test_get_dataframe_schema_statistics(): schema = pa.DataFrameSchema( columns={ "int": pa.Column( - pa.Int, + int, checks=[ pa.Check.greater_than_or_equal_to(0), pa.Check.less_than_or_equal_to(100), @@ -370,18 +399,19 @@ def test_get_dataframe_schema_statistics(): nullable=True, ), "float": pa.Column( - pa.Float, + float, checks=[ pa.Check.greater_than_or_equal_to(50), pa.Check.less_than_or_equal_to(100), ], ), "str": pa.Column( - pa.String, checks=[pa.Check.isin(["foo", "bar", "baz"])] + str, + checks=[pa.Check.isin(["foo", "bar", "baz"])], ), }, index=pa.Index( - pa.Int, + int, checks=pa.Check.greater_than_or_equal_to(0), nullable=False, name="int_index", @@ -391,7 +421,7 @@ def test_get_dataframe_schema_statistics(): "checks": None, "columns": { "int": { - "pandas_dtype": pa.Int, + "dtype": DEFAULT_INT, "checks": { "greater_than_or_equal_to": {"min_value": 0}, "less_than_or_equal_to": {"max_value": 100}, @@ -403,7 +433,7 @@ def test_get_dataframe_schema_statistics(): "regex": False, }, "float": { - "pandas_dtype": pa.Float, + "dtype": DEFAULT_FLOAT, "checks": { "greater_than_or_equal_to": {"min_value": 50}, "less_than_or_equal_to": {"max_value": 100}, @@ -415,7 +445,7 @@ def test_get_dataframe_schema_statistics(): "regex": False, }, "str": { - "pandas_dtype": pa.String, + "dtype": pandas_engine.Engine.dtype(str), "checks": {"isin": {"allowed_values": ["foo", "bar", "baz"]}}, "nullable": False, "allow_duplicates": True, @@ -426,7 +456,7 @@ def test_get_dataframe_schema_statistics(): }, "index": [ { - "pandas_dtype": pa.Int, + "dtype": DEFAULT_INT, "checks": {"greater_than_or_equal_to": {"min_value": 0}}, "nullable": False, "coerce": False, @@ -442,7 +472,7 @@ def test_get_dataframe_schema_statistics(): def test_get_series_schema_statistics(): """Test that series schema statistics logic is correct.""" schema = pa.SeriesSchema( - pa.Int, + int, nullable=False, checks=[ pa.Check.greater_than_or_equal_to(0), @@ -451,7 +481,7 @@ def test_get_series_schema_statistics(): ) statistics = schema_statistics.get_series_schema_statistics(schema) assert statistics == { - "pandas_dtype": pa.Int, + "dtype": pandas_engine.Engine.dtype(int), "nullable": False, "checks": { "greater_than_or_equal_to": {"min_value": 0}, @@ -467,7 +497,7 @@ def test_get_series_schema_statistics(): [ [ pa.Index( - pa.Int, + int, checks=[ pa.Check.greater_than_or_equal_to(10), pa.Check.less_than_or_equal_to(20), @@ -477,7 +507,7 @@ def test_get_series_schema_statistics(): ), [ { - "pandas_dtype": pa.Int, + "dtype": pandas_engine.Engine.dtype(int), "nullable": False, "checks": { "greater_than_or_equal_to": {"min_value": 10}, @@ -495,7 +525,7 @@ def test_get_index_schema_statistics(index_schema_component, expectation): statistics = schema_statistics.get_index_schema_statistics( index_schema_component ) - assert statistics == expectation + _test_statistics(statistics, expectation) @pytest.mark.parametrize( @@ -566,7 +596,10 @@ def test_parse_checks_and_statistics_roundtrip(checks, expectation): # pylint: disable=unused-argument def test_parse_checks_and_statistics_no_param(extra_registered_checks): - """Ensure that an edge case where a check does not have parameters is appropriately handled.""" + """ + Ensure that an edge case where a check does not have parameters is + appropriately handled. + """ checks = [pa.Check.no_param_check()] expectation = {"no_param_check": {}} diff --git a/tests/core/test_schemas.py b/tests/core/test_schemas.py index b625b3d04..cd3988681 100644 --- a/tests/core/test_schemas.py +++ b/tests/core/test_schemas.py @@ -828,8 +828,8 @@ def test_add_and_remove_columns(): schema2.remove_columns(["foo", "bar"]) -def test_schema_get_dtype(): - """Test that schema dtype and get_dtype methods handle regex columns.""" +def test_schema_get_dtypes(): + """Test that schema dtype and get_dtypes methods handle regex columns.""" schema = DataFrameSchema( { "col1": Column(int), @@ -1311,6 +1311,10 @@ def test_schema_transformer_deprecated(): ) def test_schema_coerce_inplace_validation(inplace, from_dtype, to_dtype): """Test coercion logic for validation when inplace is True and False""" + from_dtype = ( + from_dtype if from_dtype is not int else str(Engine.dtype(from_dtype)) + ) + to_dtype = to_dtype if to_dtype is not int else str(Engine.dtype(to_dtype)) df = pd.DataFrame({"column": pd.Series([1, 2, 6], dtype=from_dtype)}) schema = DataFrameSchema({"column": Column(to_dtype, coerce=True)}) validated_df = schema.validate(df, inplace=inplace) diff --git a/tests/core/test_typing.py b/tests/core/test_typing.py index bc1d937cb..0539c4c3f 100644 --- a/tests/core/test_typing.py +++ b/tests/core/test_typing.py @@ -8,7 +8,7 @@ import pytest import pandera as pa -from pandera.dtypes_ import DataType +from pandera.dtypes import DataType from pandera.typing import LEGACY_TYPING, Series if not LEGACY_TYPING: diff --git a/tests/io/test_io.py b/tests/io/test_io.py index 19edde926..00da31761 100644 --- a/tests/io/test_io.py +++ b/tests/io/test_io.py @@ -9,9 +9,10 @@ import pytest from packaging import version -import pandera as pa +import pandera import pandera.extensions as pa_ext import pandera.typing as pat +from pandera.engines import pandas_engine try: from pandera import io @@ -41,69 +42,69 @@ def _create_schema(index="single"): if index == "multi": - index = pa.MultiIndex( + index = pandera.MultiIndex( [ - pa.Index(pa.Int, name="int_index0"), - pa.Index(pa.Int, name="int_index1"), - pa.Index(pa.Int, name="int_index2"), + pandera.Index(pandera.Int, name="int_index0"), + pandera.Index(pandera.Int, name="int_index1"), + pandera.Index(pandera.Int, name="int_index2"), ] ) elif index == "single": # make sure io modules can handle case when index name is None - index = pa.Index(pa.Int, name=None) + index = pandera.Index(pandera.Int, name=None) else: index = None - return pa.DataFrameSchema( + return pandera.DataFrameSchema( columns={ - "int_column": pa.Column( - pa.Int, + "int_column": pandera.Column( + pandera.Int, checks=[ - pa.Check.greater_than(0), - pa.Check.less_than(10), - pa.Check.in_range(0, 10), + pandera.Check.greater_than(0), + pandera.Check.less_than(10), + pandera.Check.in_range(0, 10), ], ), - "float_column": pa.Column( - pa.Float, + "float_column": pandera.Column( + pandera.Float, checks=[ - pa.Check.greater_than(-10), - pa.Check.less_than(20), - pa.Check.in_range(-10, 20), + pandera.Check.greater_than(-10), + pandera.Check.less_than(20), + pandera.Check.in_range(-10, 20), ], ), - "str_column": pa.Column( - pa.String, + "str_column": pandera.Column( + pandera.String, checks=[ - pa.Check.isin(["foo", "bar", "x", "xy"]), - pa.Check.str_length(1, 3), + pandera.Check.isin(["foo", "bar", "x", "xy"]), + pandera.Check.str_length(1, 3), ], ), - "datetime_column": pa.Column( - pa.DateTime, + "datetime_column": pandera.Column( + pandera.DateTime, checks=[ - pa.Check.greater_than(pd.Timestamp("20100101")), - pa.Check.less_than(pd.Timestamp("20200101")), + pandera.Check.greater_than(pd.Timestamp("20100101")), + pandera.Check.less_than(pd.Timestamp("20200101")), ], ), - "timedelta_column": pa.Column( - pa.Timedelta, + "timedelta_column": pandera.Column( + pandera.Timedelta, checks=[ - pa.Check.greater_than(pd.Timedelta(1000, unit="ns")), - pa.Check.less_than(pd.Timedelta(10000, unit="ns")), + pandera.Check.greater_than(pd.Timedelta(1000, unit="ns")), + pandera.Check.less_than(pd.Timedelta(10000, unit="ns")), ], ), - "optional_props_column": pa.Column( - pa.String, + "optional_props_column": pandera.Column( + pandera.String, nullable=True, allow_duplicates=True, coerce=True, required=False, regex=True, - checks=[pa.Check.str_length(1, 3)], + checks=[pandera.Check.str_length(1, 3)], ), - "notype_column": pa.Column( - checks=pa.Check.isin(["foo", "bar", "x", "xy"]), + "notype_column": pandera.Column( + checks=pandera.Check.isin(["foo", "bar", "x", "xy"]), ), }, index=index, @@ -114,10 +115,10 @@ def _create_schema(index="single"): YAML_SCHEMA = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: int_column: - pandas_dtype: int + dtype: int64 nullable: false checks: greater_than: 0 @@ -130,7 +131,7 @@ def _create_schema(index="single"): required: true regex: false float_column: - pandas_dtype: float + dtype: float64 nullable: false checks: greater_than: -10 @@ -143,7 +144,7 @@ def _create_schema(index="single"): required: true regex: false str_column: - pandas_dtype: str + dtype: str nullable: false checks: isin: @@ -159,7 +160,7 @@ def _create_schema(index="single"): required: true regex: false datetime_column: - pandas_dtype: datetime64[ns] + dtype: datetime64[ns] nullable: false checks: greater_than: '2010-01-01 00:00:00' @@ -169,7 +170,7 @@ def _create_schema(index="single"): required: true regex: false timedelta_column: - pandas_dtype: timedelta64[ns] + dtype: timedelta64[ns] nullable: false checks: greater_than: 1000 @@ -179,7 +180,7 @@ def _create_schema(index="single"): required: true regex: false optional_props_column: - pandas_dtype: str + dtype: str nullable: true checks: str_length: @@ -190,7 +191,7 @@ def _create_schema(index="single"): required: false regex: true notype_column: - pandas_dtype: null + dtype: null nullable: false checks: isin: @@ -204,7 +205,7 @@ def _create_schema(index="single"): regex: false checks: null index: -- pandas_dtype: int +- dtype: int64 nullable: false checks: null name: null @@ -216,21 +217,21 @@ def _create_schema(index="single"): def _create_schema_null_index(): - return pa.DataFrameSchema( + return pandera.DataFrameSchema( columns={ - "float_column": pa.Column( - pa.Float, + "float_column": pandera.Column( + pandera.Float, checks=[ - pa.Check.greater_than(-10), - pa.Check.less_than(20), - pa.Check.in_range(-10, 20), + pandera.Check.greater_than(-10), + pandera.Check.less_than(20), + pandera.Check.in_range(-10, 20), ], ), - "str_column": pa.Column( - pa.String, + "str_column": pandera.Column( + pandera.String, checks=[ - pa.Check.isin(["foo", "bar", "x", "xy"]), - pa.Check.str_length(1, 3), + pandera.Check.isin(["foo", "bar", "x", "xy"]), + pandera.Check.str_length(1, 3), ], ), }, @@ -240,10 +241,10 @@ def _create_schema_null_index(): YAML_SCHEMA_NULL_INDEX = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: float_column: - pandas_dtype: float + dtype: float64 nullable: false checks: greater_than: -10 @@ -252,7 +253,7 @@ def _create_schema_null_index(): min_value: -10 max_value: 20 str_column: - pandas_dtype: str + dtype: str nullable: false checks: isin: @@ -271,28 +272,28 @@ def _create_schema_null_index(): def _create_schema_python_types(): - return pa.DataFrameSchema( + return pandera.DataFrameSchema( { - "int_column": pa.Column(int), - "float_column": pa.Column(float), - "str_column": pa.Column(str), - "object_column": pa.Column(object), + "int_column": pandera.Column(int), + "float_column": pandera.Column(float), + "str_column": pandera.Column(str), + "object_column": pandera.Column(object), } ) YAML_SCHEMA_PYTHON_TYPES = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: int_column: - pandas_dtype: int64 + dtype: int64 float_column: - pandas_dtype: float64 + dtype: float64 str_column: - pandas_dtype: str + dtype: str object_column: - pandas_dtype: object + dtype: object checks: null index: null coerce: false @@ -302,16 +303,16 @@ def _create_schema_python_types(): YAML_SCHEMA_MISSING_GLOBAL_CHECK = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: int_column: - pandas_dtype: int64 + dtype: int64 float_column: - pandas_dtype: float64 + dtype: float64 str_column: - pandas_dtype: str + dtype: str object_column: - pandas_dtype: object + dtype: object checks: unregistered_check: stat1: missing_str_stat @@ -324,20 +325,20 @@ def _create_schema_python_types(): YAML_SCHEMA_MISSING_COLUMN_CHECK = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: int_column: - pandas_dtype: int64 + dtype: int64 checks: unregistered_check: stat1: missing_str_stat stat2: 11 float_column: - pandas_dtype: float64 + dtype: float64 str_column: - pandas_dtype: str + dtype: str object_column: - pandas_dtype: object + dtype: object index: null coerce: false strict: false @@ -357,7 +358,7 @@ def test_inferred_schema_io(): "column3": ["a", "b", "c"], } ) - schema = pa.infer_schema(df) + schema = pandera.infer_schema(df) schema_yaml_str = schema.to_yaml() schema_from_yaml = io.from_yaml(schema_yaml_str) assert schema == schema_from_yaml @@ -371,6 +372,10 @@ def test_to_yaml(): """Test that to_yaml writes to yaml string.""" schema = _create_schema() yaml_str = io.to_yaml(schema) + with tempfile.NamedTemporaryFile("w+") as f: + f.write(yaml_str) + with tempfile.NamedTemporaryFile("w+") as f: + f.write(YAML_SCHEMA) assert yaml_str.strip() == YAML_SCHEMA.strip() yaml_str_schema_method = schema.to_yaml() @@ -412,7 +417,7 @@ def test_from_yaml_load_required_fields(): io.from_yaml("") with pytest.raises( - pa.errors.SchemaDefinitionError, match=".*must be a mapping.*" + pandera.errors.SchemaDefinitionError, match=".*must be a mapping.*" ): io.from_yaml( """ @@ -430,7 +435,7 @@ def test_io_yaml_file_obj(): output = schema.to_yaml(f) assert output is None f.seek(0) - schema_from_yaml = pa.DataFrameSchema.from_yaml(f) + schema_from_yaml = pandera.DataFrameSchema.from_yaml(f) assert schema_from_yaml == schema @@ -454,7 +459,7 @@ def test_io_yaml(index): with tempfile.NamedTemporaryFile("w+") as f: output = schema.to_yaml(Path(f.name)) assert output is None - schema_from_yaml = pa.DataFrameSchema.from_yaml(Path(f.name)) + schema_from_yaml = pandera.DataFrameSchema.from_yaml(Path(f.name)) assert schema_from_yaml == schema @@ -488,44 +493,48 @@ def test_to_script(index): def test_to_script_lambda_check(): """Test writing DataFrameSchema to a script with lambda check.""" - schema1 = pa.DataFrameSchema( + schema1 = pandera.DataFrameSchema( { - "a": pa.Column( - pa.Int, - checks=pa.Check(lambda s: s.mean() > 5, element_wise=False), + "a": pandera.Column( + pandera.Int, + checks=pandera.Check( + lambda s: s.mean() > 5, element_wise=False + ), ), } ) with pytest.warns(UserWarning): - pa.io.to_script(schema1) + pandera.io.to_script(schema1) - schema2 = pa.DataFrameSchema( + schema2 = pandera.DataFrameSchema( { - "a": pa.Column( - pa.Int, + "a": pandera.Column( + pandera.Int, ), }, - checks=pa.Check(lambda s: s.mean() > 5, element_wise=False), + checks=pandera.Check(lambda s: s.mean() > 5, element_wise=False), ) with pytest.warns(UserWarning, match=".*registered checks.*"): - pa.io.to_script(schema2) + pandera.io.to_script(schema2) def test_to_yaml_lambda_check(): """Test writing DataFrameSchema to a yaml with lambda check.""" - schema = pa.DataFrameSchema( + schema = pandera.DataFrameSchema( { - "a": pa.Column( - pa.Int, - checks=pa.Check(lambda s: s.mean() > 5, element_wise=False), + "a": pandera.Column( + pandera.Int, + checks=pandera.Check( + lambda s: s.mean() > 5, element_wise=False + ), ), } ) with pytest.warns(UserWarning): - pa.io.to_yaml(schema) + pandera.io.to_yaml(schema) def test_format_checks_warning(): @@ -555,24 +564,24 @@ def ncols_gt(pandas_obj: pd.DataFrame, column_count: int) -> bool: return len(pandas_obj.columns) > column_count assert ( - len(pa.Check.REGISTERED_CUSTOM_CHECKS) == 1 + len(pandera.Check.REGISTERED_CUSTOM_CHECKS) == 1 ), "custom check is registered" - schema = pa.DataFrameSchema( + schema = pandera.DataFrameSchema( { - "a": pa.Column( - pa.Int, + "a": pandera.Column( + pandera.Int, ), }, - checks=[pa.Check.ncols_gt(column_count=5)], + checks=[pandera.Check.ncols_gt(column_count=5)], ) - serialized = pa.io.to_yaml(schema) - loaded = pa.io.from_yaml(serialized) + serialized = pandera.io.to_yaml(schema) + loaded = pandera.io.from_yaml(serialized) assert len(loaded.checks) == 1, "global check was stripped" - with pytest.raises(pa.errors.SchemaError): + with pytest.raises(pandera.errors.SchemaError): schema.validate(pd.DataFrame(data={"a": [1]})) assert ncols_gt_called, "did not call ncols_gt" @@ -581,17 +590,17 @@ def ncols_gt(pandas_obj: pd.DataFrame, column_count: int) -> bool: def test_to_yaml_custom_dataframe_check(): """Tests that writing DataFrameSchema with an unregistered check raises.""" - schema = pa.DataFrameSchema( + schema = pandera.DataFrameSchema( { - "a": pa.Column( - pa.Int, + "a": pandera.Column( + pandera.Int, ), }, - checks=[pa.Check(lambda obj: len(obj.index) > 1)], + checks=[pandera.Check(lambda obj: len(obj.index) > 1)], ) with pytest.warns(UserWarning, match=".*registered checks.*"): - pa.io.to_yaml(schema) + pandera.io.to_yaml(schema) # the unregistered column check case is tested in # `test_to_yaml_lambda_check` @@ -601,13 +610,13 @@ def test_to_yaml_bugfix_419(): """Ensure that GH#419 is fixed""" # pylint: disable=no-self-use - class CheckedSchemaModel(pa.SchemaModel): + class CheckedSchemaModel(pandera.SchemaModel): """Schema with a global check""" a: pat.Series[pat.Int64] b: pat.Series[pat.Int64] - @pa.dataframe_check() + @pandera.dataframe_check() def unregistered_check(self, _): """sample unregistered check""" ... @@ -706,17 +715,17 @@ def unregistered_check(self, _): } # pandas dtype aliases to support testing across multiple pandas versions: -STR_DTYPE = pa.dtypes.PandasDtype.from_str_alias("string").value -STR_DTYPE_ALIAS = pa.dtypes.PandasDtype.from_str_alias("string").str_alias -INT_DTYPE = pa.dtypes.PandasDtype.from_str_alias("int").value -INT_DTYPE_ALIAS = pa.dtypes.PandasDtype.from_str_alias("int").str_alias +STR_DTYPE = pandas_engine.Engine.dtype("string") +STR_DTYPE_ALIAS = str(pandas_engine.Engine.dtype("string")) +INT_DTYPE = pandas_engine.Engine.dtype("int") +INT_DTYPE_ALIAS = str(pandas_engine.Engine.dtype("int")) YAML_FROM_FRICTIONLESS = f""" schema_type: dataframe -version: {pa.__version__} +version: {pandera.__version__} columns: integer_col: - pandas_dtype: {INT_DTYPE} + dtype: {INT_DTYPE} nullable: false checks: in_range: @@ -727,7 +736,7 @@ def unregistered_check(self, _): required: true regex: false integer_col_2: - pandas_dtype: {INT_DTYPE} + dtype: {INT_DTYPE} nullable: true checks: less_than_or_equal_to: 30 @@ -736,7 +745,7 @@ def unregistered_check(self, _): required: true regex: false string_col: - pandas_dtype: {STR_DTYPE} + dtype: {STR_DTYPE} nullable: true checks: str_length: @@ -747,7 +756,7 @@ def unregistered_check(self, _): required: true regex: false string_col_2: - pandas_dtype: {STR_DTYPE} + dtype: {STR_DTYPE} nullable: true checks: str_matches: ^\\d{{3}}[A-Z]$ @@ -756,7 +765,7 @@ def unregistered_check(self, _): required: true regex: false string_col_3: - pandas_dtype: {STR_DTYPE} + dtype: {STR_DTYPE} nullable: true checks: str_length: 3 @@ -765,7 +774,7 @@ def unregistered_check(self, _): required: true regex: false string_col_4: - pandas_dtype: {STR_DTYPE} + dtype: {STR_DTYPE} nullable: true checks: str_length: 3 @@ -774,7 +783,7 @@ def unregistered_check(self, _): required: true regex: false float_col: - pandas_dtype: category + dtype: category nullable: false checks: isin: @@ -786,7 +795,7 @@ def unregistered_check(self, _): required: true regex: false float_col_2: - pandas_dtype: float + dtype: float64 nullable: true checks: null allow_duplicates: true @@ -794,7 +803,7 @@ def unregistered_check(self, _): required: true regex: false date_col: - pandas_dtype: {STR_DTYPE} + dtype: {STR_DTYPE} nullable: true checks: greater_than_or_equal_to: '20201231' @@ -848,12 +857,12 @@ def unregistered_check(self, _): ) def test_frictionless_schema_parses_correctly(frictionless_schema): """Test parsing frictionless schema from yaml and json.""" - schema = pa.io.from_frictionless_schema(frictionless_schema) + schema = pandera.io.from_frictionless_schema(frictionless_schema) assert str(schema.to_yaml()).strip() == YAML_FROM_FRICTIONLESS.strip() assert isinstance( - schema, pa.schemas.DataFrameSchema + schema, pandera.schemas.DataFrameSchema ), "schema object not loaded successfully" df = schema.validate(VALID_FRICTIONLESS_DF) @@ -871,7 +880,7 @@ def test_frictionless_schema_parses_correctly(frictionless_schema): "date_col": STR_DTYPE_ALIAS, }, "dtypes not parsed correctly from frictionless schema" - with pytest.raises(pa.errors.SchemaErrors) as err: + with pytest.raises(pandera.errors.SchemaErrors) as err: schema.validate(INVALID_FRICTIONLESS_DF, lazy=True) # check we're capturing all errors according to the frictionless schema: assert err.value.failure_cases[["check", "failure_case"]].fillna( diff --git a/tests/strategies/test_strategies.py b/tests/strategies/test_strategies.py index 715c595eb..850ed1292 100644 --- a/tests/strategies/test_strategies.py +++ b/tests/strategies/test_strategies.py @@ -1,8 +1,6 @@ # pylint: disable=undefined-variable,redefined-outer-name,invalid-name,undefined-loop-variable # noqa """Unit tests for pandera data generating strategies.""" - import operator -import platform import re from typing import Any from unittest.mock import MagicMock @@ -14,6 +12,8 @@ import pandera as pa import pandera.strategies as strategies from pandera.checks import _CheckBase, register_check_statistics +from pandera.dtypes import is_category, is_complex, is_float +from pandera.engines import pandas_engine try: import hypothesis @@ -27,88 +27,98 @@ HAS_HYPOTHESIS = True -TYPE_ERROR_FMT = "data generation for the {} dtype is currently unsupported" - -SUPPORTED_DTYPES = [] -for pdtype in pa.PandasDtype: +SUPPORTED_DTYPES = set() +for data_type in pandas_engine.Engine.get_registered_dtypes(): if ( - pdtype is pa.PandasDtype.Complex256 and platform.system() == "Windows" - ) or pdtype is pa.Category: + # valid hypothesis.strategies.floats <=64 + getattr(data_type, "bit_width", -1) > 64 + or is_category(data_type) + or data_type + in (pandas_engine.Interval, pandas_engine.Period, pandas_engine.Sparse) + ): continue - SUPPORTED_DTYPES.append(pdtype) + SUPPORTED_DTYPES.add(pandas_engine.Engine.dtype(data_type)) NUMERIC_DTYPES = [ - pdtype for pdtype in SUPPORTED_DTYPES if pdtype.is_continuous + data_type for data_type in SUPPORTED_DTYPES if data_type.continuous ] NULLABLE_DTYPES = [ - pdtype - for pdtype in SUPPORTED_DTYPES - if not pdtype.is_complex - and not pdtype.is_category - and not pdtype.is_object + data_type + for data_type in SUPPORTED_DTYPES + if not is_complex(data_type) + and not is_category(data_type) + and not data_type == pandas_engine.Engine.dtype("object") ] NUMERIC_RANGE_CONSTANT = 10 DATE_RANGE_CONSTANT = np.timedelta64(NUMERIC_RANGE_CONSTANT, "D") COMPLEX_RANGE_CONSTANT = np.complex64( - complex(NUMERIC_RANGE_CONSTANT, NUMERIC_RANGE_CONSTANT) + complex(NUMERIC_RANGE_CONSTANT, NUMERIC_RANGE_CONSTANT) # type: ignore ) -@pytest.mark.parametrize("pdtype", [pa.Category]) -def test_unsupported_pandas_dtype_strategy(pdtype): +@pytest.mark.parametrize("data_type", [pa.Category]) +def test_unsupported_pandas_dtype_strategy(data_type): """Test unsupported pandas dtype strategy raises error.""" - with pytest.raises(TypeError, match=TYPE_ERROR_FMT.format(pdtype.name)): - strategies.pandas_dtype_strategy(pdtype) + with pytest.raises( + TypeError, + match="data generation for the Category dtype is currently unsupported", + ): + strategies.pandas_dtype_strategy(data_type) -@pytest.mark.parametrize("pdtype", SUPPORTED_DTYPES) +@pytest.mark.parametrize("data_type", SUPPORTED_DTYPES) @hypothesis.given(st.data()) -def test_pandas_dtype_strategy(pdtype, data): +def test_pandas_dtype_strategy(data_type, data): """Test that series can be constructed from pandas dtype.""" - strategy = strategies.pandas_dtype_strategy(pdtype) + strategy = strategies.pandas_dtype_strategy(data_type) example = data.draw(strategy) - expected_type = ( - pdtype.String.numpy_dtype.type - if pdtype is pa.Object - else pdtype.numpy_dtype.type - ) - + expected_type = strategies.to_numpy_dtype(data_type).type assert example.dtype.type == expected_type - chained_strategy = strategies.pandas_dtype_strategy(pdtype, strategy) + chained_strategy = strategies.pandas_dtype_strategy(data_type, strategy) chained_example = data.draw(chained_strategy) assert chained_example.dtype.type == expected_type -@pytest.mark.parametrize("pdtype", NUMERIC_DTYPES) +@pytest.mark.parametrize("data_type", NUMERIC_DTYPES) @hypothesis.given(st.data()) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) -def test_check_strategy_continuous(pdtype, data): +def test_check_strategy_continuous(data_type, data): """Test built-in check strategies can generate continuous data.""" + np_dtype = strategies.to_numpy_dtype(data_type) value = data.draw( npst.from_dtype( - pdtype.numpy_dtype, + strategies.to_numpy_dtype(data_type), allow_nan=False, allow_infinity=False, ) ) - pdtype = pa.PandasDtype.Int - value = data.draw(npst.from_dtype(pdtype.numpy_dtype)) - assert data.draw(strategies.ne_strategy(pdtype, value=value)) != value - assert data.draw(strategies.eq_strategy(pdtype, value=value)) == value - assert data.draw(strategies.gt_strategy(pdtype, min_value=value)) > value - assert data.draw(strategies.ge_strategy(pdtype, min_value=value)) >= value - assert data.draw(strategies.lt_strategy(pdtype, max_value=value)) < value - assert data.draw(strategies.le_strategy(pdtype, max_value=value)) <= value + # don't overstep bounds of representation + hypothesis.assume(np.finfo(np_dtype).min < value < np.finfo(np_dtype).max) + + assert data.draw(strategies.ne_strategy(data_type, value=value)) != value + assert data.draw(strategies.eq_strategy(data_type, value=value)) == value + assert ( + data.draw(strategies.gt_strategy(data_type, min_value=value)) > value + ) + assert ( + data.draw(strategies.ge_strategy(data_type, min_value=value)) >= value + ) + assert ( + data.draw(strategies.lt_strategy(data_type, max_value=value)) < value + ) + assert ( + data.draw(strategies.le_strategy(data_type, max_value=value)) <= value + ) -def value_ranges(pdtype: pa.PandasDtype): +def value_ranges(data_type: pa.DataType): """Strategy to generate value range based on PandasDtype""" kwargs = dict( allow_nan=False, @@ -118,15 +128,19 @@ def value_ranges(pdtype: pa.PandasDtype): ) return ( st.tuples( - strategies.pandas_dtype_strategy(pdtype, strategy=None, **kwargs), - strategies.pandas_dtype_strategy(pdtype, strategy=None, **kwargs), + strategies.pandas_dtype_strategy( + data_type, strategy=None, **kwargs + ), + strategies.pandas_dtype_strategy( + data_type, strategy=None, **kwargs + ), ) .map(sorted) .filter(lambda x: x[0] < x[1]) ) -@pytest.mark.parametrize("pdtype", NUMERIC_DTYPES) +@pytest.mark.parametrize("data_type", NUMERIC_DTYPES) @pytest.mark.parametrize( "strat_fn, arg_name, base_st_type, compare_op", [ @@ -143,17 +157,17 @@ def value_ranges(pdtype: pa.PandasDtype): suppress_health_check=[hypothesis.HealthCheck.too_slow], ) def test_check_strategy_chained_continuous( - pdtype, strat_fn, arg_name, base_st_type, compare_op, data + data_type, strat_fn, arg_name, base_st_type, compare_op, data ): """ Test built-in check strategies can generate continuous data building off of a parent strategy. """ - min_value, max_value = data.draw(value_ranges(pdtype)) + min_value, max_value = data.draw(value_ranges(data_type)) hypothesis.assume(min_value < max_value) value = min_value base_st = strategies.pandas_dtype_strategy( - pdtype, + data_type, min_value=min_value, max_value=max_value, allow_nan=False, @@ -165,7 +179,7 @@ def test_check_strategy_chained_continuous( assert_base_st = st.just(value) elif base_st_type == "limit": assert_base_st = strategies.pandas_dtype_strategy( - pdtype, + data_type, min_value=min_value, max_value=max_value, allow_nan=False, @@ -177,25 +191,25 @@ def test_check_strategy_chained_continuous( local_vars = locals() assert_value = local_vars[arg_name] example = data.draw( - strat_fn(pdtype, assert_base_st, **{arg_name: assert_value}) + strat_fn(data_type, assert_base_st, **{arg_name: assert_value}) ) assert compare_op(example, assert_value) -@pytest.mark.parametrize("pdtype", NUMERIC_DTYPES) +@pytest.mark.parametrize("data_type", NUMERIC_DTYPES) @pytest.mark.parametrize("chained", [True, False]) @hypothesis.given(st.data()) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) -def test_in_range_strategy(pdtype, chained, data): +def test_in_range_strategy(data_type, chained, data): """Test the built-in in-range strategy can correctly generate data.""" - min_value, max_value = data.draw(value_ranges(pdtype)) + min_value, max_value = data.draw(value_ranges(data_type)) hypothesis.assume(min_value < max_value) base_st_in_range = None if chained: - if pdtype.is_float: + if is_float(data_type): base_st_kwargs = { "exclude_min": False, "exclude_max": False, @@ -205,13 +219,13 @@ def test_in_range_strategy(pdtype, chained, data): # constraining the strategy this way makes testing more efficient base_st_in_range = strategies.pandas_dtype_strategy( - pdtype, + data_type, min_value=min_value, max_value=max_value, **base_st_kwargs, ) strat = strategies.in_range_strategy( - pdtype, + data_type, base_st_in_range, min_value=min_value, max_value=max_value, @@ -221,18 +235,18 @@ def test_in_range_strategy(pdtype, chained, data): @pytest.mark.parametrize( - "pdtype", - [pdtype for pdtype in SUPPORTED_DTYPES if pdtype.is_continuous], + "data_type", + [data_type for data_type in SUPPORTED_DTYPES if data_type.continuous], ) @pytest.mark.parametrize("chained", [True, False]) @hypothesis.given(st.data()) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) -def test_isin_notin_strategies(pdtype, chained, data): +def test_isin_notin_strategies(data_type, chained, data): """Test built-in check strategies that rely on discrete values.""" value_st = strategies.pandas_dtype_strategy( - pdtype, + data_type, allow_nan=False, allow_infinity=False, exclude_min=False, @@ -245,17 +259,17 @@ def test_isin_notin_strategies(pdtype, chained, data): if chained: base_values = values + [data.draw(value_st) for _ in range(10)] isin_base_st = strategies.isin_strategy( - pdtype, allowed_values=base_values + data_type, allowed_values=base_values ) notin_base_st = strategies.notin_strategy( - pdtype, forbidden_values=base_values + data_type, forbidden_values=base_values ) isin_st = strategies.isin_strategy( - pdtype, isin_base_st, allowed_values=values + data_type, isin_base_st, allowed_values=values ) notin_st = strategies.notin_strategy( - pdtype, notin_base_st, forbidden_values=values + data_type, notin_base_st, forbidden_values=values ) assert data.draw(isin_st) in values assert data.draw(notin_st) not in values @@ -310,7 +324,7 @@ def test_str_pattern_checks(str_strat, pattern_fn, chained, data, pattern): st.integers(min_value=0, max_value=100), st.integers(min_value=0, max_value=100), ) - .map(sorted) + .map(sorted) # type: ignore .filter(lambda x: x[0] < x[1]) # type: ignore ), ) @@ -337,12 +351,12 @@ def test_register_check_strategy(data): # pylint: disable=unused-argument def custom_eq_strategy( - pandas_dtype: pa.PandasDtype, + pandas_dtype: pa.DataType, strategy: st.SearchStrategy = None, *, value: Any, ): - return st.just(value).map(pandas_dtype.numpy_dtype.type) + return st.just(value).map(strategies.to_numpy_dtype(pandas_dtype).type) # pylint: disable=no-member class CustomCheck(_CheckBase): @@ -366,7 +380,7 @@ def _custom_equals(series: pd.Series) -> pd.Series: ) check = CustomCheck.custom_equals(100) - result = data.draw(check.strategy(pa.Int)) + result = data.draw(check.strategy(pa.Int())) assert result == 100 @@ -407,13 +421,13 @@ def _custom_check(series: pd.Series) -> pd.Series: ) def test_series_strategy(data): """Test SeriesSchema strategy.""" - series_schema = pa.SeriesSchema(pa.Int, pa.Check.gt(0)) + series_schema = pa.SeriesSchema(pa.Int(), pa.Check.gt(0)) series_schema(data.draw(series_schema.strategy())) def test_series_example(): """Test SeriesSchema example method generate examples that pass.""" - series_schema = pa.SeriesSchema(pa.Int, pa.Check.gt(0)) + series_schema = pa.SeriesSchema(pa.Int(), pa.Check.gt(0)) for _ in range(10): series_schema(series_schema.example()) @@ -424,33 +438,25 @@ def test_series_example(): ) def test_column_strategy(data): """Test Column schema strategy.""" - column_schema = pa.Column(pa.Int, pa.Check.gt(0), name="column") + column_schema = pa.Column(pa.Int(), pa.Check.gt(0), name="column") column_schema(data.draw(column_schema.strategy())) def test_column_example(): """Test Column schema example method generate examples that pass.""" - column_schema = pa.Column(pa.Int, pa.Check.gt(0), name="column") + column_schema = pa.Column(pa.Int(), pa.Check.gt(0), name="column") for _ in range(10): column_schema(column_schema.example()) -@pytest.mark.parametrize( - "pdtype", - SUPPORTED_DTYPES, -) -@pytest.mark.parametrize( - "size", - [None, 0, 1, 3, 5], -) +@pytest.mark.parametrize("data_type", SUPPORTED_DTYPES) +@pytest.mark.parametrize("size", [None, 0, 1, 3, 5]) @hypothesis.given(st.data()) -@hypothesis.settings( - suppress_health_check=[hypothesis.HealthCheck.too_slow], -) -def test_dataframe_strategy(pdtype, size, data): +@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow]) +def test_dataframe_strategy(data_type, size, data): """Test DataFrameSchema strategy.""" dataframe_schema = pa.DataFrameSchema( - {f"{pdtype.value}_col": pa.Column(pdtype)} + {f"{data_type}_col": pa.Column(data_type)} ) df_sample = data.draw(dataframe_schema.strategy(size=size)) if size == 0: @@ -461,15 +467,17 @@ def test_dataframe_strategy(pdtype, size, data): ) else: assert isinstance(dataframe_schema(df_sample), pd.DataFrame) - with pytest.raises(pa.errors.BaseStrategyOnlyError): - strategies.dataframe_strategy( - pdtype, strategies.pandas_dtype_strategy(pdtype) - ) + # with pytest.raises(pa.errors.BaseStrategyOnlyError): + # strategies.dataframe_strategy( + # data_type, strategies.pandas_dtype_strategy(data_type) + # ) def test_dataframe_example(): """Test DataFrameSchema example method generate examples that pass.""" - schema = pa.DataFrameSchema({"column": pa.Column(pa.Int, pa.Check.gt(0))}) + schema = pa.DataFrameSchema( + {"column": pa.Column(pa.Int(), pa.Check.gt(0))} + ) for _ in range(10): schema(schema.example()) @@ -503,21 +511,16 @@ def test_dataframe_with_regex(regex, data, n_regex_columns): assert df.shape[1] == n_regex_columns -@pytest.mark.parametrize("pdtype", NUMERIC_DTYPES) +@pytest.mark.parametrize("data_type", NUMERIC_DTYPES) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) @hypothesis.given(st.data()) -def test_dataframe_checks(pdtype, data): +def test_dataframe_checks(data_type, data): """Test dataframe strategy with checks defined at the dataframe level.""" - if pa.LEGACY_PANDAS and pdtype in { - pa.PandasDtype.UInt64, - pa.PandasDtype.UINT64, - }: - pytest.xfail("pandas<1.0.0 leads to OverflowError for these dtypes.") - min_value, max_value = data.draw(value_ranges(pdtype)) + min_value, max_value = data.draw(value_ranges(data_type)) dataframe_schema = pa.DataFrameSchema( - {f"{pdtype.value}_col": pa.Column(pdtype) for _ in range(5)}, + {f"{data_type}_col": pa.Column(data_type) for _ in range(5)}, checks=pa.Check.in_range(min_value, max_value), ) strat = dataframe_schema.strategy(size=5) @@ -525,17 +528,19 @@ def test_dataframe_checks(pdtype, data): dataframe_schema(example) -@pytest.mark.parametrize("pdtype", [pa.Int, pa.Float, pa.String, pa.DateTime]) +@pytest.mark.parametrize( + "data_type", [pa.Int(), pa.Float, pa.String, pa.DateTime] +) @hypothesis.given(st.data()) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) -def test_dataframe_strategy_with_indexes(pdtype, data): +def test_dataframe_strategy_with_indexes(data_type, data): """Test dataframe strategy with index and multiindex components.""" - dataframe_schema_index = pa.DataFrameSchema(index=pa.Index(pdtype)) + dataframe_schema_index = pa.DataFrameSchema(index=pa.Index(data_type)) dataframe_schema_multiindex = pa.DataFrameSchema( index=pa.MultiIndex( - [pa.Index(pdtype, name=f"index{i}") for i in range(3)] + [pa.Index(data_type, name=f"index{i}") for i in range(3)] ) ) @@ -551,12 +556,14 @@ def test_dataframe_strategy_with_indexes(pdtype, data): ) def test_index_strategy(data): """Test Index schema component strategy.""" - pdtype = pa.PandasDtype.Int - index_schema = pa.Index(pdtype, allow_duplicates=False, name="index") + data_type = pa.Int() + index_schema = pa.Index(data_type, allow_duplicates=False, name="index") strat = index_schema.strategy(size=10) example = data.draw(strat) + assert (~example.duplicated()).all() - assert example.dtype == pdtype.str_alias + actual_data_type = pandas_engine.Engine.dtype(example.dtype) + assert data_type.check(actual_data_type) index_schema(pd.DataFrame(index=example)) @@ -564,8 +571,8 @@ def test_index_example(): """ Test Index schema component example method generates examples that pass. """ - pdtype = pa.PandasDtype.Int - index_schema = pa.Index(pdtype, allow_duplicates=False) + data_type = pa.Int() + index_schema = pa.Index(data_type, allow_duplicates=False) for _ in range(10): index_schema(pd.DataFrame(index=index_schema.example())) @@ -576,22 +583,25 @@ def test_index_example(): ) def test_multiindex_strategy(data): """Test MultiIndex schema component strategy.""" - pdtype = pa.PandasDtype.Float + data_type = pa.Float() multiindex = pa.MultiIndex( indexes=[ - pa.Index(pdtype, allow_duplicates=False, name="level_0"), - pa.Index(pdtype, nullable=True), - pa.Index(pdtype), + pa.Index(data_type, allow_duplicates=False, name="level_0"), + pa.Index(data_type, nullable=True), + pa.Index(data_type), ] ) strat = multiindex.strategy(size=10) example = data.draw(strat) for i in range(example.nlevels): - assert example.get_level_values(i).dtype == pdtype.str_alias + actual_data_type = pandas_engine.Engine.dtype( + example.get_level_values(i).dtype + ) + assert data_type.check(actual_data_type) with pytest.raises(pa.errors.BaseStrategyOnlyError): strategies.multiindex_strategy( - pdtype, strategies.pandas_dtype_strategy(pdtype) + data_type, strategies.pandas_dtype_strategy(data_type) ) @@ -600,12 +610,12 @@ def test_multiindex_example(): Test MultiIndex schema component example method generates examples that pass. """ - pdtype = pa.PandasDtype.Float + data_type = pa.Float() multiindex = pa.MultiIndex( indexes=[ - pa.Index(pdtype, allow_duplicates=False, name="level_0"), - pa.Index(pdtype, nullable=True), - pa.Index(pdtype), + pa.Index(data_type, allow_duplicates=False, name="level_0"), + pa.Index(data_type, nullable=True), + pa.Index(data_type), ] ) for _ in range(10): @@ -613,21 +623,23 @@ def test_multiindex_example(): multiindex(pd.DataFrame(index=example)) -@pytest.mark.parametrize("pdtype", NULLABLE_DTYPES) +@pytest.mark.parametrize("data_type", NULLABLE_DTYPES) @hypothesis.given(st.data()) -def test_field_element_strategy(pdtype, data): +def test_field_element_strategy(data_type, data): """Test strategy for generating elements in columns/indexes.""" - strategy = strategies.field_element_strategy(pdtype) + strategy = strategies.field_element_strategy(data_type) element = data.draw(strategy) - assert element.dtype.type == pdtype.numpy_dtype.type + + expected_type = strategies.to_numpy_dtype(data_type).type + assert element.dtype.type == expected_type with pytest.raises(pa.errors.BaseStrategyOnlyError): strategies.field_element_strategy( - pdtype, strategies.pandas_dtype_strategy(pdtype) + data_type, strategies.pandas_dtype_strategy(data_type) ) -@pytest.mark.parametrize("pdtype", NULLABLE_DTYPES) +@pytest.mark.parametrize("data_type", NULLABLE_DTYPES) @pytest.mark.parametrize( "field_strategy", [strategies.index_strategy, strategies.series_strategy], @@ -637,20 +649,12 @@ def test_field_element_strategy(pdtype, data): @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) -def test_check_nullable_field_strategy(pdtype, field_strategy, nullable, data): +def test_check_nullable_field_strategy( + data_type, field_strategy, nullable, data +): """Test strategies for generating nullable column/index data.""" - - if ( - pa.LEGACY_PANDAS - and field_strategy is strategies.index_strategy - and (pdtype.is_nullable_int or pdtype.is_nullable_uint) - ): - pytest.skip( - "pandas version<1 does not handle nullable integer indexes" - ) - size = 5 - strat = field_strategy(pdtype, nullable=nullable, size=size) + strat = field_strategy(data_type, nullable=nullable, size=size) example = data.draw(strat) if nullable: @@ -659,22 +663,18 @@ def test_check_nullable_field_strategy(pdtype, field_strategy, nullable, data): assert example.notna().all() -@pytest.mark.parametrize("pdtype", NULLABLE_DTYPES) +@pytest.mark.parametrize("data_type", NULLABLE_DTYPES) @pytest.mark.parametrize("nullable", [True, False]) @hypothesis.given(st.data()) @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.too_slow], ) -def test_check_nullable_dataframe_strategy(pdtype, nullable, data): +def test_check_nullable_dataframe_strategy(data_type, nullable, data): """Test strategies for generating nullable DataFrame data.""" size = 5 # pylint: disable=no-value-for-parameter strat = strategies.dataframe_strategy( - columns={ - "col": pa.Column( - pandas_dtype=pdtype, nullable=nullable, name="col" - ) - }, + columns={"col": pa.Column(data_type, nullable=nullable, name="col")}, size=size, ) example = data.draw(strat) @@ -689,7 +689,7 @@ def test_check_nullable_dataframe_strategy(pdtype, nullable, data): [ [ pa.SeriesSchema( - pa.Int, + pa.Int(), checks=[ pa.Check(lambda x: x > 0, element_wise=True), pa.Check(lambda x: x > -10, element_wise=True), @@ -699,7 +699,7 @@ def test_check_nullable_dataframe_strategy(pdtype, nullable, data): ], [ pa.SeriesSchema( - pa.Int, + pa.Int(), checks=[ pa.Check(lambda s: s > -10000), pa.Check(lambda s: s > -9999), @@ -731,7 +731,7 @@ def test_series_strategy_undefined_check_strategy(schema, warning, data): [ [ pa.DataFrameSchema( - columns={"column": pa.Column(pa.Int)}, + columns={"column": pa.Column(pa.Int())}, checks=[ pa.Check(lambda x: x > 0, element_wise=True), pa.Check(lambda x: x > -10, element_wise=True), @@ -743,7 +743,7 @@ def test_series_strategy_undefined_check_strategy(schema, warning, data): pa.DataFrameSchema( columns={ "column": pa.Column( - pa.Int, + pa.Int(), checks=[ pa.Check(lambda s: s > -10000), pa.Check(lambda s: s > -9999), @@ -755,7 +755,7 @@ def test_series_strategy_undefined_check_strategy(schema, warning, data): ], [ pa.DataFrameSchema( - columns={"column": pa.Column(pa.Int)}, + columns={"column": pa.Column(pa.Int())}, checks=[ pa.Check(lambda s: s > -10000), pa.Check(lambda s: s > -9999), @@ -835,4 +835,4 @@ def test_schema_component_with_no_pdtype(): strategies.index_strategy, ]: with pytest.raises(pa.errors.SchemaDefinitionError): - schema_component_strategy(pandas_dtype=None) + schema_component_strategy(pandera_dtype=None)