Skip to content

Commit

Permalink
Improve pass
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolaLancellotti committed May 3, 2022
1 parent e42b246 commit 017c855
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 21 deletions.
7 changes: 5 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, unused-argument, no-else-return, inconsistent-return-statements
"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler."""
from collections import namedtuple
from typing import Optional
import numpy as np # type: ignore

import tvm
Expand Down Expand Up @@ -811,7 +812,7 @@ def HoistAllocates() -> tvm.IRModule:
return _ffi_api.HoistAllocates()


def CopyComputeReordering(max_copy_movements: int) -> tvm.IRModule:
def CopyComputeReordering(max_copy_movements: Optional[int] = None) -> tvm.IRModule:
"""
Reorders copy and compute nodes in such a way that independent DMA copies,
and computes happen in parallel.
Expand All @@ -821,8 +822,10 @@ def CopyComputeReordering(max_copy_movements: int) -> tvm.IRModule:
Parameters
----------
max_copy_movements: int
max_copy_movements: Optional[int]
The maximum number of movements allowed for a copy.
If None, the pass context option tir.copy_compute_reordering_max_copy_movements
is used if provided, otherwise the default value will be 1.
Returns
-------
Expand Down
54 changes: 35 additions & 19 deletions src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,17 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <algorithm>

namespace tvm {

/*!
* \brief The maximum number of movements allowed for a copy in the CopyComputeReordering pass.
*/
constexpr const char* kCopyComputeReorderingMaxCopyMovements =
"tir.copy_compute_reordering_max_copy_movements";
TVM_REGISTER_PASS_CONFIG_OPTION(kCopyComputeReorderingMaxCopyMovements, Integer);

namespace tir {
namespace contrib {
namespace ethosu {
Expand Down Expand Up @@ -119,13 +129,14 @@ TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.HoistAllocates").set_body_typed(HoistAl
*/
class CopyComputeReorderingMutator : public StmtExprMutator {
public:
CopyComputeReorderingMutator(int max_copy_movements) : _max_copy_movements{max_copy_movements} {}
explicit CopyComputeReorderingMutator(int max_copy_movements)
: _max_copy_movements{max_copy_movements} {}

PrimFunc operator()(PrimFunc main_func) {
if (_max_copy_movements > 0) {
auto n{main_func.CopyOnWrite()};
n->body = this->VisitStmt(main_func->body);
return GetRef<PrimFunc>(n);
auto prim_func_node{main_func.CopyOnWrite()};
prim_func_node->body = this->VisitStmt(main_func->body);
return GetRef<PrimFunc>(prim_func_node);
}
return main_func;
}
Expand All @@ -140,23 +151,23 @@ class CopyComputeReorderingMutator : public StmtExprMutator {
std::vector<Stmt> new_seq(seq_stmt->size());
std::copy(seq_stmt->seq.begin(), seq_stmt->seq.end(), new_seq.begin());

for (size_t index{}; index < new_seq.size(); ++index) {
for (int offset{}; offset < _max_copy_movements; ++offset) {
auto i{index - offset};
if (i > 0 && !stmt_is_copy(new_seq[i - 1]) && stmt_is_global_copy(new_seq[i])) {
std::swap(new_seq[i], new_seq[i - 1]);
} else {
break;
// Each copy statement to a buffer with global scope is moved up
// at most `_max_copy_movements` times.
for (size_t index = 0; index < new_seq.size(); ++index) {
if (stmt_is_global_copy(new_seq[index])) {
int lower = std::max(0, static_cast<int>(index) - _max_copy_movements);
for (int i = index; i > lower && !stmt_is_copy(new_seq[i - 1]); --i) {
std::swap(new_seq[i - 1], new_seq[i]);
}
}
}

auto n{CopyOnWrite(op)};
n->seq = std::move(new_seq);
return Stmt{n};
auto seq_stmt_node{CopyOnWrite(op)};
seq_stmt_node->seq = std::move(new_seq);
return Stmt{seq_stmt_node};
}

tvm::runtime::Array<tvm::PrimExpr> get_stmt_args(Stmt stmt) {
tvm::runtime::Array<tvm::PrimExpr> get_stmt_args(const Stmt& stmt) {
auto eval_node{stmt.as<EvaluateNode>()};
ICHECK(eval_node) << "Expected statement to be an evaluate node, but was "
<< stmt->GetTypeKey();
Expand All @@ -166,17 +177,18 @@ class CopyComputeReorderingMutator : public StmtExprMutator {
return call_node->args;
}

bool stmt_is_copy(Stmt stmt) {
bool stmt_is_copy(const Stmt& stmt) {
auto args{get_stmt_args(stmt)};
return args[0].as<StringImmNode>()->value == "ethosu_copy";
}

bool stmt_is_global_copy(Stmt stmt) {
bool stmt_is_global_copy(const Stmt& stmt) {
auto args{get_stmt_args(stmt)};
return args[0].as<StringImmNode>()->value == "ethosu_copy" &&
args[3].as<BufferLoadNode>()->buffer.scope() == "global";
}

/*! The maximum number of movements allowed for a copy. */
int _max_copy_movements;
};

Expand All @@ -185,15 +197,19 @@ class CopyComputeReorderingMutator : public StmtExprMutator {
* and computes happen in parallel.
*
* \param max_copy_movements: The maximum number of movements allowed for a copy.
* If None, the pass context option tir.copy_compute_reordering_max_copy_movements
* is used if provided, otherwise the default value will be 1.
* \return tvm::transform::Pass
*/
tvm::transform::Pass CopyComputeReordering(int max_copy_movements) {
tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements) {
auto pass_func = [=](PrimFunc f, IRModule mod, tvm::transform::PassContext ctx) {
ICHECK(mod->GetGlobalVars().size() == 1 && mod->ContainGlobalVar("main"))
<< "Expected a single primitive function called 'main'. Please run the "
"CopyComputeReordering "
"pass in conjunction with the LowerToTIR() pass.";
return CopyComputeReorderingMutator(max_copy_movements)(f);
auto value = max_copy_movements.value_or(
ctx->GetConfig(kCopyComputeReorderingMaxCopyMovements, Integer(1)).value());
return CopyComputeReorderingMutator(value)(f);
};
return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0,
"tir.contrib.ethos-u.CopyComputeReordering", {});
Expand Down
67 changes: 67 additions & 0 deletions tests/python/contrib/test_ethosu/test_copy_compute_reordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,5 +399,72 @@ def abs():
CopyComputeReordering(1)(InputModule)


def test_default_max_copy_movements():
# fmt: off
@tvm.script.ir_module
class ReferenceModule:
@T.prim_func
def main() -> None:
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
buffer1 = T.buffer_decl([97156], "int8")
buffer2 = T.buffer_decl([80], "uint8")
buffer3 = T.buffer_decl([64], "uint8")
buffer4 = T.buffer_decl([96], "uint8")
buffer5 = T.buffer_decl([32], "uint8")
# body
p1 = T.allocate([390336], "int8", "global")
p2 = T.allocate([80], "uint8", "global")
p3 = T.allocate([64], "uint8", "global")
p4 = T.allocate([390336], "int8", "global")
p5 = T.allocate([96], "uint8", "global")
p6 = T.allocate([32], "uint8", "global")
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p5[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 32, p6[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, p1[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, p4[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p2[0], 80, 0, p3[0], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 5, 214, 0, 114, p4[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 3, 214, 0, 114, buffer3[0], 0, 0, 0, T.float32(0.104816), -128, "NHWC", 342, 3, 1, 3, 1, 1, 1, 1, 2, p5[0], 96, 0, p6[0], 32, 0, 1, 0, 1, "CLIP", -128, 127, "TFL", "NONE", 0, 0, 0, dtype="handle"))
# fmt: on

test_mod = CopyComputeReordering()(OperatorsWithAndWithoutWeights)
reference_mod = ReferenceModule
tvm.ir.assert_structural_equal(test_mod, reference_mod, True)


def test_pass_context_option_max_copy_movements():
# fmt: off
@tvm.script.ir_module
class ReferenceModule:
@T.prim_func
def main() -> None:
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
buffer1 = T.buffer_decl([97156], "int8")
buffer2 = T.buffer_decl([80], "uint8")
buffer3 = T.buffer_decl([64], "uint8")
buffer4 = T.buffer_decl([96], "uint8")
buffer5 = T.buffer_decl([32], "uint8")
# body
p1 = T.allocate([390336], "int8", "global")
p2 = T.allocate([80], "uint8", "global")
p3 = T.allocate([64], "uint8", "global")
p4 = T.allocate([390336], "int8", "global")
p5 = T.allocate([96], "uint8", "global")
p6 = T.allocate([32], "uint8", "global")
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p5[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 32, p6[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, p1[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, p4[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p2[0], 80, 0, p3[0], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 5, 214, 0, 114, p4[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 3, 214, 0, 114, buffer3[0], 0, 0, 0, T.float32(0.104816), -128, "NHWC", 342, 3, 1, 3, 1, 1, 1, 1, 2, p5[0], 96, 0, p6[0], 32, 0, 1, 0, 1, "CLIP", -128, 127, "TFL", "NONE", 0, 0, 0, dtype="handle"))
# fmt: on

with tvm.transform.PassContext(config={"tir.copy_compute_reordering_max_copy_movements": 2}):
test_mod = CopyComputeReordering()(OperatorsWithAndWithoutWeights)
reference_mod = ReferenceModule
tvm.ir.assert_structural_equal(test_mod, reference_mod, True)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 017c855

Please sign in to comment.