Skip to content

Commit

Permalink
[Relay] Make check stricter: disallow inserting function with free va…
Browse files Browse the repository at this point in the history
…rs into module. (apache#6313)

* save

lint

lint

fix test

fix test

* fix
  • Loading branch information
MarisaKirisame authored and Trevor Morris committed Aug 26, 2020
1 parent e4879e0 commit a828c78
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 185 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from . import op, transform

from .analysis import free_vars

def get_tensor_array_shape(expr, dtype, prelude):
"""Get the static shape of a tensor array if it has fixed rank shape.
Expand All @@ -51,7 +51,7 @@ def get_tensor_array_shape(expr, dtype, prelude):
has dynamic shape.
"""
mod = prelude.mod
mod["main"] = Function([], expr)
mod["main"] = Function(free_vars(expr), expr)
mod = transform.InferType()(mod)
checked_type = mod["main"].body.checked_type
assert isinstance(checked_type, TypeCall), "Input must be a tensor array."
Expand Down
14 changes: 4 additions & 10 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,10 @@ relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::F
// Type check the item before we add it to the module.
auto fv = relay::FreeVars(func);
auto ftv = relay::FreeTypeVars(func, mod);
if (fv.size() != 0) {
LOG(WARNING) << "There are free variables: " << fv << " in function: " << AsText(func, false)
<< std::endl;
}
if (ftv.size() != 0) {
LOG(WARNING) << "There are free type variables: " << ftv
<< " in function: " << AsText(func, false) << std::endl;
}
func = relay::Function(concat(func->params, fv), func->body, func->ret_type,
concat(func->type_params, ftv), func->attrs);
CHECK_EQ(fv.size(), 0) << "There are free variables: " << fv
<< " in function: " << AsText(func, false);
CHECK_EQ(ftv.size(), 0) << "There are free type variables: " << fv
<< " in function: " << AsText(func, false);
// Type check the item before we add it to the module.
relay::Function checked_func = InferType(func, mod, var);
return checked_func;
Expand Down
140 changes: 1 addition & 139 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3868,143 +3868,5 @@ def lstm_cell():
tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)


#######################################################################
# Main
# ----
if __name__ == '__main__':
# Transforms
test_forward_slice()
test_forward_transpose()
test_forward_reshape()
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_squeeze()
test_forward_pack()
test_forward_size()
test_forward_broadcast_to()
test_forward_fill()
test_forward_crop()
test_forward_resize()
test_forward_crop_and_resize()
test_forward_pad()
test_forward_unpack()
test_forward_gather()
test_forward_gather_nd()
test_forward_stridedslice()
test_forward_split()
test_forward_unstack()
test_forward_tile()
test_forward_top_k_v2()
test_forward_clip_by_value()
test_forward_maximum()
test_forward_minimum()
test_forward_range()
test_forward_right_shift()
test_forward_left_shift()
test_forward_truncatemod()
test_forward_one_hot()
test_forward_atan2()
test_forward_nms()

# Activations
test_forward_sigmoid()
test_forward_relu()
test_forward_leaky_relu()
test_forward_elu()
test_forward_selu()
test_forward_tanh()

# Tensor
test_forward_round()
test_forward_reverse_v2()
test_forward_pow_exp()
test_forward_sign()
test_forward_negative()
test_forward_divide()
test_forward_abs()
test_forward_softplus()
test_forward_sqrt()
test_forward_rsqrt()
test_forward_expand_dims()
test_forward_square()
test_forward_softmax()
test_forward_log_softmax()
test_forward_bias_add()
test_forward_zeros_like()
test_forward_squared_difference()
test_forward_add_n()
test_forward_floormod()
test_forward_isfinite()
test_forward_isinf()
test_forward_unravel_index()
test_forward_unary()

# Reductions
test_forward_argminmax()
test_forward_reduce()
test_forward_mean()

# TensorArray
test_tensor_array_write_read()
test_tensor_array_concat()
test_tensor_array_scatter()
test_tensor_array_gather()
test_tensor_array_size()
test_tensor_array_split()
test_tensor_array_stack()
test_tensor_array_unstack()

# General
test_forward_multi_input()
test_forward_multi_output()
test_forward_variable()
test_placeholder()

# NN
test_forward_convolution()
test_forward_convolution3d()
test_forward_convolution3d_transpose()
test_forward_pooling()
test_forward_concat_v2()
test_forward_lrn()
test_forward_l2_normalize()
test_forward_space_to_batch_nd()
test_forward_batch_to_space_nd()
test_forward_dilation()

# End to End
test_forward_inception_v3()
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_resnetv2()
test_forward_ssd()
test_forward_placeholder()
test_forward_ptb()

# RNN
test_forward_lstm()

# Elementwise
test_forward_ceil()
test_forward_floor()

# Relational ops
test_forward_rel_ops()
test_forward_logical()
test_forward_where()
test_forward_matmul()
test_forward_batch_matmul()

# Internal misc. ops
test_read_variable_op()

# Sharing params case using Mean ops
test_sharing_node()

# StatefulPartitionedCall
test_forward_spop()

# Test dynamic input shape
test_forward_dynamic_input_shape()

test_forward_dynmaic_rnn_lstmblockcell()
pytest.main([__file__])
23 changes: 3 additions & 20 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
"""Test that type checker correcly computes types
for expressions.
"""
import pytest
import tvm
from tvm import te
from tvm import relay
from tvm.relay import op, transform, analysis
from tvm.relay import Any


def run_infer_type(expr, mod=None):
if not mod:
mod = tvm.IRModule.from_expr(expr)
Expand Down Expand Up @@ -368,26 +368,9 @@ def test_if():
f = relay.Var('f', choice_t)
true_branch = relay.Var('True', relay.TensorType([Any(), 1], dtype='float32'))
false_branch = relay.Var('False', relay.TensorType([Any(), Any()], dtype='float32'))
top = relay.Function([true_branch, false_branch], relay.If(f(), true_branch, false_branch))
top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch))
ft = run_infer_type(top)
tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32'))

if __name__ == "__main__":
test_free_expr()
test_dual_op()
test_single_op()
test_recursion()
test_monomorphic_let()
test_decl()
test_recursion()
test_tuple()
test_incomplete_call()
test_type_args()
test_global_var_recursion()
test_equal()
test_ref()
test_constructor_type()
test_constructor_call()
test_adt_match()
test_let_polymorphism()
test_if()
pytest.main([__file__])
16 changes: 3 additions & 13 deletions tests/python/relay/test_vm_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name, missing-docstring, no-else-return
"""Unit tests for the Relay VM serialization and deserialization."""
import pytest
import numpy as np

import tvm
Expand Down Expand Up @@ -291,22 +292,11 @@ def test_vm_shape_of():

newshape_var = relay.var('newshape', shape=(2,), dtype='int64')
args.append(np.array((1, -1), dtype='int64'))
main = relay.reshape(relu_x, newshape=newshape_var)
main = relay.Function([x, newshape_var], relay.reshape(relu_x, newshape=newshape_var))

res = get_serialized_output(main, *args).asnumpy()
tvm.testing.assert_allclose(res.flatten(), data.flatten())


if __name__ == "__main__":
test_serializer()
test_save_load()
test_const()
test_if()
test_loop()
test_tuple()
test_adt_list()
test_adt_compose()
test_closure()
test_synthetic()
test_mobilenet()
test_vm_shape_of()
pytest.main([__file__])
2 changes: 1 addition & 1 deletion tutorials/dev/use_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def example():
z = relay.add(y, c)
z1 = relay.add(y, c)
z2 = relay.add(z, z1)
return relay.Function([x], z2)
return relay.Function([x, weight], z2)

###############################################################################
# Let us register layout alteration for a conv2d op so that we can apply the
Expand Down

0 comments on commit a828c78

Please sign in to comment.