diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 6956eb7f368c..97820bfce6ba 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -193,10 +193,7 @@ class ExprMutator : public ExprFunctor { * \return expr. */ Expr Mutate(const Expr& expr) { - if (memo_.count(expr) == 0) { - memo_[expr] = this->VisitExpr(expr); - } - return Downcast(memo_[expr]); + return this->VisitExpr(expr); } Expr VisitExpr(const Expr& expr) override; @@ -226,6 +223,7 @@ class ExprMutator : public ExprFunctor { virtual void VisitBinding(const Binding& binding); virtual Var VisitVarBinding(const VarBinding& binding); virtual void VisitMatchShape(const MatchShape& binding); + virtual BindingBlock VisitBindingBlock(const BindingBlock& block); virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block); diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py index bb7c2458c797..53a621d0b29f 100644 --- a/python/tvm/relax/exec_builder.py +++ b/python/tvm/relax/exec_builder.py @@ -20,6 +20,7 @@ import tvm from tvm._ffi._ctypes.packed_func import TVMRetValueHandle from tvm.runtime import Object +from tvm.runtime.container import ShapeTuple from tvm._ffi.base import _LIB, check_call from . vm import Executable from . import _ffi_api @@ -89,7 +90,11 @@ def emit_call( dst = SpecialReg.VOID_ARG args_ = [] for arg in args: - if isinstance(arg, tvm.nd.NDArray) or isinstance(arg, tvm.DataType): + if isinstance(arg, tuple): + shape_tuple = ShapeTuple(arg) + new_arg = self.emit_constant(shape_tuple) + args_.append(new_arg) + elif isinstance(arg, (tvm.nd.NDArray, tvm.DataType, ShapeTuple)): new_arg = self.emit_constant(arg) args_.append(new_arg) else: diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index c205a5f41214..12dccd26e6b2 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -16,9 +16,10 @@ # under the License. # pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck -from tvm import IRModule +from tvm import IRModule from . import _ffi_api + def fma_rewrite(expr): """Perform fused multiply add rewriting in dataflow blocks. @@ -29,22 +30,45 @@ def fma_rewrite(expr): """ return _ffi_api.fma_rewrite(expr) -def explicit_memory_rewrite(expr): - """Perform explicit memory allocation for call_dps in dataflow blocks. +def to_non_dataflow(mod: IRModule) -> IRModule: + """Transform all dataflow structure to non-dataflow version. Parameters ---------- - expr : tvm.relay.Expr - The input expression. + mod : tvm.IRModule + The input module. """ - return _ffi_api.explicit_memory_rewrite(expr) + return _ffi_api.to_non_dataflow(mod) + + +def call_dps_rewrite(mod: IRModule) -> IRModule: + """Perform explicit memory allocation for call_dps. + + Parameters + ---------- + mod : tvm.IRModule + The input module. + """ + return _ffi_api.call_dps_rewrite(mod) + + +def memory_lower(mod: IRModule) -> IRModule: + """Perform memory lowering. Lower the relax.builtin.alloc_tensor op to VM builtin functions. + + Parameters + ---------- + mod : tvm.IRModule + The input module. + """ + return _ffi_api.memory_lower(mod) + def shape_lower(mod: IRModule) -> IRModule: """Lower the shape expression in relax to shape heap and TIR functions. Parameters ---------- - expr : tvm.IRModule + mod : tvm.IRModule The input module. """ return _ffi_api.shape_lower(mod) diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index 210ea7735c8e..8e7258b3b1ed 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -20,6 +20,7 @@ from tvm.runtime import Object, Device, Module, PackedFunc from tvm._ffi.base import _LIB, check_call from . import _ffi_api +from . import transform from ..rpc.base import RPC_SESS_MASK @@ -164,5 +165,9 @@ def build(mod: tvm.IRModule, lib: tvm.runtime.Module A runtime module that contains generated code. """ - ex, lib = _ffi_api.VMBuild(mod, target, target_host) + new_mod = transform.to_non_dataflow(mod) + new_mod = transform.call_dps_rewrite(new_mod) + new_mod = transform.memory_lower(new_mod) + new_mod = transform.shape_lower(new_mod) + ex, lib = _ffi_api.VMBuild(new_mod, target, target_host) return ex, lib diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 2b43128feb9a..3c5a4202db60 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -120,13 +120,19 @@ void ExprVisitor::VisitBinding(const Binding& binding) { } } -void ExprVisitor::VisitVarBinding(const VarBinding& binding) { this->VisitExpr(binding->value); } +void ExprVisitor::VisitVarBinding(const VarBinding& binding) { + this->VisitExpr(binding->value); + this->VisitExpr(binding->var); +} void ExprVisitor::VisitMatchShape(const MatchShape& binding) { this->VisitExpr(binding->value); // TODO(ziheng): should we change pattern from // Array to ShapeExpr? this->VisitExpr(ShapeExpr(binding->pattern)); + if (binding->var.defined()) { + this->VisitExpr(binding->var); + } } void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { @@ -214,6 +220,10 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { } Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { + auto it = var_remap_.find(GetRef(op)); + if (it != var_remap_.end()) { + return it->second; + } if (op->type_annotation.defined()) { Type type = this->VisitType(op->type_annotation.value()); if (!op->type_annotation.same_as(type)) { @@ -339,7 +349,7 @@ void ExprMutator::VisitBinding(const Binding& binding) { Var ExprMutator::VisitVarBinding(const VarBinding& binding) { Expr new_value = builder_->Normalize(this->Mutate(binding->value)); - Var new_var = Downcast(this->Mutate(binding->var)); + // TODO(@altanh): this probably shouldn't live here, all passes would have to make sure to do it // in this method... // if (new_value->shape_.defined()) { @@ -356,6 +366,7 @@ Var ExprMutator::VisitVarBinding(const VarBinding& binding) { // new_var->checked_type_ = new_value->checked_type_; // } + Var new_var = Downcast(this->Mutate(binding->var)); if (!builder_->CanProveShapeEqual(new_var->shape(), new_value->shape()) || !StructuralEqual()(new_var->checked_type(), new_value->checked_type())) { new_var = Var(new_var->vid, NullOpt, NullOpt, new_var->span); @@ -380,7 +391,14 @@ Var ExprMutator::VisitVarBinding(const VarBinding& binding) { void ExprMutator::VisitMatchShape(const MatchShape& binding) { Expr new_value = this->Mutate(binding->value); Expr new_pattern = this->Mutate(ShapeExpr(binding->pattern)); - Var new_var = Downcast(this->Mutate(binding->var)); + Var new_var; + if (binding->var.defined()){ + new_var = Downcast(this->Mutate(binding->var)); + } else { + new_var = binding->var; + } + + // TODO: when value's shape/type changed, create new var builder_->EmitMatchShape( MatchShape(new_value, Downcast(new_pattern)->values, new_var)); } diff --git a/src/relax/transform/call_dps_rewrite.cc b/src/relax/transform/call_dps_rewrite.cc new file mode 100644 index 000000000000..453ce198a2d1 --- /dev/null +++ b/src/relax/transform/call_dps_rewrite.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/call_dps_rewrite.cc + * \brief + */ +#include +#include +#include +#include + +#include "../../relay/transforms/pattern_utils.h" + +namespace tvm { +namespace relax { + +// ================== +// CallDPSMutator +// Example: +// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x)) +// --> +// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m]) +// rx.call_packed(op.identity, x, lv0) + +class CallDPSMutator : public ExprMutator { + public: + explicit CallDPSMutator(IRModule mod) { mod_ = mod; } + + IRModule Lower() { + IRModule ret_mod = IRModule(); + for (auto& p : mod_->functions) { + Expr func = p.second; + if (p.second->IsInstance()) { + func = this->Mutate(p.second); + } + ret_mod->Add(p.first, Downcast(func)); + } + return ret_mod; + } + + Expr VisitExpr_(const CallNode* call) override { + // post-order mutation + Expr expr = ExprMutator::VisitExpr_(call); + call = expr.as(); + // TODO(@yuchen, @altanh): using mutate cause infinite recursion + // Expr expr = ExprMutator::Mutate(GetRef(call)); + + static const Op& call_dps_op = Op::Get("relax.call_dps"); + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + + if (call->op == call_dps_op) { + ShapeExpr output_shape = Downcast(call->args[0]); + Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc"); + builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_"); + return tensor; + } + + return GetRef(call); + } + + private: + IRModule mod_; +}; + +TVM_REGISTER_GLOBAL("relax.transform.call_dps_rewrite").set_body_typed([](IRModule mod) { + return CallDPSMutator(mod).Lower(); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/memory_rewrite.cc b/src/relax/transform/memory_rewrite.cc index 39b4a56b3fd1..c80ecc088981 100644 --- a/src/relax/transform/memory_rewrite.cc +++ b/src/relax/transform/memory_rewrite.cc @@ -20,6 +20,7 @@ * \file src/relax/transform/memory_rewrite.cc * \brief */ +#include #include #include #include @@ -30,14 +31,31 @@ namespace tvm { namespace relax { // ================== -// ExplicitMemMutator +// MemLowerMutator +// Lower the relax.builtin.alloc_tensor op to VM builtin functions. // Example: -// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x)) +// x = relax.builtin.alloc_tensor((m, n)) // --> -// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m]) -// rx.call_packed(op.identity, x, lv0) +// gv0 = relax.call_packed("vm.builtin.alloc_storage", (m * n), alignment, device_type, +// relax.attrs.AllocStorageAttrs) gv1 = relax.call_packed("vm.builtin.alloc_tensor", gv0, offset, +// (m, n), relax.attrs.AllocTensorAttrs) + +class MemLowerMutator : public ExprMutator { + public: + explicit MemLowerMutator(IRModule mod) { mod_ = mod; } + + IRModule Lower() { + IRModule ret_mod = IRModule(); + for (auto& p : mod_->functions) { + Expr func = p.second; + if (p.second->IsInstance()) { + func = this->Mutate(p.second); + } + ret_mod->Add(p.first, Downcast(func)); + } + return ret_mod; + } -class ExplicitMemMutator : public ExprMutator { Expr ComputeStorageSize(const Expr& shape, const Type& type) const { DynTensorType tensor_type = Downcast(type); DataType dtype = DataType(tensor_type->dtype); @@ -63,44 +81,47 @@ class ExplicitMemMutator : public ExprMutator { return ret; } - BindingBlock VisitBindingBlock(const BindingBlock& block) { - builder_->BeginBindingBlock(); - for (Binding binding : block->bindings) { - this->VisitBinding(binding); - } - return builder_->EndBlock(); - } - Expr VisitExpr_(const CallNode* call) override { // post-order mutation Expr expr = ExprMutator::VisitExpr_(call); call = expr.as(); - // TODO(@yuchen, @altanh): using mutate cause infinite recursion - // Expr expr = ExprMutator::Mutate(GetRef(call)); - static const Op& call_dps_op = Op::Get("relax.call_dps"); static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); - if (call->op == call_dps_op) { - ShapeExpr output_shape = Downcast(call->args[0]); - Type arg_type = Downcast(call->args[2])->fields[0]->checked_type(); - Expr output_size = ComputeStorageSize(output_shape, arg_type); - Var tensor = builder_->Emit(Call(alloc_tensor_op, {call->args[0]}), "alloc"); - builder_->Emit(Call(call->args[1], {call->args[2], tensor}), "_"); - return tensor; + if (call->op == alloc_tensor_op) { + ShapeExpr tensor_shape = Downcast(call->args[0]); + // TODO(@yuchen): Get the type of input x, options: add an attr to relax.builtin.alloc_tensor + Type tensor_type = DynTensorType(tensor_shape->values.size(), DataType::Float(32)); + Expr storage_size = ComputeStorageSize(tensor_shape, tensor_type); + ShapeExpr alignment = ShapeExpr({IntImm(DataType::Int(64), 64)}); + ShapeExpr device_type = ShapeExpr({IntImm(DataType::Int(64), 1)}); + auto storage_attr = make_object(); + storage_attr->dtype = DataType::Float(32); + storage_attr->device_type = 1; + + Var storage = + builder_->Emit(Call(ExternFunc("vm.builtin.alloc_storage"), + {storage_size, alignment}, Attrs(storage_attr)), + "storage"); + + ShapeExpr offset = ShapeExpr({IntImm(DataType::Int(64), 0)}); + auto tensor_attr = make_object(); + tensor_attr->dtype = DataType::Float(32); + Expr shape = call->args[0]; + return builder_->Emit( + Call(ExternFunc("vm.builtin.alloc_tensor"), {storage, offset, shape}, Attrs(tensor_attr)), + "tensor"); } return GetRef(call); } -}; -Expr ExplicitMemRewrite(const Expr& e) { - return ExplicitMemMutator().Mutate(e); -} + private: + IRModule mod_; +}; -TVM_REGISTER_GLOBAL("relax.transform.explicit_memory_rewrite") -.set_body_typed([](Expr expr) { - return ExplicitMemRewrite(expr); +TVM_REGISTER_GLOBAL("relax.transform.memory_lower").set_body_typed([](IRModule mod) { + return MemLowerMutator(mod).Lower(); }); } // namespace relax diff --git a/src/relax/transform/shape_lower.cc b/src/relax/transform/shape_lower.cc index c251f8ad883b..d2d15f05c7e1 100644 --- a/src/relax/transform/shape_lower.cc +++ b/src/relax/transform/shape_lower.cc @@ -42,19 +42,19 @@ class ShapeLowerMutator : public ExprMutator { IRModule Lower() { ret_mod_ = IRModule(); for (auto& p : mod_->functions) { - if (!p.second->IsInstance()) { - continue; + Expr func = p.second; + if (p.second->IsInstance()) { + // prepare mapping and heap var + expr2slot_ = PrepareExpr2Slot(Downcast(func)); + // LOG(INFO) << "mapping: " << expr2slot_; + heap_size_ = IntImm(ShapeDType(), expr2slot_.size()); + DynTensorType heap_type(1, ShapeDType()); + shape_heap_ = Var("shape_heap", ShapeExpr({heap_size_}), heap_type); + + // mutate + func = this->Mutate(func); } - // prepare mapping and heap var - expr2slot_ = PrepareExpr2Slot(Downcast(p.second)); - // LOG(INFO) << "mapping: " << expr2slot_; - heap_size_ = IntImm(ShapeDType(), expr2slot_.size()); - DynTensorType heap_type(1, ShapeDType()); - shape_heap_ = Var("shape_heap", ShapeExpr({heap_size_}), heap_type); - - // mutate - Expr new_func = this->Mutate(p.second); - ret_mod_->Add(p.first, Downcast(new_func)); + ret_mod_->Add(p.first, Downcast(func)); } return ret_mod_; } @@ -72,6 +72,9 @@ class ShapeLowerMutator : public ExprMutator { } Expr VisitExpr_(const ShapeExprNode* node) override { + if (IsConstantShape(GetRef(node))) { + return ExprMutator::VisitExpr_(node); + } tir::PrimFunc func = CalculateShape(GetRef(node)); std::string shape_func_name = name_table_->GetUniqueName("shape_func"); func = WithAttr(std::move(func), "global_symbol", runtime::String(shape_func_name)); @@ -111,7 +114,8 @@ class ShapeLowerMutator : public ExprMutator { new_body = seq->body; } - builder_->Emit(Call(ExternFunc("vm.builtin.free_shape_heap"), {shape_heap_}), "gv"); + // FIXME(@yuchen): Implement vm.builtin.free_shape_heap. + // builder_->Emit(Call(ExternFunc("vm.builtin.free_shape_heap"), {shape_heap_}), "gv"); blocks.push_back(builder_->EndBlock()); new_body = SeqExpr(blocks, new_body); @@ -171,6 +175,15 @@ class ShapeLowerMutator : public ExprMutator { return ret; } + bool IsConstantShape(ShapeExpr shape) const { + for (PrimExpr e : shape->values) { + if (!e->IsInstance()) { + return false; + } + } + return true; + } + private: IRModule mod_; IRModule ret_mod_; diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc new file mode 100644 index 000000000000..9e5bc6caefae --- /dev/null +++ b/src/relax/transform/to_non_dataflow.cc @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/to_non_dataflow.cc + * \brief + */ +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class ToNonDFMutator : public ExprMutator { + public: + explicit ToNonDFMutator(IRModule mod) { mod_ = mod; } + + IRModule Lower() { + IRModule ret_mod = IRModule(); + for (auto& p : mod_->functions) { + Expr func = p.second; + if (p.second->IsInstance()) { + func = this->Mutate(p.second); + } + ret_mod->Add(p.first, Downcast(func)); + } + return ret_mod; + } + + Expr VisitExpr_(const DataflowVarNode* op) final { + auto it = var_remap_.find(GetRef(op)); + if (it != var_remap_.end()) { + return it->second; + } + Var new_var(op->vid, op->shape(), op->type_annotation, op->span); + return new_var; + } + + BindingBlock VisitDataflowBlock(const DataflowBlock& block) final { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } + + private: + IRModule mod_; +}; + +TVM_REGISTER_GLOBAL("relax.transform.to_non_dataflow").set_body_typed([](IRModule mod) { + return ToNonDFMutator(mod).Lower(); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/vm/builtin.cc b/src/relax/vm/builtin.cc index 2dcfa05f1916..97de0f2370f2 100644 --- a/src/relax/vm/builtin.cc +++ b/src/relax/vm/builtin.cc @@ -46,11 +46,6 @@ TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap") return NDArray::Empty(size, DLDataType{kDLInt, 64, 1}, DLDevice{kDLCPU, 0}); }); -TVM_REGISTER_GLOBAL("vm.builtin.free_shape_heap") -.set_body_typed([](NDArray arr) { - return static_cast(const_cast(arr.get()))->DecRef(); -}); - TVM_REGISTER_GLOBAL("vm.builtin.decode_shape") .set_body_typed([](ShapeTuple shape, NDArray heap, ShapeTuple indexes) { int64_t* heap_data = reinterpret_cast(heap.ToDLPack()->dl_tensor.data); @@ -74,10 +69,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.make_shape") }); TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage") -.set_body_typed([](void* vm_state_ptr, Index size, Index alignment, Index device_type, +.set_body_typed([](void* vm_state_ptr, ShapeTuple buffer_size, ShapeTuple alignment, Index device_type, DLDataType dtype_hint) { + ICHECK_EQ(buffer_size.size(), 1); + ICHECK_EQ(alignment.size(), 1); VMState* vm_state = static_cast(vm_state_ptr); - DLOG(INFO) << "AllocStorage: allocation_size=" << size << ", alignment=" << alignment + int64_t size_imm = buffer_size[0]; + int64_t align_imm = alignment[0]; + DLOG(INFO) << "AllocStorage: allocation_size=" << size_imm << ", alignment=" << align_imm << ", dtype_hint=" << runtime::DLDataType2String(dtype_hint) << ", device_type=" << device_type; @@ -86,14 +85,16 @@ TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage") << "Memory allocator for device " << device_type << " has not been initialized"; auto* alloc = vm_state->allocators[device_type]; ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; - storage_obj->buffer = alloc->Alloc(size, alignment, dtype_hint); + storage_obj->buffer = alloc->Alloc(size_imm, align_imm, dtype_hint); Storage storage(storage_obj); return storage; }); TVM_REGISTER_GLOBAL("vm.builtin.alloc_tensor") -.set_body_typed([](Storage storage, Index offset, DLDataType dtype, ShapeTuple shape) { - auto tensor = storage->AllocNDArray(offset, shape, dtype); +.set_body_typed([](Storage storage, ShapeTuple offset, ShapeTuple shape, DLDataType dtype) { + ICHECK_EQ(offset.size(), 1); + int64_t offset_imm = offset[0]; + auto tensor = storage->AllocNDArray(offset_imm, shape, dtype); return tensor; }); diff --git a/src/relax/vm/compiler.cc b/src/relax/vm/compiler.cc index 6dff8259094d..595ba9255a29 100644 --- a/src/relax/vm/compiler.cc +++ b/src/relax/vm/compiler.cc @@ -78,29 +78,40 @@ class VMCompilerImpl : public ExprVisitor { builder_->EmitRet(ret_reg->second); } + // TODO: visit call node void VisitVarBinding(const VarBinding& binding) { Var var = binding->var; // TODO(@yuchen): support other nodes than Call - Call call_node = Downcast(binding->value); - if (auto* extern_func = call_node->op.as()) { - String name = extern_func->global_symbol; - if (name == "vm.builtin.alloc_storage") { - EmitAllocStorage(call_node, var); - } else if (name == "vm.builtin.alloc_tensor") { - EmitAllocTensor(call_node, var); - } else { - // Normal packed function without attributes + if (binding->value.as()){ + Call call_node = Downcast(binding->value); + if (auto* extern_func = call_node->op.as()) { + String name = extern_func->global_symbol; + if (name == "vm.builtin.alloc_storage") { + EmitAllocStorage(call_node, var); + } else if (name == "vm.builtin.alloc_tensor") { + EmitAllocTensor(call_node, var); + } else { + // Normal packed function without attributes + std::vector args = ConvertArgs(call_node); + // TODO(@yuchen): what if the packed func has void return (no need to write to the dst + // register)? + builder_->EmitCall(name, args, NewRegister(var)); + } + } else if (auto* gvar = call_node->op.as()) { + String name = gvar->name_hint; std::vector args = ConvertArgs(call_node); - // TODO(@yuchen): what if the packed func has void return (no need to write to the dst - // register)? + // TODO: global_var mangling builder_->EmitCall(name, args, NewRegister(var)); + } else { + LOG(FATAL) << "TODO: support compiling everything other than extern functions."; } - } else if (auto* gvar = call_node->op.as()) { - String name = gvar->name_hint; - std::vector args = ConvertArgs(call_node); - builder_->EmitCall(name, args, NewRegister(var)); + } else if (const VarNode* var_node = binding->value.as()) { + const Var& rhs_var = GetRef(var_node); + auto rhs_var_reg = this->var_register_map_.find(rhs_var); + ICHECK(rhs_var_reg != this->var_register_map_.end()); + this->var_register_map_.insert({var, rhs_var_reg->second}); } else { - LOG(FATAL) << "TODO: support compiling everything other than extern functions."; + LOG(FATAL) << "TODO: support compiling everything other than Call and Var."; } } @@ -112,13 +123,12 @@ class VMCompilerImpl : public ExprVisitor { ICHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs"; DataType dtype = alloc_attrs->dtype; int device_type = alloc_attrs->device_type; - PrimExpr size = Downcast(call_node->args[0])->values[0]; - PrimExpr alignment = Downcast(call_node->args[1])->values[0]; std::vector args; args.push_back(Instruction::Arg(Instruction::kVMStateRegister)); - args.push_back(Instruction::Arg(Instruction::kImmediate, Downcast(size)->value)); - args.push_back(Instruction::Arg(Instruction::kImmediate, Downcast(alignment)->value)); + for (Expr arg: call_node->args) { + args.push_back(ConvertArg(arg)); + } args.push_back(Instruction::Arg(Instruction::kImmediate, device_type)); // store dtype in constant pool @@ -139,12 +149,9 @@ class VMCompilerImpl : public ExprVisitor { DataType dtype = alloc_attrs->dtype; std::vector args; - auto storage_reg = this->var_register_map_.find(Downcast(call_node->args[0])); - ICHECK(storage_reg != this->var_register_map_.end()); - args.push_back(Instruction::Arg(Instruction::kRegister, storage_reg->second)); - - PrimExpr offset = Downcast(call_node->args[1])->values[0]; - args.push_back(Instruction::Arg(Instruction::kImmediate, Downcast(offset)->value)); + for (Expr arg: call_node->args) { + args.push_back(ConvertArg(arg)); + } // store dtype in constant pool TVMRetValue data_type; @@ -152,19 +159,6 @@ class VMCompilerImpl : public ExprVisitor { Index index = builder_->EmitConstant(data_type); args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); - // TODO(@yuchen, @ziheng): support symbolic shape when connecting with shape lowering - // store shape in constant pool - std::vector shape; - auto shape_expr = Downcast(call_node->args[2])->values; - for (PrimExpr i : shape_expr) { - shape.push_back(Downcast(i)->value); - } - auto shape_tuple = ShapeTuple(shape); - TVMRetValue shape_tuple_value; - shape_tuple_value = shape_tuple; - index = builder_->EmitConstant(shape_tuple_value); - args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); - builder_->EmitCall("vm.builtin.alloc_tensor", args, NewRegister(var)); } @@ -174,27 +168,47 @@ class VMCompilerImpl : public ExprVisitor { return reg; } + bool IsConstantShape(ShapeExpr shape) const { + for (PrimExpr e : shape->values) { + if (!e->IsInstance()) { + return false; + } + } + return true; + } + + // TODO: recursive Expr -> instr::arg, ExprFunctor, like llvm builder + Instruction::Arg ConvertArg(Expr arg) { + if (arg->IsInstance()) { + Var var = Downcast(arg); + auto reg = this->var_register_map_.find(Downcast(arg)); + ICHECK(reg != this->var_register_map_.end()) + << var->name_hint() << "(" << var << ")" << " not in the register map."; + return Instruction::Arg(Instruction::kRegister, reg->second); + } else if (arg->IsInstance()) { + ShapeExpr sh = Downcast(arg); + ICHECK(IsConstantShape(sh)) + << "should only use constant shape after shape lowering: " + << sh->values; + std::vector shape; + for (PrimExpr e : sh->values) { + shape.push_back(Downcast(e)->value); + } + auto shape_tuple = ShapeTuple(shape); + TVMRetValue shape_tuple_value; + shape_tuple_value = shape_tuple; + Index index = builder_->EmitConstant(shape_tuple_value); + return Instruction::Arg(Instruction::kConstIdx, index); + } else { + LOG(FATAL) << "not supported argument type."; + } + return Instruction::Arg(); + } + std::vector ConvertArgs(const Call& call) { std::vector ret; - const auto& args = call->args; for (size_t i = 0; i < call->args.size(); ++i) { - if (args[i]->IsInstance()) { - auto reg = this->var_register_map_.find(Downcast(args[i])); - ICHECK(reg != this->var_register_map_.end()); - ret.push_back(Instruction::Arg(Instruction::kRegister, reg->second)); - } else if (args[i]->IsInstance()) { - std::vector shape; - for (PrimExpr e : Downcast(args[i])->values) { - shape.push_back(Downcast(e)->value); - } - auto shape_tuple = ShapeTuple(shape); - TVMRetValue shape_tuple_value; - shape_tuple_value = shape_tuple; - Index index = builder_->EmitConstant(shape_tuple_value); - ret.push_back(Instruction::Arg(Instruction::kConstIdx, index)); - } else { - LOG(FATAL) << "not supported argument type."; - } + ret.push_back(ConvertArg(call->args[i])); } return ret; } diff --git a/src/relax/vm/executable.cc b/src/relax/vm/executable.cc index 1154f43ab249..bbb92422aa45 100644 --- a/src/relax/vm/executable.cc +++ b/src/relax/vm/executable.cc @@ -42,6 +42,7 @@ constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D; enum ConstantType : int { kNDArray = 0, kDLDataType = 1, + kShapeTuple = 2, }; #define STREAM_CHECK(val, section) \ @@ -239,11 +240,17 @@ void ExecutableNode::SaveGlobalSection(dmlc::Stream* strm) { void ExecutableNode::SaveConstantSection(dmlc::Stream* strm) { strm->Write(static_cast(this->constants.size())); - std::vector arrays; for (const auto& it : this->constants) { if (it.IsObjectRef()) { strm->Write(ConstantType::kNDArray); runtime::SaveDLTensor(strm, it.operator DLTensor*()); + } else if (it.IsObjectRef()){ + ShapeTuple shape = it.operator ShapeTuple(); + strm->Write(ConstantType::kShapeTuple); + strm->Write(shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + strm->Write(shape.at(i)); + } } else { try { strm->Write(ConstantType::kDLDataType); @@ -293,6 +300,16 @@ void ExecutableNode::LoadConstantSection(dmlc::Stream* strm) { TVMRetValue cell; cell = ndarray; this->constants.push_back(cell); + } else if (constant_type == ConstantType::kShapeTuple) { + size_t size; + strm->Read(&size); + std::vector data(size); + for (size_t i = 0; i < size; ++i) { + strm->Read(&(data[i])); + } + TVMRetValue cell; + cell = ShapeTuple(data); + this->constants.push_back(cell); } else if (constant_type == ConstantType::kDLDataType) { strm->Read(&dtype); TVMRetValue cell; diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 8479467f1aa5..dbbb6cbe6bac 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -62,27 +62,65 @@ def test_fma_rewrite(): assert type(func.body.blocks[0].bindings[1].var) == rx.Var -def test_explicit_memory_rewrite(): - m = tir.Var("m", "int32") - n = tir.Var("n", "int32") - shape_anno = [m, n] - type_anno = rx.DynTensorType(2, "float32") - x = rx.Var("x", shape_anno, type_anno) - ib = rx.BlockBuilder() - with ib.function(x): - with ib.dataflow() as df: - gv0 = ib.emit_output(rx.call_dps([m, n], rx.extern("test.op.identity"), [x])) - ib.emit_func_output(gv0) - expr = ib.get() +def test_to_non_dataflow(): + @rx.script + class TestToNoneDataflow: + def foo(x: Tensor[(m, n), "float32"]): + with relax.dataflow(): + gv0 = relax.call_dps((m, n), "test.op.identity", (x,)) + gv1 = relax.call_dps((m, n), "test.op.identity", (gv0,)) + relax.output(gv1) + return gv1 + + mod = TestToNoneDataflow() + + old_vars = [] + def fvisit(e): + if isinstance(e, rx.Var): + nonlocal old_vars + old_vars.append(e) + rx.analysis.post_order_visit(mod["foo"], fvisit) + _, x, _, gv0, _, gv1 = old_vars + + new_mod = rx.transform.to_non_dataflow(mod) + + new_vars = [] + def fvisit(e): + if isinstance(e, rx.Var): + nonlocal new_vars + new_vars.append(e) + rx.analysis.post_order_visit(new_mod["foo"], fvisit) + + assert x == new_vars[1] + assert gv0 != new_vars[3] + assert isinstance(gv0, rx.DataflowVar) + assert not isinstance(new_vars[3], rx.DataflowVar) + + assert isinstance(gv1, rx.Var) + assert isinstance(new_vars[5], rx.Var) + assert gv1 != new_vars[5] + + +def test_call_dps_rewrite(): + @rx.script + class TestCallDpsRewrite: + def foo(x: Tensor[(m, n), "float32"]): + gv0 = relax.call_dps((m, n), "test.op.identity", (x,)) + return gv0 + + mod = TestCallDpsRewrite() + code = rx.parser.astext(mod) # before rewrite - v0 = expr.body.blocks[0].bindings[0].var - s0 = expr.body.blocks[0].bindings[0].value + v0 = mod["foo"].body.blocks[0].bindings[0].var + s0 = mod["foo"].body.blocks[0].bindings[0].value assert isinstance(s0, tvm.relay.Call) assert s0.op.name == "relax.call_dps" # after rewrite - func = rx.transform.explicit_memory_rewrite(expr) + new_mod = rx.transform.call_dps_rewrite(mod) + func = new_mod["foo"] + code = rx.parser.astext(new_mod) # the dataflow block has changed to binding block due to the rewriting block = func.body.blocks[0] @@ -92,20 +130,41 @@ def test_explicit_memory_rewrite(): assert isinstance(s1, tvm.relay.Call) assert s1.op.name == "relax.builtin.alloc_tensor" assert isinstance(s1.args[0], rx.ShapeExpr) - assert structural_equal(s1.args[0], rx.ShapeExpr(shape_anno)) + assert structural_equal(s1.args[0], s0.args[0]) s2 = block.bindings[1].value assert s2.op.global_symbol == "test.op.identity" -@rx.script -class Mod: - def foo(x: Tensor[_, "float32"]) -> Shape: - sh = relax.call_packed("vm.builtin.shape_of", x) - relax.match_shape(sh, (n, m)) - return (n * 2, m * 3) +def test_memory_lower(): + @rx.script + class TestMemoryLower: + def foo(x: Tensor[(m, n), "float32"]): + alloc = relax.builtin.alloc_tensor((m, n)) + _ = relax.call_packed("test.op.identity", (x,), alloc) + gv0 = alloc + return gv0 + + mod = TestMemoryLower() + + # after memory lowering + new_mod = rx.transform.memory_lower(mod) + + assert isinstance(new_mod, tvm.IRModule) + assert isinstance(new_mod["foo"], tvm.relax.expr.Function) + code = rx.parser.astext(new_mod) + assert "vm.builtin.alloc_storage" in code + assert "vm.builtin.alloc_tensor" in code + def test_shape_lowering(): - mod = Mod() + @rx.script + class TestShapeLower: + def foo(x: Tensor[_, "float32"]) -> Shape: + sh = relax.call_packed("vm.builtin.shape_of", x) + relax.match_shape(sh, (n, m)) + return (n * 2, m * 3) + + mod = TestShapeLower() new_mod = rx.transform.shape_lower(mod) assert isinstance(new_mod, tvm.IRModule) assert isinstance(new_mod["shape_func"], tvm.tir.function.PrimFunc) @@ -118,5 +177,7 @@ def test_shape_lowering(): if __name__ == "__main__": test_fma_rewrite() - test_explicit_memory_rewrite() + test_to_non_dataflow() + test_call_dps_rewrite() + test_memory_lower() test_shape_lowering() diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index 52e5f9f0dc3b..0f83bfa73394 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from __future__ import annotations # must import to defer parsing of annotations +import os import numpy as np import tvm from tvm.relay import Call @@ -40,6 +41,10 @@ def mul(a, b): def identity_packed(a, b): b[:] = tvm.nd.array(a.asnumpy()) +@tvm.register_func("test.vm.tile") +def tile_packed(a, b): + b[:] = tvm.nd.array(np.tile(a.asnumpy(), (1, 2))) + def test_vm_execute(): ib = rx.ExecBuilder() with ib.function("func0", num_inputs=2): @@ -79,28 +84,29 @@ def test_vm_serialize(): ib.emit_call("test.vm.mul", args=[ib.r(0), arr], dst=ib.r(1)) ib.emit_ret(ib.r(1)) exec0 = ib.get() - exec0.save_to_file("exec.bin") - exec1 = rx.load_exec_from_file("exec.bin") + exec0.save_to_file("exec.tmp") + exec1 = rx.load_exec_from_file("exec.tmp") assert exec0.astext() == exec1.astext() + os.remove("exec.tmp") def test_vm_constant_serialize(): dtype = tvm.DataType('float32') - shape = (3, 4) - shape_tuple = container.ShapeTuple(shape) - input = tvm.nd.array(np.random.rand(3,4).astype(np.float32)) + shape = (4, 6) + inp = tvm.nd.array(np.random.rand(4, 6).astype(np.float32)) ib = rx.ExecBuilder() with ib.function("main", num_inputs=1): - ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), ib.imm(24), ib.imm(64), ib.imm(1), dtype], dst=ib.r(1)) - ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(1), ib.imm(0), dtype, ib.r(0)], dst=ib.r(2)) - ib.emit_call("test.vm.identity", args=[input, ib.r(2)]) + ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), (24,), (8,), ib.imm(1), dtype], dst=ib.r(1)) + ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(1), (0,), shape, dtype], dst=ib.r(2)) + ib.emit_call("test.vm.identity", args=[ib.r(0), ib.r(2)]) ib.emit_ret(ib.r(2)) exec0 = ib.get() - exec0.save_to_file("exec.bin") - exec1 = rx.load_exec_from_file("exec.bin") + exec0.save_to_file("exec.tmp") + exec1 = rx.load_exec_from_file("exec.tmp") assert exec0.astext() == exec1.astext() - vm = rx.VirtualMachine(exec1, tvm.cpu()) - res = vm["main"](shape_tuple) - np.testing.assert_allclose(input.asnumpy(), res.asnumpy()) + vm = rx.VirtualMachine(exec0, tvm.cpu()) + res = vm["main"](inp) + np.testing.assert_allclose(inp.asnumpy(), res.asnumpy()) + os.remove("exec.tmp") def test_vm_checker(): ib = rx.ExecBuilder() @@ -169,34 +175,30 @@ def test_vm_shapeof(): def test_vm_storage(): + dtype = tvm.DataType('float32') + shape = (4, 6) ib = rx.ExecBuilder() - with ib.function("main", num_inputs=7): - ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), ib.r(0), ib.r(1), ib.r(2), ib.r(3)], dst=ib.r(7)) - ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(7), ib.r(4), ib.r(5), ib.r(6)], dst=ib.r(8)) - ib.emit_ret(ib.r(8)) + with ib.function("main", num_inputs=0): + ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), (24,), (8,), ib.imm(1), dtype], dst=ib.r(1)) + ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(1), (0,), shape, dtype], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) ex = ib.get() vm = rx.VirtualMachine(ex, tvm.cpu()) - dtype = tvm.DataType('float32') - cpu_dev = tvm.cpu().device_type - buffer_size = 24 - alignment = 8 - offset = 0 - shape = (32, 16) shape_tuple = container.ShapeTuple(shape) - res = vm["main"](buffer_size, alignment, cpu_dev, dtype, offset, dtype, shape_tuple) + res = vm["main"]() assert res.device == tvm.cpu() assert res.shape == shape def test_vm_compile_stage0(): @rx.script - class Mod: + class TestVMCompileStage0: def foo(x: Tensor[(3, 4), "float32"]): y = relax.call_packed("vm.builtin.alloc_storage", (12,), (64,), device_id=0, device_type=1, attrs_type_key="relax.attrs.AllocStorageAttrs") z = relax.call_packed("vm.builtin.alloc_tensor", y, (0,), (3, 4), attrs_type_key="relax.attrs.AllocTensorAttrs") w = relax.call_packed("test.vm.identity", x, z) return z - mod = Mod() + mod = TestVMCompileStage0() target = tvm.target.Target("llvm") target_host = tvm.target.Target("llvm") ex, lib = rx.vm.build(mod, target, target_host) @@ -204,11 +206,12 @@ def foo(x: Tensor[(3, 4), "float32"]): vm = rx.VirtualMachine(ex, tvm.cpu(), mod=lib) res = vm["foo"](inp) np.testing.assert_allclose(inp.asnumpy(), res.asnumpy()) + res = vm["foo"](inp) def test_vm_compile_stage1(): @rx.script - class Mod1: + class TestVMCompileStage1: @tvm.script.tir def shape_func0(heap: ty.handle) -> None: # function attr dict @@ -226,7 +229,7 @@ def foo(x: Tensor[_, "float32"]) -> Shape: gv3 = relax.call_packed("vm.builtin.make_shape", shape_heap, (2, 3)) return gv3 - mod = Mod1() + mod = TestVMCompileStage1() code = rx.parser.astext(mod) target = tvm.target.Target("llvm") target_host = tvm.target.Target("llvm") @@ -242,19 +245,16 @@ def foo(x: Tensor[_, "float32"]) -> Shape: def test_vm_compile_stage2(): @rx.script - class Mod2: + class TestVMCompileStage2: def foo(x: Tensor[_, "float32"]) -> Shape: sh = relax.call_packed("vm.builtin.shape_of", x) relax.match_shape(sh, (n, m)) return (n * 2, m * 3) - mod = Mod2() - code = rx.parser.astext(mod) - new_mod = rx.transform.shape_lower(mod) - code = rx.parser.astext(new_mod) + mod = TestVMCompileStage2() target = tvm.target.Target("llvm") target_host = tvm.target.Target("llvm") - ex, lib = rx.vm.build(new_mod, target, target_host) + ex, lib = rx.vm.build(mod, target, target_host) vm = rx.VirtualMachine(ex, tvm.cpu(), mod=lib) shape = (32, 16) @@ -264,6 +264,51 @@ def foo(x: Tensor[_, "float32"]) -> Shape: assert res[1] == shape[1] * 3 +def test_vm_compile_stage3(): + @rx.script + class TestVMCompileStage3: + def foo(x: Tensor[(32, 16), "float32"]) -> Tensor: + with relax.dataflow(): + y = relax.call_dps((32, 16), "test.vm.identity", (x)) + relax.output(y) + return y + + mod = TestVMCompileStage3() + target = tvm.target.Target("llvm") + target_host = tvm.target.Target("llvm") + ex, lib = rx.vm.build(mod, target, target_host) + vm = rx.VirtualMachine(ex, tvm.cpu(), mod=lib) + + shape = (32, 16) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = vm["foo"](inp) + np.testing.assert_allclose(inp.asnumpy(), res.asnumpy()) + + +def test_vm_compile_e2e(): + @rx.script + class TestVMCompileE2E: + def foo(x: Tensor[_, "float32"]) -> Tensor: + with relax.dataflow(): + sh = relax.call_packed("vm.builtin.shape_of", x) + x0 = relax.match_shape(sh, (n, m)) + y = relax.call_dps((n, m * 2), "test.vm.tile", (x)) + relax.output(y) + return y + + mod = TestVMCompileE2E() + + target = tvm.target.Target("llvm") + target_host = tvm.target.Target("llvm") + ex, lib = rx.vm.build(mod, target, target_host) + vm = rx.VirtualMachine(ex, tvm.cpu(), mod=lib) + + shape = (32, 16) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = vm["foo"](inp) + np.testing.assert_allclose(np.tile(inp.asnumpy(), (1, 2)), res.asnumpy()) + + if __name__ == "__main__": test_vm_execute() test_vm_multiple_func() @@ -277,3 +322,5 @@ def foo(x: Tensor[_, "float32"]) -> Shape: test_vm_compile_stage0() test_vm_compile_stage1() test_vm_compile_stage2() + test_vm_compile_stage3() + test_vm_compile_e2e()