Skip to content

Commit

Permalink
[Snippets][CPU] Added Brgemm FP32 blocking support by dynamic K, N di…
Browse files Browse the repository at this point in the history
…mensions (#25745)

### Details:
- *Added update support of `K` and `N` dimensions for Brgemm block in
`BrgemmKernelExecutor::update_config`*

### Tickets:
 - *147852*

### Prerequisites:
- [x] #25378
  • Loading branch information
a-sidorova authored Aug 2, 2024
1 parent b625fcb commit b2319a5
Show file tree
Hide file tree
Showing 20 changed files with 259 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,30 @@ class BrgemmBlockingBase {
static bool blocking_loop_exists(const snippets::lowered::LoopManagerPtr& loop_manager,
const ov::snippets::lowered::ExpressionPtr& brgemm_expr);

static void mark_m_blocking(const snippets::lowered::LoopManagerPtr& loop_manager,
snippets::lowered::LinearIR::constExprIt loop_begin,
snippets::lowered::LinearIR::constExprIt loop_end,
const std::vector<snippets::lowered::LoopPort>& entries,
const std::vector<snippets::lowered::LoopPort>& exits,
size_t block_size_m);
void mark_m_blocking(const snippets::lowered::LoopManagerPtr& loop_manager,
snippets::lowered::LinearIR::constExprIt loop_begin,
snippets::lowered::LinearIR::constExprIt loop_end,
const std::vector<snippets::lowered::LoopPort>& entries,
const std::vector<snippets::lowered::LoopPort>& exits,
size_t block_size_m);

static void mark_n_blocking(const snippets::lowered::LoopManagerPtr& loop_manager,
snippets::lowered::LinearIR::constExprIt loop_begin,
snippets::lowered::LinearIR::constExprIt loop_end,
const std::vector<snippets::lowered::LoopPort>& entries,
const std::vector<snippets::lowered::LoopPort>& exits,
size_t block_size_n);
void mark_n_blocking(const snippets::lowered::LoopManagerPtr& loop_manager,
snippets::lowered::LinearIR::constExprIt loop_begin,
snippets::lowered::LinearIR::constExprIt loop_end,
const std::vector<snippets::lowered::LoopPort>& entries,
const std::vector<snippets::lowered::LoopPort>& exits,
size_t block_size_n);

static void mark_k_blocking(const snippets::lowered::LoopManagerPtr& loop_manager,
snippets::lowered::LinearIR::constExprIt loop_begin,
snippets::lowered::LinearIR::constExprIt loop_end,
const std::vector<snippets::lowered::LoopPort>& entries,
const std::vector<snippets::lowered::LoopPort>& exits,
size_t block_size_k);
void mark_k_blocking(const snippets::lowered::LoopManagerPtr& loop_manager,
snippets::lowered::LinearIR::constExprIt loop_begin,
snippets::lowered::LinearIR::constExprIt loop_end,
const std::vector<snippets::lowered::LoopPort>& entries,
const std::vector<snippets::lowered::LoopPort>& exits,
size_t block_size_k);

virtual SpecificIterationHandlers get_m_loop_handlers(size_t work_amount, size_t block_size) const;
virtual SpecificIterationHandlers get_n_loop_handlers(size_t work_amount, size_t block_size) const;
virtual SpecificIterationHandlers get_k_loop_handlers(size_t work_amount, size_t block_size) const;
};

/**
Expand Down
18 changes: 0 additions & 18 deletions src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,6 @@ class SetEvaluateOnce : public snippets::lowered::pass::RangedPass {
std::shared_ptr<snippets::lowered::pass::PassBase> merge(const std::shared_ptr<snippets::lowered::pass::PassBase>& other) override;
};

/**
* @interface SetBrgemmBeta
* @brief The pass updates all CPUBrgemm nodes with a new beta value
* @param m_beta - beta which must be set
* @ingroup snippets
*/
class SetBrgemmBeta : public snippets::lowered::pass::RangedPass {
public:
SetBrgemmBeta(float beta);
OPENVINO_RTTI("SetBrgemmBeta", "RangedPass")
bool run(snippets::lowered::LinearIR& linear_ir,
snippets::lowered::LinearIR::constExprIt begin,
snippets::lowered::LinearIR::constExprIt end) override;
std::shared_ptr<snippets::lowered::pass::PassBase> merge(const std::shared_ptr<snippets::lowered::pass::PassBase>& other) override;

private:
float m_beta = 0;
};
} // namespace pass
} // namespace lowered
} // namespace snippets
Expand Down
4 changes: 0 additions & 4 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op {
size_t get_offset_b() const { return get_input_offset(1); }
size_t get_offset_c() const { return get_output_offset(0); }

float get_beta() const { return m_beta; }
void set_beta(float beta) { m_beta = beta; }

static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1);

void validate_and_infer_types() override;
Expand All @@ -48,7 +45,6 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op {
std::vector<ov::PartialShape> get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const;
ov::PartialShape infer_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const;
ov::PartialShape get_planar_output_shape(const ov::PartialShape& output_shape) const;
float m_beta = 0.f;

private:
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c);
Expand Down
37 changes: 20 additions & 17 deletions src/common/snippets/src/lowered/pass/brgemm_blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void BrgemmBlockingBase::mark_m_blocking(const snippets::lowered::LoopManagerPtr
const auto planar_dims = ov::snippets::utils::get_planar_vdims(*entries[0].expr_port);
const auto m = *++planar_dims.rbegin();
const auto id = loop_manager->mark_loop(loop_begin, loop_end, m, block_size_m, 1, entries, exits, false);
loop_manager->get_loop_info<UnifiedLoopInfo>(id)->set_handlers(get_default_blocking_loop_handlers(m, block_size_m));
loop_manager->get_loop_info<UnifiedLoopInfo>(id)->set_handlers(get_m_loop_handlers(m, block_size_m));
}

void BrgemmBlockingBase::mark_n_blocking(const snippets::lowered::LoopManagerPtr& loop_manager,
Expand All @@ -67,7 +67,7 @@ void BrgemmBlockingBase::mark_n_blocking(const snippets::lowered::LoopManagerPtr
const auto planar_dims = ov::snippets::utils::get_planar_vdims(*entries[1].expr_port);
const auto n = *planar_dims.rbegin();
const auto id = loop_manager->mark_loop(loop_begin, loop_end, n, block_size_n, 0, entries, exits, false);
loop_manager->get_loop_info<UnifiedLoopInfo>(id)->set_handlers(get_default_blocking_loop_handlers(n, block_size_n));
loop_manager->get_loop_info<UnifiedLoopInfo>(id)->set_handlers(get_n_loop_handlers(n, block_size_n));
}

void BrgemmBlockingBase::mark_k_blocking(const snippets::lowered::LoopManagerPtr& loop_manager,
Expand All @@ -79,10 +79,17 @@ void BrgemmBlockingBase::mark_k_blocking(const snippets::lowered::LoopManagerPtr
const auto planar_dims = ov::snippets::utils::get_planar_vdims(*entries[0].expr_port);
const auto k = *planar_dims.rbegin();
const auto id = loop_manager->mark_loop(loop_begin, loop_end, k, block_size_k, entries, exits, false);
auto handlers = get_default_blocking_loop_handlers(k, block_size_k);
handlers.register_pass<ov::snippets::lowered::SpecificLoopIterType::FIRST_ITER,
ov::snippets::lowered::pass::SetBrgemmBeta>(0.f);
loop_manager->get_loop_info<UnifiedLoopInfo>(id)->set_handlers(handlers);
loop_manager->get_loop_info<UnifiedLoopInfo>(id)->set_handlers(get_k_loop_handlers(k, block_size_k));
}

SpecificIterationHandlers BrgemmBlockingBase::get_m_loop_handlers(size_t work_amount, size_t block_size) const {
return get_default_blocking_loop_handlers(work_amount, block_size);
}
SpecificIterationHandlers BrgemmBlockingBase::get_n_loop_handlers(size_t work_amount, size_t block_size) const {
return get_default_blocking_loop_handlers(work_amount, block_size);
}
SpecificIterationHandlers BrgemmBlockingBase::get_k_loop_handlers(size_t work_amount, size_t block_size) const {
return get_default_blocking_loop_handlers(work_amount, block_size);
}

std::tuple<size_t, size_t, size_t> BrgemmBlockingBase::get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) {
Expand All @@ -102,21 +109,19 @@ std::tuple<size_t, size_t, size_t> BrgemmBlockingBase::get_blocking_params(const
// Ticket: 113745
// TODO: extend block size selection heuristics
auto get_block_size_m = [](const size_t M) -> size_t {
if (snippets::utils::is_dynamic_value(M))
return 32;
return M <= 32 ? get_full_dim_value() : 32;
if (!snippets::utils::is_dynamic_value(M) && M <= 32)
return get_full_dim_value();
return 32;
};
auto get_block_size_n = [](const size_t N) -> size_t {
// N blocking is disabled in dynamism by default
if (ov::snippets::utils::is_dynamic_value(N) || N <= 64)
if (!snippets::utils::is_dynamic_value(N) && N <= 64)
return get_full_dim_value();
return 64;
};
auto get_block_size_k = [](const size_t K) -> size_t {
// K blocking is disabled in dynamism by default
if (ov::snippets::utils::is_dynamic_value(K) || K <= 512)
return get_full_dim_value();
return K > 1024 ? 1024 : 512;
if (ov::snippets::utils::is_dynamic_value(K))
return 512;
return K > 1024 ? 1024 : K > 512 ? 512 : get_full_dim_value();
};
return std::make_tuple(get_block_size_m(m), get_block_size_n(n), get_block_size_k(k));
}
Expand All @@ -137,8 +142,6 @@ bool BrgemmBlockingBase::mark_blocking_loops(snippets::lowered::LinearIR& linear
LoopPort(brgemm_expr->get_input_port(1), true, 1)};
const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), false)};
mark_k_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, k_block);
} else {
ov::as_type_ptr<ov::snippets::op::Brgemm>(brgemm_expr->get_node())->set_beta(0.f);
}
if (!ov::snippets::utils::is_full_dim_value(n_block)) {
const std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), false),
Expand Down
22 changes: 0 additions & 22 deletions src/common/snippets/src/lowered/pass/iter_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,28 +154,6 @@ std::shared_ptr<snippets::lowered::pass::PassBase> SetEvaluateOnce::merge(const
return !other || ov::is_type<SetEvaluateOnce>(other) ? std::make_shared<SetEvaluateOnce>() : nullptr;
}

SetBrgemmBeta::SetBrgemmBeta(float beta) : snippets::lowered::pass::RangedPass(), m_beta(beta) {}

bool SetBrgemmBeta::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
for (auto expr_it = begin; expr_it != end; ++expr_it) {
const auto& expr = expr_it->get();
if (const auto brgemm = ov::as_type_ptr<ov::snippets::op::Brgemm>(expr->get_node())) {
brgemm->set_beta(m_beta);
}
}
return true;
}

std::shared_ptr<snippets::lowered::pass::PassBase> SetBrgemmBeta::merge(const std::shared_ptr<snippets::lowered::pass::PassBase>& other) {
const auto merged_pass = std::make_shared<SetBrgemmBeta>(m_beta);
if (other == nullptr)
return merged_pass;
const auto casted_pass = ov::as_type_ptr<SetBrgemmBeta>(other);
if (!casted_pass || m_beta != casted_pass->m_beta)
return nullptr;
return merged_pass;
}

} // namespace pass
} // namespace lowered
} // namespace snippets
Expand Down
1 change: 0 additions & 1 deletion src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ std::shared_ptr<Node> Brgemm::clone_with_new_inputs(const OutputVector& new_args
}

bool Brgemm::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("beta", m_beta);
return MemoryAccess::visit_attributes(visitor);
}

Expand Down
2 changes: 2 additions & 0 deletions src/common/snippets/src/runtime_configurator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ void RuntimeConfigurator::update_loop_info(const lowered::LinearIRPtr& linear_ir
// If the specific iteration is not needed, we skip loop evaluation - set zero as work amount is enough
if (!lowered::pass::InsertSpecificIterations::is_decomposed_loop_needed(current_unified_loop_info, decomposed_loop_type, current_work_amount)) {
expanded_loop_info->set_work_amount(0);
if (expanded_loop_info->is_evaluate_once())
expanded_loop_info->set_increment(0);
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa,
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
const auto brgemm_type = brgemm_node->get_type();
BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc,
brgemm_node->get_beta(), with_amx(brgemm_type),
with_compensations(brgemm_type),
BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, with_amx(brgemm_type), with_compensations(brgemm_type),
brgemm_utils::get_primitive_isa(brg0Prc, with_amx(brgemm_type)));
m_kernel_executor = kernel_table->register_kernel<BrgemmKernelExecutor>(expr,
compiled_kernel_cache,
Expand Down
Loading

0 comments on commit b2319a5

Please sign in to comment.