Skip to content

Commit

Permalink
start_async + wait
Browse files Browse the repository at this point in the history
  • Loading branch information
dkalinowski committed Jul 22, 2024
1 parent d43789c commit 801cec6
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/cpp/src/tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ class Tokenizer::TokenizerImpl {
CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_tokenizer.get());
size_t batch_size = 1;
infer_request_guard.get().set_input_tensor(ov::Tensor{ov::element::string, {batch_size}, &prompt});
infer_request_guard.get().infer();
infer_request_guard.get().start_async();
infer_request_guard.get().wait();
return get_copied_results(
infer_request_guard.get().get_tensor("input_ids"),
infer_request_guard.get().get_tensor("attention_mask")
Expand All @@ -262,7 +263,8 @@ class Tokenizer::TokenizerImpl {
CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_tokenizer.get());
infer_request_guard.get().set_input_tensor(ov::Tensor{ov::element::string, {prompts.size()}, prompts.data()});
auto size_ = infer_request_guard.get().get_input_tensor().get_shape();
infer_request_guard.get().infer();
infer_request_guard.get().start_async();
infer_request_guard.get().wait();

unpadded = get_copied_results(
infer_request_guard.get().get_tensor("input_ids"),
Expand All @@ -285,7 +287,8 @@ class Tokenizer::TokenizerImpl {
CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_detokenizer.get());
size_t batch_size = 1;
infer_request_guard.get().set_input_tensor(ov::Tensor{ov::element::i64, {batch_size, tokens.size()}, tokens.data()});
infer_request_guard.get().infer();
infer_request_guard.get().start_async();
infer_request_guard.get().wait();
return infer_request_guard.get().get_output_tensor().data<std::string>()[0];
}

Expand All @@ -295,7 +298,8 @@ class Tokenizer::TokenizerImpl {

CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_detokenizer.get());
infer_request_guard.get().set_input_tensor(tokens);
infer_request_guard.get().infer();
infer_request_guard.get().start_async();
infer_request_guard.get().wait();

auto res = infer_request_guard.get().get_output_tensor();
auto res_data = res.data<std::string>();
Expand All @@ -320,7 +324,8 @@ class Tokenizer::TokenizerImpl {

CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_detokenizer.get());
infer_request_guard.get().set_input_tensor(tokens);
infer_request_guard.get().infer();
infer_request_guard.get().start_async();
infer_request_guard.get().wait();
auto res = infer_request_guard.get().get_output_tensor();
auto res_data = res.data<std::string>();
return std::vector<std::string>(res_data, res_data + res.get_shape()[0]);
Expand Down

0 comments on commit 801cec6

Please sign in to comment.