Skip to content

Commit

Permalink
handle node shape constant and string, issue #40
Browse files Browse the repository at this point in the history
  • Loading branch information
emicklei committed Dec 2, 2024
1 parent 04052ec commit 80a4663
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 3 deletions.
58 changes: 55 additions & 3 deletions mermaid.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,28 @@ func diagramGraph(g *Graph, sb *strings.Builder) {
nodeShape := MermaidShapeRound
each := g.nodes[key]
if s := each.GetAttr("shape"); s != nil {
nodeShape = s.(shape)
// could be a shape or a string
shapeString, ok := s.(string)
if ok {
// see if we can map the string to a shape
mermaidShape, ok := lookupShape(shapeString)
if ok {
nodeShape = mermaidShape
}
}
// could be a shape
mermaidShape, ok := s.(shape)
if ok {
nodeShape = mermaidShape
}
}
txt := "?"
if label := each.GetAttr("label"); label != nil {
txt = label.(string)
// take string only
slabel, ok := label.(string)
if ok {
txt = slabel
}
}
fmt.Fprintf(sb, "\tn%d%s%s%s;\n", each.seq, nodeShape.open, escape(txt), nodeShape.close)
if style := each.GetAttr("style"); style != nil {
Expand All @@ -103,11 +120,16 @@ func diagramGraph(g *Graph, sb *strings.Builder) {
// The edge can override the link style
link := denoteEdge
if l := each.GetAttr("link"); l != nil {
link = l.(string)
// take string only
slink, ok := l.(string)
if ok {
link = slink
}
}
if label := each.GetAttr("label"); label != nil {
slabel, ok := label.(string)
if !ok {
// make it a string
slabel = fmt.Sprintf("%v", label)
}
if label != "" {
Expand All @@ -124,3 +146,33 @@ func diagramGraph(g *Graph, sb *strings.Builder) {
func writeEnd(sb *strings.Builder) {
sb.WriteString(";\n")
}

func lookupShape(shapeName string) (shape, bool) {
switch shapeName {
case "round", "box":
return MermaidShapeRound, true
case "asymmetric":
return MermaidShapeAsymmetric, true
case "circle":
return MermaidShapeCircle, true
case "cylinder":
return MermaidShapeCylinder, true
case "rhombux":
return MermaidShapeRhombus, true
case "stadium":
return MermaidShapeStadium, true
case "subroutine":
return MermaidShapeSubroutine, true
case "trapezoid":
return MermaidShapeTrapezoid, true
case "trapezoid-alt":
return MermaidShapeTrapezoidAlt, true
case "hexagon":
return MermaidShapeHexagon, true
case "parallelogram":
return MermaidShapeParallelogram, true
case "parallelogram-alt":
return MermaidShapeParallelogramAlt, true
}
return shape{}, false
}
42 changes: 42 additions & 0 deletions mermaid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,45 @@ func TestMermaidSubgraph(t *testing.T) {
t.Errorf("got [%[1]v:%[1]T] want [%[2]v:%[2]T]", got, want)
}
}

func TestMermaidFromBoxShape(t *testing.T) {
graph := NewGraph(Directed)
graph.Node("A").Box()
graph.Edge(graph.Node("A"), graph.Node("B"))

if got, want := flatten(MermaidGraph(graph, MermaidTopDown)), `graph TD;n1("A");n2("B");n1-->n2;`; got != want {
t.Errorf("got [%[1]v:%[1]T] want [%[2]v:%[2]T]", got, want)
}
}
func TestLookupShape(t *testing.T) {
tests := []struct {
name string
shapeName string
wantShape shape
wantOk bool
}{
{"round", "round", MermaidShapeRound, true},
{"box", "box", MermaidShapeRound, true},
{"asymmetric", "asymmetric", MermaidShapeAsymmetric, true},
{"circle", "circle", MermaidShapeCircle, true},
{"cylinder", "cylinder", MermaidShapeCylinder, true},
{"rhombux", "rhombux", MermaidShapeRhombus, true},
{"stadium", "stadium", MermaidShapeStadium, true},
{"subroutine", "subroutine", MermaidShapeSubroutine, true},
{"trapezoid", "trapezoid", MermaidShapeTrapezoid, true},
{"trapezoid-alt", "trapezoid-alt", MermaidShapeTrapezoidAlt, true},
{"hexagon", "hexagon", MermaidShapeHexagon, true},
{"parallelogram", "parallelogram", MermaidShapeParallelogram, true},
{"parallelogram-alt", "parallelogram-alt", MermaidShapeParallelogramAlt, true},
{"unknown", "unknown", shape{}, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotShape, gotOk := lookupShape(tt.shapeName)
if gotShape != tt.wantShape || gotOk != tt.wantOk {
t.Errorf("lookupShape(%q) = (%v, %v), want (%v, %v)", tt.shapeName, gotShape, gotOk, tt.wantShape, tt.wantOk)
}
})
}
}

0 comments on commit 80a4663

Please sign in to comment.