diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 5642c4122580..97b9d7a44997 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -913,9 +913,9 @@ def _mx_amp_multicast(inputs, attrs): has_float32 = any(x == "float32" for x in dtypes) dtype = dtypes[0] if cast_narrow and has_float16: - dtype = 'float16' + dtype = 'float16' if not cast_narrow and has_float32: - dtype = 'float32' + dtype = 'float32' return [_op.cast(x, dtype) for x in inputs] def _mx_grid_generator(inputs, attrs):