Skip to content

Commit

Permalink
Copy-pasted Unroll SDPA optimization from GenAI into ov::npuw::LLMCom…
Browse files Browse the repository at this point in the history
…piledModel
  • Loading branch information
AsyaPronina committed Dec 12, 2024
1 parent 5d2317d commit a4b0b81
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 29 deletions.
276 changes: 251 additions & 25 deletions src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,54 @@
#include "logging.hpp"
#include "openvino/pass/stateful_to_stateless.hpp"
#include "openvino/runtime/iasync_infer_request.hpp"
#include "openvino/openvino.hpp"
#include "openvino/pass/validate.hpp"
#include "openvino/pass/matcher_pass.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/op/ops.hpp"
#include "openvino/opsets/opset13.hpp"

namespace {
uint32_t align_to(uint32_t value, uint32_t alignment) {
return (value + alignment - 1) & ~(alignment - 1);
}

std::shared_ptr<ov::Model> cvt_kvcache_to_fp16(const std::shared_ptr<ov::Model>& model) {
ov::preprocess::PrePostProcessor ppp(model);

for (auto tensor : model->inputs()) {
if (tensor.get_any_name().find("past_key") != std::string::npos) {
ppp.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
}
}

for (auto tensor : model->outputs()) {
if (tensor.get_any_name().find("present") != std::string::npos) {
ppp.output(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
}
}

return ppp.build();
}

void align_u4_zp_constants(const std::shared_ptr<ov::Model>& model) {
for (auto op : model->get_ops()) {
if (ov::op::util::is_constant(op)) {
auto cst_op = std::dynamic_pointer_cast<ov::op::v0::Constant>(op);
const auto cst_op_out = cst_op->output(0);
if (cst_op_out.get_element_type() == ov::element::u4 && ov::shape_size(cst_op_out.get_shape()) == 1u) {
ov::Tensor cst_tensor(ov::element::u4, cst_op_out.get_shape());
*static_cast<uint8_t*>(cst_tensor.data()) = cst_op->get_vector<uint8_t>()[0] & 0x0f;
auto new_cst_op = std::make_shared<ov::op::v0::Constant>(cst_tensor);
for (auto target_input : cst_op_out.get_target_inputs()) {
target_input.replace_source_output(new_cst_op);
}
}
}
}
}

std::shared_ptr<ov::Model> redirect_new_kv_to_output(const std::shared_ptr<ov::Model>& model) {
const auto kStartOutputKVCacheLayers = 1u;
for (std::size_t i = kStartOutputKVCacheLayers; i < model->outputs().size(); ++i) {
Expand All @@ -27,24 +69,200 @@ std::shared_ptr<ov::Model> redirect_new_kv_to_output(const std::shared_ptr<ov::M
return model;
}

std::shared_ptr<ov::Model> cvt_kvcache_to_fp16(const std::shared_ptr<ov::Model>& model) {
ov::preprocess::PrePostProcessor ppp(model);
namespace opp = ov::pass::pattern;
class TransposeValueTensors : public ov::pass::MatcherPass {
public:
struct Context {
std::vector<std::shared_ptr<ov::opset13::Parameter>> new_params;
std::vector<std::shared_ptr<ov::opset13::Parameter>> old_params;
using Ref = std::reference_wrapper<Context>;
};

TransposeValueTensors(Context::Ref ctx) {
auto param = opp::wrap_type<ov::op::v0::Parameter>();
auto transpose = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input(), opp::any_input()});
auto concat = opp::wrap_type<ov::op::v0::Concat>({param, transpose});
auto softmax = opp::wrap_type<ov::op::v8::Softmax>({opp::any_input()});
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({softmax, concat});

auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_node_param = node_to_output.at(param).get_node_shared_ptr();
auto matched_node_concat = node_to_output.at(concat).get_node_shared_ptr();
auto matched_node_transpose = node_to_output.at(transpose).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(matmul).get_node_shared_ptr();

auto matched_param = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_param);
auto matched_concat = std::static_pointer_cast<ov::op::v0::Concat>(matched_node_concat);
auto matched_transpose = std::static_pointer_cast<ov::op::v1::Transpose>(matched_node_transpose);
auto matched_matmul = std::static_pointer_cast<ov::op::v0::MatMul>(matched_node_matmul);

auto shape = matched_param->get_partial_shape();
OPENVINO_ASSERT(shape.size() == 4u);
// NB: Transpose Parameter that correspond to V-tensor it will
// speed-up its multiplication with attention scores
std::swap(shape[2], shape[3]);
auto new_param = std::make_shared<ov::opset13::Parameter>(matched_param->get_element_type(), shape);
new_param->set_friendly_name(matched_param->get_friendly_name());
new_param->outputs().begin()->get_tensor().set_names(matched_param->outputs().begin()->get_tensor().get_names());
ov::replace_node(matched_param, new_param);
// NB: Save in order to add/remove to the model later on
ctx.get().new_params.push_back(new_param);
ctx.get().old_params.push_back(matched_param);

auto order_cst = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{4}, {0, 2, 3, 1});
auto new_transpose = std::make_shared<ov::opset13::Transpose>(matched_transpose->input_value(0),
order_cst->output(0));
new_transpose->set_friendly_name(matched_transpose->get_friendly_name());
ov::replace_node(matched_transpose, new_transpose);

auto new_concat = std::make_shared<ov::opset13::Concat>(
ov::OutputVector{new_param->output(0), new_transpose->output(0)}, 3u);
new_concat->set_friendly_name(matched_concat->get_friendly_name());
ov::replace_node(matched_concat, new_concat);

matched_matmul->set_transpose_b(true);

return true;
};
register_matcher(std::make_shared<opp::Matcher>(matmul, "TransposeValueTensors"), std::move(callback));
}
};

for (auto tensor : model->inputs()) {
if (tensor.get_any_name().find("past_key") != std::string::npos) {
ppp.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
class ScaledDotProductAttentionDecomposition : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ScaledDotProductAttentionDecomposition", "0");
ScaledDotProductAttentionDecomposition() {
auto pattern_node = ov::pass::pattern::wrap_type<ov::op::v13::ScaledDotProductAttention>();

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto node = ov::as_type_ptr<ov::op::v13::ScaledDotProductAttention>(
pattern_to_output.at(pattern_node).get_node_shared_ptr());

if (node == nullptr || transformation_callback(node)) {
return false;
}

auto new_output_node = decompose(node);
ov::replace_node(node, new_output_node);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(pattern_node, "ScaledDotProductAttentionDecomposition");
register_matcher(m, std::move(callback));
}
std::shared_ptr<ov::Node> decompose(std::shared_ptr<ov::op::v13::ScaledDotProductAttention> node) {
using namespace ov::op;
using namespace ov;
auto query = node->input_value(0);
auto key = node->input_value(1);
auto value = node->input_value(2);
auto q_shape = register_new_node<v3::ShapeOf>(query, element::i32);
auto k_shape = register_new_node<v3::ShapeOf>(key, element::i32);
auto minus_one = register_new_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto minus_two = register_new_node(v0::Constant::create(element::i32, Shape{}, {-2}));
auto zero_i = register_new_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one_i = register_new_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto one_f = register_new_node<v1::ConvertLike>(one_i, query);
auto zero_f = register_new_node<v1::ConvertLike>(zero_i, query);

Output<Node> scale;
if (node->get_input_size() < 5) {
scale = register_new_node<v8::Gather>(q_shape, minus_one, zero_i)->output(0);
scale = register_new_node<v1::ConvertLike>(scale, query);
auto sqrt_scale = register_new_node<v0::Sqrt>(scale);
scale = register_new_node<v1::Divide>(one_f, sqrt_scale);
} else {
scale = node->input_value(4);
}

auto q_scaled = register_new_node<v1::Multiply>(query, scale);
auto k_rank = register_new_node<v3::ShapeOf>(k_shape, element::i32)->output(0);
auto k_last_dim = register_new_node<v1::Add>(k_rank, minus_one);
auto k_next_dim = register_new_node<v1::Add>(k_rank, minus_two)->output(0);
k_rank = register_new_node<v0::Squeeze>(k_rank, zero_i);
auto minus_inf =
register_new_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}))
->output(0);
auto keep_dim_last = register_new_node<v0::Squeeze>(k_next_dim, zero_i);
auto k_dims_before_transpose = register_new_node<v4::Range>(zero_i, keep_dim_last, one_i, element::i32);

auto scaled_atten = register_new_node<v0::MatMul>(q_scaled, key, false, true)->output(0);
minus_inf = register_new_node<v1::ConvertLike>(minus_inf, scaled_atten);

if (node->get_causal() || node->get_input_size() > 3) {
Output<Node> mask;
Output<Node> atten_mask;
if (!node->get_causal()) {
mask = node->input_value(3);

// two types of masks are supported. A boolean mask where a value of True indicates that the element should
// take part in attention. A float mask of the same type as query, key, value that is added to the attention
// score.
if (mask.get_element_type() == element::boolean) {
atten_mask = register_new_node<v1::ConvertLike>(mask, scaled_atten);
auto inv_mask = register_new_node<v1::LogicalNot>(mask);
atten_mask = register_new_node<v1::Select>(inv_mask, atten_mask, minus_inf);
} else {
atten_mask = mask;
}
} else {
auto target_s_len = register_new_node<v8::Gather>(q_shape, minus_two, zero_i);
auto source_s_len = register_new_node<v8::Gather>(k_shape, minus_two, zero_i);
auto ssl = register_new_node<v0::Unsqueeze>(source_s_len, zero_i);
auto tsl = register_new_node<v0::Unsqueeze>(target_s_len, zero_i);
auto mask_shape = register_new_node<v0::Concat>(OutputVector{tsl, ssl}, 0);
mask = register_new_node<v1::Broadcast>(minus_inf, mask_shape);
auto horizontal_range = register_new_node<v4::Range>(zero_i, source_s_len, one_i, element::i32)->output(0);
horizontal_range = register_new_node<v0::Unsqueeze>(horizontal_range, zero_i);
auto stop = register_new_node<v1::Add>(target_s_len, one_i);
auto vertical_range = register_new_node<v4::Range>(one_i, stop, one_i, element::i32)->output(0);
vertical_range = register_new_node<v0::Unsqueeze>(vertical_range, one_i);
auto triu = register_new_node<v1::GreaterEqual>(horizontal_range, vertical_range);
atten_mask = register_new_node<v1::Select>(triu, mask, zero_f);
}
scaled_atten = register_new_node<v1::Add>(scaled_atten, atten_mask);
}

scaled_atten = register_new_node<v8::Softmax>(scaled_atten, -1);
auto result = register_new_node<v0::MatMul>(scaled_atten, value);
result->set_friendly_name(node->get_friendly_name());
copy_runtime_info(node, get_new_nodes());
return result;
}
};

std::shared_ptr<ov::Model> cvt_value_tensors_layout(std::shared_ptr<ov::Model> model) {
ov::preprocess::PrePostProcessor ppp(model);
for (auto tensor : model->outputs()) {
if (tensor.get_any_name().find("present") != std::string::npos) {
ppp.output(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
if (tensor.get_any_name().find("value") != std::string::npos) {
// NB: [batch, num_heads, seq_len, emb_size] -> [batch, num_heads, emb_size, seq_len]
ppp.output(tensor.get_any_name()).model().set_layout(ov::Layout("BHSE"));
ppp.output(tensor.get_any_name()).tensor().set_layout(ov::Layout("BHES"));
}
}

return ppp.build();
}

bool optimize_value_tensors(std::shared_ptr<ov::Model> model) {
ov::pass::GraphRewrite rewr;
rewr.add_matcher<ScaledDotProductAttentionDecomposition>();
TransposeValueTensors::Context ctx;
rewr.add_matcher<TransposeValueTensors>(std::ref(ctx));
rewr.run_on_model(model);

model->add_parameters(ctx.new_params);
for (auto old_param : ctx.old_params) {
model->remove_parameter(old_param);
}
ov::pass::Validate().run_on_model(model);

// NB: if new_params is not empty - pass has been applied
return !ctx.new_params.empty();
}

struct KVAxesPosition {
uint32_t batch;
uint32_t seq_len;
Expand Down Expand Up @@ -251,41 +469,49 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
auto kvcache_model = model->clone();
LOG_DEBUG("2. Transform kvcache model from stateful to stateless.");
ov::pass::StatefulToStateless().run_on_model(kvcache_model);

LOG_DEBUG("3. Creating prefill model as clone of transformed kvcache one.");
LOG_DEBUG("3. Align u4 ZP constants.");
align_u4_zp_constants(kvcache_model);
LOG_DEBUG("4. Creating prefill model as clone of transformed kvcache one.");
auto prefill_model = kvcache_model->clone();
prefill_model->set_friendly_name(kvcache_model->get_friendly_name() + "_prefill");
LOG_DEBUG("4. Converting KV-cache in prefill model to FP16.");
prefill_model = cvt_kvcache_to_fp16(prefill_model);

LOG_DEBUG("5. Optimize kvcache kvcache model to output key/values for new token.");
kvcache_model = redirect_new_kv_to_output(kvcache_model);
LOG_DEBUG("6. Converting KV-cache in kvcache model to FP16.");
kvcache_model = cvt_kvcache_to_fp16(kvcache_model);

const ::intel_npu::npuw::llm::ModelDesc model_desc = m_cfg.get<::intel_npu::NPUW_LLM_MODEL_DESC>();
const uint32_t kMaxPromptLen = align_to(m_cfg.get<::intel_npu::NPUW_LLM_MAX_PROMPT_LEN>(), 64u);
const uint32_t kMinResponseLen = align_to(m_cfg.get<::intel_npu::NPUW_LLM_MIN_RESPONSE_LEN>(), 64u);
const ::intel_npu::npuw::llm::ModelDesc model_desc = m_cfg.get<::intel_npu::NPUW_LLM_MODEL_DESC>();
KVAxesPosition axes = get_kv_axes(model_desc.type);
m_kvcache_desc = KVCacheDesc{kMaxPromptLen, kMaxPromptLen + kMinResponseLen, 0u, axes.seq_len};
LOG_DEBUG("7. Make prefill model with static shapes");
LOG_DEBUG("5. Make prefill model with static shapes");
reshape_to_static(prefill_model, m_kvcache_desc.max_prompt_size, m_kvcache_desc.max_prompt_size, axes);
LOG_DEBUG("8. Make kvcache model with static shapes");
LOG_DEBUG("6. Make kvcache model with static shapes");
reshape_to_static(kvcache_model, 1u, m_kvcache_desc.total_size, axes);
LOG_DEBUG("7.Check and apply opt layout if applicable.");
// NB: Try to apply opt transpose only for Llama-2-7b-chat-hf model
if ( model_desc.name_or_path == "meta-llama/Llama-2-7b-chat-hf" ||
(model_desc.type == "llama" && model_desc.num_key_value_heads == 32)) {
if (optimize_value_tensors(kvcache_model)) {
// NB: Check if TransposeValueTensors transformation was applied
m_kvcache_desc.v_tensors_transposed = true;
prefill_model = cvt_value_tensors_layout(prefill_model);
}
}
LOG_DEBUG("8. Optimize kvcache model to output key/values for new token.");
kvcache_model = redirect_new_kv_to_output(kvcache_model);
LOG_DEBUG("9. Converting KV-cache in kvcache model to FP16.");
kvcache_model = cvt_kvcache_to_fp16(kvcache_model);
LOG_DEBUG("10. Converting KV-cache in prefill model to FP16.");
prefill_model = cvt_kvcache_to_fp16(prefill_model);

auto npudesc = extract_npu_descriptor(plugin);

ov::AnyMap properties_copy = other_props;
auto prefill_config = get_default_prefill_config(model, npudesc);

// NB: GENERATE_HINT is only applicable for default generate config!
const ::intel_npu::npuw::llm::GenerateHint generate_hint = m_cfg.get<::intel_npu::NPUW_LLM_GENERATE_HINT>();
LOG_DEBUG("9. Passed GENERATE_HINT: " << std::string(::intel_npu::NPUW_LLM_GENERATE_HINT::toString(generate_hint)));
LOG_DEBUG("11. Passed GENERATE_HINT: " << std::string(::intel_npu::NPUW_LLM_GENERATE_HINT::toString(generate_hint)));
auto generate_config = get_default_generate_config(model, npudesc, generate_hint);

merge_config_with(prefill_config, properties_copy);
merge_config_with(generate_config, properties_copy);
// FIXME: Drop CACHE_DIR option if NPUW is enabled
drop_cache_dir(prefill_config);
drop_cache_dir(generate_config);

m_kvcache_compiled = std::make_shared<ov::npuw::CompiledModel>(kvcache_model, plugin, generate_config);
m_prefill_compiled = std::make_shared<ov::npuw::CompiledModel>(prefill_model, plugin, prefill_config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LLMCompiledModel : public ov::npuw::ICompiledModel {
uint32_t total_size = 0u;
uint32_t num_stored_tokens = 0u;
uint32_t dim = 0u;
bool v_tensors_transposed = false;
};

LLMCompiledModel(const std::shared_ptr<ov::Model>& model,
Expand Down
Loading

0 comments on commit a4b0b81

Please sign in to comment.