Skip to content

Commit

Permalink
[GPU] Fixed stateful KV cache issues (openvinotoolkit#21618)
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-paramuzov authored Dec 13, 2023
1 parent 3e2037c commit 2bcc940
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#pragma once

#include "openvino/core/partial_shape.hpp"
#include "openvino/op/broadcast.hpp"

#include "primitive.hpp"
Expand Down Expand Up @@ -131,6 +132,8 @@ struct broadcast : public primitive_base<broadcast> {
/// along which broadcast should happen.
std::vector<uint16_t> broadcast_axes;

ov::PartialShape output_pshape = ov::PartialShape::dynamic();

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_range(seed, broadcast_axes.begin(), broadcast_axes.end());
Expand All @@ -146,7 +149,8 @@ struct broadcast : public primitive_base<broadcast> {

return axes_mapping == rhs_casted.axes_mapping &&
broadcast_mode == rhs_casted.broadcast_mode &&
broadcast_sizes == rhs_casted.broadcast_sizes;
broadcast_sizes == rhs_casted.broadcast_sizes &&
output_pshape == rhs_casted.output_pshape;
}

void save(BinaryOutputBuffer& ob) const override {
Expand All @@ -156,6 +160,7 @@ struct broadcast : public primitive_base<broadcast> {
ob << make_data(&broadcast_mode, sizeof(ov::op::BroadcastModeSpec));
ob << broadcast_sizes;
ob << broadcast_axes;
ob << output_pshape;
}

void load(BinaryInputBuffer& ib) override {
Expand All @@ -165,6 +170,7 @@ struct broadcast : public primitive_base<broadcast> {
ib >> make_data(&broadcast_mode, sizeof(ov::op::BroadcastModeSpec));
ib >> broadcast_sizes;
ib >> broadcast_axes;
ib >> output_pshape;
}
};
} // namespace cldnn
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ std::vector<layout> broadcast_inst::calc_output_layouts(broadcast_node const& /*
if (input1.is_static()) {
output_rank = input1.get_dim(0); // target shape rank is set as second input.
}
output_shapes[0] = ShapeType::dynamic(std::max(static_cast<int>(output_rank), 1));
output_shapes[0] = desc->output_pshape.rank().is_static() ? desc->output_pshape : ShapeType::dynamic(std::max(static_cast<int>(output_rank), 1));
}

format output_format = format::adjust_to_rank(input0_layout.format, output_shapes[0].size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "shape_of_inst.h"
#include "read_value_inst.h"
#include "reshape_inst.h"
#include "eltwise_inst.h"
#include "pass_manager.h"
Expand Down Expand Up @@ -43,6 +44,10 @@ bool mark_shape_of_subgraphs::can_mark_node(const program_node& node) {
if (node.has_fused_primitives())
return false;

// read_value may have initializer which is shape_of sub-graph, but read_value itself is not a part of such sub-graph
if (node.is_type<read_value>())
return false;

if (node.is_type<reshape>())
return true;

Expand Down
7 changes: 2 additions & 5 deletions src/plugins/intel_gpu/src/graph/include/read_value_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,8 @@ class typed_primitive_inst<read_value> : public typed_primitive_inst_base<read_v
static std::vector<layout> calc_output_layouts(read_value_node const& /*node*/, const kernel_impl_params& impl_param) {
auto desc = impl_param.typed_desc<read_value>();
const auto default_layout = desc->output_layout;
auto out_layout = impl_param.state_layout.value_or(default_layout);
if (out_layout.is_dynamic() && desc->input_size() > 0) {
out_layout = impl_param.get_input_layout(0);
}
return { out_layout };

return { impl_param.state_layout.value_or(default_layout) };
}

static layout calc_output_layout(const read_value_node& node, kernel_impl_params const& impl_param);
Expand Down
44 changes: 26 additions & 18 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,30 @@ void primitive_inst::update_shape() {
}

if (get_node().is_type<read_value>()) {
const auto& variable_id = get_node().as<read_value>().get_primitive()->variable_id;
auto new_layout = get_network().get_variable(variable_id).get_layout();
auto prim = get_node().as<read_value>().get_primitive();
const auto& variable_id = prim->variable_id;
auto& variable = get_network().get_variable(variable_id);
// Initial variable shape is taken from variable itself
auto new_layout = variable.get_layout();

// If variable is not set and we have an initializer - use it's shape as shape of variable
if (!variable.is_set() && _impl_params->input_layouts.size() == 1) {
new_layout = _impl_params->get_input_layout(0);
}

// If we still have a dynamic dimension, which basiclly means that we don't have an initializer, then replace dynamic dims with 0
if (new_layout.is_dynamic()) {
auto pshape = new_layout.get_partial_shape();
for (auto& d : pshape) {
if (d.is_dynamic()) {
d = 0;
}
}
new_layout.set_partial_shape(pshape);
}

variable.set_layout(new_layout);

if (!_impl_params->state_layout.has_value() || _impl_params->state_layout.value() != new_layout) {
_impl_params->state_layout = new_layout;
input_shape_changed = true;
Expand Down Expand Up @@ -299,7 +321,7 @@ void primitive_inst::update_shape() {
}
}
if (!subgraph_input_changed) {
GPU_DEBUG_TRACE_DETAIL << id() << ": skip shape_update, because it is in shape_of_subgrap and input shape is not changed\n";
GPU_DEBUG_TRACE_DETAIL << id() << ": skip shape_update, because it is in shape_of_subgraph and input shape is not changed\n";
reset_shape_change();
return;
}
Expand Down Expand Up @@ -402,20 +424,6 @@ void primitive_inst::update_shape() {
get_network().get_variable(desc->variable_id).set_layout(_impl_params->get_output_layout());
_impl_params->state_layout = _impl_params->get_output_layout();
}

if (get_node().is_type<read_value>()) {
auto desc = get_node().as<read_value>().get_primitive();
if (_impl_params->output_layouts[0].is_dynamic()) {
auto pshape = _impl_params->output_layouts[0].get_partial_shape();
for (auto& d : pshape) {
if (d.is_dynamic()) {
d = 0;
}
}
_impl_params->output_layouts[0].set_partial_shape(pshape);
}
get_network().get_variable(desc->variable_id).set_layout(_impl_params->get_output_layout());
}
}

event::ptr primitive_inst::realloc_if_needed() {
Expand Down Expand Up @@ -448,7 +456,7 @@ event::ptr primitive_inst::realloc_if_needed() {
// read_value/assign nodes are supposed to always use variable memory
if (auto stateful_prim = dynamic_cast<memory_state::variable*>(this)) {
std::string variable_id = stateful_prim->variable_id();
auto variable = get_network().get_variable(variable_id);
auto& variable = get_network().get_variable(variable_id);
variable.set_layout(actual_layout);
GPU_DEBUG_TRACE_DETAIL << id() << ": use variable memory " << variable.get_memory()
<< " (size=" << variable.get_memory()->size() << ")" << std::endl;
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/plugin/ops/broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ static void CreateCommonBroadcastOp(ProgramBuilder& p, const std::shared_ptr<ov:
mode);
}

broadcast_prim->output_pshape = op->get_output_partial_shape(0);

p.add_primitive(*op, broadcast_prim);
}

Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/plugin/variable_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "intel_gpu/plugin/variable_state.hpp"
#include "intel_gpu/runtime/memory_caps.hpp"
#include "intel_gpu/runtime/layout.hpp"
#include "intel_gpu/runtime/debug_configuration.hpp"

#include <memory>

Expand Down Expand Up @@ -45,6 +46,7 @@ void VariableState::set() {

void VariableState::set_layout(const cldnn::layout& new_layout) {
m_layout = new_layout;
GPU_DEBUG_TRACE_DETAIL << "Update state layout to " << new_layout.to_short_string() << std::endl;
update_device_buffer();
}

Expand Down
52 changes: 49 additions & 3 deletions src/plugins/intel_gpu/tests/common/subgraphs_builders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@
#include <memory>
#include "openvino/core/dimension.hpp"
#include "openvino/core/model.hpp"
#include "openvino/core/node_vector.hpp"
#include "openvino/core/partial_shape.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/read_value.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/result.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/util/read_value_base.hpp"
#include "openvino/pass/make_stateful.hpp"

namespace tests {
Expand All @@ -22,7 +29,9 @@ inline std::shared_ptr<ov::Model> make_llm_kv_cache_pattern(ov::Dimension batch
ov::Dimension n_heads = ov::Dimension::dynamic(),
ov::Dimension n_features = ov::Dimension::dynamic(),
ov::element::Type_t element_type = ov::element::f32,
bool stateful = false) {
bool stateful = false,
bool fuse_cache_reorder = false,
bool build_state_initializer = false) {
ov::PartialShape kv_cache_size = {batch, n_heads, -1, n_features};
ov::PartialShape new_token_size = {batch, -1, n_heads, n_features};
ov::PartialShape matmul_in_size = {batch, n_heads, -1, -1};
Expand All @@ -34,23 +43,60 @@ inline std::shared_ptr<ov::Model> make_llm_kv_cache_pattern(ov::Dimension batch
auto in_matmul = std::make_shared<ov::op::v0::Parameter>(element_type, matmul_in_size);
in_matmul->set_friendly_name("in_matmul");

ov::ParameterVector params{in_kv_prev, in_new_token, in_matmul};
std::shared_ptr<ov::Node> concat_input = in_kv_prev;
if (fuse_cache_reorder) {
auto in_beam_idx = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{batch});
in_beam_idx->set_friendly_name("beam_idx");
params.push_back(in_beam_idx);
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 0);
auto gather = std::make_shared<ov::op::v8::Gather>(in_kv_prev, in_beam_idx, axis, 0);
concat_input = gather;
}

std::shared_ptr<ov::Node> state_initializer = nullptr;
if (stateful && build_state_initializer) {
auto shapeof = std::make_shared<ov::op::v3::ShapeOf>(in_new_token, ov::element::i32);

auto indices = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, 0);
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 0);
auto gather = std::make_shared<ov::op::v8::Gather>(shapeof, indices, axis, 0);

auto bcast_value = std::make_shared<ov::op::v0::Constant>(element_type, ov::Shape{}, 0.0f);
ov::NodeVector dims = {gather};
for (size_t i = 1; i < kv_cache_size.size(); i++) {
dims.push_back(std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, static_cast<int32_t>(kv_cache_size[i].get_min_length())));
}
auto shape = std::make_shared<ov::op::v0::Concat>(dims, 0);
state_initializer = std::make_shared<ov::op::v3::Broadcast>(bcast_value, shape);
}

auto transpose_const = ov::op::v0::Constant::create(ov::element::i32, {new_token_size.size()}, {0, 2, 1, 3});
auto transpose = std::make_shared<ov::op::v1::Transpose>(in_new_token, transpose_const);
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{in_kv_prev, transpose}, 2);
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{concat_input, transpose}, 2);
auto convert = std::make_shared<ov::op::v0::Convert>(concat, element_type);
auto kv_present = std::make_shared<ov::op::v0::Result>(convert);
kv_present->set_friendly_name("present_key_values");
auto matmul = std::make_shared<ov::op::v0::MatMul>(in_matmul, concat, false, false);
auto matmul_out = std::make_shared<ov::op::v0::Result>(matmul);
matmul_out->set_friendly_name("matmul_out");

ov::ParameterVector params{in_kv_prev, in_new_token, in_matmul};
ov::ResultVector results{kv_present, matmul_out};
auto model = std::make_shared<ov::Model>(results, params, "LLM-KV-Cache");
if (stateful) {
ov::pass::MakeStateful({{in_kv_prev, kv_present}}).run_on_model(model);
}

if (state_initializer) {
for (auto op : model->get_ops()) {
if (auto read_value = std::dynamic_pointer_cast<ov::op::v6::ReadValue>(op)) {
read_value->set_arguments(ov::OutputVector{state_initializer});
break;
}
}
}
model->validate_nodes_and_infer_types();

return model;
}

Expand Down
Loading

0 comments on commit 2bcc940

Please sign in to comment.