Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microNPU] Add a pass to reorder copy and compute nodes #10959

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
mod = ethosu_passes.HoistAllocates()(mod)
mod = ethosu_passes.CopyComputeReordering()(mod)
disable_storage_rewrite = curr_cfg.get("tir.disable_storage_rewrite", False)
if not disable_storage_rewrite:
mod = tvm.tir.transform.StorageRewrite()(mod)
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-argument, no-else-return, inconsistent-return-statements, too-many-nested-blocks
"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler."""
from collections import namedtuple
from typing import Optional
import numpy as np # type: ignore

import tvm
Expand Down Expand Up @@ -913,3 +914,27 @@ def HoistAllocates() -> tvm.IRModule:
The new module with hoisted allocate nodes.
"""
return _ffi_api.HoistAllocates()


def CopyComputeReordering(max_copy_movements: Optional[int] = None) -> tvm.IRModule:
"""
Reorders copy and compute nodes in such a way that independent DMA copies,
and computes happen in parallel.
Copies to buffers with local scope are not reordered, indeed they copy LUT
into the SHRAM which already happens in parallel with copying weights into
the weights encoder.

Parameters
----------
max_copy_movements: Optional[int]
The maximum number of movements allowed for a copy.
If None, the pass context option
tir.contrib.ethos-u.copy_compute_reordering_max_copy_movements
is used if provided, otherwise the default value will be 1.

Returns
-------
tvm.IRModule
The new module with copy and compute nodes reordered.
"""
return _ffi_api.CopyComputeReordering(max_copy_movements)
108 changes: 108 additions & 0 deletions src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,17 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <algorithm>

namespace tvm {

/*!
* \brief The maximum number of movements allowed for a copy in the CopyComputeReordering pass.
*/
constexpr const char* kCopyComputeReorderingMaxCopyMovements =
"tir.contrib.ethos-u.copy_compute_reordering_max_copy_movements";
TVM_REGISTER_PASS_CONFIG_OPTION(kCopyComputeReorderingMaxCopyMovements, Integer);

namespace tir {
namespace contrib {
namespace ethosu {
Expand Down Expand Up @@ -110,6 +120,104 @@ tvm::transform::Pass HoistAllocates() {

TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.HoistAllocates").set_body_typed(HoistAllocates);

/*!
* \brief Reorders copy and compute nodes in such a way that independent DMA copies,
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
* and computes happen in parallel.
* Copies to buffers with local scope are not reordered, indeed they copy LUT
* into the SHRAM which already happens in parallel with copying weights into
* the weights encoder.
*/
class CopyComputeReorderingMutator : public StmtExprMutator {
public:
explicit CopyComputeReorderingMutator(int max_copy_movements)
: _max_copy_movements{max_copy_movements} {}

PrimFunc operator()(PrimFunc main_func) {
if (_max_copy_movements > 0) {
auto prim_func_node{main_func.CopyOnWrite()};
prim_func_node->body = this->VisitStmt(main_func->body);
return GetRef<PrimFunc>(prim_func_node);
}
return main_func;
}

private:
Stmt VisitStmt_(const SeqStmtNode* op) override {
if (op->size() <= 1) {
return StmtExprMutator::VisitStmt_(op);
}

auto seq_stmt{GetRef<SeqStmt>(op)};
std::vector<Stmt> new_seq(seq_stmt->size());
std::copy(seq_stmt->seq.begin(), seq_stmt->seq.end(), new_seq.begin());

// Each copy statement to a buffer with global scope is moved up
// at most `_max_copy_movements` times.
for (size_t index = 0; index < new_seq.size(); ++index) {
if (stmt_is_global_copy(new_seq[index])) {
int lower = std::max(0, static_cast<int>(index) - _max_copy_movements);
for (int i = index; i > lower && !stmt_is_copy(new_seq[i - 1]); --i) {
std::swap(new_seq[i - 1], new_seq[i]);
}
}
}

auto seq_stmt_node{CopyOnWrite(op)};
seq_stmt_node->seq = std::move(new_seq);
return Stmt{seq_stmt_node};
}

tvm::runtime::Array<tvm::PrimExpr> get_stmt_args(const Stmt& stmt) {
auto eval_node{stmt.as<EvaluateNode>()};
ICHECK(eval_node) << "Expected statement to be an evaluate node, but was "
<< stmt->GetTypeKey();
auto call_node{eval_node->value.as<CallNode>()};
ICHECK(call_node) << "Expected expression to be a call node, but was "
<< eval_node->value->GetTypeKey();
return call_node->args;
}

bool stmt_is_copy(const Stmt& stmt) {
auto args{get_stmt_args(stmt)};
return args[0].as<StringImmNode>()->value == "ethosu_copy";
}

bool stmt_is_global_copy(const Stmt& stmt) {
auto args{get_stmt_args(stmt)};
return args[0].as<StringImmNode>()->value == "ethosu_copy" &&
args[3].as<BufferLoadNode>()->buffer.scope() == "global";
}

/*! The maximum number of movements allowed for a copy. */
int _max_copy_movements;
NicolaLancellotti marked this conversation as resolved.
Show resolved Hide resolved
};

/*!
* \brief A pass to reorder copy and compute nodes in such a way that independent DMA copies,
* and computes happen in parallel.
*
* \param max_copy_movements: The maximum number of movements allowed for a copy.
* If None, the pass context option tir.contrib.ethos-u.copy_compute_reordering_max_copy_movements
* is used if provided, otherwise the default value will be 1.
* \return tvm::transform::Pass
*/
tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements) {
auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main"))
NicolaLancellotti marked this conversation as resolved.
Show resolved Hide resolved
<< "Expected a single primitive function called 'main'. Please run the "
"CopyComputeReordering "
"pass in conjunction with the LowerToTIR() pass.";
auto value = max_copy_movements.value_or(
ctx->GetConfig(kCopyComputeReorderingMaxCopyMovements, Integer(1)).value());
return CopyComputeReorderingMutator(value)(f);
};
return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0,
"tir.contrib.ethos-u.CopyComputeReordering", {});
}

TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
.set_body_typed(CopyComputeReordering);

} // namespace ethosu
} // namespace contrib
} // namespace tir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def _get_ethosu_workspace_size(
@pytest.mark.parametrize(
"accel_type, expected_ws_size_without_striping, expected_ws_size_with_striping",
[
("ethos-u55-256", 1067408, 14096),
("ethos-u55-128", 1067408, 3968),
("ethos-u55-64", 1067408, 3968),
("ethos-u55-32", 1067392, 3952),
("ethos-u55-256", 1067520, 14208),
("ethos-u55-128", 1067520, 4080),
("ethos-u55-64", 1067520, 4080),
("ethos-u55-32", 1067504, 4064),
],
)
def test_double_conv2d(
Expand Down Expand Up @@ -161,10 +161,10 @@ def tf_graph(x):
@pytest.mark.parametrize(
"accel_type, expected_ws_size_without_striping, expected_ws_size_with_striping",
[
("ethos-u55-256", 180096, 15008),
("ethos-u55-128", 180096, 14240),
("ethos-u55-64", 180096, 14240),
("ethos-u55-32", 180096, 14240),
("ethos-u55-256", 180288, 15200),
("ethos-u55-128", 180288, 14432),
("ethos-u55-64", 180288, 14432),
("ethos-u55-32", 180272, 14416),
],
)
def test_depthwise2d_conv2d_pooling(
Expand Down
Loading