Skip to content

Commit

Permalink
Merge pull request #10 from aurelio-labs/vittorio/visualize_callback_…
Browse files Browse the repository at this point in the history
…edits

Fixed visualize and callback state
  • Loading branch information
jamescalam authored Nov 19, 2024
2 parents bbfb89d + f27676d commit ddea784
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions graphai/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def add_edge(self, source: _Node, destination: _Node):
edge = Edge(source, destination)
self.edges.append(edge)

def add_router(self, source: _Node, router: _Node, destinations: List[_Node]):
def add_router(self, sources: list[_Node], router: _Node, destinations: List[_Node]):
if not router.is_router:
raise TypeError("A router object must be passed to the router parameter.")
self.add_edge(source, router)
[self.add_edge(source, router) for source in sources]
for destination in destinations:
self.add_edge(router, destination)

Expand Down Expand Up @@ -102,7 +102,7 @@ async def execute(self, input):
"by setting `max_steps` when initializing the Graph object."
)
# TODO JB: may need to add end callback here to close the queue for every execution
if self.callback:
if self.callback and "callback" in state:
await self.callback.close()
del state["callback"]
return state
Expand Down Expand Up @@ -144,37 +144,43 @@ def visualize(self):
for edge in self.edges:
G.add_edge(edge.source.name, edge.destination.name)

# Compute the topological generations
generations = list(nx.topological_generations(G))
y_max = len(generations)

# Create a dictionary to store the y-coordinate for each node
y_coord = {}
for i, generation in enumerate(generations):
for node in generation:
y_coord[node] = y_max - i - 1

# Set up the layout
pos = {}
for i, generation in enumerate(generations):
x = 0
for node in generation:
pos[node] = (x, y_coord[node])
x += 1

# Center each level horizontally
for i, generation in enumerate(generations):
x_center = sum(pos[node][0] for node in generation) / len(generation)
for node in generation:
pos[node] = (pos[node][0] - x_center, pos[node][1])

# Scale the layout
max_x = max(abs(p[0]) for p in pos.values())
max_y = max(abs(p[1]) for p in pos.values())
scale = min(0.8 / max_x, 0.8 / max_y)
pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}

# Draw the graph
if nx.is_directed_acyclic_graph(G):
logger.info("The graph is acyclic. Visualization will use a topological layout.")
# Use topological layout if acyclic
# Compute the topological generations
generations = list(nx.topological_generations(G))
y_max = len(generations)

# Create a dictionary to store the y-coordinate for each node
y_coord = {}
for i, generation in enumerate(generations):
for node in generation:
y_coord[node] = y_max - i - 1

# Set up the layout
pos = {}
for i, generation in enumerate(generations):
x = 0
for node in generation:
pos[node] = (x, y_coord[node])
x += 1

# Center each level horizontally
for i, generation in enumerate(generations):
x_center = sum(pos[node][0] for node in generation) / len(generation)
for node in generation:
pos[node] = (pos[node][0] - x_center, pos[node][1])

# Scale the layout
max_x = max(abs(p[0]) for p in pos.values())
max_y = max(abs(p[1]) for p in pos.values())
scale = min(0.8 / max_x, 0.8 / max_y)
pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}

else:
print("Warning: The graph contains cycles. Visualization will use a spring layout.")
pos = nx.spring_layout(G, k=1, iterations=50)

plt.figure(figsize=(8, 6))
nx.draw(G, pos, with_labels=True, node_color='lightblue',
node_size=3000, font_size=8, font_weight='bold',
Expand All @@ -184,6 +190,7 @@ def visualize(self):
plt.show()



class Edge:
def __init__(self, source, destination):
self.source = source
Expand Down

0 comments on commit ddea784

Please sign in to comment.