Skip to content

Commit

Permalink
Implement Einsum reference in nGraph interpreter
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Kazantsev <[email protected]>
  • Loading branch information
rkazants committed May 31, 2021
1 parent c907fb9 commit 52fc92e
Show file tree
Hide file tree
Showing 7 changed files with 1,402 additions and 1 deletion.
24 changes: 24 additions & 0 deletions ngraph/core/reference/include/ngraph/runtime/reference/einsum.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <algorithm>

#include <ngraph/opsets/opset7.hpp>

namespace ngraph
{
namespace runtime
{
namespace reference
{
void einsum(const HostTensorVector& outputs, const HostTensorVector& inputs,
const std::string &equation,
const element::Type& input_type);
} // namespace reference

} // namespace runtime

} // namespace ngraph
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace ngraph
}
}

std::vector<size_t> get_transpose_order(const Shape& input_shape)
static std::vector<size_t> get_transpose_order(const Shape& input_shape)
{
size_t rank = input_shape.size();
NGRAPH_CHECK(rank > 1, "Invalid input for transpose");
Expand Down
Loading

0 comments on commit 52fc92e

Please sign in to comment.