diff --git a/docs/source/drop_invalid_rows.rst b/docs/source/drop_invalid_rows.rst new file mode 100644 index 000000000..a1d48a018 --- /dev/null +++ b/docs/source/drop_invalid_rows.rst @@ -0,0 +1,99 @@ +.. currentmodule:: pandera + +.. _drop_invalid_rows: + +Dropping Invalid Rows +===================== + +*New in version 0.16.0* + +If you wish to use the validation step to remove invalid data, you can pass the +``drop_invalid_rows=True`` argument to the ``schema`` object on creation. On ``schema.validate()``, +if a data-level check fails, then that row which caused the failure will be removed from the dataframe +when it is returned. + +``drop_invalid`` will prevent data-level schema errors being raised and will instead +remove the rows which causes the failure. + +This functionality is available on ``DataFrameSchema``, ``SeriesSchema``, ``Column``, +as well as ``DataFrameModel`` schemas. + +Dropping invalid rows with :class:`~pandera.api.pandas.container.DataFrameSchema`: + +.. testcode:: drop_invalid_rows_data_frame_schema + + import pandas as pd + import pandera as pa + + from pandera import Check, Column, DataFrameSchema + + df = pd.DataFrame({"counter": ["1", "2", "3"]}) + schema = DataFrameSchema( + {"counter": Column(int, checks=[Check(lambda x: x >= 3)])}, + drop_invalid_rows=True, + ) + + schema.validate(df, lazy=True) + +Dropping invalid rows with :class:`~pandera.api.pandas.array.SeriesSchema`: + +.. testcode:: drop_invalid_rows_series_schema + + import pandas as pd + import pandera as pa + + from pandera import Check, SeriesSchema + + series = pd.Series(["1", "2", "3"]) + schema = SeriesSchema( + int, + checks=[Check(lambda x: x >= 3)], + drop_invalid_rows=True, + ) + + schema.validate(series, lazy=True) + +Dropping invalid rows with :class:`~pandera.api.pandas.components.Column`: + +.. testcode:: drop_invalid_rows_column + + import pandas as pd + import pandera as pa + + from pandera import Check, Column + + df = pd.DataFrame({"counter": ["1", "2", "3"]}) + schema = Column( + int, + name="counter", + drop_invalid_rows=True, + checks=[Check(lambda x: x >= 3)] + ) + + schema.validate(df, lazy=True) + +Dropping invalid rows with :class:`~pandera.api.pandas.model.DataFrameModel`: + +.. testcode:: drop_invalid_rows_data_frame_model + + import pandas as pd + import pandera as pa + + from pandera import Check, DataFrameModel, Field + + class MySchema(DataFrameModel): + counter: int = Field(in_range={"min_value": 3, "max_value": 5}) + + class Config: + drop_invalid_rows = True + + + MySchema.validate( + pd.DataFrame({"counter": [1, 2, 3, 4, 5, 6]}), lazy=True + ) + +.. note:: + In order to use ``drop_invalid_rows=True``, ``lazy=True`` must + be passed to the ``schema.validate()``. :ref:`lazy_validation` enables all schema + errors to be collected and raised together, meaning all invalid rows can be dropped together. + This provides clear API for ensuring the validated dataframe contains only valid data. diff --git a/docs/source/index.rst b/docs/source/index.rst index 2a8b4d95e..8d8c4a6dd 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -360,6 +360,7 @@ page or reach out to the maintainers and pandera community on hypothesis dtypes decorators + drop_invalid_rows schema_inference lazy_validation data_synthesis_strategies diff --git a/pandera/api/base/schema.py b/pandera/api/base/schema.py index b5db52310..1e027bdc6 100644 --- a/pandera/api/base/schema.py +++ b/pandera/api/base/schema.py @@ -32,6 +32,7 @@ def __init__( name=None, title=None, description=None, + drop_invalid_rows=False, ): """Abstract base schema initializer.""" self.dtype = dtype @@ -40,6 +41,7 @@ def __init__( self.name = name self.title = title self.description = description + self.drop_invalid_rows = drop_invalid_rows def validate( self, diff --git a/pandera/api/pandas/array.py b/pandera/api/pandas/array.py index e10e04afe..384874f43 100644 --- a/pandera/api/pandas/array.py +++ b/pandera/api/pandas/array.py @@ -37,6 +37,7 @@ def __init__( title: Optional[str] = None, description: Optional[str] = None, default: Optional[Any] = None, + drop_invalid_rows: bool = False, ) -> None: """Initialize array schema. @@ -63,6 +64,8 @@ def __init__( :param title: A human-readable label for the series. :param description: An arbitrary textual description of the series. :param default: The default value for missing values in the series. + :param drop_invalid_rows: if True, drop invalid rows on validation. + """ super().__init__( @@ -72,6 +75,7 @@ def __init__( name=name, title=title, description=description, + drop_invalid_rows=drop_invalid_rows, ) if checks is None: @@ -300,6 +304,7 @@ def __init__( title: Optional[str] = None, description: Optional[str] = None, default: Optional[Any] = None, + drop_invalid_rows: bool = False, ) -> None: """Initialize series schema base object. @@ -327,6 +332,7 @@ def __init__( :param title: A human-readable label for the series. :param description: An arbitrary textual description of the series. :param default: The default value for missing values in the series. + :param drop_invalid_rows: if True, drop invalid rows on validation. """ super().__init__( @@ -340,6 +346,7 @@ def __init__( title, description, default, + drop_invalid_rows, ) self.index = index diff --git a/pandera/api/pandas/components.py b/pandera/api/pandas/components.py index 1012f1899..3047c67dd 100644 --- a/pandera/api/pandas/components.py +++ b/pandera/api/pandas/components.py @@ -30,6 +30,7 @@ def __init__( title: Optional[str] = None, description: Optional[str] = None, default: Optional[Any] = None, + drop_invalid_rows: bool = False, ) -> None: """Create column validator object. @@ -54,6 +55,7 @@ def __init__( :param title: A human-readable label for the column. :param description: An arbitrary textual description of the column. :param default: The default value for missing values in the column. + :param drop_invalid_rows: if True, drop invalid rows on validation. :raises SchemaInitError: if impossible to build schema from parameters @@ -85,6 +87,7 @@ def __init__( title=title, description=description, default=default, + drop_invalid_rows=drop_invalid_rows, ) if ( name is not None diff --git a/pandera/api/pandas/container.py b/pandera/api/pandas/container.py index 2bbb4e47f..4f1266c2a 100644 --- a/pandera/api/pandas/container.py +++ b/pandera/api/pandas/container.py @@ -46,6 +46,7 @@ def __init__( unique_column_names: bool = False, title: Optional[str] = None, description: Optional[str] = None, + drop_invalid_rows: bool = False, ) -> None: """Initialize DataFrameSchema validator. @@ -77,6 +78,7 @@ def __init__( :param unique_column_names: whether or not column names must be unique. :param title: A human-readable label for the schema. :param description: An arbitrary textual description of the schema. + :param drop_invalid_rows: if True, drop invalid rows on validation. :raises SchemaInitError: if impossible to build schema from parameters @@ -152,6 +154,7 @@ def __init__( self._unique = unique self.report_duplicates = report_duplicates self.unique_column_names = unique_column_names + self.drop_invalid_rows = drop_invalid_rows # this attribute is not meant to be accessed by users and is explicitly # set to True in the case that a schema is created by infer_schema. diff --git a/pandera/api/pandas/model.py b/pandera/api/pandas/model.py index 11d2d985c..3755cc095 100644 --- a/pandera/api/pandas/model.py +++ b/pandera/api/pandas/model.py @@ -268,6 +268,7 @@ def to_schema(cls) -> DataFrameSchema: "title": cls.__config__.title, "description": cls.__config__.description or cls.__doc__, "unique_column_names": cls.__config__.unique_column_names, + "drop_invalid_rows": cls.__config__.drop_invalid_rows, } cls.__schema__ = DataFrameSchema( columns, diff --git a/pandera/api/pandas/model_config.py b/pandera/api/pandas/model_config.py index 39cb20e08..263955dcf 100644 --- a/pandera/api/pandas/model_config.py +++ b/pandera/api/pandas/model_config.py @@ -21,6 +21,7 @@ class BaseConfig(BaseModelConfig): # pylint:disable=R0903 title: Optional[str] = None #: human-readable label for schema description: Optional[str] = None #: arbitrary textual description coerce: bool = False #: coerce types of all schema components + drop_invalid_rows: bool = False #: drop invalid rows on validation #: make sure certain column combinations are unique unique: Optional[Union[str, List[str]]] = None diff --git a/pandera/backends/base/__init__.py b/pandera/backends/base/__init__.py index d18e73acb..026a22292 100644 --- a/pandera/backends/base/__init__.py +++ b/pandera/backends/base/__init__.py @@ -124,6 +124,10 @@ def failure_cases_metadata( """Get failure cases metadata for lazy validation.""" raise NotImplementedError + def drop_invalid_rows(self, check_obj, error_handler): + """Remove invalid elements in a `check_obj` according to failures in caught by the `error_handler`""" + raise NotImplementedError + class BaseCheckBackend(ABC): """Abstract base class for a check backend implementation.""" diff --git a/pandera/backends/pandas/array.py b/pandera/backends/pandas/array.py index 4714bd947..3b971c637 100644 --- a/pandera/backends/pandas/array.py +++ b/pandera/backends/pandas/array.py @@ -20,6 +20,7 @@ SchemaError, SchemaErrors, SchemaErrorReason, + SchemaDefinitionError, ) @@ -45,6 +46,11 @@ def validate( error_handler = SchemaErrorHandler(lazy) check_obj = self.preprocess(check_obj, inplace) + if getattr(schema, "drop_invalid_rows", False) and not lazy: + raise SchemaDefinitionError( + "When drop_invalid_rows is True, lazy must be set to True." + ) + # fill nans with `default` if it's present if hasattr(schema, "default") and pd.notna(schema.default): check_obj.fillna(schema.default, inplace=True) @@ -55,6 +61,42 @@ def validate( except SchemaError as exc: error_handler.collect_error(exc.reason_code, exc) + # run the core checks + error_handler = self.run_checks_and_handle_errors( + error_handler, + schema, + check_obj, + head, + tail, + sample, + random_state, + ) + + if lazy and error_handler.collected_errors: + if getattr(schema, "drop_invalid_rows", False): + check_obj = self.drop_invalid_rows(check_obj, error_handler) + return check_obj + else: + raise SchemaErrors( + schema=schema, + schema_errors=error_handler.collected_errors, + data=check_obj, + ) + + return check_obj + + def run_checks_and_handle_errors( + self, + error_handler, + schema, + check_obj, + head, + tail, + sample, + random_state, + ): + """Run checks on schema""" + # pylint: disable=too-many-locals field_obj_subsample = self.subsample( check_obj if is_field(check_obj) else check_obj[schema.name], head, @@ -71,14 +113,15 @@ def validate( random_state, ) - # run the core checks - for core_check, args in ( + core_checks = [ (self.check_name, (field_obj_subsample, schema)), (self.check_nullable, (field_obj_subsample, schema)), (self.check_unique, (field_obj_subsample, schema)), (self.check_dtype, (field_obj_subsample, schema)), (self.run_checks, (check_obj_subsample, schema)), - ): + ] + + for core_check, args in core_checks: results = core_check(*args) if isinstance(results, CoreCheckResult): results = [results] @@ -106,13 +149,7 @@ def validate( original_exc=result.original_exc, ) - if lazy and error_handler.collected_errors: - raise SchemaErrors( - schema=schema, - schema_errors=error_handler.collected_errors, - data=check_obj, - ) - return check_obj + return error_handler def coerce_dtype( self, diff --git a/pandera/backends/pandas/base.py b/pandera/backends/pandas/base.py index 36194d8b7..0f44e7291 100644 --- a/pandera/backends/pandas/base.py +++ b/pandera/backends/pandas/base.py @@ -24,6 +24,7 @@ scalar_failure_case, ) from pandera.errors import FailureCaseMetadata, SchemaError, SchemaErrorReason +from pandera.error_handlers import SchemaErrorHandler class ColumnInfo(NamedTuple): @@ -149,3 +150,12 @@ def failure_cases_metadata( message=message, error_counts=error_counts, ) + + def drop_invalid_rows(self, check_obj, error_handler: SchemaErrorHandler): + """Remove invalid elements in a check obj according to failures in caught by the error handler.""" + errors = error_handler.collected_errors + for err in errors: + check_obj = check_obj.loc[ + ~check_obj.index.isin(err.failure_cases["index"]) + ] + return check_obj diff --git a/pandera/backends/pandas/components.py b/pandera/backends/pandas/components.py index 66c98b083..ea1b542fc 100644 --- a/pandera/backends/pandas/components.py +++ b/pandera/backends/pandas/components.py @@ -1,4 +1,5 @@ """Backend implementation for pandas schema components.""" +# pylint: disable=too-many-locals import traceback from copy import copy, deepcopy @@ -18,7 +19,12 @@ ) from pandera.backends.pandas.error_formatters import scalar_failure_case from pandera.error_handlers import SchemaErrorHandler -from pandera.errors import SchemaError, SchemaErrors, SchemaErrorReason +from pandera.errors import ( + SchemaError, + SchemaErrors, + SchemaErrorReason, + SchemaDefinitionError, +) class ColumnBackend(ArraySchemaBackend): @@ -42,6 +48,11 @@ def validate( error_handler = SchemaErrorHandler(lazy=lazy) + if getattr(schema, "drop_invalid_rows", False) and not lazy: + raise SchemaDefinitionError( + "When drop_invalid_rows is True, lazy must be set to True." + ) + if schema.name is None: raise SchemaError( schema, @@ -51,10 +62,10 @@ def validate( "method.", ) - def validate_column(check_obj, column_name): + def validate_column(check_obj, column_name, return_check_obj=False): try: # pylint: disable=super-with-arguments - super(ColumnBackend, self).validate( + validated_check_obj = super(ColumnBackend, self).validate( check_obj, copy(schema).set_name(column_name), head=head, @@ -64,6 +75,10 @@ def validate_column(check_obj, column_name): lazy=lazy, inplace=inplace, ) + + if return_check_obj: + return validated_check_obj + except SchemaErrors as err: for err in err.schema_errors: error_handler.collect_error( @@ -95,7 +110,13 @@ def validate_column(check_obj, column_name): check_obj[column_name].iloc[:, [i]], column_name ) else: - validate_column(check_obj, column_name) + if getattr(schema, "drop_invalid_rows", False): + # replace the check_obj with the validated check_obj + check_obj = validate_column( + check_obj, column_name, return_check_obj=True + ) + else: + validate_column(check_obj, column_name) if lazy and error_handler.collected_errors: raise SchemaErrors( @@ -381,16 +402,8 @@ def validate( otherwise creates a copy of the data. :returns: validated DataFrame or Series. """ - # pylint: disable=too-many-locals if schema.coerce: - try: - check_obj.index = self.coerce_dtype( - check_obj.index, schema=schema # type: ignore [arg-type] - ) - except SchemaErrors as err: - if lazy: - raise - raise err.schema_errors[0] from err + check_obj.index = self.__coerce_index(check_obj, schema, lazy) # Prevent data type coercion when the validate method is called because # it leads to some weird behavior when calling coerce_dtype within the @@ -419,32 +432,9 @@ def validate( ): columns[name] = column.set_name(name) schema_copy.columns = columns - - def to_dataframe(multiindex): - """ - Emulate the behavior of pandas.MultiIndex.to_frame, but preserve - duplicate index names if they exist. - """ - # NOTE: this is a hack to support pyspark.pandas - if type(multiindex).__module__.startswith("pyspark.pandas"): - df = multiindex.to_frame() - else: - df = pd.DataFrame( - { - i: multiindex.get_level_values(i) - for i in range(multiindex.nlevels) - } - ) - df.columns = [ - i if name is None else name - for i, name in enumerate(multiindex.names) - ] - df.index = multiindex - return df - try: validation_result = super().validate( - to_dataframe(check_obj.index), + self.__to_dataframe(check_obj.index), schema_copy, head=head, tail=tail, @@ -480,3 +470,36 @@ def to_dataframe(multiindex): assert is_table(validation_result) return check_obj + + def __to_dataframe(self, multiindex): + """ + Emulate the behavior of pandas.MultiIndex.to_frame, but preserve + duplicate index names if they exist. + """ + # NOTE: this is a hack to support pyspark.pandas + if type(multiindex).__module__.startswith("pyspark.pandas"): + df = multiindex.to_frame() + else: + df = pd.DataFrame( + { + i: multiindex.get_level_values(i) + for i in range(multiindex.nlevels) + } + ) + df.columns = [ + i if name is None else name + for i, name in enumerate(multiindex.names) + ] + df.index = multiindex + return df + + def __coerce_index(self, check_obj, schema, lazy): + """Coerce index""" + try: + return self.coerce_dtype( + check_obj.index, schema=schema # type: ignore [arg-type] + ) + except SchemaErrors as err: + if lazy: + raise + raise err.schema_errors[0] from err diff --git a/pandera/backends/pandas/container.py b/pandera/backends/pandas/container.py index 27a4520b1..1bf799fe2 100644 --- a/pandera/backends/pandas/container.py +++ b/pandera/backends/pandas/container.py @@ -56,6 +56,11 @@ def validate( if not is_table(check_obj): raise TypeError(f"expected pd.DataFrame, got {type(check_obj)}") + if getattr(schema, "drop_invalid_rows", False) and not lazy: + raise SchemaDefinitionError( + "When drop_invalid_rows is True, lazy must be set to True." + ) + error_handler = SchemaErrorHandler(lazy) check_obj = self.preprocess(check_obj, inplace=inplace) @@ -82,6 +87,49 @@ def validate( except SchemaErrors as exc: error_handler.collect_errors(exc) + # run the checks + error_handler = self.run_checks_and_handle_errors( + error_handler, + schema, + check_obj, + column_info, + sample, + components, + lazy, + head, + tail, + random_state, + ) + + if error_handler.collected_errors: + if getattr(schema, "drop_invalid_rows", False): + check_obj = self.drop_invalid_rows(check_obj, error_handler) + return check_obj + else: + raise SchemaErrors( + schema=schema, + schema_errors=error_handler.collected_errors, + data=check_obj, + ) + + return check_obj + + def run_checks_and_handle_errors( + self, + error_handler, + schema, + check_obj, + column_info, + sample, + components, + lazy, + head, + tail, + random_state, + ): + """Run checks on schema""" + # pylint: disable=too-many-locals + # subsample the check object if head, tail, or sample are specified sample = self.subsample(check_obj, head, tail, sample, random_state) @@ -93,7 +141,6 @@ def validate( (self.run_schema_component_checks, (sample, components, lazy)), (self.run_checks, (sample, schema)), ] - for check, args in core_checks: results = check(*args) # type: ignore [operator] if isinstance(results, CoreCheckResult): @@ -122,14 +169,7 @@ def validate( original_exc=result.original_exc, ) - if error_handler.collected_errors: - raise SchemaErrors( - schema=schema, - schema_errors=error_handler.collected_errors, - data=check_obj, - ) - - return check_obj + return error_handler def run_schema_component_checks( self, diff --git a/pandera/strategies/pandas_strategies.py b/pandera/strategies/pandas_strategies.py index ef03068b1..c6b35993c 100644 --- a/pandera/strategies/pandas_strategies.py +++ b/pandera/strategies/pandas_strategies.py @@ -827,7 +827,7 @@ def series_strategy( unique: bool = False, name: Optional[str] = None, size: Optional[int] = None, -) -> SearchStrategy[pd.Series]: +) -> SearchStrategy: """Strategy to generate a pandas Series. :param pandera_dtype: :class:`pandera.dtypes.DataType` instance. diff --git a/tests/core/test_schemas.py b/tests/core/test_schemas.py index 8c74ea945..4d2377169 100644 --- a/tests/core/test_schemas.py +++ b/tests/core/test_schemas.py @@ -21,6 +21,8 @@ SeriesSchema, String, errors, + Field, + DataFrameModel, ) from pandera.dtypes import UniqueSettings from pandera.engines.pandas_engine import Engine @@ -2056,3 +2058,127 @@ def _constructor(self): dataframe = MyDataFrame([1, 2, 3], columns=["x"]) print(schema.validate(dataframe)) + + +@pytest.mark.parametrize( + "schema, obj, expected_obj", + [ + ( + DataFrameSchema( + {"numbers": Column(int, checks=[Check(lambda x: x >= 3)])}, + drop_invalid_rows=True, + ), + pd.DataFrame({"numbers": [1, 2, 3, 4, 5]}), + pd.DataFrame({"numbers": [3, 4, 5]}), + ), + ( + DataFrameSchema({"numbers": Column(str)}, drop_invalid_rows=True), + pd.DataFrame({"numbers": [1, 2, 3, 4, 5]}), + pd.DataFrame({"numbers": []}), + ), + ( + DataFrameSchema( + { + "letters": Column(str), + "numbers": Column(int, checks=[Check(lambda x: x >= 3)]), + }, + drop_invalid_rows=True, + ), + pd.DataFrame( + { + "letters": ["a", "b", "c", "d", "e"], + "numbers": [1, 2, 3, 4, 5], + } + ), + pd.DataFrame({"letters": ["c", "d", "e"], "numbers": [3, 4, 5]}), + ), + ], +) +def test_drop_invalid_for_dataframe_schema(schema, obj, expected_obj): + """Test drop_invalid_rows works as expected on DataFrameSchemaBackend.validate""" + actual_obj = schema.validate(obj, lazy=True) + actual_obj.index = expected_obj.index + actual_obj.numbers = actual_obj.numbers.astype(expected_obj.numbers.dtype) + + pd.testing.assert_frame_equal(actual_obj, expected_obj) + + with pytest.raises(errors.SchemaDefinitionError): + schema.validate(obj, lazy=False) + + +@pytest.mark.parametrize( + "schema, obj, expected_obj", + [ + ( + SeriesSchema( + int, + checks=[Check(lambda x: x > 3)], + drop_invalid_rows=True, + ), + pd.Series([9, 6, 3]), + pd.Series([9, 6]), + ), + ( + SeriesSchema(str, drop_invalid_rows=True), + pd.Series(["nine", 6, "three"]), + pd.Series(["nine", "three"]), + ), + ], +) +def test_drop_invalid_for_series_schema(schema, obj, expected_obj): + """Test drop_invalid_rows works as expected on SeriesSchemaBackend.validate""" + actual_obj = schema.validate(obj, lazy=True).reset_index(drop=True) + expected_obj = expected_obj.reset_index(drop=True) + + pd.testing.assert_series_equal(actual_obj, expected_obj) + + with pytest.raises(errors.SchemaDefinitionError): + schema.validate(obj, lazy=False) + + +@pytest.mark.parametrize( + "col, obj, expected_obj", + [ + ( + Column(str, name="letters", drop_invalid_rows=True), + pd.DataFrame({"letters": [None, 1, "c"]}), + pd.DataFrame({"letters": ["c"]}), + ) + ], +) +def test_drop_invalid_for_column(col, obj, expected_obj): + """Test drop_invalid_rows works as expected on ColumnBackend.validate""" + actual_obj = col.validate(obj, lazy=True) + + pd.testing.assert_frame_equal( + expected_obj.reset_index(drop=True), actual_obj.reset_index(drop=True) + ) + + with pytest.raises(errors.SchemaDefinitionError): + col.validate(obj, lazy=False) + + +def test_drop_invalid_for_model_schema(): + """Test drop_invalid_rows works as expected on DataFrameModel.validate""" + + class MySchema(DataFrameModel): + """Schema for the test""" + + counter: int = Field(in_range={"min_value": 3, "max_value": 5}) + + class Config: + """Config for the schema model for the test""" + + drop_invalid_rows = True + + expected_obj = pd.DataFrame({"counter": [3, 4, 5]}) + + actual_obj = MySchema.validate( + pd.DataFrame({"counter": [1, 2, 3, 4, 5, 6]}), lazy=True + ) + + actual_obj.index = expected_obj.index + pd.testing.assert_frame_equal(expected_obj, actual_obj) + + with pytest.raises(errors.SchemaDefinitionError): + MySchema.validate(actual_obj, lazy=False)