diff --git a/core b/core index 3b376fe7b82..216d4a4922d 160000 --- a/core +++ b/core @@ -1 +1 @@ -Subproject commit 3b376fe7b82aaca7cdd7d0a1d89916471d17a0ca +Subproject commit 216d4a4922d66a0443ab0a2943967fc077980b0f diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/DnnBase.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/DnnBase.scala index b3314f11d68..5127d8f91ef 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/DnnBase.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/DnnBase.scala @@ -77,11 +77,18 @@ trait MklDnnModuleHelper { } protected def singleNativeData(formats: Array[MemoryData]): Array[MemoryData] = { require(formats.length == 1, "Only accept one tensor as input") - formats(0) match { - case i: NativeData => Array(i) - case i: HeapData => Array(i.toNative()) - case _ => throw new UnsupportedOperationException("Not support memory format") - } + nativeData(formats) + } + protected def nativeData(formats: Array[MemoryData]): Array[MemoryData] = { + formats.map( + f => { + f match { + case i: NativeData => i + case i: HeapData => i.toNative() + case _ => throw new UnsupportedOperationException("Not support memory format") + } + } + ) } } @@ -148,7 +155,8 @@ trait MklDnnLayer extends AbstractModule[Activity, Activity, Float] with MklDnnM cachedInput = input } MklDnnOps.streamSubmit( - runtime.stream, 1, updateOutputPrimitives, 1, updateOutputMemoryPrimitives, + runtime.stream, 1, updateOutputPrimitives, updateOutputPrimitives.length, + updateOutputMemoryPrimitives, updateOutputTensors ) output @@ -187,7 +195,8 @@ trait MklDnnLayer extends AbstractModule[Activity, Activity, Float] with MklDnnM cachedInput = input cachedGradOutput = gradOutput } - MklDnnOps.streamSubmit(runtime.stream, 1, updateGradInputPrimitives, 1, + MklDnnOps.streamSubmit(runtime.stream, 1, updateGradInputPrimitives, + updateGradInputPrimitives.length, updateGradInputMemoryPrimitives, updateGradInputTensors) gradInput } @@ -214,7 +223,6 @@ trait MklDnnLayer extends AbstractModule[Activity, Activity, Float] with MklDnnM } override private[mkldnn] def gradOutputWeightFormats() = { - require(_gradOutputFormatsForWeight != null, "You should call initGradPrimitives first") _gradOutputFormatsForWeight } } diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/JoinTable.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/JoinTable.scala index ccd909a7682..b784a7950d3 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/JoinTable.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/JoinTable.scala @@ -15,12 +15,19 @@ */ package com.intel.analytics.bigdl.nn.mkldnn -import com.intel.analytics.bigdl.mkl.{Memory, MklDnn, Query} +import com.intel.analytics.bigdl.mkl.{DataType, Memory, MklDnn, Query} +import com.intel.analytics.bigdl.nn.abstractnn.Activity +import com.intel.analytics.bigdl.tensor.Tensor + +import scala.collection.mutable.ArrayBuffer class JoinTable(val dimension: Int) extends MklDnnLayer { + @transient + private var memoryPrims: Array[Array[Long]] = _ + override private[mkldnn] def initFwdPrimitives(inputs: Array[MemoryData], phase: Phase) = { - require(inputs.length > 0, "at least one tensor") - _inputFormats = inputs + require(inputs.length > 0, s"at least one tensor, but is ${inputs.length}") + _inputFormats = nativeData(inputs) val totalShape = inputs(0).shape.clone() val layout = inputs(0).layout @@ -41,19 +48,64 @@ class JoinTable(val dimension: Int) extends MklDnnLayer { } i += 1 } - _outputFormats = Array(NativeData(totalShape, layout)) val primDesc = MklDnn.ConcatPrimitiveDescCreate( - _outputFormats(0).getMemoryDescription(), + MklDnn.MemoryDescInit(totalShape.length, totalShape, DataType.F32, Memory.Format.any), inputs.length, dimension - 1, _inputFormats.map(_.getPrimitiveDescription(runtime))) + _outputFormats = Array(MemoryData.primitiveOutput(primDesc)) updateOutputPrimitives = Array(MklDnnOps.primitiveCreate2(primDesc, _inputFormats.map(_.getPrimitive(runtime)), - new Array[Int](inputs.length), inputs.length, _outputFormats.map(_.getPrimitive(runtime)), 1)) + new Array[Int](inputs.length), inputs.length, + _outputFormats.map(_.getPrimitive(runtime)), 1) + ) output = initTensor(_outputFormats(0)) (_inputFormats, _outputFormats) } override private[mkldnn] def initBwdPrimitives(grads: Array[MemoryData], phase: Phase) = { - null + _gradOutputFormats = singleNativeData(grads) + _gradOutputFormatsForWeight = _gradOutputFormats + _gradInputFormats = _inputFormats.map(f => { + NativeData(f.shape, f.layout) + }) + val prims = new ArrayBuffer[Long]() + val buffer = new ArrayBuffer[Array[Long]]() + val offset = new Array[Int](_gradOutputFormats(0).shape.length) + for(i <- 0 until _gradInputFormats.length) { + val viewPD = MklDnn.ViewPrimitiveDescCreate( + _gradOutputFormats(0).getPrimitiveDescription(runtime), _gradInputFormats(i).shape, offset) + val viewFormat = MemoryData.primitiveOutput(viewPD) + val reorderPD = MklDnn.ReorderPrimitiveDescCreate( + viewFormat.getPrimitiveDescription(runtime), + _gradInputFormats(i).getPrimitiveDescription(runtime)) + val reorderPrim = MklDnnOps.primitiveCreate2(reorderPD, + Array(viewFormat.getPrimitive(runtime)), Array(0), 1, + Array(_gradInputFormats(i).getPrimitive(runtime)), 1) + prims.append(reorderPrim) + buffer.append(Array(viewFormat.getPrimitive(runtime), + _gradInputFormats(i).getPrimitive(runtime))) + offset(dimension - 1) += _gradInputFormats(i).shape(dimension - 1) + } + updateGradInputPrimitives = prims.toArray + gradInput = initActivity(_gradInputFormats) + memoryPrims = buffer.toArray + + (_gradOutputFormats, _gradInputFormats) + } + + override def updateGradInput(input: Activity, gradOutput: Activity): Activity = { + require(gradOutput.isTensor, "gradOutput should be tensor") + require(gradInput.isTable, "gradInput should be table") + val _gradOutput = gradOutput.asInstanceOf[Tensor[Float]] + val _gradInput = gradInput.toTable + val length = _gradInput.length() + require(length == updateGradInputPrimitives.length, "gradOutput number not match") + var i = 0 + while(i < length) { + MklDnnOps.streamSubmit(runtime.stream, 1, Array(updateGradInputPrimitives(i)), + 1, memoryPrims(i), Array(_gradOutput, _gradInput[Tensor[Float]](i + 1))) + i += 1 + } + gradInput } } \ No newline at end of file diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/JoinTableSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/JoinTableSpec.scala index 02ac6bcd766..6e27387b921 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/JoinTableSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/JoinTableSpec.scala @@ -35,12 +35,22 @@ class JoinTableSpec extends BigDLSpecHelper { model.add(ReorderMemory(NativeData(Array(4, 2), Memory.Format.nc), HeapData(Array(4, 2), Memory.Format.nc),NativeData(Array(4, 2), Memory.Format.nc), HeapData(Array(4, 2), Memory.Format.nc))) - model.compile(Phase.InferencePhase, Array(HeapData(Array(2, 2), Memory.Format.nc))) + model.compile(Phase.TrainingPhase, Array(HeapData(Array(2, 2), Memory.Format.nc))) model.forward(Tensor[Float](T(T(1, 2), T(3, 4)))) should be(Tensor[Float](T( T(1, 2), T(3, 4), T(1, 2), T(3, 4) ))) + model.backward(Tensor[Float](T(T(1, 2), T(3, 4))), T( + Tensor[Float](T( + T(4, 5), + T(6, 7), + T(1, 3), + T(4, 2) + )) + )) should be( + Tensor[Float](T(T(5, 8), T(10, 9))) + ) } }