Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#6 from 0x45f/fix-ut
Browse files Browse the repository at this point in the history
Polish ut
  • Loading branch information
2742195759 authored Jun 15, 2023
2 parents d38a124 + fe39a99 commit 8b8623a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
4 changes: 4 additions & 0 deletions test/dygraph_to_static/test_build_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from test_resnet import ResNetHelper

import paddle


@dy2static_unittest
class TestResnetWithPass(unittest.TestCase):
def setUp(self):
self.build_strategy = paddle.static.BuildStrategy()
Expand Down Expand Up @@ -64,6 +66,7 @@ def verify_predict(self):
),
)

@ast_only_test
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand All @@ -77,6 +80,7 @@ def test_resnet(self):
)
self.verify_predict()

@ast_only_test
def test_in_static_mode_mkldnn(self):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
Expand Down
8 changes: 8 additions & 0 deletions test/dygraph_to_static/test_convert_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest

import paddle
import paddle.jit.dy2static as _jst
Expand Down Expand Up @@ -252,6 +253,7 @@ def test_code(self):
)


@dy2static_unittest
class TestNotToConvert2(TestRecursiveCall2):
def set_func(self):
self.net = NotToStaticHelper()
Expand All @@ -264,7 +266,9 @@ def test_conversion_options(self):
self.assertIsNotNone(options)
self.assertTrue(options.not_convert)

@ast_only_test
def test_code(self):
self.dygraph_func = paddle.jit.to_static(self.net.sum)
# check 'if statement' is not converted
self.assertIn("if x.shape[0] > 1", self.dygraph_func.code)

Expand All @@ -277,19 +281,23 @@ def forward(self, x):
return x


@dy2static_unittest
class TestConvertPaddleAPI(unittest.TestCase):
@ast_only_test
def test_functional_api(self):
func = paddle.nn.functional.relu
func = paddle.jit.to_static(func)
self.assertNotIn("_jst.IfElse", func.code)
self.assertIn("if in_dynamic_mode()", func.code)

@ast_only_test
def test_class_api(self):
bn = paddle.nn.SyncBatchNorm(2)
paddle.jit.to_static(bn)
self.assertNotIn("_jst.IfElse", bn.forward.code)
self.assertIn("if in_dynamic_mode()", bn.forward.code)

@ast_only_test
def test_class_patch_api(self):
paddle.nn.SyncBatchNorm.forward = forward
bn = paddle.nn.SyncBatchNorm(2)
Expand Down

0 comments on commit 8b8623a

Please sign in to comment.