From 39a7ff80b3677eb3c95fd869efa0c686d0f681d5 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta <47334368+gaugup@users.noreply.github.com> Date: Thu, 27 Jan 2022 09:35:45 -0800 Subject: [PATCH] Implement list() method causal manager (#1178) * Implement list() method in CausalManager Signed-off-by: Gaurav Gupta * Fix sorted imports Signed-off-by: Gaurav Gupta * Fix code review comments Signed-off-by: Gaurav Gupta --- .../responsibleai/_internal/constants.py | 8 ++++++ .../_tools/causal/causal_config.py | 12 +++++++++ .../responsibleai/managers/causal_manager.py | 23 ++++++++++++++-- .../managers/counterfactual_manager.py | 2 +- .../tests/causal_manager_validator.py | 26 +++++++++++++++++++ 5 files changed, 68 insertions(+), 3 deletions(-) diff --git a/responsibleai/responsibleai/_internal/constants.py b/responsibleai/responsibleai/_internal/constants.py index dc9ab30a04..07f24076b8 100644 --- a/responsibleai/responsibleai/_internal/constants.py +++ b/responsibleai/responsibleai/_internal/constants.py @@ -52,6 +52,14 @@ class CounterfactualManagerKeys(object): COUNTERFACTUALS = 'counterfactuals' +class CausalManagerKeys(object): + """Provide constants for CausalManager key properties.""" + CAUSAL_EFFECTS = 'causal_effects' + GLOBAL_EFFECTS_COMPUTED = 'global_effects_computed' + LOCAL_EFFECTS_COMPUTED = 'local_effects_computed' + POLICIES_COMPUTED = 'policies_computed' + + class SKLearn(object): """Provide scikit-learn related constants.""" diff --git a/responsibleai/responsibleai/_tools/causal/causal_config.py b/responsibleai/responsibleai/_tools/causal/causal_config.py index e8ce41229e..d4984e9ba1 100644 --- a/responsibleai/responsibleai/_tools/causal/causal_config.py +++ b/responsibleai/responsibleai/_tools/causal/causal_config.py @@ -40,3 +40,15 @@ def __init__( self.verbose = verbose self.random_state = random_state self.categorical_features = categorical_features + + def get_config_as_dict(self): + """Returns the dictionary representation of configuration + in the CausalConfig. + + The dictionary contains the different parameters required for + computing the causal effects. + + :return: The dictionary representation of the CausalConfig. + :rtype: dict + """ + return self.__dict__ diff --git a/responsibleai/responsibleai/managers/causal_manager.py b/responsibleai/responsibleai/managers/causal_manager.py index 7fe7e90649..7324011d0c 100644 --- a/responsibleai/responsibleai/managers/causal_manager.py +++ b/responsibleai/responsibleai/managers/causal_manager.py @@ -10,7 +10,8 @@ from econml.solutions.causal_analysis import CausalAnalysis from responsibleai._data_validations import validate_train_test_categories -from responsibleai._internal.constants import ManagerNames +from responsibleai._internal.constants import (CausalManagerKeys, + ListProperties, ManagerNames) from responsibleai._tools.causal.causal_config import CausalConfig from responsibleai._tools.causal.causal_constants import (DefaultParams, ModelTypes, @@ -328,7 +329,25 @@ def get(self): return self._results def list(self): - pass + """List information about the CausalManager. + + :return: A dictionary of properties. + :rtype: dict + """ + props = {ListProperties.MANAGER_TYPE: self.name} + causal_props_list = [] + for result in self._results: + causal_config_dict = result.config.get_config_as_dict() + causal_config_dict[CausalManagerKeys.GLOBAL_EFFECTS_COMPUTED] = \ + True if result.global_effects is not None else False + causal_config_dict[CausalManagerKeys.LOCAL_EFFECTS_COMPUTED] = \ + True if result.local_effects is not None else False + causal_config_dict[CausalManagerKeys.POLICIES_COMPUTED] = \ + True if result.policies is not None else False + causal_props_list.append(causal_config_dict) + props[CausalManagerKeys.CAUSAL_EFFECTS] = causal_props_list + + return props def get_data(self): """Get causal data diff --git a/responsibleai/responsibleai/managers/counterfactual_manager.py b/responsibleai/responsibleai/managers/counterfactual_manager.py index 1d920737b8..843e13c053 100644 --- a/responsibleai/responsibleai/managers/counterfactual_manager.py +++ b/responsibleai/responsibleai/managers/counterfactual_manager.py @@ -583,7 +583,7 @@ def list(self): """List information about the CounterfactualManager. :return: A dictionary of properties. - :rtype: Dict + :rtype: dict """ props = {ListProperties.MANAGER_TYPE: self.name} counterfactual_props_list = [] diff --git a/responsibleai/tests/causal_manager_validator.py b/responsibleai/tests/causal_manager_validator.py index 7bec445c23..fc639c35a5 100644 --- a/responsibleai/tests/causal_manager_validator.py +++ b/responsibleai/tests/causal_manager_validator.py @@ -11,6 +11,8 @@ CausalPolicyGains, CausalPolicyTreeInternal, CausalPolicyTreeLeaf) +from responsibleai._internal.constants import (CausalManagerKeys, + ListProperties, ManagerNames) from responsibleai._tools.causal.causal_result import CausalResult from responsibleai.exceptions import UserConfigValidationException @@ -43,6 +45,8 @@ def validate_causal(rai_insights, data, target_column, nuisance_model='automl', upper_bound_on_cat_expansion=max_cat_expansion) rai_insights.causal.compute() + _check_causal_properties(rai_insights.causal.list(), + expected_causal_effects=0) return # Add the first configuration @@ -57,6 +61,9 @@ def validate_causal(rai_insights, data, target_column, assert len(results) == 1 _check_causal_result(results[0]) + _check_causal_properties(rai_insights.causal.list(), + expected_causal_effects=1) + results = rai_insights.causal.get_data() assert results is not None assert isinstance(results, list) @@ -71,12 +78,31 @@ def validate_causal(rai_insights, data, target_column, assert isinstance(results, list) assert len(results) == 2 + _check_causal_properties(rai_insights.causal.list(), + expected_causal_effects=2) + # Add a bad configuration with pytest.raises(UserConfigValidationException): rai_insights.causal.add(treatment_features, nuisance_model='fake_model') +def _check_causal_properties( + causal_props, expected_causal_effects): + assert causal_props[ListProperties.MANAGER_TYPE] == \ + ManagerNames.CAUSAL + assert causal_props[ + CausalManagerKeys.CAUSAL_EFFECTS] is not None + assert len( + causal_props[CausalManagerKeys.CAUSAL_EFFECTS]) == \ + expected_causal_effects + + for causal_effect in causal_props[CausalManagerKeys.CAUSAL_EFFECTS]: + assert causal_effect['global_effects_computed'] + assert causal_effect['local_effects_computed'] + assert causal_effect['policies_computed'] + + def _check_causal_result(causal_result, is_serialized=False): assert len(causal_result.id) > 0