Skip to content

Commit

Permalink
[GPU] Added RoPE support for Llama2
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Jun 3, 2024
1 parent a1a2c4f commit aa5b5c9
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 5 deletions.
12 changes: 10 additions & 2 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@ struct rope : public primitive_base<rope> {
/// @param id This primitive id
/// @param inputs Inputs primitive ids
/// @param config Specific RoPE config
/// @param gather_rank Required for correct processing of gather input (if applicable)
rope(const primitive_id& id,
const std::vector<input_info>& inputs,
const RoPE::Config& config,
size_t gather_rank = 0,
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding}),
config(config) {}
config(config),
gather_rank(gather_rank) {}

RoPE::Config config;
size_t gather_rank;

size_t hash() const override {
size_t seed = primitive::hash();
Expand All @@ -40,6 +44,7 @@ struct rope : public primitive_base<rope> {
seed = hash_combine(seed, config.rotary_ndims);
seed = hash_combine(seed, config.slice_start);
seed = hash_combine(seed, config.slice_stop);
seed = hash_combine(seed, gather_rank);
return seed;
}

Expand All @@ -58,7 +63,8 @@ struct rope : public primitive_base<rope> {
config.is_qwen == rhs_casted.config.is_qwen &&
config.rotary_ndims == rhs_casted.config.rotary_ndims &&
config.slice_start == rhs_casted.config.slice_start &&
config.slice_stop == rhs_casted.config.slice_stop;
config.slice_stop == rhs_casted.config.slice_stop &&
gather_rank == rhs_casted.gather_rank;
}

void save(BinaryOutputBuffer& ob) const override {
Expand All @@ -73,6 +79,7 @@ struct rope : public primitive_base<rope> {
ob << config.rotary_ndims;
ob << config.slice_start;
ob << config.slice_stop;
ob << gather_rank;
}

void load(BinaryInputBuffer& ib) override {
Expand All @@ -87,6 +94,7 @@ struct rope : public primitive_base<rope> {
ib >> config.rotary_ndims;
ib >> config.slice_start;
ib >> config.slice_stop;
ib >> gather_rank;
}
};
} // namespace cldnn
34 changes: 34 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ struct rope_impl : typed_primitive_impl_ocl<rope> {
params.axis = primitive->config.is_qwen || primitive->config.is_chatglm ? 2 : 3;
params.num_of_inputs = primitive->config.is_chatglm || primitive->config.is_interleaved ? 2 : 3;

if (primitive->gather_rank > 0) {
params.gather_rank = primitive->gather_rank;
params.num_of_inputs++;
}

params.is_qwen = primitive->config.is_qwen;
params.is_chatglm = primitive->config.is_chatglm;
params.transposed_input = primitive->config.input_trans0213;
Expand All @@ -56,6 +61,35 @@ struct rope_impl : typed_primitive_impl_ocl<rope> {
return params;
}

static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params) {
const auto& primitive = impl_params.typed_desc<rope>();

if (primitive->config.is_chatglm || primitive->config.is_qwen) {
return primitive_impl::static_canonicalize_shapes(impl_params);
} else {
auto updated_impl_params = canonicalize_fused_shapes(impl_params);

std::set<size_t> canonicalize_from_begin = { 1, 2 };
for (size_t i = 0; i < updated_impl_params.input_layouts.size(); ++i) {
auto& input_layout = updated_impl_params.input_layouts[i];
if (canonicalize_from_begin.count(i) != 0) {
input_layout.set_partial_shape(extend_shape_to_rank_from_begin(input_layout.get_partial_shape()));
} else {
input_layout.set_partial_shape(extend_shape_to_rank_from_end(input_layout.get_partial_shape()));
}
}

auto& output_layout = updated_impl_params.output_layouts[0];
output_layout.set_partial_shape(extend_shape_to_rank_from_end(output_layout.get_partial_shape()));

return updated_impl_params;
}
}

kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override {
return static_canonicalize_shapes(impl_params);
}

void update_dispatch_data(const kernel_impl_params& impl_param) override {
auto kernel_params = get_kernel_params(impl_param, true);
(_kernel_data.update_dispatch_data_func)(kernel_params, _kernel_data);
Expand Down
19 changes: 18 additions & 1 deletion src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ KERNEL(rope_ref)(
const __global INPUT0_TYPE* input,
const __global INPUT1_TYPE* cos,
const __global INPUT1_TYPE* sin,
#ifdef ENABLE_GATHER
const __global INPUT3_TYPE* gather,
#endif
__global OUTPUT_TYPE* output)
{
const uint b = get_global_id(0);
Expand All @@ -108,9 +111,23 @@ KERNEL(rope_ref)(
#else
uint input_idx = INPUT0_GET_INDEX(b, h, p, 0);
#endif

uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0;
uint cos_sin_h = h < INPUT1_FEATURE_NUM ? h : 0;
uint cos_sin_p = p < INPUT1_SIZE_Y ? p : 0;
uint cos_sin_p = p;
#ifdef ENABLE_GATHER
uint gather_b = b < INPUT3_BATCH_NUM ? b : 0;
#if GATHER_RANK == 4
uint gather_h = h < INPUT3_FEATURE_NUM ? h : 0;
uint gather_p = p < INPUT3_SIZE_Y ? p : 0;
uint gather_idx = INPUT3_GET_INDEX(gather_b, gather_h, gather_p, 0);
#else
uint gather_p = p < INPUT3_FEATURE_NUM ? p : 0;
uint gather_idx = INPUT3_GET_INDEX(gather_b, gather_p, 0, 0);
#endif
cos_sin_p = gather[gather_idx];
#endif
cos_sin_p = cos_sin_p < INPUT1_SIZE_Y ? cos_sin_p : 0;
uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0);

uint output_idx = OUTPUT_GET_INDEX(b, h, p, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ JitConstants RoPEKernelBase::GetJitConstants(const rope_params& params, RoPEKern
jit.AddConstant(MakeJitConstant("ENABLE_IO_COPY", true));
}

if (params.gather_rank > 0) {
jit.AddConstant(MakeJitConstant("ENABLE_GATHER", true));
jit.AddConstant(MakeJitConstant("GATHER_RANK", params.gather_rank));
}

if (params.slice_stop - params.slice_start > 0) {
jit.AddConstant(MakeJitConstant("ENABLE_SLICE", true));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct rope_params : public base_params {
size_t slice_stop;
size_t axis;
size_t num_of_inputs;
size_t gather_rank;

bool is_qwen;
bool is_chatglm;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
namespace kernel_selector {
ParamsKey RoPEKernelRef::GetSupportedKey() const {
ParamsKey k;
k.EnableInputDataType(Datatype::INT32);
k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::F16);
Expand Down
8 changes: 7 additions & 1 deletion src/plugins/intel_gpu/src/plugin/ops/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@ static void CreateRoPEOp(ProgramBuilder& p, const std::shared_ptr<op::internal::
auto inputs = p.GetInputInfo(op);
const auto& config = op->get_config();

size_t gather_rank = 0;
if (config.gather_position_arg_id > 0) {
gather_rank = op->get_input_partial_shape(config.gather_position_arg_id).size();
}

auto rope = cldnn::rope(layer_type_name_ID(op),
inputs,
config);
config,
gather_rank);

p.add_primitive(*op, rope);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,6 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {

manager.register_pass<ov::pass::RoPEFusion>();
pass_config->disable<ov::pass::RoPEFusionGPTJ>();
pass_config->disable<ov::pass::RoPEFusionCosSinPreprocess>();
pass_config->disable<ov::pass::RoPEFusionIOSlicing>();
pass_config->disable<ov::pass::RoPEShareCosSin>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,10 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7b,
::testing::Values(ov::test::utils::DEVICE_GPU)),
RoPETestQwen7b::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestLlama2,
RoPETestLlama2,
::testing::Values(ov::test::utils::DEVICE_GPU),
RoPETestLlama2::getTestCaseName);

} // namespace test
} // namespace ov

0 comments on commit aa5b5c9

Please sign in to comment.