Skip to content

Commit

Permalink
[GPU] Fix rms not to be fused if the size of target dimension is > WGS (
Browse files Browse the repository at this point in the history
openvinotoolkit#22768)

### Details:
 - Fix rms not to be fused if the size of target dimemnsion is > WGS

### Tickets:
 - 131958
  • Loading branch information
yeonbok authored Feb 12, 2024
1 parent 3bdd728 commit a69fe5e
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 16 deletions.
9 changes: 9 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/rms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,22 @@ struct rms_impl : typed_primitive_impl_ocl<rms> {

params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(1)));
params.epsilon = primitive->epsilon;
params.ov_input_rank = static_cast<int32_t>(impl_param.get_input_layout().get_partial_shape().size());
return {params, optional_params};
}

void update_dispatch_data(const kernel_impl_params& impl_param) override {
auto kernel_params = get_kernel_params(impl_param, true);
(_kernel_data.update_dispatch_data_func)(kernel_params.first, _kernel_data);
}

static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params) {
return impl_params;
}

kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override {
return static_canonicalize_shapes(impl_params);
}
};

namespace detail {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace kernel_selector {
struct rms_params : public base_params {
rms_params() : base_params(KernelType::RMS) {}
float epsilon = 0.0f;
int32_t ov_input_rank = -1;
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,22 @@ JitConstants RMSKernelBfyxOpt::GetJitConstants(const rms_params& params, Dispatc
if (params.has_dynamic_tensors()) {
const auto& input = params.inputs[0];
DimensionAccessHelper dims(input);
const std::string data_size = toVectorMulString({dims.x(), dims.y(), dims.z()});
std::string data_size;
switch (params.ov_input_rank) {
case 1 :
data_size = dims.b();
break;
case 2 :
data_size = dims.f();
break;
case 3 :
data_size = dims.y();
break;
default:
data_size = dims.x();
break;
}

const std::string lws_0 = "get_local_size(0)";
jit.AddConstants({
MakeJitConstant("DATA_SIZE", data_size),
Expand All @@ -47,7 +62,7 @@ JitConstants RMSKernelBfyxOpt::GetJitConstants(const rms_params& params, Dispatc
});
}
jit.AddConstants({
MakeJitConstant("VEC_SIZE", 8),
MakeJitConstant("VEC_SIZE", vec_size),
MakeJitConstant("VLOAD", "CAT(vload, VEC_SIZE)"),
MakeJitConstant("VSTORE", "CAT(vstore, VEC_SIZE)"),
MakeJitConstant("INPUT_VEC_TYPE", "MAKE_VECTOR_TYPE(INPUT0_TYPE, VEC_SIZE)"),
Expand All @@ -71,10 +86,26 @@ RMSKernelBase::DispatchData RMSKernelBfyxOpt::SetDefault(const rms_params& param
dispatchData.maxSlmSize = max_lws;

if (!params.has_dynamic_tensors()) {
dispatchData.dataSize = input.X().v * input.Y().v * input.Z().v;
dispatchData.dataCount = input.Batch().v * input.Feature().v;
dispatchData.slmSize = dispatchData.dataSize / 8;
dispatchData.leftovers = dispatchData.dataSize % 8;
// data size to be processed within a LWG
switch (params.ov_input_rank) {
case 1:
dispatchData.dataSize = input.Batch().v;
dispatchData.dataCount = 1;
case 2:
dispatchData.dataSize = input.Feature().v;
dispatchData.dataCount = input.Batch().v;
case 3:
dispatchData.dataSize = input.Y().v;
dispatchData.dataCount = input.Batch().v * input.Feature().v;
break;
default:
dispatchData.dataSize = input.X().v;
dispatchData.dataCount = input.Batch().v * input.Feature().v * input.Z().v * input.Y().v;
break;
}

dispatchData.slmSize = dispatchData.dataSize / vec_size;
dispatchData.leftovers = dispatchData.dataSize % vec_size;

dispatchData.gws[0] = dispatchData.slmSize;
dispatchData.gws[1] = dispatchData.dataCount;
Expand All @@ -96,12 +127,12 @@ bool RMSKernelBfyxOpt::Validate(const Params& p, const optional_params& o) const

if (!gamma.is_dynamic()) {
size_t data_size = gamma.LogicalSize();
if (data_size < 8) {
if (data_size < vec_size) {
return false;
}
auto local_mem_per_wi = 2 * BytesPerElement(params.inputs[0].GetDType());
auto max_lws = std::min(params.engineInfo.maxWorkGroupSize, params.engineInfo.maxLocalMemSize / local_mem_per_wi);
auto slm_size = data_size / 8;
auto slm_size = data_size / vec_size;
if (slm_size > max_lws) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ class RMSKernelBfyxOpt : public RMSKernelBase {
bool Validate(const Params&, const optional_params&) const override;
DispatchData SetDefault(const rms_params& params) const override;
JitConstants GetJitConstants(const rms_params& params, DispatchData dispatchData) const override;
const size_t vec_size = 8;
};
} // namespace kernel_selector
16 changes: 15 additions & 1 deletion src/plugins/intel_gpu/src/plugin/transformations/rms_fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static std::function<bool(ov::Output<ov::Node>)> constant_value(const float targ
};
}

RMSFusion::RMSFusion() {
RMSFusion::RMSFusion(uint64_t max_work_group_size) {
using namespace ov::pass::pattern;

// Detect RMS decomposition pattern
Expand Down Expand Up @@ -82,6 +82,20 @@ RMSFusion::RMSFusion() {
}

const auto& gamma_node = pattern_map.at(gamma).get_node_shared_ptr();
const auto& gamma_shape = gamma_node->get_output_partial_shape(0).to_shape();

const auto& mean_node = pattern_map.at(mean).get_node_shared_ptr();
const auto & axes = pattern_map.at(mean_axes).get_node_shared_ptr();
auto axes_constant = std::dynamic_pointer_cast<ov::op::v0::Constant>(axes);
auto axes_val = axes_constant->cast_vector<int64_t>();
// allow last dimension only
if ((axes_val[0] != -1) && (axes_val[0] != (static_cast<int64_t>(mean_node->get_input_partial_shape(0).size()) - 1)))
return false;

const int32_t vec_size = 8;
if (static_cast<int32_t>((gamma_shape.back() / vec_size)) > static_cast<int32_t>(max_work_group_size))
return false;

auto output_type = m.get_match_root()->get_output_element_type(0);

auto rms = std::make_shared<op::RMS>(x_output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace intel_gpu {
class RMSFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("RMSFusion", "0");
RMSFusion();
RMSFusion(uint64_t max_work_group_size);
};

} // namespace intel_gpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
manager.register_pass<ov::intel_gpu::MoveFCReshapeToWeights>();
manager.register_pass<ov::intel_gpu::ConvertFullyConnectedToFullyConnectedCompressed>();
manager.register_pass<ov::intel_gpu::ConvertGatherToGatherCompressed>();
manager.register_pass<ov::intel_gpu::RMSFusion>();
manager.register_pass<ov::intel_gpu::RMSFusion>(device_info.max_work_group_size);
manager.register_pass<ov::intel_gpu::KVCacheFusion>();
manager.register_pass<ov::intel_gpu::FullyConnectedConvertFusion>();
if (!device_info.supports_immad)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest1) {
auto comp = std::make_shared<ov::opset10::Convert>(mul2, ov::element::f16);

model = std::make_shared<ov::Model>(ov::NodeVector{comp}, ov::ParameterVector{input});
manager.register_pass<RMSFusion>();
manager.register_pass<RMSFusion>(32);
}
{
auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::Shape{1, 2, 6});
Expand Down Expand Up @@ -66,7 +66,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest2) {
auto comp = std::make_shared<ov::opset10::Convert>(mul2, ov::element::f16);

model = std::make_shared<ov::Model>(ov::NodeVector{comp}, ov::ParameterVector{input});
manager.register_pass<RMSFusion>();
manager.register_pass<RMSFusion>(32);
}
}

Expand All @@ -88,7 +88,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest3) {
auto comp = std::make_shared<ov::opset10::Convert>(mul2, ov::element::f16);

model = std::make_shared<ov::Model>(ov::NodeVector{comp}, ov::ParameterVector{input});
manager.register_pass<RMSFusion>();
manager.register_pass<RMSFusion>(32);
}
}

Expand All @@ -110,7 +110,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest4) {
auto comp = std::make_shared<ov::opset10::Convert>(mul2, ov::element::f16);

model = std::make_shared<ov::Model>(ov::NodeVector{comp}, ov::ParameterVector{input});
manager.register_pass<RMSFusion>();
manager.register_pass<RMSFusion>(32);
}
}

Expand All @@ -132,7 +132,7 @@ TEST_F(TransformationTestsF, RMSNormFusionTest5) {
auto comp = std::make_shared<ov::opset10::Convert>(mul2, ov::element::f16);

model = std::make_shared<ov::Model>(ov::NodeVector{comp}, ov::ParameterVector{input});
manager.register_pass<RMSFusion>();
manager.register_pass<RMSFusion>(32);
}
{
auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 6});
Expand Down

0 comments on commit a69fe5e

Please sign in to comment.