diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 05a067d3ff14..224428bcc8bb 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1944,13 +1944,11 @@ def infer_value(self, input_val, params, mod=None): self._infer_simulated = False self._mod = mod return self.visit(input_val).data - #return _infer_value(input_val, params, mod) def infer_value_simulated(self, input_val, params): self._tmp_params = params self._infer_simulated = True return self.visit(input_val).data - #return _infer_value_simulated(input_val, params) def infer(self, expr): if self._infer_simulated: @@ -1978,7 +1976,10 @@ def visit_let(self, let): def visit_call(self, call): new_fn = self.visit(call.op) new_args = [self.visit(arg) for arg in call.args] - return self.infer(Call(new_fn, new_args, call.attrs)) + call = Call(new_fn, new_args, call.attrs) + if new_fn == _op.get("nn.batch_norm"): + return call + return self.infer(call) def visit_var(self, var): return self.infer(var) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a82f1a52f548..d9c481a20805 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2020,6 +2020,89 @@ def test_or(): verify_or(indata=[x, y], dtype=bool) +def test_batch_norm(): + def verify_batch_norm(in_shape): + batchnorm = onnx.helper.make_node('BatchNormalization', + inputs=["x", "scale", "B", "mean", "var"], + outputs=['Y']) + + graph = helper.make_graph([batchnorm], + "batchnorm_test", + inputs=[helper.make_tensor_value_info("x", + TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("scale", + TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("B", + TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("mean", + TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("var", + TensorProto.FLOAT, [in_shape[1]]), + ], + outputs=[helper.make_tensor_value_info("Y", + TensorProto.FLOAT, list(in_shape))]) + + model = helper.make_model(graph, producer_name='batchnorm_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=in_shape).astype('float32') + scale = np.random.uniform(size=in_shape[1]).astype('float32') + b = np.random.uniform(size=in_shape[1]).astype('float32') + mean = np.random.uniform(size=in_shape[1]).astype('float32') + var = np.random.uniform(size=in_shape[1]).astype('float32') + onnx_out = get_onnxruntime_output(model, [x, scale, b, mean, var], 'float32')[0] + tvm_out = get_tvm_output(model, [x, scale, b, mean, var], target, ctx, in_shape, 'float32') + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + + verify_batch_norm([1, 3, 224, 224]) + verify_batch_norm([1, 3, 24, 24]) + verify_batch_norm([16, 3, 24, 24]) + verify_batch_norm([16, 16, 24, 24]) + verify_batch_norm([16, 16, 10, 10]) + + +def test_batch_norm_dynamic_subgraph(): + def verify_batch_norm_dynamic_subgraph(in_shape, o_shape): + batchnorm = onnx.helper.make_node('BatchNormalization', + inputs=["x", "scale", "B", "mean", "var"], + outputs=['Y']) + + shape_node = helper.make_node("Shape", ['Y'], ['shape']) + reshape_node = helper.make_node("Reshape", ["in", "shape"], ["out"]) + graph = helper.make_graph([batchnorm, shape_node, reshape_node], + "batchnorm_test", + inputs=[helper.make_tensor_value_info("x", + TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(o_shape)), + helper.make_tensor_value_info("scale", + TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("B", + TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("mean", + TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("var", + TensorProto.FLOAT, [in_shape[1]]), + ], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(in_shape))]) + + model = helper.make_model(graph, producer_name='batchnorm_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=in_shape).astype('float32') + inp = np.random.uniform(size=o_shape).astype('float32') + scale = np.random.uniform(size=in_shape[1]).astype('float32') + b = np.random.uniform(size=in_shape[1]).astype('float32') + mean = np.random.uniform(size=in_shape[1]).astype('float32') + var = np.random.uniform(size=in_shape[1]).astype('float32') + onnx_out = get_onnxruntime_output(model, [x, inp, scale, b, mean, var], 'float32')[0] + tvm_out = get_tvm_output(model, [x, inp, scale, b, mean, var], target, ctx, in_shape, 'float32') + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + + verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160]) + + def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, auto_pad="NOTSET", unset_pad=False): if unset_pad: node = helper.make_node('Conv', @@ -2892,6 +2975,8 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ test_or() test_depth_to_space() test_space_to_depth() + test_batch_norm() + test_batch_norm_dynamic_subgraph() test_conv() test_convtranspose() test_unsqueeze_constant()