Skip to content

Commit

Permalink
remove copy_
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Sep 17, 2020
1 parent b32f5c4 commit 9a77478
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 44 deletions.
10 changes: 0 additions & 10 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,15 +1756,6 @@ def _impl(inputs, input_types):
return _impl


def _copy_():
def _impl(inputs, input_types):
# use add to help handle broadcasting
rel = _op.zeros_like(inputs[0])
return _op.add(rel, inputs[1])

return _impl


def _none():
def _impl(inputs, input_types):
return None
Expand Down Expand Up @@ -2641,7 +2632,6 @@ def _get_convert_map(prelude, default_dtype):
"aten::isnan": _unary("isnan"),
"aten::clamp": _clamp(),
"aten::clamp_": _clamp(),
"aten::copy_": _copy_(),
"aten::detach": _identity(),
"aten::upsample_bilinear2d": _upsample("bilinear", prelude),
"aten::upsample_nearest2d": _upsample("nearest_neighbor", prelude),
Expand Down
41 changes: 7 additions & 34 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2416,34 +2416,6 @@ def forward(self, *args):
verify_model(ClampInPlace(min, max).float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_copy_():
torch.set_grad_enabled(False)

class Copy(Module):
def __init__(self):
super(Copy, self).__init__()

def forward(self, *args):
return torch.Tensor.copy_(args[0], args[1])

class CopyInPlace(Module):
def __init__(self):
super(CopyInPlace, self).__init__()

def forward(self, *args):
a = args[0]
b = args[1]
c = torch.Tensor.copy_(a, b)
return a

src_tensor = torch.rand(5)
tgt_tensor = torch.rand((2, 3, 5))
for copy in [Copy, CopyInPlace]:
verify_model(copy().float().eval(), input_data=[tgt_tensor, src_tensor])
verify_model(copy().float().eval(), input_data=[tgt_tensor, src_tensor + tgt_tensor])


@tvm.testing.uses_gpu
def test_forward_ones():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -2940,7 +2912,7 @@ def forward(self, *args):
t2 = torch.rand([1, 3]).float()
verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])


@tvm.testing.uses_gpu
def test_forward_true_divide():
torch.set_grad_enabled(False)

Expand All @@ -2950,10 +2922,12 @@ def forward(self, *args):

dividend = torch.rand([5, 3]).float()
# divisor could be either tensor or scalar
divisor_tensor = torch.rand([5, 3]).float()
divisor_scalar = divisor = torch.tensor(1.0, dtype=torch.float32)
verify_model(TrueDivide().float().eval(), input_data=[dividend, divisor_tensor])
verify_model(TrueDivide().float().eval(), input_data=[dividend, divisor_scalar])
divisor_tensor = torch.rand([5, 3]).float() + 0.5
divisor_scalar = torch.tensor(1.0, dtype=torch.float32)
verify_model(TrueDivide().float().eval(),
input_data=[dividend, divisor_tensor], atol=1e-4, rtol=1e-4)
verify_model(TrueDivide().float().eval(),
input_data=[dividend, divisor_scalar], atol=1e-4, rtol=1e-4)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -3386,7 +3360,6 @@ def test_forward_pretrained_bert_base_uncased():
test_forward_unary()
test_forward_clamp()
test_forward_clamp_()
test_forward_copy_()
test_forward_logical_not()
test_forward_bitwise_not()
test_forward_bitwise_xor()
Expand Down

0 comments on commit 9a77478

Please sign in to comment.