Skip to content

Commit

Permalink
[Snippets][CPU] Enabled dynamic MHA FP32 tokenization on x64 (#25500)
Browse files Browse the repository at this point in the history
### Details:
- *The PR enables dynamic FP32 MHA tokenization on x64 platforms 🎉*
- *`std::vector.resize()` which was used for buffer scratchpad
allocation is very expensive operation due to default constructor of
elements. This PR replace `std::vector.resize()` with CPU Node
Scratchpad memory which can be shared between nodes. Also since each
thread must have the own scratchpad memory, we allocated `size *
threads_max` - however, in execution thread count can be less (depends
on parallel work amount). Now we allocate only `size * n_threads` where
`nthreads` is real count of working threads.*
- *Fixed dimension K validation in `BrgemmBlocking` pass: one of inputs
can have dynamic value of this dimension*
- *Fixed `utils::broadcast_merge_dim()` and supported broadcasting of
integer values in IterHandlers. Added unit tests for
`utils::broadcast_merge_dim()`*

### Tickets:
 - *149900*


### Prerequisites:
- [x] #25326
- [x] #25378
- [x] #25623
- [x] #25638
- [x] #25745
- [x] #25957
- [x] #25733
  • Loading branch information
a-sidorova authored Aug 21, 2024
1 parent fdde9f1 commit 54f58b8
Show file tree
Hide file tree
Showing 14 changed files with 201 additions and 82 deletions.
2 changes: 1 addition & 1 deletion src/common/snippets/include/snippets/lowered/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace pass {
* @brief Base class for transformations on linear IR
* @ingroup snippets
*/
class PassBase {
class PassBase : public std::enable_shared_from_this<PassBase> {
public:
PassBase() = default;
virtual ~PassBase() = default;
Expand Down
5 changes: 5 additions & 0 deletions src/common/snippets/include/snippets/utils/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ std::string vector2str(const std::vector<T>& values) {

bool broadcast_merge_dim(size_t& dst, const size_t& d1, const size_t& d2);

// If one of the dims is dynamic, return the other dim (might also be dynamic)
// If both dims are static, they must be equal - this is the difference from the utility above
// Can be used in SpecificLoopIterationHandlers
bool merge_dynamic_dim(size_t& dst, const size_t& d1, const size_t& d2);

VectorDims pshape_to_vdims(const PartialShape&);
ov::PartialShape vdims_to_pshape(const VectorDims&);

Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/lowered/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ ExpressionPtr Expression::clone_with_new_inputs(const ExpressionMap& expr_map,
}

ExpressionPort Expression::get_input_port(size_t i) {
return ExpressionPort(this->shared_from_this(), ExpressionPort::Type::Input, i);
return ExpressionPort(shared_from_this(), ExpressionPort::Type::Input, i);
}

ExpressionPort Expression::get_output_port(size_t i) {
return ExpressionPort(this->shared_from_this(), ExpressionPort::Type::Output, i);
return ExpressionPort(shared_from_this(), ExpressionPort::Type::Output, i);
}

std::vector<ExpressionPort> Expression::get_input_ports() {
Expand Down
6 changes: 4 additions & 2 deletions src/common/snippets/src/lowered/pass/brgemm_blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ std::tuple<size_t, size_t, size_t> BrgemmBlockingBase::get_blocking_params(const

const auto& m = *++out_preordered_dims.rbegin();
const auto& n = *out_preordered_dims.rbegin();
const auto& k = *in_0_planar_dims.rbegin();
OPENVINO_ASSERT(k == *++in_1_planar_dims.rbegin(), "Brgemm input descriptors have different K dimension value.");
const auto& k0 = *in_0_planar_dims.rbegin();
const auto& k1 = *++in_1_planar_dims.rbegin();
size_t k = 0;
OPENVINO_ASSERT(utils::merge_dynamic_dim(k, k0, k1), "Brgemm input descriptors have incompatible K dimension value.");

// Ticket: 113745
// TODO: extend block size selection heuristics
Expand Down
20 changes: 10 additions & 10 deletions src/common/snippets/src/lowered/pass/iter_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ bool UpdateMemoryAccessCounts::run(LinearIR& linear_ir, LinearIR::constExprIt be
}

std::shared_ptr<pass::PassBase> UpdateMemoryAccessCounts::merge(const std::shared_ptr<pass::PassBase>& other) {
const auto merged_pass = std::make_shared<UpdateMemoryAccessCounts>(m_count);
if (other == nullptr)
return merged_pass;
if (!other)
return shared_from_this();
const auto casted_pass = ov::as_type_ptr<UpdateMemoryAccessCounts>(other);
if (!casted_pass || m_count != casted_pass->m_count)
size_t merged_count;
if (!casted_pass || !ov::snippets::utils::merge_dynamic_dim(merged_count, m_count, casted_pass->m_count))
return nullptr;
return merged_pass;
return std::make_shared<UpdateMemoryAccessCounts>(merged_count);
}

SetFillOffset::SetFillOffset(size_t offset) : RangedPass(), m_offset(offset) {}
Expand All @@ -71,13 +71,13 @@ bool SetFillOffset::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linear
}

std::shared_ptr<pass::PassBase> SetFillOffset::merge(const std::shared_ptr<pass::PassBase>& other) {
const auto merged_pass = std::make_shared<SetFillOffset>(m_offset);
if (other == nullptr)
return merged_pass;
if (!other)
return shared_from_this();
const auto casted_pass = ov::as_type_ptr<SetFillOffset>(other);
if (!casted_pass || m_offset != casted_pass->m_offset)
size_t merged_offset;
if (!casted_pass || !ov::snippets::utils::merge_dynamic_dim(merged_offset, m_offset, casted_pass->m_offset))
return nullptr;
return merged_pass;
return std::make_shared<SetFillOffset>(merged_offset);
}

bool SetLoopIncrementOne::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
Expand Down
10 changes: 5 additions & 5 deletions src/common/snippets/src/lowered/pass/propagate_subtensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,13 @@ bool UpdateSubtensors::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Lin
}

std::shared_ptr<pass::PassBase> UpdateSubtensors::merge(const std::shared_ptr<pass::PassBase>& other) {
const auto merged_pass = std::make_shared<UpdateSubtensors>(m_tail_size);
if (other == nullptr)
return merged_pass;
if (!other)
return shared_from_this();
const auto casted_pass = ov::as_type_ptr<UpdateSubtensors>(other);
if (!casted_pass || m_tail_size != casted_pass->m_tail_size)
size_t merged_size;
if (!casted_pass || !ov::snippets::utils::merge_dynamic_dim(merged_size, m_tail_size, casted_pass->m_tail_size))
return nullptr;
return merged_pass;
return std::make_shared<UpdateSubtensors>(merged_size);
}

} // namespace pass
Expand Down
20 changes: 18 additions & 2 deletions src/common/snippets/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,25 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken

// TODO [75567]: move this plugin-specific constraint to the plugin callback
const auto last_node = ordered_ops.back();
if (potential_body_params_count + last_node->get_output_size() + hidden_virtual_ports_count + uniqie_buffer_reg_group_count > 11) {
const auto io_count = potential_body_params_count + last_node->get_output_size() + hidden_virtual_ports_count;
const auto data_count = io_count + uniqie_buffer_reg_group_count;
auto available_regs = config.get_data_ptr_gpr_count();
// [150148, 150149] Currently Snippets don't have mechanism of spilling registers on stack.
// Due to this limitation we have to skip tokenization of some subgraphs
// if we need more registers than we have on the target machine.
// `config.get_data_ptr_gpr_count()` provides available data registers count (including parameters, results and buffers)
// after excluding 2 registers for work amounts.
// However, MHA Subgraph has `SplitLoops` optimization which adds outermost blocked Loop by M. This Loop requires
// the separate own register for `work_amount` also. Thus, we have to decrement `available_regs` count in MHA case.
// Need to notice that in general we have enough count of available registers.
// But in rare cases (when there are a lot of parameters/results, the heuristic value of their number is `5`)
// the count of available registers might be not enough and we have to not tokenize these subgraphs.
// So only for these rare cases we decrement `available_regs` value.
if (io_count > 5)
available_regs--;

if (data_count > available_regs)
return false;
}

// If backend doesn't enable dynamic MHA tokenization, return false
if (!config.is_dynamic_mha_token_enabled()) {
Expand Down
15 changes: 13 additions & 2 deletions src/common/snippets/src/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,21 @@ auto get_non_scalar_constant_count_for_fq(const std::shared_ptr<ov::op::v0::Fake
}

bool broadcast_merge_dim(size_t& dst, const size_t& d1, const size_t& d2) {
if (d1 == d2 || d1 == 1 || is_dynamic_value(d2)) {
if (d1 == d2 || d1 == 1 || (is_dynamic_value(d1) && d2 != 1)) {
dst = d2;
return true;
} else if (d2 == 1 || is_dynamic_value(d1)) {
} else if (d2 == 1 || is_dynamic_value(d2)) {
dst = d1;
return true;
}
return false;
}

bool merge_dynamic_dim(size_t& dst, const size_t& d1, const size_t& d2) {
if (d1 == d2 || is_dynamic_value(d1)) {
dst = d2;
return true;
} else if (is_dynamic_value(d2)) {
dst = d1;
return true;
}
Expand Down
28 changes: 28 additions & 0 deletions src/common/snippets/tests/include/utils/broadcast_dim_merge.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <common_test_utils/ov_test_utils.hpp>


namespace ov {
namespace test {
namespace snippets {

// D1, D2, Result
using BroadcastMergeDimParams = std::tuple<size_t, size_t, size_t>;

class BroadcastMergeDimTest : public testing::TestWithParam<BroadcastMergeDimParams> {
public:
static std::string getTestCaseName(testing::TestParamInfo<BroadcastMergeDimParams> obj);

protected:
void SetUp() override;
BroadcastMergeDimParams m_dims = {};
};

} // namespace snippets
} // namespace test
} // namespace ov
56 changes: 56 additions & 0 deletions src/common/snippets/tests/src/utils/broadcast_merge_dim.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "utils/broadcast_dim_merge.hpp"

#include "common_test_utils/ov_test_utils.hpp"
#include "snippets/utils/utils.hpp"

namespace ov {
namespace test {
namespace snippets {

std::string BroadcastMergeDimTest::getTestCaseName(testing::TestParamInfo<BroadcastMergeDimParams> obj) {
BroadcastMergeDimParams params = obj.param;
std::ostringstream result;
result << "D0=" << ov::snippets::utils::value2str(std::get<0>(params)) << "_";
result << "D1=" << ov::snippets::utils::value2str(std::get<1>(params)) << "_";
result << "DST=" << ov::snippets::utils::value2str(std::get<2>(params));
return result.str();
}

void BroadcastMergeDimTest::SetUp() {
m_dims = this->GetParam();
}

TEST_P(BroadcastMergeDimTest, BrodcastMergeDim) {
size_t d1, d2, dst, result;
std::tie(d1, d2, dst) = this->m_dims;
ASSERT_TRUE(ov::snippets::utils::broadcast_merge_dim(result, d1, d2));
ASSERT_EQ(result, dst);
}

namespace BrodcastMergeDimInstantiation {

constexpr size_t dynamic = ov::snippets::utils::get_dynamic_value<size_t>();

const std::vector<BroadcastMergeDimParams> dimension_cases = {
{10, 10, 10},
{10, 1, 10},
{1, 10, 10},
{dynamic, 10, 10},
{10, dynamic, 10},
{dynamic, dynamic, dynamic},
{dynamic, 1, dynamic},
{1, dynamic, dynamic},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_BrodcastMergeDim, BroadcastMergeDimTest,
::testing::ValuesIn(dimension_cases),
BroadcastMergeDimTest::getTestCaseName);

} // namespace BrodcastMergeDimInstantiation
} // namespace snippets
} // namespace test
} // namespace ov
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ class Node {
NameFromType(getType()));
}

MemoryPtr getScratchPadMem(const DnnlMemoryDescPtr& desc) {
MemoryPtr getScratchPadMem(const MemoryDescPtr& desc) {
if (!scratchpadMem || !scratchpadMem->getDesc().isCompatible(*desc)) {
scratchpadMem = context->getScratchPad(curNumaNode)->createScratchPadMem(desc);
}
Expand Down
Loading

0 comments on commit 54f58b8

Please sign in to comment.