Skip to content

Commit

Permalink
Fixed review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AsyaPronina committed Dec 23, 2024
1 parent a4b0b81 commit b06e640
Showing 1 changed file with 8 additions and 27 deletions.
35 changes: 8 additions & 27 deletions src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,6 @@ std::shared_ptr<ov::Model> cvt_kvcache_to_fp16(const std::shared_ptr<ov::Model>&
return ppp.build();
}

void align_u4_zp_constants(const std::shared_ptr<ov::Model>& model) {
for (auto op : model->get_ops()) {
if (ov::op::util::is_constant(op)) {
auto cst_op = std::dynamic_pointer_cast<ov::op::v0::Constant>(op);
const auto cst_op_out = cst_op->output(0);
if (cst_op_out.get_element_type() == ov::element::u4 && ov::shape_size(cst_op_out.get_shape()) == 1u) {
ov::Tensor cst_tensor(ov::element::u4, cst_op_out.get_shape());
*static_cast<uint8_t*>(cst_tensor.data()) = cst_op->get_vector<uint8_t>()[0] & 0x0f;
auto new_cst_op = std::make_shared<ov::op::v0::Constant>(cst_tensor);
for (auto target_input : cst_op_out.get_target_inputs()) {
target_input.replace_source_output(new_cst_op);
}
}
}
}
}

std::shared_ptr<ov::Model> redirect_new_kv_to_output(const std::shared_ptr<ov::Model>& model) {
const auto kStartOutputKVCacheLayers = 1u;
for (std::size_t i = kStartOutputKVCacheLayers; i < model->outputs().size(); ++i) {
Expand Down Expand Up @@ -469,9 +452,7 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
auto kvcache_model = model->clone();
LOG_DEBUG("2. Transform kvcache model from stateful to stateless.");
ov::pass::StatefulToStateless().run_on_model(kvcache_model);
LOG_DEBUG("3. Align u4 ZP constants.");
align_u4_zp_constants(kvcache_model);
LOG_DEBUG("4. Creating prefill model as clone of transformed kvcache one.");
LOG_DEBUG("3. Creating prefill model as clone of transformed kvcache one.");
auto prefill_model = kvcache_model->clone();
prefill_model->set_friendly_name(kvcache_model->get_friendly_name() + "_prefill");

Expand All @@ -480,11 +461,11 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
const uint32_t kMinResponseLen = align_to(m_cfg.get<::intel_npu::NPUW_LLM_MIN_RESPONSE_LEN>(), 64u);
KVAxesPosition axes = get_kv_axes(model_desc.type);
m_kvcache_desc = KVCacheDesc{kMaxPromptLen, kMaxPromptLen + kMinResponseLen, 0u, axes.seq_len};
LOG_DEBUG("5. Make prefill model with static shapes");
LOG_DEBUG("4. Make prefill model with static shapes");
reshape_to_static(prefill_model, m_kvcache_desc.max_prompt_size, m_kvcache_desc.max_prompt_size, axes);
LOG_DEBUG("6. Make kvcache model with static shapes");
LOG_DEBUG("5. Make kvcache model with static shapes");
reshape_to_static(kvcache_model, 1u, m_kvcache_desc.total_size, axes);
LOG_DEBUG("7.Check and apply opt layout if applicable.");
LOG_DEBUG("6.Check and apply opt layout if applicable.");
// NB: Try to apply opt transpose only for Llama-2-7b-chat-hf model
if ( model_desc.name_or_path == "meta-llama/Llama-2-7b-chat-hf" ||
(model_desc.type == "llama" && model_desc.num_key_value_heads == 32)) {
Expand All @@ -494,11 +475,11 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
prefill_model = cvt_value_tensors_layout(prefill_model);
}
}
LOG_DEBUG("8. Optimize kvcache model to output key/values for new token.");
LOG_DEBUG("7. Optimize kvcache model to output key/values for new token.");
kvcache_model = redirect_new_kv_to_output(kvcache_model);
LOG_DEBUG("9. Converting KV-cache in kvcache model to FP16.");
LOG_DEBUG("8. Converting KV-cache in kvcache model to FP16.");
kvcache_model = cvt_kvcache_to_fp16(kvcache_model);
LOG_DEBUG("10. Converting KV-cache in prefill model to FP16.");
LOG_DEBUG("9. Converting KV-cache in prefill model to FP16.");
prefill_model = cvt_kvcache_to_fp16(prefill_model);

auto npudesc = extract_npu_descriptor(plugin);
Expand All @@ -507,7 +488,7 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m

// NB: GENERATE_HINT is only applicable for default generate config!
const ::intel_npu::npuw::llm::GenerateHint generate_hint = m_cfg.get<::intel_npu::NPUW_LLM_GENERATE_HINT>();
LOG_DEBUG("11. Passed GENERATE_HINT: " << std::string(::intel_npu::NPUW_LLM_GENERATE_HINT::toString(generate_hint)));
LOG_DEBUG("10. Passed GENERATE_HINT: " << std::string(::intel_npu::NPUW_LLM_GENERATE_HINT::toString(generate_hint)));
auto generate_config = get_default_generate_config(model, npudesc, generate_hint);

merge_config_with(prefill_config, properties_copy);
Expand Down

0 comments on commit b06e640

Please sign in to comment.