Skip to content

Commit

Permalink
Merge pull request #6 from jroesch/relay-rts-wweic-adt
Browse files Browse the repository at this point in the history
Basic ADT allocation in VM
  • Loading branch information
wweic authored Feb 6, 2019
2 parents 8661269 + fded961 commit 7eead36
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 15 deletions.
8 changes: 7 additions & 1 deletion include/tvm/relay/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct VMDatatypeCell : public VMObjectCell {
std::vector<VMObject> fields;

VMDatatypeCell(size_t tag, const std::vector<VMObject>& fields)
: VMObjectCell(VMObjectTag::kDatatype), fields(fields) {}
: VMObjectCell(VMObjectTag::kDatatype), tag(tag), fields(fields) {}
};


Expand Down Expand Up @@ -88,6 +88,7 @@ enum struct Opcode {
Invoke,
InvokePacked,
AllocTensor,
AllocDatatype,
GetField,
If,
LoadConst,
Expand Down Expand Up @@ -126,6 +127,10 @@ struct Instruction {
size_t object_offset;
size_t field_index;
};
struct {
size_t constructor_tag;
size_t num_fields;
};
};

Instruction();
Expand Down Expand Up @@ -187,6 +192,7 @@ struct VirtualMachine {

// Interface debugging.
std::unordered_map<GlobalVar, size_t, NodeHash, NodeEqual> global_map;
std::unordered_map<size_t, Constructor> tag_index_map;

void PushFrame(size_t arg_count, size_t ret_pc, const VMFunction& vm_func);
size_t PopFrame();
Expand Down
93 changes: 79 additions & 14 deletions src/relay/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ Instruction::Instruction(const Instruction& instr) {
case Opcode::AllocTensor:
this->tensor_info = instr.tensor_info;
return;
case Opcode::AllocDatatype:
this->constructor_tag = instr.constructor_tag;
this->num_fields = instr.num_fields;
return;
case Opcode::InvokePacked:
this->packed_index = instr.packed_index;
this->arity = instr.arity;
Expand Down Expand Up @@ -108,6 +112,14 @@ Instruction AllocTensor(const std::vector<int64_t> shape, DLDataType dtype) {
return instr;
}

Instruction AllocDatatype(size_t tag, size_t num_fields) {
Instruction instr;
instr.op = Opcode::AllocDatatype;
instr.constructor_tag = tag;
instr.num_fields = num_fields;
return instr;
}

Instruction GetField(size_t object_offset, size_t field_index) {
Instruction instr;
instr.op = Opcode::GetField;
Expand Down Expand Up @@ -172,6 +184,13 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
os << TVMType2Type(instr.tensor_info.dtype);
break;
}
case Opcode::AllocDatatype: {
os << "alloc_block";
os << " ";
os << instr.constructor_tag << " ";
os << instr.num_fields;
break;
}
case Opcode::If: {
os << "if "
<< instr.true_offset << " "
Expand Down Expand Up @@ -210,11 +229,14 @@ void VMFunctionPrint(const VMFunction& vm_func) {
}
}


using TagMap = std::unordered_map<tvm::relay::Constructor, size_t, NodeHash, NodeEqual>;
using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>;
using GlobalMap = std::unordered_map<GlobalVar, size_t, NodeHash, NodeEqual>;
using ConstMap = std::unordered_map<Constant, size_t, NodeHash, NodeEqual>;

struct VMCompilerContext {
TagNameMap tag_index_map;
TagMap tag_map;
GlobalMap global_map;
ConstMap const_map;
};
Expand Down Expand Up @@ -285,6 +307,11 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
Emit(Push(it->second));
}

void VisitExpr_(const MatchNode* match_node) {
auto match = GetRef<Match>(match_node);
std::cout << "Ignore compiling match node\n";
}

void VisitExpr_(const LetNode* let_node) {
this->VisitExpr(let_node->value);
var_map.insert({ let_node->var, this->stack_index++ });
Expand Down Expand Up @@ -371,11 +398,27 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
auto it = this->context->global_map.find(global);
CHECK(it != this->context->global_map.end());
Emit(Invoke(it->second));
} else if (auto constructor_node = call_node->op.as<ConstructorNode>()) {
auto constructor = GetRef<Constructor>(constructor_node);
auto tag = GetConstructorTag(constructor);
Emit(AllocDatatype(tag, call_node->args.size()));
} else {
LOG(FATAL) << "unsupported case in vm compiler: " << call_node->op;
}
}

size_t GetConstructorTag(tvm::relay::Constructor constructor) {
auto it = this->context->tag_map.find(constructor);
if (it != this->context->tag_map.end()) {
return it->second;
} else {
auto tag = this->context->tag_map.size();
this->context->tag_map[constructor] = tag;
this->context->tag_index_map[tag] = constructor;
return tag;
}
}

void VisitExpr_(const FunctionNode* func_node) {
CHECK(!seen_func) << GetRef<Function>(func_node);
this->seen_func = true;
Expand Down Expand Up @@ -410,7 +453,7 @@ void PopulatePackedFuncMap(


CompiledFunc CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
std::cout << func << std::endl;
// std::cout << func << std::endl;
size_t params = func->params.size();
VMCompiler compiler(context);
// std::cout << "Compiling: " << func << std::endl;
Expand Down Expand Up @@ -449,6 +492,7 @@ VirtualMachine CompileModule(const Module& mod) {
for (auto named_func : mod->functions) {
auto gvar = named_func.first;
auto func = named_func.second;
std::cout << "Compiling func " << gvar->name_hint << std::endl;
auto cfunc = CompileFunc(&context, gvar, func);
auto lfuncs = cfunc.first;
auto vm_func = cfunc.second;
Expand All @@ -464,12 +508,16 @@ VirtualMachine CompileModule(const Module& mod) {
}

for (auto vm_func : vm.functions) {
std::cout << "Function: " << std::endl;
std::cout << "Function: " << vm_func.name << std::endl;
VMFunctionPrint(vm_func);
std::cout << "-------------" << std::endl;
}

PopulatePackedFuncMap(lowered_funcs, &vm.packed_funcs);

vm.global_map = context.global_map;
vm.tag_index_map = context.tag_index_map;

return vm;
}

Expand Down Expand Up @@ -520,6 +568,7 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<VMOb
}

VMObject VirtualMachine::Invoke(const VMFunction& func, const std::vector<VMObject>& args) {
std::cout << "Executing function " << func.name << std::endl;
InvokeGlobal(func, args);
Run();
auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]);
Expand Down Expand Up @@ -641,6 +690,16 @@ void VirtualMachine::Run() {
pc++;
goto main_loop;
}
case Opcode::AllocDatatype: {
std::vector<VMObject> fields;
size_t stack_size = stack.size();
for (size_t i = 0; i < instr.num_fields; ++i) {
fields.push_back(stack[stack_size - instr.num_fields + i]);
}
stack.push_back(VMDatatype(instr.constructor_tag, fields));
pc++;
goto main_loop;
}
case Opcode::Push: {
CHECK(bp + instr.stack_index < stack.size());
stack.push_back(stack[bp + instr.stack_index]);
Expand Down Expand Up @@ -699,28 +758,33 @@ VMObject ValueToVM(Value value) {
return out[0];
}

Value VMToValue(VMObject obj) {
Value VMToValue(TagNameMap& tag_index_map, VMObject obj) {
switch (obj->tag) {
case VMObjectTag::kTensor: {
return TensorValueNode::make(ToNDArray(obj));
}
case VMObjectTag::kDatatype: {
LOG(FATAL) << "unsupported return value: data type";
return Value();
auto data_type = std::dynamic_pointer_cast<VMDatatypeCell>(obj.ptr);

tvm::Array<Value> fields;
for (size_t i = 0; i < data_type->fields.size(); ++i) {
fields.push_back(VMToValue(tag_index_map, data_type->fields[i]));
}

return ConValueNode::make(tag_index_map[data_type->tag], fields);
}
default:
LOG(FATAL) << "unsupported return value";
return Value();
}
}

VMObject EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
const std::vector<VMObject>& vm_args) {
std::tuple<VMObject, TagNameMap>
EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
const std::vector<VMObject>& vm_args) {
VirtualMachine vm = VirtualMachine::FromModule(module, ctxs);
std::cout << "--------------------------" << std::endl;
VMFunctionPrint(vm.functions[0]);
std::cout << "--------------------------" << std::endl;
return vm.Invoke(module->entry_func, vm_args);
std::cout << "Entry function is " << module->entry_func << std::endl;
return std::make_tuple(vm.Invoke(module->entry_func, vm_args), vm.tag_index_map);
}

TVM_REGISTER_API("relay._vm._ValueToVM")
Expand All @@ -730,7 +794,8 @@ TVM_REGISTER_API("relay._vm._ValueToVM")

TVM_REGISTER_API("relay._vm._VMToValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = VMToValue(args[0]);
TagNameMap tag_index_map{};
*ret = VMToValue(tag_index_map, args[0]);
});

TVM_REGISTER_API("relay._vm._Tensor")
Expand Down Expand Up @@ -779,7 +844,7 @@ TVM_REGISTER_API("relay._vm._evaluate_vm")
vm_args.push_back(obj);
}
auto result = EvaluateModule(module, {ctx}, vm_args);
*ret = result;
*ret = VMToValue(std::get<1>(result), std::get<0>(result));
});


Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tvm import relay
from tvm.relay.vm import eval_vm, eta_expand
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.prelude import Prelude

def test_id():
x = relay.var('x', shape=(10, 10))
Expand Down Expand Up @@ -129,6 +130,22 @@ def test_tuple_second():
result = eval_vm(f, tvm.cpu(), (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data)

def test_list_constructor():
mod = relay.Module()
p = Prelude(mod)

nil = p.nil
cons = p.cons
l = p.l

f = relay.Function([], cons(nil(), nil()))

mod[mod.entry_func] = f

print("Entry func is {}".format(mod.entry_func))
result = eval_vm(mod, tvm.cpu())
print("Result is {}".format(result))

def test_let_tensor():
sb = relay.ScopeBuilder()
shape = (1,)
Expand Down

0 comments on commit 7eead36

Please sign in to comment.