Skip to content

Commit

Permalink
StridedSlice -> inside kernel calc
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Feb 1, 2024
1 parent 8c29306 commit fa8776b
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 78 deletions.
5 changes: 5 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ struct rope_impl : typed_primitive_impl_ocl<rope> {
params.head_size = primitive->config.head_size;
params.rotary_ndims = primitive->config.rotary_ndims;

params.slice_start = primitive->config.slice_start;
params.slice_stop = primitive->config.slice_stop;

params.axis = primitive->config.is_qwen || primitive->config.is_chatglm ? 2 : 3;

for (size_t i = 1; i < impl_param.input_layouts.size(); ++i) {
params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(i)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ KERNEL(rope_ref)(
const uint h = get_global_id(2) / HALF_ROTARY_NDIMS;
const uint r = get_global_id(2) % HALF_ROTARY_NDIMS * 2;

#ifdef ENABLE_SLICE
uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, p, b, h * HEAD_SIZE, 0);

input_idx += SLICED_FROM_START * (p * INPUT0_FEATURE_NUM + b + 1)
+ SLICED_FROM_END * (p * INPUT0_FEATURE_NUM + b);
#else
uint input_idx = INPUT0_GET_INDEX(p, b, h * HEAD_SIZE, 0);
#endif
uint cos_sin_idx = INPUT1_GET_INDEX(p, b, 0, 0);
uint output_idx = OUTPUT_GET_INDEX(p, b, h, 0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ JitConstants RoPEKernelBase::GetJitConstants(const rope_params& params, RoPEKern
jit.AddConstant(MakeJitConstant("HEAD_SIZE", params.head_size));
jit.AddConstant(MakeJitConstant("HALF_ROTARY_NDIMS", params.rotary_ndims / 2));

if (params.slice_stop - params.slice_start > 0) {
jit.AddConstant(MakeJitConstant("ENABLE_SLICE", true));

auto f = toCodeString(params.inputs[0].Feature(), 1);
auto x = toCodeString(params.inputs[0].X(), 2);
auto y = toCodeString(params.inputs[0].Y(), 3);
auto sliced_y = toCodeString(params.slice_stop - params.slice_start);

jit.AddConstant(MakeJitConstant("SLICED_INPUT0_X_PITCH", 1));
jit.AddConstant(MakeJitConstant("SLICED_INPUT0_Y_PITCH", x));
jit.AddConstant(MakeJitConstant("SLICED_INPUT0_FEATURE_PITCH", x + "*" + sliced_y));
jit.AddConstant(MakeJitConstant("SLICED_INPUT0_BATCH_PITCH", x + "*" + sliced_y + "*" + f));
jit.AddConstant(MakeJitConstant("SLICED_INPUT0_OFFSET", 0));

jit.AddConstant(MakeJitConstant("SLICED_FROM_START", toCodeString(params.slice_start)));
jit.AddConstant(MakeJitConstant("SLICED_FROM_END", "(" + y + "-" + toCodeString(params.slice_stop) + ")"));
}

return jit;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ struct rope_params : public base_params {
size_t head_cnt;
size_t head_size;
size_t rotary_ndims;

size_t slice_start;
size_t slice_stop;
size_t axis;
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
78 changes: 0 additions & 78 deletions src/plugins/intel_gpu/src/plugin/ops/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,87 +22,11 @@ using RoPE = ov::intel_gpu::op::RoPE;
namespace ov {
namespace intel_gpu {

namespace {

template<typename T>
cldnn::data CreateScalarDataPrimitive(ProgramBuilder& p, const cldnn::primitive_id& name, T value) {
auto mem = p.get_engine().allocate_memory(
cldnn::layout{element::from<T>(), cldnn::format::bfyx, {1, 1, 1, 1}}, false);
cldnn::mem_lock<int8_t> host_mem{mem, p.get_engine().get_service_stream()};
std::memcpy(host_mem.data(), &value, sizeof value);
return {name, mem};
}

cldnn::data CreateDataPrimitive(ProgramBuilder& p, const cldnn::primitive_id& name, const std::vector<int64_t>& value) {
auto mem = p.get_engine().allocate_memory({{ 3 }, cldnn::data_types::i64, cldnn::format::bfyx });
cldnn::mem_lock<int64_t> host_mem(mem, p.get_engine().get_service_stream());

auto it = host_mem.begin();
for (auto x : value)
*it++ = x;

return {name, mem};
}

static void CreateRoPEOp(ProgramBuilder& p, const std::shared_ptr<op::RoPE>& op) {
validate_inputs_count(op, {3, 4});
auto inputs = p.GetInputInfo(op);
const auto& config = op->get_config();

// if (config.slice_stop - config.slice_start > 0) {
// const cldnn::primitive_id slice_start {layer_type_name_ID(op) + "_slice_start"};
// auto slice_start_prim = CreateScalarDataPrimitive(p, slice_start, static_cast<std::int64_t>(config.slice_start));
// p.add_primitive(*op, slice_start_prim);

// const cldnn::primitive_id slice_stop {layer_type_name_ID(op) + "_slice_stop"};
// auto slice_stop_prim = CreateScalarDataPrimitive(p, slice_stop, static_cast<std::int64_t>(config.slice_stop));
// p.add_primitive(*op, slice_stop_prim);

// const cldnn::primitive_id slice_step {layer_type_name_ID(op) + "_slice_step"};
// auto slice_step_prim = CreateScalarDataPrimitive(p, slice_step, static_cast<std::int64_t>(1));
// p.add_primitive(*op, slice_step_prim);

// const cldnn::primitive_id slice_axis {layer_type_name_ID(op) + "_slice_axis"};
// auto slice_axis_prim = CreateScalarDataPrimitive(p, slice_axis, static_cast<std::int64_t>(2));
// p.add_primitive(*op, slice_axis_prim);


// auto sliceName = op->get_friendly_name() + "_slice";
// auto slicePrim = cldnn::slice(sliceName,
// { cldnn::input_info(inputs[0].pid),
// slice_start,
// slice_stop,
// slice_step,
// slice_axis });
// p.add_primitive(*op, slicePrim);
// inputs[0] = cldnn::input_info(sliceName);
// }

if (config.slice_stop - config.slice_start > 0) {
const cldnn::primitive_id begin {layer_type_name_ID(op) + "_slice_begin"};
auto slice_start_prim = CreateDataPrimitive(p, begin, { 0, 0, static_cast<std::int64_t>(config.slice_start) });
p.add_primitive(*op, slice_start_prim);

const cldnn::primitive_id end {layer_type_name_ID(op) + "_slice_end"};
auto slice_stop_prim = CreateDataPrimitive(p, end, { 1000000, 1000000, static_cast<std::int64_t>(config.slice_stop) });
p.add_primitive(*op, slice_stop_prim);

const cldnn::primitive_id strides {layer_type_name_ID(op) + "_slice_strides"};
auto slice_step_prim = CreateDataPrimitive(p, strides, { 1, 1, 1 });
p.add_primitive(*op, slice_step_prim);


auto strided_slice_name = op->get_friendly_name() + "_strided_slice";
auto strided_slice_prim = cldnn::strided_slice(strided_slice_name,
cldnn::input_info(inputs[0].pid),
begin,
end,
strides,
{}, {}, {}, {}, {}, {});
p.add_primitive(*op, strided_slice_prim);
inputs[0] = cldnn::input_info(strided_slice_name);
}

if (config.input_trans0213) {
auto& input_pshape = op->get_input_partial_shape(0);
std::vector<uint16_t> transposeOrder(input_pshape.size());
Expand All @@ -127,8 +51,6 @@ static void CreateRoPEOp(ProgramBuilder& p, const std::shared_ptr<op::RoPE>& op)

p.add_primitive(*op, rope);
}
} // namespace


REGISTER_FACTORY_IMPL(internal, RoPE);

Expand Down

0 comments on commit fa8776b

Please sign in to comment.