Skip to content

Commit

Permalink
StaticLLMPipeline - align u4 zero points (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
TolyaTalamanov authored Jul 30, 2024
1 parent 06c57b7 commit e286469
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@

namespace {

void align_u4_zp_constants(const std::shared_ptr<ov::Model>& model) {
for (auto op : model->get_ops()) {
if (ov::op::util::is_constant(op)) {
auto cst_op = std::dynamic_pointer_cast<ov::op::v0::Constant>(op);
const auto cst_op_out = cst_op->output(0);
if (cst_op_out.get_element_type() == ov::element::u4 && ov::shape_size(cst_op_out.get_shape()) == 1u) {
ov::Tensor cst_tensor(ov::element::u4, cst_op_out.get_shape());
*static_cast<uint8_t*>(cst_tensor.data()) = cst_op->get_vector<uint8_t>()[0] & 0x0f;
auto new_cst_op = std::make_shared<ov::op::v0::Constant>(cst_tensor);
for (auto target_input : cst_op_out.get_target_inputs()) {
target_input.replace_source_output(new_cst_op);
}
}
}
}
}

std::shared_ptr<ov::Model> add_slices_to_kvcache_inputs(const std::shared_ptr<ov::Model>& model) {
const auto kvcache_name_pattern = "past_key_values";
std::vector<std::shared_ptr<ov::opset13::Parameter>> new_params;
Expand Down Expand Up @@ -147,6 +164,7 @@ StaticLLMPipeline::StaticLLMPipeline(
m_kvcache_model = core.read_model(path / "openvino_model.xml");
// (2) Expose KV-cache input and output layers from kvcache model
ov::pass::StatefulToStateless().run_on_model(m_kvcache_model);
align_u4_zp_constants(m_kvcache_model);
// (3) Clone the model - this will be prefill
m_prefill_model = m_kvcache_model->clone();
m_prefill_model->set_friendly_name(m_kvcache_model->get_friendly_name() + "_prefill");
Expand Down

0 comments on commit e286469

Please sign in to comment.