Skip to content

Commit

Permalink
[checkpoint] Switch to script, and test ADTs
Browse files Browse the repository at this point in the history
This need some changes to support attributes and metadata in parsing.

Added a test of ADTs which revealed (suprise!) the handling for constructors was too unconstrained.

We'll need to peel off another preparation PR.
  • Loading branch information
mbs-octoml committed Sep 25, 2021
1 parent f64e372 commit 65d2377
Show file tree
Hide file tree
Showing 17 changed files with 955 additions and 1,072 deletions.
21 changes: 21 additions & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,27 @@ constexpr const char* kTarget = "target";
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";

/*!
* \brief The device type which will hold each of the functions parameters.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: Array<Integer> (but interpreted as Array<DLDeviceType>)
*/
constexpr const char* kParamDeviceTypes = "param_device_types";

/*!
* \brief The device type which will hold the function result.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: Integer (but interpreted as DLDeviceType)
*/
constexpr const char* kResultDeviceType = "result_device_type";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
8 changes: 6 additions & 2 deletions include/tvm/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* \file parser.h
* \brief A parser for TVM IR.
*/
#include <tvm/ir/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

Expand All @@ -32,8 +33,11 @@
namespace tvm {
namespace parser {

IRModule ParseModule(std::string file_name, std::string file_content,
Optional<IRModule> init_module = Optional<IRModule>());
using MetaTable = Map<String, Array<ObjectRef>>;

IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module = Optional<IRModule>(),
const MetaTable& init_meta_table = MetaTable());

} // namespace parser
} // namespace tvm
Expand Down
66 changes: 0 additions & 66 deletions include/tvm/relay/attrs/function.h

This file was deleted.

6 changes: 4 additions & 2 deletions python/tvm/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def add(self, name, content):
return _ffi.get_global_func("SourceMapAdd")(self, name, content)


def parse(source, source_name="from_string"):
return _ffi_api.ParseModule(source_name, source)
def parse(source, source_name="from_string", init_module=None, init_meta_table=None):
if init_meta_table is None:
init_meta_table = {}
return _ffi_api.ParseModule(source_name, source, init_module, init_meta_table)


def parse_expr(source):
Expand Down
6 changes: 3 additions & 3 deletions src/ir/diagnostic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,10 @@ void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& s
}

auto source = (*it).second;
DLOG(INFO) << "Source: " << std::endl << source->source;
VLOG(1) << "Source: " << std::endl << source->source;

DLOG(INFO) << "ReportAt "
<< "span = " << span << " msg = " << diagnostic->message;
VLOG(1) << "ReportAt "
<< "span = " << span << " msg = " << diagnostic->message;

auto line_text = source.GetLine(span->line);

Expand Down
3 changes: 1 addition & 2 deletions src/parser/meta_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_PARSER_META_REF_H_

#include <tvm/ir/attrs.h>
#include <tvm/parser/parser.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>

Expand All @@ -36,8 +37,6 @@ namespace parser {

using namespace relay;

using MetaTable = Map<String, Array<ObjectRef>>;

/*!
* \brief Options for allocating storage.
*/
Expand Down
49 changes: 34 additions & 15 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ class Parser {
VLOG(0) << "Parser::ParseFunctionDef";
return WithSpan<Function>([&]() {
PushScope();
PushTypeScope();
PushTypeScope(); // TODO(mbs): BUG?

Array<TypeVar> generics;
if (Peek()->token_type == TokenType::kLSquare) {
Expand Down Expand Up @@ -1444,6 +1444,10 @@ class Parser {
ICHECK(attr_obj.defined());
attrs = Downcast<Attrs>(attr_obj);
}
} else {
this->diag_ctx.EmitFatal(Diagnostic::Error(op->span)
<< "unable to determine the 'attrs_type_key' with which "
"to represent the call attributes for this operator");
}
}
return true;
Expand Down Expand Up @@ -1867,7 +1871,7 @@ class Parser {
};

Parser InitParser(const std::string& file_name, const std::string& file_content,
Optional<IRModule> init_module) {
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(0) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size();
SourceName src_name = SourceName::Get(file_name);
Source source(src_name, file_content);
Expand All @@ -1886,43 +1890,58 @@ Parser InitParser(const std::string& file_name, const std::string& file_content,
auto tokens_and_table = Tokenize(diag_ctx, source);

auto tokens = tokens_and_table.first;
auto meta_data_table = tokens_and_table.second;
MetaTable meta_data_table = tokens_and_table.second.ToMetadata();

// Merge any entries in init_meta_table into anything captured in the #[metadata] section
// of the file_content. Metadata references within file_content must use indexes which account
// for this ordering.
for (const auto& pair : init_meta_table) {
Array<ObjectRef> items;
if (meta_data_table.count(pair.first)) {
items = meta_data_table[pair.first];
}
for (const auto& obj : pair.second) {
items.push_back(obj);
}
meta_data_table.Set(pair.first, items);
}

return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), meta_data_table.ToMetadata());
return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table));
}

IRModule ParseModule(std::string file_name, std::string file_content,
Optional<IRModule> init_module) {
IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(0) << "ParseModule";
auto parser = InitParser(file_name, file_content, init_module);
auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
auto mod = parser.ParseModule();
ICHECK(mod.defined()) << "The parser must return a non-null module.";
// NB(@jroesch): it is very important that we render any errors before we procede
// if there were any errors which allow the parser to procede we must render them
// NB(@jroesch): it is very important that we render any errors before we proceed
// if there were any errors which allow the parser to proceed we must render them
// here.
parser.diag_ctx.Render();
auto infer_type = tvm::relay::transform::InferType();
ICHECK(infer_type.defined()) << "The type inferencer must be non-null.";
return infer_type(mod);
}

Expr ParseExpr(std::string file_name, std::string file_content) {
Expr ParseExpr(const std::string& file_name, const std::string& file_content) {
VLOG(0) << "ParseExpr";
auto parser = InitParser(file_name, file_content, Optional<IRModule>());
auto parser = InitParser(file_name, file_content, Optional<IRModule>(), MetaTable());
parser.ParseSemVer(false);
parser.PushScope();
auto expr = parser.ParseExpr();
parser.Match(TokenType::kEndOfFile);
// NB(@jroesch): it is very important that we render any errors before we procede
// if there were any errors which allow the parser to procede we must render them
// NB(@jroesch): it is very important that we render any errors before we proceed
// if there were any errors which allow the parser to proceed we must render them
// here.
parser.diag_ctx.Render();
return expr;
}

TVM_REGISTER_GLOBAL("parser.ParseModule")
.set_body_typed([](tvm::String file_name, tvm::String file_content) {
return ParseModule(file_name, file_content);
.set_body_typed([](const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
return ParseModule(file_name, file_content, init_module, init_meta_table);
});

TVM_REGISTER_GLOBAL("parser.ParseExpr")
Expand Down
6 changes: 3 additions & 3 deletions src/parser/source_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Source::Source(SourceName src_name, std::string source) {
}

tvm::String Source::GetLine(int line) {
DLOG(INFO) << "Source::GetLine: line=" << line;
VLOG(1) << "Source::GetLine: line=" << line;
ICHECK(line - 1 < static_cast<int64_t>((*this)->line_map.size()))
<< "requested line: " << line << "at index: " << (line - 1)
<< "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source;
Expand All @@ -69,10 +69,10 @@ tvm::String Source::GetLine(int line) {
auto range = (*this)->line_map.at(line - 1);
int line_start = range.first;
int line_length = range.second;
DLOG(INFO) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length;
VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length;
// TODO(@jroesch): expose substring on tvm::String.
auto line_text = std::string((*this)->source).substr(line_start, line_length);
DLOG(INFO) << "Source::GetLine: line_text=" << line_text;
VLOG(1) << "Source::GetLine: line_text=" << line_text;
return line_text;
}

Expand Down
47 changes: 24 additions & 23 deletions src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include "./annotation.h"

#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/function.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
Expand Down Expand Up @@ -92,6 +91,7 @@ RELAY_REGISTER_OP("on_device")
.add_argument("data", "Tensor", "The input data.")
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attrs_type_key("relay.attrs.OnDeviceAttrs")
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
Expand Down Expand Up @@ -128,24 +128,31 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) {
return {};
}

TVM_REGISTER_NODE_TYPE(FunctionOnDeviceAttrs);

Function FunctionOnDevice(Function function, Array<Integer> param_device_types,
DLDeviceType result_device_type) {
auto attrs = make_object<FunctionOnDeviceAttrs>();
attrs->param_device_types = std::move(param_device_types);
attrs->result_device_type = result_device_type;
return WithAttr(std::move(function), attr::kFunctionAttrsKey, Attrs(std::move(attrs)));
Integer result_device_type) {
return WithAttr(WithAttr(std::move(function), tvm::attr::kParamDeviceTypes, param_device_types),
tvm::attr::kResultDeviceType, result_device_type);
}

Function FunctionOnDevice(Function function, const std::vector<DLDeviceType>& param_device_types,
DLDeviceType result_device_type) {
Array<Integer> arr;
arr.reserve(param_device_types.size());
for (const auto device_type : param_device_types) {
arr.push_back(static_cast<int64_t>(device_type));
arr.push_back(static_cast<int>(device_type));
}
return FunctionOnDevice(std::move(function), std::move(arr),
static_cast<int>(result_device_type));
}

Function OptFunctionOnDevice(Function function, const std::vector<DLDeviceType>& param_device_types,
DLDeviceType result_device_type) {
if (std::all_of(param_device_types.begin(), param_device_types.end(),
[](DLDeviceType type) { return type == kInvalidDeviceType; }) &&
result_device_type == kInvalidDeviceType) {
return function;
}
return FunctionOnDevice(function, arr, result_device_type);
return FunctionOnDevice(function, param_device_types, result_device_type);
}

TVM_REGISTER_GLOBAL("relay.op.annotation._make.function_on_device")
Expand All @@ -156,32 +163,26 @@ TVM_REGISTER_GLOBAL("relay.op.annotation._make.function_on_device")
});

DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node) {
auto opt_attrs = function_node->GetAttr<Attrs>(attr::kFunctionAttrsKey);
if (!opt_attrs) {
auto opt_integer = function_node->GetAttr<Integer>(tvm::attr::kResultDeviceType);
if (!opt_integer) {
// No annotation.
return kInvalidDeviceType;
}
const auto* opt_function_on_device_attrs = opt_attrs.value().as<FunctionOnDeviceAttrs>();
ICHECK(opt_function_on_device_attrs != nullptr)
<< "function '" << attr::kFunctionAttrsKey << "' annotation must be a FunctionOnDeviceAttrs";
return static_cast<DLDeviceType>(opt_function_on_device_attrs->result_device_type);
return static_cast<DLDeviceType>(opt_integer.value()->value);
}

DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i) {
ICHECK_LT(i, function_node->params.size())
<< "param index " << i << " out of range for function of arity "
<< function_node->params.size();
auto opt_attrs = function_node->GetAttr<Attrs>(attr::kFunctionAttrsKey);
if (!opt_attrs) {
auto opt_array = function_node->GetAttr<Array<Integer>>(tvm::attr::kParamDeviceTypes);
if (!opt_array) {
// No annotation.
return kInvalidDeviceType;
}
const auto* opt_function_on_device_attrs = opt_attrs.value().as<FunctionOnDeviceAttrs>();
ICHECK(opt_function_on_device_attrs != nullptr)
<< "function '" << attr::kFunctionAttrsKey << "' annotation must be a FunctionOnDeviceAttrs";
ICHECK_EQ(opt_function_on_device_attrs->param_device_types.size(), function_node->params.size())
ICHECK_EQ(opt_array.value().size(), function_node->params.size())
<< "annotation parameters do not match function arity";
return static_cast<DLDeviceType>(opt_function_on_device_attrs->param_device_types[i]->value);
return static_cast<DLDeviceType>(opt_array.value()[i]->value);
}

Expr StopFusion(Expr data) {
Expand Down
Loading

0 comments on commit 65d2377

Please sign in to comment.