diff --git a/test/dygraph_to_static/test_build_strategy.py b/test/dygraph_to_static/test_build_strategy.py index 39f4504375467..14e29b9ef4508 100644 --- a/test/dygraph_to_static/test_build_strategy.py +++ b/test/dygraph_to_static/test_build_strategy.py @@ -15,11 +15,13 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest from test_resnet import ResNetHelper import paddle +@dy2static_unittest class TestResnetWithPass(unittest.TestCase): def setUp(self): self.build_strategy = paddle.static.BuildStrategy() @@ -64,6 +66,7 @@ def verify_predict(self): ), ) + @ast_only_test def test_resnet(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) @@ -77,6 +80,7 @@ def test_resnet(self): ) self.verify_predict() + @ast_only_test def test_in_static_mode_mkldnn(self): paddle.fluid.set_flags({'FLAGS_use_mkldnn': True}) try: diff --git a/test/dygraph_to_static/test_convert_call.py b/test/dygraph_to_static/test_convert_call.py index 11f947d183243..7b54ea5956134 100644 --- a/test/dygraph_to_static/test_convert_call.py +++ b/test/dygraph_to_static/test_convert_call.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle import paddle.jit.dy2static as _jst @@ -252,6 +253,7 @@ def test_code(self): ) +@dy2static_unittest class TestNotToConvert2(TestRecursiveCall2): def set_func(self): self.net = NotToStaticHelper() @@ -264,7 +266,9 @@ def test_conversion_options(self): self.assertIsNotNone(options) self.assertTrue(options.not_convert) + @ast_only_test def test_code(self): + self.dygraph_func = paddle.jit.to_static(self.net.sum) # check 'if statement' is not converted self.assertIn("if x.shape[0] > 1", self.dygraph_func.code) @@ -277,19 +281,23 @@ def forward(self, x): return x +@dy2static_unittest class TestConvertPaddleAPI(unittest.TestCase): + @ast_only_test def test_functional_api(self): func = paddle.nn.functional.relu func = paddle.jit.to_static(func) self.assertNotIn("_jst.IfElse", func.code) self.assertIn("if in_dynamic_mode()", func.code) + @ast_only_test def test_class_api(self): bn = paddle.nn.SyncBatchNorm(2) paddle.jit.to_static(bn) self.assertNotIn("_jst.IfElse", bn.forward.code) self.assertIn("if in_dynamic_mode()", bn.forward.code) + @ast_only_test def test_class_patch_api(self): paddle.nn.SyncBatchNorm.forward = forward bn = paddle.nn.SyncBatchNorm(2)