Skip to content

Commit

Permalink
[GPU] Allow softmax_bf kernel for axis=X 4d case
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Oct 26, 2023
1 parent 1e4f3f1 commit 84e3c7e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ bool SoftmaxKernelBaseBF::Validate(const Params& p, const optional_params& o) co

switch (params.dim) {
case SoftmaxDim::X:
return !input.Y().is_dynamic && input.Y().v == 1 &&
return ((!input.Y().is_dynamic && input.Y().v == 1) || input.GetLayout() == DataLayout::bfyx) &&
!input.Z().is_dynamic && input.Z().v == 1 &&
!input.Feature().is_dynamic && input.Feature().v == 1;
((!input.Feature().is_dynamic && input.Feature().v == 1) || input.GetLayout() == DataLayout::bfyx);
case SoftmaxDim::Y:
return !input.X().is_dynamic && input.X().v == 1 &&
!input.Z().is_dynamic && input.Z().v == 1 &&
Expand Down Expand Up @@ -122,6 +122,10 @@ SoftmaxKernelBase::DispatchData SoftmaxKernelBaseBF::SetDefault(const softmax_pa
OPENVINO_ASSERT(input.X().v == 1, "[GPU] SoftmaxKernelBaseBF: input.X() is expected to be 1 while actual value is ", input.X().v);
dispatchData.dataSetSize = input.Y().v;
dispatchData.dataSetsCount = input.Batch().v * input.Feature().v;
} else if (params.dim == SoftmaxDim::X && (input.Feature().v > 1 || input.Y().v > 1) && input.GetLayout() == DataLayout::bfyx) {
// Flatten BFY for such case
dispatchData.dataSetSize = input.X().v;
dispatchData.dataSetsCount = input.Batch().v * input.Feature().v * input.Y().v;
} else {
auto flatten_input = input.FlattenFeatureAndSpatials();
dispatchData.dataSetSize = flatten_input.Feature().v;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,18 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(ov::AnyMap())),
SoftMax8LayerTest::getTestCaseName);

const std::vector<ov::Shape> stableDiffusionShapes = {
{16, 4096, 4096},
{2, 8, 4096, 4096}
};

INSTANTIATE_TEST_SUITE_P(
smoke_SoftMaxStableDiffusion,
SoftMax8LayerTest,
testing::Combine(testing::ValuesIn(netPrecisions),
::testing::Values(ov::element::undefined),
::testing::Values(ov::element::undefined),
::testing::ValuesIn(ov::test::static_shapes_to_test_representation({{16, 4096, 4096}})),
::testing::ValuesIn(ov::test::static_shapes_to_test_representation(stableDiffusionShapes)),
testing::Values(-1),
testing::Values(ov::test::utils::DEVICE_GPU),
testing::Values(ov::AnyMap())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ TEST(softmax_gpu_dynamic_f32_test_upper_bound, input_same_values) {
format::bfyx);
auto config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
ov::intel_gpu::ImplementationDesc softmax_impl = { format::bfyx, "softmax_gpu_ref" };
config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "softmax", softmax_impl } }));
network network(engine, topology(input_layout("input", in_layout),
reorder("reorder", input_info("input"), format::bfyx, data_types::f16),
softmax("softmax", input_info("reorder"), 3),
Expand Down

0 comments on commit 84e3c7e

Please sign in to comment.