Skip to content

Commit

Permalink
[GPU] Fix Crop->Reshape (Squeeze/Unsqueeze modes) buffer optimization (
Browse files Browse the repository at this point in the history
…#25836)

These changes fix a significant accuracy issue (reducing perplexity from
120 000 to 17) for Llama models with precalculated constant sin/cos
values. However, there is still a problem with sin/cos representation in
FP16 precision, which will be addressed in a separate PR.

### Details:
 - Fixed Crop->Reshape (Squeeze/Unsqueeze modes) buffer optimization
 - Update rope_ref kernel to support dynamic paddings for cos/sin inputs
 - Fix propagate_padding() function and update shape infer tests

### Tickets:
- [CVS-148220](https://jira.devtools.intel.com/browse/CVS-148220),
[CVS-146283](https://jira.devtools.intel.com/browse/CVS-146283)
  • Loading branch information
sshlyapn authored Aug 2, 2024
1 parent b2319a5 commit a33afe4
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,12 @@ bool crop_in_place_optimization::match(const program_node& node,
if (node.get_program().is_body_program() && node.get_dependency(0).is_type<lstm_elt>()) {
return false;
}

GPU_DEBUG_GET_INSTANCE(debug_config);
GPU_DEBUG_IF(debug_config->disable_runtime_buffer_fusing && node.is_dynamic()) {
return false;
}

// optimization is available for cropping across depth(features) or batch
// if output padding has defined padding across features already it wouldn't
// work because it expect to have zeros in the padded area.
Expand Down Expand Up @@ -553,18 +559,22 @@ bool crop_in_place_optimization::optimize(crop_node& node) {
node.get_primitive()->axis,
false);
} else if (can_crop_be_optimized_simple_data_format(crop_layout, input_layout)) {
std::vector<layout> reshape_layouts;
if (node.get_users().front()->is_type<reshape>() && node.get_users().front()->as<reshape>().is_runtime_propagatable_padding()) {
reshape_layouts.push_back(node.get_users().front()->get_output_layout());
std::pair<const program_node*, layout> user_info;
if (node.get_users().front()->is_type<reshape>()) {
auto& reshape_node = node.get_users().front()->as<reshape>();
if (reshape_node.is_runtime_propagatable_padding()) {
user_info.first = &reshape_node;
user_info.second = reshape_node.get_output_layout();
}
}
update_in_place_crop_padding_simple_data_format(crop_layout,
input_layout,
reshape_layouts,
user_info,
crop_params->input_offsets[0],
node.get_primitive()->axis,
false);
if (reshape_layouts.size() > 0) {
node.get_users().front()->set_output_layout(reshape_layouts[0]);
if (user_info.first) {
node.get_users().front()->set_output_layout(user_info.second);
}
}
node.set_output_layout(crop_layout);
Expand Down Expand Up @@ -632,24 +642,51 @@ void crop_in_place_optimization::update_in_place_crop_padding_along_feature(cons

void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format(layout& crop_layout,
layout& input_layout,
std::vector<layout>& user_layouts,
std::pair<const program_node*, layout>& user_info,
const tensor offsets,
size_t crop_axis,
bool is_runtime) {
auto crop_axis_legacy = crop_axis;
if (crop_axis_legacy >= 2) {
auto spatial_axis = crop_axis_legacy - 2;
// Default and minimum number of dimensions is 4
auto spatial_size = std::max<size_t>(crop_layout.get_partial_shape().size(), 4) - 2;
crop_axis_legacy = spatial_size - spatial_axis - 1 + 2;
}
auto convert_axis_to_legacy = [](size_t axis, size_t rank) {
auto axis_legacy = axis;
if (axis_legacy >= 2) {
auto spatial_axis = axis_legacy - 2;
// Default and minimum number of dimensions is 4
auto spatial_size = std::max<size_t>(rank, 4) - 2;
axis_legacy = spatial_size - spatial_axis - 1 + 2;
}

return axis_legacy;
};

auto crop_axis_legacy = convert_axis_to_legacy(crop_axis, crop_layout.get_partial_shape().size());

// If it's build-time and node is dynamic, only dynamic padding is set first
if ((crop_layout.is_dynamic() || input_layout.is_dynamic()) && !is_runtime) {
auto dyn_pad_sizes = tensor(0).sizes();
dyn_pad_sizes[crop_axis_legacy] = 1;
crop_layout.data_padding.set_dynamic_pad(tensor(dyn_pad_sizes));
for (auto& user_layout : user_layouts) {
user_layout.data_padding.set_dynamic_pad(tensor(dyn_pad_sizes));

if (user_info.first && user_info.first->is_type<reshape>()) {
auto reshape_desc = user_info.first->as<reshape>().get_primitive();
auto reshape_mode = reshape_desc->mode;
if (reshape_mode == reshape::reshape_mode::base) {
user_info.second.data_padding.set_dynamic_pad(tensor(dyn_pad_sizes));
} else if (reshape_mode == reshape::reshape_mode::unsqueeze || reshape_mode == reshape::reshape_mode::squeeze) {
auto reshape_ps = user_info.second.get_partial_shape();
auto output_pattern = reshape_desc->output_pattern;

auto reshape_axis = crop_axis;
for (size_t i = 0; i < output_pattern.size(); i++) {
if (output_pattern[i] <= static_cast<int64_t>(reshape_axis)) {
reshape_axis += reshape_mode == reshape::reshape_mode::unsqueeze ? 1 : -1;
}
}

auto dyn_pad_mask = tensor(0).sizes();
auto reshape_axis_legacy = convert_axis_to_legacy(reshape_axis, reshape_ps.size());
dyn_pad_mask[reshape_axis_legacy] = 1;
user_info.second.data_padding.set_dynamic_pad(tensor(dyn_pad_mask));
}
}
return;
}
Expand All @@ -673,14 +710,40 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
auto dyn_pad_sizes = lower_sizes;
dyn_pad_sizes[crop_axis_legacy] = 1;
crop_layout.data_padding = padding(lower_sizes, upper_sizes, 0.f, tensor(dyn_pad_sizes));
for (auto& user_layout : user_layouts) {
auto reshape_rank = user_layout.get_partial_shape().size();
auto reshape_last_dim = user_layout.get_partial_shape().to_shape()[reshape_rank - 1];
if (lower_sizes[crop_axis_legacy])
lower_sizes[crop_axis_legacy] /= reshape_last_dim;
if (upper_sizes[crop_axis_legacy])
upper_sizes[crop_axis_legacy] /= reshape_last_dim;
user_layout.data_padding = padding(lower_sizes, upper_sizes, 0.f, tensor(dyn_pad_sizes));
if (user_info.first) {
auto reshape_desc = user_info.first->as<reshape>().get_primitive();
auto reshape_mode = reshape_desc->mode;
if (reshape_mode == reshape::reshape_mode::base) {
auto reshape_rank = user_info.second.get_partial_shape().size();
auto reshape_last_dim = user_info.second.get_partial_shape().to_shape()[reshape_rank - 1];
if (lower_sizes[crop_axis_legacy])
lower_sizes[crop_axis_legacy] /= reshape_last_dim;
if (upper_sizes[crop_axis_legacy])
upper_sizes[crop_axis_legacy] /= reshape_last_dim;
user_info.second.data_padding = padding(lower_sizes, upper_sizes, 0.f, tensor(dyn_pad_sizes));
} else {
auto reshape_ps = user_info.second.get_partial_shape();
auto output_pattern = reshape_desc->output_pattern;

auto reshape_axis = crop_axis;
for (size_t i = 0; i < output_pattern.size(); i++) {
if (output_pattern[i] <= static_cast<int64_t>(reshape_axis)) {
reshape_axis += reshape_mode == reshape::reshape_mode::unsqueeze ? 1 : -1;
}
}

const auto output_rank = std::max(reshape_ps.size(), static_cast<size_t>(4));
std::vector<int32_t> reshape_lower_sizes(output_rank, 0);
std::vector<int32_t> reshape_upper_sizes(output_rank, 0);
std::vector<int32_t> reshape_dyn_pad_mask(output_rank, 0);

const auto reshape_axis_legacy = convert_axis_to_legacy(reshape_axis, reshape_ps.size());
reshape_lower_sizes[reshape_axis_legacy] = lower_sizes[crop_axis_legacy];
reshape_upper_sizes[reshape_axis_legacy] = upper_sizes[crop_axis_legacy];
reshape_dyn_pad_mask[reshape_axis_legacy] = 1;

user_info.second.data_padding = padding(reshape_lower_sizes, reshape_upper_sizes, 0.f, tensor(reshape_dyn_pad_mask));
}
}
} else {
crop_layout.data_padding = padding(lower_sizes, upper_sizes);
Expand Down Expand Up @@ -743,18 +806,23 @@ void prepare_buffer_fusing::run(program& p) {
node.get_primitive()->axis,
false);
} else if (crop_in_place_optimization::can_crop_be_optimized_simple_data_format(crop_layout, pred_layout)) {
std::pair<const program_node*, layout> user_info;
std::vector<layout> reshape_layouts;
if (node.get_users().front()->is_type<reshape>() && node.get_users().front()->as<reshape>().is_runtime_propagatable_padding()) {
reshape_layouts.push_back(node.get_users().front()->get_output_layout());
if (node.get_users().front()->is_type<reshape>()) {
auto& reshape_node = node.get_users().front()->as<reshape>();
if (reshape_node.is_runtime_propagatable_padding()) {
user_info.first = &reshape_node;
user_info.second = reshape_node.get_output_layout();
}
}
crop_in_place_optimization::update_in_place_crop_padding_simple_data_format(crop_layout,
pred_layout,
reshape_layouts,
user_info,
crop_params->input_offsets[0],
node.get_primitive()->axis,
false);
if (reshape_layouts.size() > 0) {
node.get_users().front()->set_output_layout(reshape_layouts[0]);
if (user_info.first) {
node.get_users().front()->set_output_layout(user_info.second);
}
}
node.set_output_layout(crop_layout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct crop_in_place_optimization : pattern_match_optimization_typed<crop_in_pla
bool is_runtime);
static void update_in_place_crop_padding_simple_data_format(layout& crop_layout,
layout& pred_layout,
std::vector<layout>& user_layouts,
std::pair<const program_node*, layout>& user_info,
const tensor offsets,
size_t crop_axis,
bool is_runtime);
Expand Down
7 changes: 5 additions & 2 deletions src/plugins/intel_gpu/src/graph/include/reshape_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {

bool is_runtime_propagatable_padding() const {
auto prim = typed_desc();
if (prim->mode == reshape::reshape_mode::squeeze || prim->mode == reshape::reshape_mode::unsqueeze)
return true;
if (prim->mode == reshape::reshape_mode::squeeze || prim->mode == reshape::reshape_mode::unsqueeze) {
// For proper padding propagation we need to know output pattern at model loading stage
// in case of squeeze/unsqueeze mode
return prim->output_pattern.size() > 0;
}

// TODO: This function is to limit condition to a specific case (crop + reshape) among cases for the base mode
if (!input().is_type<crop>())
Expand Down
12 changes: 6 additions & 6 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1485,15 +1485,16 @@ void primitive_inst::do_runtime_in_place_crop() {
u->update_shape_done_by_other = true;

const auto& crop_users = u->get_user_insts();
std::vector<layout> reshape_layouts;
std::pair<const program_node*, layout> user_info;
if (crop_users.front()->get_node().is_type<reshape>()) {
OPENVINO_ASSERT(crop_users.size() == 1, "[GPU] Expected number of reshape users is 1, but it is ", crop_users.size());
auto reshape_inst = crop_users.front();
if (!reshape_inst->update_shape_done_by_other) {
GPU_DEBUG_TRACE_DETAIL << "[In place crop] update shape for " << reshape_inst->id() << std::endl;
reshape_inst->update_shape();
reshape_inst->update_shape_done_by_other = true;
reshape_layouts.push_back(reshape_inst->_impl_params->get_output_layout());
user_info.first = &reshape_inst->get_node();
user_info.second = reshape_inst->_impl_params->get_output_layout();
}
}

Expand All @@ -1510,11 +1511,10 @@ void primitive_inst::do_runtime_in_place_crop() {
if (crop_in_place_optimization::can_crop_be_optimized_along_feature(crop_layout, pred_layout)) {
crop_in_place_optimization::update_in_place_crop_padding_along_feature(u->get_node(), crop_layout, pred_layout, offsets, crop_axis, true);
} else if (crop_in_place_optimization::can_crop_be_optimized_simple_data_format(crop_layout, pred_layout)) {
crop_in_place_optimization::update_in_place_crop_padding_simple_data_format(crop_layout, pred_layout, reshape_layouts,
offsets, crop_axis, true);
if (crop_users.front()->get_node().is_type<reshape>() && reshape_layouts.size() > 0) {
crop_in_place_optimization::update_in_place_crop_padding_simple_data_format(crop_layout, pred_layout, user_info, offsets, crop_axis, true);
if (user_info.first) {
auto reshape_inst = crop_users.front();
reshape_inst->_impl_params->output_layouts[0] = reshape_layouts[0];
reshape_inst->_impl_params->output_layouts[0] = user_info.second;
reshape_inst->set_shape_change();
}
} else {
Expand Down
18 changes: 12 additions & 6 deletions src/plugins/intel_gpu/src/graph/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ padding propagate_padding(const layout& in_layout, const ov::PartialShape& out_s
update_pad_upper = pad_upper;
update_pad_mask = pad_mask;

// Truncate to the actual rank (for shapes with a rank less than 4)
update_pad_lower.resize(rank);
update_pad_upper.resize(rank);
update_pad_mask.resize(rank);

std::unordered_set<int64_t> tmp(axes.begin(), axes.end());
std::vector<int64_t> unique_axes;
const auto expanded_rank = rank + tmp.size();
Expand All @@ -61,13 +66,13 @@ padding propagate_padding(const layout& in_layout, const ov::PartialShape& out_s
// Normalize then remove repeated axes after normalization.
for (const auto& axis : axes) {
if (static_cast<size_t>(axis) <= out_shape.size()) {
pad_lower.insert(std::next(std::begin(pad_lower), axis), 0);
pad_upper.insert(std::next(std::begin(pad_upper), axis), 0);
pad_mask.insert(std::next(std::begin(pad_mask), axis), 0);
update_pad_lower.insert(std::next(std::begin(update_pad_lower), axis), 0);
update_pad_upper.insert(std::next(std::begin(update_pad_upper), axis), 0);
update_pad_mask.insert(std::next(std::begin(update_pad_mask), axis), 0);
} else {
pad_lower.push_back(0);
pad_upper.push_back(0);
pad_mask.push_back(0);
update_pad_lower.push_back(0);
update_pad_upper.push_back(0);
update_pad_mask.push_back(0);
}
}
} else {
Expand Down Expand Up @@ -254,6 +259,7 @@ std::string reshape_inst::to_string(reshape_node const& node) {
reshape_info.add("output pshape", desc->output_partial_shape);
reshape_info.add("output pattern", desc->output_pattern);
reshape_info.add("special zero", desc->special_zero);
reshape_info.add("reshape mode", desc->mode);

node_info->add("reshape info", reshape_info);
node_info->dump(primitive_description);
Expand Down
30 changes: 24 additions & 6 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,26 @@ KERNEL(rope_ref)(
uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0;
uint cos_sin_p = p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM < INPUT1_FEATURE_NUM ? p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM : 0;
uint cos_sin_h = h < INPUT1_SIZE_Y ? h : 0;

#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS
uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0);

uint cos_idx = cos_sin_idx;
uint sin_idx = cos_sin_idx;
#else
uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0);
uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0);
#endif

uint output_idx = OUTPUT_GET_INDEX(b, p, h, 0);

INPUT0_TYPE in1 = input[input_idx + r];
INPUT0_TYPE in2 = input[input_idx + HALF_ROTARY_NDIMS + r];

output[output_idx + r] = cos[cos_sin_idx + r] * in1 - sin[cos_sin_idx + r] * in2;
output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2;

output[output_idx + HALF_ROTARY_NDIMS + r] = cos[cos_sin_idx + HALF_ROTARY_NDIMS + r] * in2 +
sin[cos_sin_idx + HALF_ROTARY_NDIMS + r] * in1;
output[output_idx + HALF_ROTARY_NDIMS + r] = cos[cos_idx + HALF_ROTARY_NDIMS + r] * in2 +
sin[sin_idx + HALF_ROTARY_NDIMS + r] * in1;
}
#endif

Expand Down Expand Up @@ -128,16 +137,25 @@ KERNEL(rope_ref)(
cos_sin_p = gather[gather_idx];
#endif
cos_sin_p = cos_sin_p < INPUT1_SIZE_Y ? cos_sin_p : 0;

#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS
uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0);

uint cos_idx = cos_sin_idx;
uint sin_idx = cos_sin_idx;
#else
uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0);
uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0);
#endif

uint output_idx = OUTPUT_GET_INDEX(b, h, p, 0);

INPUT0_TYPE in1 = input[input_idx + r];
INPUT0_TYPE in2 = input[input_idx + HALF_ROTARY_NDIMS + r];

output[output_idx + r] = cos[cos_sin_idx + r] * in1 - sin[cos_sin_idx + r] * in2;
output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2;

output[output_idx + HALF_ROTARY_NDIMS + r] = cos[cos_sin_idx + HALF_ROTARY_NDIMS + r] * in2 +
sin[cos_sin_idx + HALF_ROTARY_NDIMS + r] * in1;
output[output_idx + HALF_ROTARY_NDIMS + r] = cos[cos_idx + HALF_ROTARY_NDIMS + r] * in2 +
sin[sin_idx + HALF_ROTARY_NDIMS + r] * in1;
}
#endif
Loading

0 comments on commit a33afe4

Please sign in to comment.