From 6983913b11bc71ff6520b021cea9ad284f6ef066 Mon Sep 17 00:00:00 2001 From: Logan Weber <36520469+weberlo@users.noreply.github.com> Date: Wed, 18 Apr 2018 14:25:31 -0700 Subject: [PATCH] Adds Span node into the AST (#59) --- relay/include/relay/ir/base.h | 35 +++++++++++++++++++++++++++++++++-- relay/python/relay/expr.py | 5 +++++ relay/src/relay/ir/expr.cc | 17 +++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/relay/include/relay/ir/base.h b/relay/include/relay/ir/base.h index 5ad95fd4b4866..f2523e873a497 100644 --- a/relay/include/relay/ir/base.h +++ b/relay/include/relay/ir/base.h @@ -41,7 +41,35 @@ namespace relay { typedef HalideIR::Type HType; -struct Node : public tvm::Node {}; +class Span; + +/*! \brief Stores locations in frontend source that generated a node. */ +class SpanNode : public tvm::Node { + public: + int start; + int end; + + SpanNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("start", &start); + v->Visit("end", &end); + } + + TVM_DLL static Span make(int start, int end); + + static constexpr const char* _type_key = "nnvm.Span"; + TVM_DECLARE_NODE_TYPE_INFO(SpanNode, ::tvm::Node); +}; + +RELAY_DEFINE_NODE_REF(Span, SpanNode, ::tvm::NodeRef); + +struct Node : public tvm::Node { + public: + Span span; + + Node() {} +}; /*! * \brief we always used NodeRef for referencing @@ -78,7 +106,10 @@ class LocalIdNode : public ExprNode { LocalIdNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("name", &name); } + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("span", &span); + v->Visit("name", &name); + } TVM_DLL static LocalId make(std::string name); diff --git a/relay/python/relay/expr.py b/relay/python/relay/expr.py index 258609bcb4272..9058d3888ed4a 100644 --- a/relay/python/relay/expr.py +++ b/relay/python/relay/expr.py @@ -8,6 +8,11 @@ # TODO(@jroesch): Add meta-programming dunder methods to a new Base class +@register_nnvm_node +class Span(NodeBase): + pass + + @register_nnvm_node class Environment(NodeBase): """Global Environment diff --git a/relay/src/relay/ir/expr.cc b/relay/src/relay/ir/expr.cc index 41d6a6a037725..eebe1e15b990f 100644 --- a/relay/src/relay/ir/expr.cc +++ b/relay/src/relay/ir/expr.cc @@ -142,6 +142,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "CastNode(" << node->target << ", " << node->node << ")"; }); +Span SpanNode::make(int start, int end) { + std::shared_ptr n = std::make_shared(); + n->start = start; + n->end = end; + return Span(n); +} + +TVM_REGISTER_API("nnvm.make.Span") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = SpanNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const SpanNode *node, tvm::IRPrinter *p) { + p->stream << "SpanNode(" << node->start << "," << node->end << ")"; + }); + LocalId LocalIdNode::make(std::string name) { std::shared_ptr n = std::make_shared(); n->name = name;