diff --git a/mermaid.go b/mermaid.go index fc5e163..ea8940f 100644 --- a/mermaid.go +++ b/mermaid.go @@ -63,7 +63,7 @@ func diagram(g *Graph, diagramType string, orientation int) string { diagramGraph(g, sb) for _, id := range g.sortedSubgraphsKeys() { each := g.subgraphs[id] - fmt.Fprintf(sb, "subgraph %s;\n", id) + fmt.Fprintf(sb, "subgraph %s [%s];\n", id, each.attributes["label"]) diagramGraph(each, sb) fmt.Fprintln(sb, "end;") } diff --git a/mermaid_test.go b/mermaid_test.go index 7e4f1a0..16e5b7c 100644 --- a/mermaid_test.go +++ b/mermaid_test.go @@ -88,12 +88,12 @@ func TestMermaidSubgraph(t *testing.T) { sub1.Node("a1").Edge(sub1.Node("a2")) sub2 := di.Subgraph("two") sub2.Node("b1").Edge(sub2.Node("b2")) - sub3 := di.Subgraph("three") + sub3 := di.Subgraph("THREE").Label("three") sub3.Node("c1").Edge(sub3.Node("c2")) sub3.Node("c1").Edge(sub1.Node("a2")) mf := MermaidFlowchart(di, MermaidLeftToRight) - if got, want := flatten(mf), `flowchart LR;n8-->n3;subgraph one;n2("a1");n3("a2");n2-->n3;end;subgraph three;n8("c1");n9("c2");n8-->n9;end;subgraph two;n5("b1");n6("b2");n5-->n6;end;`; got != want { + if got, want := flatten(mf), `flowchart LR;n8-->n3;subgraph THREE [three];n8("c1");n9("c2");n8-->n9;end;subgraph one [one];n2("a1");n3("a2");n2-->n3;end;subgraph two [two];n5("b1");n6("b2");n5-->n6;end;`; got != want { t.Errorf("got [%[1]v:%[1]T] want [%[2]v:%[2]T]", got, want) } }