Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support export ADT value in Python #3299

Merged
merged 4 commits into from
Jun 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,22 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
class ConstructorValue;

struct ConstructorValueNode : ValueNode {
Constructor constructor;
int tag;

tvm::Array<Value> fields;

/*! \brief Optional field tracking ADT constructor. */
Constructor constructor;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("constructor", &constructor);
v->Visit("tag", &tag);
v->Visit("fields", &fields);
v->Visit("constructor", &constructor);
}

TVM_DLL static ConstructorValue make(Constructor constructor,
tvm::Array<Value> fields);
TVM_DLL static ConstructorValue make(int tag,
tvm::Array<Value> fields,
Constructor construtor = {});

static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ class Closure(Value):

@register_relay_node
class ConstructorValue(Value):
def __init__(self, constructor, fields, types):
def __init__(self, tag, fields, constructor, types):
self.__init_handle_by_constructor__(
_make.ConstructorValue, constructor, fields, types)
_make.ConstructorValue, tag, fields, constructor, types)


@register_relay_node
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def _eval_vm(mod, ctx, *args):
args: List[tvm.NDArray, np.ndarray]
The arguments to evaluate.
"""

mod = optimize(mod)
args = list(args)
assert isinstance(args, list)
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,6 @@ def load_prelude(self):
def __init__(self, mod):
self.mod = mod
self.load_prelude()

self.define_list_adt()
self.define_list_hd()
self.define_list_tl()
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/relay/testing/nat.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,25 +151,25 @@ def add_nat_definitions(prelude):
# helper functions for working with nats


def count(n):
def count(prelude, n):
"""Takes a ConstructorValue corresponding to a nat ADT
and converts it into a Python integer. This is an example of
using an ADT value in Python.
"""
assert isinstance(n, ConstructorValue)
if n.constructor.name_hint == 'z':
if n.tag == prelude.z.tag:
return 0
assert n.constructor.name_hint == 's'
return 1 + count(n.fields[0])
assert n.tag == prelude.s.tag
return 1 + count(prelude, n.fields[0])


def make_nat_value(prelude, n):
"""The inverse of count(): Given a non-negative Python integer,
constructs a ConstructorValue representing that value as a nat.
"""
if n == 0:
return ConstructorValue(prelude.z, [], [])
return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], [])
return ConstructorValue(prelude.z.tag, [], None, [])
return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, [])


def make_nat_expr(prelude, n):
Expand Down
17 changes: 9 additions & 8 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "RefValueNode(" << node->value << ")";
});

ConstructorValue ConstructorValueNode::make(Constructor constructor,
tvm::Array<Value> fields) {
ConstructorValue ConstructorValueNode::make(int tag,
tvm::Array<Value> fields,
Constructor constructor) {
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
n->constructor = constructor;
n->tag = tag;
n->fields = fields;
n->constructor = constructor;
return ConstructorValue(n);
}

Expand All @@ -117,7 +119,7 @@ TVM_REGISTER_API("relay._make.ConstructorValue")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
tvm::IRPrinter* p) {
p->stream << "ConstructorValueNode(" << node->constructor
p->stream << "ConstructorValueNode(" << node->tag << ","
<< node->fields << ")";
});

Expand Down Expand Up @@ -448,7 +450,7 @@ class Interpreter :
"fusing and lowering";
}
if (auto con = call->op.as<ConstructorNode>()) {
return ConstructorValueNode::make(GetRef<Constructor>(con), args);
return ConstructorValueNode::make(con->tag, args, GetRef<Constructor>(con));
}
// Now we just evaluate and expect to find a closure.
Value fn_val = Eval(call->op);
Expand Down Expand Up @@ -544,9 +546,8 @@ class Interpreter :
const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
CHECK(cvn) << "need to be a constructor for match";
CHECK_NE(op->constructor->tag, -1);
CHECK_NE(cvn->constructor->tag, -1);
if (op->constructor->tag == cvn->constructor->tag) {
// todo(M.K.): should use ptr equality but it is broken
CHECK_NE(cvn->tag, -1);
if (op->constructor->tag == cvn->tag) {
CHECK_EQ(op->patterns.size(), cvn->fields.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
Expand Down
41 changes: 12 additions & 29 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ struct VMCompilerContext {
ConstTensorShapeMap const_tensor_shape_map;
// List of lowered functions
std::vector<LoweredFunc> lowered_funcs;
// The functions that have been lowered.
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
};

// Compute the constant pool, i.e a mapping from Constant node to constant index.
Expand Down Expand Up @@ -184,9 +186,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
size_t registers_num;
CompileEngine engine;

/*! \brief The functions that have been lowered. */
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;

/*! \brief Global shared meta data */
VMCompilerContext* context;

Expand Down Expand Up @@ -260,7 +259,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {

void VisitExpr_(const MatchNode* match_node) {
auto match = GetRef<Match>(match_node);
LOG(FATAL) << "translation of match nodes to the VM is"
LOG(FATAL) << "translation of match nodes to the VM is "
<< "currently unsupported" << std::endl;
}

Expand All @@ -280,7 +279,8 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
}

void VisitExpr_(const GlobalVarNode* gvar) {
LOG(FATAL) << "Global variables should only appear in the call position";
// TODO(wweic): Support Load GlobalVar into a register
LOG(FATAL) << "Loading GlobalVar into register is not yet supported";
}

void VisitExpr_(const IfNode* if_node) {
Expand Down Expand Up @@ -405,12 +405,12 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
// TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1;
if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) {
if (this->context->seen_funcs.find(cfunc->funcs[0]) == this->context->seen_funcs.end()) {
op_index = this->context->lowered_funcs.size();
this->context->lowered_funcs.push_back(cfunc->funcs[0]);
seen_funcs[cfunc->funcs[0]] = op_index;
this->context->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = seen_funcs[cfunc->funcs[0]];
op_index = this->context->seen_funcs[cfunc->funcs[0]];
}

Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
Expand All @@ -429,7 +429,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
std::vector<Index> args_registers;

for (auto arg : call_node->args) {
CHECK(arg.as<VarNode>()) << "found: " << AsText(arg, false) << std::endl << arg;
this->VisitExpr(arg);
args_registers.push_back(last_register);
}
Expand All @@ -449,18 +448,14 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
auto func = this->context->module->Lookup(global);
if (IsClosure(func)) {
auto arity = func->params.size();
std::vector<Index> free_var_registers;
for (size_t i = 0; i < arity; ++i) {
free_var_registers.push_back(var_register_map.at(func->params[i]));
}
Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister()));
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
} else {
Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
}
} else if (auto constructor_node = op.as<ConstructorNode>()) {
auto constructor = GetRef<Constructor>(constructor_node);
auto tag = GetConstructorTag(constructor);
Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister()));
Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers,
NewRegister()));
} else if (auto var_node = op.as<VarNode>()) {
VisitExpr(GetRef<Var>(var_node));
Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister()));
Expand All @@ -469,18 +464,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
}
}

size_t GetConstructorTag(tvm::relay::Constructor constructor) {
auto it = this->context->tag_map.find(constructor);
if (it != this->context->tag_map.end()) {
return it->second;
} else {
auto tag = this->context->tag_map.size();
this->context->tag_map[constructor] = tag;
this->context->tag_index_map[tag] = constructor;
return tag;
}
}

void VisitExpr_(const FunctionNode* func_node) {
if (!func_node->IsPrimitive()) {
LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
Expand Down Expand Up @@ -549,7 +532,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
}

VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl;
DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false) << std::endl;
size_t params = func->params.size();
VMCompiler compiler(context);
compiler.Compile(func);
Expand Down
23 changes: 9 additions & 14 deletions src/relay/backend/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,21 @@ Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
return res;
}

Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) {
CHECK(module.defined() && type.defined());
Value VMToValue(const relay::Module& module, Object obj) {
CHECK(module.defined());
switch (obj->tag) {
case ObjectTag::kTensor: {
CHECK(type.as<TensorTypeNode>()) << "VM internal error: return value must be a tensor";
return TensorValueNode::make(ToNDArray(obj));
}
case ObjectTag::kDatatype: {
// const auto* tuple_type
// const auto& data_type = obj.AsDatatype();
const auto& data_type = obj.AsDatatype();

// tvm::Array<Value> fields;
// for (size_t i = 0; i < data_type->fields.size(); ++i) {
// fields.push_back(VMToValue(tag_index_map, data_type->fields[i]));
// }
tvm::Array<Value> fields;
for (size_t i = 0; i < data_type->fields.size(); ++i) {
fields.push_back(VMToValue(module, data_type->fields[i]));
}

// return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields);
LOG(FATAL) << "fix me";
return ConstructorValueNode::make(data_type->tag, fields);
}
default:
LOG(FATAL) << "unsupported return value of type: " << obj->tag;
Expand Down Expand Up @@ -141,8 +138,6 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue
LOG(FATAL) << "expected function or module";
}

auto return_type = module->Lookup(module->entry_func)->ret_type;

std::vector<Object> vm_args;
for (auto i = 3; i < args.size(); i++) {
Object obj = args[i];
Expand All @@ -151,7 +146,7 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue

auto result = EvaluateModule(module, {ctx}, vm_args);
DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
*ret = VMToValue(module, return_type, result);
*ret = VMToValue(module, result);
});

} // namespace vm
Expand Down
3 changes: 2 additions & 1 deletion src/relay/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ Module FunctionPassNode::operator()(const Module& mod,
Module updated_mod = mod;
// Execute the pass function and return a new module.
std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : mod->functions) {
auto original = mod->functions;
for (const auto& it : original) {
auto updated_func = SkipFunction(it.second)
? it.second
: pass_func(it.second, updated_mod, pass_ctx);
Expand Down
11 changes: 7 additions & 4 deletions tests/python/relay/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay import testing, create_executor
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr

mod = relay.Module()
p = Prelude(mod)
add_nat_definitions(p)

def count(e):
return count_(p, e)

ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")

Expand Down Expand Up @@ -91,18 +94,18 @@ def to_list(l):
val = l
ret = []
while True:
if val.constructor.name_hint == 'cons':
if val.tag == p.cons.tag:
ret.append(val.fields[0])
val = val.fields[1]
else:
assert val.constructor.name_hint == 'nil'
assert val.tag == p.nil.tag
break
return ret

def tree_to_dict(t):
assert isinstance(t, ConstructorValue)
ret = {}
assert t.constructor.name_hint == 'rose'
assert t.tag == p.rose.tag
ret['member'] = t.fields[0]
ret['children'] = []
for subtree in to_list(t.fields[1]):
Expand Down
12 changes: 6 additions & 6 deletions tests/python/relay/test_backend_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ def test_function_taking_adt_ref_tuple():
prelude = relay.prelude.Prelude(mod)
intrp = create_executor("debug", mod)

nil_value = ConstructorValue(prelude.nil, [], [])
cons_value = ConstructorValue(prelude.cons, [
nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil, [])
cons_value = ConstructorValue(prelude.cons.tag, [
TensorValue(np.random.rand(1, 10).astype('float32')),
nil_value
], [relay.TensorType((1, 10), 'float32')])
], prelude.cons, [relay.TensorType((1, 10), 'float32')])

ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
tuple_value = TupleValue(*[
Expand All @@ -197,16 +197,16 @@ def test_function_taking_adt_ref_tuple():
id_func = intrp.evaluate(prelude.id)

res_nil = id_func(nil_value)
assert res_nil.constructor == nil_value.constructor
assert res_nil.tag == nil_value.tag
assert len(res_nil.fields) == 0

res_cons = id_func(cons_value)
assert res_cons.constructor == cons_value.constructor
assert res_cons.tag == cons_value.tag
assert len(res_cons.fields) == len(cons_value.fields)
tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
cons_value.fields[0].asnumpy())
assert isinstance(res_cons.fields[1], ConstructorValue)
assert res_cons.fields[1].constructor == prelude.nil
assert res_cons.fields[1].tag == prelude.nil.tag
assert len(res_cons.fields[1].fields) == 0

res_ref = id_func(ref_value)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_pass_to_a_normal_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def test_nat_add():
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
assert count(intrp.evaluate(add(s(z()), s(z())))) == 2
assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2
assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext()


Expand Down
Loading