Skip to content

Commit

Permalink
Whisper pipeline: support stateful decoder (#1474)
Browse files Browse the repository at this point in the history
Ticket: 159473
Optimum-intel PR: huggingface/optimum-intel#1078
This PR switches optimum-intel in tests to stateful seq2seq branch.
Tests check both stateful and with past decoders. Once optimum-intel PR
is merged I'll switch version back to master.
  • Loading branch information
as-suvorov authored Jan 13, 2025
1 parent 505abe8 commit 67d6cd3
Show file tree
Hide file tree
Showing 15 changed files with 477 additions and 214 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,9 @@ jobs:
python -m pip install . --verbose --find-links ${env:OV_INSTALL_DIR}/wheels
python -m pip install ./tools/who_what_benchmark --find-links ${env:OV_INSTALL_DIR}/wheels
# will install transformers 4.46.3 version
# transformers 4.46.3 will enable return_timestamps tests
# this check enabled for windows only. Ticket: 160205.
python -m pip install git+https://github.com/huggingface/optimum-intel.git@753f84db6e0966580eb9eaa74a808213be730631
python -m pip install transformers==4.46.3
python -m pytest -v ./tests/python_tests/test_whisper_pipeline.py -k "not test_smoke"
Expand Down
2 changes: 1 addition & 1 deletion src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ int main(int argc, char* argv[]) {

Streaming with a custom class:

C++ template for a stremer.
C++ template for a streamer.
```cpp
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/llm_pipeline.hpp"
Expand Down
17 changes: 17 additions & 0 deletions src/cpp/src/logger.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once
#include <iostream>
#include <string>

namespace ov::genai {

class Logger {
public:
static void warn(std::string message) {
std::cout << "[WARN] " << message << '\n';
};
};

} // namespace ov::genai
26 changes: 26 additions & 0 deletions src/cpp/src/whisper/models/decoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "decoder.hpp"

#include <filesystem>

#include "statefull_decoder.hpp"
#include "utils.hpp"
#include "with_past_decoder.hpp"

namespace ov::genai {
std::shared_ptr<WhisperDecoder> WhisperDecoder::from_path(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties) {
bool has_decoder_with_past = std::filesystem::exists(models_path / "openvino_decoder_with_past_model.xml");

if (has_decoder_with_past) {
return std::make_shared<WhisperWithPastDecoder>(models_path, device, properties);
}

return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties);
}

WhisperDecoder::~WhisperDecoder() = default;
} // namespace ov::genai
29 changes: 29 additions & 0 deletions src/cpp/src/whisper/models/decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <filesystem>

#include "openvino/genai/whisper_generation_config.hpp"
#include "openvino/runtime/core.hpp"

namespace ov::genai {
class WhisperDecoder {
public:
static std::shared_ptr<WhisperDecoder> from_path(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties);

virtual std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) = 0;

virtual std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) = 0;

virtual void reset_state() = 0;

virtual ~WhisperDecoder();
};
} // namespace ov::genai
60 changes: 60 additions & 0 deletions src/cpp/src/whisper/models/statefull_decoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "statefull_decoder.hpp"

#include "utils.hpp"

namespace ov::genai {
WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties) {
ov::Core core = utils::singleton_core();

auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties);

utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
m_request = compiled_model.create_infer_request();
}

std::pair<int64_t, float> WhisperStatefullDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) {
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0);

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

reset_state();

return {output_token, infer_ms};
}

std::pair<ov::Tensor, float> WhisperStatefullDecoder::decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) {
m_request.set_tensor("encoder_hidden_states", encoder_hidden_state);

ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data());
m_request.set_tensor("input_ids", input_ids_tensor);

ov::Tensor cache_position_tensor = m_request.get_tensor("cache_position");
cache_position_tensor.set_shape({input_ids.size()});

auto cache_data = cache_position_tensor.data<int64_t>();
std::iota(cache_data, cache_data + cache_position_tensor.get_size(), cache_position);

m_request.get_tensor("beam_idx").set_shape({1});
m_request.get_tensor("beam_idx").data<int32_t>()[0] = 0;

const auto infer_start = std::chrono::steady_clock::now();
m_request.infer();
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start);

auto output_tensor = m_request.get_tensor("logits");

return {output_tensor, infer_ms};
};

void WhisperStatefullDecoder::reset_state() {
m_request.reset_state();
}
} // namespace ov::genai
29 changes: 29 additions & 0 deletions src/cpp/src/whisper/models/statefull_decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "decoder.hpp"
#include "openvino/runtime/core.hpp"

namespace ov::genai {

class WhisperStatefullDecoder : public WhisperDecoder {
public:
WhisperStatefullDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties);

std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) override;

std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) override;

void reset_state() override;

private:
ov::InferRequest m_request;
};
} // namespace ov::genai
107 changes: 107 additions & 0 deletions src/cpp/src/whisper/models/with_past_decoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "with_past_decoder.hpp"

#include <regex>

#include "logger.hpp"
#include "utils.hpp"

namespace {
void set_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) {
// source outputs:
// present.0.decoder.key
// present.0.decoder.value
// present.0.encoder.key
// present.0.encoder.value

// dest inputs:
// past_key_values.0.decoder.key
// past_key_values.0.decoder.value
// past_key_values.0.encoder.key
// past_key_values.0.encoder.value

for (auto& source_output : source.get_compiled_model().outputs()) {
std::string source_output_name = source_output.get_any_name();
if (source_output_name.find("logits") != std::string::npos) {
continue;
}

std::string with_past_input_name =
std::regex_replace(source_output_name, std::regex("present"), "past_key_values");

auto kv_tensor = source.get_tensor(source_output_name);
dest.set_tensor(with_past_input_name, ov::Tensor{kv_tensor});
}
}
} // namespace

namespace ov::genai {
WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties) {
Logger::warn("Whisper decoder models with past is deprecated. Support will be removed in 2026.0.0 release.\n"
"To obtain stateful decoder model use latest `optimum-intel` package:\n"
"pip install optimum-intel@git+https://github.com/huggingface/optimum-intel.git\n"
"optimum-cli export openvino --trust-remote-code --model openai/whisper-tiny whisper-tiny");
ov::Core core = utils::singleton_core();

auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties);
utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
m_request_decoder = compiled_model.create_infer_request();

compiled_model = core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, properties);
utils::print_compiled_model_properties(compiled_model, "whisper decoder with past model");
m_request_decoder_with_past = compiled_model.create_infer_request();
}

std::pair<int64_t, float> WhisperWithPastDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) {
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0);

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

reset_state();

return {output_token, infer_ms};
}

std::pair<ov::Tensor, float> WhisperWithPastDecoder::decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) {
const bool initial_step = cache_position == 0;
ov::InferRequest& request = initial_step ? m_request_decoder : m_request_decoder_with_past;

request.set_tensor("encoder_hidden_states", encoder_hidden_state);

const ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data());
request.set_tensor("input_ids", input_ids_tensor);

if (!initial_step) {
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
cache_position_tensor.set_shape({1});
cache_position_tensor.data<int64_t>()[0] = cache_position;
}

const auto infer_start = std::chrono::steady_clock::now();
request.infer();
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start);

auto output_tensor = request.get_tensor("logits");

if (initial_step) {
set_past_key_value(m_request_decoder, m_request_decoder_with_past);
} else if (!m_decoder_with_past_kv_value_set) {
set_past_key_value(m_request_decoder_with_past, m_request_decoder_with_past);
m_decoder_with_past_kv_value_set = true;
}

return {output_tensor, infer_ms};
}

void WhisperWithPastDecoder::reset_state() {
m_request_decoder_with_past.reset_state();
m_decoder_with_past_kv_value_set = false;
}
} // namespace ov::genai
32 changes: 32 additions & 0 deletions src/cpp/src/whisper/models/with_past_decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "decoder.hpp"
#include "openvino/runtime/core.hpp"

namespace ov::genai {

class WhisperWithPastDecoder : public WhisperDecoder {
public:
WhisperWithPastDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties);

std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) override;

std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) override;

void reset_state() override;

private:
ov::InferRequest m_request_decoder;
ov::InferRequest m_request_decoder_with_past;
bool m_decoder_with_past_kv_value_set = false;
};

} // namespace ov::genai
Loading

0 comments on commit 67d6cd3

Please sign in to comment.