diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 201c6bac59f0..50987f9209f3 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1227,7 +1227,7 @@ def _impl(inputs, attr, params, mod): attr['data_format'] = attr['data_format'].decode("utf-8") if attr['data_format'] == 'NCHW': axis = 1 - if 'U' in attr: + if 'U' in attr and attr['U'].name != attr['T'].name: need_cast = True inputs[0] = _op.cast(inputs[0], dtype=attr['U'].name) # Check if mean and variance are empty