Skip to content

Commit

Permalink
fix bug (intel#1730)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 authored Oct 31, 2019
1 parent 22a97ff commit ffcf932
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 17 deletions.
14 changes: 14 additions & 0 deletions net/NetUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,20 @@ object NetUtils {
vec
}

def generateZeroGrad(input: Activity, grad: Activity): Unit = {
if (grad.isTable) {
var i = 0
while (i < grad.toTable.length()) {
grad.toTable[Tensor[Float]](i + 1)
.resizeAs(input.toTable[Tensor[Float]](i + 1))
i = i + 1
}
} else {
grad.toTensor[Float]
.resizeAs(input.toTensor[Float])
}
}

def tfenum2datatype(enum: Int): DataType = {
enum match {
case 1 => DataType.FLOAT
Expand Down
16 changes: 1 addition & 15 deletions net/TFNet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class TFNet(private val graphDef: TFGraphHolder,
override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
try {
if (graphMeta.variables.isEmpty) {
generateZeroGrad(input)
NetUtils.generateZeroGrad(input, gradInput)
} else {

val runner = sess.runner()
Expand Down Expand Up @@ -568,20 +568,6 @@ class TFNet(private val graphDef: TFGraphHolder,
}
}

private def generateZeroGrad(input: Activity) = {
if (gradInput.isTable) {
var i = 0
while (i < gradInput.toTable.length()) {
gradInput.toTable[Tensor[Float]](i + 1)
.resizeAs(input.toTable[Tensor[Float]](i + 1))
i = i + 1
}
} else {
gradInput.toTensor[Float]
.resizeAs(input.toTensor[Float])
}
}

private def addGrad(name: String) = {
val parts = name.split(":")
parts(0) + "_grad:" + parts(1)
Expand Down
25 changes: 23 additions & 2 deletions net/TFNetForInference.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ private[zoo] class TFNetForInference(graphRunner: GraphRunner,
}
}

gradInput = {
if (inputs.length == 1) {
Tensor[Float]()
} else {
val t = T()
var i = 0
while (i < inputs.length) {
t.insert(Tensor[Float]())
i = i + 1
}
t
}
}

private def setVariableIntoTF(weights: Array[Tensor[Float]],
inputNames: Array[String],
variableTypes: Array[DataType],
Expand All @@ -85,12 +99,18 @@ private[zoo] class TFNetForInference(graphRunner: GraphRunner,
)
}

setVariableIntoTF(weights, variableAssignPlaceholders,
variableTypes.map(NetUtils.tfenum2datatype), assignVariableOps)
@transient
private lazy val variableInited = {
setVariableIntoTF(weights, variableAssignPlaceholders,
variableTypes.map(NetUtils.tfenum2datatype), assignVariableOps)
true
}

override def updateOutput(input: Activity): Activity = {
NetUtils.timeIt("updateOutput", TFNetForInference.logger) {

assert(variableInited)

val feeds = NetUtils.activity2VectorBuilder(input)

val types = inputTypes.toVector.map(NetUtils.tfenum2datatype)
Expand All @@ -110,6 +130,7 @@ private[zoo] class TFNetForInference(graphRunner: GraphRunner,
override def updateGradInput(
input: Activity,
gradOutput: Activity): Activity = {
NetUtils.generateZeroGrad(input, gradInput)
gradInput
}
}
Expand Down

0 comments on commit ffcf932

Please sign in to comment.