Skip to content

Commit

Permalink
fix f16 path
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 committed Mar 21, 2024
1 parent c791693 commit 5117108
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,10 @@ bool fuse_type_to_parameter(const std::shared_ptr<ov::Node>& node,
auto convert = std::make_shared<opset4::Convert>(param, to);
for (auto& input : param_consumers) {
const auto consumer = input.get_node();
if (ov::is_type<ov::op::v0::Result>(consumer) || ov::is_type<ov::op::v0::Convert>(consumer)) {
if (ov::is_type<ov::op::v0::Result>(consumer) || ov::is_type<ov::op::v0::Convert>(consumer) ||
// TODO: refactor after ngraph op defined
// The fourth and fifth inputs are kvcache and should be directly connected to parameters
(consumer->get_type_name() == std::string("PagedAttentionExtension") && (input.get_index() == 3 || input.get_index() == 4))) {
continue;
}
input.replace_source_output(convert);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
auto block_idx = slot_mapping.ptr<int32_t>(b)[m];
if (block_idx < 0) return;
attn_copy(past_k_output.ptr<T2>(block_idx, h, m, 0),
k_input.ptr<T>(b, h, 0),
attn_copy(past_k_output.ptr<T2>(block_idx, h, 0),
k_input.ptr<T>(b, h, m, 0),
S);
attn_copy(past_v_output.ptr<T2>(block_idx, h, m, 0),
v_input.ptr<T>(b, h, 0),
attn_copy(past_v_output.ptr<T2>(block_idx, h, 0),
v_input.ptr<T>(b, h, m, 0),
S);
});
}
Expand Down

0 comments on commit 5117108

Please sign in to comment.