Skip to content

Commit

Permalink
[TF FE] Support StringToHashBucketFast operation (#115)
Browse files Browse the repository at this point in the history
* [TF FE] Support StringToHashBucketFast operation

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix build issue

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix attribute type

Signed-off-by: Kazantsev, Roman <[email protected]>

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Apr 9, 2024
1 parent fd220db commit f1a34f0
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/ov_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("LookupTableFindV2", translate_lookup_table_find_op), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("StringSplitV2", translate_string_split), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("RaggedTensorToTensor", translate_ragged_tensor_to_tensor), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("Equal", translate_equal)
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("Equal", translate_equal), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("StringToHashBucketFast", translate_string_to_hash_bucket_fast)
#else
#define OPENVINO_TOKENIZERS_TENSORFLOW_CONVERSION_EXTENSIONS
#endif
Expand All @@ -47,6 +48,7 @@ OPENVINO_CREATE_EXTENSIONS(
std::make_shared<ov::OpExtension<RaggedToDense>>(),
std::make_shared<ov::OpExtension<RaggedToSparse>>(),
std::make_shared<ov::OpExtension<RaggedToRagged>>(),
std::make_shared<ov::OpExtension<StringToHashBucket>>(),
std::make_shared<ov::OpExtension<VocabEncoder>>(),
std::make_shared<ov::OpExtension<VocabDecoder>>(),
std::make_shared<ov::OpExtension<CharsToBytes>>(),
Expand Down
218 changes: 218 additions & 0 deletions src/string_to_hash_bucket.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "string_to_hash_bucket.hpp"
#include "utils.hpp"

using namespace ov;

namespace {
static const uint64_t k0 = 0xc3a5c85c97cb3127ULL;
static const uint64_t k1 = 0xb492b66fbe98f273ULL;
static const uint64_t k2 = 0x9ae16a3b2f90404fULL;

uint64_t hash_len16(uint64_t u, uint64_t v, uint64_t mul) {
uint64_t a = (u ^ v) * mul;
a ^= (a >> 47);
uint64_t b = (v ^ a) * mul;
b ^= (b >> 47);
b *= mul;
return b;
}

inline uint64_t basic_rotate64(uint64_t val, int shift) {
return shift == 0 ? val : ((val >> shift) | (val << (64 - shift)));
}

inline uint64_t fetch(const char* p) {
uint64_t result;
std::memcpy(&result, p, sizeof(result));
return result;
}

#if defined(_MSC_VER)

uint64_t rotate(uint64_t val, int shift) {
return sizeof(unsigned long) == sizeof(val) ? _lrotr(val, shift) : basic_rotate64(val, shift);
}

#else

uint64_t rotate(uint64_t val, int shift) {
return basic_rotate64(val, shift);
}

#endif

uint64_t hash_len17_to_32(const char* s, size_t len) {
uint64_t mul = k2 + len * 2;
uint64_t a = fetch(s) * k1;
uint64_t b = fetch(s + 8);
uint64_t c = fetch(s + len - 8) * mul;
uint64_t d = fetch(s + len - 16) * k2;
return hash_len16(rotate(a + b, 43) + rotate(c, 30) + d, a + rotate(b + k2, 18) + c, mul);
}

inline uint64_t shift_mix(uint64_t val) {
return val ^ (val >> 47);
}

inline uint32_t fetch32(const char* p) {
uint32_t result;
memcpy(&result, p, sizeof(result));
return result;
}


uint64_t hash_len0_to_16(const char* s, size_t len) {
if (len >= 8) {
uint64_t mul = k2 + len * 2;
uint64_t a = fetch(s) + k2;
uint64_t b = fetch(s + len - 8);
uint64_t c = rotate(b, 37) * mul + a;
uint64_t d = (rotate(a, 25) + b) * mul;
return hash_len16(c, d, mul);
}
if (len >= 4) {
uint64_t mul = k2 + len * 2;
uint64_t a = fetch32(s);
return hash_len16(len + (a << 3), fetch32(s + len - 4), mul);
}
if (len > 0) {
uint8_t a = s[0];
uint8_t b = s[len >> 1];
uint8_t c = s[len - 1];
uint32_t y = static_cast<uint32_t>(a) + (static_cast<uint32_t>(b) << 8);
uint32_t z = len + (static_cast<uint32_t>(c) << 2);
return shift_mix(y * k2 ^ z * k0) * k2;
}
return k2;
}

uint64_t hash_len33_to_64(const char* s, size_t len) {
uint64_t mul = k2 + len * 2;
uint64_t a = fetch(s) * k2;
uint64_t b = fetch(s + 8);
uint64_t c = fetch(s + len - 8) * mul;
uint64_t d = fetch(s + len - 16) * k2;
uint64_t y = rotate(a + b, 43) + rotate(c, 30) + d;
uint64_t z = hash_len16(y, a + rotate(b + k2, 18) + c, mul);
uint64_t e = fetch(s + 16) * mul;
uint64_t f = fetch(s + 24);
uint64_t g = (y + fetch(s + len - 32)) * mul;
uint64_t h = (z + fetch(s + len - 24)) * mul;
return hash_len16(rotate(e + f, 43) + rotate(g, 30) + h, e + rotate(f + a, 18) + g, mul);
}

std::pair<uint64_t, uint64_t> weak_hash_len32_with_seeds(uint64_t w,
uint64_t x,
uint64_t y,
uint64_t z,
uint64_t a,
uint64_t b) {
a += w;
b = rotate(b + a + z, 21);
uint64_t c = a;
a += x;
a += y;
b += rotate(a, 44);
return std::make_pair(a + z, b + c);
}

std::pair<uint64_t, uint64_t> weak_hash_len32_with_seeds(const char* s, uint64_t a, uint64_t b) {
return weak_hash_len32_with_seeds(fetch(s), fetch(s + 8), fetch(s + 16), fetch(s + 24), a, b);
}

uint64_t hash64(const char* s, size_t len) {
const uint64_t seed = 81;
if (len <= 32) {
if (len <= 16) {
return hash_len0_to_16(s, len);
}
else {
return hash_len17_to_32(s, len);
}
}
else if (len <= 64) {
return hash_len33_to_64(s, len);
}

// For strings over 64 bytes we loop. Internal state consists of
// 56 bytes: v, w, x, y, and z.
uint64_t x = seed;
uint64_t y = seed * k1 + 113;
uint64_t z = shift_mix(y * k2 + 113) * k2;
std::pair<uint64_t, uint64_t> v = std::make_pair(0, 0);
std::pair<uint64_t, uint64_t> w = std::make_pair(0, 0);
x = x * k2 + fetch(s);

// Set end so that after the loop we have 1 to 64 bytes left to process.
const char* end = s + ((len - 1) / 64) * 64;
const char* last64 = end + ((len - 1) & 63) - 63;
do {
x = rotate(x + y + v.first + fetch(s + 8), 37) * k1;
y = rotate(y + v.second + fetch(s + 48), 42) * k1;
x ^= w.second;
y += v.first + fetch(s + 40);
z = rotate(z + w.first, 33) * k1;
v = weak_hash_len32_with_seeds(s, v.second * k1, x + w.first);
w = weak_hash_len32_with_seeds(s + 32, z + w.second, y + fetch(s + 16));
std::swap(z, x);
s += 64;
} while (s != end);
uint64_t mul = k1 + ((z & 0xff) << 1);
// Make s point to the last 64 bytes of input.
s = last64;
w.first += ((len - 1) & 63);
v.first += w.first;
w.first += v.first;
x = rotate(x + y + v.first + fetch(s + 8), 37) * mul;
y = rotate(y + v.second + fetch(s + 48), 42) * mul;
x ^= w.second * 9;
y += v.first * 9 + fetch(s + 40);
z = rotate(z + w.first, 33) * mul;
v = weak_hash_len32_with_seeds(s, v.second * mul, x + w.first);
w = weak_hash_len32_with_seeds(s + 32, z + w.second, y + fetch(s + 16));
std::swap(z, x);
return hash_len16(hash_len16(v.first, w.first, mul) + shift_mix(y) * k0 + z,
hash_len16(v.second, w.second, mul) + x,
mul);
}

uint64_t hash64(const std::string& str) {
return hash64(str.data(), str.size());
}
}

void StringToHashBucket::validate_and_infer_types() {
OPENVINO_ASSERT(get_input_size() == 3);

auto begins_type = this->get_input_element_type(0);
auto ends_type = this->get_input_element_type(1);
auto output_shape = this->get_input_partial_shape(0);

OPENVINO_ASSERT(begins_type == element::i32 && ends_type == element::i32,
"Expected an i32 begins and ends for string tensor representation.");
OPENVINO_ASSERT(m_num_buckets > 0, "num_buckets attribute must be positive");

set_output_type(0, ov::element::i64, output_shape);
}

bool StringToHashBucket::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
auto begins = inputs[0].data<const int32_t>();
auto ends = inputs[1].data<const int32_t>();
auto chars = inputs[2].data<const char>();

auto output_shape = inputs[0].get_shape();
outputs[0].set_shape(output_shape);
auto result = outputs[0].data<int64_t>();

auto num_elems = inputs[0].get_size();
for (size_t ind = 0; ind < num_elems; ++ind) {
OPENVINO_ASSERT(begins[ind] <= ends[ind]);
result[ind] = hash64(chars + begins[ind], static_cast<size_t>(ends[ind] - begins[ind])) % m_num_buckets;
}

return true;
}
41 changes: 41 additions & 0 deletions src/string_to_hash_bucket.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <openvino/op/op.hpp>

// StringToHashBucket computes a bucket index for each string element
class StringToHashBucket : public ov::op::Op {
public:
OPENVINO_OP("StringToHashBucket");

StringToHashBucket() = default;

StringToHashBucket(ov::OutputVector inputs, int64_t num_buckets)
: m_num_buckets(num_buckets), ov::op::Op(inputs) {
constructor_validate_and_infer_types();
}

void validate_and_infer_types() override;

std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& inputs) const override {
auto result = std::make_shared<StringToHashBucket>(inputs, m_num_buckets);
return result;
}

bool visit_attributes(ov::AttributeVisitor& visitor) override {
visitor.on_attribute("num_buckets", m_num_buckets);
return true;
}

bool has_evaluate() const override {
return true;
}

bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override;

private:
int64_t m_num_buckets;
};
19 changes: 19 additions & 0 deletions src/tensorflow_translators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "ragged_to_ragged.hpp"
#include "regex_normalization.hpp"
#include "regex_split.hpp"
#include "string_to_hash_bucket.hpp"
#include "vocab_encoder.hpp"

#include "wordpiece_tokenizer.hpp"
Expand Down Expand Up @@ -492,3 +493,21 @@ ov::OutputVector translate_equal(const ov::frontend::NodeContext& node) {

return { result };
}

ov::OutputVector translate_string_to_hash_bucket_fast(const ov::frontend::NodeContext& node) {
auto node_name = node.get_name();
auto node_input_size = node.get_input_size();
TENSORFLOW_OP_VALIDATION(node, node_input_size == 1,
"[TensorFlow Frontend] inconsistent model: StringToHashBucketFast must have one input");
auto input = node.get_input(0);
auto num_buckets = node.get_attribute<int64_t>("num_buckets");
TENSORFLOW_OP_VALIDATION(node, num_buckets > 0,
"[TensorFlow Frontend] inconsistent model: num_buckets for StringToHashBucketFast must be positive");

ov::OutputVector unpacked_input = pre_translate_string_tensor_input(input);
ov::Output<ov::Node> result = std::make_shared<StringToHashBucket>(unpacked_input, num_buckets);

result.get_node_shared_ptr()->set_friendly_name(node_name);
result.set_names({ node_name + ":0" });
return { result };
}
1 change: 1 addition & 0 deletions src/tensorflow_translators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ ov::OutputVector translate_wordpiece_tokenize_with_offsets(const ov::frontend::N
ov::OutputVector translate_string_lower(const ov::frontend::NodeContext& node);
ov::OutputVector translate_ragged_tensor_to_tensor(const ov::frontend::NodeContext& node);
ov::OutputVector translate_equal(const ov::frontend::NodeContext& node);
ov::OutputVector translate_string_to_hash_bucket_fast(const ov::frontend::NodeContext& node);
1 change: 1 addition & 0 deletions src/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "ragged_to_dense.hpp"
#include "ragged_to_sparse.hpp"
#include "ragged_to_ragged.hpp"
#include "string_to_hash_bucket.hpp"
#include "vocab_decoder.hpp"
#include "vocab_encoder.hpp"
#include "chars_to_bytes.hpp"
Expand Down

0 comments on commit f1a34f0

Please sign in to comment.