Skip to content

Commit

Permalink
Pytorchloader and some pytorch examples (intel-analytics#2318)
Browse files Browse the repository at this point in the history
* pytorch loader

* some fix

* update to distributed sampler

* add distributedseqsampler

* some clean up

* delete size

* update example

* Create README.md

* Update README.md

* update example

* update example

* some update

* some change

* fix python style check

* Update README.md

* some update

* meet code review

* clean up

* some update

* some fix

* update main.py

* Update README.md

* some update

* meet code review

* some fix

* fix unit test

* fix ut

* add toto

* fix rebase
  • Loading branch information
qiuxin2012 authored May 26, 2020
1 parent 3f8e48d commit 58c86cb
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion python/dllib/src/bigdl/dllib/feature/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from bigdl.util.common import *
from zoo.common.utils import callZooFunc
from bigdl.dataset.dataset import DataSet
import sys
from pyspark.serializers import CloudPickleSerializer
import sys
import math
import warnings

if sys.version >= '3':
long = int
Expand Down Expand Up @@ -361,6 +363,25 @@ def tf_dataset(cls, func, total_size, bigdl_type="float"):
jvalue = callZooFunc(bigdl_type, "createFeatureSetFromTfDataset", func, total_size)
return cls(jvalue=jvalue)

@classmethod
def pytorch_dataloader(cls, dataloader, bigdl_type="float"):
"""
Create FeatureSet from pytorch dataloader
:param dataloader: a pytorch dataloader
:param bigdl_type: numeric type
:return: A feature set
"""
node_num, core_num = get_node_and_core_number()
if dataloader.batch_size % node_num != 0:
true_bs = math.ceil(dataloader.batch_size / node_num) * node_num
warning_msg = "Detect dataloader's batch_size is not divisible by node number(" + \
node_num + "), will adjust batch_size to " + true_bs + " automatically"
warnings.warn(warning_msg)

bys = CloudPickleSerializer.dumps(CloudPickleSerializer, dataloader)
jvalue = callZooFunc(bigdl_type, "createFeatureSetFromPyTorch", bys)
return cls(jvalue=jvalue)

def transform(self, transformer):
"""
Helper function to transform the data type in the data set.
Expand Down

0 comments on commit 58c86cb

Please sign in to comment.