Skip to content

Commit

Permalink
Slice -> StridedSlice
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Jan 31, 2024
1 parent be71494 commit 8c29306
Showing 1 changed file with 60 additions and 21 deletions.
81 changes: 60 additions & 21 deletions src/plugins/intel_gpu/src/plugin/ops/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
#include "intel_gpu/primitives/permute.hpp"
#include "intel_gpu/primitives/slice.hpp"
#include "intel_gpu/primitives/data.hpp"

// static size_t uniq_id = 0;
#include "intel_gpu/primitives/strided_slice.hpp"

namespace ov {
namespace op {
Expand All @@ -23,6 +22,8 @@ using RoPE = ov::intel_gpu::op::RoPE;
namespace ov {
namespace intel_gpu {

namespace {

template<typename T>
cldnn::data CreateScalarDataPrimitive(ProgramBuilder& p, const cldnn::primitive_id& name, T value) {
auto mem = p.get_engine().allocate_memory(
Expand All @@ -32,38 +33,74 @@ cldnn::data CreateScalarDataPrimitive(ProgramBuilder& p, const cldnn::primitive_
return {name, mem};
}

cldnn::data CreateDataPrimitive(ProgramBuilder& p, const cldnn::primitive_id& name, const std::vector<int64_t>& value) {
auto mem = p.get_engine().allocate_memory({{ 3 }, cldnn::data_types::i64, cldnn::format::bfyx });
cldnn::mem_lock<int64_t> host_mem(mem, p.get_engine().get_service_stream());

auto it = host_mem.begin();
for (auto x : value)
*it++ = x;

return {name, mem};
}

static void CreateRoPEOp(ProgramBuilder& p, const std::shared_ptr<op::RoPE>& op) {
validate_inputs_count(op, {3, 4});
auto inputs = p.GetInputInfo(op);
const auto& config = op->get_config();

// if (config.slice_stop - config.slice_start > 0) {
// const cldnn::primitive_id slice_start {layer_type_name_ID(op) + "_slice_start"};
// auto slice_start_prim = CreateScalarDataPrimitive(p, slice_start, static_cast<std::int64_t>(config.slice_start));
// p.add_primitive(*op, slice_start_prim);

// const cldnn::primitive_id slice_stop {layer_type_name_ID(op) + "_slice_stop"};
// auto slice_stop_prim = CreateScalarDataPrimitive(p, slice_stop, static_cast<std::int64_t>(config.slice_stop));
// p.add_primitive(*op, slice_stop_prim);

// const cldnn::primitive_id slice_step {layer_type_name_ID(op) + "_slice_step"};
// auto slice_step_prim = CreateScalarDataPrimitive(p, slice_step, static_cast<std::int64_t>(1));
// p.add_primitive(*op, slice_step_prim);

// const cldnn::primitive_id slice_axis {layer_type_name_ID(op) + "_slice_axis"};
// auto slice_axis_prim = CreateScalarDataPrimitive(p, slice_axis, static_cast<std::int64_t>(2));
// p.add_primitive(*op, slice_axis_prim);


// auto sliceName = op->get_friendly_name() + "_slice";
// auto slicePrim = cldnn::slice(sliceName,
// { cldnn::input_info(inputs[0].pid),
// slice_start,
// slice_stop,
// slice_step,
// slice_axis });
// p.add_primitive(*op, slicePrim);
// inputs[0] = cldnn::input_info(sliceName);
// }

if (config.slice_stop - config.slice_start > 0) {
const cldnn::primitive_id slice_start {layer_type_name_ID(op) + "_slice_start"};
auto slice_start_prim = CreateScalarDataPrimitive(p, slice_start, static_cast<std::int64_t>(config.slice_start));
const cldnn::primitive_id begin {layer_type_name_ID(op) + "_slice_begin"};
auto slice_start_prim = CreateDataPrimitive(p, begin, { 0, 0, static_cast<std::int64_t>(config.slice_start) });
p.add_primitive(*op, slice_start_prim);

const cldnn::primitive_id slice_stop {layer_type_name_ID(op) + "_slice_stop"};
auto slice_stop_prim = CreateScalarDataPrimitive(p, slice_stop, static_cast<std::int64_t>(config.slice_stop));
const cldnn::primitive_id end {layer_type_name_ID(op) + "_slice_end"};
auto slice_stop_prim = CreateDataPrimitive(p, end, { 1000000, 1000000, static_cast<std::int64_t>(config.slice_stop) });
p.add_primitive(*op, slice_stop_prim);

const cldnn::primitive_id slice_step {layer_type_name_ID(op) + "_slice_step"};
auto slice_step_prim = CreateScalarDataPrimitive(p, slice_step, static_cast<std::int64_t>(1));
const cldnn::primitive_id strides {layer_type_name_ID(op) + "_slice_strides"};
auto slice_step_prim = CreateDataPrimitive(p, strides, { 1, 1, 1 });
p.add_primitive(*op, slice_step_prim);

const cldnn::primitive_id slice_axis {layer_type_name_ID(op) + "_slice_axis"};
auto slice_axis_prim = CreateScalarDataPrimitive(p, slice_axis, static_cast<std::int64_t>(2));
p.add_primitive(*op, slice_axis_prim);


auto sliceName = op->get_friendly_name() + "_slice";
auto slicePrim = cldnn::slice(sliceName,
{ cldnn::input_info(inputs[0].pid),
slice_start,
slice_stop,
slice_step,
slice_axis });
p.add_primitive(*op, slicePrim);
inputs[0] = cldnn::input_info(sliceName);
auto strided_slice_name = op->get_friendly_name() + "_strided_slice";
auto strided_slice_prim = cldnn::strided_slice(strided_slice_name,
cldnn::input_info(inputs[0].pid),
begin,
end,
strides,
{}, {}, {}, {}, {}, {});
p.add_primitive(*op, strided_slice_prim);
inputs[0] = cldnn::input_info(strided_slice_name);
}

if (config.input_trans0213) {
Expand All @@ -90,6 +127,8 @@ static void CreateRoPEOp(ProgramBuilder& p, const std::shared_ptr<op::RoPE>& op)

p.add_primitive(*op, rope);
}
} // namespace


REGISTER_FACTORY_IMPL(internal, RoPE);

Expand Down

0 comments on commit 8c29306

Please sign in to comment.