diff --git a/python/dllib/src/bigdl/dllib/feature/common.py b/python/dllib/src/bigdl/dllib/feature/common.py index e68dc52b2d5..b480fb8f735 100644 --- a/python/dllib/src/bigdl/dllib/feature/common.py +++ b/python/dllib/src/bigdl/dllib/feature/common.py @@ -17,8 +17,10 @@ from bigdl.util.common import * from zoo.common.utils import callZooFunc from bigdl.dataset.dataset import DataSet -import sys from pyspark.serializers import CloudPickleSerializer +import sys +import math +import warnings if sys.version >= '3': long = int @@ -361,6 +363,25 @@ def tf_dataset(cls, func, total_size, bigdl_type="float"): jvalue = callZooFunc(bigdl_type, "createFeatureSetFromTfDataset", func, total_size) return cls(jvalue=jvalue) + @classmethod + def pytorch_dataloader(cls, dataloader, bigdl_type="float"): + """ + Create FeatureSet from pytorch dataloader + :param dataloader: a pytorch dataloader + :param bigdl_type: numeric type + :return: A feature set + """ + node_num, core_num = get_node_and_core_number() + if dataloader.batch_size % node_num != 0: + true_bs = math.ceil(dataloader.batch_size / node_num) * node_num + warning_msg = "Detect dataloader's batch_size is not divisible by node number(" + \ + node_num + "), will adjust batch_size to " + true_bs + " automatically" + warnings.warn(warning_msg) + + bys = CloudPickleSerializer.dumps(CloudPickleSerializer, dataloader) + jvalue = callZooFunc(bigdl_type, "createFeatureSetFromPyTorch", bys) + return cls(jvalue=jvalue) + def transform(self, transformer): """ Helper function to transform the data type in the data set.