From 3fc8133929aedd29bb4a623f148f13c6d65117bd Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Tue, 16 Nov 2021 08:45:53 -0800 Subject: [PATCH] Update Shape lowering pass (#38) * Update shape lowering pass. * Rebase. --- src/relax/backend/vm/vm_shape_lower.cc | 77 ++++++++++++++++++++------ tests/python/relax/test_transform.py | 66 +++++++++++++++++++++- tests/python/relax/test_vm.py | 69 ++++++++++++++++++----- 3 files changed, 176 insertions(+), 36 deletions(-) diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 14ed499180b8..d921d3a137e6 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -32,6 +32,38 @@ namespace tvm { namespace relax { +/*! + * \brief Visitor to apply a function to every Expr it visits. Also applies the function + * to the shape field of the var definition site if the var's shape is a ShapeExpr. + */ +class ExprApplyVisitWithShape : public ExprVisitor { + public: + explicit ExprApplyVisitWithShape(std::function f) : f_(f) {} + + void VisitVarDef(const Var& var) { + if (var.as()) { + this->VisitExpr(Downcast(var)); + } else { + this->VisitExpr(var); + } + if (var->shape_.operator bool() && var->shape_.value().as()) { + f_(Downcast(var->shape_.value())); + } + } + + void VisitExpr(const Expr& e) final { + ExprVisitor::VisitExpr(e); + f_(e); + } + + private: + std::function f_; +}; + +void PostOrderVisitWithShape(const Expr& e, std::function fvisit) { + ExprApplyVisitWithShape(fvisit).VisitExpr(e); +} + class VMShapeLowerMutator : public ExprMutator { public: static DataType ShapeDType() { return DataType::Int(64); }; @@ -58,18 +90,11 @@ class VMShapeLowerMutator : public ExprMutator { } void VisitBinding_(const MatchShapeNode* binding) override { - Expr shape = ExprMutator::VisitExpr(binding->value); - static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape"); - auto store_shape_attr = make_object(); - - Array pattern = binding->pattern; - Array indices; - for (size_t i = 0; i < pattern.size(); ++i) { - int idx = expr2slot_.at(pattern[i]); - indices.push_back(idx); - } - store_shape_attr->indices = indices; - builder_->Emit(Call(store_shape_op, {shape, shape_heap_}, Attrs(store_shape_attr)), "gv"); + Expr value = ExprMutator::VisitExpr(binding->value); + + // TODO(@yuchen): match_shape overloaded semantic: value is ShapeType + Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {value}), "sh"); + StoreShape(shape, binding->pattern); } Expr VisitExpr_(const ShapeExprNode* node) override { @@ -97,16 +122,18 @@ class VMShapeLowerMutator : public ExprMutator { } Expr VisitExpr_(const FunctionNode* node) override { + builder_->BeginBindingBlock(); + builder_->Emit(VarBinding( + shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})}))); Array params; for (Var param : node->params) { params.push_back(this->VisitVarDef(param)); + if (param->shape_.operator bool() && param->shape_.value().as()) { + Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {param}), "sh"); + StoreShape(shape, Downcast(param->shape_.value())->values); + } } Type ret_type = this->VisitType(node->ret_type); - - builder_->BeginBindingBlock(); - builder_->Emit(VarBinding( - shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})}))); - Expr new_body = this->VisitExpr(node->body); Array blocks; @@ -174,10 +201,24 @@ class VMShapeLowerMutator : public ExprMutator { } } }; - PostOrderVisit(expr, func); + PostOrderVisitWithShape(expr, func); return ret; } + /*! \brief Store symbolic shape into indices of the VM shape heap. */ + void StoreShape(Expr shape, Array pattern) { + static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape"); + auto store_shape_attr = make_object(); + + Array indices; + for (size_t i = 0; i < pattern.size(); ++i) { + int idx = expr2slot_.at(pattern[i]); + indices.push_back(idx); + } + store_shape_attr->indices = indices; + builder_->Emit(Call(store_shape_op, {shape, shape_heap_}, Attrs(store_shape_attr)), "gv"); + } + bool IsConstantShape(ShapeExpr shape) const { for (PrimExpr e : shape->values) { if (!e->IsInstance()) { diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index ad6ee597e31f..cafca7246334 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -23,7 +23,7 @@ from tvm.ir.module import IRModule import tvm.script -from tvm.script import relax as R +from tvm.script import tir as T, relax as R def test_fma_rewrite(): @@ -179,8 +179,7 @@ def test_vm_shape_lowering(): class TestVMShapeLower: @R.function def foo(x: Tensor[_, "float32"]) -> Shape: - sh = relax.call_packed("vm.builtin.shape_of", x) - relax.match_shape(sh, (n, m)) + relax.match_shape(x, (n, m)) return (n * 2, m * 3) mod = TestVMShapeLower @@ -196,6 +195,7 @@ def foo(x: Tensor[_, "float32"]) -> Shape: s1 = func.body.blocks[0].bindings[0].value assert isinstance(s1.op, relax.ExternFunc) assert s1.op.global_symbol == "vm.builtin.alloc_shape_heap" + assert s1.args[0].values[0] == 4 s2 = func.body.blocks[1].bindings[0].value assert isinstance(s2.op, relax.ExternFunc) assert s2.op.global_symbol == "vm.builtin.shape_of" @@ -209,6 +209,65 @@ def foo(x: Tensor[_, "float32"]) -> Shape: assert isinstance(s5, tvm.relay.Call) assert s5.op.name == "relax.vm.builtin.load_shape" + +def test_vm_shape_lowering_func_param_with_shape(): + src = """@tvm.script.ir_module +class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m,n)) + B = T.match_buffer(y, (n,k)) + C = T.match_buffer(z, (m,k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + @R.function + def foo(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor: + gv0 = R.call_dps((m, k), tir_matmul, (x, w)) + return gv0 +""" + mod = tvm.script.relax.parser.from_source(src) + + # after vm shape lowering + new_mod = relax.transform.VMShapeLower()(mod) + + assert isinstance(new_mod, tvm.IRModule) + assert isinstance(new_mod["shape_func"], tvm.tir.function.PrimFunc) + assert isinstance(new_mod["tir_matmul"], tvm.tir.function.PrimFunc) + func = new_mod["foo"] + assert isinstance(func, tvm.relax.expr.Function) + + x, w = func.params + s1 = func.body.blocks[0].bindings[0].value + assert isinstance(s1.op, relax.ExternFunc) + assert s1.op.global_symbol == "vm.builtin.alloc_shape_heap" + assert s1.args[0].values[0] == 3 + + s2 = func.body.blocks[0].bindings[1].value + assert isinstance(s2.op, relax.ExternFunc) + assert s2.op.global_symbol == "vm.builtin.shape_of" + assert s2.args[0] == x + s3 = func.body.blocks[0].bindings[2].value + assert isinstance(s3, tvm.relay.Call) + assert s3.op.name == "relax.vm.builtin.store_shape" + + s4 = func.body.blocks[0].bindings[3].value + assert isinstance(s4.op, relax.ExternFunc) + assert s4.op.global_symbol == "vm.builtin.shape_of" + assert s4.args[0] == w + s5 = func.body.blocks[0].bindings[2].value + assert isinstance(s5, tvm.relay.Call) + assert s5.op.name == "relax.vm.builtin.store_shape" + + def test_to_anf(): x = relax.Var("x", type_annotation=relax.DynTensorType()) gv = relax.op.add(x, x) @@ -241,4 +300,5 @@ def f(x: Tensor[_, "float32"]): test_call_dps_rewrite() test_vm_memory_lower() test_vm_shape_lowering() + test_vm_shape_lowering_func_param_with_shape() test_to_anf() diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index 66526fbf28eb..052cb32744a4 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -232,7 +232,7 @@ def test_vm_compile_stage0(): class TestVMCompileStage0: @R.function def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]): - z = relax.call_packed("test.vm.identity", x, y) + z = R.call_packed("test.vm.identity", x, y) return y mod = TestVMCompileStage0 @@ -272,13 +272,13 @@ def shape_func0(heap: T.handle) -> None: @R.function def foo(x: Tensor[_, "float32"]) -> Shape: - shape_heap: Tensor[(4,), "int64"] = relax.call_packed( + shape_heap: Tensor[(4,), "int64"] = R.call_packed( "vm.builtin.alloc_shape_heap", (4,) ) - gv0 = relax.call_packed("vm.builtin.shape_of", x) - gv1 = relax.call_packed("vm.builtin.store_shape", gv0, shape_heap, (0, 1)) + gv0 = R.call_packed("vm.builtin.shape_of", x) + gv1 = R.call_packed("vm.builtin.store_shape", gv0, shape_heap, (0, 1)) gv2 = shape_func0(shape_heap) - gv3 = relax.call_packed("vm.builtin.load_shape", shape_heap, (2, 3)) + gv3 = R.call_packed("vm.builtin.load_shape", shape_heap, (2, 3)) return gv3 """ @@ -301,8 +301,7 @@ def test_vm_compile_stage2(): class TestVMCompileStage2: @R.function def foo(x: Tensor[_, "float32"]) -> Shape: - sh = relax.call_packed("vm.builtin.shape_of", x) - relax.match_shape(sh, (n, m)) + R.match_shape(x, (n, m)) return (n * 2, m * 3) mod = TestVMCompileStage2 @@ -323,9 +322,9 @@ def test_vm_compile_stage3(): class TestVMCompileStage3: @R.function def foo(x: Tensor[(32, 16), "float32"]) -> Tensor: - with relax.dataflow(): - y = relax.call_dps((32, 16), "test.vm.identity", (x)) - relax.output(y) + with R.dataflow(): + y = R.call_dps((32, 16), "test.vm.identity", (x)) + R.output(y) return y mod = TestVMCompileStage3 @@ -345,11 +344,10 @@ def test_vm_compile_e2e(): class TestVMCompileE2E: @R.function def foo(x: Tensor[_, "float32"]) -> Tensor: - with relax.dataflow(): - sh = relax.call_packed("vm.builtin.shape_of", x) - x0 = relax.match_shape(sh, (n, m)) - y = relax.call_dps((n, m * 2), "test.vm.tile", (x)) - relax.output(y) + with R.dataflow(): + R.match_shape(x, (n, m)) + y = R.call_dps((n, m * 2), "test.vm.tile", (x)) + R.output(y) return y mod = TestVMCompileE2E @@ -364,6 +362,46 @@ def foo(x: Tensor[_, "float32"]) -> Tensor: res = vm["foo"](inp) np.testing.assert_allclose(np.tile(inp.asnumpy(), (1, 2)), res.asnumpy()) +def test_vm_compile_e2e_func_param_with_shape(): + src = """@tvm.script.ir_module +class TestVMCompileE2E2: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m,n)) + B = T.match_buffer(y, (n,k)) + C = T.match_buffer(z, (m,k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def func(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor: + gv0 = R.call_dps((m, k), tir_matmul, (x, w)) + return gv0 +""" + + mod = tvm.script.relax.parser.from_source(src) + + target = tvm.target.Target("llvm") + target_host = tvm.target.Target("llvm") + ex, lib = relax.vm.build(mod, target, target_host) + vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) + + import numpy as np + data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + res = vm["func"](data, weight) + expected = np.dot(data.asnumpy(), weight.asnumpy()) + np.testing.assert_allclose(expected, res.asnumpy(), rtol=1e-4, atol=1e-4) + if __name__ == "__main__": test_vm_execute() @@ -380,3 +418,4 @@ def foo(x: Tensor[_, "float32"]) -> Tensor: test_vm_compile_stage2() test_vm_compile_stage3() test_vm_compile_e2e() + test_vm_compile_e2e_func_param_with_shape()