Skip to content

Commit

Permalink
add JoinTable backward
Browse files Browse the repository at this point in the history
  • Loading branch information
yiheng-wang-intel authored and i8run committed Jun 23, 2018
1 parent 4e6c524 commit edbf737
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 17 deletions.
2 changes: 1 addition & 1 deletion core
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,18 @@ trait MklDnnModuleHelper {
}
protected def singleNativeData(formats: Array[MemoryData]): Array[MemoryData] = {
require(formats.length == 1, "Only accept one tensor as input")
formats(0) match {
case i: NativeData => Array(i)
case i: HeapData => Array(i.toNative())
case _ => throw new UnsupportedOperationException("Not support memory format")
}
nativeData(formats)
}
protected def nativeData(formats: Array[MemoryData]): Array[MemoryData] = {
formats.map(
f => {
f match {
case i: NativeData => i
case i: HeapData => i.toNative()
case _ => throw new UnsupportedOperationException("Not support memory format")
}
}
)
}
}

Expand Down Expand Up @@ -148,7 +155,8 @@ trait MklDnnLayer extends AbstractModule[Activity, Activity, Float] with MklDnnM
cachedInput = input
}
MklDnnOps.streamSubmit(
runtime.stream, 1, updateOutputPrimitives, 1, updateOutputMemoryPrimitives,
runtime.stream, 1, updateOutputPrimitives, updateOutputPrimitives.length,
updateOutputMemoryPrimitives,
updateOutputTensors
)
output
Expand Down Expand Up @@ -187,7 +195,8 @@ trait MklDnnLayer extends AbstractModule[Activity, Activity, Float] with MklDnnM
cachedInput = input
cachedGradOutput = gradOutput
}
MklDnnOps.streamSubmit(runtime.stream, 1, updateGradInputPrimitives, 1,
MklDnnOps.streamSubmit(runtime.stream, 1, updateGradInputPrimitives,
updateGradInputPrimitives.length,
updateGradInputMemoryPrimitives, updateGradInputTensors)
gradInput
}
Expand All @@ -214,7 +223,6 @@ trait MklDnnLayer extends AbstractModule[Activity, Activity, Float] with MklDnnM
}

override private[mkldnn] def gradOutputWeightFormats() = {
require(_gradOutputFormatsForWeight != null, "You should call initGradPrimitives first")
_gradOutputFormatsForWeight
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
*/
package com.intel.analytics.bigdl.nn.mkldnn

import com.intel.analytics.bigdl.mkl.{Memory, MklDnn, Query}
import com.intel.analytics.bigdl.mkl.{DataType, Memory, MklDnn, Query}
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor.Tensor

import scala.collection.mutable.ArrayBuffer

class JoinTable(val dimension: Int) extends MklDnnLayer {
@transient
private var memoryPrims: Array[Array[Long]] = _

override private[mkldnn] def initFwdPrimitives(inputs: Array[MemoryData], phase: Phase) = {
require(inputs.length > 0, "at least one tensor")
_inputFormats = inputs
require(inputs.length > 0, s"at least one tensor, but is ${inputs.length}")
_inputFormats = nativeData(inputs)

val totalShape = inputs(0).shape.clone()
val layout = inputs(0).layout
Expand All @@ -41,19 +48,64 @@ class JoinTable(val dimension: Int) extends MklDnnLayer {
}
i += 1
}
_outputFormats = Array(NativeData(totalShape, layout))
val primDesc = MklDnn.ConcatPrimitiveDescCreate(
_outputFormats(0).getMemoryDescription(),
MklDnn.MemoryDescInit(totalShape.length, totalShape, DataType.F32, Memory.Format.any),
inputs.length, dimension - 1, _inputFormats.map(_.getPrimitiveDescription(runtime)))

_outputFormats = Array(MemoryData.primitiveOutput(primDesc))
updateOutputPrimitives = Array(MklDnnOps.primitiveCreate2(primDesc,
_inputFormats.map(_.getPrimitive(runtime)),
new Array[Int](inputs.length), inputs.length, _outputFormats.map(_.getPrimitive(runtime)), 1))
new Array[Int](inputs.length), inputs.length,
_outputFormats.map(_.getPrimitive(runtime)), 1)
)
output = initTensor(_outputFormats(0))
(_inputFormats, _outputFormats)
}

override private[mkldnn] def initBwdPrimitives(grads: Array[MemoryData], phase: Phase) = {
null
_gradOutputFormats = singleNativeData(grads)
_gradOutputFormatsForWeight = _gradOutputFormats
_gradInputFormats = _inputFormats.map(f => {
NativeData(f.shape, f.layout)
})
val prims = new ArrayBuffer[Long]()
val buffer = new ArrayBuffer[Array[Long]]()
val offset = new Array[Int](_gradOutputFormats(0).shape.length)
for(i <- 0 until _gradInputFormats.length) {
val viewPD = MklDnn.ViewPrimitiveDescCreate(
_gradOutputFormats(0).getPrimitiveDescription(runtime), _gradInputFormats(i).shape, offset)
val viewFormat = MemoryData.primitiveOutput(viewPD)
val reorderPD = MklDnn.ReorderPrimitiveDescCreate(
viewFormat.getPrimitiveDescription(runtime),
_gradInputFormats(i).getPrimitiveDescription(runtime))
val reorderPrim = MklDnnOps.primitiveCreate2(reorderPD,
Array(viewFormat.getPrimitive(runtime)), Array(0), 1,
Array(_gradInputFormats(i).getPrimitive(runtime)), 1)
prims.append(reorderPrim)
buffer.append(Array(viewFormat.getPrimitive(runtime),
_gradInputFormats(i).getPrimitive(runtime)))
offset(dimension - 1) += _gradInputFormats(i).shape(dimension - 1)
}
updateGradInputPrimitives = prims.toArray
gradInput = initActivity(_gradInputFormats)
memoryPrims = buffer.toArray

(_gradOutputFormats, _gradInputFormats)
}

override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
require(gradOutput.isTensor, "gradOutput should be tensor")
require(gradInput.isTable, "gradInput should be table")
val _gradOutput = gradOutput.asInstanceOf[Tensor[Float]]
val _gradInput = gradInput.toTable
val length = _gradInput.length()
require(length == updateGradInputPrimitives.length, "gradOutput number not match")
var i = 0
while(i < length) {
MklDnnOps.streamSubmit(runtime.stream, 1, Array(updateGradInputPrimitives(i)),
1, memoryPrims(i), Array(_gradOutput, _gradInput[Tensor[Float]](i + 1)))
i += 1
}
gradInput
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,22 @@ class JoinTableSpec extends BigDLSpecHelper {
model.add(ReorderMemory(NativeData(Array(4, 2), Memory.Format.nc),
HeapData(Array(4, 2), Memory.Format.nc),NativeData(Array(4, 2), Memory.Format.nc),
HeapData(Array(4, 2), Memory.Format.nc)))
model.compile(Phase.InferencePhase, Array(HeapData(Array(2, 2), Memory.Format.nc)))
model.compile(Phase.TrainingPhase, Array(HeapData(Array(2, 2), Memory.Format.nc)))
model.forward(Tensor[Float](T(T(1, 2), T(3, 4)))) should be(Tensor[Float](T(
T(1, 2),
T(3, 4),
T(1, 2),
T(3, 4)
)))
model.backward(Tensor[Float](T(T(1, 2), T(3, 4))), T(
Tensor[Float](T(
T(4, 5),
T(6, 7),
T(1, 3),
T(4, 2)
))
)) should be(
Tensor[Float](T(T(5, 8), T(10, 9)))
)
}
}

0 comments on commit edbf737

Please sign in to comment.