Skip to content

Commit

Permalink
fix failing to create error report when filter_features is empty list
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed May 12, 2022
1 parent c53ba4d commit fcc44f0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion erroranalysis/erroranalysis/analyzer/error_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def create_error_report(self,
num_leaves=num_leaves,
min_child_samples=min_child_samples)
matrix = None
if filter_features is not None:
if filter_features:
matrix = self.compute_matrix(filter_features,
None,
None)
Expand Down
15 changes: 11 additions & 4 deletions erroranalysis/tests/test_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def test_error_report_housing(self):
run_error_analyzer(model, X_test, y_test, feature_names,
categorical_features)

def test_error_report_housing_pandas(self):
@pytest.mark.parametrize('filter_features', [None, [], ['MedInc', 'HouseAge']])
def test_error_report_housing_pandas(self, filter_features):
X_train, X_test, y_train, y_test, feature_names = \
create_housing_data()
X_train = create_dataframe(X_train, feature_names)
Expand All @@ -61,7 +62,7 @@ def test_error_report_housing_pandas(self):
for model in models:
categorical_features = []
run_error_analyzer(model, X_test, y_test, feature_names,
categorical_features)
categorical_features, filter_features=filter_features)


def is_valid_uuid(id):
Expand All @@ -73,7 +74,8 @@ def is_valid_uuid(id):


def run_error_analyzer(model, X_test, y_test, feature_names,
categorical_features, expect_user_warnings=False):
categorical_features, expect_user_warnings=False,
filter_features=None):
if expect_user_warnings and pd.__version__[0] == '0':
with pytest.warns(UserWarning,
match='which has issues with pandas version'):
Expand All @@ -84,7 +86,7 @@ def run_error_analyzer(model, X_test, y_test, feature_names,
model_analyzer = ModelAnalyzer(model, X_test, y_test,
feature_names,
categorical_features)
report1 = model_analyzer.create_error_report(filter_features=None,
report1 = model_analyzer.create_error_report(filter_features=filter_features,
max_depth=3,
num_leaves=None,
compute_importances=True)
Expand All @@ -109,6 +111,11 @@ def run_error_analyzer(model, X_test, y_test, feature_names,
assert ea_deserialized.importances == report1.importances
assert ea_deserialized.root_stats == report1.root_stats

if not filter_features:
assert ea_deserialized.matrix == None
else:
assert ea_deserialized.matrix != None

# validate error report does not modify original dataset in ModelAnalyzer
if isinstance(X_test, pd.DataFrame):
assert X_test.equals(model_analyzer.dataset)
Expand Down

0 comments on commit fcc44f0

Please sign in to comment.