Skip to content

Commit

Permalink
Infernce Model string support of TFNet (intel-analytics#2452)
Browse files Browse the repository at this point in the history
Noted InferenceModel predict batchSize is removed
  • Loading branch information
Song Jiaming authored Jun 15, 2020
1 parent 1a073a2 commit c1cf09c
Showing 1 changed file with 9 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ import java.util.{List => JList}

import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.sun.xml.internal.bind.v2.TODO

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

class InferenceModel(private var autoScalingEnabled: Boolean = true,
private var concurrentNum: Int = 20,
Expand All @@ -37,7 +39,6 @@ class InferenceModel(private var autoScalingEnabled: Boolean = true,

require(concurrentNum > 0, "concurrentNum should > 0")

private var batchCnt: Int = 0
@transient var inferenceSummary: InferenceSummary = null
/**
* default constructor, will create a InferenceModel with auto-scaling enabled.
Expand Down Expand Up @@ -743,17 +744,17 @@ class InferenceModel(private var autoScalingEnabled: Boolean = true,
val model: AbstractModel = retrieveModel()
try {
val begin = System.nanoTime()
val batchSize = if (inputActivity.isTensor) {
inputActivity.toTensor[Float].size(1)
} else {
val sampleKey = inputActivity.toTable.keySet.head
inputActivity.toTable(sampleKey).asInstanceOf[Tensor[Float]].size(1)
}
// val batchSize = if (inputActivity.isTensor) {
// inputActivity.toTensor[T].size(1)
// } else {
// val sampleKey = inputActivity.toTable.keySet.head
// inputActivity.toTable(sampleKey).asInstanceOf[Tensor[T]].size(1)
// }
val result = model.predict(inputActivity)
val end = System.nanoTime()

val latency = end - begin
val name = s"model predict for batch ${batchSize}"
val name = s"model predict for batch"
InferenceSupportive.logger.info(s"$name time elapsed [${latency/1e9} s, ${latency/1e6} ms].")

result
Expand Down

0 comments on commit c1cf09c

Please sign in to comment.