Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR] Streamline Function Attr interface. #5045

Merged
merged 4 commits into from
Mar 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,26 +277,13 @@ class BaseAttrsNode : public Object {
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};

/*! \brief Base attribute container for all attributes */
/*!
* \brief Managed reference to BaseAttrsNode.
* \sa AttrsNode, BaseAttrsNode
*/
class Attrs : public ObjectRef {
public:
// normal constructor
Attrs() {}
// construct from shared ptr.
explicit Attrs(ObjectPtr<Object> n) : ObjectRef(n) {}

/*! \return The attribute node */
const BaseAttrsNode* operator->() const {
return ptr();
}
/*! \brief specify container node */
using ContainerType = BaseAttrsNode;

private:
/*! \return the internal attribute node */
const BaseAttrsNode* ptr() const {
return static_cast<const BaseAttrsNode*>(get());
}
TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode);
};

/*!
Expand All @@ -309,12 +296,7 @@ class DictAttrsNode : public BaseAttrsNode {
public:
/*! \brief internal attrs map */
Map<std::string, ObjectRef> dict;
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL static Attrs make(Map<std::string, ObjectRef> dict);

// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
Expand All @@ -327,6 +309,23 @@ class DictAttrsNode : public BaseAttrsNode {
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
};

/*!
* \brief Managed reference to DictAttrsNode
* \sa DictAttrsNode.
*/
class DictAttrs : public Attrs {
public:
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/Consruct/Construct

* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL explicit DictAttrs(Map<std::string, ObjectRef> dict);


TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};

// Namespace containing detail implementations
namespace detail {
Expand Down
24 changes: 0 additions & 24 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,30 +211,6 @@ class GlobalVar : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};

/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions shares the same type system(via checked_type)
* to support cross variant calls.
*
* \sa BaseFunc
*/
class BaseFuncNode : public RelayExprNode {
public:
static constexpr const char* _type_key = "BaseFunc";
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};

/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

// PrimExprs that are useful as runtime containers.
//
/*!
Expand Down
119 changes: 119 additions & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/function.h
* \brief Function nodes.
*/
#ifndef TVM_IR_FUNCTION_H_
#define TVM_IR_FUNCTION_H_

#include <tvm/ir/expr.h>
#include <tvm/ir/attrs.h>
#include <type_traits>
#include <string>


namespace tvm {

/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions share the same type system(via checked_type)
* to support cross variant calls.
*
* \sa BaseFunc
*/
class BaseFuncNode : public RelayExprNode {
public:
/*! \brief Additional attributes storing the meta-data */
DictAttrs attrs;

/*!
* \brief Get a function attribute.
*
* \param attr_key The attribute key.
* \param default_value The default value if the key does not exist, defaults to nullptr.
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef
*
* \code
*
* void GetAttrExample(const BaseFunc& f) {
* Integer value = f->GetAttr<Integer>("AttrKey", 0);
* }
*
* \endcode
*/
template<typename TObjectRef>
TObjectRef GetAttr(const std::string& attr_key,
TObjectRef default_value = NullValue<TObjectRef>()) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
if (!attrs.defined()) return default_value;
auto it = attrs->dict.find(attr_key);
if (it != attrs->dict.end()) {
return Downcast<TObjectRef>((*it).second);
} else {
return default_value;
}
}

/*!
* \brief Check whether the function has an non-zero integer attr.
*
* This function can be used to check whether an optional
* attribute mark(e.g. inline) exists.
*
* \param attr_key The key to the attribute.
* \return The check result.
*
* \code
*
* void HasNonzeroAttrExample(const BaseFunc& f) {
* if (f->HasNonzeroAttr(attr::kInline)) {
* // inline the function.
* }
* }
*
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0)->value != 0;
}

static constexpr const char* _type_key = "BaseFunc";
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};

/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
1 change: 1 addition & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ir/type.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/adt.h>

#include <string>
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/ir/module.h>
#include <tvm/relay/type.h>
#include <string>
Expand Down
131 changes: 0 additions & 131 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,113 +164,6 @@ class Var : public Expr {
TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
};

/*!
* \brief Function (subgraph in computational graph)
*/
class Function;
/*! \brief Function container */
class FunctionNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
tvm::Array<Var> params;
/*!
* \brief
* The expression which represents the computation of the function,
* the expression may reference the parameters, and the type of it
* or sub-expressions may reference the type variables.
*/
Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief Type parameters of the function.
* Enables the function to vary its type based on these.
* This corresponds to template paramaters in c++'s terminology.
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array<TypeVar> type_params;

/*!
* \brief The attributes which store metadata about functions.
*/
tvm::Attrs attrs;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

/*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL FuncType func_type_annotation() const;

/*!
* \brief Check whether the function is a primitive function.
*
* \return Whether the function is primitive or not.
*/
bool IsPrimitive() const;

/*!
* \brief Check whether the function is marked as inline.
*
* \return Whether the function should be inlined or not.
*/
bool IsMarkedInline() const;

/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool UseDefaultCompiler() const;

TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> ty_params,
tvm::Attrs attrs = Attrs());

/*!
* \brief Attach the function's parameters to its attributes for use in analysis.
* \return The function with its parameters attached.
*/
Function SetParams(const tvm::Map<Var, Constant>& parameters) const;

/*!
* \brief Retrieve the function's parameters.
*
* \return The function's parameter.
*/
tvm::Map<Var, Constant> GetParams() const;

static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};

class Function : public BaseFunc {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
};


TVM_DLL ObjectRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func,
const std::string& key,
const ObjectRef& data);

/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
Expand Down Expand Up @@ -550,30 +443,6 @@ class TempExpr : public Expr {
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
};


/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Indicate the compiler that should be used for builing this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
/*! \brief Treat the function as a composite operator. */
constexpr const char* kComposite = "Composite";
/*! \brief Mark the function to be inlined. */
constexpr const char* kInline = "Inline";
} // namespace attr

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
Loading