Skip to content

Commit

Permalink
fix: takeSample only works for dnn backend and get one batch (intel-a…
Browse files Browse the repository at this point in the history
…nalytics#2947)

* fix: takeSample only works for dnn backend and get one batch
  • Loading branch information
i8run authored Oct 28, 2019
1 parent 5abb1e2 commit 756643b
Showing 1 changed file with 13 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,19 @@ object Predictor {
// because Evaluator will use it too, we extend the scope out of Predictor
private[optim] def getDummyData[T: ClassTag, R](dataSet: RDD[R],
batchSize: Int)(implicit ev: TensorNumeric[T]): Activity = {
// here has an assumption, batchSizePerPar is not very large.
val samples = dataSet.takeSample(withReplacement = false, num = batchSize)
.map {
case feature: ImageFeature => feature[Sample[T]](ImageFeature.sample)
case sample => sample.asInstanceOf[Sample[T]]
}
val sampleToMiniBatch = SampleToMiniBatch(batchSize)
sampleToMiniBatch(samples.toIterator).toSeq.head.getInput()
if (Engine.getEngineType() == MklDnn && Engine.isMultiModels) {
// here has an assumption, batchSizePerPar is not very large.
val samples = dataSet.takeSample(withReplacement = false, num = batchSize)
.map {
case feature: ImageFeature => feature[Sample[T]](ImageFeature.sample)
case sample => sample.asInstanceOf[Sample[T]]
}
val sampleToMiniBatch = SampleToMiniBatch(batchSize, partitionNum = Some(1))
val miniBatch = sampleToMiniBatch(samples.toIterator).toSeq
miniBatch.head.getInput()
} else {
Tensor()
}
}
}

Expand Down

0 comments on commit 756643b

Please sign in to comment.