Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement list() method causal manager #1178

Merged
merged 3 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions responsibleai/responsibleai/_internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class CounterfactualManagerKeys(object):
COUNTERFACTUALS = 'counterfactuals'


class CausalManagerKeys(object):
"""Provide constants for CausalManager key properties."""
CAUSAL_EFFECTS = 'causal_effects'


class SKLearn(object):
"""Provide scikit-learn related constants."""

Expand Down
12 changes: 12 additions & 0 deletions responsibleai/responsibleai/_tools/causal/causal_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,15 @@ def __init__(
self.verbose = verbose
self.random_state = random_state
self.categorical_features = categorical_features

def get_config_as_dict(self):
"""Returns the dictionary representation of configuration
in the CausalConfig.

The dictionary contains the different parameters required for
computing the causal effects.

:return: The dictionary representation of the CausalConfig.
:rtype: dict
"""
return self.__dict__
23 changes: 21 additions & 2 deletions responsibleai/responsibleai/managers/causal_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from econml.solutions.causal_analysis import CausalAnalysis

from responsibleai._data_validations import validate_train_test_categories
from responsibleai._internal.constants import ManagerNames
from responsibleai._internal.constants import (CausalManagerKeys,
ListProperties, ManagerNames)
from responsibleai._tools.causal.causal_config import CausalConfig
from responsibleai._tools.causal.causal_constants import (DefaultParams,
ModelTypes,
Expand Down Expand Up @@ -328,7 +329,25 @@ def get(self):
return self._results

def list(self):
pass
"""List information about the CausalManager.

:return: A dictionary of properties.
:rtype: Dict
gaugup marked this conversation as resolved.
Show resolved Hide resolved
"""
props = {ListProperties.MANAGER_TYPE: self.name}
causal_props_list = []
for result in self._results:
causal_config_dict = result.config.get_config_as_dict()
causal_config_dict['global_effects_computed'] = \
gaugup marked this conversation as resolved.
Show resolved Hide resolved
True if result.global_effects is not None else False
causal_config_dict['local_effects_computed'] = \
True if result.local_effects is not None else False
causal_config_dict['policies_computed'] = \
True if result.policies is not None else False
causal_props_list.append(causal_config_dict)
props[CausalManagerKeys.CAUSAL_EFFECTS] = causal_props_list

return props

def get_data(self):
"""Get causal data
Expand Down
24 changes: 24 additions & 0 deletions responsibleai/tests/causal_manager_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
CausalPolicyGains,
CausalPolicyTreeInternal,
CausalPolicyTreeLeaf)
from responsibleai._internal.constants import (CausalManagerKeys,
ListProperties, ManagerNames)
from responsibleai._tools.causal.causal_result import CausalResult
from responsibleai.exceptions import UserConfigValidationException

Expand Down Expand Up @@ -57,6 +59,9 @@ def validate_causal(rai_insights, data, target_column,
assert len(results) == 1
_check_causal_result(results[0])

_check_causal_properties(rai_insights.causal.list(),
expected_causal_effects=1)
gaugup marked this conversation as resolved.
Show resolved Hide resolved

results = rai_insights.causal.get_data()
assert results is not None
assert isinstance(results, list)
Expand All @@ -71,12 +76,31 @@ def validate_causal(rai_insights, data, target_column,
assert isinstance(results, list)
assert len(results) == 2

_check_causal_properties(rai_insights.causal.list(),
expected_causal_effects=2)

# Add a bad configuration
with pytest.raises(UserConfigValidationException):
rai_insights.causal.add(treatment_features,
nuisance_model='fake_model')


def _check_causal_properties(
causal_props, expected_causal_effects):
assert causal_props[ListProperties.MANAGER_TYPE] == \
ManagerNames.CAUSAL
assert causal_props[
CausalManagerKeys.CAUSAL_EFFECTS] is not None
assert len(
causal_props[CausalManagerKeys.CAUSAL_EFFECTS]) == \
expected_causal_effects

for causal_effect in causal_props[CausalManagerKeys.CAUSAL_EFFECTS]:
assert causal_effect['global_effects_computed']
assert causal_effect['local_effects_computed']
assert causal_effect['policies_computed']


def _check_causal_result(causal_result, is_serialized=False):
assert len(causal_result.id) > 0

Expand Down