Skip to content

Commit

Permalink
[ONNX ] Update Trilu to accept single value tensor (openvinotoolkit…
Browse files Browse the repository at this point in the history
…#22286)

### Details:
 - update `trilu.cpp`

### Tickets:
 - Closes openvinotoolkit#21172

---------

Co-authored-by: Georgy Krivoruchko <[email protected]>
  • Loading branch information
2 people authored and bbielawx committed Apr 12, 2024
1 parent ec521d5 commit c5f7926
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/frontends/onnx/frontend/src/op/trilu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "openvino/op/range.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/unsqueeze.hpp"

using namespace ov::op;
Expand All @@ -37,9 +38,19 @@ ov::OutputVector trilu(const ov::frontend::onnx::Node& node) {
if (rank.is_static()) {
CHECK_VALID_NODE(node, rank.get_length() >= 2, "Trilu first input's rank must be >= 2");
}

Output<ov::Node> k;
bool is_k_available = num_inputs == 2 && !ov::op::util::is_null(inputs[1]);
if (is_k_available) {
CHECK_VALID_NODE(node, inputs[1].get_partial_shape().compatible({}), "Trilu second input must be a scalar");
// Trilu-14 documentation allows only 0-D tensor (scalar),
// but we extend support to be able work with 1-D with length == 1
k = inputs[1];
auto axes = v0::Constant::create(ov::element::i64, ov::Shape{}, {0});
// Check if k is a tensor with a single value
if (k.get_shape().size() == 1 && k.get_shape()[0] == 1) {
k = std::make_shared<v0::Squeeze>(k, axes);
}
CHECK_VALID_NODE(node, k.get_partial_shape().compatible({}), "Trilu second input must be a scalar");
}

const auto shape = std::make_shared<v3::ShapeOf>(input);
Expand Down Expand Up @@ -83,8 +94,7 @@ ov::OutputVector trilu(const ov::frontend::onnx::Node& node) {
// create 2D tensor with shape [N, 1] and values [[k], [k + 1], ..., [N + k - 1]]
std::shared_ptr<ov::Node> vertical_range;
if (is_k_available) {
vertical_range =
std::make_shared<v4::Range>(inputs[1], std::make_shared<v1::Add>(N, inputs[1]), one, ov::element::i64);
vertical_range = std::make_shared<v4::Range>(k, std::make_shared<v1::Add>(N, k), one, ov::element::i64);
} else {
vertical_range = std::make_shared<v4::Range>(zero, N, one, ov::element::i64);
}
Expand Down

0 comments on commit c5f7926

Please sign in to comment.