Skip to content

Commit

Permalink
Match and lower ov::Relu (#143)
Browse files Browse the repository at this point in the history
Adds ReLU op matcher and lowering to MLIR named Linalg ops.
Also, adds buffer deallocation passes to prevent memory leaks
when temporary buffers are created in larger graphs.
  • Loading branch information
adam-smnk authored Jul 19, 2024
1 parent f1a9a7f commit b548f56
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 50 deletions.
62 changes: 13 additions & 49 deletions src/common/transformations/src/transformations/mlir/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
#include <algorithm>
#include <functional>
#include <openvino/op/add.hpp>
#include <openvino/op/subtract.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/divide.hpp>

#include <openvino/op/multiply.hpp>
#include <openvino/op/relu.hpp>
#include <openvino/op/subtract.hpp>
#include <openvino/pass/graph_rewrite.hpp>
#include <openvino/pass/manager.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <unordered_map>

// TODO: Prune unused headers -- it's hard to understand needed ones
#include "conversion_context.hpp"
#include "convert_common.hpp"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/InitLLVM.h"
Expand Down Expand Up @@ -55,20 +57,16 @@
#include "mlir/Target/LLVMIR/Dialect/All.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir_op.hpp"
#include "op/matmul.hpp"
#include "op/relu.hpp"
#include "openvino/core/dimension.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations_visibility.hpp"
#include "openvino/core/symbol.hpp"

#include "transformations/symbolic_transformations/symbolic_optimizations.hpp"

#include "mlir_op.hpp"
#include "convert_common.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "subgraph_tracker.hpp"
#include "conversion_context.hpp"
#include "op/matmul.hpp"

#include "transformations/symbolic_transformations/symbolic_optimizations.hpp"
#include "transformations_visibility.hpp"

namespace {

Expand Down Expand Up @@ -269,47 +267,12 @@ class Partitioner : public ov::pass::ModelPass {
}
};


bool elementwise_f32_binary_no_broadcast_predicate(const ov::Output<ov::Node>& output) {
if(output.get_element_type() != ov::element::f32) {
return false;
}
// Check if implicit broadcast is possible, reject in this case
// Relies on symbolic information -- register SymbolicPropagation before applying this pattern
auto input_shape_a = output.get_node_shared_ptr()->get_input_partial_shape(0);
auto input_shape_b = output.get_node_shared_ptr()->get_input_partial_shape(1);
auto output_shape = output.get_partial_shape();
if(output_shape.rank().is_dynamic() || input_shape_a.rank().is_dynamic() || input_shape_b.rank().is_dynamic()) {
return false;
}
if(output_shape.rank().get_length() != input_shape_a.rank().get_length() || output_shape.rank().get_length() != input_shape_b.rank().get_length()) {
return false;
}

for(size_t i = 0; i < output_shape.size(); ++i) {
if(output_shape[i] != input_shape_a[i] || output_shape[i] != input_shape_b[i]) {
return false;
}
// Continue if all shapes are static.
if (output_shape[i].is_static() && input_shape_a[i].is_static() &&
input_shape_b[i].is_static())
continue;
if(!ov::symbol::are_equal(output_shape[i].get_symbol(), input_shape_a[i].get_symbol()) || !ov::symbol::are_equal(output_shape[i].get_symbol(), input_shape_b[i].get_symbol())) {
return false;
}
}

return true;
}


template <typename Op>
NodePtr elementwise_f32_binary_no_broadcast() {
using namespace ov::pass::pattern;
return wrap_type<Op>({any_input(), any_input()}, elementwise_f32_binary_no_broadcast_predicate);
return wrap_type<Op>({any_input(), any_input()}, elementwise_no_broadcast_predicate<ov::element::f32>);
}


void injectMLIR(std::shared_ptr<ov::Model> model, MLIRContext* context) {
ov::pass::Manager manager;
using namespace ov::op;
Expand All @@ -319,6 +282,7 @@ void injectMLIR(std::shared_ptr<ov::Model> model, MLIRContext* context) {
manager.register_pass<MarkPattern>(elementwise_f32_binary_no_broadcast<v1::Subtract>(), ConvertBinary<linalg::SubOp>());
manager.register_pass<MarkPattern>(elementwise_f32_binary_no_broadcast<v1::Multiply>(), ConvertBinary<linalg::MulOp>());
manager.register_pass<MarkPattern>(elementwise_f32_binary_no_broadcast<v1::Divide>(), ConvertBinary<linalg::DivOp>());
manager.register_pass<ReluPattern>();
manager.register_pass<MatMulPattern>();
manager.register_pass<Partitioner>(context);
manager.run_passes(model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,42 @@ Location createLocation(MLIRContext* ctx, NodePtr node) {
return createLayerLocation(ctx, node->get_friendly_name(), node->get_type_name());
}

bool elementwise_no_broadcast_predicate_impl(const ov::Output<ov::Node>& output, ov::element::Type type) {
if (output.get_element_type() != type) {
return false;
}
// Check if implicit broadcast is possible, reject in this case
// Relies on symbolic information -- register SymbolicPropagation before applying this pattern
auto inputs = output.get_node_shared_ptr()->inputs();
auto output_shape = output.get_partial_shape();
if (output_shape.rank().is_dynamic()) {
return false;
}
if (std::any_of(inputs.begin(), inputs.end(), [&](const ov::Input<ov::Node>& input) {
auto input_shape = input.get_partial_shape();
return input_shape.rank().is_dynamic() ||
output_shape.rank().get_length() != input_shape.rank().get_length();
})) {
return false;
}

if (std::any_of(inputs.begin(), inputs.end(), [&](const ov::Input<ov::Node>& input) {
for (size_t i = 0; i < output_shape.size(); ++i) {
auto input_shape = input.get_partial_shape();
if (output_shape[i] != input_shape[i])
return true;
if (output_shape[i].is_static() && input_shape[i].is_static())
continue;
if (!ov::symbol::are_equal(output_shape[i].get_symbol(), input_shape[i].get_symbol()))
return true;
}
return false;
})) {
return false;
}

return true;
}

} // namespace mlir
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ RankedTensorType importTensor(MLIRContext* ctx,

Location createLocation(MLIRContext* ctx, NodePtr node);

bool elementwise_no_broadcast_predicate_impl(const ov::Output<ov::Node>& output, ov::element::Type type);

template <ov::element::Type_t type>
bool elementwise_no_broadcast_predicate(const ov::Output<ov::Node>& output) {
return elementwise_no_broadcast_predicate_impl(output, type);
}

// Borrowed it from TPP-MLIR. FIXME: Do we have a better upstreamed alternative?
template <typename T>
mlir::arith::ConstantOp getConstant(OpBuilder &builder, const ov::element::Type& precision, T value) {
Expand Down
14 changes: 13 additions & 1 deletion src/common/transformations/src/transformations/mlir/mlir_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,24 @@ void prepareMLIRKernelWithoutWrapper(mlir::OwningOpRef<mlir::ModuleOp>& module)
pm.addPass(bufferization::createEmptyTensorEliminationPass());

pm.addPass(bufferization::createOneShotBufferizePass());
// TODO: Add deallocation pass/pipeline to avoid memory leaks.
pm.addNestedPass<func::FuncOp>(bufferization::createFinalizingBufferizePass());

// Cleanup after bufferization - possibly remove redundant copies.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());

// Deallocation pipeline to avoid memory leaks from created temporary buffers.
pm.addPass(memref::createExpandReallocPass(/*emitDeallocs=*/false));
pm.addPass(createCanonicalizerPass());
bufferization::DeallocationOptions deallocOpts;
deallocOpts.privateFuncDynamicOwnership = false;
pm.addPass(bufferization::createOwnershipBasedBufferDeallocationPass(deallocOpts));
pm.addPass(createCanonicalizerPass());
pm.addPass(bufferization::createBufferDeallocationSimplificationPass());
pm.addPass(bufferization::createLowerDeallocationsPass());
pm.addPass(createCSEPass());
pm.addPass(createCanonicalizerPass());

// Blanket-convert any remaining high-level vector ops to loops if any remain.
pm.addNestedPass<func::FuncOp>(createConvertVectorToSCFPass());
// pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
Expand Down
59 changes: 59 additions & 0 deletions src/common/transformations/src/transformations/mlir/op/relu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Linalg/Passes.h"

#include <openvino/op/relu.hpp>
#include "openvino/pass/pattern/op/wrap_type.hpp"

#include "relu.hpp"
#include "../convert_common.hpp"

namespace {

using namespace ov::mlir;

struct ConvertRelu {
void operator()(ConversionContext& context, NodePtr node) {
auto loc = createLocation(context.context, node);
auto& builder = context.builder();
// TODO: Support broadcasts
const auto input = context.getInputs(node)[0];
const auto ov_output_element_type = node->get_output_element_type(0);
const auto ov_output_shape = node->get_output_partial_shape(0);
auto outType = importTensor(context.context, ov_output_shape, ov_output_element_type);
// Named unary ops directly overwrite data in `outs` buffer so, there is no need to provide non-empty
// destination at the tensor-level.
// Use `tensor.empty` to avoid temporary buffer allocation and memcpy after bufferization.
llvm::SmallVector<Value> dynamicSizes;
for (auto [idx, dim] : llvm::enumerate(outType.getShape())) {
if (!mlir::ShapedType::isDynamic(dim))
continue;
auto dimSize = builder.create<tensor::DimOp>(loc, input, idx);
dynamicSizes.push_back(dimSize);
}
auto empty = builder.create<tensor::EmptyOp>(loc, outType, dynamicSizes);
auto zero = getConstant(builder, ov_output_element_type, 0);
auto fill = builder.create<linalg::FillOp>(loc, mlir::ValueRange{zero}, mlir::ValueRange{empty});
auto relu =
builder.create<linalg::MaxOp>(loc, mlir::ValueRange{input, fill.getResult(0)}, mlir::ValueRange{empty});
context.addOutputs(node, relu);
}
};

} // namespace

namespace ov {
namespace mlir {

using namespace ov::pass::pattern;
using namespace ov::op;

ReluPattern::ReluPattern()
: MarkPattern(wrap_type<v0::Relu>({any_input()}, elementwise_no_broadcast_predicate<ov::element::f32>),
ConvertRelu()) {}

} // namespace mlir
} // namespace ov
23 changes: 23 additions & 0 deletions src/common/transformations/src/transformations/mlir/op/relu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Value.h"

#include "../conversion_context.hpp"

namespace ov {
namespace mlir {

class ReluPattern : public MarkPattern {
public:
OPENVINO_RTTI("ReluPattern", "0");
ReluPattern();
};

} // namespace mlir
} // namespace ov

0 comments on commit b548f56

Please sign in to comment.