Skip to content

Commit

Permalink
fix(warnings): sort warnings by priority (#11)
Browse files Browse the repository at this point in the history
Core changes
- (QualityEngine) warnings is now a list
- (QualityEngine) updated store_warnings, fixed report methods
- (DataQuality) replicated QualityEngine updates
  • Loading branch information
jfsantos-ds authored Sep 1, 2021
1 parent e182e3f commit ffac9f2
Show file tree
Hide file tree
Showing 15 changed files with 1,533 additions and 1,639 deletions.
41 changes: 23 additions & 18 deletions src/ydata_quality/core/data_quality.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""
Implementation of main class for Data Quality checks.
"""
from typing import List, Union, Optional, Callable
from collections import Counter
from typing import Callable, List, Optional, Union

import pandas as pd

from ydata_quality.core.warnings import QualityWarning, Priority
from ydata_quality.core.warnings import Priority, QualityWarning
from ydata_quality.drift import DriftAnalyser
from ydata_quality.duplicates import DuplicateChecker
from ydata_quality.labelling import LabelInspector
from ydata_quality.missings import MissingsProfiler
from ydata_quality.valued_missing_values import VMVIdentifier
from ydata_quality.drift import DriftAnalyser


class DataQuality:
"DataQuality contains the multiple data quality engines."
Expand Down Expand Up @@ -41,7 +43,7 @@ def __init__(self,
model: [DRIFT ANALYSIS] model wrapped by ModelWrapper used to test concept drift.
"""
self.df = df
self._warnings = set()
self._warnings = list()
self._engines = { # Default list of engines
'duplicates': DuplicateChecker(df=df, entities=entities),
'missings': MissingsProfiler(df=df, target=label),
Expand All @@ -55,22 +57,20 @@ def __init__(self,
else:
print('Label is not defined. Skipping LABELLING engine.')


@property
def warnings(self):
"Set of warnings generated by individual QualityEngines."
return self._warnings
def __clean_warnings(self):
"""Deduplicates and sorts the list of warnings."""
self._warnings = sorted(list(set(self._warnings))) # Sort unique warnings by priority

def get_warnings(self,
category: Optional[str] = None,
test: Optional[str] = None,
priority: Optional[Priority] = None) -> List[QualityWarning]:
"Retrieves warnings filtered by their properties."
filtered = list(self.warnings) # convert original set
filtered = [w for w in filtered if w.category == category] if category else filtered
self.__store_warnings()
self.__clean_warnings()
filtered = [w for w in self._warnings if w.category == category] if category else self._warnings
filtered = [w for w in filtered if w.test == test] if test else filtered
filtered = [w for w in filtered if w.priority == Priority(priority)] if priority else filtered
filtered.sort() # sort by priority
return filtered

@property
Expand All @@ -81,18 +81,23 @@ def engines(self):
def __store_warnings(self):
"Appends all warnings from individiual engines into warnings of DataQuality main class."
for engine in self.engines.values():
self._warnings = self._warnings.union(set(engine.get_warnings()))
self._warnings += engine.get_warnings()

def evaluate(self):
"Runs all the individual data quality checks and aggregates the results."
results = {name: engine.evaluate() for name, engine in self.engines.items()}
self.__store_warnings()
return results

def report(self):
"Prints a report containing all the warnings detected during the data quality analysis."
# TODO: Provide a count of warnings by priority
self.__store_warnings() # fetch all warnings from the engines
for warn in self.get_warnings():
print(warn)

self.__clean_warnings()
if not self._warnings:
print('No warnings found.')
else:
prio_counts = Counter([warn.priority.value for warn in self._warnings])
print('Warnings count by priority:')
print(*(f"\tPriority {prio}: {count} warning(s)" for prio, count in prio_counts.items()), sep='\n')
print(f'\tTOTAL: {len(self._warnings)} warning(s)')
print('List of warnings sorted by priority:')
print(*(f"\t{warn}" for warn in self._warnings), sep='\n')
49 changes: 27 additions & 22 deletions src/ydata_quality/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
Implementation of abstract class for Data Quality engines.
"""
from abc import ABC
from collections import Counter
from typing import Optional

import pandas as pd
from ydata_quality.core.warnings import QualityWarning, Priority

from ydata_quality.core.warnings import Priority, QualityWarning
from ydata_quality.utils.modelling import infer_dtypes


class QualityEngine(ABC):
"Main class for running and storing data quality analysis."

def __init__(self, df: pd.DataFrame, label: str = None, dtypes: dict = None):
self._df = df
self._warnings = set()
self._warnings = list()
self._tests = []
self._label = label
self._dtypes = dtypes
Expand All @@ -23,12 +26,6 @@ def df(self):
"Target of data quality checks."
return self._df

@property
def warnings(self):
"Storage of all detected data quality warnings."
return self._warnings


@property
def label(self):
"Property that returns the label under inspection."
Expand All @@ -52,11 +49,10 @@ def dtypes(self):
def dtypes(self, dtypes: dict):
if not isinstance(dtypes, dict):
raise ValueError("Property 'dtypes' should be a dictionary.")
assert all(col in self.df.columns for col in dtypes), "All dtypes keys \
must be columns in the dataset."
assert all(col in self.df.columns for col in dtypes), "All dtypes keys must be columns in the dataset."
supported_dtypes = ['numerical', 'categorical']
assert all(dtype in supported_dtypes for dtype in dtypes.values()), "Assigned dtypes\
must be in the supported broad dtype list: {}.".format(supported_dtypes)
assert all(dtype in supported_dtypes for dtype in dtypes.values()), "Assigned dtypes must be in the supported \
broad dtype list: {}.".format(supported_dtypes)
df_col_set = set(self.df.columns)
dtypes_col_set = set(dtypes.keys())
missing_cols = df_col_set.difference(dtypes_col_set)
Expand All @@ -66,21 +62,24 @@ def dtypes(self, dtypes: dict):
dtypes[col] = dtype
self._dtypes = dtypes

def __clean_warnings(self):
"""Deduplicates and sorts the list of warnings."""
self._warnings = sorted(list(set(self._warnings))) # Sort unique warnings by priority

def store_warning(self, warning: QualityWarning):
"Adds a new warning to the internal 'warnings' storage."
self._warnings.add(warning)
self._warnings.append(warning)

def get_warnings(self,
category: Optional[str] = None,
test: Optional[str] = None,
priority: Optional[Priority] = None):
"Retrieves warnings filtered by their properties."
filtered = list(self.warnings) # convert original set
filtered = [w for w in filtered if w.category == category] if category else filtered
self.__clean_warnings()
filtered = [w for w in self._warnings if w.category == category] if category else self._warnings
filtered = [w for w in filtered if w.test == test] if test else filtered
filtered = [w for w in filtered if w.priority == Priority(priority)] if priority else filtered
filtered.sort() # sort by priority
return filtered
return filtered # sort by priority

@property
def tests(self):
Expand All @@ -89,14 +88,20 @@ def tests(self):

def report(self):
"Prints a report containing all the warnings detected during the data quality analysis."
# TODO: Provide a count of warnings by priority
self._warnings = set(sorted(self._warnings)) # Sort the warnings by priority
for warn in self.warnings:
print(warn)
self.__clean_warnings()
if not self._warnings:
print('No warnings found.')
else:
prio_counts = Counter([warn.priority.value for warn in self._warnings])
print('Warnings count by priority:')
print(*(f"\tPriority {prio}: {count} warning(s)" for prio, count in prio_counts.items()), sep='\n')
print(f'\tTOTAL: {len(self._warnings)} warning(s)')
print('List of warnings sorted by priority:')
print(*(f"\t{warn}" for warn in self._warnings), sep='\n')

def evaluate(self):
"Runs all the indidividual tests available within the same suite. Returns a dict of (name: results)."
self._warnings = set() # reset the warnings to avoid duplicates
self._warnings = list() # reset the warnings
results = {}
for test in self.tests:
try: # if anything fails
Expand Down
2 changes: 1 addition & 1 deletion src/ydata_quality/data_expectations/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def evaluate(self, results_json_path: str, df: pd.DataFrame = None, error_tol: i
rel_error_tol (float): Defines the maximum fraction of failed expectations, overrides error_tol.
minimum_coverage (float): Minimum expected fraction of DataFrame columns covered by the expectation suite.
"""
self._warnings = set() # reset the warnings to avoid duplicates
self._warnings = list() # reset the warnings to avoid duplicates
df = df if isinstance(df, pd.DataFrame) else None
results = {}
results['Overall Assessment'] = self._overall_assessment(results_json_path, error_tol, rel_error_tol)
Expand Down
12 changes: 6 additions & 6 deletions src/ydata_quality/drift/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,13 @@ def sample_covariate_drift(self, p_thresh: float= 0.05) -> pd.DataFrame:
n_drifted_feats = sum(test_summary['Verdict']=='Drift')
n_invalid_tests = sum(test_summary['Verdict']=='Invalid test')
if n_drifted_feats>0:
self._warnings.add(
self.store_warning(
QualityWarning(
test='Sample covariate drift', category='Sampling', priority=2, data=test_summary,
description=f"""{n_drifted_feats} features accused drift in the sample test. The covariates of the test sample do not appear to be representative of the reference sample."""
))
elif n_invalid_tests>0:
self._warnings.add(
self.store_warning(
QualityWarning(
test='Sample covariate drift', category='Sampling', priority=3, data=test_summary,
description=f"""There were {n_invalid_tests} invalid tests found. This is likely due to a small test sample size. The data summary should be analyzed before considering the test conclusive."""
Expand All @@ -295,13 +295,13 @@ def sample_label_drift(self, p_thresh: float= 0.05) -> pd.Series:
index=['Statistic', 'Statistic Value', 'p-value', 'Verdict'])
test_summary['Verdict'] = 'OK' if p_val > p_thresh else ('Drift' if p_val>= 0 else 'Invalid test')
if test_summary['Verdict']=='Drift':
self._warnings.add(
self.store_warning(
QualityWarning(
test='Sample label drift', category='Sampling', priority=2, data=test_summary,
description="""The label accused drift in the sample test with a p-test of {:.4f}, which is under the threshold {:.2f}. The label of the test sample does not appear to be representative of the reference sample.""".format(p_val, p_thresh)
))
elif test_summary['Verdict']=='Invalid test':
self._warnings.add(
self.store_warning(
QualityWarning(
test='Sample label drift', category='Sampling', priority=3, data=test_summary,
description="The test was invalid. This is likely due to a small test sample size."
Expand Down Expand Up @@ -332,13 +332,13 @@ def sample_concept_drift(self, p_thresh: float= 0.05) -> pd.Series:
index=['Statistic', 'Statistic Value', 'p-value', 'Verdict'])
test_summary['Verdict'] = 'OK' if p_val > p_thresh else ('Drift' if p_val>= 0 else 'Invalid test')
if test_summary['Verdict']=='Drift':
self._warnings.add(
self.store_warning(
QualityWarning(
test='Concept drift', category='Sampling', priority=2, data=test_summary,
description="""There was concept drift detected with a p-test of {:.4f}, which is under the threshold {:.2f}. The model's predicted labels for the test sample do not appear to be representative of the distribution of labels predicted for the reference sample.""".format(p_val, p_thresh)
))
elif test_summary['Verdict']=='Invalid test':
self._warnings.add(
self.store_warning(
QualityWarning(
test='Concept drift', category='Sampling', priority=3, data=test_summary,
description="The test was invalid. This is likely due to a small test sample size."
Expand Down
6 changes: 3 additions & 3 deletions src/ydata_quality/duplicates/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def exact_duplicates(self):
"Returns a DataFrame filtered for exact duplicate records."
dups = self.__get_duplicates(self.df) # Filter for duplicate instances
if len(dups) > 0:
self._warnings.add(
self.store_warning(
QualityWarning(
test='Exact Duplicates', category='Duplicates', priority=2, data=dups,
description=f"Found {len(dups)} instances with exact duplicate feature values."
Expand All @@ -74,7 +74,7 @@ def entity_duplicates(self, entity: Optional[Union[str, List[str]]] = None):
if entity is not None: # entity is specified
dups = self.__get_entity_duplicates(self.df, entity)
if len(dups) > 0: # if we have any duplicates
self._warnings.add(
self.store_warning(
QualityWarning(
test='Entity Duplicates', category='Duplicates', priority=2, data=dups,
description=f"Found {len(dups)} duplicates after grouping by entities."
Expand Down Expand Up @@ -109,7 +109,7 @@ def duplicate_columns(self):
dups[col] = tgt_col # Store if they match

if len(dups) > 0:
self._warnings.add(
self.store_warning(
QualityWarning(
test='Duplicate Columns', category='Duplicates', priority=1, data=dups,
description=f"Found {len(dups)} columns with exactly the same feature values as other columns."
Expand Down
Loading

0 comments on commit ffac9f2

Please sign in to comment.