diff --git a/dowhy/causal_identifier/complete_adjustment.py b/dowhy/causal_identifier/complete_adjustment.py new file mode 100644 index 0000000000..f2b5deec05 --- /dev/null +++ b/dowhy/causal_identifier/complete_adjustment.py @@ -0,0 +1,68 @@ +import networkx as nx + +import pywhy_graphs + +def adjustable(self, G, X, Y, Z=None): + + if Z is None: + Z = set() + + #check amenability + if not self._is_amenable(): + return False + + #check if z contains any node from the forbidden set + + if not self._check_forbidden_set(): + return False + + #find the proper back-door graph + proper_back_door_graph = self._proper_backdoor_graph() + + #check if z m-seperates x and y in Gpbd + if not pywhy_graphs.m_seperated(proper_back_door_graph, X, Y, Z): + return False + + return True + +def _is_amenable(G, X, Y): + dp = G.directed_paths(G, X, Y) + pdp = pywhy_graphs.possibly_directed_paths(G, dp) + ppdp = pywhy_graphs.proper_paths(G, pdp) + visible_edges = frozenset(pywhy_graphs.get_visible_edges(G, X)) + for elem in ppdp: + first_edge = elem[0] + if first_edge in visible_edges and first_edge[0] in X: + continue + else: + return False + return True + +def _check_forbidden_set(G,X,Y,Z): + + if Z is None: + Z = set() + + forbidden_set = pywhy_graphs.find_forbidden_set(G, X, Y) + if len(Z.intersection(forbidden_set)) > 0: + return False + else: + return True + +def _proper_backdoor_graph(G,X,Y): + ppdp = pywhy_graphs.proper_possibly_directed_path(G, X, Y) + visible_edges = pywhy_graphs.get_visible_edges(G) # assuming all are directed edges + x_vedges = [] + for elem in visible_edges: + if elem[0] in X: + x_vedges.append(elem) + x_vedges = frozenset(x_vedges) + all_edges = [] + for elem in ppdp: + all_edges.extend(elem) + all_edges = frozenset(all_edges) + to_remove = all_edges.intersection(x_vedges) + G = G.copy() + for elem in to_remove: + G.remove_edge(elem[0], elem[1], G.directed_edge_name) + return G \ No newline at end of file diff --git a/tests/causal_identifiers/test_complete_adjustment.py b/tests/causal_identifiers/test_complete_adjustment.py new file mode 100644 index 0000000000..b9240cd828 --- /dev/null +++ b/tests/causal_identifiers/test_complete_adjustment.py @@ -0,0 +1,143 @@ +import pytest + +from pywhy_graphs import MAG, PAG + +from dowhy.causal_identifier.complete_adjustment import CompleteAdjustment + +def test_complete_adjsutment(): + + # DAGs + + G = MAG() + G.add_edge("Z", "X", G.directed_edge_name) + G.add_edge("Z", "Y", G.directed_edge_name) + G.add_edge("X", "Y", G.directed_edge_name) + + cad = CompleteAdjustment(G, {"X"}, {"Y"}) + + assert cad.adjustable() + + + G = MAG() + G.add_edge("X", "Z", G.directed_edge_name) + G.add_edge("X", "Y", G.directed_edge_name) + + cad = CompleteAdjustment(G, {"X"}, {"Y"}) + + assert cad.adjustable() + + + G = MAG() + G.add_edge("X", "Z", G.directed_edge_name) + G.add_edge("Z", "Y", G.directed_edge_name) + G.add_edge("U", "X", G.directed_edge_name) + G.add_edge("U", "Y", G.directed_edge_name) + + cad = CompleteAdjustment(G, {"X"}, {"Y"}) + + assert not cad.adjustable() + + + # CPDAGs + + G = PAG() + G.add_edge("I", "X", G.directed_edge_name) + G.add_edge("Z", "X", G.directed_edge_name) + G.add_edge("A", "X", G.directed_edge_name) + G.add_edge("X", "Y", G.directed_edge_name) + G.add_edge("Z", "Y", G.directed_edge_name) + G.add_edge("B", "Y", G.directed_edge_name) + G.add_edge("B", "Z", G.circle_edge_name) + G.add_edge("Z", "B", G.circle_edge_name) + G.add_edge("A", "B", G.circle_edge_name) + G.add_edge("B", "A", G.circle_edge_name) + G.add_edge("A", "Z", G.circle_edge_name) + G.add_edge("Z", "A", G.circle_edge_name) + G.add_edge("A", "I", G.circle_edge_name) + G.add_edge("I", "A", G.circle_edge_name) + + cad = CompleteAdjustment(G, {"X"},{"Y"}) + + assert cad.adjustable() + + # MAG + + G = MAG() + G.add_edge("A", "B", G.directed_edge_name) + G.add_edge("B", "C", G.directed_edge_name) + G.add_edge("C", "D", G.directed_edge_name) + G.add_edge("D", "E", G.directed_edge_name) + G.add_edge("A", "E", G.directed_edge_name) + G.add_edge("F", "C", G.directed_edge_name) + G.add_edge("F", "E", G.directed_edge_name) + + cad = CompleteAdjustment(G, {"A", "D"}, {"E","F"}) + + assert cad.adjustable() + + G = MAG() + G.add_edge("A", "B", G.directed_edge_name) + G.add_edge("B", "C", G.directed_edge_name) + G.add_edge("C", "D", G.directed_edge_name) + G.add_edge("D", "E", G.directed_edge_name) + G.add_edge("A", "E", G.directed_edge_name) + G.add_edge("F", "C", G.directed_edge_name) + G.add_edge("F", "E", G.directed_edge_name) + G.add_edge("A", "F", G.directed_edge_name) + + cad = CompleteAdjustment(G, {"A", "D"}, {"E","F"} ) + + assert not cad.adjustable() + + + G = MAG() + G.add_edge("A", "B", G.directed_edge_name) + G.add_edge("B", "C", G.directed_edge_name) + G.add_edge("C", "D", G.directed_edge_name) + G.add_edge("D", "E", G.directed_edge_name) + G.add_edge("A", "F", G.directed_edge_name) + G.add_edge("F", "E", G.directed_edge_name) + G.add_edge("G", "F", G.directed_edge_name) + G.add_edge("G", "C", G.directed_edge_name) + G.add_edge("H", "A", G.directed_edge_name) + G.add_edge("I", "A", G.directed_edge_name) + + cad = CompleteAdjustment(G, {"A", "D"}, {"E"} ) + + assert cad.adjustable() + + G = MAG() + G.add_edge("B", "A", G.directed_edge_name) + G.add_edge("C", "B", G.directed_edge_name) + G.add_edge("C", "D", G.directed_edge_name) + G.add_edge("E", "D", G.directed_edge_name) + G.add_edge("E", "F", G.directed_edge_name) + G.add_edge("F", "A", G.directed_edge_name) + G.add_edge("A", "D", G.directed_edge_name) + + cad = CompleteAdjustment(G, {"A", "C"}, {"D"}) + + assert not cad.adjustable() + + # PAG + + G = PAG() + G.add_edge("A", "B", G.directed_edge_name) + G.add_edge("B", "C", G.directed_edge_name) + G.add_edge("C", "D", G.directed_edge_name) + G.add_edge("D", "E", G.directed_edge_name) + G.add_edge("A", "F", G.directed_edge_name) + G.add_edge("F", "E", G.directed_edge_name) + G.add_edge("F", "C", G.bidirected_edge_name) + G.add_edge("H", "A", G.directed_edge_name) + G.add_edge("I", "A", G.directed_edge_name) + G.add_edge("A", "H", G.circle_edge_name) + G.add_edge("A", "I", G.circle_edge_name) + + cad = CompleteAdjustment(G, {"A", "D"}, {"E"} ) + + assert cad.adjustable() + + + +