Skip to content

Commit

Permalink
fix batchnorm infer_value error, add regression test and unit test (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored Jun 18, 2020
1 parent eacfe89 commit d85efa1
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1944,13 +1944,11 @@ def infer_value(self, input_val, params, mod=None):
self._infer_simulated = False
self._mod = mod
return self.visit(input_val).data
#return _infer_value(input_val, params, mod)

def infer_value_simulated(self, input_val, params):
self._tmp_params = params
self._infer_simulated = True
return self.visit(input_val).data
#return _infer_value_simulated(input_val, params)

def infer(self, expr):
if self._infer_simulated:
Expand Down Expand Up @@ -1978,7 +1976,10 @@ def visit_let(self, let):
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return self.infer(Call(new_fn, new_args, call.attrs))
call = Call(new_fn, new_args, call.attrs)
if new_fn == _op.get("nn.batch_norm"):
return call
return self.infer(call)

def visit_var(self, var):
return self.infer(var)
Expand Down
85 changes: 85 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,89 @@ def test_or():
verify_or(indata=[x, y], dtype=bool)


def test_batch_norm():
def verify_batch_norm(in_shape):
batchnorm = onnx.helper.make_node('BatchNormalization',
inputs=["x", "scale", "B", "mean", "var"],
outputs=['Y'])

graph = helper.make_graph([batchnorm],
"batchnorm_test",
inputs=[helper.make_tensor_value_info("x",
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("scale",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("B",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("mean",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("var",
TensorProto.FLOAT, [in_shape[1]]),
],
outputs=[helper.make_tensor_value_info("Y",
TensorProto.FLOAT, list(in_shape))])

model = helper.make_model(graph, producer_name='batchnorm_test')

for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape).astype('float32')
scale = np.random.uniform(size=in_shape[1]).astype('float32')
b = np.random.uniform(size=in_shape[1]).astype('float32')
mean = np.random.uniform(size=in_shape[1]).astype('float32')
var = np.random.uniform(size=in_shape[1]).astype('float32')
onnx_out = get_onnxruntime_output(model, [x, scale, b, mean, var], 'float32')[0]
tvm_out = get_tvm_output(model, [x, scale, b, mean, var], target, ctx, in_shape, 'float32')
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)

verify_batch_norm([1, 3, 224, 224])
verify_batch_norm([1, 3, 24, 24])
verify_batch_norm([16, 3, 24, 24])
verify_batch_norm([16, 16, 24, 24])
verify_batch_norm([16, 16, 10, 10])


def test_batch_norm_dynamic_subgraph():
def verify_batch_norm_dynamic_subgraph(in_shape, o_shape):
batchnorm = onnx.helper.make_node('BatchNormalization',
inputs=["x", "scale", "B", "mean", "var"],
outputs=['Y'])

shape_node = helper.make_node("Shape", ['Y'], ['shape'])
reshape_node = helper.make_node("Reshape", ["in", "shape"], ["out"])
graph = helper.make_graph([batchnorm, shape_node, reshape_node],
"batchnorm_test",
inputs=[helper.make_tensor_value_info("x",
TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("in",
TensorProto.FLOAT, list(o_shape)),
helper.make_tensor_value_info("scale",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("B",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("mean",
TensorProto.FLOAT, [in_shape[1]]),
helper.make_tensor_value_info("var",
TensorProto.FLOAT, [in_shape[1]]),
],
outputs=[helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(in_shape))])

model = helper.make_model(graph, producer_name='batchnorm_test')

for target, ctx in ctx_list():
x = np.random.uniform(size=in_shape).astype('float32')
inp = np.random.uniform(size=o_shape).astype('float32')
scale = np.random.uniform(size=in_shape[1]).astype('float32')
b = np.random.uniform(size=in_shape[1]).astype('float32')
mean = np.random.uniform(size=in_shape[1]).astype('float32')
var = np.random.uniform(size=in_shape[1]).astype('float32')
onnx_out = get_onnxruntime_output(model, [x, inp, scale, b, mean, var], 'float32')[0]
tvm_out = get_tvm_output(model, [x, inp, scale, b, mean, var], target, ctx, in_shape, 'float32')
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)

verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160])


def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, auto_pad="NOTSET", unset_pad=False):
if unset_pad:
node = helper.make_node('Conv',
Expand Down Expand Up @@ -2892,6 +2975,8 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_
test_or()
test_depth_to_space()
test_space_to_depth()
test_batch_norm()
test_batch_norm_dynamic_subgraph()
test_conv()
test_convtranspose()
test_unsqueeze_constant()
Expand Down

0 comments on commit d85efa1

Please sign in to comment.