Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect edges in adjacency_matrix_to_graph #1202

Merged
merged 2 commits into from
Aug 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions dowhy/utils/graph_operations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from queue import LifoQueue

import graphviz
rahulbshrestha marked this conversation as resolved.
Show resolved Hide resolved
import networkx as nx
import numpy as np
from networkx.algorithms.dag import is_directed_acyclic_graph
Expand Down Expand Up @@ -37,18 +38,27 @@ def adjacency_matrix_to_graph(adjacency_matrix, labels=None):
:param labels: List of labels.
:returns: Graph in DOT format.
"""

if adjacency_matrix.ndim != 2:
raise ValueError("Adjacency matrix must have a dimension of 2.")

if isinstance(adjacency_matrix, np.matrix):
adjacency_matrix = np.asarray(adjacency_matrix)

# Only consider edges have absolute edge weight > 0.01
idx = np.abs(adjacency_matrix) > 0.01
dirs = np.where(idx)
import graphviz

d = graphviz.Digraph(engine="dot")
names = labels if labels else [f"x{i}" for i in range(len(adjacency_matrix))]
for name in names:
d.node(name)
for to, from_, coef in zip(dirs[0], dirs[1], adjacency_matrix[idx]):
d.edge(names[from_], names[to], label=str(coef))
return d

if labels is None:
labels = [f"x{i}" for i in range(len(adjacency_matrix))]

for label in labels:
d.node(label)

for from_, to, coef in zip(dirs[0], dirs[1], adjacency_matrix[idx]):
d.edge(labels[from_], labels[to], label=str(coef))


def str_to_dot(string):
Expand Down
Loading