From 68deb3fbe8b53365954111d728615055cb7874d5 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Thu, 2 Jul 2020 10:51:36 +0800 Subject: [PATCH] fix pytorch featureset, total train sample is incorrect. (#2421) * fix pytorch featureset * fix style check --- FeatureSet.scala | 8 ++-- .../feature/python/PythonFeatureSet.scala | 42 +++++++++++-------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/FeatureSet.scala b/FeatureSet.scala index e4401ce266b..405393b73c6 100644 --- a/FeatureSet.scala +++ b/FeatureSet.scala @@ -415,7 +415,7 @@ object PythonFeatureSet{ class PythonFeatureSet[T: ClassTag]( dataset: Array[Byte], getLoader: (Int, Int, String) => String, - getIterator: (String, String) => String, + getIterator: (String, String, Boolean) => String, getNext: (String) => String, inputName: String, targetName: String = "", @@ -441,7 +441,7 @@ class PythonFeatureSet[T: ClassTag]( cachedRdd.mapPartitions{dataIter => val localLoaderName = getLocalLoader(loaderName) val localIterName = getLocalIter(localLoaderName, train) - val getIteratorCode = getIterator(localIterName, localLoaderName) + val getIteratorCode = getIterator(localIterName, localLoaderName, train) val nextCode = getNext(localIterName) new Iterator[T] { @@ -478,7 +478,7 @@ class PythonFeatureSet[T: ClassTag]( cachedRdd.mapPartitions{ dataIter => val localLoaderName = getLocalLoader(loaderName) val localIterName = getLocalIter(localLoaderName, train) - PythonInterpreter.exec(getIterator(localIterName, localLoaderName)) + PythonInterpreter.exec(getIterator(localIterName, localLoaderName, train)) new Iterator[T] { val nextCode = getNext(localIterName) var alreadyNext = false @@ -639,7 +639,7 @@ object FeatureSet { private[zoo] def python[T: ClassTag]( dataset: Array[Byte], getLoader: (Int, Int, String) => String, - getIterator: (String, String) => String, + getIterator: (String, String, Boolean) => String, getNext: (String) => String, inputName: String, targetName: String, diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/feature/python/PythonFeatureSet.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/feature/python/PythonFeatureSet.scala index da450c6adfe..78a57fca7df 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/feature/python/PythonFeatureSet.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/feature/python/PythonFeatureSet.scala @@ -91,7 +91,7 @@ class PythonFeatureSet[T: ClassTag](implicit ev: TensorNumeric[T]) extends Pytho |from zoo.util.nest import flatten |sess = tf.Session() |""".stripMargin - def getIterator(iterName: String, loaderName: String): String = { + def getIterator(iterName: String, loaderName: String, train: Boolean): String = { s""" |${iterName} = ${loaderName}.make_one_shot_iterator() |""".stripMargin @@ -116,6 +116,8 @@ class PythonFeatureSet[T: ClassTag](implicit ev: TensorNumeric[T]) extends Pytho def createFeatureSetFromPyTorch( dataloader: Array[Byte]): FeatureSet[MiniBatch[Float]] = { + val trainPostfix = "_train" + val evalPostfix = "_eval" val imports = s""" |from zoo.util.nest import ptensor_to_numpy |import torch @@ -123,15 +125,19 @@ class PythonFeatureSet[T: ClassTag](implicit ev: TensorNumeric[T]) extends Pytho | |""".stripMargin - def getIterator(iterName: String, loaderName: String): String = { - s""" - |if '${loaderName}_epoch' not in dir(): - | ${loaderName}_epoch = 0 - |else: - | ${loaderName}_epoch += 1 - |${loaderName}_sampler.set_epoch(${loaderName}_epoch) - |${iterName} = enumerate(${loaderName}) - |""".stripMargin + def getIterator(iterName: String, loaderName: String, train: Boolean): String = { + if (train) { + s""" + |if '${loaderName}_epoch' not in dir(): + | ${loaderName}_epoch = 0 + |else: + | ${loaderName}_epoch += 1 + |${loaderName}_random_sampler.set_epoch(${loaderName}_epoch) + |${iterName} = enumerate(${loaderName}${trainPostfix}) + |""".stripMargin + } else { + s"${iterName} = enumerate(${loaderName}${evalPostfix})" + } } def getNext(iterName: String): String = { @@ -155,12 +161,10 @@ class PythonFeatureSet[T: ClassTag](implicit ev: TensorNumeric[T]) extends Pytho |from torch.utils.data import DataLoader |import math | - |if isinstance(${localLoaderName}.sampler, RandomSampler): - | ${localLoaderName}_sampler=DistributedSampler(${localLoaderName}.dataset, - | ${nodeNumber}, ${partId}, True) - |else: - | ${localLoaderName}_sampler=DistributedSequentialSampler(${localLoaderName}.dataset, - | ${nodeNumber}, ${partId}) + |${localLoaderName}_rand_sampler=DistributedSampler(${localLoaderName}.dataset, + | ${nodeNumber}, ${partId}, True) + |${localLoaderName}_seq_sampler=DistributedSequentialSampler(${localLoaderName}.dataset, + | ${nodeNumber}, ${partId}) | |bs_node = int(math.ceil(${localLoaderName}.batch_size / ${nodeNumber})) | @@ -173,9 +177,11 @@ class PythonFeatureSet[T: ClassTag](implicit ev: TensorNumeric[T]) extends Pytho | "drop_last": ${localLoaderName}.drop_last, | "timeout": ${localLoaderName}.timeout, | "worker_init_fn": ${localLoaderName}.worker_init_fn, - | "sampler": ${localLoaderName}_sampler + | "sampler": ${localLoaderName}_rand_sampler | } - |${localLoaderName} = DataLoader(**data_loader_args) + |${localLoaderName}${trainPostfix} = DataLoader(**data_loader_args) + |data_loader_args["sampler"] = ${localLoaderName}_seq_sampler + |${localLoaderName}${evalPostfix} = DataLoader(**data_loader_args) |""".stripMargin }