Skip to content

Commit

Permalink
core: use friendlier names for duplicated nodes in mermaid output (#2…
Browse files Browse the repository at this point in the history
…7747)

Thank you for contributing to LangChain!

- [x] **PR title**: "core: use friendlier names for duplicated nodes in
mermaid output"

- **Description:** When generating the Mermaid visualization of a chain,
if the chain had multiple nodes of the same type, the reid function
would replace their names with the UUID node_id. This made the generated
graph difficult to understand. This change deduplicates the nodes in a
chain by appending an index to their names.
- **Issue:** None
- **Discussion:**
#27714
- **Dependencies:** None

- [ ] **Add tests and docs**:  
- Currently this functionality is not covered by unit tests, happy to
add tests if you'd like


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

# Example Code:
```python
from langchain_core.runnables import RunnablePassthrough

def fake_llm(prompt: str) -> str: # Fake LLM for the example
    return "completion"

runnable = {
    'llm1':  fake_llm,
    'llm2':  fake_llm,
} | RunnablePassthrough.assign(
    total_chars=lambda inputs: len(inputs['llm1'] + inputs['llm2'])
)

print(runnable.get_graph().draw_mermaid(with_styles=False))
```

# Before
```mermaid
graph TD;
	Parallel_llm1_llm2_Input --> 0b01139db5ed4587ad37964e3a40c0ec;
	0b01139db5ed4587ad37964e3a40c0ec --> Parallel_llm1_llm2_Output;
	Parallel_llm1_llm2_Input --> a98d4b56bd294156a651230b9293347f;
	a98d4b56bd294156a651230b9293347f --> Parallel_llm1_llm2_Output;
	Parallel_total_chars_Input --> Lambda;
	Lambda --> Parallel_total_chars_Output;
	Parallel_total_chars_Input --> Passthrough;
	Passthrough --> Parallel_total_chars_Output;
	Parallel_llm1_llm2_Output --> Parallel_total_chars_Input;
```

# After
```mermaid
graph TD;
	Parallel_llm1_llm2_Input --> fake_llm_1;
	fake_llm_1 --> Parallel_llm1_llm2_Output;
	Parallel_llm1_llm2_Input --> fake_llm_2;
	fake_llm_2 --> Parallel_llm1_llm2_Output;
	Parallel_total_chars_Input --> Lambda;
	Lambda --> Parallel_total_chars_Output;
	Parallel_total_chars_Input --> Passthrough;
	Passthrough --> Parallel_total_chars_Output;
	Parallel_llm1_llm2_Output --> Parallel_total_chars_Input;
```
  • Loading branch information
antwhite authored Oct 31, 2024
1 parent 71f590d commit e3ea365
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 5 deletions.
17 changes: 12 additions & 5 deletions libs/core/langchain_core/runnables/graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import inspect
from collections import Counter
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -423,12 +423,19 @@ def prefixed(id: str) -> str:
def reid(self) -> Graph:
"""Return a new graph with all nodes re-identified,
using their unique, readable names where possible."""
node_labels = {node.id: node.name for node in self.nodes.values()}
node_label_counts = Counter(node_labels.values())
node_name_to_ids = defaultdict(list)
for node in self.nodes.values():
node_name_to_ids[node.name].append(node.id)

unique_labels = {
node_id: node_name if len(node_ids) == 1 else f"{node_name}_{i + 1}"
for node_name, node_ids in node_name_to_ids.items()
for i, node_id in enumerate(node_ids)
}

def _get_node_id(node_id: str) -> str:
label = node_labels[node_id]
if is_uuid(node_id) and node_label_counts[label] == 1:
label = unique_labels[node_id]
if is_uuid(node_id):
return label
else:
return node_id
Expand Down
14 changes: 14 additions & 0 deletions libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@

'''
# ---
# name: test_graph_mermaid_duplicate_nodes[mermaid]
'''
graph TD;
PromptInput --> PromptTemplate_1;
Parallel_llm1_llm2_Input --> FakeListLLM_1;
FakeListLLM_1 --> Parallel_llm1_llm2_Output;
Parallel_llm1_llm2_Input --> FakeListLLM_2;
FakeListLLM_2 --> Parallel_llm1_llm2_Output;
PromptTemplate_1 --> Parallel_llm1_llm2_Input;
PromptTemplate_2 --> PromptTemplateOutput;
Parallel_llm1_llm2_Output --> PromptTemplate_2;

'''
# ---
# name: test_graph_sequence[ascii]
'''
+-------------+
Expand Down
14 changes: 14 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,17 @@ def test_graph_mermaid_escape_node_label() -> None:
assert _escape_node_label("foo-bar") == "foo-bar"
assert _escape_node_label("foo_1") == "foo_1"
assert _escape_node_label("#foo*&!") == "_foo___"


def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["foo", "bar"])
sequence: Runnable = (
PromptTemplate.from_template("Hello, {input}")
| {
"llm1": fake_llm,
"llm2": fake_llm,
}
| PromptTemplate.from_template("{llm1} {llm2}")
)
graph = sequence.get_graph()
assert graph.draw_mermaid(with_styles=False) == snapshot(name="mermaid")

0 comments on commit e3ea365

Please sign in to comment.