Skip to content

Commit

Permalink
fix pytorch featureset, total train sample is incorrect. (intel-analy…
Browse files Browse the repository at this point in the history
…tics#2421)

* fix pytorch featureset

* fix style check
  • Loading branch information
qiuxin2012 authored Jul 2, 2020
1 parent cc9cd2f commit 68deb3f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 22 deletions.
8 changes: 4 additions & 4 deletions FeatureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ object PythonFeatureSet{
class PythonFeatureSet[T: ClassTag](
dataset: Array[Byte],
getLoader: (Int, Int, String) => String,
getIterator: (String, String) => String,
getIterator: (String, String, Boolean) => String,
getNext: (String) => String,
inputName: String,
targetName: String = "",
Expand All @@ -441,7 +441,7 @@ class PythonFeatureSet[T: ClassTag](
cachedRdd.mapPartitions{dataIter =>
val localLoaderName = getLocalLoader(loaderName)
val localIterName = getLocalIter(localLoaderName, train)
val getIteratorCode = getIterator(localIterName, localLoaderName)
val getIteratorCode = getIterator(localIterName, localLoaderName, train)

val nextCode = getNext(localIterName)
new Iterator[T] {
Expand Down Expand Up @@ -478,7 +478,7 @@ class PythonFeatureSet[T: ClassTag](
cachedRdd.mapPartitions{ dataIter =>
val localLoaderName = getLocalLoader(loaderName)
val localIterName = getLocalIter(localLoaderName, train)
PythonInterpreter.exec(getIterator(localIterName, localLoaderName))
PythonInterpreter.exec(getIterator(localIterName, localLoaderName, train))
new Iterator[T] {
val nextCode = getNext(localIterName)
var alreadyNext = false
Expand Down Expand Up @@ -639,7 +639,7 @@ object FeatureSet {
private[zoo] def python[T: ClassTag](
dataset: Array[Byte],
getLoader: (Int, Int, String) => String,
getIterator: (String, String) => String,
getIterator: (String, String, Boolean) => String,
getNext: (String) => String,
inputName: String,
targetName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class PythonFeatureSet[T: ClassTag](implicit ev: TensorNumeric[T]) extends Pytho
|from zoo.util.nest import flatten
|sess = tf.Session()
|""".stripMargin
def getIterator(iterName: String, loaderName: String): String = {
def getIterator(iterName: String, loaderName: String, train: Boolean): String = {
s"""
|${iterName} = ${loaderName}.make_one_shot_iterator()
|""".stripMargin
Expand All @@ -116,22 +116,28 @@ class PythonFeatureSet[T: ClassTag](implicit ev: TensorNumeric[T]) extends Pytho

def createFeatureSetFromPyTorch(
dataloader: Array[Byte]): FeatureSet[MiniBatch[Float]] = {
val trainPostfix = "_train"
val evalPostfix = "_eval"
val imports = s"""
|from zoo.util.nest import ptensor_to_numpy
|import torch
|from torch.utils.data import DataLoader
|
|""".stripMargin

def getIterator(iterName: String, loaderName: String): String = {
s"""
|if '${loaderName}_epoch' not in dir():
| ${loaderName}_epoch = 0
|else:
| ${loaderName}_epoch += 1
|${loaderName}_sampler.set_epoch(${loaderName}_epoch)
|${iterName} = enumerate(${loaderName})
|""".stripMargin
def getIterator(iterName: String, loaderName: String, train: Boolean): String = {
if (train) {
s"""
|if '${loaderName}_epoch' not in dir():
| ${loaderName}_epoch = 0
|else:
| ${loaderName}_epoch += 1
|${loaderName}_random_sampler.set_epoch(${loaderName}_epoch)
|${iterName} = enumerate(${loaderName}${trainPostfix})
|""".stripMargin
} else {
s"${iterName} = enumerate(${loaderName}${evalPostfix})"
}
}

def getNext(iterName: String): String = {
Expand All @@ -155,12 +161,10 @@ class PythonFeatureSet[T: ClassTag](implicit ev: TensorNumeric[T]) extends Pytho
|from torch.utils.data import DataLoader
|import math
|
|if isinstance(${localLoaderName}.sampler, RandomSampler):
| ${localLoaderName}_sampler=DistributedSampler(${localLoaderName}.dataset,
| ${nodeNumber}, ${partId}, True)
|else:
| ${localLoaderName}_sampler=DistributedSequentialSampler(${localLoaderName}.dataset,
| ${nodeNumber}, ${partId})
|${localLoaderName}_rand_sampler=DistributedSampler(${localLoaderName}.dataset,
| ${nodeNumber}, ${partId}, True)
|${localLoaderName}_seq_sampler=DistributedSequentialSampler(${localLoaderName}.dataset,
| ${nodeNumber}, ${partId})
|
|bs_node = int(math.ceil(${localLoaderName}.batch_size / ${nodeNumber}))
|
Expand All @@ -173,9 +177,11 @@ class PythonFeatureSet[T: ClassTag](implicit ev: TensorNumeric[T]) extends Pytho
| "drop_last": ${localLoaderName}.drop_last,
| "timeout": ${localLoaderName}.timeout,
| "worker_init_fn": ${localLoaderName}.worker_init_fn,
| "sampler": ${localLoaderName}_sampler
| "sampler": ${localLoaderName}_rand_sampler
| }
|${localLoaderName} = DataLoader(**data_loader_args)
|${localLoaderName}${trainPostfix} = DataLoader(**data_loader_args)
|data_loader_args["sampler"] = ${localLoaderName}_seq_sampler
|${localLoaderName}${evalPostfix} = DataLoader(**data_loader_args)
|""".stripMargin
}

Expand Down

0 comments on commit 68deb3f

Please sign in to comment.