Skip to content

Commit

Permalink
[dy2static] fix the speed problem introduced by #50883 (#51606)
Browse files Browse the repository at this point in the history
  • Loading branch information
2742195759 authored Mar 14, 2023
1 parent dca81a4 commit 46d6080
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions python/paddle/jit/dy2static/function_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,14 @@ def check_type_and_len(input, spec, check_length=False):
)
real_spec.name = input_spec.name
if spec_greater(input_spec, real_spec):
return input_spec
# change shape but keep the others (stop_gradient / dtype) .
real_spec.shape = input_spec.shape
else:
logging_utils.warn(
"input spec is not compatitable with real inputs. input_spec: {input_spec} , real_spec: {real_spec} ".format(
input_spec=input_spec, real_spec=real_spec
)
)
return real_spec
else:
# NOTE(Aurelius84): Support non-Tensor type as input spec info
Expand Down Expand Up @@ -480,8 +487,4 @@ def _shape_greater(first_shape, second_shape):
return False
return True

return (
other.stop_gradient == first.stop_gradient
and other.dtype == first.dtype
and _shape_greater(first.shape, other.shape)
)
return _shape_greater(first.shape, other.shape)

0 comments on commit 46d6080

Please sign in to comment.