diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 5b2ecc27b9987..1b7ed77e9b576 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -25,7 +25,7 @@ from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard from . import op, transform - +from .analysis import free_vars def get_tensor_array_shape(expr, dtype, prelude): """Get the static shape of a tensor array if it has fixed rank shape. @@ -51,7 +51,7 @@ def get_tensor_array_shape(expr, dtype, prelude): has dynamic shape. """ mod = prelude.mod - mod["main"] = Function([], expr) + mod["main"] = Function(free_vars(expr), expr) mod = transform.InferType()(mod) checked_type = mod["main"].body.checked_type assert isinstance(checked_type, TypeCall), "Input must be a tensor array." diff --git a/src/ir/module.cc b/src/ir/module.cc index b34740865fc60..bcab39aabf32e 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -189,16 +189,10 @@ relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::F // Type check the item before we add it to the module. auto fv = relay::FreeVars(func); auto ftv = relay::FreeTypeVars(func, mod); - if (fv.size() != 0) { - LOG(WARNING) << "There are free variables: " << fv << " in function: " << AsText(func, false) - << std::endl; - } - if (ftv.size() != 0) { - LOG(WARNING) << "There are free type variables: " << ftv - << " in function: " << AsText(func, false) << std::endl; - } - func = relay::Function(concat(func->params, fv), func->body, func->ret_type, - concat(func->type_params, ftv), func->attrs); + CHECK_EQ(fv.size(), 0) << "There are free variables: " << fv + << " in function: " << AsText(func, false); + CHECK_EQ(ftv.size(), 0) << "There are free type variables: " << fv + << " in function: " << AsText(func, false); // Type check the item before we add it to the module. relay::Function checked_func = InferType(func, mod, var); return checked_func; diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index cee583fe74cfb..97657a4afd93d 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -3868,143 +3868,5 @@ def lstm_cell(): tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) -####################################################################### -# Main -# ---- if __name__ == '__main__': - # Transforms - test_forward_slice() - test_forward_transpose() - test_forward_reshape() - test_forward_depthtospace() - test_forward_spacetodepth() - test_forward_squeeze() - test_forward_pack() - test_forward_size() - test_forward_broadcast_to() - test_forward_fill() - test_forward_crop() - test_forward_resize() - test_forward_crop_and_resize() - test_forward_pad() - test_forward_unpack() - test_forward_gather() - test_forward_gather_nd() - test_forward_stridedslice() - test_forward_split() - test_forward_unstack() - test_forward_tile() - test_forward_top_k_v2() - test_forward_clip_by_value() - test_forward_maximum() - test_forward_minimum() - test_forward_range() - test_forward_right_shift() - test_forward_left_shift() - test_forward_truncatemod() - test_forward_one_hot() - test_forward_atan2() - test_forward_nms() - - # Activations - test_forward_sigmoid() - test_forward_relu() - test_forward_leaky_relu() - test_forward_elu() - test_forward_selu() - test_forward_tanh() - - # Tensor - test_forward_round() - test_forward_reverse_v2() - test_forward_pow_exp() - test_forward_sign() - test_forward_negative() - test_forward_divide() - test_forward_abs() - test_forward_softplus() - test_forward_sqrt() - test_forward_rsqrt() - test_forward_expand_dims() - test_forward_square() - test_forward_softmax() - test_forward_log_softmax() - test_forward_bias_add() - test_forward_zeros_like() - test_forward_squared_difference() - test_forward_add_n() - test_forward_floormod() - test_forward_isfinite() - test_forward_isinf() - test_forward_unravel_index() - test_forward_unary() - - # Reductions - test_forward_argminmax() - test_forward_reduce() - test_forward_mean() - - # TensorArray - test_tensor_array_write_read() - test_tensor_array_concat() - test_tensor_array_scatter() - test_tensor_array_gather() - test_tensor_array_size() - test_tensor_array_split() - test_tensor_array_stack() - test_tensor_array_unstack() - - # General - test_forward_multi_input() - test_forward_multi_output() - test_forward_variable() - test_placeholder() - - # NN - test_forward_convolution() - test_forward_convolution3d() - test_forward_convolution3d_transpose() - test_forward_pooling() - test_forward_concat_v2() - test_forward_lrn() - test_forward_l2_normalize() - test_forward_space_to_batch_nd() - test_forward_batch_to_space_nd() - test_forward_dilation() - - # End to End - test_forward_inception_v3() - test_forward_inception_v1() - test_forward_mobilenet() - test_forward_resnetv2() - test_forward_ssd() - test_forward_placeholder() - test_forward_ptb() - - # RNN - test_forward_lstm() - - # Elementwise - test_forward_ceil() - test_forward_floor() - - # Relational ops - test_forward_rel_ops() - test_forward_logical() - test_forward_where() - test_forward_matmul() - test_forward_batch_matmul() - - # Internal misc. ops - test_read_variable_op() - - # Sharing params case using Mean ops - test_sharing_node() - - # StatefulPartitionedCall - test_forward_spop() - - # Test dynamic input shape - test_forward_dynamic_input_shape() - - test_forward_dynmaic_rnn_lstmblockcell() + pytest.main([__file__]) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index e5082db25cbca..cc4748c92b002 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -17,13 +17,13 @@ """Test that type checker correcly computes types for expressions. """ +import pytest import tvm from tvm import te from tvm import relay from tvm.relay import op, transform, analysis from tvm.relay import Any - def run_infer_type(expr, mod=None): if not mod: mod = tvm.IRModule.from_expr(expr) @@ -368,26 +368,9 @@ def test_if(): f = relay.Var('f', choice_t) true_branch = relay.Var('True', relay.TensorType([Any(), 1], dtype='float32')) false_branch = relay.Var('False', relay.TensorType([Any(), Any()], dtype='float32')) - top = relay.Function([true_branch, false_branch], relay.If(f(), true_branch, false_branch)) + top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch)) ft = run_infer_type(top) tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32')) if __name__ == "__main__": - test_free_expr() - test_dual_op() - test_single_op() - test_recursion() - test_monomorphic_let() - test_decl() - test_recursion() - test_tuple() - test_incomplete_call() - test_type_args() - test_global_var_recursion() - test_equal() - test_ref() - test_constructor_type() - test_constructor_call() - test_adt_match() - test_let_polymorphism() - test_if() + pytest.main([__file__]) diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 2aae431b2a944..df3bbc19cb58c 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name, missing-docstring, no-else-return """Unit tests for the Relay VM serialization and deserialization.""" +import pytest import numpy as np import tvm @@ -291,22 +292,11 @@ def test_vm_shape_of(): newshape_var = relay.var('newshape', shape=(2,), dtype='int64') args.append(np.array((1, -1), dtype='int64')) - main = relay.reshape(relu_x, newshape=newshape_var) + main = relay.Function([x, newshape_var], relay.reshape(relu_x, newshape=newshape_var)) res = get_serialized_output(main, *args).asnumpy() tvm.testing.assert_allclose(res.flatten(), data.flatten()) if __name__ == "__main__": - test_serializer() - test_save_load() - test_const() - test_if() - test_loop() - test_tuple() - test_adt_list() - test_adt_compose() - test_closure() - test_synthetic() - test_mobilenet() - test_vm_shape_of() + pytest.main([__file__]) diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 821233446fb20..4b842b90995e7 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -65,7 +65,7 @@ def example(): z = relay.add(y, c) z1 = relay.add(y, c) z2 = relay.add(z, z1) - return relay.Function([x], z2) + return relay.Function([x, weight], z2) ############################################################################### # Let us register layout alteration for a conv2d op so that we can apply the