Skip to content

Commit

Permalink
[PYTORCH]Minor bug fixes (apache#5683)
Browse files Browse the repository at this point in the history
* [PYTORCH]Minor bug fixes

* Review comment fix, testcase added

* Added testcase for bert model
  • Loading branch information
siju-samuel authored and trevor-m committed Jun 18, 2020
1 parent 73f1470 commit e1ff439
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 14 deletions.
58 changes: 44 additions & 14 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import infer_type as _infer_type
from ..prelude import Prelude, StaticTensorArrayOps

Expand Down Expand Up @@ -152,19 +153,33 @@ def _impl(inputs, input_types):

def _arange():
def _impl(inputs, input_types):
def _get_value(val, dtype):
if isinstance(val, _expr.Expr):
return _op.cast(val, _convert_data_type(dtype))
return _create_typed_const(val, dtype)

def _get_type(val, inp_type):
if isinstance(val, _expr.Expr):
dtype = str(_infer_type(val).checked_type)
return dtype if dtype != "float32" else "float"
return inp_type

if len(inputs) == 5:
dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1])
start = _create_typed_const(0, dtype)
stop = _create_typed_const(inputs[0], dtype)
step = _create_typed_const(1, dtype)
dtype0 = _get_type(inputs[0], input_types[0])
dtype = "float" if dtype0 == "float" else _convert_dtype_value(inputs[1])
start = _get_value(0, dtype)
stop = _get_value(inputs[0], dtype)
step = _get_value(1, dtype)
elif len(inputs) == 7:
dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
start = _create_typed_const(inputs[0], dtype)
stop = _create_typed_const(inputs[1], dtype)
step = _create_typed_const(inputs[2], dtype)
types = [_get_type(inputs[i], input_types[i]) for i in range(3)]
dtype = "float" if "float" in types else _convert_dtype_value(inputs[3])
start = _get_value(inputs[0], dtype)
stop = _get_value(inputs[1], dtype)
step = _get_value(inputs[2], dtype)
else:
msg = "Unknown number of arguments (%d) to parse." % (len(inputs))
raise AssertionError(msg)

return _op.transform.arange(start=start,
stop=stop,
step=step,
Expand Down Expand Up @@ -235,12 +250,18 @@ def _impl(inputs, input_types):

begin = [0] * len(end)
dim = int(inputs[1])
begin[dim] = int(inputs[2])
if isinstance(inputs[2], _expr.Call):
begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int))
else:
begin[dim] = int(inputs[2])

if isinstance(inputs[3], str) and inputs[3].isdigit():
end[dim] = min(end[dim], int(inputs[3]))
else:
end[dim] = inputs[3]
if isinstance(inputs[3], _expr.Call):
end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int))
else:
end[dim] = inputs[3]

strides.append(int(inputs[4]))
return _op.transform.strided_slice(data, begin, end, strides)
Expand Down Expand Up @@ -997,7 +1018,10 @@ def _impl(inputs, input_types):
def _numtotensor():
def _impl(inputs, input_types):
val = inputs[0]
dtype = type(val)
dtype = input_types[0]

if isinstance(val, _expr.Expr):
return val

if isinstance(val, tvm.tir.IntImm):
val = val.__int__()
Expand All @@ -1019,16 +1043,22 @@ def _impl(inputs, input_types):
data = inputs[0]

if len(inputs) == 3:
new_shape = [inputs[1], _infer_shape(inputs[2])[0]]
shape_inp = [inputs[1], _infer_shape(inputs[2])[0]]
else:
if isinstance(inputs[1], list):
new_shape = inputs[1]
shape_inp = inputs[1]
else:
new_shape = _infer_shape(inputs[1])
shape_inp = _infer_shape(inputs[1])
new_shape = shape_inp
for i, shape in enumerate(shape_inp):
if isinstance(shape, _expr.Expr):
val = _infer_value_simulated(shape, {})
new_shape[i] = np.asscalar(val.asnumpy())

return _op.transform.reshape(data, new_shape)
return _impl


def _reshape():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down
183 changes: 183 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,28 +381,61 @@ def test_forward_arange():
class Arange1(Module):
def forward(self, *args):
return torch.arange(5)

class Arange2(Module):
def forward(self, *args):
return torch.arange(2.5)

class Arange3(Module):
def forward(self, *args):
return torch.arange(1, 4)

class Arange4(Module):
def forward(self, *args):
return torch.arange(1, 2.5, 0.5)

class Arange5(Module):
def forward(self, *args):
return torch.arange(1, 2, 1, dtype=torch.int32)

class Arange6(Module):
def forward(self, *args):
return torch.arange(start=1, end=6, step=2)

class Arange7(Module):
def forward(self, *args):
return torch.arange(1, 4, dtype=torch.float32)

class Arange8(Module):
def forward(self, *args):
return torch.arange(1, 2, 1, dtype=torch.int16)

class Arange9(Module):
def forward(self, *args):
end = torch.add(torch.tensor(4), 1)
return torch.arange(end) + torch.ones((5,), dtype=torch.int64)

class Arange10(Module):
def forward(self, *args):
end = torch.add(torch.tensor(4.0), torch.tensor(1.0))
return torch.arange(end) + torch.ones((5,), dtype=torch.float)

class Arange11(Module):
def forward(self, *args):
start = torch.add(torch.tensor(1), 1)
end = torch.add(torch.tensor(4), 1)
step = torch.add(torch.tensor(2), 1)
out = torch.arange(start, end, step)
return out + torch.ones((3,), dtype=torch.int64)

class Arange12(Module):
def forward(self, *args):
start = torch.add(torch.tensor(1), 1)
end = torch.add(torch.tensor(4), 1)
step = torch.add(torch.tensor(2.5), torch.tensor(4.1))
out = torch.arange(start, end, step)
return out + torch.ones((3,), dtype=torch.float)

verify_model(Arange1().float().eval())
verify_model(Arange2().float().eval())
verify_model(Arange3().float().eval())
Expand All @@ -411,6 +444,11 @@ def forward(self, *args):
verify_model(Arange6().float().eval())
verify_model(Arange7().float().eval())
verify_model(Arange8().float().eval())
verify_model(Arange9().float().eval())
verify_model(Arange10().float().eval())
verify_model(Arange11().float().eval())
verify_model(Arange12().float().eval())


def test_forward_abs():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -810,9 +848,15 @@ class View2(Module):
def forward(self, *args):
return args[0].view(args[0].shape[0], -1)

class View3(Module):
def forward(self, *args):
d1 = torch.tensor(3) * torch.tensor(10) * torch.tensor(10)
return args[0].view(args[0].shape[0], d1)

input_data = torch.rand(input_shape).float()
verify_model(View1().float().eval(), input_data=input_data)
verify_model(View2().float().eval(), input_data=input_data)
verify_model(View3().float().eval(), input_data=input_data)

def test_forward_select():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -896,9 +940,17 @@ class Slice2(Module):
def forward(self, *args):
return args[0][0, :, :, :]

class Slice3(Module):
def forward(self, *args):
x0 = torch.tensor(2) - torch.tensor(1)
x1 = torch.tensor(3) + torch.tensor(1)
return args[0][:, x0:, :x1, :]

input_data = torch.rand(input_shape).float()
verify_model(Slice1().float().eval(), input_data=input_data)
verify_model(Slice2().float().eval(), input_data=input_data)
verify_model(Slice3().float().eval(), input_data=input_data)


def test_forward_mean():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -2157,6 +2209,134 @@ def forward(self, *args):
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])


def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM
# ---------------------------------------------------
"""
Refer the bert example given in https://pypi.org/project/pytorch-pretrained-bert
# To get started, pretrained bert package needs to be installed as prerequisite.
.. code-block:: bash
# install bert package
pip install pytorch_pretrained_bert==0.6.2 --user
"""

try:
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM
except:
print("Torch pretrained bert package must be installed to run this script.")
return

######################################################################
# Load the tokenizer and tokenize the input
# -----------------------------------------

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenized input
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet',
'##eer', '[SEP]']

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

######################################################################
# Load a pretrained PyTorch model bert-base-uncased
# -------------------------------------------------

# Bert Model with a language modeling
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

######################################################################
# Predict all tokens with pytorch
# -------------------------------

with torch.no_grad():
torch_preds = model(tokens_tensor, segments_tensors)

######################################################################
# Make TorchScripted model via jit trace
# --------------------------------------

scripted_model = torch.jit.trace(model, (tokens_tensor, segments_tensors)).eval()

######################################################################
# Import the graph to Relay
# -------------------------
# Convert PyTorch graph to Relay graph. The input name can be arbitrary.

input_1 = 'input_ids'
input_2 = 'input.2'
shape_list = [(input_1, list(tokens_tensor.shape)),
(input_2, list(segments_tensors.shape))]

mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

######################################################################
# Compile the model with relay
# ----------------------------

target = 'llvm'
with relay.build_config(opt_level=3):
relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params)

######################################################################
# Execute on TVM
# --------------

ctx = tvm.context(target, 0)
relay_model = graph_runtime.create(relay_graph, relay_lib, ctx)
relay_model.set_input(**relay_params)
relay_model.set_input(input_1, tokens_tensor)
relay_model.set_input(input_2, segments_tensors)
relay_model.run()
compiled_output = relay_model.get_output(0).asnumpy()

######################################################################
# Validate the outputs
# --------------------
# Compare the torch and tvm outputs

tvm.testing.assert_allclose(torch_preds, compiled_output, rtol=1e-3, atol=1e-3)

######################################################################
# Process the output
# ------------------
# Process the model output to token.

# Torch output to token
torch_pred_idx = torch.argmax(torch_preds[0, masked_index]).item()
torch_pred_token = tokenizer.convert_ids_to_tokens([torch_pred_idx])[0]

# TVM output to token
tvm_pred_idx = compiled_output[0, masked_index].argmax()
tvm_pred_token = tokenizer.convert_ids_to_tokens([tvm_pred_idx])[0]

assert torch_pred_idx == tvm_pred_idx
assert torch_pred_token == tvm_pred_token

# Print the outputs
print('Torch top-1 id: {}, token: {}'.format(torch_pred_idx, torch_pred_token))
print('TVM top-1 id: {}, token: {}'.format(tvm_pred_idx, tvm_pred_token))


if __name__ == "__main__":
# Single operator tests
test_forward_add()
Expand Down Expand Up @@ -2284,3 +2464,6 @@ def forward(self, *args):
from lstm_test import custom_lstm_test

custom_lstm_test()

# Test bert model
test_forward_pretrained_bert_base_uncased()

0 comments on commit e1ff439

Please sign in to comment.