Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sns matmul support #61

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/common/snippets/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ ie_faster_build(${TARGET_NAME}
)

target_link_libraries(${TARGET_NAME} PUBLIC openvino::runtime
PRIVATE ngraph_reference openvino::runtime::dev)
PRIVATE ngraph_reference ov_shape_inference openvino::runtime::dev)

target_include_directories(${TARGET_NAME} PUBLIC $<BUILD_INTERFACE:${PUBLIC_HEADERS_DIR}>)
target_include_directories(${TARGET_NAME} PUBLIC $<BUILD_INTERFACE:${PUBLIC_HEADERS_DIR}>
PRIVATE $<BUILD_INTERFACE:${SHAPE_INFER_INCLUDE_DIR}>)

add_cpplint_target(${TARGET_NAME}_cpplint FOR_TARGETS ${TARGET_NAME})

Expand Down
14 changes: 13 additions & 1 deletion src/common/snippets/include/snippets/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,21 @@ class Generator {
* @brief Default destructor
*/
virtual ~Generator() = default;
/**
* @interface GeneratorConfig
* @brief Allows to tweak the lowering process.
*/
class GeneratorConfig {
public:
// True if the lowered Emitters need to be accessed during runtime. Normally they're destroyed after code emission.
bool m_save_lowered_code = false;
};
/**
* @brief virtual method any specific implementation should implement
* @param m model in canonical for for table-based code generation
* @return pointer to generated code
*/
code generate(std::shared_ptr<ov::Model>& m, const void* compile_params = nullptr) const;
code generate(std::shared_ptr<ov::Model>& m, const GeneratorConfig& config, const void* compile_params = nullptr);

/**
* @brief gets target machine
Expand All @@ -127,6 +136,9 @@ class Generator {

protected:
std::shared_ptr<TargetMachine> target;
// todo: we need to save lowered code to access compiled brgemm kernels on execution time (normally lowered is destructed by then).
// This is temporary solution, remove this when kernel caching is implemented. Don't forget to make generate const method.
std::vector<AllocatedEmitter> lowered_saved;
};

} // namespace snippets
Expand Down
33 changes: 33 additions & 0 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ngraph/op/op.hpp"
#include "ngraph/op/matmul.hpp"

namespace ngraph {
namespace snippets {
namespace op {

/**
* @interface Brgemm
* @brief Brgemm is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows
* @ingroup snippets
*/
class Brgemm : public ngraph::op::v0::MatMul {
public:
OPENVINO_OP("Brgemm", "SnippetsOpset", ngraph::op::v0::MatMul);
Brgemm(const Output<Node>& A, const Output<Node>& B);
Brgemm() = default;

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

bool has_evaluate() const override { return false; }
};

} // namespace op
} // namespace snippets
} // namespace ngraph
21 changes: 14 additions & 7 deletions src/common/snippets/include/snippets/op/subgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class Subgraph : public ngraph::op::Op {
private:
void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes);
void convert_to_snippet_dialect();
void init_config();
// Count of potentional non-scalar Consants that will be created after some tranformations
// At the moment it's relevant only for FakeQuantize decomposition
// NOTE: To avoid overheads in each calcution of this count (for example, in validate_and_type_infer()),
Expand All @@ -144,23 +145,29 @@ class Subgraph : public ngraph::op::Op {
// TODO: Change logic of insert Converts. This exec element type can be different for plugins
const ov::element::Type execution_element_type = ov::element::f32;

// Config to know which transformations should be called.
// It helps to avoid overheads of extra transformation calls
struct {
ov::PartialShape master_shape;
size_t tileRank = 0; // set by plugin to specify the number of dimensions processed in a single kernel call

/**
* @interface SubgraphConfig
* @brief Config to optimize IR transformation pipeline. It indicates which transformations are necessary
* so the irrelevant ones could be skipped.
*/
class SubgraphConfig {
public:
// True if Subgraph contains FakeQuantize -> FQ decomposition should be called
bool m_is_quantized = false;
// True if we should align element types indise body
bool m_is_needed_to_align_precision = false;
// True if Subgraph contains TypeRelaxed nodes -> for several streams in tp mode we should copy body using mutexes
// because TypeRelaxed::copy_with_new_inputs() isn't save-thread method
bool m_has_type_relaxed_ops = false;
// True if we should check runtime info for nodes to call specific needed transformations
bool m_need_fill_tail_register = false;
// True if body has operations that don't support plugin-side domain optimizations
// (e.g. Transpose in general doesn't support dimensions collapsing)
// (e.g. Transpose, Softmax, MatMul in general doesn't support dimensions collapsing)
bool m_has_domain_sensitive_ops = false;
} config;

ov::PartialShape master_shape;
size_t tileRank = 0; // set by plugin to specify the number of dimensions processed in a single kernel call
};

static inline std::ostream& operator<<(std::ostream& os, const op::Subgraph::BlockedShape& blocked_shape) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pattern/matcher.hpp"

namespace ngraph {
namespace snippets {
namespace pass {

/**
* @interface FuseTransposeBrgemm
* @brief Fuses Transpose with Brgemm node, fusing on both Brgemm inputs and output is supported. Applicable to
* Transposes that don't change the position of the last dimension (since Brgemm supports strided rows i/o),
* but only 0213 Transpose is currently supported.
* @ingroup snippets
*/
class FuseTransposeBrgemm: public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("FuseTransposeBrgemm", "0");
FuseTransposeBrgemm();
static const std::set<std::vector<int>> supported_cases;
};

} // namespace pass
} // namespace snippets
} // namespace ngraph
28 changes: 28 additions & 0 deletions src/common/snippets/include/snippets/pass/matmul_to_brgemm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pattern/matcher.hpp"

namespace ngraph {
namespace snippets {
namespace pass {

/**
* @interface MatMulToBrgemm
* @brief Replaces ngraph::MatMul with snippets::op::Brgemm operation (only non-trasposing MatMuls are currently supported)
* @ingroup snippets
*/
class MatMulToBrgemm: public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("MatMulToBrgemm", "0");
MatMulToBrgemm();
};


} // namespace pass
} // namespace snippets
} // namespace ngraph
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/snippets_isa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "op/powerstatic.hpp"
#include "op/store.hpp"
#include "op/loop.hpp"
#include "op/brgemm.hpp"

namespace ngraph {
namespace snippets {
Expand Down
4 changes: 4 additions & 0 deletions src/common/snippets/include/snippets/snippets_isa_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

// SnippetS dialect
NGRAPH_OP(Load, ngraph::snippets::op)
NGRAPH_OP(LoadReshape, ngraph::snippets::op)
NGRAPH_OP(LoopBegin, ngraph::snippets::op)
NGRAPH_OP(LoopEnd, ngraph::snippets::op)
NGRAPH_OP(Brgemm, ngraph::snippets::op)
dmitry-gorokhov marked this conversation as resolved.
Show resolved Hide resolved
NGRAPH_OP(BroadcastLoad, ngraph::snippets::op)

NGRAPH_OP(Store, ngraph::snippets::op)
Expand Down
6 changes: 6 additions & 0 deletions src/common/snippets/include/snippets/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ inline auto is_scalar_constant(const std::shared_ptr<ngraph::Node>& source_outpu
return ngraph::is_type<ngraph::opset1::Constant>(source_output_node) && ngraph::shape_size(source_output_node->get_shape()) == 1;
}


ov::PartialShape get_port_planar_shape(const Output<Node>& out);
ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector<size_t>& layout);
std::vector<size_t> get_node_output_layout(const std::shared_ptr<Node>& node);
std::vector<size_t> get_node_output_layout(const Node* node);

} // namespace utils
} // namespace snippets
} // namespace ngraph
10 changes: 9 additions & 1 deletion src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ auto getRegisters(const std::shared_ptr<ngraph::Node> &n) -> RegInfo {
if (it_rt != rt.end())
rin.push_back(it_rt->second.as<size_t>());
}

return std::make_pair(rin, rout);
}

ngraph::snippets::code ngraph::snippets::Generator::generate(std::shared_ptr<ov::Model>& m,
const void* compile_params) const {
const GeneratorConfig& config,
const void* compile_params) {
OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::Generator::generate")
if (!target->is_supported())
throw ngraph_error("unsupported architecture for code generation");
Expand Down Expand Up @@ -157,6 +159,12 @@ ngraph::snippets::code ngraph::snippets::Generator::generate(std::shared_ptr<ov:
op.first->emit_data();
}
OV_ITT_TASK_NEXT(GENERATE, "::GetSnippet")

// todo: we save lowered to access compiled brgemm kernels on execution time (normally lowered is destructed by then)
// remove this when kernel caching is implemented. Don't forget to make generate const method.
if (config.m_save_lowered_code)
lowered_saved = lowered;

return target->get_snippet();
}

Expand Down
55 changes: 55 additions & 0 deletions src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/itt.hpp"
#include "snippets/op/brgemm.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "openvino/core/rt_info.hpp"
#include "snippets/utils.hpp"
#include "matmul_shape_inference.hpp"

namespace ngraph {
namespace snippets {
namespace op {

Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B) : MatMul() {
set_arguments({A, B});
set_output_size(1);
constructor_validate_and_infer_types();
}

void Brgemm::validate_and_infer_types() {
INTERNAL_OP_SCOPE(Brgemm_validate_and_infer_types);
element::Type result_et;
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)),
"Arguments do not have the same element type (arg0 element type: ",
get_input_element_type(0),
", arg1 element type: ",
get_input_element_type(1),
").");
// If no leading dimensions are provided, assume dense row-major inputs-outputs
NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(),
"Brgemm currently supports only static shapes.");

std::vector<ov::PartialShape> planar_input_shapes;
for (const auto& in : input_values())
planar_input_shapes.emplace_back(utils::get_port_planar_shape(in));

std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
ov::op::v0::shape_infer(this, planar_input_shapes, output_shapes);
const auto& output_layout = utils::get_node_output_layout(this);
output_shapes[0] = utils::get_reordered_planar_shape(output_shapes[0], output_layout);
set_output_type(0, result_et, output_shapes[0]);
}

std::shared_ptr<Node> Brgemm::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(Brgemm_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<Brgemm>(new_args.at(0), new_args.at(1));;
}

} // namespace op
} // namespace snippets
} // namespace ngraph
Loading