Skip to content

Commit

Permalink
Tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed May 29, 2019
1 parent b67c442 commit dbea7e4
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 65 deletions.
3 changes: 1 addition & 2 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ def while_loop(cond, loop_vars, loop_bodies):

func = Function(fresh_vars, sb.get())
let = Let(wl, func, wl)
print(let)
return let
return ir_pass.infer_type(let)

def foreach(iter, init, body):
i = var("i", shape=(), dtype='int32')
Expand Down
9 changes: 0 additions & 9 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,6 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
// type inferencer will populate it up
std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;

// stores type hints
std::unordered_map<Expr, std::vector<Type>, NodeHash, NodeEqual> hint_map_;

// The solver used by the inferencer.
TypeSolver solver_;
// relation function
Expand Down Expand Up @@ -600,12 +597,6 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
<< PrettyPrint(type_hint->type_hint);
auto ty = GetType(type_hint->expr);
solver_.AddHint(ty, type_hint->type_hint);
auto it = this->hint_map_.find(type_hint->expr);
if (it == this->hint_map_.end()) {
this->hint_map_[type_hint->expr] = { type_hint->type_hint };
} else {
it->second.push_back(type_hint->type_hint);
}
return TupleTypeNode::make({});
}

Expand Down
106 changes: 52 additions & 54 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,51 +66,55 @@ def _body(i, st):
# result = ex.evaluate(mod.entry_func)()
# np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))

def test_dynamic_concat_with_wrong_annotation():
"""
v0.0.1
fn (%start: int32) {
%7 = {
let %while_loop = fn (%i: int32, %st: Tensor[(1, 1), int32]) {
%0 = less(%i, 10)
%1 = min(%0)
if (%1) {
%2 = add(%i, 1)
%3 = reshape(%i, newshape=[1, 1])
%4 = (%st, %3)
/* The result of concat should be 1,1 but it is 2, 1. */
%5 = concatenate(%4)
%while_loop(%2, %5)
} else {
(%i, %st)
}
}
%6 = reshape(0, newshape=[1, 1])
%while_loop(%start, %6)
}
%7.1
}
"""
# Initial Values.
i = relay.var('i', shape=(), dtype='int32')
st = relay.var('st', shape=(1, 1), dtype='int32')

def _cond(i, st):
return relay.op.min(relay.op.less(i, int32(10)))

def _body(i, st):
i_vec = relay.op.reshape(i, (1,1))
ret = relay.op.concatenate([st, i_vec], axis=0)
return i + int32(1), ret

loop = while_loop(_cond, [i, st], _body)
start = relay.var('start', shape=(), dtype='int32')
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1))
try:
func = relay.ir_pass.infer_type(func)
except Exception as e:
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
# def test_dynamic_concat_with_wrong_annotation():
# """
# v0.0.1
# fn (%start: int32) {
# %7 = {
# let %while_loop = fn (%i: int32, %st: Tensor[(1, 1), int32]) {
# %0 = less(%i, 10)
# %1 = min(%0)
# if (%1) {
# %2 = add(%i, 1)
# %3 = reshape(%i, newshape=[1, 1])
# %4 = (%st, %3)
# /* The result of concat should be 1,1 but it is 2, 1. */
# %5 = concatenate(%4)
# %while_loop(%2, %5)
# } else {
# (%i, %st)
# }
# }
# %6 = reshape(0, newshape=[1, 1])
# %while_loop(%start, %6)
# }
# %7.1
# }
# """
# # Initial Values.
# i = relay.var('i', shape=(), dtype='int32')
# st = relay.var('st', shape=(1, 1), dtype='int32')

# def _cond(i, st):
# return relay.op.min(relay.op.less(i, int32(10)))

# def _body(i, st):
# i_vec = relay.op.reshape(i, (1,1))
# ret = relay.op.concatenate([st, i_vec], axis=0)
# return i + int32(1), ret

# loop = while_loop(_cond, [i, st], _body)
# loop = relay.ir_pass.infer_type(loop)
# start = relay.var('start', shape=(), dtype='int32')
# body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
# func = relay.Function([start], relay.TupleGetItem(body, 1))
# try:
# # import pdb; pdb.set_trace()
# func = relay.ir_pass.infer_type(func)
# import pdb; pdb.set_trace()
# assert False
# except Exception as e:
# assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)

def test_dynamic_concat_with_hint():
"""
Expand Down Expand Up @@ -146,16 +150,10 @@ def _body(i, st):
return i + int32(1), ret

loop = while_loop(_cond, [i, st], _body)
print(loop)
loop = relay.ir_pass.infer_type(loop)
start = relay.var('start', shape=(), dtype='int32')
print(loop)
# body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
# print(body)
# body = relay.ir_pass.infer_type(body)
# func = relay.Function([start], relay.TupleGetItem(body, 1))
# func = relay.ir_pass.infer_type(func)
# import pdb; pdb.set_trace()
st_shape = loop.value.params[1].type_annotation.shape
ret_shape = loop.value.ret_type.fields[1].shape
assert st_shape == ret_shape

if __name__ == "__main__":
test_arange_with_dynamic_shape()
Expand Down

0 comments on commit dbea7e4

Please sign in to comment.