Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Implement the Generalised Complete Adjustment Criterion #1148

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions dowhy/causal_identifier/complete_adjustment.py
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import networkx as nx

import pywhy_graphs

class CompleteAdjustment:

def __init__(self, graph, x, y, z=None):
self._graph = graph
self._X = x
self._Y = y
if z is None:
self._Z = set()
else:
self._Z = z

def adjustable(self, G):
#check amenability
if not self._is_amenable():
return False

#check if z contains any node from the forbidden set

if not self._check_forbidden_set():
return False

#find the proper back-door graph
proper_back_door_graph = self._proper_backdoor_graph()

#check if z m-seperates x and y in Gpbd
if not pywhy_graphs.m_seperated(proper_back_door_graph, self._X, self._Y, self._Z):
return False

return True

def _is_amenable(self):
dp = self._graph.directed_paths(self._X, self._Y)
pdp = pywhy_graphs.possibly_directed_paths(self._graph, dp)
ppdp = pywhy_graphs.proper_paths(self._graph, pdp)
visible_edges = frozenset(pywhy_graphs.get_visible_edges(self._graph, self._X))
for elem in ppdp:
first_edge = elem[0]
if first_edge in visible_edges and first_edge[0] in self._X:
continue
else:
return False
return True

def _check_forbidden_set(self):
forbidden_set = pywhy_graphs.find_forbidden_set(self._graph, self._X, self._Y)
if len(self._Z.intersection(forbidden_set)) > 0:
return False
else:
return True

def _proper_backdoor_graph(self):
dp = self._graph.directed_paths(self._X, self._Y)
pdp = pywhy_graphs.possibly_directed_paths(self._graph, dp)
ppdp = pywhy_graphs.proper_paths(self._graph, pdp)
aryan26roy marked this conversation as resolved.
Show resolved Hide resolved
visible_edges = pywhy_graphs.get_visible_edges(self._graph) # assuming all are directed edges
x_vedges = []
for elem in visible_edges:
if elem[0] in self._X:
x_vedges.append(elem)
x_vedges = frozenset(x_vedges)
all_edges = []
for elem in ppdp:
all_edges.extend(elem)
all_edges = frozenset(all_edges)
to_remove = all_edges.intersection(x_vedges)
G = self._graph.copy()
for elem in to_remove:
G.remove_edge(elem[0], elem[1], G.directed_edge_name)
return G
143 changes: 143 additions & 0 deletions tests/causal_identifiers/test_complete_adjustment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import pytest

from pywhy_graphs import MAG, PAG

from dowhy.causal_identifier.complete_adjustment import CompleteAdjustment

def test_complete_adjsutment():

# DAGs

G = MAG()
G.add_edge("Z", "X", G.directed_edge_name)
G.add_edge("Z", "Y", G.directed_edge_name)
G.add_edge("X", "Y", G.directed_edge_name)

cad = CompleteAdjustment(G, {"X"}, {"Y"})

assert cad.adjustable()


G = MAG()
G.add_edge("X", "Z", G.directed_edge_name)
G.add_edge("X", "Y", G.directed_edge_name)

cad = CompleteAdjustment(G, {"X"}, {"Y"})

assert cad.adjustable()


G = MAG()
G.add_edge("X", "Z", G.directed_edge_name)
G.add_edge("Z", "Y", G.directed_edge_name)
G.add_edge("U", "X", G.directed_edge_name)
G.add_edge("U", "Y", G.directed_edge_name)

cad = CompleteAdjustment(G, {"X"}, {"Y"})

assert not cad.adjustable()


# CPDAGs

G = PAG()
G.add_edge("I", "X", G.directed_edge_name)
G.add_edge("Z", "X", G.directed_edge_name)
G.add_edge("A", "X", G.directed_edge_name)
G.add_edge("X", "Y", G.directed_edge_name)
G.add_edge("Z", "Y", G.directed_edge_name)
G.add_edge("B", "Y", G.directed_edge_name)
G.add_edge("B", "Z", G.circle_edge_name)
G.add_edge("Z", "B", G.circle_edge_name)
G.add_edge("A", "B", G.circle_edge_name)
G.add_edge("B", "A", G.circle_edge_name)
G.add_edge("A", "Z", G.circle_edge_name)
G.add_edge("Z", "A", G.circle_edge_name)
G.add_edge("A", "I", G.circle_edge_name)
G.add_edge("I", "A", G.circle_edge_name)

cad = CompleteAdjustment(G, {"X"},{"Y"})

assert cad.adjustable()

# MAG

G = MAG()
G.add_edge("A", "B", G.directed_edge_name)
G.add_edge("B", "C", G.directed_edge_name)
G.add_edge("C", "D", G.directed_edge_name)
G.add_edge("D", "E", G.directed_edge_name)
G.add_edge("A", "E", G.directed_edge_name)
G.add_edge("F", "C", G.directed_edge_name)
G.add_edge("F", "E", G.directed_edge_name)

cad = CompleteAdjustment(G, {"A", "D"}, {"E","F"})

assert cad.adjustable()

G = MAG()
G.add_edge("A", "B", G.directed_edge_name)
G.add_edge("B", "C", G.directed_edge_name)
G.add_edge("C", "D", G.directed_edge_name)
G.add_edge("D", "E", G.directed_edge_name)
G.add_edge("A", "E", G.directed_edge_name)
G.add_edge("F", "C", G.directed_edge_name)
G.add_edge("F", "E", G.directed_edge_name)
G.add_edge("A", "F", G.directed_edge_name)

cad = CompleteAdjustment(G, {"A", "D"}, {"E","F"} )

assert not cad.adjustable()


G = MAG()
G.add_edge("A", "B", G.directed_edge_name)
G.add_edge("B", "C", G.directed_edge_name)
G.add_edge("C", "D", G.directed_edge_name)
G.add_edge("D", "E", G.directed_edge_name)
G.add_edge("A", "F", G.directed_edge_name)
G.add_edge("F", "E", G.directed_edge_name)
G.add_edge("G", "F", G.directed_edge_name)
G.add_edge("G", "C", G.directed_edge_name)
G.add_edge("H", "A", G.directed_edge_name)
G.add_edge("I", "A", G.directed_edge_name)

cad = CompleteAdjustment(G, {"A", "D"}, {"E"} )

assert cad.adjustable()

G = MAG()
G.add_edge("B", "A", G.directed_edge_name)
G.add_edge("C", "B", G.directed_edge_name)
G.add_edge("C", "D", G.directed_edge_name)
G.add_edge("E", "D", G.directed_edge_name)
G.add_edge("E", "F", G.directed_edge_name)
G.add_edge("F", "A", G.directed_edge_name)
G.add_edge("A", "D", G.directed_edge_name)

cad = CompleteAdjustment(G, {"A", "C"}, {"D"})

assert not cad.adjustable()

# PAG

G = PAG()
G.add_edge("A", "B", G.directed_edge_name)
G.add_edge("B", "C", G.directed_edge_name)
G.add_edge("C", "D", G.directed_edge_name)
G.add_edge("D", "E", G.directed_edge_name)
G.add_edge("A", "F", G.directed_edge_name)
G.add_edge("F", "E", G.directed_edge_name)
G.add_edge("F", "C", G.bidirected_edge_name)
G.add_edge("H", "A", G.directed_edge_name)
G.add_edge("I", "A", G.directed_edge_name)
G.add_edge("A", "H", G.circle_edge_name)
G.add_edge("A", "I", G.circle_edge_name)

cad = CompleteAdjustment(G, {"A", "D"}, {"E"} )

assert cad.adjustable()




Loading