Skip to content

Commit

Permalink
[Dy2St] pir dy2st unittest verification - Part -2 (#59370)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: SigureMo <[email protected]>
  • Loading branch information
DrRyanHuang and SigureMo authored Nov 29, 2023
1 parent 4f23e7e commit 3b40279
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 39 deletions.
31 changes: 18 additions & 13 deletions test/dygraph_to_static/test_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,17 @@
import numpy
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
)

import paddle
from paddle import base
from paddle.jit.api import to_static


@paddle.jit.to_static
def dyfunc_assert_variable(x):
x_v = base.dygraph.to_variable(x)
assert x_v


@to_static
def dyfunc_assert_non_variable(x=True):
assert x

Expand All @@ -51,31 +47,40 @@ def _run_dy_static(self, func, x, with_exception):
self._run(func, x, with_exception, True)
self._run(func, x, with_exception, False)

@test_ast_only
def test_non_variable(self):
self._run_dy_static(
dyfunc_assert_non_variable, x=False, with_exception=True
paddle.jit.to_static(dyfunc_assert_non_variable),
x=False,
with_exception=True,
)
self._run_dy_static(
dyfunc_assert_non_variable, x=True, with_exception=False
paddle.jit.to_static(dyfunc_assert_non_variable),
x=True,
with_exception=False,
)

@test_ast_only
def test_bool_variable(self):
self._run_dy_static(
dyfunc_assert_variable, x=numpy.array([False]), with_exception=True
paddle.jit.to_static(dyfunc_assert_variable),
x=numpy.array([False]),
with_exception=True,
)
self._run_dy_static(
dyfunc_assert_variable, x=numpy.array([True]), with_exception=False
paddle.jit.to_static(dyfunc_assert_variable),
x=numpy.array([True]),
with_exception=False,
)

@test_ast_only
def test_int_variable(self):
self._run_dy_static(
dyfunc_assert_variable, x=numpy.array([0]), with_exception=True
paddle.jit.to_static(dyfunc_assert_variable),
x=numpy.array([0]),
with_exception=True,
)
self._run_dy_static(
dyfunc_assert_variable, x=numpy.array([1]), with_exception=False
paddle.jit.to_static(dyfunc_assert_variable),
x=numpy.array([1]),
with_exception=False,
)


Expand Down
29 changes: 17 additions & 12 deletions test/dygraph_to_static/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt_and_pir,
)

import paddle
from paddle import base
from paddle.jit import to_static

PLACE = base.CUDAPlace(0) if base.is_compiled_with_cuda() else base.CPUPlace()

Expand Down Expand Up @@ -79,7 +81,6 @@ def __init__(self, batch_size=64, hidden_size=16, output_size=16):
self.output_size = output_size
self.sub_net = SubNetWithDict(hidden_size, output_size)

@to_static
def forward(self, input, max_len=4):
input = base.dygraph.to_variable(input)
cache = {
Expand Down Expand Up @@ -135,17 +136,19 @@ def _run_dygraph(self):
def train(self, to_static=False):
paddle.jit.enable_to_static(to_static)
with base.dygraph.guard(PLACE):
net = MainNetWithDict(batch_size=self.batch_size)
net = paddle.jit.to_static(
MainNetWithDict(batch_size=self.batch_size)
)
ret = net(self.x)
return ret.numpy()

@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
self.assertTrue((self._run_dygraph() == self._run_static()).all())


# Tests for dict pop
@paddle.jit.to_static
def test_dic_pop(x):
def test_dict_pop(x):
x = paddle.to_tensor(x)
dict_a = {"red": 0, "green": 1, "blue": 2}

Expand All @@ -156,8 +159,7 @@ def test_dic_pop(x):
return out


@paddle.jit.to_static
def test_dic_pop_2(x):
def test_dict_pop_2(x):
x = paddle.to_tensor(x)
dict_a = {"red": x, "green": x + 1, "blue": x + 3}

Expand All @@ -179,7 +181,7 @@ def setUp(self):
self._set_test_func()

def _set_test_func(self):
self.dygraph_func = test_dic_pop
self.dygraph_func = paddle.jit.to_static(test_dict_pop)

def _run_static(self):
return self._run(to_static=True)
Expand All @@ -194,6 +196,7 @@ def _run(self, to_static):

return result.numpy()

@test_legacy_and_pt_and_pir
def test_transformed_result(self):
dygraph_res = self._run_dygraph()
static_res = self._run_static()
Expand All @@ -207,14 +210,13 @@ def test_transformed_result(self):

class TestDictPop2(TestDictPop):
def _set_test_func(self):
self.dygraph_func = test_dic_pop_2
self.dygraph_func = paddle.jit.to_static(test_dict_pop_2)


class NetWithDictPop(paddle.nn.Layer):
def __init__(self):
super().__init__()

@to_static
def forward(self, x, **kwargs):
x = paddle.to_tensor(x)
y = kwargs.pop('y', None)
Expand All @@ -233,10 +235,11 @@ def setUp(self):
def train(self, to_static=False):
paddle.jit.enable_to_static(to_static)
with base.dygraph.guard(PLACE):
net = NetWithDictPop()
net = paddle.jit.to_static(NetWithDictPop())
ret = net(z=0, x=self.x, y=True)
return ret.numpy()

@test_legacy_and_pt_and_pir
def test_ast_to_func(self):
dygraph_result = self._run_dygraph()
static_result = self._run_static()
Expand All @@ -248,6 +251,7 @@ def test_ast_to_func(self):


class TestDictCmpInFor(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_with_for(self):
def func():
pos = [1, 3]
Expand All @@ -264,6 +268,7 @@ def func():

self.assertEqual(paddle.jit.to_static(func)()['minus'], 8)

@test_legacy_and_pt_and_pir
def test_with_for_enumerate(self):
def func():
pos = [1, 3]
Expand Down
38 changes: 24 additions & 14 deletions test/dygraph_to_static/test_isinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,34 @@ def train(model, to_static):
class TestIsinstance(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_isinstance_simple_return_layer(self):
model = paddle.jit.to_static(IsInstanceLayer(SimpleReturnLayer()))
self._test_model(model)
model_creator = lambda: paddle.jit.to_static(
IsInstanceLayer(SimpleReturnLayer())
)
self._test_model(model_creator)

@test_legacy_and_pt_and_pir
def test_isinstance_add_attr_layer(self):
model = paddle.jit.to_static(IsInstanceLayer(AddAttrLayer()))
self._test_model(model)
model_creator = lambda: paddle.jit.to_static(
IsInstanceLayer(AddAttrLayer())
)
self._test_model(model_creator)

@test_legacy_and_pt_and_pir
def test_sequential_layer(self):
layers = []
for i in range(5):
layers.append(SimpleReturnLayer())
layers.append(AddAttrLayer())
model = paddle.jit.to_static(SequentialLayer(layers))
self._test_model(model)

def _test_model(self, model):
st_out = train(model, to_static=True)
dy_out = train(model, to_static=False)
def model_creator():
layers = []
for i in range(5):
layers.append(SimpleReturnLayer())
layers.append(AddAttrLayer())
return paddle.jit.to_static(SequentialLayer(layers))

self._test_model(model_creator)

def _test_model(self, model_creator):
st_model = model_creator()
st_out = train(st_model, to_static=True)
dy_model = model_creator()
dy_out = train(dy_model, to_static=False)
np.testing.assert_allclose(
dy_out,
st_out,
Expand Down

0 comments on commit 3b40279

Please sign in to comment.