Skip to content

Commit

Permalink
feat: enable global average pooling (intel-analytics#2823)
Browse files Browse the repository at this point in the history
* feat: enable global average pooling

* test: add more unit tests
  • Loading branch information
i8run authored May 24, 2019
1 parent 43170f6 commit e7ec66e
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import com.intel.analytics.bigdl.nn.mkldnn.Phase.InferencePhase
import com.intel.analytics.bigdl.tensor.Tensor

class AvgPooling(
kW: Int,
kH: Int,
var kW: Int,
var kH: Int,
dW: Int = 1,
dH: Int = 1,
padW: Int = 0,
padH: Int = 0
padH: Int = 0,
globalPooling: Boolean = false
) extends MklDnnLayer {
@transient private var paddingTL: Array[Int] = _
@transient private var paddingBR: Array[Int] = _
Expand Down Expand Up @@ -63,12 +64,20 @@ class AvgPooling(

override private[mkldnn] def initFwdPrimitives(inputs: Array[MemoryData], phase: Phase) = {
_inputFormats = singleNativeData(inputs)
val strides = Array(dW, dH)
val kernel = Array(kH, kW)
val n = _inputFormats(0).shape(0)
val c = _inputFormats(0).shape(1)
val h = _inputFormats(0).shape(2)
val w = _inputFormats(0).shape(3)

// global average pooling reduce each feature map to a single average value
if (globalPooling) {
kH = h
kW = w
}

val strides = Array(dW, dH)
val kernel = Array(kH, kW)

val (pt, pb, pl, pr, oh, ow) = if (padH == -1 && padW == -1) {
val sizes = NNUtils.getSAMEOutSizeAndPadding(h, w, dH, dW, kH, kW)
(sizes(0), sizes(1), sizes(2), sizes(3), sizes(4), sizes(5))
Expand Down Expand Up @@ -127,6 +136,7 @@ object AvgPooling {
dW: Int = 1,
dH: Int = 1,
padW: Int = 0,
padH: Int = 0
): AvgPooling = new AvgPooling(kW, kH, dW, dH, padW, padH)
padH: Int = 0,
globalPooling: Boolean = false
): AvgPooling = new AvgPooling(kW, kH, dW, dH, padW, padH, globalPooling)
}
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,54 @@ class AvgPoolingSpec extends BigDLSpecHelper {

Equivalent.nearequals(seq.output.toTensor, seq2.output.toTensor, 1e-2) should be (true)
}

"global average pooling" should "work correctly" in {
val gap = AvgPooling(2, 2, globalPooling = true)
val ap = AvgPooling(3, 3)

val inputShape = Array(4, 2, 3, 3)
val input = Tensor[Float](inputShape).rand(-1, 1)

val seq1 = Sequential()
.add(Input(inputShape, Memory.Format.nchw))
.add(ap)
.add(Output(Memory.Format.nchw))

val seq2 = Sequential()
.add(Input(inputShape, Memory.Format.nchw))
.add(gap)
.add(Output(Memory.Format.nchw))

seq1.evaluate()
seq2.evaluate()

seq1.compile(InferencePhase)
seq2.compile(InferencePhase)

seq1.forward(input)
seq2.forward(input)

seq1.output should be (seq2.output)
}

"global average pooling" should "has same behavior with nn" in {
val gap = AvgPooling(2, 2, globalPooling = true)

val inputShape = Array(4, 2, 3, 3)
val input = Tensor[Float](inputShape).rand(-1, 1)

val seq1 = Sequential()
.add(Input(inputShape, Memory.Format.nchw))
.add(gap)
.add(Output(Memory.Format.nchw))

seq1.evaluate()
seq1.compile(InferencePhase)
seq1.forward(input)

val nngap = SpatialAveragePooling[Float](2, 2, globalPooling = true)
nngap.forward(input)

seq1.output should be (nngap.output)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -240,5 +240,30 @@ class IRconvertSpec extends BigDLSpecHelper {

Equivalent.nearequals(outDnn, outBlas, 1e-4) should be (true)
Equivalent.nearequals(gradInputDnn.toTensor, gradInputBlas.toTensor, 1e-4) should be (true)

System.clearProperty("bigdl.engineType")
}

"convert blas gap to dnn" should "work correctly" in {
System.setProperty("bigdl.engineType", "mkldnn")
val graph = Sequential()
.add(SpatialAveragePooling[Float](2, 2, globalPooling = true))
.toGraph()

graph.asInstanceOf[StaticGraph[Float]].setOutputFormats(Seq(Memory.Format.nchw))
val dnn = ConversionUtils.convert(graph.cloneModule())

graph.evaluate()
dnn.evaluate()

val input = Tensor[Float](4, 2, 3, 3).rand(-1, 1)

graph.forward(input)
dnn.forward(input)

graph.output should be (dnn.output)

dnn.release()
System.clearProperty("bigdl.engineType")
}
}

0 comments on commit e7ec66e

Please sign in to comment.