Skip to content

Commit

Permalink
Add a method to merge nested StaticGraphs (intel#2985)
Browse files Browse the repository at this point in the history
  • Loading branch information
mengceng15 authored Jan 2, 2020
1 parent 42eafe6 commit 639e07a
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,60 @@ class StaticGraph[T: ClassTag](
val model = IRGraph(inputsIR, outputsIR, variables, true, inFormats, outFormats)
model.build()
}

// Merge a nested StaticGraph into a non-nested one
private[bigdl] def toSingleGraph(): StaticGraph[T] = {
if (this.isNestedGraph()) {
val graph = this.cloneModule()
val fwdExecution = graph.getSortedForwardExecutions()
val dmOutput = fwdExecution(fwdExecution.length - 1).nextNodes(0)

var i = 0
while (i < fwdExecution.length) {
if (fwdExecution(i).element.isInstanceOf[StaticGraph[T]]) {
var g = fwdExecution(i).element.asInstanceOf[StaticGraph[T]].toSingleGraph()
fwdExecution(i).element = g

for (inputIndex <- 0 until fwdExecution(i).prevNodes.length) {
val inputNode = g.inputs(inputIndex)
inputNode.element = Identity()

while (fwdExecution(i).prevNodes.length != 0) {
val preNode = fwdExecution(i).prevNodes(0)
preNode.delete(fwdExecution(i))
preNode.add(inputNode)
}
}

for (outputIndex <- 0 until g.outputs.length) {
val outputNode = g.outputs(outputIndex)
outputNode.removeNextEdges()
while (fwdExecution(i).nextNodes.length != 0) {
val nextNode = fwdExecution(i).nextNodes(0)
fwdExecution(i).delete(nextNode)
outputNode.add(nextNode)
}
}
}
i += 1
}

val resultOutputNodes = dmOutput.prevNodes
resultOutputNodes.foreach(_.delete(dmOutput))
new StaticGraph[T](Array(graph.inputs(0)), resultOutputNodes,
enableExcludeChecking = this.enableExcludeChecking)
} else {
this
}
}

private def isNestedGraph(): Boolean = {
for (i <- 0 until forwardExecution.length) {
if (forwardExecution(i).element.isInstanceOf[StaticGraph[T]]) {
return true
}
}

false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,13 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag,
def toGraph(startNodes: ModuleNode[T]*): Graph[T] = {
val starts = if (startNodes.isEmpty) Array(Input[T]()) else startNodes.toArray
val endNodes = this.getEndNodes(starts)
Graph(starts, endNodes)
val graph = Graph(starts, endNodes)
if (graph.isInstanceOf[StaticGraph[T]]) {
// Merge nested graphs inside to make the whole graph non-nested
graph.asInstanceOf[StaticGraph[T]].toSingleGraph()
} else {
graph
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,39 @@ class StaticGraphSpec extends FlatSpec with Matchers {
val model = Graph(Array(n1, n2), Array(n3, n4))
}
}

"Graph toSingleGraph" should "work correctly" in {
val input = Input()

val linear1 = Linear[Float](2, 3).inputs(input)

val inputg1 = Input()
val l1 = Linear[Float](3, 5).inputs(inputg1)
val inputg1nested = Input()
val l1nested = Linear[Float](5, 5).inputs(inputg1nested)
val g1nested = Graph(inputg1nested, l1nested).inputs(l1)
val g1 = Graph(inputg1, g1nested).inputs(linear1)

val inputg2 = Input()
val l2 = Linear[Float](5, 3).inputs(inputg2)
val g2 = Graph(inputg2, l2).inputs(g1)

val linear3 = Linear(3, 6).inputs(g2)
val linear4 = Linear(3, 5).inputs(g2)

val graph = Graph(input, Array(linear3, linear4)).asInstanceOf[StaticGraph[Float]]
val toSingle = graph.toSingleGraph()

val tensor = Tensor[Float](Array(3, 2)).rand()
val graphOutput = graph.forward(tensor)
val toSingleOutput = toSingle.forward(tensor)
graphOutput should equal(toSingleOutput)

val fwdExecution = toSingle.asInstanceOf[StaticGraph[Float]].getForwardExecutions()
for (i <- 0 until fwdExecution.length) {
assert(!fwdExecution(i).element.isInstanceOf[StaticGraph[Float]])
}
}
}

object ModelUntils {
Expand Down

0 comments on commit 639e07a

Please sign in to comment.