Skip to content

Commit

Permalink
[TIR] Introduce tir::PrimFunc (#5070)
Browse files Browse the repository at this point in the history
This PR introduces tir::PrimFunc which will be used as the TIR function
container in the unified IR.

Also streamlined the function attributes a bit further.
- All common attributes are under tvm::attr
- TIR specific attributes are under tvm::tir::attr and comes with a tir prefix
- Use stl_style for attributes for now
  • Loading branch information
tqchen authored Mar 14, 2020
1 parent be4e9db commit e031641
Show file tree
Hide file tree
Showing 17 changed files with 553 additions and 59 deletions.
99 changes: 99 additions & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,36 @@

namespace tvm {

/*!
* \brief Possible Calling conventions.
*
* NOTE: The calling convention also implies
* the way we implement the function during lowering.
*/
enum class CallingConv : int {
/*!
* \brief Default calling convetion.
*
* - Uses the native calling convention of the target.
* - Implementation: specified by the native target.
*/
kDefault = 0,
/*!
* \brief Device kernel launch
*
* - Call by PackedFunc calling convention.
* - Implementation: defined by device runtime(e.g. runtime/cuda)
*/
kDeviceKernelLaunch = 2,
/*!
* \brief PackedFunc that exposes a CPackedFunc signature.
*
* - Calling by PackedFunc calling convention.
* - Implementation: Expose a function with the CPackedFunc signature.
*/
kCPackedFunc = 3,
};

/*!
* \brief Base node of all functions.
*
Expand Down Expand Up @@ -115,5 +145,74 @@ class BaseFunc : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

/*!
* \brief Create a new function that copies func, but overrides
* the attribute value key with the value.
*
* \param func The input function.
* \param attr_key The attribute key.
* \param attr_value The value attribute value.
*
* \tparam TFunc The corresponding function type.
*
* \returns The new function with updated attributes.
*
* \note This function performs copy on write optimization for func.
* If we move a uniquely referenced func into WithAttr,
* then no additional copy will be performed.
*
* This is also why we make it as a function instead of a member function
* and why we pass by value in the first argument.
*
* \code
*
* // Recommended way to trigger copy on write
* func = WithAttr(std::move(func), "key1", value1);
* func = WithAttr(std::move(func), "key2", value2);
*
* \endcode
*/
template<typename TFunc,
typename = typename std::enable_if<
std::is_base_of<BaseFunc, TFunc>::value>::type>
inline TFunc WithAttr(TFunc func,
const std::string& attr_key,
ObjectRef attr_value) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = func.CopyOnWrite();
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<std::string, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return func;
}

/*!
* \brief Generic attribute names that can be attached to any function.
*
* \sa tvm::tir::attr, tvm::relay::attr
*/
namespace attr {
/*!
* \brief Indicates the special calling convention.
*
* Type: Integer
*
* \sa tvm::CallingConv
*/
constexpr const char* kCallingConv = "calling_conv";

/*!
* \brief Compilation target of the function.
*
* Type: Target
*
* \sa tvm::Target
*/
constexpr const char* kTarget = "target";
} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
7 changes: 7 additions & 0 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ class TupleType : public Type {
TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode);
};

/*!
* \return a type that represents void.
*/
inline Type VoidType() {
return TupleType::Empty();
}

/*!
* \brief Potential Constraints in a function.
* \sa TypeConstraint
Expand Down
27 changes: 0 additions & 27 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,33 +114,6 @@ class Function : public BaseFunc {
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
};

/*!
* \brief Create a new function that copies func, but overrides
* the attribute value key with the value.
*
* \param func The input function.
* \param attr_key The attribute key.
* \param attr_value The value attribute value.
*
* \returns The new function with updated attributes.
*
* \note This function performs copy on write optimization for func.
* If we move a uniquely referenced func into WithAttr,
* then no additional copy will be performed.
*
* This is also why we make it as a function instead of a member function
* and why we pass by value in the first argument.
*
* \code
*
* // Recommended way to trigger copy on write
* func = WithAttr(std::move(func), "key1", value1);
* func = WithAttr(std::move(func), "key2", value2);
*
* \endcode
*/
TVM_DLL Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value);

/*!
* \brief namespace of the attributes that can be attached to a relay::Function.
*/
Expand Down
177 changes: 177 additions & 0 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/tir/function.h
* \brief TIR Function.
*/
#ifndef TVM_TIR_FUNCTION_H_
#define TVM_TIR_FUNCTION_H_

#include <tvm/ir/function.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/stmt.h>
#include <string>


namespace tvm {
namespace tir {

/*!
* \brief Primitive functions that contains TIR statements.
*
* The PrimFunc provides low-level code representation does not
* automatically manage
*
* \sa PrimFunc
*/
class PrimFuncNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
Array<tir::Var> params;
/*! \brief The body of the function */
tir::Stmt body;
/*! \brief The return type of the function. */
Type ret_type;
/*!
* \brief Maps some parameters to specific Buffer data structures.
*
* buffer_map provides a way to express data structure's field and shape
* constraints. The provided information is used in the program analysis
* and the code generation.
*
* - It defines the vars in the Buffer (m, n) in the cases below when
* they appears in the buffer_map for the first time.
* - When a var appears multiple times, they translate into runtime
* assertion to check the field constraint.
*
* \code
*
* # The corresponding fields of f are as follows
* #
* # - f.params = [a, b]
* # - f.buffer_map = {a: A, b: B}
* # - A = decl_buffer(shape=[m, n])
* # - B = decl_buffer(shape=[m, n])
*
* def f(a, b):
* m, n = var(), var()
* A = bind_buffer(a, shape=[m, n])
* B = bind_buffer(b, shape=[m, n])
* # body
*
* \endcode
*
* buffer_map is a sugar to express:
* - Parameter unpacking: e.g. I can load a.shape[0] to get value of m
* - Constraint checking: a.shape[0] must equal b.shape[0] because they
* both corresponds to m.
* While we could have express parameter unpacking and constraint using
* normal statements, making buffer_map as first class citizen of PrimFunc
* will make program analysis much easier.
*
* \note This field can be nullptr
*/
Map<tir::Var, Buffer> buffer_map;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

/*!
* \brief Return the derived function annotation of this function.
*
* \return The function type annotation.
* \note The function type annotation of PrimExpr is
* directly derived from the Vars without the need of type inference.
*/
TVM_DLL FuncType func_type_annotation() const;

static constexpr const char* _type_key = "tir.PrimFunc";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
};

/*!
* \brief Managed reference to PrimFuncNode.
* \sa PrimFuncNode
*/
class PrimFunc : public BaseFunc {
public:
/*!
* \brief Constructor
* \param params The parameters of the function.
* \param body The body of the function.
* \param ret_type The return type of the function.
* \param buffer_map The buffer map for parameter buffer unpacking.
* \param attrs Additional function attributes.
*/
TVM_DLL PrimFunc(Array<tir::Var> params,
Stmt body,
Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = NullValue<Map<tir::Var, Buffer>>(),
DictAttrs attrs = NullValue<DictAttrs>());

TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
};

/*!
* \brief PrimFunc specific attribute names.
*
* \sa tvm::attr
*/
namespace attr {
/*!
* \brief List of thread IterVar that a DeviceLaunch function corresponds to.
*
* Type: Array<tir::IterVar>
*
* We call a device kernel launch function f using the following convention:
*
* Call(f,
* [arg1, arg2, ..., arg_n,
* work_size_1, work_size_2, ... work_size_m])
*
* Here n = len(arg), m = len(work_size) = len(device_thread_axis).
*
* The list of device_thread_axis indicates how can be bind the
* work_size arguments to the corresponding threads.
*
* \sa tvm::CallingConv::kDeviceKernelLaunch
*/
constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";

/*!
* \brief Whether to set noalias rule on the function arguments.
*
* Type: Integer
*/
constexpr const char* kNoAlias = "tir.noalias";
} // namespace attr
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_FUNCTION_H_
12 changes: 12 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#ifndef TVM_TIR_OP_H_
#define TVM_TIR_OP_H_

#include <tvm/ir/type.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>

Expand All @@ -37,13 +38,24 @@


namespace tvm {

// Most common operators can be overloaded by argument type(PrimExpr).
// So we put them under the root namespace.
// It is also necessary to overload operators for PrimExpr.
//
// We put more developer oriented APIs -- make_const and is_const under tir
// as they are more specific to the tir namespace.

/*!
* \brief Get the type of the expression under the unified type system.
*
* This function could return a more refined type than
* the runtime type provided by expr->dtype
*
* \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
*/
TVM_DLL Type GetType(const PrimExpr& expr);

/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
from .function import BaseFunc
from .adt import Constructor, TypeData
from .module import IRModule
from .attrs import Attrs, DictAttrs, make_node
Expand Down
9 changes: 0 additions & 9 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,6 @@ def checked_type(self):
return ret


class BaseFunc(RelayExpr):
"""Base class of all functions."""
@property
def attrs(self):
"""Return the attrs member of the function.
"""
return _ffi_api.BaseFunc_Attrs(self)


@tvm._ffi.register_object("relay.GlobalVar")
class GlobalVar(RelayExpr):
"""A global variable in the IR.
Expand Down
Loading

0 comments on commit e031641

Please sign in to comment.