Skip to content

Commit

Permalink
First version that works with hints, needs clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed May 29, 2019
1 parent e0e7d4c commit b3fe5da
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 233 deletions.
46 changes: 17 additions & 29 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,35 +127,26 @@
def int32(val):
return relay.const(val, 'int32')

# TODO(@jroesch): possibly move to their own file
def while_loop(cond, loop_vars, loop_bodies):
sb = ScopeBuilder()
wl = Var("while_loop")
with sb.if_scope(cond(*loop_vars)):
sb.ret(wl(*loop_bodies(*loop_vars)))
with sb.else_scope():
sb.ret(Tuple(loop_vars))

# return wl, Function(loop_vars, sb.get())
# Return this decision
def _while_loop(*args):
return Let(
wl, Function(loop_vars, sb.get()),
wl(*args))

return _while_loop
def type_of(sb, expr):
if isinstance(expr, Var):
return expr.type_annotation
else:
it = IncompleteType()
v = var("unify", it)
sb.let(v, expr)
return it

# TODO(@jroesch): possibly move to their own file
def while_loop2(cond, loop_vars, loop_bodies):
def while_loop(cond, loop_vars, loop_bodies):
sb = ScopeBuilder()
wl = Var("while_loop")
type_hints = []
fresh_vars = []
for lv in loop_vars:
assert isinstance(lv, Var)
v = var(lv.name_hint, type_annotation=IncompleteType())
for i, lv in enumerate(loop_vars):
name = lv.name_hint if isinstance(lv, Var) else "arg{}".format(i)
v = var(name, type_annotation=IncompleteType())
fresh_vars.append(v)
type_hints.append(TypeHint(v, lv.type_annotation))
type_hints.append(TypeHint(v, type_of(sb, lv)))

for i, hint in enumerate(type_hints):
var_hint = var("hint{}".format(i))
Expand All @@ -166,13 +157,10 @@ def while_loop2(cond, loop_vars, loop_bodies):
with sb.else_scope():
sb.ret(Tuple(fresh_vars))

# return wl, Function(fresh_vars, sb.get())
def _while_loop(*args):
return Let(
wl, Function(fresh_vars, sb.get()),
wl(*args))

return _while_loop
func = Function(fresh_vars, sb.get())
let = Let(wl, func, wl)
print(let)
return let

def foreach(iter, init, body):
i = var("i", shape=(), dtype='int32')
Expand Down
72 changes: 72 additions & 0 deletions python/tvm/relay/scope_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __exit__(self, ptype, value, trace):
self._exit_cb()


def int32(val):
return relay.const(val, 'int32')

def _make_lets(bindings, ret_value):
"""Make a nested let expressions.
Expand Down Expand Up @@ -176,6 +179,75 @@ def _on_exit():
false_branch)
return WithScope(None, _on_exit)


def type_of(self, expr):
if isinstance(expr, Var):
return expr.type_annotation
else:
it = _ty.IncompleteType()
v = _expr.var("unify", it)
self.let(v, expr)
return it

# def while_loop(self, cond, loop_vars, loop_bodies):
# self._enter_scope()
# wl = _expr_Var("while_loop")

# with self.if_scope(cond(*loop_vars)):
# self.ret(wl(*loop_bodies(*loop_vars)))
# with sb.else_scope():
# self.ret(Tuple(loop_vars))

# def _on_exit():
# bindings, ret_value = self._exit_scope()
# _make_lets(bindings,)
# self._ret_values[-1] = _expr.Let(
# wl,
# Function(loop_vars, sb.get()),
# wl(loop_vars)

# return WithScope(None, _on_exit)

# # TODO(@jroesch): possibly move to their own file
# def while_loop2(cond, loop_vars, loop_bodies):
# sb = ScopeBuilder()
# wl = Var("while_loop")
# type_hints = []
# fresh_vars = []
# for i, lv in enumerate(loop_vars):
# name = lv.name_hint if isinstance(lv, Var) else "arg{}".format(i)
# v = var(name, type_annotation=IncompleteType())
# fresh_vars.append(v)
# type_hints.append(TypeHint(v, type_of(sb, lv)))

# for i, hint in enumerate(type_hints):
# var_hint = var("hint{}".format(i))
# sb.let(var_hint, hint)

# with sb.if_scope(cond(*fresh_vars)):
# sb.ret(wl(*loop_bodies(*fresh_vars)))
# with sb.else_scope():
# sb.ret(Tuple(fresh_vars))

# # return wl, Function(fresh_vars, sb.get())
# def _while_loop(*args):
# return Let(
# wl, Function(fresh_vars, sb.get()),
# wl(*args))

# return _while_loop

# def foreach(iter, init, body):
# i = var("i", shape=(), dtype='int32')
# st = var("st", type_annotation=relay.IncompleteType())
# update = body(i, st)
# dim = take(op.shape_of(iter), indices=i, axis=0)
# def _cond(i, st):
# return op.min(op.less(i, dim))
# loop = while_loop(
# _cond, [i, st], [i + int32(1), update])
# return loop(int32(0), init)

def ret(self, value):
"""Set the return value of this scope.
Expand Down
168 changes: 66 additions & 102 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return op->tensor_type();
}

Type VisitExpr_(const TupleNode* op) final {
/* Type VisitExpr_(const TupleNode* op) final {
if (!make_tuple_rel_.defined()) {
make_tuple_rel_ = TypeRelationFn(
EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_);
Expand All @@ -245,6 +245,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
solver_.AddConstraint(TypeRelationNode::make(
make_tuple_rel_, types, op->fields.size(), {}), GetRef<Tuple>(op));
return rtype;
} */

Type VisitExpr_(const TupleNode* op) final {
Array<Type> types;
for (Expr field : op->fields) {
types.push_back(GetType(field));
}
return TupleTypeNode::make(types);
}

Type VisitExpr_(const TupleGetItemNode* op) final {
Expand Down Expand Up @@ -335,6 +343,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type let_type = IncompleteTypeNode::make(Kind::kType);

if (is_functional_literal) {
let_type = GetType(let->var);
type_map_[let->var].checked_type = let_type;
}

Expand Down Expand Up @@ -589,6 +598,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
type_hint->type_hint.as<IncompleteTypeNode>() != nullptr)
<< "type hints can only be applied to tensor types: "
<< PrettyPrint(type_hint->type_hint);
auto ty = GetType(type_hint->expr);
solver_.AddHint(ty, type_hint->type_hint);
auto it = this->hint_map_.find(type_hint->expr);
if (it == this->hint_map_.end()) {
this->hint_map_[type_hint->expr] = { type_hint->type_hint };
Expand All @@ -598,95 +609,55 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
return TupleTypeNode::make({});
}

Type GeneralizeShape(const std::vector<TensorType>& concrete_types) {
if (concrete_types.size() >= 1) {
CHECK(concrete_types[0].defined());
std::cout << "shape" << concrete_types[0];
Shape shape = concrete_types[0]->shape;
DataType dtype = concrete_types[0]->dtype;
for (auto ty : concrete_types) {
std::cout << ty;
CHECK_EQ(dtype, ty->dtype);
CHECK_EQ(shape.size(), ty->shape.size());
for (size_t i = 0; i < shape.size(); i++) {
auto ldim = shape[i];
auto rdim = ty->shape[i];

std::cout << "dim1: " << ldim << "dim2: " << rdim;

if (ldim.same_as(rdim)) {
continue;
}

if (ldim.same_as(Any()) || rdim.same_as(Any())) {
continue;
}

auto left_is_int = ldim.as<tvm::IntImm>();
auto right_is_int = rdim.as<tvm::IntImm>();
if (left_is_int && right_is_int && left_is_int->value != right_is_int->value) {
shape.Set(i, Any());
continue;
}

auto left_is_var = ldim.as<tvm::Variable>();
auto right_is_var = rdim.as<tvm::Variable>();
if (left_is_var && right_is_var &&
!GetRef<tvm::Var>(right_is_var).same_as(GetRef<tvm::Var>(left_is_var))) {
shape.Set(i, Any());
continue;
}
}
LOG(INFO) << "shape after: " << shape;
}
return TensorTypeNode::make(shape, dtype);
} else {
return IncompleteTypeNode::make(Kind::kType);
}
}

void ApplyTypeHints() {
for (auto hint : hint_map_) {
auto expr = hint.first;
auto hints = hint.second;
LOG(INFO) << "type hint processing: " << expr;
// Each hint represents a possible type
// for the expression.
//
// In order to unify types we must
// split all concretete type information
// to construct the most general type.
//
// We will then unify the result with
// any hint types which are non-concrete.
std::vector<TensorType> concrete_types;
std::vector<Type> to_unify;
for (auto ty : hints) {
auto rty = solver_.Resolve(ty);
if (rty.as<IncompleteTypeNode>() != nullptr) {
to_unify.push_back(rty);
} else {
concrete_types.push_back(Downcast<TensorType>(rty));
}
}

std::cout << "above generalize";
// Add the generalized shape.
Type general_shape = GeneralizeShape(concrete_types);
std::cout << "Generalized Shape: " << general_shape;
for (auto ty : to_unify) {
CHECK(general_shape.defined());
general_shape = Unify(general_shape, ty, expr);
}
// void ApplyTypeHints() {
// for (auto hint : hint_map_) {
// auto expr = hint.first;
// auto hints = hint.second;
// LOG(INFO) << "type hint processing: " << expr;
// // Each hint represents a possible type
// // for the expression.
// //
// // In order to unify types we must
// // split all concretete type information
// // to construct the most general type.
// //
// // We will then unify the result with
// // any hint types which are non-concrete.
// std::vector<TensorType> concrete_types;
// std::vector<Type> to_unify;
// for (auto ty : hints) {
// auto rty = solver_.Resolve(ty);
// if (rty.as<IncompleteTypeNode>() != nullptr) {
// to_unify.push_back(rty);
// } else {
// concrete_types.push_back(Downcast<TensorType>(rty));
// }
// }

// std::cout << "above generalize";
// // Add the generalized shape.
// Type general_shape = GeneralizeShape(concrete_types);
// std::cout << "Generalized Shape: " << general_shape;
// for (auto ty : to_unify) {
// CHECK(general_shape.defined());
// general_shape = Unify(general_shape, ty, expr);
// }

// auto it = type_map_.find(expr);
// CHECK(it != type_map_.end());
// auto rtype = solver_.Resolve(it->second.checked_type);
// if (rtype.as<IncompleteTypeNode>() == nullptr) {
// it->second.checked_type = general_shape;
// } else {
// it->second.checked_type = Unify(rtype, general_shape, expr);
// }
// }
// }
void Solve() {
solver_.Solve();

auto it = type_map_.find(expr);
CHECK(it != type_map_.end());
auto rtype = solver_.Resolve(it->second.checked_type);
if (rtype.as<IncompleteTypeNode>() == nullptr) {
it->second.checked_type = general_shape;
} else {
it->second.checked_type = Unify(rtype, general_shape, expr);
}
if (err_reporter.AnyErrors()) {
err_reporter.RenderErrors(mod_);
}
}
};
Expand Down Expand Up @@ -798,18 +769,18 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
// check if we need update the new_e
bool need_update_type = !checked_type.same_as(new_e->checked_type_);
bool need_update_call = (
std::is_base_of<CallNode, T>::value); /* &&
std::is_base_of<CallNode, T>::value &&
it->second.type_args.defined() &&
!it->second.type_args.same_as(new_call->type_args)) */
!it->second.type_args.same_as(new_call->type_args));
bool need_update_var = (
std::is_base_of<VarNode, T>::value &&
update_missing_type_annotation_ &&
!new_var->type_annotation.defined());

bool need_update_fn = (
std::is_base_of<FunctionNode, T>::value &&
bool need_update_fn =(
std::is_base_of<FunctionNode, T>::value); /* &&
update_missing_type_annotation_ &&
!new_fn->ret_type.defined());
!new_fn->ret_type.defined()); */

if (!need_update_type &&
!need_update_var &&
Expand Down Expand Up @@ -875,14 +846,7 @@ Expr TypeInferencer::Infer(Expr expr) {
GetType(expr);

// Step 2: Solve the constraints.
solver_.Solve();

// Step 1: Apply type hints.
ApplyTypeHints();

if (err_reporter.AnyErrors()) {
err_reporter.RenderErrors(mod_);
}
Solve();

// Step 3: Attach resolved types to checked_type field.
auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
Expand Down
Loading

0 comments on commit b3fe5da

Please sign in to comment.