-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
einsum.cpp
32 lines (27 loc) · 936 Bytes
/
einsum.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "op_table.hpp"
#include "openvino/opsets/opset8.hpp"
using namespace std;
using namespace ov::opset8;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_einsum_op(const NodeContext& node) {
auto op_type = node.get_op_type();
TENSORFLOW_OP_VALIDATION(node, op_type == "Einsum", "Internal error: incorrect usage of translate_einsum_op.");
auto equation = node.get_attribute<std::string>("equation");
OutputVector inputs;
for (size_t input_ind = 0; input_ind < node.get_input_size(); ++input_ind) {
inputs.push_back(node.get_input(input_ind));
}
auto einsum = make_shared<Einsum>(inputs, equation);
set_node_name(node.get_name(), einsum);
return {einsum};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov