Skip to content

Commit

Permalink
Pytorchloader and some pytorch examples (intel-analytics#2318)
Browse files Browse the repository at this point in the history
* pytorch loader

* some fix

* update to distributed sampler

* add distributedseqsampler

* some clean up

* delete size

* update example

* Create README.md

* Update README.md

* update example

* update example

* some update

* some change

* fix python style check

* Update README.md

* some update

* meet code review

* clean up

* some update

* some fix

* update main.py

* Update README.md

* some update

* meet code review

* some fix

* fix unit test

* fix ut

* add toto

* fix rebase
  • Loading branch information
qiuxin2012 authored May 26, 2020
1 parent 94bf8d4 commit 3b887c6
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ class CachedDistributedFeatureSet[T: ClassTag]
}
}

object PythonLoaderFeatureSet{
object PythonFeatureSet{
// One partition one loader
protected def getLocalLoader(loaderName: String): String = {
s"${loaderName}_${TaskContext.getPartitionId()}"
Expand All @@ -339,64 +339,46 @@ object PythonLoaderFeatureSet{
s"${loaderName}_iter_${train}"
}

protected def loadPytorchLoader(
protected def loadPythonSet(
loaderName: String,
getLoader: (Int, Int, String) => String,
dataset: Array[Byte],
imports: String,
interpRdd: RDD[SharedInterpreter]): Unit = {
interpRdd: RDD[Int]): Unit = {
val bcDataSet = interpRdd.sparkContext.broadcast(dataset)
val nodeNumber = EngineRef.getNodeNumber()
val preimports = s"""
|from pyspark.serializers import CloudPickleSerializer
|import numpy as np
|""".stripMargin + imports
interpRdd.mapPartitions{iter =>
val interp = iter.next()
val partId = TaskContext.getPartitionId()
require(partId < nodeNumber, s"partId($partId) should be" +
s" smaller than nodeNumber(${nodeNumber})")
interp.exec(preimports)
interp.set("pyjarray", bcDataSet.value)
PythonInterpreter.exec(preimports)
PythonInterpreter.set("pyjarray", bcDataSet.value)

val load = s"""
|by${partId} = bytes(b % 256 for b in pyjarray)
|func${partId} = CloudPickleSerializer.loads(CloudPickleSerializer, by${partId})
|${getLocalLoader(loaderName)} = func${partId}().shard(${nodeNumber}, ${partId})
|""".stripMargin
val localLoaderName = getLocalLoader(loaderName)

interp.exec(load)
Iterator.single(interp)
val load = getLoader(nodeNumber, partId, localLoaderName)
PythonInterpreter.exec(load)
Iterator.single(1)
}.count()
}

private var jepRDD: RDD[SharedInterpreter] = null
protected def getOrCreateInterpRdd(): RDD[SharedInterpreter] = {
if (jepRDD == null) {
this.synchronized {
if (jepRDD == null) {
val sc = SparkContext.getOrCreate()
val nodeNumber = EngineRef.getNodeNumber()
// TODO: make sure 1 executor 1 partition
val originRdd = sc.parallelize(
Array.tabulate(nodeNumber)(_ => "dummy123123"), nodeNumber * 10)
.mapPartitions(_ => (0 until 20000000).toIterator)
.coalesce(nodeNumber)
.setName("PartitionRDD")
.persist(StorageLevel.DISK_ONLY)
originRdd.count()
originRdd.mapPartitions{
_ => TFNetNative.isLoaded
Iterator.single(1)
}.count()
jepRDD = originRdd.mapPartitions { iter =>
val interp = PythonInterpreter.getSharedInterpreter()
Iterator.single(interp)
}.setName("SharedInterpRDD").cache()
jepRDD.count()
}
}
}
jepRDD
protected lazy val cachedRdd: RDD[Int] = createCachedRdd()
protected def createCachedRdd(): RDD[Int] = {
val sc = SparkContext.getOrCreate()
val nodeNumber = EngineRef.getNodeNumber()
// TODO: make sure 1 executor 1 partition
val originRdd = sc.parallelize(
Array.tabulate(nodeNumber)(_ => "dummy123123"), nodeNumber * 10)
.mapPartitions(_ => (0 until 20000000).toIterator)
.coalesce(nodeNumber)
.setName("PartitionRDD")
.persist(StorageLevel.DISK_ONLY)
originRdd.count()
originRdd
}

private[zoo] def toArrayTensor(
Expand Down Expand Up @@ -430,23 +412,23 @@ object PythonLoaderFeatureSet{
}
}

class PythonLoaderFeatureSet[T: ClassTag](
class PythonFeatureSet[T: ClassTag](
dataset: Array[Byte],
getLoader: (Int, Int, String) => String,
getIterator: (String, String) => String,
getNext: (String) => String,
inputName: String,
targetName: String = "",
totalSize: Int,
imports: String = "") extends DistributedFeatureSet[T] {
import PythonLoaderFeatureSet._
import PythonFeatureSet._
protected val namePostfix = Integer.toHexString(java.util.UUID.randomUUID().hashCode())
protected val loaderName = s"loader${namePostfix}"

protected val sharedInterp = getOrCreateInterpRdd()
loadPytorchLoader(loaderName, dataset, imports, sharedInterp)
loadPythonSet(loaderName, getLoader, dataset, imports, cachedRdd)

override def originRDD(): RDD[_] = {
sharedInterp
cachedRdd
}

override def data(train: Boolean): RDD[T] = {
Expand All @@ -456,8 +438,7 @@ class PythonLoaderFeatureSet[T: ClassTag](
val getNext = this.getNext
val getIterator = this.getIterator
if (train) {
sharedInterp.mapPartitions{dataIter =>
val interp = dataIter.next()
cachedRdd.mapPartitions{dataIter =>
val localLoaderName = getLocalLoader(loaderName)
val localIterName = getLocalIter(localLoaderName, train)
val getIteratorCode = getIterator(localIterName, localLoaderName)
Expand All @@ -470,20 +451,21 @@ class PythonLoaderFeatureSet[T: ClassTag](

override def next(): T = {
try {
interp.exec(nextCode)
PythonInterpreter.exec(nextCode)
} catch {
case e: Exception =>
if (e.getMessage().contains("End of sequence") ||
e.getMessage().contains("StopIteration") ||
e.getMessage().contains("is not defined")) {
interp.exec(getIteratorCode)
interp.exec(nextCode)
PythonInterpreter.exec(getIteratorCode)
PythonInterpreter.exec(nextCode)
} else {
throw e
}
}
val inputs = toArrayTensor(interp.getValue(inputName))
val inputs = toArrayTensor(PythonInterpreter.getValue[AnyRef](inputName))
val miniBatch = if (targetName != "") {
val targets = toArrayTensor(interp.getValue(targetName))
val targets = toArrayTensor(PythonInterpreter.getValue(targetName))
MiniBatch[Float](inputs, targets)
} else {
MiniBatch[Float](inputs)
Expand All @@ -493,22 +475,22 @@ class PythonLoaderFeatureSet[T: ClassTag](
}
}
} else {
sharedInterp.mapPartitions{ dataIter =>
val interp = dataIter.next()
cachedRdd.mapPartitions{ dataIter =>
val localLoaderName = getLocalLoader(loaderName)
val localIterName = getLocalIter(localLoaderName, train)
interp.exec(getIterator(localIterName, localLoaderName))
PythonInterpreter.exec(getIterator(localIterName, localLoaderName))
new Iterator[T] {
val nextCode = getNext(localIterName)
var alreadyNext = false

override def hasNext: Boolean = {
if (!alreadyNext) {
try {
interp.exec(nextCode)
PythonInterpreter.exec(nextCode)
} catch {
case e: Exception =>
if (e.getMessage().contains("End of sequence")) {
if (e.getMessage().contains("End of sequence") ||
e.getMessage().contains("StopIteration")) {
return false
} else {
throw e
Expand All @@ -521,11 +503,11 @@ class PythonLoaderFeatureSet[T: ClassTag](

override def next(): T = {
if (!alreadyNext) {
interp.exec(nextCode)
PythonInterpreter.exec(nextCode)
}
val inputs = toArrayTensor(interp.getValue(inputName))
val inputs = toArrayTensor(PythonInterpreter.getValue(inputName))
val miniBatch = if (targetName != "") {
val targets = toArrayTensor(interp.getValue(targetName))
val targets = toArrayTensor(PythonInterpreter.getValue(targetName))
MiniBatch[Float](inputs, targets)
} else {
MiniBatch[Float](inputs)
Expand Down Expand Up @@ -656,13 +638,14 @@ object FeatureSet {
val logger: Logger = LoggerFactory.getLogger(this.getClass)
private[zoo] def python[T: ClassTag](
dataset: Array[Byte],
getLoader: (Int, Int, String) => String,
getIterator: (String, String) => String,
getNext: (String) => String,
inputName: String,
targetName: String,
totalSize: Int,
imports: String = ""): PythonLoaderFeatureSet[T] = {
new PythonLoaderFeatureSet[T](dataset, getIterator, getNext,
imports: String = ""): PythonFeatureSet[T] = {
new PythonFeatureSet[T](dataset, getLoader, getIterator, getNext,
inputName, targetName, totalSize, imports)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,91 @@ class PythonFeatureSet[T: ClassTag](implicit ev: TensorNumeric[T]) extends Pytho
|${iterName} = ${loaderName}.make_one_shot_iterator()
|""".stripMargin
}
def getLoader(nodeNumber: Int, partId: Int, localLoaderName: String): String = {
s"""
|by${partId} = bytes(b % 256 for b in pyjarray)
|func${partId} = CloudPickleSerializer.loads(CloudPickleSerializer, by${partId})
|${localLoaderName} = func${partId}().shard(${nodeNumber}, ${partId})
|""".stripMargin
}
def getNext(iterName: String): String = {
s"""
|data = sess.run(${iterName}.get_next())
|data = flatten(data)
|""".stripMargin
}
FeatureSet.python[MiniBatch[Float]](dataset,
getIterator, getNext,
getLoader, getIterator, getNext,
"data", "", totalSize, imports)
}

def createFeatureSetFromPyTorch(
dataloader: Array[Byte]): FeatureSet[MiniBatch[Float]] = {
val imports = s"""
|from zoo.util.nest import ptensor_to_numpy
|import torch
|from torch.utils.data import DataLoader
|
|""".stripMargin

def getIterator(iterName: String, loaderName: String): String = {
s"""
|if '${loaderName}_epoch' not in dir():
| ${loaderName}_epoch = 0
|else:
| ${loaderName}_epoch += 1
|${loaderName}_sampler.set_epoch(${loaderName}_epoch)
|${iterName} = enumerate(${loaderName})
|""".stripMargin
}

def getNext(iterName: String): String = {
// _index and _data will used in TorchModel and TorchLoss
s"""
|_index, _data = next(${iterName})
|""".stripMargin
}

def getLoader(nodeNumber: Int, partId: Int, localLoaderName: String): String = {
val load = s"""
|by${partId} = bytes(b % 256 for b in pyjarray)
|func${partId} = CloudPickleSerializer.loads(CloudPickleSerializer, by${partId})
|${localLoaderName} = func${partId}
|""".stripMargin
load +
s"""
|from torch.utils.data.distributed import DistributedSampler
|from torch.utils.data.sampler import RandomSampler
|from zoo.pipeline.api.torch.utils import DistributedSequentialSampler
|from torch.utils.data import DataLoader
|import math
|
|if isinstance(${localLoaderName}.sampler, RandomSampler):
| ${localLoaderName}_sampler=DistributedSampler(${localLoaderName}.dataset,
| ${nodeNumber}, ${partId}, True)
|else:
| ${localLoaderName}_sampler=DistributedSequentialSampler(${localLoaderName}.dataset,
| ${nodeNumber}, ${partId})
|
|bs_node = int(math.ceil(${localLoaderName}.batch_size / ${nodeNumber}))
|
|data_loader_args = {
| "dataset": ${localLoaderName}.dataset,
| "batch_size": bs_node,
| "shuffle": False,
| "num_workers": 0,
| "collate_fn": ${localLoaderName}.collate_fn,
| "drop_last": ${localLoaderName}.drop_last,
| "timeout": ${localLoaderName}.timeout,
| "worker_init_fn": ${localLoaderName}.worker_init_fn,
| "sampler": ${localLoaderName}_sampler
| }
|${localLoaderName} = DataLoader(**data_loader_args)
|""".stripMargin
}

FeatureSet.python[MiniBatch[Float]](dataloader, getLoader, getIterator, getNext,
"ptensor_to_numpy(_data[0])", "ptensor_to_numpy(_data[1])", -1, imports)
}

}

0 comments on commit 3b887c6

Please sign in to comment.