Skip to content

Commit

Permalink
008 plot marked dag (#9)
Browse files Browse the repository at this point in the history
* Resolve #6

* Resolve #8
  • Loading branch information
jaryaman authored Jun 28, 2020
1 parent db48186 commit 259f9be
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 33 deletions.
41 changes: 41 additions & 0 deletions causality/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
def get_directed_edges(edge: tuple, arrows: list) -> list:
assert len(arrows) >= 1
if len(arrows) == 1:
endVertex = arrows[0]
startVertex = edge[int(not edge.index(endVertex))]
return [(startVertex, endVertex)]
else:
return [edge, edge[::-1]]


def as_digraph(graph):
digraph = graph.to_directed()
drop_dummy_edges(graph, digraph)
return digraph

def drop_dummy_edges(graph, digraph):
edges_to_drop = []
for edge, edge_metadata in digraph.edges.items():
if edge not in list(graph.edges):
edges_to_drop.append(edge)
for edge in edges_to_drop:
digraph.remove_edge(*edge)

def get_edges_ICstar(digraph):
edges_ICstar = {'marked':[],
'undirected': [],
'directed': [],
}
for edge, metadata in digraph.edges.items():
# marked
if metadata['marked']:
edges_ICstar['marked'].append(get_directed_edges(edge, metadata['arrows'])[0])
else:
# undirected
if len(metadata['arrows']) == 0:
edges_ICstar['undirected'].append(edge)
else:
directed_edges = get_directed_edges(edge, metadata['arrows'])
for e in directed_edges: edges_ICstar['directed'].append(e)
return edges_ICstar

51 changes: 21 additions & 30 deletions causality/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,38 @@
import networkx as nx
import matplotlib.pyplot as plt

def _drop_dummy_edges(graph, digraph):
edges_to_drop = []
for edge, edge_metadata in digraph.edges.items():
if edge not in list(graph.edges):
edges_to_drop.append(edge)
for edge in edges_to_drop:
digraph.remove_edge(*edge)
from causality.inference import get_edges_ICstar
from causality.inference import as_digraph


def _split_marked_edges(digraph):
marked_edges = []
unmarked_edges = []
for edge, edge_metadata in digraph.edges.items():
if edge_metadata['marked']:
marked_edges.append(edge)
else:
unmarked_edges.append(edge)
return marked_edges, unmarked_edges


def plot_DAG(graph, plot_attributes=None):
def plot_marked_partially_directed_graph(graph, plot_attributes=None):
if plot_attributes is None:
plot_attributes = {}

unmarked_edge_color = plot_attributes.get('unmarked_edge_color', 'black')
marked_edge_color = plot_attributes.get('unmarked_edge_color', 'red')
marked_edge_color = plot_attributes.get('unmarked_edge_color', 'black')
arrowsize = plot_attributes.get('arrowsize', 25)

digraph = graph.to_directed()

_drop_dummy_edges(graph, digraph)

marked_edges, unmarked_edges = _split_marked_edges(digraph)
digraph = as_digraph(graph)
edges_ICstar = get_edges_ICstar(digraph)

pos = nx.spring_layout(digraph)
nx.draw_networkx_nodes(digraph, pos)
nx.draw_networkx_labels(digraph, pos)
nx.draw_networkx_edges(digraph, pos, arrows=True, edgelist=unmarked_edges,

# directed edges
nx.draw_networkx_edges(digraph, pos, arrows=True, edgelist=edges_ICstar['directed'],
edge_color=unmarked_edge_color,
arrowsize=arrowsize)
# undirected edges
nx.draw_networkx_edges(digraph, pos, arrows=False, edgelist=edges_ICstar['undirected'],
edge_color=unmarked_edge_color,
arrowsize=arrowsize)
nx.draw_networkx_edges(digraph, pos, arrows=True, edgelist=marked_edges,
edge_color=marked_edge_color, arrowsize=arrowsize)

# marked edges
nx.draw_networkx_edges(digraph, pos, arrows=True, edgelist=edges_ICstar['marked'],
edge_color=marked_edge_color,
arrowsize=arrowsize)
nx.draw_networkx_edge_labels(digraph, pos, arrows=True, edgelist=edges_ICstar['marked'],
edge_labels={e: '*' for e in edges_ICstar['marked']}, arrowsize=arrowsize)
plt.axis('off')
return pos, digraph
return pos, digraph, edges_ICstar
74 changes: 74 additions & 0 deletions tests/unit/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
import pandas as pd
from causality.inference.search import IC
from causality.inference.independence_tests import RobustRegressionTest

from causality.inference import drop_dummy_edges
from causality.inference import get_directed_edges
from causality.inference import get_edges_ICstar
from causality.inference import as_digraph

from networkx.classes.digraph import DiGraph

def _make_DAG():
# generate some toy data:
np.random.seed(1)
SIZE = 2000
x1 = np.random.normal(size=SIZE)
x2 = x1 + np.random.normal(size=SIZE)
x3 = x1 + np.random.normal(size=SIZE)
x4 = x2 + x3 + np.random.normal(size=SIZE)
x5 = x4 + np.random.normal(size=SIZE)

X = pd.DataFrame({'x1': x1, 'x2': x2, 'x3': x3, 'x4': x4, 'x5': x5})

# define the variable types: 'c' is 'continuous'. The variables defined here
# are the ones the search is performed over -- NOT all the variables defined
# in the data frame.
variable_types = {'x1': 'c', 'x2': 'c', 'x3': 'c', 'x4': 'c', 'x5': 'c'}

# run the search
ic_algorithm = IC(RobustRegressionTest)
graph = ic_algorithm.search(X, variable_types)
return graph


def test_get_directed_edges():
edge = ('a', 'b')
arrows = ['a']
directed_edges = get_directed_edges(edge, arrows)
assert directed_edges == [('b', 'a')]

arrows = ['b']
directed_edges = get_directed_edges(edge, arrows)
assert directed_edges == [('a', 'b')]

arrows = ['a','b']
directed_edges = get_directed_edges(edge, arrows)
assert ('a', 'b') in directed_edges
assert ('b', 'a') in directed_edges
assert len(directed_edges) == 2


def test_as_digraph():
graph = _make_DAG()
digraph = as_digraph(graph)
for edge, metadata in graph.edges.items():
assert edge in digraph.edges
assert isinstance(digraph, DiGraph)


def test_drop_dummy_edges():
graph = _make_DAG()
digraph = as_digraph(graph)
drop_dummy_edges(graph, digraph)
assert set(graph.edges) == set(digraph.edges)
return digraph


def test_get_edges_ICstar():
digraph = test_drop_dummy_edges()
edges_ICstar = get_edges_ICstar(digraph)
assert edges_ICstar['marked'] == [('x4', 'x5')]
assert set(edges_ICstar['undirected']) == {('x1', 'x2'), ('x1', 'x3')}
assert set(edges_ICstar['directed']) == {('x2', 'x4'), ('x3', 'x4')}
11 changes: 8 additions & 3 deletions tests/unit/plot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from causality.plot import plot_DAG
from causality.plot import plot_marked_partially_directed_graph
from networkx.classes.digraph import DiGraph

import numpy
import pandas as pd

from causality.inference.search import IC
from causality.inference.independence_tests import RobustRegressionTest
from networkx.classes.digraph import DiGraph



def _make_DAG():
Expand All @@ -31,6 +32,10 @@ def _make_DAG():

def test_plot_DAG():
graph = _make_DAG()
pos, digraph = plot_DAG(graph)
pos, digraph, edges_ICstar = plot_marked_partially_directed_graph(graph)
assert isinstance(pos, dict)
assert isinstance(digraph, DiGraph)
assert isinstance(edges_ICstar, dict)
assert 'marked' in edges_ICstar
assert 'directed' in edges_ICstar
assert 'undirected' in edges_ICstar

0 comments on commit 259f9be

Please sign in to comment.