Skip to content

Commit

Permalink
[PT FE] Fix issue with aten.copy in FX graph
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Mar 27, 2024
1 parent 0307631 commit a7b5af7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
15 changes: 15 additions & 0 deletions src/frontends/pytorch/src/op/copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ OutputVector translate_copy_(const NodeContext& context) {
return {res};
};

OutputVector translate_copy_fx(const NodeContext& context) {
// copy = torch.ops.aten.copy.default(slice_4, clone);
num_inputs_check(context, 1, 2);
auto self = context.get_input(0);
if (context.input_is_none(1)) {
return {self};
} else {
auto src = context.get_input(1);
auto src_converted = context.mark_node(std::make_shared<v1::ConvertLike>(src, self));
auto self_shape = context.mark_node(std::make_shared<v3::ShapeOf>(self));
Output<Node> res = context.mark_node(std::make_shared<v3::Broadcast>(src_converted, self_shape));
return {res};
}
};

OutputVector translate_alias_copy(const NodeContext& context) {
// aten::alias_copy(Tensor self) -> Tensor
// aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
Expand Down
3 changes: 2 additions & 1 deletion src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ OP_CONVERTER(translate_batch_norm_legit_no_training_fx);
OP_CONVERTER(translate_batch_norm_legit_no_stats_fx);
OP_CONVERTER(translate_cat_fx);
OP_CONVERTER(translate_constant_pad_nd_fx);
OP_CONVERTER(translate_copy_fx);
OP_CONVERTER(translate_cumsum_fx);
OP_CONVERTER(translate_chunk_fx);
OP_CONVERTER(translate_div_fx);
Expand Down Expand Up @@ -779,7 +780,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.clone.default", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd
{"aten.constant_pad_nd.default", op::translate_constant_pad_nd_fx},
{"aten.convolution.default", op::translate_convolution},
{"aten.copy.default", op::skip_node},
{"aten.copy.default", op::translate_copy_fx},
{"aten.copy_.default", op::translate_copy_},
{"aten.cos.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Cos>},
{"aten.cosh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Cosh>},
Expand Down
17 changes: 17 additions & 0 deletions tests/layer_tests/pytorch_tests/test_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ def forward(self, x):
return y


class aten_alias_tensor(torch.nn.Module):
def forward(self, x):
y = x.clone()
n,c,h,w = x.shape
ones = torch.ones([2,h,w]).to(x.dtype)
y[:, 1:, :, :] = ones
return y


class aten_loop_alias(torch.nn.Module):
def forward(self, x):
y = x.clone()
Expand All @@ -36,6 +45,14 @@ def test_alias(self, ie_device, precision, ir_version):
"aten::copy_"],
ie_device, precision, ir_version)

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
def test_alias_tensor(self, ie_device, precision, ir_version):
self._test(aten_alias_tensor(), None, ["aten::slice",
"aten::copy_"],
ie_device, precision, ir_version, freeze_model=False)

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
Expand Down

0 comments on commit a7b5af7

Please sign in to comment.