forked from akelleh/causality
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
144 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
def get_directed_edges(edge: tuple, arrows: list) -> list: | ||
assert len(arrows) >= 1 | ||
if len(arrows) == 1: | ||
endVertex = arrows[0] | ||
startVertex = edge[int(not edge.index(endVertex))] | ||
return [(startVertex, endVertex)] | ||
else: | ||
return [edge, edge[::-1]] | ||
|
||
|
||
def as_digraph(graph): | ||
digraph = graph.to_directed() | ||
drop_dummy_edges(graph, digraph) | ||
return digraph | ||
|
||
def drop_dummy_edges(graph, digraph): | ||
edges_to_drop = [] | ||
for edge, edge_metadata in digraph.edges.items(): | ||
if edge not in list(graph.edges): | ||
edges_to_drop.append(edge) | ||
for edge in edges_to_drop: | ||
digraph.remove_edge(*edge) | ||
|
||
def get_edges_ICstar(digraph): | ||
edges_ICstar = {'marked':[], | ||
'undirected': [], | ||
'directed': [], | ||
} | ||
for edge, metadata in digraph.edges.items(): | ||
# marked | ||
if metadata['marked']: | ||
edges_ICstar['marked'].append(get_directed_edges(edge, metadata['arrows'])[0]) | ||
else: | ||
# undirected | ||
if len(metadata['arrows']) == 0: | ||
edges_ICstar['undirected'].append(edge) | ||
else: | ||
directed_edges = get_directed_edges(edge, metadata['arrows']) | ||
for e in directed_edges: edges_ICstar['directed'].append(e) | ||
return edges_ICstar | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,38 @@ | ||
import networkx as nx | ||
import matplotlib.pyplot as plt | ||
|
||
def _drop_dummy_edges(graph, digraph): | ||
edges_to_drop = [] | ||
for edge, edge_metadata in digraph.edges.items(): | ||
if edge not in list(graph.edges): | ||
edges_to_drop.append(edge) | ||
for edge in edges_to_drop: | ||
digraph.remove_edge(*edge) | ||
from causality.inference import get_edges_ICstar | ||
from causality.inference import as_digraph | ||
|
||
|
||
def _split_marked_edges(digraph): | ||
marked_edges = [] | ||
unmarked_edges = [] | ||
for edge, edge_metadata in digraph.edges.items(): | ||
if edge_metadata['marked']: | ||
marked_edges.append(edge) | ||
else: | ||
unmarked_edges.append(edge) | ||
return marked_edges, unmarked_edges | ||
|
||
|
||
def plot_DAG(graph, plot_attributes=None): | ||
def plot_marked_partially_directed_graph(graph, plot_attributes=None): | ||
if plot_attributes is None: | ||
plot_attributes = {} | ||
|
||
unmarked_edge_color = plot_attributes.get('unmarked_edge_color', 'black') | ||
marked_edge_color = plot_attributes.get('unmarked_edge_color', 'red') | ||
marked_edge_color = plot_attributes.get('unmarked_edge_color', 'black') | ||
arrowsize = plot_attributes.get('arrowsize', 25) | ||
|
||
digraph = graph.to_directed() | ||
|
||
_drop_dummy_edges(graph, digraph) | ||
|
||
marked_edges, unmarked_edges = _split_marked_edges(digraph) | ||
digraph = as_digraph(graph) | ||
edges_ICstar = get_edges_ICstar(digraph) | ||
|
||
pos = nx.spring_layout(digraph) | ||
nx.draw_networkx_nodes(digraph, pos) | ||
nx.draw_networkx_labels(digraph, pos) | ||
nx.draw_networkx_edges(digraph, pos, arrows=True, edgelist=unmarked_edges, | ||
|
||
# directed edges | ||
nx.draw_networkx_edges(digraph, pos, arrows=True, edgelist=edges_ICstar['directed'], | ||
edge_color=unmarked_edge_color, | ||
arrowsize=arrowsize) | ||
# undirected edges | ||
nx.draw_networkx_edges(digraph, pos, arrows=False, edgelist=edges_ICstar['undirected'], | ||
edge_color=unmarked_edge_color, | ||
arrowsize=arrowsize) | ||
nx.draw_networkx_edges(digraph, pos, arrows=True, edgelist=marked_edges, | ||
edge_color=marked_edge_color, arrowsize=arrowsize) | ||
|
||
# marked edges | ||
nx.draw_networkx_edges(digraph, pos, arrows=True, edgelist=edges_ICstar['marked'], | ||
edge_color=marked_edge_color, | ||
arrowsize=arrowsize) | ||
nx.draw_networkx_edge_labels(digraph, pos, arrows=True, edgelist=edges_ICstar['marked'], | ||
edge_labels={e: '*' for e in edges_ICstar['marked']}, arrowsize=arrowsize) | ||
plt.axis('off') | ||
return pos, digraph | ||
return pos, digraph, edges_ICstar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from causality.inference.search import IC | ||
from causality.inference.independence_tests import RobustRegressionTest | ||
|
||
from causality.inference import drop_dummy_edges | ||
from causality.inference import get_directed_edges | ||
from causality.inference import get_edges_ICstar | ||
from causality.inference import as_digraph | ||
|
||
from networkx.classes.digraph import DiGraph | ||
|
||
def _make_DAG(): | ||
# generate some toy data: | ||
np.random.seed(1) | ||
SIZE = 2000 | ||
x1 = np.random.normal(size=SIZE) | ||
x2 = x1 + np.random.normal(size=SIZE) | ||
x3 = x1 + np.random.normal(size=SIZE) | ||
x4 = x2 + x3 + np.random.normal(size=SIZE) | ||
x5 = x4 + np.random.normal(size=SIZE) | ||
|
||
X = pd.DataFrame({'x1': x1, 'x2': x2, 'x3': x3, 'x4': x4, 'x5': x5}) | ||
|
||
# define the variable types: 'c' is 'continuous'. The variables defined here | ||
# are the ones the search is performed over -- NOT all the variables defined | ||
# in the data frame. | ||
variable_types = {'x1': 'c', 'x2': 'c', 'x3': 'c', 'x4': 'c', 'x5': 'c'} | ||
|
||
# run the search | ||
ic_algorithm = IC(RobustRegressionTest) | ||
graph = ic_algorithm.search(X, variable_types) | ||
return graph | ||
|
||
|
||
def test_get_directed_edges(): | ||
edge = ('a', 'b') | ||
arrows = ['a'] | ||
directed_edges = get_directed_edges(edge, arrows) | ||
assert directed_edges == [('b', 'a')] | ||
|
||
arrows = ['b'] | ||
directed_edges = get_directed_edges(edge, arrows) | ||
assert directed_edges == [('a', 'b')] | ||
|
||
arrows = ['a','b'] | ||
directed_edges = get_directed_edges(edge, arrows) | ||
assert ('a', 'b') in directed_edges | ||
assert ('b', 'a') in directed_edges | ||
assert len(directed_edges) == 2 | ||
|
||
|
||
def test_as_digraph(): | ||
graph = _make_DAG() | ||
digraph = as_digraph(graph) | ||
for edge, metadata in graph.edges.items(): | ||
assert edge in digraph.edges | ||
assert isinstance(digraph, DiGraph) | ||
|
||
|
||
def test_drop_dummy_edges(): | ||
graph = _make_DAG() | ||
digraph = as_digraph(graph) | ||
drop_dummy_edges(graph, digraph) | ||
assert set(graph.edges) == set(digraph.edges) | ||
return digraph | ||
|
||
|
||
def test_get_edges_ICstar(): | ||
digraph = test_drop_dummy_edges() | ||
edges_ICstar = get_edges_ICstar(digraph) | ||
assert edges_ICstar['marked'] == [('x4', 'x5')] | ||
assert set(edges_ICstar['undirected']) == {('x1', 'x2'), ('x1', 'x3')} | ||
assert set(edges_ICstar['directed']) == {('x2', 'x4'), ('x3', 'x4')} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters