Skip to content

Commit

Permalink
Softmax decomposition moved to data flow pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Dec 20, 2023
1 parent 2c08ec6 commit 257cb9d
Show file tree
Hide file tree
Showing 13 changed files with 417 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

/**
* @interface ReduceDecomposition
* @brief Decomposes snippets::Reduce operations to a range of low-level operations on linear IR
* @attention Only Reduce by last dimension is supported
* @ingroup snippets
*/
class ReduceDecomposition : public Pass {
public:
OPENVINO_RTTI("ReduceDecomposition", "Pass")
explicit ReduceDecomposition(size_t vector_size);
bool run(LinearIR& linear_ir) override;

private:
size_t m_vector_size;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
59 changes: 59 additions & 0 deletions src/common/snippets/include/snippets/op/reduce.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "snippets/shape_inference/shape_infer_instances.hpp"

namespace ov {
namespace snippets {
namespace op {

/**
* @interface ReduceBase
* @brief Base class for reduce operations.
* @arg m_axis reduce axis.
* @ingroup snippets
*/
class ReduceBase : public ov::op::Op {
public:
OPENVINO_OP("ReduceBase", "SnippetsOpset");

ReduceBase(const Output<Node>& x, size_t axis);
ReduceBase() = default;

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
size_t get_axis() const { return m_axis; }

protected:
size_t m_axis;
};

class ReduceSum : public ReduceBase {
public:
OPENVINO_OP("ReduceSum", "SnippetsOpset", ReduceBase);
ReduceSum(const Output<Node>& x, size_t axis) : ReduceBase(x, axis) {}
ReduceSum() = default;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
static std::set<ov::element::TypeVector> get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{ov::element::f32}};
}
};

class ReduceMax : public ReduceBase {
public:
OPENVINO_OP("ReduceMax", "SnippetsOpset", ReduceBase);
ReduceMax(const Output<Node>& x, size_t axis) : ReduceBase(x, axis) {}
ReduceMax() = default;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
static std::set<ov::element::TypeVector> get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{ov::element::f32}};
}
};

} // namespace op
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"

namespace ov {
namespace snippets {
namespace pass {

/**
* @interface SoftmaxDecomposition
* @brief Decomposes Softmax to a range of low-level operations
* @ingroup snippets
*/
class SoftmaxDecomposition: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SoftmaxDecomposition", "0");
SoftmaxDecomposition();
};

} // namespace pass
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,12 @@ class BrgemmShapeInfer : public IShapeInferSnippets {
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};

class ReduceShapeInfer : public IShapeInferSnippets {
size_t m_axis;
public:
explicit ReduceShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};

} // namespace snippets
} // namespace ov
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/snippets_isa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "op/vector_buffer.hpp"
#include "op/rank_normalization.hpp"
#include "op/perf_count.hpp"
#include "op/reduce.hpp"

namespace ov {
namespace snippets {
Expand Down
2 changes: 2 additions & 0 deletions src/common/snippets/include/snippets/snippets_isa_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ OV_OP(BroadcastMove, ov::snippets::op)
OV_OP(Scalar, ov::snippets::op)
OV_OP(Nop, ov::snippets::op)
OV_OP(RankNormalization, ov::snippets::op)
OV_OP(ReduceMax, ov::snippets::op)
OV_OP(ReduceSum, ov::snippets::op)

#ifdef SNIPPETS_DEBUG_CAPS
OV_OP(PerfCountBegin, ov::snippets::op)
Expand Down
135 changes: 135 additions & 0 deletions src/common/snippets/src/lowered/pass/reduce_decomposition.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/lowered/pass/reduce_decomposition.hpp"

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/lowered/pass/mark_loops.hpp"
#include "snippets/lowered/pass/iter_handler.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/itt.hpp"

#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/pass/pattern/matcher.hpp"


namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

namespace {
uint32_t get_initial_value(const ov::DiscreteTypeInfo& type_info) {
static const std::map<ov::DiscreteTypeInfo, uint32_t> reduce_initial_values {
{op::ReduceMax::get_type_info_static(), uint32_t(0xff7fffff)},
{op::ReduceSum::get_type_info_static(), uint32_t(0x00000000)},
};
OPENVINO_ASSERT(reduce_initial_values.count(type_info), "Unexpected ReduceType");
return reduce_initial_values.at(type_info);
}

std::shared_ptr<ov::Node> get_accumulation_node(const ov::Output<ov::Node>& input0,
const ov::Output<ov::Node>& input1,
const ov::DiscreteTypeInfo& type_info) {
if (type_info == op::ReduceMax::get_type_info_static()) {
return std::make_shared<ov::op::v1::Maximum>(input0, input1);
} else if (type_info == op::ReduceSum::get_type_info_static()) {
return std::make_shared<ov::op::v1::Add>(input0, input1);
} else {
OPENVINO_THROW("Unsupported reduce type: ", type_info);
}
}

std::shared_ptr<ov::Node> get_horizon_node(const ov::Output<ov::Node>& input, const ov::DiscreteTypeInfo& type_info) {
if (type_info == op::ReduceMax::get_type_info_static()) {
return std::make_shared<op::HorizonMax>(input);
} else if (type_info == op::ReduceSum::get_type_info_static()) {
return std::make_shared<op::HorizonSum>(input);
} else {
OPENVINO_THROW("Unsupported reduce type: ", type_info);
}
}
} // namespace

using LoopInfo = LinearIR::LoopManager::LoopInfo;

ReduceDecomposition::ReduceDecomposition(size_t vector_size) : m_vector_size{vector_size} {}

bool ReduceDecomposition::run(LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ReduceMaxDecompositionLowered")
const auto& loop_manager = linear_ir.get_loop_manager();
bool modified = false;
for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
const auto& reduce_expr = *expr_it;
const auto& reduce = ov::as_type_ptr<ov::snippets::op::ReduceBase>(reduce_expr->get_node());
if (!reduce)
continue;

const auto& reduce_type_info = reduce->get_type_info();
const auto& input_shape = reduce_expr->get_input_port_descriptor(0)->get_shape();
const auto work_amount = *(input_shape.rbegin());
const auto increment = m_vector_size <= work_amount ? m_vector_size : work_amount;
const bool is_dynamic = reduce->is_dynamic();
OPENVINO_ASSERT(reduce->get_axis() == input_shape.size() - 1, "ReduceDecomposition supports only Reduce by last dimension.");

// We need an iterator to the inserted element
auto push_node = [&](const std::shared_ptr<Node>& n) {
const auto expr = linear_ir.insert(expr_it, n);
if (is_dynamic)
expr->get()->updateShapes();
return std::make_pair(expr, n);
};
// Float constant values in byte representation
const auto fill_value = get_initial_value(reduce_type_info);
// Note: VectorBuffer is a special case, since it should go before the initial Load.
// The buffer must be initialized with fill_value before reduction
const auto vector_buffer = push_node(std::make_shared<op::VectorBuffer>());
const auto initial_fill = push_node(std::make_shared<op::Fill>(vector_buffer.second, 0, fill_value));

// Reduce loop
const auto fill = push_node(std::make_shared<op::Fill>(reduce->get_input_source_output(0), increment, fill_value));
const auto accumulation = push_node(get_accumulation_node(fill.second, initial_fill.second, reduce_type_info));

const auto reduce_loop_id = loop_manager->mark_loop(
fill.first,
expr_it,
work_amount,
increment,
0,
std::vector<ExpressionPort>{(*fill.first)->get_input_port(0), (*accumulation.first)->get_input_port(1)},
std::vector<ExpressionPort>{(*accumulation.first)->get_output_port(0)});
const auto loop_info = loop_manager->get_loop_info(reduce_loop_id);
const auto tail_size = work_amount % increment;
if (tail_size != 0) {
loop_info->handlers[LoopInfo::LAST_ITER].register_pass<SetFillOffset>(tail_size);
}
const auto horizon = push_node(get_horizon_node(accumulation.second, reduce_type_info));

// Transfer original ExpressionPorts
linear_ir.replace_input((*fill.first)->get_input_port(0), reduce_expr->get_input_port_connector(0));
linear_ir.replace_input(reduce_expr->get_output_port_connector(0)->get_consumers(), (*horizon.first)->get_output_port_connector(0));

// Update Loop info for outer loops
const std::vector<ExpressionPort> entry_points{(*fill.first)->get_input_port(0)};
const std::vector<ExpressionPort> exit_points{(*horizon.first)->get_output_port(0)};
for (auto loop_id : reduce_expr->get_loop_ids()) {
loop_manager->expression_replacement(vector_buffer.first,
expr_it,
reduce_expr,
loop_id,
entry_points,
exit_points);
}

expr_it = linear_ir.erase(expr_it);
modified = true;
}
return modified;
}

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
43 changes: 43 additions & 0 deletions src/common/snippets/src/op/reduce.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/itt.hpp"

#include "snippets/op/reduce.hpp"


namespace ov {
namespace snippets {
namespace op {

ReduceBase::ReduceBase(const Output<Node>& x, size_t axis) : Op({x}), m_axis(axis) {
constructor_validate_and_infer_types();
}

bool ReduceBase::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("axis", m_axis);
return true;
}

void ReduceBase::validate_and_infer_types() {
auto result_shape = get_input_partial_shape(0);
result_shape[m_axis] = 1;
set_output_type(0, get_input_element_type(0), result_shape);
}

std::shared_ptr<Node> ReduceSum::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(ReduceSum);
check_new_args_count(this, new_args);
return std::make_shared<ReduceSum>(new_args.at(0), m_axis);
}

std::shared_ptr<Node> ReduceMax::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(ReduceMax);
check_new_args_count(this, new_args);
return std::make_shared<ReduceMax>(new_args.at(0), m_axis);
}

} // namespace op
} // namespace snippets
} // namespace ov
10 changes: 9 additions & 1 deletion src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "snippets/pass/convert_constants.hpp"
#include "snippets/pass/convert_power_to_powerstatic.hpp"
#include "snippets/pass/transpose_decomposition.hpp"
#include "snippets/pass/softmax_decomposition.hpp"
#include "snippets/pass/matmul_to_brgemm.hpp"
#include "snippets/pass/fuse_transpose_brgemm.hpp"
#include "snippets/pass/set_softmax_ports.hpp"
Expand Down Expand Up @@ -43,6 +44,7 @@
#include "snippets/lowered/pass/insert_perf_count.hpp"
#include "snippets/lowered/pass/validate_shapes.hpp"
#include "snippets/lowered/pass/pass_config.hpp"
#include "snippets/lowered/pass/reduce_decomposition.hpp"

#include "transformations/utils/utils.hpp"

Expand Down Expand Up @@ -404,7 +406,11 @@ void Subgraph::data_flow_transformations(const BlockedShapeVector& blocked_input
manager.register_pass<snippets::pass::MatMulToBrgemm>();
manager.register_pass<snippets::pass::FuseTransposeBrgemm>();
manager.register_pass<snippets::pass::TransposeDecomposition>();
manager.register_pass<snippets::pass::SetSoftmaxPorts>();
if (getenv("DISABLE_DATA_FLOW_DECOMPOSITION")) {
manager.register_pass<snippets::pass::SetSoftmaxPorts>();
} else {
manager.register_pass<snippets::pass::SoftmaxDecomposition>();
}
}
manager.register_pass<snippets::pass::BroadcastToMoveBroadcast>();
manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
Expand Down Expand Up @@ -435,7 +441,9 @@ void Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,

lowered::pass::PassPipeline pipeline(lowered_pass_config);
pipeline.register_pass<lowered::pass::MarkLoops>(vector_size);
// TODO: remove SoftmaxDecomposition pass
pipeline.register_pass<lowered::pass::SoftmaxDecomposition>(vector_size);
pipeline.register_pass<lowered::pass::ReduceDecomposition>(vector_size);
pipeline.register_pass<lowered::pass::FuseLoops>();
pipeline.register_pass<lowered::pass::SplitLoops>();
pipeline.register_pass<lowered::pass::MoveResultOutOfLoop>();
Expand Down
Loading

0 comments on commit 257cb9d

Please sign in to comment.