Skip to content

Commit

Permalink
Adds Span node into the AST (apache#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
weberlo authored and jroesch committed Aug 16, 2018
1 parent 9db7b3e commit 6983913
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
35 changes: 33 additions & 2 deletions relay/include/relay/ir/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,35 @@ namespace relay {

typedef HalideIR::Type HType;

struct Node : public tvm::Node {};
class Span;

/*! \brief Stores locations in frontend source that generated a node. */
class SpanNode : public tvm::Node {
public:
int start;
int end;

SpanNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("start", &start);
v->Visit("end", &end);
}

TVM_DLL static Span make(int start, int end);

static constexpr const char* _type_key = "nnvm.Span";
TVM_DECLARE_NODE_TYPE_INFO(SpanNode, ::tvm::Node);
};

RELAY_DEFINE_NODE_REF(Span, SpanNode, ::tvm::NodeRef);

struct Node : public tvm::Node {
public:
Span span;

Node() {}
};

/*!
* \brief we always used NodeRef for referencing
Expand Down Expand Up @@ -78,7 +106,10 @@ class LocalIdNode : public ExprNode {

LocalIdNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("name", &name); }
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("span", &span);
v->Visit("name", &name);
}

TVM_DLL static LocalId make(std::string name);

Expand Down
5 changes: 5 additions & 0 deletions relay/python/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
# TODO(@jroesch): Add meta-programming dunder methods to a new Base class


@register_nnvm_node
class Span(NodeBase):
pass


@register_nnvm_node
class Environment(NodeBase):
"""Global Environment
Expand Down
17 changes: 17 additions & 0 deletions relay/src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "CastNode(" << node->target << ", " << node->node << ")";
});

Span SpanNode::make(int start, int end) {
std::shared_ptr<SpanNode> n = std::make_shared<SpanNode>();
n->start = start;
n->end = end;
return Span(n);
}

TVM_REGISTER_API("nnvm.make.Span")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SpanNode::make(args[0], args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const SpanNode *node, tvm::IRPrinter *p) {
p->stream << "SpanNode(" << node->start << "," << node->end << ")";
});

LocalId LocalIdNode::make(std::string name) {
std::shared_ptr<LocalIdNode> n = std::make_shared<LocalIdNode>();
n->name = name;
Expand Down

0 comments on commit 6983913

Please sign in to comment.