Skip to content

Commit

Permalink
fix no_grad argspec (#23790)
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
songyouwei authored Apr 14, 2020
1 parent 9549b78 commit 8f63a3e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/paddle/fluid/dygraph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import decorator
import contextlib
import functools
import sys
Expand Down Expand Up @@ -196,12 +197,12 @@ def test_layer():
return _switch_tracer_mode_guard_(is_train=False)
else:

@functools.wraps(func)
def __impl__(*args, **kwargs):
@decorator.decorator
def __impl__(func, *args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)

return __impl__
return __impl__(func)


@signature_safe_contextmanager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import unittest
import inspect

from test_imperative_base import new_program_scope

Expand Down Expand Up @@ -51,6 +52,14 @@ def test_main(self):
self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.no_grad_func.__name__, "no_grad_func")

def need_no_grad_func(a, b=1):
return a + b

decorated_func = fluid.dygraph.no_grad(need_no_grad_func)
self.assertTrue(
str(inspect.getargspec(decorated_func)) ==
str(inspect.getargspec(need_no_grad_func)))

self.assertEqual(self.tracer._train_mode, self.init_mode)

with fluid.dygraph.guard():
Expand Down

0 comments on commit 8f63a3e

Please sign in to comment.