Skip to content

Commit

Permalink
pytorch doc and example update. (intel-analytics#1560)
Browse files Browse the repository at this point in the history
* pytorch doc

* update example

* remove transfomer

* doc and example update

* fix sample_input

* check none

* param update

* check trace
  • Loading branch information
YY-OnCall authored Aug 15, 2019
1 parent 976fa31 commit cc1d954
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 38 deletions.
42 changes: 22 additions & 20 deletions pyspark/bigdl/dllib/inference/net/torch_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,35 +50,37 @@ def __init__(self, path, bigdl_type="float"):
super(TorchCriterion, self).__init__(None, bigdl_type, path)

@staticmethod
def from_pytorch(loss, input_shape=None, label_shape=None,
sample_input=None, sample_label=None):
def from_pytorch(loss, input, label=None):
"""
Create a TorchCriterion directly from PyTorch function. We need user to provide a sample
input and label to trace the loss function. User may just specify the input and label shape.
For specific data type or multiple input models, users can send sample_input and
sample_label.
Create a TorchCriterion directly from PyTorch function. We need users to provide example
input and label (or just their sizes) to trace the loss function.
:param loss: this can be a torch loss (e.g. nn.MSELoss()) or
a function that take two Tensor parameters: input and label. E.g.
a function that takes two Tensor parameters: input and label. E.g.
def lossFunc(input, target):
return nn.CrossEntropyLoss().forward(input, target.flatten().long())
:param input_shape: list of integers.
:param label_shape: list of integers. If not specified, it will be set equal to input_shape
:param sample_input: a sample of input.
:param sample_label: a sample of label.
:param input: example input. It can be:
1. a torch tensor, or tuple of torch tensors for multi-input models
2. list of integers, or tuple of int list for multi-input models. E.g. For
ResNet, this can be [1, 3, 224, 224]. A random tensor with the
specified size will be used as the example input.
:param label: example label. It can be:
1. a torch tensor, or tuple of torch tensors for multi-input models
2. list of integers, or tuple of int list for multi-input models. E.g. For
ResNet, this can be [1, 3, 224, 224]. A random tensor with the
specified size will be used as the example input.
When label is None, input will also be used as label.
"""
if not input_shape and not label_shape and not sample_input and not sample_label:
raise Exception("please specify input_shape and label_shape, or sample_input"
" and sample_label")
if input is None:
raise Exception("please specify input and label")

temp = tempfile.mkdtemp()

# use input_shape as label shape when label_shape is not specified
if not label_shape:
label_shape = input_shape
if label is None:
label = input

sample_input = TorchNet.get_sample_input(input_shape, sample_input)
sample_label = TorchNet.get_sample_input(label_shape, sample_label)
sample_input = TorchNet.get_sample_input(input)
sample_label = TorchNet.get_sample_input(label)

traced_script_loss = torch.jit.trace(LossWrapperModule(loss),
(sample_input, sample_label))
Expand Down
49 changes: 31 additions & 18 deletions pyspark/bigdl/dllib/inference/net/torch_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,31 @@ def __init__(self, path, bigdl_type="float"):
super(TorchNet, self).__init__(None, bigdl_type, path)

@staticmethod
def from_pytorch(module, input_shape=None, sample_input=None):
def from_pytorch(module, input, check_trace=True):
"""
Create a TorchNet directly from PyTorch model, e.g. model in torchvision.models.
Users need to specify sample_input or input_shape.
Users need to provide an example input or the input tensor shape.
:param module: a PyTorch model
:param input_shape: list of integers, or tuple of list for multiple inputs models. E.g.
for ResNet, this may be [1, 3, 224, 224]
:param sample_input. A sample of Torch Tensor or tuple to trace the model.
:param input: To trace the tensor operations, torch jit trace requires users to
provide an example input. Here the input parameter can be:
1. a torch tensor, or tuple of torch tensors for multi-input models
2. list of integers, or tuple of int list for multi-input models. E.g. For
ResNet, this can be [1, 3, 224, 224]. A random tensor with the
specified size will be used as the example input.
:param check_trace: boolean value, optional. check if the same inputs run through
traced module produce the same outputs. Default: ``True``. You
might want to disable this if, for example, your network contains
non-deterministic ops or if you are sure that the network is
correct despite a checker failure.
"""
if not input_shape and not sample_input:
raise Exception("please specify input_shape or sample_input")
if input is None:
raise Exception("please provide an example input or input Tensor size")

sample = TorchNet.get_sample_input(input_shape, sample_input)
sample = TorchNet.get_sample_input(input)
temp = tempfile.mkdtemp()

# save model
traced_script_module = torch.jit.trace(module, sample)
traced_script_module = torch.jit.trace(module, sample, check_trace=check_trace)
path = os.path.join(temp, "model.pt")
traced_script_module.save(path)

Expand All @@ -69,15 +77,20 @@ def from_pytorch(module, input_shape=None, sample_input=None):
return net

@staticmethod
def get_sample_input(shape, sample):
if sample:
return sample
elif isinstance(shape, list):
return torch.rand(shape)
elif isinstance(shape, tuple):
return tuple(map(lambda s: torch.rand(s), shape))
else:
raise Exception("please specify shape as list of ints or tuples of int lists")
def get_sample_input(input):
if isinstance(input, torch.Tensor):
return input

elif isinstance(input, (list, tuple)) and len(input) > 0:
if all(isinstance(x, torch.Tensor) for x in input): # tensors
return tuple(input)
elif all(isinstance(x, int) for x in input): # ints
return torch.rand(input)
elif all(isinstance(x, (list, tuple)) for x in input) and \
all(isinstance(y, int) for x in input for y in x): # nested int list (tuple)
return tuple(map(lambda size: torch.rand(size), input))

raise Exception("Unsupported input type: " + str(input))

def predict(self, x, batch_per_thread=1, distributed=True):
"""
Expand Down

0 comments on commit cc1d954

Please sign in to comment.