Skip to content

Commit

Permalink
[Refactor] Relay Node::make to constructor (apache#5128)
Browse files Browse the repository at this point in the history
* relay Node::make to constructor

* patternwildcard

* Address comments
  • Loading branch information
zhiics authored and Trevor Morris committed Apr 16, 2020
1 parent 14aec51 commit 9b4eb10
Show file tree
Hide file tree
Showing 87 changed files with 782 additions and 621 deletions.
85 changes: 61 additions & 24 deletions include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@

#include <tvm/ir/attrs.h>
#include <tvm/ir/adt.h>
#include <tvm/relay/base.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <string>
#include <functional>
#include "./base.h"
#include "./type.h"
#include "./expr.h"
#include <utility>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -69,10 +70,6 @@ class PatternWildcard;
/*! \brief PatternWildcard container node */
class PatternWildcardNode : public PatternNode {
public:
PatternWildcardNode() {}

TVM_DLL static PatternWildcard make();

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("span", &span);
}
Expand All @@ -83,21 +80,39 @@ class PatternWildcardNode : public PatternNode {

class PatternWildcard : public Pattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PatternWildcard, Pattern, PatternWildcardNode);
/* \brief Overload the default constructors. */
TVM_DLL PatternWildcard();
explicit PatternWildcard(ObjectPtr<Object> n) : Pattern(n) {}
/* \brief Copy constructor. */
PatternWildcard(const PatternWildcard& pat) : PatternWildcard(pat.data_) {}
/* \brief Move constructor. */
PatternWildcard(PatternWildcard&& pat) : PatternWildcard(std::move(pat.data_)) {}
/* \brief Copy assignment. */
PatternWildcard& operator=(const PatternWildcard& other) {
(*this).data_ = other.data_;
return *this;
}
/* \brief Move assignment. */
PatternWildcard& operator=(PatternWildcard&& other) {
(*this).data_ = std::move(other.data_);
return *this;
}

const PatternWildcardNode* operator->() const {
return static_cast<const PatternWildcardNode*>(get());
}

using ContainerType = PatternWildcardNode;
};

/*! \brief A var pattern. Accept all input and bind to a var. */
class PatternVar;
/*! \brief PatternVar container node */
class PatternVarNode : public PatternNode {
public:
PatternVarNode() {}

/*! \brief Variable that stores the matched value. */
tvm::relay::Var var;

TVM_DLL static PatternVar make(tvm::relay::Var var);

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("span", &span);
Expand All @@ -109,6 +124,12 @@ class PatternVarNode : public PatternNode {

class PatternVar : public Pattern {
public:
/*!
* \brief Constructor
* \param var The var to construct a pattern
*/
TVM_DLL explicit PatternVar(tvm::relay::Var var);

TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode);
};

Expand All @@ -122,10 +143,6 @@ class PatternConstructorNode : public PatternNode {
/*! Sub-patterns to match against each input to the constructor. */
tvm::Array<Pattern> patterns;

PatternConstructorNode() {}

TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("constructor", &constructor);
v->Visit("patterns", &patterns);
Expand All @@ -138,6 +155,13 @@ class PatternConstructorNode : public PatternNode {

class PatternConstructor : public Pattern {
public:
/*!
* \brief Constructor
* \param constructor The constructor of a pattern
* \param patterns The sub-patterns for matching
*/
TVM_DLL PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns);

TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode);
};

Expand All @@ -149,10 +173,6 @@ class PatternTupleNode : public PatternNode {
/*! Sub-patterns to match against each value of the tuple. */
tvm::Array<Pattern> patterns;

PatternTupleNode() {}

TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("patterns", &patterns);
v->Visit("span", &span);
Expand All @@ -164,6 +184,12 @@ class PatternTupleNode : public PatternNode {

class PatternTuple : public Pattern {
public:
/*!
* \brief Constructor
* \param patterns The sub-patterns to match against each value of the tuple
*/
TVM_DLL explicit PatternTuple(tvm::Array<Pattern> patterns);

TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode);
};

Expand All @@ -182,14 +208,19 @@ class ClauseNode : public Object {
v->Visit("rhs", &rhs);
}

TVM_DLL static Clause make(Pattern lhs, Expr rhs);

static constexpr const char* _type_key = "relay.Clause";
TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
};

class Clause : public ObjectRef {
public:
/*!
* \brief Constructor
* \param lhs The pattern matched by the clause.
* \param rhs The resulting value
*/
TVM_DLL explicit Clause(Pattern lhs, Expr rhs);

TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode);
};

Expand Down Expand Up @@ -217,14 +248,20 @@ class MatchNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern, bool complete = true);

static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
};

class Match : public Expr {
public:
/*!
* \brief Constructor
* \param data the input being deconstructed.
* \param clauses The clauses for matching.
* \param complete Indicate if this match is complete.
*/
TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true);

TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode);
};

Expand Down
6 changes: 6 additions & 0 deletions include/tvm/relay/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ class IdNode : public Object {

class Id : public ObjectRef {
public:
/*!
* \brief The constructor
* \param name_hint The name of the variable.
*/
TVM_DLL explicit Id(std::string name_hint);

TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
};

Expand Down
Loading

0 comments on commit 9b4eb10

Please sign in to comment.