diff --git a/python/paddle/jit/dy2static/function_spec.py b/python/paddle/jit/dy2static/function_spec.py index da82295409fee..cb40c3ae7d43e 100644 --- a/python/paddle/jit/dy2static/function_spec.py +++ b/python/paddle/jit/dy2static/function_spec.py @@ -373,7 +373,14 @@ def check_type_and_len(input, spec, check_length=False): ) real_spec.name = input_spec.name if spec_greater(input_spec, real_spec): - return input_spec + # change shape but keep the others (stop_gradient / dtype) . + real_spec.shape = input_spec.shape + else: + logging_utils.warn( + "input spec is not compatitable with real inputs. input_spec: {input_spec} , real_spec: {real_spec} ".format( + input_spec=input_spec, real_spec=real_spec + ) + ) return real_spec else: # NOTE(Aurelius84): Support non-Tensor type as input spec info @@ -480,8 +487,4 @@ def _shape_greater(first_shape, second_shape): return False return True - return ( - other.stop_gradient == first.stop_gradient - and other.dtype == first.dtype - and _shape_greater(first.shape, other.shape) - ) + return _shape_greater(first.shape, other.shape)