Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Aug 7, 2020
1 parent 50b517c commit 6d1a38d
Show file tree
Hide file tree
Showing 12 changed files with 508 additions and 464 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/span.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class SpanNode : public Object {

class Span : public ObjectRef {
public:
TVM_DLL Span(SourceName source, int line, int column, int end_line, int end_column);
TVM_DLL Span(SourceName source, int line, int end_line, int column, int end_column);

/*! \brief Merge two spans into one which captures the combined regions. */
TVM_DLL Span Merge(const Span& other);
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/parser/source_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_PARSER_SOURCE_MAP_H_
#define TVM_PARSER_SOURCE_MAP_H_
/*!
* \file source_map.h
* \brief A map from source names to source code.
*/
#ifndef TVM_PARSER_SOURCE_MAP_H_
#define TVM_PARSER_SOURCE_MAP_H_

#include <tvm/ir/span.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
Expand Down
12 changes: 6 additions & 6 deletions src/ir/span.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,30 +65,30 @@ Span::Span(SourceName source, int line, int column, int end_line, int end_column
auto n = make_object<SpanNode>();
n->source = std::move(source);
n->line = line;
n->column = column;
n->end_line = end_line;
n->column = column;
n->end_column = end_column;
data_ = std::move(n);
}

Span Span::Merge(const Span& other) {
CHECK((*this)->source == other->source);
return Span((*this)->source, std::min((*this)->line, other->line),
std::min((*this)->column, other->column),
std::max((*this)->end_line, other->end_line),
std::min((*this)->column, other->column),
std::max((*this)->end_column, other->end_column));
}

TVM_REGISTER_NODE_TYPE(SpanNode);

TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int line, int column,
int end_line, int end_column) {
return Span(source, line, column, end_line, end_column);
TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int line, int end_line, int column,
int end_column) {
return Span(source, line, end_line, column, end_column);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
p->stream << "Span(" << node->source << ", " << node->line << ", " << node->column << ")";
p->stream << "Span(" << node->source << ", " << node->line << ", " << node->end_line << ", " << node->column << ", " << node->end_column << ")";
});
} // namespace tvm
11 changes: 4 additions & 7 deletions src/parser/meta_ref.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ namespace parser {
using tvm::relay::transform::CreateFunctionPass;
using tvm::transform::PassContext;

/* Set to arbitrary high number, since we should never schedule in normal pass manager flow. */
static int kMetaExpandOptLevel = 1337;

TVM_REGISTER_NODE_TYPE(MetaRefAttrs);

bool MetaRefRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand All @@ -60,12 +63,6 @@ Expr MetaRef(std::string type_key, uint64_t node_index) {
return Call(op, {}, Attrs(attrs), {});
}

// class MetaRefAttrExpander : AttrFunctor<ObjectRef(const ObjectRef& n)> {
// ObjectRef VisitAttrDefault_(const Object* node) final {

// }
// }

struct MetaRefExpander : public ExprMutator {
MetaTable table;

Expand Down Expand Up @@ -94,7 +91,7 @@ Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func
IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod) {
auto pass = CreateFunctionPass([&](Function func, IRModule module,
PassContext ctx) { return ExpandMetaRefs(meta_table, func); },
1337, "ExpandMetaRefs", {});
kMetaExpandOptLevel, "ExpandMetaRefs", {});

return pass(mod, PassContext::Create());
}
Expand Down
4 changes: 2 additions & 2 deletions src/parser/meta_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ struct MetaRefAttrs : public tvm::AttrsNode<MetaRefAttrs> {
* of the program.
*
* \param type_key The type key of the object in the meta section.
* \param kind The index into that subfield.
* \param node_index The index into that subfield.
* \returns The meta table reference.
*/
Expr MetaRef(std::string type_key, uint64_t node_index);

relay::Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& mod);
relay::Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func);
IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod);

} // namespace parser
Expand Down
20 changes: 10 additions & 10 deletions src/parser/op_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,16 @@ struct OperatorTable {

OperatorTable DefaultOpTable() {
return OperatorTable(
{Rule({TokenType::Star}, Op::Get("multiply"), 12, 2, true),
Rule({TokenType::Division}, Op::Get("divide"), 12, 2, true),
Rule({TokenType::Plus}, Op::Get("add"), 10, 2, true),
Rule({TokenType::Minus}, Op::Get("subtract"), 10, 2, true),
Rule({TokenType::LAngle}, Op::Get("less"), 8, 2, true),
Rule({TokenType::LAngle, TokenType::Equal}, Op::Get("less_equal"), 8, 2, true),
Rule({TokenType::RAngle}, Op::Get("greater"), 8, 2, true),
Rule({TokenType::RAngle, TokenType::Equal}, Op::Get("greater_equal"), 8, 2, true),
Rule({TokenType::Equal, TokenType::Equal}, Op::Get("equal"), 7, 2, true),
Rule({TokenType::Bang, TokenType::Equal}, Op::Get("not_equal"), 7, 2, true)});
{Rule({TokenType::kStar}, Op::Get("multiply"), 12, 2, true),
Rule({TokenType::kDivision}, Op::Get("divide"), 12, 2, true),
Rule({TokenType::kPlus}, Op::Get("add"), 10, 2, true),
Rule({TokenType::kMinus}, Op::Get("subtract"), 10, 2, true),
Rule({TokenType::kLAngle}, Op::Get("less"), 8, 2, true),
Rule({TokenType::kLAngle, TokenType::kEqual}, Op::Get("less_equal"), 8, 2, true),
Rule({TokenType::kRAngle}, Op::Get("greater"), 8, 2, true),
Rule({TokenType::kRAngle, TokenType::kEqual}, Op::Get("greater_equal"), 8, 2, true),
Rule({TokenType::kEqual, TokenType::kEqual}, Op::Get("equal"), 7, 2, true),
Rule({TokenType::kBang, TokenType::kEqual}, Op::Get("not_equal"), 7, 2, true)});
}

} // namespace parser
Expand Down
Loading

0 comments on commit 6d1a38d

Please sign in to comment.