-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Snippets][CPU] Enabled dynamic MHA FP32 tokenization on x64 (#25500)
### Details: - *The PR enables dynamic FP32 MHA tokenization on x64 platforms 🎉* - *`std::vector.resize()` which was used for buffer scratchpad allocation is very expensive operation due to default constructor of elements. This PR replace `std::vector.resize()` with CPU Node Scratchpad memory which can be shared between nodes. Also since each thread must have the own scratchpad memory, we allocated `size * threads_max` - however, in execution thread count can be less (depends on parallel work amount). Now we allocate only `size * n_threads` where `nthreads` is real count of working threads.* - *Fixed dimension K validation in `BrgemmBlocking` pass: one of inputs can have dynamic value of this dimension* - *Fixed `utils::broadcast_merge_dim()` and supported broadcasting of integer values in IterHandlers. Added unit tests for `utils::broadcast_merge_dim()`* ### Tickets: - *149900* ### Prerequisites: - [x] #25326 - [x] #25378 - [x] #25623 - [x] #25638 - [x] #25745 - [x] #25957 - [x] #25733
- Loading branch information
1 parent
fdde9f1
commit 54f58b8
Showing
14 changed files
with
201 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
28 changes: 28 additions & 0 deletions
28
src/common/snippets/tests/include/utils/broadcast_dim_merge.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <common_test_utils/ov_test_utils.hpp> | ||
|
||
|
||
namespace ov { | ||
namespace test { | ||
namespace snippets { | ||
|
||
// D1, D2, Result | ||
using BroadcastMergeDimParams = std::tuple<size_t, size_t, size_t>; | ||
|
||
class BroadcastMergeDimTest : public testing::TestWithParam<BroadcastMergeDimParams> { | ||
public: | ||
static std::string getTestCaseName(testing::TestParamInfo<BroadcastMergeDimParams> obj); | ||
|
||
protected: | ||
void SetUp() override; | ||
BroadcastMergeDimParams m_dims = {}; | ||
}; | ||
|
||
} // namespace snippets | ||
} // namespace test | ||
} // namespace ov |
56 changes: 56 additions & 0 deletions
56
src/common/snippets/tests/src/utils/broadcast_merge_dim.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "utils/broadcast_dim_merge.hpp" | ||
|
||
#include "common_test_utils/ov_test_utils.hpp" | ||
#include "snippets/utils/utils.hpp" | ||
|
||
namespace ov { | ||
namespace test { | ||
namespace snippets { | ||
|
||
std::string BroadcastMergeDimTest::getTestCaseName(testing::TestParamInfo<BroadcastMergeDimParams> obj) { | ||
BroadcastMergeDimParams params = obj.param; | ||
std::ostringstream result; | ||
result << "D0=" << ov::snippets::utils::value2str(std::get<0>(params)) << "_"; | ||
result << "D1=" << ov::snippets::utils::value2str(std::get<1>(params)) << "_"; | ||
result << "DST=" << ov::snippets::utils::value2str(std::get<2>(params)); | ||
return result.str(); | ||
} | ||
|
||
void BroadcastMergeDimTest::SetUp() { | ||
m_dims = this->GetParam(); | ||
} | ||
|
||
TEST_P(BroadcastMergeDimTest, BrodcastMergeDim) { | ||
size_t d1, d2, dst, result; | ||
std::tie(d1, d2, dst) = this->m_dims; | ||
ASSERT_TRUE(ov::snippets::utils::broadcast_merge_dim(result, d1, d2)); | ||
ASSERT_EQ(result, dst); | ||
} | ||
|
||
namespace BrodcastMergeDimInstantiation { | ||
|
||
constexpr size_t dynamic = ov::snippets::utils::get_dynamic_value<size_t>(); | ||
|
||
const std::vector<BroadcastMergeDimParams> dimension_cases = { | ||
{10, 10, 10}, | ||
{10, 1, 10}, | ||
{1, 10, 10}, | ||
{dynamic, 10, 10}, | ||
{10, dynamic, 10}, | ||
{dynamic, dynamic, dynamic}, | ||
{dynamic, 1, dynamic}, | ||
{1, dynamic, dynamic}, | ||
}; | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_BrodcastMergeDim, BroadcastMergeDimTest, | ||
::testing::ValuesIn(dimension_cases), | ||
BroadcastMergeDimTest::getTestCaseName); | ||
|
||
} // namespace BrodcastMergeDimInstantiation | ||
} // namespace snippets | ||
} // namespace test | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.