Skip to content

Commit

Permalink
[TRANSFORMATIONS] Extend SDPAToPagedAttention transformation with the…
Browse files Browse the repository at this point in the history
… score output (#25621)

[TRANSFORMATIONS] Extend SDPAToPagedAttention transformation with the
score output

Extend SDPAToPagedAttention transformation with the score output used
for cache eviction.
Add a use_cache_eviction flag to the transformation constructor to
enable/disable the feature.

Tickets:
	* CVS-146959

Signed-off-by: Andrii Staikov <[email protected]>

---------

Signed-off-by: Andrii Staikov <[email protected]>
  • Loading branch information
CuriousPanCake authored Jul 22, 2024
1 parent 7e0c442 commit e2a4bbf
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,14 @@ void regmodule_offline_transformations(py::module m) {

m_offline_transformations.def(
"paged_attention_transformation",
[](std::shared_ptr<ov::Model> model) {
[](std::shared_ptr<ov::Model> model, bool use_block_indices_inputs, bool use_score_outputs) {
ov::pass::Manager manager;
manager.register_pass<ov::pass::SDPAToPagedAttention>();
manager.register_pass<ov::pass::SDPAToPagedAttention>(use_block_indices_inputs, use_score_outputs);
manager.run_passes(model);
},
py::arg("model"));
py::arg("model"),
py::arg("use_block_indices_inputs") = false,
py::arg("use_score_outputs") = false);

m_offline_transformations.def(
"stateful_to_stateless_transformation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ class ov::pass::StateManagementPattern : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("StateManagementPattern", "0");
StateManagementPattern(ParameterVector& kv_parameters,
const ParameterVector& model_remaining_params,
ParameterVector& model_remaining_params,
const std::shared_ptr<ov::op::v0::Constant>& sliding_window,
ParameterVector& parameters_to_remove,
int& layer_index,
ov::Output<Node> max_context_len);
ov::Output<Node> max_context_len,
ParameterVector& block_indices_inputs,
ResultVector& score_results,
bool use_block_indices,
bool use_score_outputs);
};
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ static std::shared_ptr<v0::Parameter> setName(std::shared_ptr<v0::Parameter> nod
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
// given single name)
node->set_friendly_name(name);
OPENVINO_ASSERT(node->get_output_size() ==
1); // Should I use assert here? I heard using ASSERTS is not the best thing
OPENVINO_ASSERT(node->get_output_size() == 1);
node->get_output_tensor(0).set_names({name});
return node;
}
Expand Down Expand Up @@ -64,11 +63,15 @@ static node_tuple kv_read_and_concat(ov::Output<ov::Node> kv_current) {
}

ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_parameters,
const ParameterVector& model_remaining_params,
ParameterVector& model_remaining_params,
const std::shared_ptr<ov::op::v0::Constant>& sliding_window,
ParameterVector& parameters_to_remove,
int& layer_index,
Output<Node> max_context_len) {
Output<Node> max_context_len,
ParameterVector& block_indices_inputs,
ResultVector& score_results,
bool use_block_indices_inputs,
bool use_score_outputs) {
MATCHER_SCOPE(StateManagementPattern);

auto k_current = pattern::any_input();
Expand Down Expand Up @@ -163,6 +166,8 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
&model_remaining_params,
&sliding_window,
&parameters_to_remove,
&block_indices_inputs,
&score_results,
&layer_index](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto real_q = pattern_map.at(q);
Expand Down Expand Up @@ -236,7 +241,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
if (pattern_map.count(qkv_current_split_node)) {
// Fast track for merged K/V caches, based on the currently observed models topologies we don't need to
// change layout and there is no point in the graph where it is in 4D. So `else` branch below is not
// applicable for this case.
// applicable for this case. + std::to_string(layer_index - 1)
auto qkv_split = pattern_map.at(qkv_current_split_node).get_node_shared_ptr();
// TODO: Consider handling Q part as well as KV here, requires more changes in the code and sets
// VariadicSplit before Concat as essential part of the pattern
Expand Down Expand Up @@ -270,8 +275,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
auto real_k = take_4d(k_current, k_current_reshaped, k_current2);
auto real_v = take_4d(v_current, v_current_reshaped, v_current2);

std::shared_ptr<Node> k_transpose_order =
kv_transpose_order; // eeeh, is it a right way to assign Constants? Maybe I should clone somehow?
std::shared_ptr<Node> k_transpose_order = kv_transpose_order;
if (pattern_map.find(k_order) !=
pattern_map
.end()) { // reapply transpose found in the graph by manipulating of indices of our Transpose
Expand All @@ -280,8 +284,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
v0::Constant::create(element::i64, Shape{}, {0}));
}
k_target_layout = std::make_shared<v1::Transpose>(real_k, k_transpose_order);
std::shared_ptr<Node> v_transpose_order =
kv_transpose_order; // eeeh, is it a right way to assign Constants? Maybe I should clone somehow?
std::shared_ptr<Node> v_transpose_order = kv_transpose_order;
if (pattern_map.find(v_order) !=
pattern_map
.end()) { // reapply transpose found in the graph by manipulating of indices of our Transpose
Expand Down Expand Up @@ -317,24 +320,30 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
if (pattern_map.find(alibi) != pattern_map.end()) {
alibi_slopes = std::make_shared<v1::Reshape>(pattern_map.at(alibi),
v0::Constant::create(element::i64, Shape{1}, {-1}),
false); // here {-1} is interesting in Python TODO: discuss
false);
if (alibi_slopes->get_element_type() == element::f32) {
alibi_slopes = std::make_shared<v0::Convert>(alibi_slopes, element::f32);
}
} else {
alibi_slopes = v0::Constant::create(element::f32, Shape{0}, {}); // correctly created?
alibi_slopes = v0::Constant::create(element::f32, Shape{0}, {});
}

OutputVector params = {q_reshape, k_reshape, v_reshape, k_parameter, v_parameter};
params.insert(params.end(), model_remaining_params.begin(), model_remaining_params.end());
OutputVector pa_arguments = {q_reshape, k_reshape, v_reshape, k_parameter, v_parameter};
pa_arguments.insert(pa_arguments.end(), model_remaining_params.begin(), model_remaining_params.end());
std::initializer_list<std::shared_ptr<Node>> additional_params = {scale,
sliding_window,
alibi_slopes,
max_context_len.get_node_shared_ptr()};
params.insert(params.end(), additional_params.begin(), additional_params.end());
pa_arguments.insert(pa_arguments.end(), additional_params.begin(), additional_params.end());

// Really not sure if I construct correctly because the Python code uses an additional function
auto paged_attention = std::make_shared<ov::op::PagedAttentionExtension>(params);
if (use_block_indices_inputs) {
auto block_indices = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}),
"block_indices." + std::to_string(layer_index - 1));
pa_arguments.insert(pa_arguments.begin() + 7, block_indices);
block_indices_inputs.push_back(block_indices);
}

auto paged_attention = std::make_shared<ov::op::PagedAttentionExtension>(pa_arguments);

auto pa_shape = std::make_shared<v0::Concat>(
OutputVector{
Expand All @@ -344,8 +353,13 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
std::make_shared<v0::Unsqueeze>(hidden_dim, v0::Constant::create(element::i64, Shape{}, {0})),
},
0);
auto pa_reshape = std::make_shared<v1::Reshape>(paged_attention, pa_shape, true);
auto pa_reshape = std::make_shared<v1::Reshape>(paged_attention->output(0), pa_shape, true);
auto pa_transpose = std::make_shared<v1::Transpose>(pa_reshape, kv_transpose_order);
if (use_score_outputs) {
auto score_result = std::make_shared<v0::Result>(paged_attention->output(1));
score_result->get_output_tensor(0).set_names({"scores." + std::to_string(layer_index - 1)});
score_results.push_back(score_result);
}

// TODO: Complete this part to work with stateless models as well as will stateful
// def add_kv_parameter(past_node):
Expand Down
5 changes: 5 additions & 0 deletions src/core/include/openvino/pass/sdpa_to_paged_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ class OPENVINO_API SDPAToPagedAttention : public ModelPass {
public:
OPENVINO_RTTI("SDPAToPagedAttention");

SDPAToPagedAttention(bool use_block_indices_inputs = false, bool use_score_outputs = false);
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;

private:
bool m_use_block_indices_inputs;
bool m_use_score_outputs;
};
} // namespace pass
} // namespace ov
1 change: 1 addition & 0 deletions src/core/src/op/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ void PagedAttentionExtension::validate_and_infer_types() {
".");

set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
set_output_type(1, get_input_element_type(0), {Dimension::dynamic()});
}

std::shared_ptr<ov::Node> PagedAttentionExtension::clone_with_new_inputs(const ov::OutputVector& new_args) const {
Expand Down
28 changes: 25 additions & 3 deletions src/core/src/pass/sdpa_to_paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

using namespace ov::op;

ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_block_indices_inputs, bool use_score_outputs)
: m_use_block_indices_inputs(use_block_indices_inputs),
m_use_score_outputs(use_score_outputs) {}

static std::shared_ptr<v0::Parameter> setName(std::shared_ptr<v0::Parameter> node, const char* name) {
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
// given single name)
Expand All @@ -37,12 +41,16 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
"the SDPAToPagedAttention transformation.");

auto max_context_len = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{}), "max_context_len");
ParameterVector model_remaining_params = {
ParameterVector model_remaining_params{
setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), "past_lens"),
setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), "subsequence_begins"),
setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), "block_indices"),
setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), "block_indices_begins"),
};
if (!m_use_block_indices_inputs) {
auto block_indices = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), "block_indices");
model_remaining_params.insert(model_remaining_params.begin() + 2, block_indices);
}

auto sliding_window = v0::Constant::create(element::i32, Shape{}, {0}); // sliding_window

std::shared_ptr<v0::Parameter> input_ids_node =
Expand Down Expand Up @@ -72,6 +80,8 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
ParameterVector kv_parameters;
ParameterVector parameters_to_remove;
ResultVector results_to_remove; // # used, but cannot really track all Results in stateless model
ParameterVector block_indices_inputs;
ResultVector score_results;

std::shared_ptr<v0::Parameter> position_ids;
if (!has_parameter(model, "position_ids")) {
Expand All @@ -98,7 +108,11 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
sliding_window,
parameters_to_remove,
layer_index,
max_context_len->output(0));
max_context_len->output(0),
block_indices_inputs,
score_results,
m_use_block_indices_inputs,
m_use_score_outputs);
manager.register_pass<PrevSequenceLengthPattern>(prev_max_seq_len, batch_dim);
manager.register_pass<TotalSequenceLengthPattern>(max_context_len);
manager.register_pass<PositionIDsReplacer>(unsqueezed_position_ids->output(0));
Expand Down Expand Up @@ -154,6 +168,14 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
model->remove_parameter(parameter);
}

if (m_use_block_indices_inputs) {
model->add_parameters(block_indices_inputs);
}

if (m_use_score_outputs) {
model->add_results(score_results);
}

model->add_parameters(kv_parameters);
model->add_parameters(model_remaining_params);
model->add_parameters({std::move(max_context_len)});
Expand Down
50 changes: 46 additions & 4 deletions tests/model_hub_tests/pytorch/test_pa_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@
import models_hub_common.utils as utils
import pytest
import os
import re

def run_pa(tmp_path, model_id, model_link):
def run_pa(tmp_path, model_id, model_link, use_block_indices_inputs, use_score_outputs):
model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True)

paged_attention_transformation(model.model)
paged_attention_transformation(model.model, use_block_indices_inputs, use_score_outputs)

# Test that a _PagedAttentionExtension node appeared after the transformation.
assert any(isinstance(op, _PagedAttentionExtension) for op in model.model.get_ordered_ops()), f"The model '{model_id}' has no _PagedAttentionExtension present."
pa_counter = 0
for op in model.model.get_ordered_ops():
if isinstance(op, _PagedAttentionExtension):
pa_counter += 1

assert pa_counter > 0, f"The model '{model_id}' has no _PagedAttentionExtension present."

model_inputs = model.model.inputs
for input in model_inputs:
Expand All @@ -26,6 +32,31 @@ def run_pa(tmp_path, model_id, model_link):
assert shape[-1].is_static, f"Dimension {len(shape) - 1} of input '{name}' in '{model_id}' is not static: {shape}"
assert shape[-2].is_static, f"Dimension {len(shape) - 2} of input '{name}' in '{model_id}' is not static: {shape}"

# Test for block_indices inputs and scores outputs to appear in the model
if (use_block_indices_inputs):
block_indices_pattern = r'block_indices\.[0-9]+'
block_indices_counter = 0

model_inputs = model.model.inputs
for input in model_inputs:
for name in list(input.get_names()):
if re.search(block_indices_pattern, name):
block_indices_counter += 1

assert(block_indices_counter == pa_counter)

if (use_score_outputs):
score_pattern = r'scores\.[0-9]+'
score_outputs_counter = 0

model_outputs = model.model.outputs
for output in model_outputs:
for name in list(output.get_names()):
if re.search(score_pattern, name):
score_outputs_counter += 1

assert(score_outputs_counter == pa_counter)

@pytest.mark.precommit
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit")))
def test_pa_precommit(tmp_path, model_name, model_link, mark, reason, ie_device):
Expand All @@ -35,4 +66,15 @@ def test_pa_precommit(tmp_path, model_name, model_link, mark, reason, ie_device)
pytest.skip(reason)
elif mark == 'xfail':
pytest.xfail(reason)
run_pa(tmp_path, model_name, model_link)
run_pa(tmp_path, model_name, model_link, False, False)

@pytest.mark.precommit
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit")))
def test_pa_precommit_use_cache_eviction(tmp_path, model_name, model_link, mark, reason, ie_device):
assert mark is None or mark == 'skip' or mark == 'xfail', \
"Incorrect test case: {}, {}".format(model_name, model_link)
if mark == 'skip':
pytest.skip(reason)
elif mark == 'xfail':
pytest.xfail(reason)
run_pa(tmp_path, model_name, model_link, True, True)

0 comments on commit e2a4bbf

Please sign in to comment.