Skip to content

Commit

Permalink
Implement list() method causal manager (#1178)
Browse files Browse the repository at this point in the history
* Implement list() method in CausalManager

Signed-off-by: Gaurav Gupta <[email protected]>

* Fix sorted imports

Signed-off-by: Gaurav Gupta <[email protected]>

* Fix code review comments

Signed-off-by: Gaurav Gupta <[email protected]>
  • Loading branch information
gaugup authored Jan 27, 2022
1 parent efcfe07 commit 39a7ff8
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 3 deletions.
8 changes: 8 additions & 0 deletions responsibleai/responsibleai/_internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ class CounterfactualManagerKeys(object):
COUNTERFACTUALS = 'counterfactuals'


class CausalManagerKeys(object):
"""Provide constants for CausalManager key properties."""
CAUSAL_EFFECTS = 'causal_effects'
GLOBAL_EFFECTS_COMPUTED = 'global_effects_computed'
LOCAL_EFFECTS_COMPUTED = 'local_effects_computed'
POLICIES_COMPUTED = 'policies_computed'


class SKLearn(object):
"""Provide scikit-learn related constants."""

Expand Down
12 changes: 12 additions & 0 deletions responsibleai/responsibleai/_tools/causal/causal_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,15 @@ def __init__(
self.verbose = verbose
self.random_state = random_state
self.categorical_features = categorical_features

def get_config_as_dict(self):
"""Returns the dictionary representation of configuration
in the CausalConfig.
The dictionary contains the different parameters required for
computing the causal effects.
:return: The dictionary representation of the CausalConfig.
:rtype: dict
"""
return self.__dict__
23 changes: 21 additions & 2 deletions responsibleai/responsibleai/managers/causal_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from econml.solutions.causal_analysis import CausalAnalysis

from responsibleai._data_validations import validate_train_test_categories
from responsibleai._internal.constants import ManagerNames
from responsibleai._internal.constants import (CausalManagerKeys,
ListProperties, ManagerNames)
from responsibleai._tools.causal.causal_config import CausalConfig
from responsibleai._tools.causal.causal_constants import (DefaultParams,
ModelTypes,
Expand Down Expand Up @@ -328,7 +329,25 @@ def get(self):
return self._results

def list(self):
pass
"""List information about the CausalManager.
:return: A dictionary of properties.
:rtype: dict
"""
props = {ListProperties.MANAGER_TYPE: self.name}
causal_props_list = []
for result in self._results:
causal_config_dict = result.config.get_config_as_dict()
causal_config_dict[CausalManagerKeys.GLOBAL_EFFECTS_COMPUTED] = \
True if result.global_effects is not None else False
causal_config_dict[CausalManagerKeys.LOCAL_EFFECTS_COMPUTED] = \
True if result.local_effects is not None else False
causal_config_dict[CausalManagerKeys.POLICIES_COMPUTED] = \
True if result.policies is not None else False
causal_props_list.append(causal_config_dict)
props[CausalManagerKeys.CAUSAL_EFFECTS] = causal_props_list

return props

def get_data(self):
"""Get causal data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def list(self):
"""List information about the CounterfactualManager.
:return: A dictionary of properties.
:rtype: Dict
:rtype: dict
"""
props = {ListProperties.MANAGER_TYPE: self.name}
counterfactual_props_list = []
Expand Down
26 changes: 26 additions & 0 deletions responsibleai/tests/causal_manager_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
CausalPolicyGains,
CausalPolicyTreeInternal,
CausalPolicyTreeLeaf)
from responsibleai._internal.constants import (CausalManagerKeys,
ListProperties, ManagerNames)
from responsibleai._tools.causal.causal_result import CausalResult
from responsibleai.exceptions import UserConfigValidationException

Expand Down Expand Up @@ -43,6 +45,8 @@ def validate_causal(rai_insights, data, target_column,
nuisance_model='automl',
upper_bound_on_cat_expansion=max_cat_expansion)
rai_insights.causal.compute()
_check_causal_properties(rai_insights.causal.list(),
expected_causal_effects=0)
return

# Add the first configuration
Expand All @@ -57,6 +61,9 @@ def validate_causal(rai_insights, data, target_column,
assert len(results) == 1
_check_causal_result(results[0])

_check_causal_properties(rai_insights.causal.list(),
expected_causal_effects=1)

results = rai_insights.causal.get_data()
assert results is not None
assert isinstance(results, list)
Expand All @@ -71,12 +78,31 @@ def validate_causal(rai_insights, data, target_column,
assert isinstance(results, list)
assert len(results) == 2

_check_causal_properties(rai_insights.causal.list(),
expected_causal_effects=2)

# Add a bad configuration
with pytest.raises(UserConfigValidationException):
rai_insights.causal.add(treatment_features,
nuisance_model='fake_model')


def _check_causal_properties(
causal_props, expected_causal_effects):
assert causal_props[ListProperties.MANAGER_TYPE] == \
ManagerNames.CAUSAL
assert causal_props[
CausalManagerKeys.CAUSAL_EFFECTS] is not None
assert len(
causal_props[CausalManagerKeys.CAUSAL_EFFECTS]) == \
expected_causal_effects

for causal_effect in causal_props[CausalManagerKeys.CAUSAL_EFFECTS]:
assert causal_effect['global_effects_computed']
assert causal_effect['local_effects_computed']
assert causal_effect['policies_computed']


def _check_causal_result(causal_result, is_serialized=False):
assert len(causal_result.id) > 0

Expand Down

0 comments on commit 39a7ff8

Please sign in to comment.