Skip to content

Commit

Permalink
fix model load of maskrcnn (intel-analytics#2961)
Browse files Browse the repository at this point in the history
* fix maskrcnn model load

* delete temp file

* fix maskrcnn tests
  • Loading branch information
zhangxiaoli73 authored Dec 2, 2019
1 parent df6b9eb commit 6eeb7a5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,26 @@ class MaskRCNN(val inChannels: Int,
val config: MaskRCNNParams = new MaskRCNNParams)(implicit ev: TensorNumeric[Float])
extends Container[Activity, Activity, Float] {

private val batchImgInfo : Tensor[Float] = Tensor[Float](2)
private val backbone = buildBackbone(inChannels, outChannels)
private val rpn = RegionProposal(inChannels, config.anchorSizes, config.aspectRatios,
config.anchorStride, config.preNmsTopNTest, config.postNmsTopNTest, config.preNmsTopNTrain,
config.postNmsTopNTrain, config.rpnNmsThread, config.minSize)
private val boxHead = BoxHead(inChannels, config.boxResolution, config.scales,
config.samplingRatio, config.boxScoreThresh, config.boxNmsThread, config.maxPerImage,
config.outputSize, numClasses)
private val maskHead = MaskHead(inChannels, config.maskResolution, config.scales,
config.samplingRatio, config.layers, config.dilation, numClasses)

// add layer to modules
modules.append(backbone.asInstanceOf[Module[Float]])
modules.append(rpn.asInstanceOf[Module[Float]])
modules.append(boxHead.asInstanceOf[Module[Float]])
modules.append(maskHead.asInstanceOf[Module[Float]])
private val batchImgInfo : Tensor[Float] = Tensor[Float](2)
initModules()
// add layer to modules
private def initModules(): Unit = {
modules.clear()
val backbone = buildBackbone(inChannels, outChannels)
val rpn = RegionProposal(inChannels, config.anchorSizes, config.aspectRatios,
config.anchorStride, config.preNmsTopNTest, config.postNmsTopNTest, config.preNmsTopNTrain,
config.postNmsTopNTrain, config.rpnNmsThread, config.minSize)
val boxHead = BoxHead(inChannels, config.boxResolution, config.scales,
config.samplingRatio, config.boxScoreThresh, config.boxNmsThread, config.maxPerImage,
config.outputSize, numClasses)
val maskHead = MaskHead(inChannels, config.maskResolution, config.scales,
config.samplingRatio, config.layers, config.dilation, numClasses)

modules.append(backbone.asInstanceOf[Module[Float]])
modules.append(rpn.asInstanceOf[Module[Float]])
modules.append(boxHead.asInstanceOf[Module[Float]])
modules.append(maskHead.asInstanceOf[Module[Float]])
}

private def buildResNet50(): Module[Float] = {

Expand Down Expand Up @@ -167,18 +171,24 @@ class MaskRCNN(val inChannels: Int,
// contains all images info (height, width, original height, original width)
val imageInfo = input.toTable[Tensor[Float]](2)

// get each layer from modules
val backbone = modules(0)
val rpn = modules(1)
val boxHead = modules(2)
val maskHead = modules(3)

batchImgInfo.setValue(1, inputFeatures.size(3))
batchImgInfo.setValue(2, inputFeatures.size(4))

val features = this.backbone.forward(inputFeatures)
val proposals = this.rpn.forward(T(features, batchImgInfo))
val boxOutput = this.boxHead.forward(T(features, proposals, batchImgInfo)).toTable
val features = backbone.forward(inputFeatures)
val proposals = rpn.forward(T(features, batchImgInfo))
val boxOutput = boxHead.forward(T(features, proposals, batchImgInfo)).toTable
val postProcessorBox = boxOutput[Table](2)
val labelsBox = postProcessorBox[Tensor[Float]](1)
val proposalsBox = postProcessorBox[Table](2)
val scores = postProcessorBox[Tensor[Float]](3)
if (labelsBox.size(1) > 0) {
val masks = this.maskHead.forward(T(features, proposalsBox, labelsBox)).toTable
val masks = maskHead.forward(T(features, proposalsBox, labelsBox)).toTable
if (this.isTraining()) {
output = T(proposalsBox, labelsBox, masks, scores)
} else {
Expand Down Expand Up @@ -340,8 +350,12 @@ object MaskRCNN extends ContainerSerializable {
.getAttributeValue(context, attrMap.get("useGn"))
.asInstanceOf[Boolean])

MaskRCNN(inChannels, outChannels, numClasses, config)
.asInstanceOf[AbstractModule[Activity, Activity, T]]
val maskrcnn = MaskRCNN(inChannels, outChannels, numClasses, config)
.asInstanceOf[Container[Activity, Activity, T]]
maskrcnn.modules.clear()
loadSubModules(context, maskrcnn)

maskrcnn
}

override def doSerializeModule[T: ClassTag](context: SerializeContext[T],
Expand Down Expand Up @@ -461,5 +475,7 @@ object MaskRCNN extends ContainerSerializable {
DataConverter.setAttributeValue(context, useGnBuilder,
config.useGn, universe.typeOf[Boolean])
maskrcnnBuilder.putAttr("useGn", useGnBuilder.build)

serializeSubModules(context, maskrcnnBuilder)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ package com.intel.analytics.bigdl.models.maskrcnn

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.dataset.segmentation.RLEMasks
import com.intel.analytics.bigdl.nn.Nms
import com.intel.analytics.bigdl.nn.{Module, Nms}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.transform.vision.image.RoiImageInfo
import com.intel.analytics.bigdl.transform.vision.image.label.roi.RoiLabel
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import com.intel.analytics.bigdl.utils.{RandomGenerator, T, Table}
import org.scalatest.{FlatSpec, Matchers}

import scala.reflect.io.File

class MaskRCNNSpec extends FlatSpec with Matchers {
"build maskrcnn" should "be ok" in {
RandomGenerator.RNG.setSeed(100)
Expand Down Expand Up @@ -465,6 +467,19 @@ class MaskRCNNSpec extends FlatSpec with Matchers {
index(i) should be(expectedOut(i) + 1)
}
}

"MaskRCNN model load" should "be ok" in {
val resNetOutChannels = 32
val backboneOutChannels = 32
val mask = new MaskRCNN(resNetOutChannels, backboneOutChannels)
mask.getExtraParameter().foreach(_.fill(0.1f))

val tempFile = "/tmp/maskrcnn.model"
mask.saveModule(tempFile, overWrite = true)
val maskLoad = Module.loadModule[Float](tempFile)
maskLoad.getExtraParameter().foreach(t => require(t.valueAt(1) == 0.1f))
File(tempFile).delete()
}
}

class MaskRCNNSerialTest extends ModuleSerializationTest {
Expand Down

0 comments on commit 6eeb7a5

Please sign in to comment.