Skip to content

Commit

Permalink
[TF FE][Tokenizers] Avoid dependency from TF FE in tokenizers (openvi…
Browse files Browse the repository at this point in the history
…notoolkit#26131)

**Details:** All required routines (Variable, HashTable) are moved to
common FE API

**Ticket:** 148101

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Aug 20, 2024
1 parent 595ce79 commit db22d0a
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 214 deletions.
135 changes: 135 additions & 0 deletions src/frontends/common/include/openvino/frontend/hash_table.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/node_output.hpp"
#include "openvino/core/type/element_type.hpp"
#include "openvino/frontend/variable.hpp"
#include "openvino/frontend/visibility.hpp"

namespace ov {
namespace frontend {

/// \brief HashTable is a special type of Variable that has a complex value including keys and values.
/// Keys and values are represented with two separate graph at each time step
class FRONTEND_API HashTable : public Variable {
public:
using Ptr = std::shared_ptr<HashTable>;
OPENVINO_OP("HashTable", "ov::frontend", Variable);

HashTable(const std::string& name,
const ov::element::Type& key_type,
const ov::element::Type& value_type,
const std::shared_ptr<DecoderBase>& decoder = nullptr)
: Variable(name, decoder),
m_key_type(key_type),
m_value_type(value_type) {
validate_and_infer_types();
}

HashTable(const HashTable& other, const ov::Output<ov::Node>& keys, const ov::Output<ov::Node>& values)
: HashTable(other) {
m_keys = keys;
m_values = values;
m_is_initialized = true;
++m_init_counter;
}

// it must be used only for cloning
// other ways are illegal
HashTable(const std::string& name,
const ov::element::Type& key_type,
const ov::element::Type& value_type,
const ov::Output<ov::Node>& keys,
const ov::Output<ov::Node>& values,
bool is_initialized,
uint64_t init_counter,
const std::shared_ptr<DecoderBase>& decoder = nullptr)
: Variable(name, decoder),
m_key_type(key_type),
m_value_type(value_type),
m_keys(keys),
m_values(values) {
m_init_counter = init_counter;
m_is_initialized = is_initialized;
validate_and_infer_types();
}

void validate_and_infer_types() override {
// this is a type of resource so its shape and type is not applicable
// its output serves to store a reference to a resource
set_output_type(0, ov::element::dynamic, ov::PartialShape::dynamic());
// these two outputs serves to store keys and values of a resource
// keys and values are 1D tensors
set_output_type(1, m_key_type, ov::PartialShape::dynamic(1));
set_output_type(2, m_value_type, ov::PartialShape::dynamic(1));
}

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override {
auto hash_table_node = std::make_shared<HashTable>(m_name,
m_key_type,
m_value_type,
m_keys,
m_values,
m_is_initialized,
m_init_counter,
m_decoder);
hash_table_node->set_attrs(get_attrs());
return hash_table_node;
}

ov::Output<ov::Node> get_value() override {
return output(0);
}

/// \brief Returns a value corresponding keys of hash table
ov::Output<ov::Node> get_keys() {
if (m_is_initialized) {
return m_keys;
} else if (m_other_keys.size() > 0) {
return *(m_other_keys.begin());
}

return output(1);
}

/// \brief Returns a value corresponding values of hash table
ov::Output<ov::Node> get_values() {
if (m_is_initialized) {
return m_values;
} else if (m_other_values.size() > 0) {
return *(m_other_values.begin());
}

return output(2);
}

ov::element::Type get_key_type() const {
return m_key_type;
}

ov::element::Type get_value_type() const {
return m_value_type;
}

void add_other_keys_values(const ov::Output<ov::Node>& other_key, const ov::Output<ov::Node>& other_value) {
m_other_keys.insert(other_key);
m_other_values.insert(other_value);
}

virtual ~HashTable();

private:
ov::element::Type m_key_type;
ov::element::Type m_value_type;
ov::Output<ov::Node> m_keys;
ov::Output<ov::Node> m_values;

std::set<ov::Output<ov::Node>> m_other_keys;
std::set<ov::Output<ov::Node>> m_other_values;
};

} // namespace frontend
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ class FRONTEND_API NodeContext {
FRONT_END_NOT_IMPLEMENTED(get_input);
}

/// \brief Returns the input by reference. The reference value can be changed by consuming operation
virtual Output<Node> get_input_by_reference(int idx) const {
FRONT_END_NOT_IMPLEMENTED(get_input_by_reference);
}

/// \brief Returns values from Constant input with the given index as ov::Any.
/// Throws an exception if the input cannot be represented as Constant.
virtual Any get_values_from_const_input(int idx) const {
Expand Down
117 changes: 117 additions & 0 deletions src/frontends/common/include/openvino/frontend/variable.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/frontend/decoder.hpp"
#include "openvino/frontend/visibility.hpp"
#include "openvino/op/util/framework_node.hpp"

namespace ov {
namespace frontend {

/// \brief Variable is a special node used in a conversion step
/// It can have several values (or states) during the conversion.
/// Variable value at some time step is represented with a graph.
class FRONTEND_API Variable : public ov::op::util::FrameworkNode {
public:
using Ptr = std::shared_ptr<Variable>;
OPENVINO_OP("Variable", "ov::frontend", ov::op::util::FrameworkNode);

Variable(const std::string& name, const std::shared_ptr<DecoderBase>& decoder)
: ov::op::util::FrameworkNode(ov::OutputVector{}, 1),
m_name(name),
m_shape(ov::Shape{}),
m_type(ov::element::dynamic),
m_decoder(decoder),
m_is_initialized(false),
m_init_counter(0) {
validate_and_infer_types();
}

Variable(const std::string& name,
const ov::Shape& shape,
const ov::element::Type& type,
const std::shared_ptr<DecoderBase>& decoder)
: ov::op::util::FrameworkNode(ov::OutputVector{}, 1),
m_name(name),
m_shape(shape),
m_type(type),
m_decoder(decoder),
m_is_initialized(false),
m_init_counter(0) {
validate_and_infer_types();
}

Variable(const std::string& name,
const ov::Shape& shape,
const ov::element::Type& type,
const ov::Output<ov::Node>& value,
const std::shared_ptr<DecoderBase>& decoder)
: Variable(name, shape, type, decoder) {
m_value = value;
// reset names of tensor corresponding to variable value
// that is because variable can have multiple values during inference
m_value.set_names({});
m_is_initialized = true;
++m_init_counter;
}

Variable(const Variable& other, const ov::Output<ov::Node>& value) : Variable(other) {
m_value = value;
// reset names of tensor corresponding to variable value
// that is because variable can have multiple values during inference
m_value.set_names({});
m_is_initialized = true;
++m_init_counter;
}

void validate_and_infer_types() override {
set_output_type(0, m_type, m_shape);
}

/// \brief Checks if variable is initialized with some value
bool is_initialized() const {
return m_is_initialized;
}

/// \brief Returns a value at the current step of conversion
virtual ov::Output<ov::Node> get_value() {
FRONT_END_GENERAL_CHECK(m_is_initialized, "internal error: get_value() is called for uninitialized variable");
return m_value;
}

std::string get_name() const {
return m_name;
}

/// \brief Returns a counter value (a number of values that have assigned to this variable)
uint64_t get_init_counter() const {
return m_init_counter;
}

std::shared_ptr<DecoderBase> get_decoder() const {
return m_decoder;
}

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override {
auto new_variable = std::make_shared<Variable>(*this);
new_variable->set_attrs(get_attrs());
return new_variable;
}

virtual ~Variable();

protected:
std::string m_name;
ov::Shape m_shape;
ov::element::Type m_type;
std::shared_ptr<DecoderBase> m_decoder;
bool m_is_initialized;
ov::Output<ov::Node> m_value;
uint64_t m_init_counter;
};

} // namespace frontend
} // namespace ov
9 changes: 9 additions & 0 deletions src/frontends/common/src/hash_table.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/hash_table.hpp"

using namespace ov::frontend;

HashTable::~HashTable(){};
9 changes: 9 additions & 0 deletions src/frontends/common/src/variable.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/variable.hpp"

using namespace ov::frontend;

Variable::~Variable(){};
Loading

0 comments on commit db22d0a

Please sign in to comment.