From 4db04b71592ae3095ad2c1976f5d7cd59b6a0781 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 14 Mar 2020 14:26:23 -0700 Subject: [PATCH] [TIR] Introduce tir::PrimFunc (#5070) This PR introduces tir::PrimFunc which will be used as the TIR function container in the unified IR. Also streamlined the function attributes a bit further. - All common attributes are under tvm::attr - TIR specific attributes are under tvm::tir::attr and comes with a tir prefix - Use stl_style for attributes for now --- include/tvm/ir/function.h | 99 ++++++++++ include/tvm/ir/type.h | 7 + include/tvm/relay/function.h | 27 --- include/tvm/tir/function.h | 177 ++++++++++++++++++ include/tvm/tir/op.h | 12 ++ python/tvm/ir/__init__.py | 3 +- python/tvm/ir/expr.py | 9 - python/tvm/ir/function.py | 28 +++ python/tvm/relay/expr.py | 3 +- python/tvm/tir/__init__.py | 2 + python/tvm/tir/function.py | 86 +++++++++ src/ir/function.cc | 1 + src/printer/relay_text_printer.cc | 13 +- src/relay/ir/function.cc | 18 +- src/tir/ir/function.cc | 91 +++++++++ src/tir/ir/op.cc | 12 ++ .../{test_lang_basic.py => test_tir_nodes.py} | 24 ++- 17 files changed, 553 insertions(+), 59 deletions(-) create mode 100644 include/tvm/tir/function.h create mode 100644 python/tvm/ir/function.py create mode 100644 python/tvm/tir/function.py create mode 100644 src/tir/ir/function.cc rename tests/python/unittest/{test_lang_basic.py => test_tir_nodes.py} (92%) diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 4cb5d7018578..db7f4465f2a3 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -32,6 +32,36 @@ namespace tvm { +/*! + * \brief Possible Calling conventions. + * + * NOTE: The calling convention also implies + * the way we implement the function during lowering. + */ +enum class CallingConv : int { + /*! + * \brief Default calling convetion. + * + * - Uses the native calling convention of the target. + * - Implementation: specified by the native target. + */ + kDefault = 0, + /*! + * \brief Device kernel launch + * + * - Call by PackedFunc calling convention. + * - Implementation: defined by device runtime(e.g. runtime/cuda) + */ + kDeviceKernelLaunch = 2, + /*! + * \brief PackedFunc that exposes a CPackedFunc signature. + * + * - Calling by PackedFunc calling convention. + * - Implementation: Expose a function with the CPackedFunc signature. + */ + kCPackedFunc = 3, +}; + /*! * \brief Base node of all functions. * @@ -115,5 +145,74 @@ class BaseFunc : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); }; +/*! + * \brief Create a new function that copies func, but overrides + * the attribute value key with the value. + * + * \param func The input function. + * \param attr_key The attribute key. + * \param attr_value The value attribute value. + * + * \tparam TFunc The corresponding function type. + * + * \returns The new function with updated attributes. + * + * \note This function performs copy on write optimization for func. + * If we move a uniquely referenced func into WithAttr, + * then no additional copy will be performed. + * + * This is also why we make it as a function instead of a member function + * and why we pass by value in the first argument. + * + * \code + * + * // Recommended way to trigger copy on write + * func = WithAttr(std::move(func), "key1", value1); + * func = WithAttr(std::move(func), "key2", value2); + * + * \endcode + */ +template::value>::type> +inline TFunc WithAttr(TFunc func, + const std::string& attr_key, + ObjectRef attr_value) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + TNode* node = func.CopyOnWrite(); + if (node->attrs.defined()) { + node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); + } else { + Map dict = {{attr_key, attr_value}}; + node->attrs = DictAttrs(dict); + } + return func; +} + +/*! + * \brief Generic attribute names that can be attached to any function. + * + * \sa tvm::tir::attr, tvm::relay::attr + */ +namespace attr { +/*! + * \brief Indicates the special calling convention. + * + * Type: Integer + * + * \sa tvm::CallingConv + */ +constexpr const char* kCallingConv = "calling_conv"; + +/*! + * \brief Compilation target of the function. + * + * Type: Target + * + * \sa tvm::Target + */ +constexpr const char* kTarget = "target"; +} // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 9e87731dae72..7fd224b7c4a9 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -276,6 +276,13 @@ class TupleType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode); }; +/*! + * \return a type that represents void. + */ +inline Type VoidType() { + return TupleType::Empty(); +} + /*! * \brief Potential Constraints in a function. * \sa TypeConstraint diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 27aa2e830aea..f7514c7685e6 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -114,33 +114,6 @@ class Function : public BaseFunc { TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); }; -/*! - * \brief Create a new function that copies func, but overrides - * the attribute value key with the value. - * - * \param func The input function. - * \param attr_key The attribute key. - * \param attr_value The value attribute value. - * - * \returns The new function with updated attributes. - * - * \note This function performs copy on write optimization for func. - * If we move a uniquely referenced func into WithAttr, - * then no additional copy will be performed. - * - * This is also why we make it as a function instead of a member function - * and why we pass by value in the first argument. - * - * \code - * - * // Recommended way to trigger copy on write - * func = WithAttr(std::move(func), "key1", value1); - * func = WithAttr(std::move(func), "key2", value2); - * - * \endcode - */ -TVM_DLL Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value); - /*! * \brief namespace of the attributes that can be attached to a relay::Function. */ diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h new file mode 100644 index 000000000000..06802671db91 --- /dev/null +++ b/include/tvm/tir/function.h @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/function.h + * \brief TIR Function. + */ +#ifndef TVM_TIR_FUNCTION_H_ +#define TVM_TIR_FUNCTION_H_ + +#include +#include +#include +#include +#include + + +namespace tvm { +namespace tir { + +/*! + * \brief Primitive functions that contains TIR statements. + * + * The PrimFunc provides low-level code representation does not + * automatically manage + * + * \sa PrimFunc + */ +class PrimFuncNode : public BaseFuncNode { + public: + /*! \brief Function parameters */ + Array params; + /*! \brief The body of the function */ + tir::Stmt body; + /*! \brief The return type of the function. */ + Type ret_type; + /*! + * \brief Maps some parameters to specific Buffer data structures. + * + * buffer_map provides a way to express data structure's field and shape + * constraints. The provided information is used in the program analysis + * and the code generation. + * + * - It defines the vars in the Buffer (m, n) in the cases below when + * they appears in the buffer_map for the first time. + * - When a var appears multiple times, they translate into runtime + * assertion to check the field constraint. + * + * \code + * + * # The corresponding fields of f are as follows + * # + * # - f.params = [a, b] + * # - f.buffer_map = {a: A, b: B} + * # - A = decl_buffer(shape=[m, n]) + * # - B = decl_buffer(shape=[m, n]) + * + * def f(a, b): + * m, n = var(), var() + * A = bind_buffer(a, shape=[m, n]) + * B = bind_buffer(b, shape=[m, n]) + * # body + * + * \endcode + * + * buffer_map is a sugar to express: + * - Parameter unpacking: e.g. I can load a.shape[0] to get value of m + * - Constraint checking: a.shape[0] must equal b.shape[0] because they + * both corresponds to m. + + * While we could have express parameter unpacking and constraint using + * normal statements, making buffer_map as first class citizen of PrimFunc + * will make program analysis much easier. + * + * \note This field can be nullptr + */ + Map buffer_map; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("body", &body); + v->Visit("ret_type", &ret_type); + v->Visit("buffer_map", &buffer_map); + v->Visit("attrs", &attrs); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + /*! + * \brief Return the derived function annotation of this function. + * + * \return The function type annotation. + * \note The function type annotation of PrimExpr is + * directly derived from the Vars without the need of type inference. + */ + TVM_DLL FuncType func_type_annotation() const; + + static constexpr const char* _type_key = "tir.PrimFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode); +}; + +/*! + * \brief Managed reference to PrimFuncNode. + * \sa PrimFuncNode + */ +class PrimFunc : public BaseFunc { + public: + /*! + * \brief Constructor + * \param params The parameters of the function. + * \param body The body of the function. + * \param ret_type The return type of the function. + * \param buffer_map The buffer map for parameter buffer unpacking. + * \param attrs Additional function attributes. + */ + TVM_DLL PrimFunc(Array params, + Stmt body, + Type ret_type = VoidType(), + Map buffer_map = NullValue>(), + DictAttrs attrs = NullValue()); + + TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); +}; + +/*! + * \brief PrimFunc specific attribute names. + * + * \sa tvm::attr + */ +namespace attr { +/*! + * \brief List of thread IterVar that a DeviceLaunch function corresponds to. + * + * Type: Array + * + * We call a device kernel launch function f using the following convention: + * + * Call(f, + * [arg1, arg2, ..., arg_n, + * work_size_1, work_size_2, ... work_size_m]) + * + * Here n = len(arg), m = len(work_size) = len(device_thread_axis). + * + * The list of device_thread_axis indicates how can be bind the + * work_size arguments to the corresponding threads. + * + * \sa tvm::CallingConv::kDeviceKernelLaunch + */ +constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis"; + +/*! + * \brief Whether to set noalias rule on the function arguments. + * + * Type: Integer + */ +constexpr const char* kNoAlias = "tir.noalias"; +} // namespace attr +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_FUNCTION_H_ diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index a30c3c989322..6ee506350ba7 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -28,6 +28,7 @@ #ifndef TVM_TIR_OP_H_ #define TVM_TIR_OP_H_ +#include #include #include @@ -37,6 +38,7 @@ namespace tvm { + // Most common operators can be overloaded by argument type(PrimExpr). // So we put them under the root namespace. // It is also necessary to overload operators for PrimExpr. @@ -44,6 +46,16 @@ namespace tvm { // We put more developer oriented APIs -- make_const and is_const under tir // as they are more specific to the tir namespace. +/*! + * \brief Get the type of the expression under the unified type system. + * + * This function could return a more refined type than + * the runtime type provided by expr->dtype + * + * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. + */ +TVM_DLL Type GetType(const PrimExpr& expr); + /*! * Query the maximum possible value of dtype. * \param dtype The data type. diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index a718124d2116..416032634ce9 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -21,7 +21,8 @@ from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .tensor_type import TensorType from .type_relation import TypeCall, TypeRelation -from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range +from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range +from .function import BaseFunc from .adt import Constructor, TypeData from .module import IRModule from .attrs import Attrs, DictAttrs, make_node diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 00ceb5bd4623..4e6bf16f7545 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -51,15 +51,6 @@ def checked_type(self): return ret -class BaseFunc(RelayExpr): - """Base class of all functions.""" - @property - def attrs(self): - """Return the attrs member of the function. - """ - return _ffi_api.BaseFunc_Attrs(self) - - @tvm._ffi.register_object("relay.GlobalVar") class GlobalVar(RelayExpr): """A global variable in the IR. diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py new file mode 100644 index 000000000000..70eb51a093d3 --- /dev/null +++ b/python/tvm/ir/function.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Function defintiions.""" +from .expr import RelayExpr +from . import _ffi_api + + +class BaseFunc(RelayExpr): + """Base class of all functions.""" + @property + def attrs(self): + """Return the attrs member of the function. + """ + return _ffi_api.BaseFunc_Attrs(self) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index a3c625173f4e..61a5fb7c63ba 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -282,7 +282,8 @@ def with_attr(self, attr_key, attr_value): func : Function A new copy of the function """ - return _expr.FunctionWithAttr(self, attr_key, attr_value) + return _expr.FunctionWithAttr( + self, attr_key, convert(attr_value)) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index fa244ac72103..b8a56f8ef7e8 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -31,6 +31,8 @@ from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list +from .function import PrimFunc + from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, all, any, min_value, max_value, trace from .op import exp, exp2, exp10, log, log2, log10 diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py new file mode 100644 index 000000000000..37946f66b1bb --- /dev/null +++ b/python/tvm/tir/function.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Function data types.""" + +import tvm._ffi +import tvm.runtime +from tvm.ir import BaseFunc +from .buffer import Buffer +from .expr import Var +from . import _ffi_api + + +@tvm._ffi.register_object("tir.PrimFunc") +class PrimFunc(BaseFunc): + """A function declaration expression. + + Parameters + ---------- + params: List[Union[tvm.tir.Var, tvm.tir.Buffer]] + List of input parameters to the function. + + body: tvm.tir.Stmt + The body of the function. + + ret_type: tvm.ir.Type + The return type annotation of the function. + + buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer] + The buffer binding map. + + attrs: Optional[tvm.Attrs] + Attributes of the function, can be None + """ + def __init__(self, + params, + body, + ret_type=None, + buffer_map=None, + attrs=None): + param_list = [] + buffer_map = {} if buffer_map is None else buffer_map + for x in params: + if isinstance(x, Buffer): + var = Var(x.name, dtype="handle") + param_list.append(var) + buffer_map[var] = x + elif isinstance(x, Var): + param_list.append(x) + else: + raise TypeError("params can only contain Var or Buffer") + + self.__init_handle_by_constructor__( + _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs) + + def with_attr(self, attr_key, attr_value): + """Create a new copy of the function and update the attribute + + Parameters + ---------- + attr_key : str + The attribute key to use. + + attr_value : Object + The new attribute value. + + Returns + ------- + func : Function + A new copy of the function + """ + return _ffi_api.PrimFuncWithAttr( + self, attr_key, tvm.runtime.convert(attr_value)) diff --git a/src/ir/function.cc b/src/ir/function.cc index d3753d8ffb64..e7ccbbe73e7b 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -30,4 +30,5 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs") .set_body_typed([](BaseFunc func) { return func->attrs; }); + } // namespace tvm diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 2799be0896cf..56e77b72ed8e 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -99,7 +99,11 @@ class RelayTextPrinter : } Doc PrintFinal(const ObjectRef& node) { - if (node.as()) { + if (node->IsInstance() && + !node->IsInstance()) { + // Temporarily skip non-relay functions. + // TODO(tvm-team) enhance the code to work for all functions + } else if (node.as()) { Expr expr = Downcast(node); dg_ = DependencyGraph::Create(&arena_, expr); } @@ -122,7 +126,10 @@ class RelayTextPrinter : std::vector PrintFuncAttrs(const Attrs& attrs); Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false) { - if (node.as()) { + bool is_non_relay_func = + node->IsInstance() && + !node->IsInstance(); + if (node.as() && !is_non_relay_func) { return PrintExpr(Downcast(node), meta, try_inline); } else if (node.as()) { return PrintType(Downcast(node), meta); @@ -134,7 +141,7 @@ class RelayTextPrinter : // default module. std::ostringstream os; os << node; - return Doc() << os.str(); + return Doc::RawText(os.str()); } } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index c1bd7101adfc..63ad4ddb26d5 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -60,18 +60,6 @@ bool FunctionNode::UseDefaultCompiler() const { return !val.defined() || val->value == "default"; } -Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value) { - FunctionNode* node = func.CopyOnWrite(); - if (node->attrs.defined()) { - node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); - } else { - Map dict = {{attr_key, attr_value}}; - node->attrs = DictAttrs(dict); - } - return func; -} - - TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_GLOBAL("relay._make.Function") @@ -94,9 +82,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_GLOBAL("relay._expr.FunctionWithAttr") .set_body_typed( - [](Function func, std::string name, ObjectRef ref) { - return WithAttr(std::move(func), name, ref); -}); + [](Function func, std::string name, ObjectRef ref) { + return WithAttr(std::move(func), name, ref); + }); } // namespace relay } // namespace tvm diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc new file mode 100644 index 000000000000..7464e3ad4370 --- /dev/null +++ b/src/tir/ir/function.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tir/ir/function.cc + * \brief The function data structure. + */ +#include +#include + +namespace tvm { +namespace tir { + +PrimFunc::PrimFunc(Array params, + Stmt body, + Type ret_type, + Map buffer_map, + DictAttrs attrs) { + // Assume void-return type for now + // TODO(tvm-team) consider type deduction from body. + if (!ret_type.defined()) { + ret_type = VoidType(); + } + auto n = make_object(); + n->params = std::move(params); + n->body = std::move(body); + n->ret_type = std::move(ret_type); + n->buffer_map = std::move(buffer_map); + n->attrs = std::move(attrs); + data_ = std::move(n); +} + +FuncType PrimFuncNode::func_type_annotation() const { + Array param_types; + for (auto param : this->params) { + param_types.push_back(GetType(param)); + } + return FuncType(param_types, ret_type, {}, {}); +} + +TVM_REGISTER_NODE_TYPE(PrimFuncNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + // TODO(tvm-team) redirect to Text printer once we have a good text format. + auto* node = static_cast(ref.get()); + p->stream << "PrimFunc(" << node->params << ") "; + if (node->attrs.defined()) { + p->stream << "attrs=" << node->attrs; + } + p->stream << " {\n"; + p->indent += 2; + p->Print(node->body); + p->indent -= 2; + p->stream << "}\n"; +}); + + +TVM_REGISTER_GLOBAL("tir.PrimFunc") +.set_body_typed([](Array params, + Stmt body, + Type ret_type, + Map buffer_map, + DictAttrs attrs) { + return PrimFunc(params, body, ret_type, buffer_map, attrs); +}); + + +TVM_REGISTER_GLOBAL("tir.PrimFuncWithAttr") +.set_body_typed([](PrimFunc func, std::string name, ObjectRef ref) { + return WithAttr(std::move(func), name, ref); +}); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc index 2882fea0693b..b0736435607d 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/ir/op.cc @@ -32,6 +32,18 @@ namespace tvm { using namespace tir; + +Type GetType(const PrimExpr& expr) { + runtime::DataType dtype = expr.dtype(); + // These types already implies the specific type. + if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) { + return PrimType(dtype); + } + // TODO(tqchen): add recursive type inference for Var and Call here + // once we introduced the corresponding fields to the IR. + return PrimType(dtype); +} + // simple cast that only checks if type matches and cast inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_tir_nodes.py similarity index 92% rename from tests/python/unittest/test_lang_basic.py rename to tests/python/unittest/test_tir_nodes.py index c279194ce522..3a7985dde3fe 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -19,6 +19,7 @@ import numpy as np + def test_const(): x = tvm.tir.const(1, "int32") print(x.dtype) @@ -46,8 +47,8 @@ def test_make(): x = tvm.tir.const(1, "int32") y = te.var("x") z = x + y - assert isinstance(tvm.te.max(x, y), tvm.tir.Max) - assert isinstance(tvm.te.min(x, y), tvm.tir.Min) + assert isinstance(tvm.tir.max(x, y), tvm.tir.Max) + assert isinstance(tvm.tir.min(x, y), tvm.tir.Min) def test_ir(): @@ -111,7 +112,6 @@ def test_stmt(): tvm.tir.For.Serial, 0, x) - def test_dir(): x = te.var('x') dir(x) @@ -247,8 +247,26 @@ def test_equality_string_imm(): x == y.value x == y +def test_prim_func(): + x = te.var('x') + y = te.var('y') + b = tvm.tir.decl_buffer((x,), "float32") + stmt = tvm.tir.LetStmt( + x, 10, tvm.tir.Evaluate(x + 1)); + + func = tvm.tir.PrimFunc( + [x, y, b], stmt) + + assert func.buffer_map[func.params[2]].same_as(b) + + assert len(func.buffer_map) == 1 + f2 = func.with_attr("calling_conv", 1) + assert f2.attrs["calling_conv"].value == 1 + assert func.attrs is None + if __name__ == "__main__": + test_prim_func() test_cast() test_attr() test_const()