Skip to content

Commit

Permalink
Bug fixing and hacking for Beacon
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored and icemelon committed Mar 15, 2019
1 parent af77668 commit 56a8927
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def optimize(self, expr):
ck_expr = ir_pass.infer_type(wrapped_expr, mod=self.mod)
simp_expr = ir_pass.simplify_inference(ck_expr)
ck_simp = ir_pass.infer_type(simp_expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_simp)
fused_expr = ir_pass.fuse_ops(ck_simp, 0, mod=self.mod)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused if isinstance(expr, Function) else Call(ck_fused, [])

Expand Down
30 changes: 28 additions & 2 deletions python/tvm/relay/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,26 @@
from . import _make
from . import _module
from . import expr as _expr

from . import ty as _ty
from . import _ir_pass

# Refactor w/ VM code.
def eta_expand(expr, mod):
if isinstance(expr, _expr.GlobalVar):
ck_type = mod[expr].checked_type
else:
prev = mod[mod.entry_func]
mod[mod.entry_func]
expr = _ir_pass.infer_type(expr, mod)
ck_type = expr.checked_type

assert isinstance(ck_type, FuncType)

eta_args = []
for arg_type in ck_type.arg_types:
eta_args.append(var('a', type_annotation=arg_type))

return Function(eta_args, Call(expr, eta_args))

@register_relay_node
class Module(RelayNode):
Expand Down Expand Up @@ -61,9 +79,17 @@ def __setitem__(self, var, val):
return self._add(var, val)

def _add(self, var, val, update=False):
if isinstance(val, _expr.Function):
if isinstance(val, _expr.Expr):
if isinstance(var, _base.string_types):
var = _expr.GlobalVar(var)

if not isinstance(val, _expr.Function):
if (val, _expr.GlobalVar):
eta_expand(val, self)
else:
val = _expr.Function([], val)


_make.Module_Add(self, var, val, update)
else:
assert isinstance(val, _ty.Type)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def eta_expand(expr, mod):
ck_type = expr.checked_type

assert isinstance(ck_type, FuncType)

eta_args = []
for arg_type in ck_type.arg_types:
eta_args.append(var('a', type_annotation=arg_type))
Expand Down
1 change: 1 addition & 0 deletions src/relay/ir/hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class RelayHashHandler:
}

for (auto t : call->type_args) {
CHECK(t.defined());
hash = Combine(hash, TypeHash(t));
}

Expand Down
12 changes: 12 additions & 0 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,18 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}

void VisitExpr_(const CallNode* call) final {
// // The global var case should cause a split in the graph.
// if (auto call_node = call->op.as<GlobalVarNode>()) {
// // do not fuse through this call.
// this->Update(call->op, nullptr, kOpaque);
// for (auto arg : call->args) {
// this->Update(arg, nullptr, kOpaque);
// }
// ExprVisitor::VisitExpr_(call);
// this->AddNode(call);
// return;
// }

CHECK(graph_.node_map.count(call));
Node* node = graph_.node_map.at(call);
static auto fpattern =
Expand Down
11 changes: 10 additions & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,16 @@ Expr InferType(const Expr& expr, const Module& mod_ref) {
return func->body;
}
} else {
auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr);
auto tmp = mod_ref->GetGlobalVar("temp");
Function body;
if (auto func_node = expr.as<FunctionNode>()) {
body = GetRef<Function>(func_node);
} else {
body = FunctionNode::make({}, body, Type(), {}, {});
}
mod_ref->AddUnchecked(tmp, body);
auto e = TypeInferencer(mod_ref, tmp).Infer(expr);
mod_ref->Remove(tmp);
CHECK(WellFormed(e));
auto free_tvars = FreeTypeVars(e, mod_ref);
CHECK(free_tvars.size() == 0)
Expand Down

0 comments on commit 56a8927

Please sign in to comment.