Skip to content

Commit

Permalink
Add excludes cohort filter test in erroranalysis (#1171)
Browse files Browse the repository at this point in the history
* Add excludes cohort filter test in erroranalysis

Signed-off-by: Gaurav Gupta <[email protected]>

* Remove old references

Signed-off-by: Gaurav Gupta <[email protected]>

* Fix sorted imports

Signed-off-by: Gaurav Gupta <[email protected]>
  • Loading branch information
gaugup authored Jan 26, 2022
1 parent 60e4c66 commit efcfe07
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 30 deletions.
31 changes: 13 additions & 18 deletions erroranalysis/erroranalysis/_internal/cohort_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,12 @@
import numpy as np
import pandas as pd

from erroranalysis._internal.constants import (METHOD, METHOD_EXCLUDES,
METHOD_INCLUDES, PRED_Y,
ROW_INDEX, TRUE_Y, ModelTask)
from erroranalysis._internal.constants import (METHOD, PRED_Y, ROW_INDEX,
TRUE_Y, CohortFilterMethods,
ModelTask)
from erroranalysis._internal.metrics import get_ordered_classes

COLUMN = 'column'
METHOD_EQUAL = 'equal'
METHOD_GREATER = 'greater'
METHOD_LESS = 'less'
METHOD_LESS_AND_EQUAL = 'less and equal'
METHOD_GREATER_AND_EQUAL = 'greater and equal'
METHOD_RANGE = 'in the range of'
MODEL = 'model'
CLASSIFICATION_OUTCOME = 'Classification outcome'

Expand Down Expand Up @@ -176,24 +170,25 @@ def build_query(filters, categorical_features, categories):
method = filter[METHOD]
arg0 = str(filter['arg'][0])
colname = filter[COLUMN]
if method == METHOD_GREATER:
if method == CohortFilterMethods.METHOD_GREATER:
queries.append("`" + colname + "` > " + arg0)
elif method == METHOD_LESS:
elif method == CohortFilterMethods.METHOD_LESS:
queries.append("`" + colname + "` < " + arg0)
elif method == METHOD_LESS_AND_EQUAL:
elif method == CohortFilterMethods.METHOD_LESS_AND_EQUAL:
queries.append("`" + colname + "` <= " + arg0)
elif method == METHOD_GREATER_AND_EQUAL:
elif method == CohortFilterMethods.METHOD_GREATER_AND_EQUAL:
queries.append("`" + colname + "` >= " + arg0)
elif method == METHOD_RANGE:
elif method == CohortFilterMethods.METHOD_RANGE:
arg1 = str(filter['arg'][1])
queries.append("`" + colname + "` >= " + arg0 +
' & `' + colname + "` <= " + arg1)
elif method == METHOD_INCLUDES or method == METHOD_EXCLUDES:
elif method == CohortFilterMethods.METHOD_INCLUDES or \
method == CohortFilterMethods.METHOD_EXCLUDES:
query = build_bounds_query(filter, colname, method,
categorical_features,
categories)
queries.append(query)
elif method == METHOD_EQUAL:
elif method == CohortFilterMethods.METHOD_EQUAL:
is_categorical = False
if categorical_features:
is_categorical = colname in categorical_features
Expand Down Expand Up @@ -243,7 +238,7 @@ def build_bounds_query(filter, colname, method,
:rtype: str
"""
bounds = []
if method == METHOD_EXCLUDES:
if method == CohortFilterMethods.METHOD_EXCLUDES:
operator = " != "
else:
operator = " == "
Expand All @@ -260,7 +255,7 @@ def build_bounds_query(filter, colname, method,
else:
arg_val = arg
bounds.append("`{}`{}{}".format(colname, operator, arg_val))
if method == METHOD_EXCLUDES:
if method == CohortFilterMethods.METHOD_EXCLUDES:
return ' & '.join(bounds)
else:
return ' | '.join(bounds)
16 changes: 14 additions & 2 deletions erroranalysis/erroranalysis/_internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,20 @@
SPLIT_FEATURE = 'split_feature'
LEAF_INDEX = 'leaf_index'
METHOD = 'method'
METHOD_EXCLUDES = 'excludes'
METHOD_INCLUDES = 'includes'


class CohortFilterMethods:
"""Cohort filter methods.
"""

METHOD_INCLUDES = 'includes'
METHOD_EXCLUDES = 'excludes'
METHOD_EQUAL = 'equal'
METHOD_GREATER = 'greater'
METHOD_LESS = 'less'
METHOD_LESS_AND_EQUAL = 'less and equal'
METHOD_GREATER_AND_EQUAL = 'greater and equal'
METHOD_RANGE = 'in the range of'


class ModelTask(str, Enum):
Expand Down
19 changes: 9 additions & 10 deletions erroranalysis/erroranalysis/_internal/surrogate_error_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@

from erroranalysis._internal.cohort_filter import filter_from_cohort
from erroranalysis._internal.constants import (DIFF, LEAF_INDEX, METHOD,
METHOD_EXCLUDES,
METHOD_INCLUDES, PRED_Y,
ROW_INDEX, SPLIT_FEATURE,
SPLIT_INDEX, TRUE_Y, Metrics,
ModelTask, error_metrics,
f1_metrics,
PRED_Y, ROW_INDEX,
SPLIT_FEATURE, SPLIT_INDEX,
TRUE_Y, CohortFilterMethods,
Metrics, ModelTask,
error_metrics, f1_metrics,
metric_to_display_name,
precision_metrics,
recall_metrics)
Expand Down Expand Up @@ -277,7 +276,7 @@ def create_categorical_arg(parent_threshold):

def create_categorical_query(method, arg, p_node_name, p_node_query,
parent, categories):
if method == METHOD_INCLUDES:
if method == CohortFilterMethods.METHOD_INCLUDES:
operation = "=="
else:
operation = "!="
Expand All @@ -296,7 +295,7 @@ def create_categorical_query(method, arg, p_node_name, p_node_query,
query = []
for argi in arg:
query.append(p_node_query + " " + operation + " " + str(argi))
if method == METHOD_INCLUDES:
if method == CohortFilterMethods.METHOD_INCLUDES:
query = " | ".join(query)
else:
query = " & ".join(query)
Expand Down Expand Up @@ -333,7 +332,7 @@ def node_to_dict(df, tree, nodeid, categories, json,
parent_threshold)
df = df[df[p_node_name_val] <= parent_threshold]
elif parent_decision_type == '==':
method = METHOD_INCLUDES
method = CohortFilterMethods.METHOD_INCLUDES
arg = create_categorical_arg(parent_threshold)
query, condition = create_categorical_query(method,
arg,
Expand All @@ -350,7 +349,7 @@ def node_to_dict(df, tree, nodeid, categories, json,
parent_threshold)
df = df[df[p_node_name_val] > parent_threshold]
elif parent_decision_type == '==':
method = METHOD_EXCLUDES
method = CohortFilterMethods.METHOD_EXCLUDES
arg = create_categorical_arg(parent_threshold)
query, condition = create_categorical_query(method,
arg,
Expand Down
24 changes: 24 additions & 0 deletions erroranalysis/tests/test_cohort_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,29 @@ def test_cohort_filter_includes(self):
model_task,
filters=filters)

def test_cohort_filter_excludes(self):
X_train, X_test, y_train, y_test, numeric, categorical = \
create_simple_titanic_data()
feature_names = categorical + numeric
clf = create_titanic_pipeline(X_train, y_train)
categorical_features = categorical
# the indexes other than 0, 2 correspond to Q
filters = [{'arg': [0, 2],
'column': EMBARKED,
'method': 'excludes'}]
validation_data = create_validation_data(X_test, y_test)
filter_embarked = X_test[EMBARKED].isin(['Q'])
validation_data = validation_data.loc[filter_embarked]
model_task = ModelTask.CLASSIFICATION
run_error_analyzer(validation_data,
clf,
X_test,
y_test,
feature_names,
categorical_features,
model_task,
filters=filters)

def test_cohort_filter_classification_outcome(self):
X_train, X_test, y_train, y_test, numeric, categorical = \
create_simple_titanic_data()
Expand Down Expand Up @@ -242,6 +265,7 @@ def run_error_analyzer(validation_data,
filtered_data = filter_from_cohort(error_analyzer,
filters,
composite_filters)

# validate there is some data selected for each of the filters
assert validation_data.shape[0] > 0
assert validation_data.equals(filtered_data)

0 comments on commit efcfe07

Please sign in to comment.