diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/StaticGraph.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/StaticGraph.scala index 5b62ecb8e2f..622cd1b3e39 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/StaticGraph.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/StaticGraph.scala @@ -208,4 +208,60 @@ class StaticGraph[T: ClassTag]( val model = IRGraph(inputsIR, outputsIR, variables, true, inFormats, outFormats) model.build() } + + // Merge a nested StaticGraph into a non-nested one + private[bigdl] def toSingleGraph(): StaticGraph[T] = { + if (this.isNestedGraph()) { + val graph = this.cloneModule() + val fwdExecution = graph.getSortedForwardExecutions() + val dmOutput = fwdExecution(fwdExecution.length - 1).nextNodes(0) + + var i = 0 + while (i < fwdExecution.length) { + if (fwdExecution(i).element.isInstanceOf[StaticGraph[T]]) { + var g = fwdExecution(i).element.asInstanceOf[StaticGraph[T]].toSingleGraph() + fwdExecution(i).element = g + + for (inputIndex <- 0 until fwdExecution(i).prevNodes.length) { + val inputNode = g.inputs(inputIndex) + inputNode.element = Identity() + + while (fwdExecution(i).prevNodes.length != 0) { + val preNode = fwdExecution(i).prevNodes(0) + preNode.delete(fwdExecution(i)) + preNode.add(inputNode) + } + } + + for (outputIndex <- 0 until g.outputs.length) { + val outputNode = g.outputs(outputIndex) + outputNode.removeNextEdges() + while (fwdExecution(i).nextNodes.length != 0) { + val nextNode = fwdExecution(i).nextNodes(0) + fwdExecution(i).delete(nextNode) + outputNode.add(nextNode) + } + } + } + i += 1 + } + + val resultOutputNodes = dmOutput.prevNodes + resultOutputNodes.foreach(_.delete(dmOutput)) + new StaticGraph[T](Array(graph.inputs(0)), resultOutputNodes, + enableExcludeChecking = this.enableExcludeChecking) + } else { + this + } + } + + private def isNestedGraph(): Boolean = { + for (i <- 0 until forwardExecution.length) { + if (forwardExecution(i).element.isInstanceOf[StaticGraph[T]]) { + return true + } + } + + false + } } diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/abstractnn/AbstractModule.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/abstractnn/AbstractModule.scala index f218abb36cf..7a7207c9808 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/abstractnn/AbstractModule.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/abstractnn/AbstractModule.scala @@ -828,7 +828,13 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag, def toGraph(startNodes: ModuleNode[T]*): Graph[T] = { val starts = if (startNodes.isEmpty) Array(Input[T]()) else startNodes.toArray val endNodes = this.getEndNodes(starts) - Graph(starts, endNodes) + val graph = Graph(starts, endNodes) + if (graph.isInstanceOf[StaticGraph[T]]) { + // Merge nested graphs inside to make the whole graph non-nested + graph.asInstanceOf[StaticGraph[T]].toSingleGraph() + } else { + graph + } } /** diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/GraphSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/GraphSpec.scala index ea887a86d09..a25242376d3 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/GraphSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/GraphSpec.scala @@ -1336,6 +1336,39 @@ class StaticGraphSpec extends FlatSpec with Matchers { val model = Graph(Array(n1, n2), Array(n3, n4)) } } + + "Graph toSingleGraph" should "work correctly" in { + val input = Input() + + val linear1 = Linear[Float](2, 3).inputs(input) + + val inputg1 = Input() + val l1 = Linear[Float](3, 5).inputs(inputg1) + val inputg1nested = Input() + val l1nested = Linear[Float](5, 5).inputs(inputg1nested) + val g1nested = Graph(inputg1nested, l1nested).inputs(l1) + val g1 = Graph(inputg1, g1nested).inputs(linear1) + + val inputg2 = Input() + val l2 = Linear[Float](5, 3).inputs(inputg2) + val g2 = Graph(inputg2, l2).inputs(g1) + + val linear3 = Linear(3, 6).inputs(g2) + val linear4 = Linear(3, 5).inputs(g2) + + val graph = Graph(input, Array(linear3, linear4)).asInstanceOf[StaticGraph[Float]] + val toSingle = graph.toSingleGraph() + + val tensor = Tensor[Float](Array(3, 2)).rand() + val graphOutput = graph.forward(tensor) + val toSingleOutput = toSingle.forward(tensor) + graphOutput should equal(toSingleOutput) + + val fwdExecution = toSingle.asInstanceOf[StaticGraph[Float]].getForwardExecutions() + for (i <- 0 until fwdExecution.length) { + assert(!fwdExecution(i).element.isInstanceOf[StaticGraph[Float]]) + } + } } object ModelUntils {