Skip to content

Commit

Permalink
support roialign backward (intel-analytics#2975)
Browse files Browse the repository at this point in the history
* support roialign backward

* fix sparselinear unit test
  • Loading branch information
zhangxiaoli73 authored Dec 6, 2019
1 parent 09f3106 commit da17001
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.Table
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric

import scala.reflect._
Expand Down Expand Up @@ -49,11 +49,11 @@ class RoiAlign[T: ClassTag] (
val pooledW: Int,
val mode: String = "avg",
val aligned: Boolean = true
)(implicit ev: TensorNumeric[T]) extends AbstractModule[Table, Tensor[T], T]{
override def updateOutput(input: Table): Tensor[T] = {
)(implicit ev: TensorNumeric[T]) extends AbstractModule[Activity, Tensor[T], T]{
override def updateOutput(input: Activity): Tensor[T] = {
if (classTag[T] == classTag[Float]) {
val data = input[Tensor[Float]](1)
val rois = input[Tensor[Float]](2)
val data = input.toTable[Tensor[Float]](1)
val rois = input.toTable[Tensor[Float]](2)

val num_rois = rois.size(1)
val channels = data.size(2)
Expand All @@ -78,8 +78,8 @@ class RoiAlign[T: ClassTag] (
width,
spatialScale)
} else if (classTag[T] == classTag[Double]) {
val data = input[Tensor[Double]](1)
val rois = input[Tensor[Double]](2)
val data = input.toTable[Tensor[Double]](1)
val rois = input.toTable[Tensor[Double]](2)

val num_rois = rois.size(1)
val channels = data.size(2)
Expand Down Expand Up @@ -110,8 +110,180 @@ class RoiAlign[T: ClassTag] (
output
}

override def updateGradInput(input: Table, gradOutput: Tensor[T]): Table = {
throw new UnsupportedOperationException("Not support backward propagation")

private def bilinearInterpolateGradient(height: Int, width: Int, y: Float, x: Float)
: (Float, Float, Float, Float, Int, Int, Int, Int) = {
var w1: Float = 0.0f
var w2: Float = 0.0f
var w3: Float = 0.0f
var w4: Float = 0.0f
var x_low : Int = 0
var x_high: Int = 0
var y_low: Int = 0
var y_high: Int = 0

// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
return (w1, w2, w3, w4, x_low, x_high, y_low, y_high)
}

var realY = if (y <= 0) 0 else y
var realX = if (x <= 0) 0 else x
y_low = realY.toInt
x_low = realX.toInt

if (y_low >= height - 1) {
y_high = height - 1
y_low = height - 1
realY = y_low
} else y_high = y_low + 1

if (x_low >= width - 1) {
x_high = width - 1
x_low = width - 1
realX = x_low
} else x_high = x_low + 1

val ly = realY - y_low
val lx = realX - x_low
val hy = 1.0 - ly
val hx = 1.0 - lx

w1 = (hy * hx).toFloat
w2 = (hy * lx).toFloat
w3 = (ly * hx).toFloat
w4 = (ly * lx).toFloat

return (w1, w2, w3, w4, x_low, x_high, y_low, y_high)
}

private def roiAlignBackward(
nums: Int,
gradOutputArr: Array[T],
gradInputArr: Array[T],
gradInputOffset: Int,
rois: Array[T],
channels: Int,
height: Int,
width: Int,
pooled_height: Int,
pooled_width: Int,
sampling_ratio : Int,
n_stride : Int,
c_stride : Int,
h_stride : Int,
w_stride : Int,
spatial_scale: Float) {
val roi_cols = 4
for (index <- 0 until nums) {
val pw = index % pooled_width
val ph = (index / pooled_width) % pooled_height
val c = (index / pooled_width / pooled_height) % channels
val n = index / pooled_width / pooled_height / channels
val offset_rois = n * roi_cols

val offset = if (aligned) 0.5f else 0.0f
val roi_start_w = ev.toType[Float](rois(offset_rois)) * spatial_scale - offset
val roi_start_h = ev.toType[Float](rois(offset_rois + 1)) * spatial_scale - offset
val roi_end_w = ev.toType[Float](rois(offset_rois + 2)) * spatial_scale - offset
val roi_end_h = ev.toType[Float](rois(offset_rois + 3)) * spatial_scale - offset

var roi_width = roi_end_w - roi_start_w
var roi_height = roi_end_h - roi_start_h

if (aligned) {
require(roi_width >= 0 && roi_height >= 0,
s"ROIs in ROIAlign do not have non-negative size!" +
s"But get ${roi_height} ${roi_width}")
} else {
roi_width = math.max(roi_width, 1.0f)
roi_height = math.max(roi_height, 1.0f)
}

val bin_size_h = roi_height / pooled_height
val bin_size_w = roi_width / pooled_width
val output_offset = n * n_stride + c * c_stride
val grad_output_value = gradOutputArr(output_offset + ph * h_stride + pw * w_stride)

// We use roi_bin_grid to sample the grid and mimic integral
val roi_bin_grid_h =
if (sampling_ratio > 0) sampling_ratio else math.ceil(roi_height / pooled_height).toInt
val roi_bin_grid_w =
if (sampling_ratio > 0) sampling_ratio else math.ceil(roi_width / pooled_width).toInt

// We do average (integral) pooling inside a bin
val count = roi_bin_grid_h * roi_bin_grid_w

for (iy <- 0 until roi_bin_grid_h) {
val y = roi_start_h + ph * bin_size_h + (iy + 0.5) * bin_size_h / roi_bin_grid_h
for (ix <- 0 until roi_bin_grid_w) {
val x = roi_start_w + pw * bin_size_w + (ix + 0.5) * bin_size_w / roi_bin_grid_w

val (w1, w2, w3, w4, x_low, x_high, y_low, y_high) =
bilinearInterpolateGradient(height, width, y.toFloat, x.toFloat)

val g1 = ev.times(grad_output_value, ev.fromType(w1 / count))
val g2 = ev.times(grad_output_value, ev.fromType(w2 / count))
val g3 = ev.times(grad_output_value, ev.fromType(w3 / count))
val g4 = ev.times(grad_output_value, ev.fromType(w4 / count))

if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
gradInputArr(gradInputOffset + y_low * width + x_low) =
ev.plus(gradInputArr(gradInputOffset + y_low * width + x_low), g1)
gradInputArr(gradInputOffset + y_low * width + x_high) =
ev.plus(gradInputArr(gradInputOffset + y_low * width + x_high), g2)
gradInputArr(gradInputOffset + y_high * width + x_low) =
ev.plus(gradInputArr(gradInputOffset + y_high * width + x_low), g3)
gradInputArr(gradInputOffset + y_high * width + x_high) =
ev.plus(gradInputArr(gradInputOffset + y_high * width + x_high), g4)
}
}
}
}
}

override def updateGradInput(input: Activity, gradOutput: Tensor[T]): Activity = {
require(mode == "avg", s"Only support backward for average mode, but get ${mode}")
val data = input.toTable[Tensor[T]](1)
val rois = input.toTable[Tensor[T]](2)
val num_rois = rois.size(1)
val channels = data.size(2)
val height = data.size(3)
val width = data.size(4)

require(gradOutput.isContiguous(), "gradOutput should be contiguous")
require(gradOutput.dim() == 4, s"gradOutput should be with 4 dims, but get ${gradOutput.dim()}")

val n_stride = gradOutput.stride(1)
val c_stride = gradOutput.stride(2)
val h_stride = gradOutput.stride(3)
val w_stride = gradOutput.stride(4)

if (gradInput == null) gradInput = Tensor[T]()
gradInput.toTensor[T].resize(channels, height, width)
val gradInputArr = gradInput.toTensor[T].storage().array()
val gradInputOffset = gradInput.toTensor[T].storageOffset() - 1

roiAlignBackward(
gradOutput.nElement(),
gradOutputArr = gradOutput.asInstanceOf[Tensor[T]].storage().array(),
gradInputArr = gradInputArr,
gradInputOffset = 0,
rois = rois.storage().array(),
channels = channels,
height = height,
width = width,
pooled_height = pooledH,
pooled_width = pooledW,
sampling_ratio = samplingRatio,
n_stride = n_stride,
c_stride = c_stride,
h_stride = h_stride,
w_stride = w_stride,
spatial_scale = spatialScale)

gradInput
}

private def poolOneRoiFloat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.mkldnn.Equivalent
import com.intel.analytics.bigdl.tensor.{Storage, Tensor}
import com.intel.analytics.bigdl.utils.RandomGenerator._
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
Expand Down Expand Up @@ -209,6 +210,92 @@ class RoiAlignSpec extends FlatSpec with Matchers {
out should be(expectedWithAlign)
out2 should be(expected)
}

"backward" should "work correctly" in {
val input = Tensor[Float](T(T(T(
T(0.0611, 0.2246, 0.2343, 0.1771, 0.5561, 0.1094, 0.4609, 0.7084,
0.5798, 0.4967),
T(0.5104, 0.3295, 0.7182, 0.3845, 0.0898, 0.1175, 0.6402, 0.1968,
0.5124, 0.7118),
T(0.9249, 0.9997, 0.8927, 0.8767, 0.8450, 0.1544, 0.1705, 0.9842,
0.8127, 0.4358),
T(0.4143, 0.4284, 0.7578, 0.9225, 0.9643, 0.1760, 0.9539, 0.3134,
0.4544, 0.2956),
T(0.1875, 0.2433, 0.3493, 0.4441, 0.4069, 0.2859, 0.8036, 0.3218,
0.3639, 0.2985),
T(0.6635, 0.2552, 0.4144, 0.8396, 0.7418, 0.2865, 0.7929, 0.5001,
0.8977, 0.1051),
T(0.5809, 0.9867, 0.1315, 0.2391, 0.3047, 0.5158, 0.4514, 0.4929,
0.5301, 0.2647),
T(0.1671, 0.5482, 0.2380, 0.5374, 0.4422, 0.6454, 0.5376, 0.2245,
0.6632, 0.8439),
T(0.0109, 0.2807, 0.9301, 0.5438, 0.8123, 0.7750, 0.7308, 0.9924,
0.7282, 0.2328),
T(0.9997, 0.5540, 0.4200, 0.5419, 0.8642, 0.4312, 0.1213, 0.8956,
0.8784, 0.9128)))))

val rois = Tensor[Float](T(T(0.0f, 0.0f, 9.0f, 9.0f),
T(0.0f, 5.0f, 4.0f, 9.0f),
T(5.0f, 5.0f, 9.0f, 9.0f)))

val layer = RoiAlign[Float](spatialScale = 1, samplingRatio = 2, pooledH = 5,
pooledW = 5, aligned = true)
val out = layer.forward(T(input, rois))

val output = Tensor[Float](T(T(T(
T(0.2593, 0.3618, 0.2819, 0.3935, 0.5265),
T(0.7170, 0.8159, 0.6562, 0.4006, 0.6567),
T(0.3210, 0.4949, 0.5372, 0.5892, 0.4368),
T(0.6147, 0.3702, 0.4642, 0.5216, 0.5698),
T(0.2292, 0.5687, 0.6427, 0.6625, 0.6822))),

T(T(T(0.5731, 0.3794, 0.3402, 0.4984, 0.7202),
T(0.6138, 0.7188, 0.4918, 0.2772, 0.4116),
T(0.3937, 0.6494, 0.4761, 0.2458, 0.3759),
T(0.1376, 0.3636, 0.4568, 0.4737, 0.5367),
T(0.1754, 0.2846, 0.5770, 0.7363, 0.5957))),

T(T(T(0.3776, 0.6335, 0.6252, 0.5709, 0.6844),
T(0.4507, 0.5218, 0.5245, 0.5387, 0.5696),
T(0.5452, 0.5203, 0.4266, 0.4301, 0.5784),
T(0.6602, 0.6221, 0.5252, 0.5232, 0.6680),
T(0.7253, 0.6559, 0.7846, 0.8819, 0.6998)))))

val gradOutput = Tensor[Float](T(T(
T(T(0.9688, 0.4150, 0.4094, 0.6885, 0.6800),
T(0.6415, 0.4019, 0.4875, 0.9569, 0.5172),
T(0.9534, 0.8540, 0.9555, 0.0836, 0.1684),
T(0.1883, 0.9384, 0.3543, 0.2027, 0.5069),
T(0.7145, 0.6801, 0.9717, 0.2403, 0.3372))),
T(T(T(0.5260, 0.1794, 0.4793, 0.3070, 0.7682),
T(0.6350, 0.7321, 0.9899, 0.1897, 0.6957),
T(0.1313, 0.9514, 0.3386, 0.5337, 0.1051),
T(0.1800, 0.4603, 0.7114, 0.5114, 0.2422),
T(0.1480, 0.2527, 0.2014, 0.3004, 0.7147))),
T(T(T(0.4033, 0.9819, 0.4697, 0.3446, 0.7631),
T(0.3554, 0.2396, 0.6231, 0.6009, 0.3054),
T(0.2082, 0.2404, 0.6693, 0.7529, 0.1088),
T(0.0441, 0.4054, 0.0348, 0.7627, 0.0077),
T(0.9582, 0.6859, 0.3182, 0.5291, 0.3420)))))

Equivalent.nearequals(output, out, 1e-3) should be(true)

val grad = layer.backward(T(input, rois), gradOutput).toTensor[Float]

val expectedGrad = Tensor[Float](T(T(
T(0.3203, 0.2666, 0.1312, 0.1305, 0.1295, 0.1816, 0.2177, 0.2157, 0.2150, 0.0098),
T(0.2828, 0.2374, 0.1246, 0.1265, 0.1292, 0.1868, 0.2267, 0.2018, 0.1945, 0.0088),
T(0.2029, 0.1776, 0.1216, 0.1322, 0.1475, 0.2314, 0.2895, 0.1867, 0.1565, 0.0071),
T(0.2432, 0.2201, 0.1775, 0.1889, 0.2054, 0.1912, 0.1814, 0.1288, 0.1133, 0.0051),
T(0.3845, 0.3403, 0.3323, 0.3769, 0.3154, 0.2258, 0.1666, 0.1222, 0.1580, 0.0195),
T(0.8482, 0.8043, 0.8665, 0.9852, 0.3694, 0.7024, 0.9496, 0.7323, 0.8099, 0.1104),
T(0.8683, 1.2765, 1.0463, 0.7984, 0.2498, 0.4796, 0.7130, 1.1149, 0.6427, 0.0529),
T(0.6204, 1.1059, 1.0230, 0.6332, 0.3176, 0.4221, 0.5735, 0.9508, 0.4563, 0.0167),
T(0.4918, 0.6479, 0.7008, 0.8754, 0.5076, 0.9881, 0.7134, 0.6981, 0.5184, 0.0460),
T(0.0427, 0.0525, 0.0614, 0.1103, 0.0510, 0.1533, 0.1064, 0.0863, 0.0695, 0.0079))))

Equivalent.nearequals(grad, expectedGrad, 1e-3) should be(true)
}
}

class RoiAlignSerialTest extends ModuleSerializationTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.intel.analytics.bigdl.nn
import org.scalatest.{FlatSpec, Matchers}
import com.intel.analytics.bigdl.numeric.NumericFloat
import com.intel.analytics.bigdl.tensor.{SparseTensor, Tensor}
import com.intel.analytics.bigdl.utils.T
import com.intel.analytics.bigdl.utils.{RandomGenerator, T}
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest

import scala.util.Random
Expand Down Expand Up @@ -143,9 +143,11 @@ class SparseLinearSpec extends FlatSpec with Matchers {
}

"Sparse Linear" should "return the same result with Linear 7" in {
RandomGenerator.RNG.setSeed(10)
val rnd = new Random(10)
val gradOutput = Tensor(4, 2).rand()
val input = Tensor(4, 1023213).apply1(_ => Random.nextInt(100000) / 99999 * Random.nextFloat())
val input2 = Tensor(4, 50).apply1(_ => Random.nextInt(2) * Random.nextFloat())
val input = Tensor(4, 1023213).apply1(_ => rnd.nextInt(100000) / 99999 * rnd.nextFloat())
val input2 = Tensor(4, 50).apply1(_ => rnd.nextInt(2) * rnd.nextFloat())
val sl = SparseLinear(1023263, 2, backwardStart = 1, backwardLength = 1023263)
val sj = SparseJoinTable(2)
val sparseModel = Sequential().add(ParallelTable().add(Identity()).add(Identity()))
Expand Down

0 comments on commit da17001

Please sign in to comment.