Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolaLancellotti committed Apr 12, 2022
1 parent 1648cce commit 78f05ea
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 3 deletions.
3 changes: 3 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,9 @@ def CopyComputeReordering() -> tvm.IRModule:
"""
Reorders copy and compute nodes in such a way that independent DMA copies,
and computes happen in parallel.
Copies to buffers with local scope are not reordered, indeed they copy LUT
into the SHRAM which already happens in parallel with copying weights into
the weights encoder.
Returns
-------
Expand Down
13 changes: 11 additions & 2 deletions src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.HoistAllocates").set_body_typed(HoistAl
/*!
* \brief Reorders copy and compute nodes in such a way that independent DMA copies,
* and computes happen in parallel.
* Copies to buffers with local scope are not reordered, indeed they copy LUT
* into the SHRAM which already happens in parallel with copying weights into
* the weights encoder.
*/
class CopyComputeReorderingMutator : public StmtExprMutator {
public:
Expand All @@ -136,8 +139,14 @@ class CopyComputeReorderingMutator : public StmtExprMutator {
bool previous_stmt_is_copy{true}; // Do not move the first stmt if it is a copy

for (size_t i{}; i < seq_stmt->size(); ++i) {
auto stmt{seq_stmt[i]};
auto args{stmt.as<EvaluateNode>()->value.as<CallNode>()->args};
Stmt stmt{seq_stmt[i]};
auto eval_node{stmt.as<EvaluateNode>()};
ICHECK(eval_node) << "Expected statement to be an evaluate node, but was "
<< stmt->GetTypeKey();
auto call_node{eval_node->value.as<CallNode>()};
ICHECK(call_node) << "Expected expression to be a call node, but was "
<< eval_node->value->GetTypeKey();
auto args{call_node->args};
bool stmt_is_copy{args[0].as<StringImmNode>()->value == "ethosu_copy"};
bool stmt_is_global_copy{stmt_is_copy &&
args[3].as<BufferLoadNode>()->buffer.scope() == "global"};
Expand Down
130 changes: 129 additions & 1 deletion tests/python/contrib/test_ethosu/test_copy_compute_reordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm.relay.backend.contrib.ethosu.tir.passes import CopyComputeReordering


def test_four_convolutions():
def test_all_operators_with_weights():
# fmt: off
@tvm.script.ir_module
class InputModule:
Expand Down Expand Up @@ -107,6 +107,96 @@ def main() -> None:
tvm.ir.assert_structural_equal(test_mod, reference_mod, True)


def test_all_operators_without_weights():
# fmt: off
@tvm.script.ir_module
class InputModule:
@T.prim_func
def main() -> None:
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
buffer1 = T.buffer_decl([36], "int8")
buffer2 = T.buffer_decl([9], "int8")
# body
p1 = T.allocate([96], "int8", "global")
T.evaluate(T.call_extern("ethosu_pooling", "int8", 3, 4, 3, 3, 0, 4, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 12, 3, 1, "int8", 3, 2, 3, 3, 0, 2, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 32, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_pooling", "int8", 3, 2, 3, 3, 0, 2, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 32, 16, 1, "int8", 3, 1, 3, 3, 0, 1, buffer2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 3, 1, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))

@tvm.script.ir_module
class ReferenceModule:
@T.prim_func
def main() -> None:
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
buffer1 = T.buffer_decl([36], "int8")
buffer2 = T.buffer_decl([9], "int8")
# body
p1 = T.allocate([96], "int8", "global")
T.evaluate(T.call_extern("ethosu_pooling", "int8", 3, 4, 3, 3, 0, 4, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 12, 3, 1, "int8", 3, 2, 3, 3, 0, 2, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 32, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_pooling", "int8", 3, 2, 3, 3, 0, 2, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 32, 16, 1, "int8", 3, 1, 3, 3, 0, 1, buffer2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 3, 1, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
# fmt: on

test_mod = CopyComputeReordering()(InputModule)
reference_mod = ReferenceModule
tvm.ir.assert_structural_equal(test_mod, reference_mod, True)


def test_operators_with_and_without_weights():
# fmt: off
@tvm.script.ir_module
class InputModule:
@T.prim_func
def main() -> None:
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
buffer1 = T.buffer_decl([97156], "int8")
buffer2 = T.buffer_decl([80], "uint8")
buffer3 = T.buffer_decl([64], "uint8")
buffer4 = T.buffer_decl([96], "uint8")
buffer5 = T.buffer_decl([32], "uint8")
# body
p1 = T.allocate([390336], "int8", "global")
p2 = T.allocate([80], "uint8", "global")
p3 = T.allocate([64], "uint8", "global")
p4 = T.allocate([390336], "int8", "global")
p5 = T.allocate([96], "uint8", "global")
p6 = T.allocate([32], "uint8", "global")
T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, p1[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, p4[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p2[0], 80, 0, p3[0], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p5[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 32, p6[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 5, 214, 0, 114, p4[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 3, 214, 0, 114, buffer3[0], 0, 0, 0, T.float32(0.104816), -128, "NHWC", 342, 3, 1, 3, 1, 1, 1, 1, 2, p5[0], 96, 0, p6[0], 32, 0, 1, 0, 1, "CLIP", -128, 127, "TFL", "NONE", 0, 0, 0, dtype="handle"))

@tvm.script.ir_module
class ReferenceModule:
@T.prim_func
def main() -> None:
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
buffer1 = T.buffer_decl([97156], "int8")
buffer2 = T.buffer_decl([80], "uint8")
buffer3 = T.buffer_decl([64], "uint8")
buffer4 = T.buffer_decl([96], "uint8")
buffer5 = T.buffer_decl([32], "uint8")
# body
p1 = T.allocate([390336], "int8", "global")
p2 = T.allocate([80], "uint8", "global")
p3 = T.allocate([64], "uint8", "global")
p4 = T.allocate([390336], "int8", "global")
p5 = T.allocate([96], "uint8", "global")
p6 = T.allocate([32], "uint8", "global")
T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 80, p2[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer3[0], 64, p3[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_pooling", "int8", 214, 227, 2, 214, 0, 227, buffer1[0], 0, 0, 0, T.float32(1), 0, "NHWC", 454, 2, 1, "int8", 214, 114, 2, 214, 0, 114, p1[0], 0, 0, 0, T.float32(1), 0, "NHCWB16", 1824, 16, 1, "MAX", 2, 1, 2, 1, 1, 1, 0, 0, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 96, p5[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", buffer5[0], 32, p6[0], dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 2, 214, 0, 114, p1[0], 0, 0, 0, T.float32(0.00392157), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 5, 214, 0, 114, p4[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, 3, 1, 1, 1, 1, 2, p2[0], 80, 0, p3[0], 64, 0, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 214, 114, 5, 214, 0, 114, p4[0], 0, 0, 0, T.float32(0.0174839), -128, "NHCWB16", 1824, 16, 1, "int8", 214, 114, 3, 214, 0, 114, buffer3[0], 0, 0, 0, T.float32(0.104816), -128, "NHWC", 342, 3, 1, 3, 1, 1, 1, 1, 2, p5[0], 96, 0, p6[0], 32, 0, 1, 0, 1, "CLIP", -128, 127, "TFL", "NONE", 0, 0, 0, dtype="handle"))
# fmt: on

test_mod = CopyComputeReordering()(InputModule)
reference_mod = ReferenceModule
tvm.ir.assert_structural_equal(test_mod, reference_mod, True)


def test_copy_to_buffer_with_local_scope():
# fmt: off
@tvm.script.ir_module
Expand Down Expand Up @@ -175,5 +265,43 @@ def main() -> None:
tvm.ir.assert_structural_equal(test_mod, reference_mod, True)


def test_multiple_prim_funcs():
# fmt: off
@tvm.script.ir_module
class InputModule:
@T.prim_func
def main():
T.evaluate(0)

@T.prim_func
def abc():
T.evaluate(0)
# fmt: on

err_rgx = (
r"Expected a single primitive function called 'main'. "
r"Please run the CopyComputeReordering pass in conjunction with the LowerToTIR\(\) pass."
)
with pytest.raises(tvm.TVMError, match=err_rgx):
CopyComputeReordering()(InputModule)


def test_no_main_prim_func():
# fmt: off
@tvm.script.ir_module
class InputModule:
@T.prim_func
def abs():
T.evaluate(0)
# fmt: on

err_rgx = (
r"Expected a single primitive function called 'main'. "
r"Please run the CopyComputeReordering pass in conjunction with the LowerToTIR\(\) pass."
)
with pytest.raises(tvm.TVMError, match=err_rgx):
CopyComputeReordering()(InputModule)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 78f05ea

Please sign in to comment.