diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/models/maskrcnn/MaskRCNN.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/models/maskrcnn/MaskRCNN.scala index 34bf3aa7cf4..0a2966ee15f 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/models/maskrcnn/MaskRCNN.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/models/maskrcnn/MaskRCNN.scala @@ -60,22 +60,26 @@ class MaskRCNN(val inChannels: Int, val config: MaskRCNNParams = new MaskRCNNParams)(implicit ev: TensorNumeric[Float]) extends Container[Activity, Activity, Float] { - private val batchImgInfo : Tensor[Float] = Tensor[Float](2) - private val backbone = buildBackbone(inChannels, outChannels) - private val rpn = RegionProposal(inChannels, config.anchorSizes, config.aspectRatios, - config.anchorStride, config.preNmsTopNTest, config.postNmsTopNTest, config.preNmsTopNTrain, - config.postNmsTopNTrain, config.rpnNmsThread, config.minSize) - private val boxHead = BoxHead(inChannels, config.boxResolution, config.scales, - config.samplingRatio, config.boxScoreThresh, config.boxNmsThread, config.maxPerImage, - config.outputSize, numClasses) - private val maskHead = MaskHead(inChannels, config.maskResolution, config.scales, - config.samplingRatio, config.layers, config.dilation, numClasses) - - // add layer to modules - modules.append(backbone.asInstanceOf[Module[Float]]) - modules.append(rpn.asInstanceOf[Module[Float]]) - modules.append(boxHead.asInstanceOf[Module[Float]]) - modules.append(maskHead.asInstanceOf[Module[Float]]) + private val batchImgInfo : Tensor[Float] = Tensor[Float](2) + initModules() + // add layer to modules + private def initModules(): Unit = { + modules.clear() + val backbone = buildBackbone(inChannels, outChannels) + val rpn = RegionProposal(inChannels, config.anchorSizes, config.aspectRatios, + config.anchorStride, config.preNmsTopNTest, config.postNmsTopNTest, config.preNmsTopNTrain, + config.postNmsTopNTrain, config.rpnNmsThread, config.minSize) + val boxHead = BoxHead(inChannels, config.boxResolution, config.scales, + config.samplingRatio, config.boxScoreThresh, config.boxNmsThread, config.maxPerImage, + config.outputSize, numClasses) + val maskHead = MaskHead(inChannels, config.maskResolution, config.scales, + config.samplingRatio, config.layers, config.dilation, numClasses) + + modules.append(backbone.asInstanceOf[Module[Float]]) + modules.append(rpn.asInstanceOf[Module[Float]]) + modules.append(boxHead.asInstanceOf[Module[Float]]) + modules.append(maskHead.asInstanceOf[Module[Float]]) + } private def buildResNet50(): Module[Float] = { @@ -167,18 +171,24 @@ class MaskRCNN(val inChannels: Int, // contains all images info (height, width, original height, original width) val imageInfo = input.toTable[Tensor[Float]](2) + // get each layer from modules + val backbone = modules(0) + val rpn = modules(1) + val boxHead = modules(2) + val maskHead = modules(3) + batchImgInfo.setValue(1, inputFeatures.size(3)) batchImgInfo.setValue(2, inputFeatures.size(4)) - val features = this.backbone.forward(inputFeatures) - val proposals = this.rpn.forward(T(features, batchImgInfo)) - val boxOutput = this.boxHead.forward(T(features, proposals, batchImgInfo)).toTable + val features = backbone.forward(inputFeatures) + val proposals = rpn.forward(T(features, batchImgInfo)) + val boxOutput = boxHead.forward(T(features, proposals, batchImgInfo)).toTable val postProcessorBox = boxOutput[Table](2) val labelsBox = postProcessorBox[Tensor[Float]](1) val proposalsBox = postProcessorBox[Table](2) val scores = postProcessorBox[Tensor[Float]](3) if (labelsBox.size(1) > 0) { - val masks = this.maskHead.forward(T(features, proposalsBox, labelsBox)).toTable + val masks = maskHead.forward(T(features, proposalsBox, labelsBox)).toTable if (this.isTraining()) { output = T(proposalsBox, labelsBox, masks, scores) } else { @@ -340,8 +350,12 @@ object MaskRCNN extends ContainerSerializable { .getAttributeValue(context, attrMap.get("useGn")) .asInstanceOf[Boolean]) - MaskRCNN(inChannels, outChannels, numClasses, config) - .asInstanceOf[AbstractModule[Activity, Activity, T]] + val maskrcnn = MaskRCNN(inChannels, outChannels, numClasses, config) + .asInstanceOf[Container[Activity, Activity, T]] + maskrcnn.modules.clear() + loadSubModules(context, maskrcnn) + + maskrcnn } override def doSerializeModule[T: ClassTag](context: SerializeContext[T], @@ -461,5 +475,7 @@ object MaskRCNN extends ContainerSerializable { DataConverter.setAttributeValue(context, useGnBuilder, config.useGn, universe.typeOf[Boolean]) maskrcnnBuilder.putAttr("useGn", useGnBuilder.build) + + serializeSubModules(context, maskrcnnBuilder) } } diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/models/maskrcnn/MaskRCNNSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/models/maskrcnn/MaskRCNNSpec.scala index 0a8f147eb73..f65caa8e97e 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/models/maskrcnn/MaskRCNNSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/models/maskrcnn/MaskRCNNSpec.scala @@ -18,7 +18,7 @@ package com.intel.analytics.bigdl.models.maskrcnn import com.intel.analytics.bigdl.Module import com.intel.analytics.bigdl.dataset.segmentation.RLEMasks -import com.intel.analytics.bigdl.nn.Nms +import com.intel.analytics.bigdl.nn.{Module, Nms} import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.transform.vision.image.RoiImageInfo import com.intel.analytics.bigdl.transform.vision.image.label.roi.RoiLabel @@ -26,6 +26,8 @@ import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest import com.intel.analytics.bigdl.utils.{RandomGenerator, T, Table} import org.scalatest.{FlatSpec, Matchers} +import scala.reflect.io.File + class MaskRCNNSpec extends FlatSpec with Matchers { "build maskrcnn" should "be ok" in { RandomGenerator.RNG.setSeed(100) @@ -465,6 +467,19 @@ class MaskRCNNSpec extends FlatSpec with Matchers { index(i) should be(expectedOut(i) + 1) } } + + "MaskRCNN model load" should "be ok" in { + val resNetOutChannels = 32 + val backboneOutChannels = 32 + val mask = new MaskRCNN(resNetOutChannels, backboneOutChannels) + mask.getExtraParameter().foreach(_.fill(0.1f)) + + val tempFile = "/tmp/maskrcnn.model" + mask.saveModule(tempFile, overWrite = true) + val maskLoad = Module.loadModule[Float](tempFile) + maskLoad.getExtraParameter().foreach(t => require(t.valueAt(1) == 0.1f)) + File(tempFile).delete() + } } class MaskRCNNSerialTest extends ModuleSerializationTest {