Skip to content

Commit

Permalink
Review comment fix, testcase added
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed May 28, 2020
1 parent 7bd4309 commit 4d9bb47
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,28 +381,61 @@ def test_forward_arange():
class Arange1(Module):
def forward(self, *args):
return torch.arange(5)

class Arange2(Module):
def forward(self, *args):
return torch.arange(2.5)

class Arange3(Module):
def forward(self, *args):
return torch.arange(1, 4)

class Arange4(Module):
def forward(self, *args):
return torch.arange(1, 2.5, 0.5)

class Arange5(Module):
def forward(self, *args):
return torch.arange(1, 2, 1, dtype=torch.int32)

class Arange6(Module):
def forward(self, *args):
return torch.arange(start=1, end=6, step=2)

class Arange7(Module):
def forward(self, *args):
return torch.arange(1, 4, dtype=torch.float32)

class Arange8(Module):
def forward(self, *args):
return torch.arange(1, 2, 1, dtype=torch.int16)

class Arange9(Module):
def forward(self, *args):
end = torch.add(torch.tensor(4), 1)
return torch.arange(end) + torch.ones((5,), dtype=torch.int64)

class Arange10(Module):
def forward(self, *args):
end = torch.add(torch.tensor(4.0), torch.tensor(1.0))
return torch.arange(end) + torch.ones((5,), dtype=torch.float)

class Arange11(Module):
def forward(self, *args):
start = torch.add(torch.tensor(1), 1)
end = torch.add(torch.tensor(4), 1)
step = torch.add(torch.tensor(2), 1)
out = torch.arange(start, end, step)
return out + torch.ones((3,), dtype=torch.int64)

class Arange12(Module):
def forward(self, *args):
start = torch.add(torch.tensor(1), 1)
end = torch.add(torch.tensor(4), 1)
step = torch.add(torch.tensor(2.5), torch.tensor(4.1))
out = torch.arange(start, end, step)
return out + torch.ones((3,), dtype=torch.float)

verify_model(Arange1().float().eval())
verify_model(Arange2().float().eval())
verify_model(Arange3().float().eval())
Expand All @@ -411,6 +444,11 @@ def forward(self, *args):
verify_model(Arange6().float().eval())
verify_model(Arange7().float().eval())
verify_model(Arange8().float().eval())
verify_model(Arange9().float().eval())
verify_model(Arange10().float().eval())
verify_model(Arange11().float().eval())
verify_model(Arange12().float().eval())


def test_forward_abs():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -810,9 +848,15 @@ class View2(Module):
def forward(self, *args):
return args[0].view(args[0].shape[0], -1)

class View3(Module):
def forward(self, *args):
d1 = torch.tensor(3) * torch.tensor(10) * torch.tensor(10)
return args[0].view(args[0].shape[0], d1)

input_data = torch.rand(input_shape).float()
verify_model(View1().float().eval(), input_data=input_data)
verify_model(View2().float().eval(), input_data=input_data)
verify_model(View3().float().eval(), input_data=input_data)

def test_forward_select():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -896,9 +940,17 @@ class Slice2(Module):
def forward(self, *args):
return args[0][0, :, :, :]

class Slice3(Module):
def forward(self, *args):
x0 = torch.tensor(2) - torch.tensor(1)
x1 = torch.tensor(3) + torch.tensor(1)
return args[0][:, x0:, :x1, :]

input_data = torch.rand(input_shape).float()
verify_model(Slice1().float().eval(), input_data=input_data)
verify_model(Slice2().float().eval(), input_data=input_data)
verify_model(Slice3().float().eval(), input_data=input_data)


def test_forward_mean():
torch.set_grad_enabled(False)
Expand Down

0 comments on commit 4d9bb47

Please sign in to comment.