Skip to content

Commit

Permalink
Whisper pipeline: add perf metrics (#971)
Browse files Browse the repository at this point in the history
This PR adds:
- [x] support perf metrics

Common Todos for Whisper support:
- [ ] Long-form audio support with [parallel
chunking](https://huggingface.co/blog/asr-chunking).
- [ ] update documentation
- [ ] add cpp, python samples tests
- [ ] support timestamps streaming
- [ ] expose only meaningful parameters in `GenerationConfig` (`task`,
`language`, `return_timestamps`, etc)
- [ ] Move all whisper pipeline files to dedicated subfolder
- [ ] Whisper pipeline doesn't need tokenizer, it uses detokenizer only.
Implement detokenizer only initialization for `ov::genai::Tokenizer`
- [ ] Check discrete GPU. Integrated GPU works as expected.
- [ ] Investigate use of `RemoteTensor` for GPU
- [ ] Add batch
- [ ] Add sampler, inherit WhisperGenerationConfig from GenerationConfig
- [ ] Investigate language autodetection with single decoder (without
past) call
- [ ] Update python bindings cmake to include whole directory instead of
explicit list of files
- [ ] Add samples with audio preparation examples
- [ ] Add links to audio files so users can download them in samples
- [ ] Move supported models list from samples README to common supported
models section
- [ ] Avoid building GenAI in each tests job as it takes a lot of time
- [ ] Double check FP32 support
- [ ] Fix tests sporadic fails. Sometimes whisper model cannot be
downloaded from HF due to network issues
- [ ] Fix stop criteria. Current approach stops on eos_token which is no
speech token. But there could be more speech tokens further which are
wrongly skipped now
- [ ] Fix distil whisper accuracy, match with HF
- [ ] Fix en models accuracy with timestamps, match with HF
- [ ] Try to trim input_ids cache between chunks for long-form audio to
match HF

Completed:
- [x] support different languages, language autodetection
- [x] support translation
- [x] support timestamps
- [x] Long-form audio support with sequential chunking.

Current limitations:
- No resampling during preprocessing. Input raw speech should have 16k
Hz sampling rate
- No normalization during preprocessing. Input raw speech should be
normalized to near [-1, 1] range

Tickets: CVS-147994, CVS-146010, CVS-152523
  • Loading branch information
as-suvorov authored Oct 15, 2024
1 parent 1fdf96e commit a907b5f
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ int main(int argc, char* argv[]) try {
for (auto& chunk : *result.chunks) {
std::cout << "timestamps: [" << chunk.start_ts << ", " << chunk.end_ts << "] text: " << chunk.text << "\n";
}

} catch (const std::exception& error) {
try {
std::cerr << error.what() << '\n';
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/perf_metrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void PerfMetrics::evaluate_statistics(std::optional<TimePoint> start_time) {
if (m_evaluated){
return;
}
// If start_tiem is specified then recalcualte durations according to start times and calculate statistics only after that.
// If start_item is specified then recalcualte durations according to start times and calculate statistics only after that.
if (start_time.has_value()) {
auto start_time_val = *start_time;
auto& tok_times = raw_metrics.m_new_token_times;
Expand Down
5 changes: 3 additions & 2 deletions src/cpp/src/whisper/timestamps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ ov::genai::ExtractedSegments extract_segments(const std::vector<int64_t>& tokens
tokens.end());
}

// last timestamps generated in pairs <ts><ts><eos> -> speech segment continuation to the next chunk -> token_start will have value
// single ending timestamp <ts><eos> -> no more speech till the end of current chunk -> set offset to the end of frame
// last timestamps generated in pairs <ts><ts><eos> -> speech segment continuation to the next chunk -> token_start
// will have value single ending timestamp <ts><eos> -> no more speech till the end of current chunk -> set offset
// to the end of frame
if (!token_start.has_value()) {
extracted_segments.last_offset = nb_max_frames;
}
Expand Down
78 changes: 61 additions & 17 deletions src/cpp/src/whisper/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "../utils.hpp"
#include "logit_processor.hpp"
#include "openvino/genai/perf_metrics.hpp"
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/whisper_generation_config.hpp"
#include "openvino/genai/whisper_pipeline.hpp"
Expand All @@ -18,12 +19,15 @@
#include "whisper_feature_extractor.hpp"
#include "whisper_models.hpp"

using ov::genai::MicroSeconds;

namespace {

ov::Tensor encode(ov::InferRequest& request,
std::vector<float>& mel_data,
const size_t feature_size,
const size_t nb_max_frames) {
const size_t nb_max_frames,
ov::genai::RawPerfMetrics& raw_metrics) {
OPENVINO_ASSERT(mel_data.size() == feature_size * nb_max_frames,
"Mel spectrogram required size: ",
feature_size,
Expand All @@ -37,7 +41,10 @@ ov::Tensor encode(ov::InferRequest& request,

request.set_tensor("input_features", input_tensor);

const auto infer_start = std::chrono::steady_clock::now();
request.infer();
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start);
raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms);

// reset input tensor
request.set_tensor("input_features", ov::Tensor(ov::element::f32, {0, feature_size, nb_max_frames}));
Expand Down Expand Up @@ -72,18 +79,30 @@ void set_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) {
}
}

void infer_with_perf_metrics(ov::InferRequest& request, ov::genai::RawPerfMetrics& raw_metrics) {
const auto infer_start = std::chrono::steady_clock::now();
request.infer();
const auto infer_end = std::chrono::steady_clock::now();
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(infer_end - infer_start);
raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms);
raw_metrics.m_token_infer_durations.emplace_back(infer_ms);
raw_metrics.m_new_token_times.emplace_back(infer_end);
raw_metrics.m_batch_sizes.emplace_back(1);
}

int64_t decode(ov::Tensor& encoder_hidden_state,
ov::InferRequest& decoder,
std::vector<int64_t>& input_ids,
const ov::genai::WhisperGenerationConfig& config,
ov::genai::RawPerfMetrics& raw_metrics,
const bool apply_logit_processors = true,
const bool return_timestamps = false) {
decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});

ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, input_ids.data());
decoder.set_tensor("input_ids", input_ids_tensor);

decoder.infer();
infer_with_perf_metrics(decoder, raw_metrics);

auto output_tensor = decoder.get_tensor("logits");

Expand All @@ -106,6 +125,7 @@ int64_t decode_with_past(ov::Tensor& encoder_hidden_state,
int64_t input_id,
const size_t cache_position,
const ov::genai::WhisperGenerationConfig& config,
ov::genai::RawPerfMetrics& raw_metrics,
const bool return_timestamps,
const std::vector<int64_t>& generated_tokens) {
decoder_with_past.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});
Expand All @@ -118,7 +138,7 @@ int64_t decode_with_past(ov::Tensor& encoder_hidden_state,
cache_position_tensor.set_shape({1});
cache_position_tensor.data<int64_t>()[0] = cache_position;

decoder_with_past.infer();
infer_with_perf_metrics(decoder_with_past, raw_metrics);

auto output_tensor = decoder_with_past.get_tensor("logits");

Expand All @@ -137,7 +157,17 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state,
ov::InferRequest decoder,
const ov::genai::WhisperGenerationConfig& config) {
std::vector<int64_t> input_ids{config.decoder_start_token_id};
int64_t output_token = decode(encoder_hidden_state, decoder, input_ids, config, false, false);

decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});

ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, input_ids.data());
decoder.set_tensor("input_ids", input_ids_tensor);

decoder.infer();

auto output_tensor = decoder.get_tensor("logits");

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

return output_token;
}
Expand Down Expand Up @@ -181,8 +211,10 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
std::vector<int64_t> init_ids,
const size_t max_new_tokens,
const bool return_timestamps,
ov::genai::RawPerfMetrics& raw_metrics,
const std::shared_ptr<ov::genai::StreamerBase> streamer) {
int64_t output_token = decode(encoder_hidden_state, models.decoder, init_ids, config, true, return_timestamps);
int64_t output_token =
decode(encoder_hidden_state, models.decoder, init_ids, config, raw_metrics, true, return_timestamps);

std::vector<int64_t> output_tokens{output_token};

Expand All @@ -203,6 +235,7 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
output_tokens.back(),
init_ids.size() + output_tokens.size() - 1,
config,
raw_metrics,
return_timestamps,
output_tokens);

Expand Down Expand Up @@ -230,23 +263,30 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
namespace ov {
namespace genai {

std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_generate(
const ov::genai::WhisperGenerationConfig& config,
const ov::genai::WhisperConfig& model_config,
const RawSpeechInput& raw_speech,
ov::genai::WhisperInitializedModels& models,
WhisperFeatureExtractor& feature_extractor,
const std::shared_ptr<StreamerBase> streamer) {
WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& config,
const ov::genai::WhisperConfig& model_config,
const RawSpeechInput& raw_speech,
ov::genai::WhisperInitializedModels& models,
WhisperFeatureExtractor& feature_extractor,
const std::shared_ptr<StreamerBase> streamer) {
auto input_features = feature_extractor.extract(raw_speech);

const bool is_shortform = input_features.n_frames <= feature_extractor.nb_max_frames;
// long-form audio processing requires timestamps to be enabled
const bool return_timestamps = config.return_timestamps || !is_shortform;

std::vector<int64_t> init_ids;
std::vector<int64_t> output_tokens;
size_t max_new_tokens = config.get_max_new_tokens();

WhisperGenerateResult result;
RawPerfMetrics& raw_metrics = result.perf_metrics.raw_metrics;
result.perf_metrics.num_input_tokens = 0;
raw_metrics.m_new_token_times.reserve(max_new_tokens);
raw_metrics.m_batch_sizes.reserve(max_new_tokens);
raw_metrics.m_token_infer_durations.reserve(max_new_tokens);
raw_metrics.m_inference_durations = {{MicroSeconds(0.0f)}};

std::vector<int64_t> init_ids;
std::vector<int64_t>& output_tokens = result.output_tokens;
std::vector<Segment> segments;

// 0.02 by default
Expand All @@ -263,7 +303,8 @@ std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_gen
ov::Tensor hidden_state_tensor = encode(models.encoder,
input_features_chunk,
feature_extractor.feature_size,
feature_extractor.nb_max_frames);
feature_extractor.nb_max_frames,
raw_metrics);

// prepare init_ids just once for whole input
if (init_ids.empty()) {
Expand All @@ -276,6 +317,7 @@ std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_gen
init_ids,
max_new_tokens - output_tokens.size(),
return_timestamps,
raw_metrics,
streamer);

if (return_timestamps) {
Expand Down Expand Up @@ -310,10 +352,12 @@ std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_gen

// if return_timestamps wasn't enabled by user
if (!config.return_timestamps) {
return {output_tokens, std::nullopt};
return result;
}

return {output_tokens, segments};
result.segments = segments;

return result;
}
} // namespace genai
} // namespace ov
19 changes: 12 additions & 7 deletions src/cpp/src/whisper/whisper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@ struct Segment {
std::vector<int64_t> m_tokens;
};

std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_generate(
const ov::genai::WhisperGenerationConfig& config,
const ov::genai::WhisperConfig& model_config,
const ov::genai::RawSpeechInput& raw_speech,
ov::genai::WhisperInitializedModels& models,
ov::genai::WhisperFeatureExtractor& feature_extractor,
const std::shared_ptr<StreamerBase> streamer);
struct WhisperGenerateResult {
std::vector<int64_t> output_tokens;
std::optional<std::vector<Segment>> segments = std::nullopt;
PerfMetrics perf_metrics;
};

WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& config,
const ov::genai::WhisperConfig& model_config,
const ov::genai::RawSpeechInput& raw_speech,
ov::genai::WhisperInitializedModels& models,
ov::genai::WhisperFeatureExtractor& feature_extractor,
const std::shared_ptr<StreamerBase> streamer);

} // namespace genai
} // namespace ov
53 changes: 34 additions & 19 deletions src/cpp/src/whisper_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,28 +93,43 @@ class WhisperPipeline::Impl {
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
}

auto [output_tokens, segments] = ov::genai::whisper_generate(config,
m_model_config,
raw_speech_input,
m_models,
m_feature_extractor,
streamer_ptr);

WhisperDecodedResults decoded_results{std::vector{m_tokenizer.decode(output_tokens)}, std::vector{1.f}};
if (!segments.has_value()) {
return decoded_results;
auto generate_result = ov::genai::whisper_generate(config,
m_model_config,
raw_speech_input,
m_models,
m_feature_extractor,
streamer_ptr);
auto decode_start_time = std::chrono::steady_clock::now();
WhisperDecodedResults result{std::vector{m_tokenizer.decode(generate_result.output_tokens)}, std::vector{1.f}};
generate_result.perf_metrics.raw_metrics.detokenization_durations.emplace_back(
PerfMetrics::get_microsec(std::chrono::steady_clock::now() - decode_start_time));

result.perf_metrics = generate_result.perf_metrics;
auto& segments = generate_result.segments;

if (segments.has_value()) {
std::vector<WhisperDecodedResultChunk> chunks;
chunks.reserve((*segments).size());

for (auto& segment : *segments) {
decode_start_time = std::chrono::steady_clock::now();
chunks.push_back(
WhisperDecodedResultChunk{segment.m_start, segment.m_end, m_tokenizer.decode(segment.m_tokens)});
result.perf_metrics.raw_metrics.detokenization_durations.emplace_back(
PerfMetrics::get_microsec(std::chrono::steady_clock::now() - decode_start_time));
}

result.chunks = chunks;
}

std::vector<WhisperDecodedResultChunk> chunks;
chunks.reserve((*segments).size());
auto& metrics = result.perf_metrics;
metrics.load_time = this->m_load_time_ms;
auto stop_time = std::chrono::steady_clock::now();
metrics.raw_metrics.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time));
result.perf_metrics.raw_metrics.tokenization_durations.emplace_back(MicroSeconds(0.0f));
metrics.evaluate_statistics(start_time);

for (auto& segment : *segments) {
chunks.push_back(
WhisperDecodedResultChunk{segment.m_start, segment.m_end, m_tokenizer.decode(segment.m_tokens)});
}

decoded_results.chunks = chunks;
return decoded_results;
return result;
}
};

Expand Down
2 changes: 2 additions & 0 deletions src/python/py_generate_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,10 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
.def("get_num_input_tokens", &PerfMetrics::get_num_input_tokens)
.def("get_ttft", &PerfMetrics::get_ttft)
.def("get_tpot", &PerfMetrics::get_tpot)
.def("get_ipot", &PerfMetrics::get_ipot)
.def("get_throughput", &PerfMetrics::get_throughput)
.def("get_generate_duration", &PerfMetrics::get_generate_duration)
.def("get_inference_duration", &PerfMetrics::get_inference_duration)
.def("get_tokenization_duration", &PerfMetrics::get_tokenization_duration)
.def("get_detokenization_duration", &PerfMetrics::get_detokenization_duration)
.def("__add__", &PerfMetrics::operator+)
Expand Down
44 changes: 37 additions & 7 deletions tests/python_tests/test_whisper_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,25 @@ def test_whisper_on_hf_dataset(model_descr, dataset_id):
compare_genai_and_opt_pipelines(opt_pipe, genai_pipe, dataset_id)


@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
@pytest.mark.parametrize(
"test_sample",
get_samples_from_dataset(language="en", length=1),
)
@pytest.mark.precommit
def test_smoke(model_descr, test_sample):
model_id, path, opt_pipe, pipe = read_whisper_model(model_descr)

expected = opt_pipe(test_sample)

genai_result = pipe.generate(test_sample)

assert genai_result.texts[0] == expected["text"]

assert "chunks" not in expected
assert genai_result.chunks == None


@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
@pytest.mark.precommit
def test_whisper_config_constructor(model_descr):
Expand Down Expand Up @@ -509,17 +528,28 @@ def test_longform_audio(model_descr, test_sample):
@pytest.mark.parametrize("model_descr", get_whisper_models_list(tiny_only=True))
@pytest.mark.parametrize(
"test_sample",
get_samples_from_dataset(language="en", length=1),
[
*get_samples_from_dataset(language="en", length=1),
],
)
@pytest.mark.precommit
def test_smoke(model_descr, test_sample):
def test_perf_metrics(model_descr, test_sample):
model_id, path, opt_pipe, pipe = read_whisper_model(model_descr)

expected = opt_pipe(test_sample)
result = pipe.generate(test_sample)

genai_result = pipe.generate(test_sample)
perf_metrics = result.perf_metrics

assert genai_result.texts[0] == expected["text"]
assert perf_metrics is not None

assert "chunks" not in expected
assert genai_result.chunks == None
assert perf_metrics.get_load_time() > 0
assert perf_metrics.get_num_generated_tokens() > 0
assert perf_metrics.get_num_input_tokens() == 0
assert perf_metrics.get_ttft().mean > 0
assert perf_metrics.get_tpot().mean > 0
assert perf_metrics.get_ipot().mean > 0
assert perf_metrics.get_throughput().mean > 0
assert perf_metrics.get_inference_duration().mean > 0
assert perf_metrics.get_generate_duration().mean > 0
assert perf_metrics.get_tokenization_duration().mean == 0
assert perf_metrics.get_detokenization_duration().mean > 0

0 comments on commit a907b5f

Please sign in to comment.