Skip to content

Commit

Permalink
[PT FE]: support aten::broadcast_tensors (openvinotoolkit#19994)
Browse files Browse the repository at this point in the history
* broadcast tensors

* [PT FE]: support aten::broadcast_tensors

* apply review comments

* remove add
  • Loading branch information
eaidova authored Sep 22, 2023
1 parent 2151e5f commit 26d18c9
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,29 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
}
}

if (auto broadcast_tensors = cast_fw_node(input_node, "aten::broadcast_tensors")) {
auto tensors = cast_fw_node(broadcast_tensors->input_value(0).get_node_shared_ptr(), "prim::ListConstruct");
if (!tensors) {
add_exception_to_fw_node(input_node,
"aten::broadcast_tensors: only prim::ListConstruct supported as input.");
return false;
}
Output<Node> final_shape_t = opset10::Constant::create(element::i32, Shape{}, {0});
for (auto input : tensors->inputs()) {
auto tensor_shape = rg.make<opset10::ShapeOf>(input.get_source_output(), element::i32);
final_shape_t =
rg.make<opset10::Broadcast>(final_shape_t, tensor_shape, ov::op::BroadcastType::BIDIRECTIONAL);
}
auto final_shape = rg.make<opset10::ShapeOf>(final_shape_t, element::i32);
OutputVector outputs;
for (auto input : tensors->inputs()) {
outputs.push_back(rg.make<opset10::Broadcast>(input.get_source_output(), final_shape));
}
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
replace_node(list_unpack, outputs);
return true;
}

if (auto unbind = cast_fw_node(input_node, "aten::unbind")) {
const auto input = unbind->get_input_source_output(0);
const auto axis = unbind->get_input_source_output(1);
Expand Down
45 changes: 45 additions & 0 deletions tests/layer_tests/pytorch_tests/test_broadcast_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest

from pytorch_layer_test_class import PytorchLayerTest


class TestBroadcastTensors(PytorchLayerTest):
def _prepare_input(self, x_shape, y_shape, z_shape, x_dtype, y_dtype, z_dtype):
import numpy as np
return (
np.random.randn(*x_shape).astype(x_dtype),
np.random.randn(*y_shape).astype(y_dtype),
np.random.randn(*z_shape).astype(z_dtype))

def create_model(self):
import torch

class aten_broadcast_tensors(torch.nn.Module):
def __init__(self):
super(aten_broadcast_tensors, self).__init__()

def forward(self, x, y, z):
x1, y1, z1 = torch.broadcast_tensors(x, y, z)
return x1, y1, z1

ref_net = None

return aten_broadcast_tensors(), ref_net, ("prim::ListConstruct", "aten::broadcast_tensors", "prim::ListUnpack")

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("x_shape", [[1, ], [2, 1], [2, 2, 1]])
@pytest.mark.parametrize("y_shape", [[2, ], [1, 2], [1, 2, 1]])
@pytest.mark.parametrize("z_shape", [[1, 2], [2, 2], [1, 2, 1, 1]])
@pytest.mark.parametrize("x_dtype", ["float32", "int32"])
@pytest.mark.parametrize("y_dtype", ["float32", "int32"])
@pytest.mark.parametrize("z_dtype", ["float32", "int32"])
def test_broadcast_tensors(self, x_shape, y_shape, z_shape, x_dtype, y_dtype, z_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
"x_shape": x_shape, "x_dtype": x_dtype,
"y_shape": y_shape, "y_dtype": y_dtype,
"z_shape": z_shape, "z_dtype": z_dtype,
})

0 comments on commit 26d18c9

Please sign in to comment.