diff --git a/mermaid.go b/mermaid.go index 9b2d3a6..df4fa07 100644 --- a/mermaid.go +++ b/mermaid.go @@ -80,11 +80,28 @@ func diagramGraph(g *Graph, sb *strings.Builder) { nodeShape := MermaidShapeRound each := g.nodes[key] if s := each.GetAttr("shape"); s != nil { - nodeShape = s.(shape) + // could be a shape or a string + shapeString, ok := s.(string) + if ok { + // see if we can map the string to a shape + mermaidShape, ok := lookupShape(shapeString) + if ok { + nodeShape = mermaidShape + } + } + // could be a shape + mermaidShape, ok := s.(shape) + if ok { + nodeShape = mermaidShape + } } txt := "?" if label := each.GetAttr("label"); label != nil { - txt = label.(string) + // take string only + slabel, ok := label.(string) + if ok { + txt = slabel + } } fmt.Fprintf(sb, "\tn%d%s%s%s;\n", each.seq, nodeShape.open, escape(txt), nodeShape.close) if style := each.GetAttr("style"); style != nil { @@ -103,11 +120,16 @@ func diagramGraph(g *Graph, sb *strings.Builder) { // The edge can override the link style link := denoteEdge if l := each.GetAttr("link"); l != nil { - link = l.(string) + // take string only + slink, ok := l.(string) + if ok { + link = slink + } } if label := each.GetAttr("label"); label != nil { slabel, ok := label.(string) if !ok { + // make it a string slabel = fmt.Sprintf("%v", label) } if label != "" { @@ -124,3 +146,33 @@ func diagramGraph(g *Graph, sb *strings.Builder) { func writeEnd(sb *strings.Builder) { sb.WriteString(";\n") } + +func lookupShape(shapeName string) (shape, bool) { + switch shapeName { + case "round", "box": + return MermaidShapeRound, true + case "asymmetric": + return MermaidShapeAsymmetric, true + case "circle": + return MermaidShapeCircle, true + case "cylinder": + return MermaidShapeCylinder, true + case "rhombux": + return MermaidShapeRhombus, true + case "stadium": + return MermaidShapeStadium, true + case "subroutine": + return MermaidShapeSubroutine, true + case "trapezoid": + return MermaidShapeTrapezoid, true + case "trapezoid-alt": + return MermaidShapeTrapezoidAlt, true + case "hexagon": + return MermaidShapeHexagon, true + case "parallelogram": + return MermaidShapeParallelogram, true + case "parallelogram-alt": + return MermaidShapeParallelogramAlt, true + } + return shape{}, false +} diff --git a/mermaid_test.go b/mermaid_test.go index 16e5b7c..5fc7d6e 100644 --- a/mermaid_test.go +++ b/mermaid_test.go @@ -97,3 +97,45 @@ func TestMermaidSubgraph(t *testing.T) { t.Errorf("got [%[1]v:%[1]T] want [%[2]v:%[2]T]", got, want) } } + +func TestMermaidFromBoxShape(t *testing.T) { + graph := NewGraph(Directed) + graph.Node("A").Box() + graph.Edge(graph.Node("A"), graph.Node("B")) + + if got, want := flatten(MermaidGraph(graph, MermaidTopDown)), `graph TD;n1("A");n2("B");n1-->n2;`; got != want { + t.Errorf("got [%[1]v:%[1]T] want [%[2]v:%[2]T]", got, want) + } +} +func TestLookupShape(t *testing.T) { + tests := []struct { + name string + shapeName string + wantShape shape + wantOk bool + }{ + {"round", "round", MermaidShapeRound, true}, + {"box", "box", MermaidShapeRound, true}, + {"asymmetric", "asymmetric", MermaidShapeAsymmetric, true}, + {"circle", "circle", MermaidShapeCircle, true}, + {"cylinder", "cylinder", MermaidShapeCylinder, true}, + {"rhombux", "rhombux", MermaidShapeRhombus, true}, + {"stadium", "stadium", MermaidShapeStadium, true}, + {"subroutine", "subroutine", MermaidShapeSubroutine, true}, + {"trapezoid", "trapezoid", MermaidShapeTrapezoid, true}, + {"trapezoid-alt", "trapezoid-alt", MermaidShapeTrapezoidAlt, true}, + {"hexagon", "hexagon", MermaidShapeHexagon, true}, + {"parallelogram", "parallelogram", MermaidShapeParallelogram, true}, + {"parallelogram-alt", "parallelogram-alt", MermaidShapeParallelogramAlt, true}, + {"unknown", "unknown", shape{}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotShape, gotOk := lookupShape(tt.shapeName) + if gotShape != tt.wantShape || gotOk != tt.wantOk { + t.Errorf("lookupShape(%q) = (%v, %v), want (%v, %v)", tt.shapeName, gotShape, gotOk, tt.wantShape, tt.wantOk) + } + }) + } +}