From b6b94ed25220a7c0cda567d8740dfba57d6845a4 Mon Sep 17 00:00:00 2001 From: zhangxiaoli73 <380761639@qq.com> Date: Fri, 6 Dec 2019 12:17:40 +0800 Subject: [PATCH] support roialign backward (#2975) * support roialign backward * fix sparselinear unit test --- .../analytics/bigdl/dllib/nn/RoiAlign.scala | 190 +++++++++++++++++- .../analytics/bigdl/nn/RoiAlignSpec.scala | 87 ++++++++ .../analytics/bigdl/nn/SparseLinearSpec.scala | 8 +- 3 files changed, 273 insertions(+), 12 deletions(-) diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/RoiAlign.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/RoiAlign.scala index a81ba967dd6..40a18757ec5 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/RoiAlign.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/RoiAlign.scala @@ -18,7 +18,7 @@ package com.intel.analytics.bigdl.nn import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.utils.Table -import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule +import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity} import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric import scala.reflect._ @@ -49,11 +49,11 @@ class RoiAlign[T: ClassTag] ( val pooledW: Int, val mode: String = "avg", val aligned: Boolean = true -)(implicit ev: TensorNumeric[T]) extends AbstractModule[Table, Tensor[T], T]{ - override def updateOutput(input: Table): Tensor[T] = { +)(implicit ev: TensorNumeric[T]) extends AbstractModule[Activity, Tensor[T], T]{ + override def updateOutput(input: Activity): Tensor[T] = { if (classTag[T] == classTag[Float]) { - val data = input[Tensor[Float]](1) - val rois = input[Tensor[Float]](2) + val data = input.toTable[Tensor[Float]](1) + val rois = input.toTable[Tensor[Float]](2) val num_rois = rois.size(1) val channels = data.size(2) @@ -78,8 +78,8 @@ class RoiAlign[T: ClassTag] ( width, spatialScale) } else if (classTag[T] == classTag[Double]) { - val data = input[Tensor[Double]](1) - val rois = input[Tensor[Double]](2) + val data = input.toTable[Tensor[Double]](1) + val rois = input.toTable[Tensor[Double]](2) val num_rois = rois.size(1) val channels = data.size(2) @@ -110,8 +110,180 @@ class RoiAlign[T: ClassTag] ( output } - override def updateGradInput(input: Table, gradOutput: Tensor[T]): Table = { - throw new UnsupportedOperationException("Not support backward propagation") + + private def bilinearInterpolateGradient(height: Int, width: Int, y: Float, x: Float) + : (Float, Float, Float, Float, Int, Int, Int, Int) = { + var w1: Float = 0.0f + var w2: Float = 0.0f + var w3: Float = 0.0f + var w4: Float = 0.0f + var x_low : Int = 0 + var x_high: Int = 0 + var y_low: Int = 0 + var y_high: Int = 0 + + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return (w1, w2, w3, w4, x_low, x_high, y_low, y_high) + } + + var realY = if (y <= 0) 0 else y + var realX = if (x <= 0) 0 else x + y_low = realY.toInt + x_low = realX.toInt + + if (y_low >= height - 1) { + y_high = height - 1 + y_low = height - 1 + realY = y_low + } else y_high = y_low + 1 + + if (x_low >= width - 1) { + x_high = width - 1 + x_low = width - 1 + realX = x_low + } else x_high = x_low + 1 + + val ly = realY - y_low + val lx = realX - x_low + val hy = 1.0 - ly + val hx = 1.0 - lx + + w1 = (hy * hx).toFloat + w2 = (hy * lx).toFloat + w3 = (ly * hx).toFloat + w4 = (ly * lx).toFloat + + return (w1, w2, w3, w4, x_low, x_high, y_low, y_high) + } + + private def roiAlignBackward( + nums: Int, + gradOutputArr: Array[T], + gradInputArr: Array[T], + gradInputOffset: Int, + rois: Array[T], + channels: Int, + height: Int, + width: Int, + pooled_height: Int, + pooled_width: Int, + sampling_ratio : Int, + n_stride : Int, + c_stride : Int, + h_stride : Int, + w_stride : Int, + spatial_scale: Float) { + val roi_cols = 4 + for (index <- 0 until nums) { + val pw = index % pooled_width + val ph = (index / pooled_width) % pooled_height + val c = (index / pooled_width / pooled_height) % channels + val n = index / pooled_width / pooled_height / channels + val offset_rois = n * roi_cols + + val offset = if (aligned) 0.5f else 0.0f + val roi_start_w = ev.toType[Float](rois(offset_rois)) * spatial_scale - offset + val roi_start_h = ev.toType[Float](rois(offset_rois + 1)) * spatial_scale - offset + val roi_end_w = ev.toType[Float](rois(offset_rois + 2)) * spatial_scale - offset + val roi_end_h = ev.toType[Float](rois(offset_rois + 3)) * spatial_scale - offset + + var roi_width = roi_end_w - roi_start_w + var roi_height = roi_end_h - roi_start_h + + if (aligned) { + require(roi_width >= 0 && roi_height >= 0, + s"ROIs in ROIAlign do not have non-negative size!" + + s"But get ${roi_height} ${roi_width}") + } else { + roi_width = math.max(roi_width, 1.0f) + roi_height = math.max(roi_height, 1.0f) + } + + val bin_size_h = roi_height / pooled_height + val bin_size_w = roi_width / pooled_width + val output_offset = n * n_stride + c * c_stride + val grad_output_value = gradOutputArr(output_offset + ph * h_stride + pw * w_stride) + + // We use roi_bin_grid to sample the grid and mimic integral + val roi_bin_grid_h = + if (sampling_ratio > 0) sampling_ratio else math.ceil(roi_height / pooled_height).toInt + val roi_bin_grid_w = + if (sampling_ratio > 0) sampling_ratio else math.ceil(roi_width / pooled_width).toInt + + // We do average (integral) pooling inside a bin + val count = roi_bin_grid_h * roi_bin_grid_w + + for (iy <- 0 until roi_bin_grid_h) { + val y = roi_start_h + ph * bin_size_h + (iy + 0.5) * bin_size_h / roi_bin_grid_h + for (ix <- 0 until roi_bin_grid_w) { + val x = roi_start_w + pw * bin_size_w + (ix + 0.5) * bin_size_w / roi_bin_grid_w + + val (w1, w2, w3, w4, x_low, x_high, y_low, y_high) = + bilinearInterpolateGradient(height, width, y.toFloat, x.toFloat) + + val g1 = ev.times(grad_output_value, ev.fromType(w1 / count)) + val g2 = ev.times(grad_output_value, ev.fromType(w2 / count)) + val g3 = ev.times(grad_output_value, ev.fromType(w3 / count)) + val g4 = ev.times(grad_output_value, ev.fromType(w4 / count)) + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + gradInputArr(gradInputOffset + y_low * width + x_low) = + ev.plus(gradInputArr(gradInputOffset + y_low * width + x_low), g1) + gradInputArr(gradInputOffset + y_low * width + x_high) = + ev.plus(gradInputArr(gradInputOffset + y_low * width + x_high), g2) + gradInputArr(gradInputOffset + y_high * width + x_low) = + ev.plus(gradInputArr(gradInputOffset + y_high * width + x_low), g3) + gradInputArr(gradInputOffset + y_high * width + x_high) = + ev.plus(gradInputArr(gradInputOffset + y_high * width + x_high), g4) + } + } + } + } + } + + override def updateGradInput(input: Activity, gradOutput: Tensor[T]): Activity = { + require(mode == "avg", s"Only support backward for average mode, but get ${mode}") + val data = input.toTable[Tensor[T]](1) + val rois = input.toTable[Tensor[T]](2) + val num_rois = rois.size(1) + val channels = data.size(2) + val height = data.size(3) + val width = data.size(4) + + require(gradOutput.isContiguous(), "gradOutput should be contiguous") + require(gradOutput.dim() == 4, s"gradOutput should be with 4 dims, but get ${gradOutput.dim()}") + + val n_stride = gradOutput.stride(1) + val c_stride = gradOutput.stride(2) + val h_stride = gradOutput.stride(3) + val w_stride = gradOutput.stride(4) + + if (gradInput == null) gradInput = Tensor[T]() + gradInput.toTensor[T].resize(channels, height, width) + val gradInputArr = gradInput.toTensor[T].storage().array() + val gradInputOffset = gradInput.toTensor[T].storageOffset() - 1 + + roiAlignBackward( + gradOutput.nElement(), + gradOutputArr = gradOutput.asInstanceOf[Tensor[T]].storage().array(), + gradInputArr = gradInputArr, + gradInputOffset = 0, + rois = rois.storage().array(), + channels = channels, + height = height, + width = width, + pooled_height = pooledH, + pooled_width = pooledW, + sampling_ratio = samplingRatio, + n_stride = n_stride, + c_stride = c_stride, + h_stride = h_stride, + w_stride = w_stride, + spatial_scale = spatialScale) + + gradInput } private def poolOneRoiFloat( diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/RoiAlignSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/RoiAlignSpec.scala index 219b9ef496c..158459fcaec 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/RoiAlignSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/RoiAlignSpec.scala @@ -16,6 +16,7 @@ package com.intel.analytics.bigdl.nn +import com.intel.analytics.bigdl.nn.mkldnn.Equivalent import com.intel.analytics.bigdl.tensor.{Storage, Tensor} import com.intel.analytics.bigdl.utils.RandomGenerator._ import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest @@ -209,6 +210,92 @@ class RoiAlignSpec extends FlatSpec with Matchers { out should be(expectedWithAlign) out2 should be(expected) } + + "backward" should "work correctly" in { + val input = Tensor[Float](T(T(T( + T(0.0611, 0.2246, 0.2343, 0.1771, 0.5561, 0.1094, 0.4609, 0.7084, + 0.5798, 0.4967), + T(0.5104, 0.3295, 0.7182, 0.3845, 0.0898, 0.1175, 0.6402, 0.1968, + 0.5124, 0.7118), + T(0.9249, 0.9997, 0.8927, 0.8767, 0.8450, 0.1544, 0.1705, 0.9842, + 0.8127, 0.4358), + T(0.4143, 0.4284, 0.7578, 0.9225, 0.9643, 0.1760, 0.9539, 0.3134, + 0.4544, 0.2956), + T(0.1875, 0.2433, 0.3493, 0.4441, 0.4069, 0.2859, 0.8036, 0.3218, + 0.3639, 0.2985), + T(0.6635, 0.2552, 0.4144, 0.8396, 0.7418, 0.2865, 0.7929, 0.5001, + 0.8977, 0.1051), + T(0.5809, 0.9867, 0.1315, 0.2391, 0.3047, 0.5158, 0.4514, 0.4929, + 0.5301, 0.2647), + T(0.1671, 0.5482, 0.2380, 0.5374, 0.4422, 0.6454, 0.5376, 0.2245, + 0.6632, 0.8439), + T(0.0109, 0.2807, 0.9301, 0.5438, 0.8123, 0.7750, 0.7308, 0.9924, + 0.7282, 0.2328), + T(0.9997, 0.5540, 0.4200, 0.5419, 0.8642, 0.4312, 0.1213, 0.8956, + 0.8784, 0.9128))))) + + val rois = Tensor[Float](T(T(0.0f, 0.0f, 9.0f, 9.0f), + T(0.0f, 5.0f, 4.0f, 9.0f), + T(5.0f, 5.0f, 9.0f, 9.0f))) + + val layer = RoiAlign[Float](spatialScale = 1, samplingRatio = 2, pooledH = 5, + pooledW = 5, aligned = true) + val out = layer.forward(T(input, rois)) + + val output = Tensor[Float](T(T(T( + T(0.2593, 0.3618, 0.2819, 0.3935, 0.5265), + T(0.7170, 0.8159, 0.6562, 0.4006, 0.6567), + T(0.3210, 0.4949, 0.5372, 0.5892, 0.4368), + T(0.6147, 0.3702, 0.4642, 0.5216, 0.5698), + T(0.2292, 0.5687, 0.6427, 0.6625, 0.6822))), + + T(T(T(0.5731, 0.3794, 0.3402, 0.4984, 0.7202), + T(0.6138, 0.7188, 0.4918, 0.2772, 0.4116), + T(0.3937, 0.6494, 0.4761, 0.2458, 0.3759), + T(0.1376, 0.3636, 0.4568, 0.4737, 0.5367), + T(0.1754, 0.2846, 0.5770, 0.7363, 0.5957))), + + T(T(T(0.3776, 0.6335, 0.6252, 0.5709, 0.6844), + T(0.4507, 0.5218, 0.5245, 0.5387, 0.5696), + T(0.5452, 0.5203, 0.4266, 0.4301, 0.5784), + T(0.6602, 0.6221, 0.5252, 0.5232, 0.6680), + T(0.7253, 0.6559, 0.7846, 0.8819, 0.6998))))) + + val gradOutput = Tensor[Float](T(T( + T(T(0.9688, 0.4150, 0.4094, 0.6885, 0.6800), + T(0.6415, 0.4019, 0.4875, 0.9569, 0.5172), + T(0.9534, 0.8540, 0.9555, 0.0836, 0.1684), + T(0.1883, 0.9384, 0.3543, 0.2027, 0.5069), + T(0.7145, 0.6801, 0.9717, 0.2403, 0.3372))), + T(T(T(0.5260, 0.1794, 0.4793, 0.3070, 0.7682), + T(0.6350, 0.7321, 0.9899, 0.1897, 0.6957), + T(0.1313, 0.9514, 0.3386, 0.5337, 0.1051), + T(0.1800, 0.4603, 0.7114, 0.5114, 0.2422), + T(0.1480, 0.2527, 0.2014, 0.3004, 0.7147))), + T(T(T(0.4033, 0.9819, 0.4697, 0.3446, 0.7631), + T(0.3554, 0.2396, 0.6231, 0.6009, 0.3054), + T(0.2082, 0.2404, 0.6693, 0.7529, 0.1088), + T(0.0441, 0.4054, 0.0348, 0.7627, 0.0077), + T(0.9582, 0.6859, 0.3182, 0.5291, 0.3420))))) + + Equivalent.nearequals(output, out, 1e-3) should be(true) + + val grad = layer.backward(T(input, rois), gradOutput).toTensor[Float] + + val expectedGrad = Tensor[Float](T(T( + T(0.3203, 0.2666, 0.1312, 0.1305, 0.1295, 0.1816, 0.2177, 0.2157, 0.2150, 0.0098), + T(0.2828, 0.2374, 0.1246, 0.1265, 0.1292, 0.1868, 0.2267, 0.2018, 0.1945, 0.0088), + T(0.2029, 0.1776, 0.1216, 0.1322, 0.1475, 0.2314, 0.2895, 0.1867, 0.1565, 0.0071), + T(0.2432, 0.2201, 0.1775, 0.1889, 0.2054, 0.1912, 0.1814, 0.1288, 0.1133, 0.0051), + T(0.3845, 0.3403, 0.3323, 0.3769, 0.3154, 0.2258, 0.1666, 0.1222, 0.1580, 0.0195), + T(0.8482, 0.8043, 0.8665, 0.9852, 0.3694, 0.7024, 0.9496, 0.7323, 0.8099, 0.1104), + T(0.8683, 1.2765, 1.0463, 0.7984, 0.2498, 0.4796, 0.7130, 1.1149, 0.6427, 0.0529), + T(0.6204, 1.1059, 1.0230, 0.6332, 0.3176, 0.4221, 0.5735, 0.9508, 0.4563, 0.0167), + T(0.4918, 0.6479, 0.7008, 0.8754, 0.5076, 0.9881, 0.7134, 0.6981, 0.5184, 0.0460), + T(0.0427, 0.0525, 0.0614, 0.1103, 0.0510, 0.1533, 0.1064, 0.0863, 0.0695, 0.0079)))) + + Equivalent.nearequals(grad, expectedGrad, 1e-3) should be(true) + } } class RoiAlignSerialTest extends ModuleSerializationTest { diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/SparseLinearSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/SparseLinearSpec.scala index 4c1da650a73..da88ce1e365 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/SparseLinearSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/SparseLinearSpec.scala @@ -19,7 +19,7 @@ package com.intel.analytics.bigdl.nn import org.scalatest.{FlatSpec, Matchers} import com.intel.analytics.bigdl.numeric.NumericFloat import com.intel.analytics.bigdl.tensor.{SparseTensor, Tensor} -import com.intel.analytics.bigdl.utils.T +import com.intel.analytics.bigdl.utils.{RandomGenerator, T} import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest import scala.util.Random @@ -143,9 +143,11 @@ class SparseLinearSpec extends FlatSpec with Matchers { } "Sparse Linear" should "return the same result with Linear 7" in { + RandomGenerator.RNG.setSeed(10) + val rnd = new Random(10) val gradOutput = Tensor(4, 2).rand() - val input = Tensor(4, 1023213).apply1(_ => Random.nextInt(100000) / 99999 * Random.nextFloat()) - val input2 = Tensor(4, 50).apply1(_ => Random.nextInt(2) * Random.nextFloat()) + val input = Tensor(4, 1023213).apply1(_ => rnd.nextInt(100000) / 99999 * rnd.nextFloat()) + val input2 = Tensor(4, 50).apply1(_ => rnd.nextInt(2) * rnd.nextFloat()) val sl = SparseLinear(1023263, 2, backwardStart = 1, backwardLength = 1023263) val sj = SparseJoinTable(2) val sparseModel = Sequential().add(ParallelTable().add(Identity()).add(Identity()))