diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index fd2c8adb98b23..2cb57c4e0fec5 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect -from collections import Counter +from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass, field from enum import Enum @@ -423,12 +423,19 @@ def prefixed(id: str) -> str: def reid(self) -> Graph: """Return a new graph with all nodes re-identified, using their unique, readable names where possible.""" - node_labels = {node.id: node.name for node in self.nodes.values()} - node_label_counts = Counter(node_labels.values()) + node_name_to_ids = defaultdict(list) + for node in self.nodes.values(): + node_name_to_ids[node.name].append(node.id) + + unique_labels = { + node_id: node_name if len(node_ids) == 1 else f"{node_name}_{i + 1}" + for node_name, node_ids in node_name_to_ids.items() + for i, node_id in enumerate(node_ids) + } def _get_node_id(node_id: str) -> str: - label = node_labels[node_id] - if is_uuid(node_id) and node_label_counts[label] == 1: + label = unique_labels[node_id] + if is_uuid(node_id): return label else: return node_id diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index ba9d742b37407..208bdbac45e7f 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -26,6 +26,20 @@ ''' # --- +# name: test_graph_mermaid_duplicate_nodes[mermaid] + ''' + graph TD; + PromptInput --> PromptTemplate_1; + Parallel_llm1_llm2_Input --> FakeListLLM_1; + FakeListLLM_1 --> Parallel_llm1_llm2_Output; + Parallel_llm1_llm2_Input --> FakeListLLM_2; + FakeListLLM_2 --> Parallel_llm1_llm2_Output; + PromptTemplate_1 --> Parallel_llm1_llm2_Input; + PromptTemplate_2 --> PromptTemplateOutput; + Parallel_llm1_llm2_Output --> PromptTemplate_2; + + ''' +# --- # name: test_graph_sequence[ascii] ''' +-------------+ diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 8898b64c01fae..39f2f2871a800 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -405,3 +405,17 @@ def test_graph_mermaid_escape_node_label() -> None: assert _escape_node_label("foo-bar") == "foo-bar" assert _escape_node_label("foo_1") == "foo_1" assert _escape_node_label("#foo*&!") == "_foo___" + + +def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None: + fake_llm = FakeListLLM(responses=["foo", "bar"]) + sequence: Runnable = ( + PromptTemplate.from_template("Hello, {input}") + | { + "llm1": fake_llm, + "llm2": fake_llm, + } + | PromptTemplate.from_template("{llm1} {llm2}") + ) + graph = sequence.get_graph() + assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid")