diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 8cf2677b9b31..74240d750694 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -25,6 +25,7 @@ #define TVM_RELAX_BLOCK_BUILDER_H_ #include +#include #include #include #include @@ -38,32 +39,6 @@ namespace relax { class BlockBuilder; -/*! - * \brief Utility data structure for generating unique names for IR construction. - */ -class NameTable { - public: - /*! - * \brief Generate a unique name with a specified prefix. - * \param prefix The name prefix. - * \return The generated name. - */ - inline std::string GetUniqueName(std::string prefix) { - std::replace(prefix.begin(), prefix.end(), '.', '_'); - std::string unique_prefix = prefix; - auto it = alloc_map_.find(prefix); - if (it != alloc_map_.end()) { - while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) { - } - } - alloc_map_[unique_prefix] = 0; - return unique_prefix; - } - - private: - std::unordered_map alloc_map_; -}; - /*! * \brief A builder that provides APIs to build Relax binding blocks. */ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index c5c86dc4717c..23000fa5bbeb 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -54,7 +54,7 @@ class ShapeExprNode : public ExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("values", &values); v->Visit("shape_", &shape_); - v->Visit("checked_type_", &checked_type_); + v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -94,11 +94,11 @@ class VarNode : public ExprNode { const String& name_hint() const { return vid->name_hint; } void VisitAttrs(AttrVisitor* v) { + v->Visit("_checked_type_", &checked_type_); v->Visit("vid", &vid); v->Visit("type_annotation", &type_annotation); v->Visit("span", &span); v->Visit("shape_", &shape_); - v->Visit("checked_type_", &checked_type_); } bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { @@ -143,7 +143,7 @@ class DataflowVarNode : public VarNode { v->Visit("type_annotation", &type_annotation); v->Visit("span", &span); v->Visit("shape_", &shape_); - v->Visit("checked_type_", &checked_type_); + v->Visit("_checked_type_", &checked_type_); } bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { @@ -330,7 +330,7 @@ class SeqExprNode : public ExprNode { v->Visit("blocks", &blocks); v->Visit("body", &body); v->Visit("shape_", &shape_); - v->Visit("checked_type_", &checked_type_); + v->Visit("_checked_type_", &checked_type_); v->Visit("span", &span); } @@ -378,7 +378,7 @@ class FunctionNode : public BaseFuncNode { v->Visit("params", ¶ms); v->Visit("body", &body); v->Visit("ret_type", &ret_type); - v->Visit("checked_type_", &checked_type_); + v->Visit("_checked_type_", &checked_type_); v->Visit("shape_", &shape_); v->Visit("span", &span); } diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 97820bfce6ba..22a6c401f4ff 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -139,9 +139,6 @@ class ExprFunctor { /*! * \brief A simple visitor wrapper around ExprFunctor. * Recursively visit the content. - * - * ExprVisitor treats Expr as dataflow graph, - * and only visit each Expr node once. */ class ExprVisitor : public ExprFunctor { public: @@ -167,9 +164,6 @@ class ExprVisitor : public ExprFunctor { virtual void VisitMatchShape(const MatchShape& binding); virtual void VisitBindingBlock(const BindingBlock& block); virtual void VisitDataflowBlock(const DataflowBlock& block); - - protected: - std::unordered_map visit_counter_; }; void PostOrderVisit(const Expr& node, std::function fvisit); @@ -221,7 +215,7 @@ class ExprMutator : public ExprFunctor { virtual Type VisitType(const Type& t); virtual void VisitBinding(const Binding& binding); - virtual Var VisitVarBinding(const VarBinding& binding); + virtual void VisitVarBinding(const VarBinding& binding); virtual void VisitMatchShape(const MatchShape& binding); virtual BindingBlock VisitBindingBlock(const BindingBlock& block); @@ -229,11 +223,40 @@ class ExprMutator : public ExprFunctor { protected: Expr MutateWithPrologue(const Expr& expr, bool is_dataflow); - /*! \brief Look up the value binded to a var. */ + + /*! \brief Look up the value of a variable. If the variable is bound, then returns the bound + * value. Otherwise, returns the rewritten expression for the variable. + */ Expr LookupVar(Var var); - // A remapping table: pre var -> post var - std::unordered_map var_remap_; - std::unordered_map memo_; + + inline void UpdateMemo(Expr pre, Expr post) { + if (const VarNode* var = pre.as()) { + var_memo_[var->vid] = post; + } else { + expr_memo_[pre] = post; + } + } + + inline Optional LookupMemo(Expr pre) { + if (pre.as()) { + Id vid = Downcast(pre)->vid; + if (var_memo_.count(vid)) { + return var_memo_[vid]; + } + } else { + if (expr_memo_.count(pre)) { + return expr_memo_[pre]; + } + } + return NullOpt; + } + + /*! \brief Variable memoization table using Id equality */ + std::unordered_map var_memo_; + + /*! \brief Expr memoization table using pointer equality */ + std::unordered_map expr_memo_; + std::shared_ptr name_table_; BlockBuilder builder_; }; @@ -245,7 +268,7 @@ class DataflowMutator : public ExprMutator { public: void VisitBinding(const Binding& binding) final; - virtual Var VisitDataflowVarBinding(const VarBinding& binding); + virtual void VisitDataflowVarBinding(const VarBinding& binding); }; } // namespace relax diff --git a/include/tvm/relax/ir_functor.h b/include/tvm/relax/ir_functor.h index b9a17f19ef0e..be6e4537756c 100644 --- a/include/tvm/relax/ir_functor.h +++ b/include/tvm/relax/ir_functor.h @@ -19,7 +19,8 @@ /*! * \file tvm/relax/ir_functor.h - * \brief A generic visitor for traversing Relax IR nodes. + * \brief A generic functor for working with Relax IR nodes. + * \sa tvm/relax/expr_functor.h for common IR rewriting use-cases. */ #ifndef TVM_RELAX_IR_FUNCTOR_H_ #define TVM_RELAX_IR_FUNCTOR_H_ diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h new file mode 100644 index 000000000000..1b3815c81c8e --- /dev/null +++ b/include/tvm/relax/utils.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/utils.h + * \brief Utility classes and functions for working with the Relax IR. + */ +#ifndef TVM_RELAX_UTILS_H_ +#define TVM_RELAX_UTILS_H_ + +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Utility data structure for generating unique names for IR construction. + */ +class NameTable { + public: + /*! + * \brief Generate a unique name with a specified prefix. + * \param prefix The name prefix. + * \return The generated name. + */ + inline std::string GetUniqueName(std::string prefix) { + std::replace(prefix.begin(), prefix.end(), '.', '_'); + std::string unique_prefix = prefix; + auto it = alloc_map_.find(prefix); + if (it != alloc_map_.end()) { + while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) { + } + } + alloc_map_[unique_prefix] = 0; + return unique_prefix; + } + + private: + std::unordered_map alloc_map_; +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_UTILS_H_ diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index ed78cd689ec7..1f1e4edcd969 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -111,7 +111,7 @@ def _traverse_expr(node): else: node_entry["inputs"].append([in_node_idx, 0, 0]) infer_out = _infer_type(node) - out_type = infer_out._checked_type_ + out_type = infer_out.checked_type_ if isinstance(out_type, TensorType): node_entry["types"].append(out_type) elif isinstance(out_type, TupleType): diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index 272598f7c3be..b214a1b9dd50 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -85,7 +85,7 @@ def infer_type(node): def call_node_infer_type(node): """infer the output types of call node""" infer_out = infer_type(node) - out_type = infer_out._checked_type_ + out_type = infer_out.checked_type_ if isinstance(out_type, TensorType): types = [out_type] elif isinstance(out_type, TupleType): diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 0a06b4e89c99..f91a3c69face 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -45,7 +45,7 @@ def checked_type(self): checked_type : tvm.relay.Type The checked type. """ - ret = self.checked_type_ + ret = self._checked_type_ if ret is None: raise ValueError("The type checker has not populated the checked_type for this node") return ret diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index b2bd4a945afa..6447e030a97f 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -23,7 +23,7 @@ */ #include -#include +#include #include #include @@ -397,7 +397,7 @@ std::vector RelaxScriptPrinter::PrintAttrs(const Attrs& attrs) { } } else { AttrPrinter attr_printer(&kwargs, this); - const_cast(attrs.operator->())->VisitNonDefaultAttrs(&attr_printer); + const_cast(attrs.operator->())->VisitAttrs(&attr_printer); } return kwargs; } diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 3c5a4202db60..7efa104c36bd 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -205,10 +205,6 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { } Expr ExprMutator::VisitExpr_(const VarNode* op) { - auto it = var_remap_.find(GetRef(op)); - if (it != var_remap_.end()) { - return it->second; - } if (op->type_annotation.defined()) { Type type = this->VisitType(op->type_annotation.value()); if (!op->type_annotation.same_as(type)) { @@ -220,10 +216,6 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { } Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { - auto it = var_remap_.find(GetRef(op)); - if (it != var_remap_.end()) { - return it->second; - } if (op->type_annotation.defined()) { Type type = this->VisitType(op->type_annotation.value()); if (!op->type_annotation.same_as(type)) { @@ -347,7 +339,7 @@ void ExprMutator::VisitBinding(const Binding& binding) { } } -Var ExprMutator::VisitVarBinding(const VarBinding& binding) { +void ExprMutator::VisitVarBinding(const VarBinding& binding) { Expr new_value = builder_->Normalize(this->Mutate(binding->value)); // TODO(@altanh): this probably shouldn't live here, all passes would have to make sure to do it @@ -369,7 +361,12 @@ Var ExprMutator::VisitVarBinding(const VarBinding& binding) { Var new_var = Downcast(this->Mutate(binding->var)); if (!builder_->CanProveShapeEqual(new_var->shape(), new_value->shape()) || !StructuralEqual()(new_var->checked_type(), new_value->checked_type())) { - new_var = Var(new_var->vid, NullOpt, NullOpt, new_var->span); + // TODO(@altanh): use CopyOnWrite and/or type inference machinery here + if (new_var.as()) { + new_var = DataflowVar(new_var->vid, NullOpt, NullOpt, new_var->span); + } else { + new_var = Var(new_var->vid, NullOpt, NullOpt, new_var->span); + } if (new_value->shape_.defined()) { new_var->shape_ = new_value->shape_; } @@ -377,14 +374,14 @@ Var ExprMutator::VisitVarBinding(const VarBinding& binding) { if (new_value->checked_type_.defined()){ new_var->checked_type_ = new_value->checked_type_; } - } - this->var_remap_[binding->var] = new_var; + UpdateMemo(binding->var, new_var); + } - if (builder_->CurrentBlockIsDataFlow() && !binding->var.as()) { - return builder_->EmitOutput(VarBinding(new_var, new_value)); + if (builder_->CurrentBlockIsDataFlow() && !new_var.as()) { + builder_->EmitOutput(VarBinding(new_var, new_value)); } else { - return builder_->Emit(VarBinding(new_var, new_value)); + builder_->Emit(VarBinding(new_var, new_value)); } } @@ -394,6 +391,7 @@ void ExprMutator::VisitMatchShape(const MatchShape& binding) { Var new_var; if (binding->var.defined()){ new_var = Downcast(this->Mutate(binding->var)); + // TODO(@altanh, @yuchen): shape and type inference here too... } else { new_var = binding->var; } @@ -423,7 +421,16 @@ BindingBlock ExprMutator::VisitDataflowBlock(const DataflowBlock& block) { return builder_->EndBlock(); } -Expr ExprMutator::VisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } +Expr ExprMutator::VisitExpr(const Expr& expr) { + Optional post = LookupMemo(expr); + if (post) { + return post.value(); + } + + UpdateMemo(expr, ExprFunctor::VisitExpr(expr)); + + return LookupMemo(expr).value(); +} Expr ExprMutator::MutateWithPrologue(const Expr& expr, bool is_dataflow) { if (is_dataflow) { @@ -441,12 +448,18 @@ Expr ExprMutator::MutateWithPrologue(const Expr& expr, bool is_dataflow) { } Expr ExprMutator::LookupVar(Var var) { - auto it = var_remap_.find(var); - if (it != var_remap_.end()) { - return builder_->LookupVar(it->first); - } else { - return builder_->LookupVar(var); + // cases: + // 1. var has been rewritten to some expr (e.g. a constant) and is no longer bound + // 2. var remains bound to some expr + // 3. var is deleted, in which case this should never be called + Expr mutated_var = LookupMemo(var).value(); + if (mutated_var.as()) { + // lookup bound var in the builder + return builder_->LookupVar(Downcast(mutated_var)); } + + // return the rewritten var value + return mutated_var; } // ================== @@ -456,17 +469,17 @@ void DataflowMutator::VisitBinding(const Binding& binding) { if (binding.as()) { VarBinding var_binding = Downcast(binding); if (builder_->CurrentBlockIsDataFlow()) { - var_remap_[var_binding->var] = this->VisitDataflowVarBinding(var_binding); + this->VisitDataflowVarBinding(var_binding); } else { - var_remap_[var_binding->var] = ExprMutator::VisitVarBinding(var_binding); + ExprMutator::VisitVarBinding(var_binding); } } else { ExprMutator::VisitBinding(binding); } } -Var DataflowMutator::VisitDataflowVarBinding(const VarBinding& binding) { - return ExprMutator::VisitVarBinding(binding); +void DataflowMutator::VisitDataflowVarBinding(const VarBinding& binding) { + ExprMutator::VisitVarBinding(binding); } } // namespace relax diff --git a/src/relax/transform/shape_lower.cc b/src/relax/transform/shape_lower.cc index d2d15f05c7e1..6f58ab2df3c6 100644 --- a/src/relax/transform/shape_lower.cc +++ b/src/relax/transform/shape_lower.cc @@ -68,7 +68,7 @@ class ShapeLowerMutator : public ExprMutator { indices.push_back(idx); } builder_->Emit(Call(ExternFunc("vm.builtin.decode_shape"), - {shape, shape_heap_, ShapeExpr(indices)}), "gv"); + {shape, shape_heap_, ShapeExpr(indices)}), "_decode_shape"); } Expr VisitExpr_(const ShapeExprNode* node) override { @@ -80,7 +80,7 @@ class ShapeLowerMutator : public ExprMutator { func = WithAttr(std::move(func), "global_symbol", runtime::String(shape_func_name)); GlobalVar shape_func_var(shape_func_name); // TODO make sure shape_heap doesnt get redefined by local funcs? - builder_->Emit(Call(shape_func_var, {shape_heap_}), "gv"); + builder_->Emit(Call(shape_func_var, {shape_heap_}), "_compute_shape"); ret_mod_->Add(shape_func_var, func); // construct shape diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc index 9e5bc6caefae..63fea270dca6 100644 --- a/src/relax/transform/to_non_dataflow.cc +++ b/src/relax/transform/to_non_dataflow.cc @@ -45,12 +45,7 @@ class ToNonDFMutator : public ExprMutator { } Expr VisitExpr_(const DataflowVarNode* op) final { - auto it = var_remap_.find(GetRef(op)); - if (it != var_remap_.end()) { - return it->second; - } - Var new_var(op->vid, op->shape(), op->type_annotation, op->span); - return new_var; + return Var(op->vid, op->shape(), op->type_annotation, op->span); } BindingBlock VisitDataflowBlock(const DataflowBlock& block) final { diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index 5c3e5ea424d1..2d850bf2383a 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -61,7 +61,7 @@ def test_match_shape() -> None: assert b0.pattern[0] == m assert b0.pattern[1] == n assert b0.var is not None - assert b0.var.checked_type_ == rx.ShapeType() + assert b0.var.checked_type == rx.ShapeType() # var1: Tensor[(m, n), "float32"] = # match_shape(var0: Tensor[_, "float32"], [m, n]) @@ -78,7 +78,7 @@ def test_match_shape() -> None: assert b1.var is not None for s0, s1 in zip(b1.var.shape, [m, n]): assert s0 == s1 - assert b1.var.checked_type_ == rx.DynTensorType(2, "float32") + assert b1.var.checked_type == rx.DynTensorType(2, "float32") def test_var_binding() -> None: diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 2c503ba488d9..aab2b4903237 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -426,7 +426,7 @@ def my_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: check_call( z_bind.value, "relax.call_dps", - [rx.ShapeExpr([B, tir.IntImm("int32", 128)]), mm_bind.var, rx.Tuple([x, y])], + [rx.ShapeExpr([B, tir.IntImm("int64", 128)]), mm_bind.var, rx.Tuple([x, y])], ) @@ -479,7 +479,7 @@ def f(x: Tensor): check_call( z_bind.value, "relax.call_dps", - [rx.ShapeExpr([tir.IntImm("int32", 10)]), rx.ExternFunc("my_extern"), rx.Tuple([x])], + [rx.ShapeExpr([tir.IntImm("int64", 10)]), rx.ExternFunc("my_extern"), rx.Tuple([x])], )