Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] pir dy2st unittest verification - Part 5 #58965

Merged
merged 2 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions test/dygraph_to_static/test_ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir,
test_legacy_and_pir_api,
)
from ifelse_simple_func import (
dyfunc_with_if_else,
Expand Down Expand Up @@ -48,6 +48,7 @@ def _ast2func(self, func):
return transformed_func

@test_ast_only
@test_legacy_and_pir_api
def test_ast2func(self):
def func(x, y):
return x + y
Expand All @@ -56,19 +57,19 @@ def func(x, y):
self.assertEqual(func(x, y), self._ast2func(func)(x, y))

@test_ast_only
@test_legacy_and_pir_api
def test_ast2func_dygraph(self):
paddle.disable_static()
funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else]
x_data = np.random.random([10, 16]).astype('float32')
for func in funcs:
with base.dygraph.guard():
x_v = base.dygraph.to_variable(x_data)
true_ret = func(x_v).numpy()
test_ret = self._ast2func(func)(x_v).numpy()
self.assertTrue((true_ret == test_ret).all())
x_v = base.dygraph.to_variable(x_data)
true_ret = func(x_v).numpy()
test_ret = self._ast2func(func)(x_v).numpy()
self.assertTrue((true_ret == test_ret).all())

@test_legacy_and_pir
@test_ast_only
@test_legacy_and_pir_api
def test_ast2func_static(self):
paddle.enable_static()

Expand All @@ -83,11 +84,12 @@ def func(x):
x_v = paddle.assign(x_data)
true_ret = func(x_v)
test_ret = self._ast2func(func)(x_v)
exe = base.Executor(base.CPUPlace())
exe = base.Executor(paddle.CPUPlace())
ret = exe.run(main_program, fetch_list=[true_ret, test_ret])
self.assertTrue((ret[0] == ret[1]).all())

@test_ast_only
@test_legacy_and_pir_api
def test_ast2func_error(self):
with self.assertRaises(Exception) as e:
self.assertRaises(TypeError, ast_to_func("x = a + b", 'foo'))
Expand Down
42 changes: 22 additions & 20 deletions test/dygraph_to_static/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,30 @@
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir,
test_legacy_and_pir_exe_and_pir_api,
)

from paddle import base
from paddle.jit.api import to_static
import paddle
from paddle.base.dygraph import to_variable

SEED = 2020
np.random.seed(SEED)


def test_bool_cast(x):
x = base.dygraph.to_variable(x)
x = to_variable(x)
x = bool(x)
return x


def test_int_cast(x):
x = base.dygraph.to_variable(x)
x = to_variable(x)
x = int(x)
return x


def test_float_cast(x):
x = base.dygraph.to_variable(x)
x = to_variable(x)
x = float(x)
return x

Expand All @@ -52,7 +53,7 @@ def test_not_var_cast(x):


def test_mix_cast(x):
x = base.dygraph.to_variable(x)
x = to_variable(x)
x = int(x)
x = float(x)
x = bool(x)
Expand All @@ -63,12 +64,11 @@ def test_mix_cast(x):
class TestCastBase(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.prepare()
self.set_func()

def prepare(self):
self.input_shape = (16, 32)
Expand All @@ -81,16 +81,16 @@ def prepare(self):
self.cast_dtype = 'bool'

def set_func(self):
self.func = to_static(full_graph=True)(test_bool_cast)
self.func = paddle.jit.to_static(full_graph=True)(test_bool_cast)

def do_test(self):
with base.dygraph.guard():
res = self.func(self.input)
return res
res = self.func(self.input)
return res

@test_ast_only # TODO: add new sot only test.
@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_cast_result(self):
self.set_func()
res = self.do_test().numpy()
self.assertTrue(
res.dtype == self.cast_dtype,
Expand Down Expand Up @@ -119,7 +119,7 @@ def prepare(self):
self.cast_dtype = 'int32'

def set_func(self):
self.func = to_static(full_graph=True)(test_int_cast)
self.func = paddle.jit.to_static(full_graph=True)(test_int_cast)


class TestFloatCast(TestCastBase):
Expand All @@ -134,7 +134,7 @@ def prepare(self):
self.cast_dtype = 'float32'

def set_func(self):
self.func = to_static(full_graph=True)(test_float_cast)
self.func = paddle.jit.to_static(full_graph=True)(test_float_cast)


class TestMixCast(TestCastBase):
Expand All @@ -152,11 +152,12 @@ def prepare(self):
self.cast_dtype = 'float32'

def set_func(self):
self.func = to_static(full_graph=True)(test_mix_cast)
self.func = paddle.jit.to_static(full_graph=True)(test_mix_cast)

@test_ast_only # TODO: add new symbolic only test.
@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_cast_result(self):
self.set_func()
res = self.do_test().numpy()
self.assertTrue(
res.dtype == self.cast_dtype,
Expand Down Expand Up @@ -184,11 +185,12 @@ def prepare(self):
self.cast_dtype = 'int'

def set_func(self):
self.func = to_static(full_graph=True)(test_not_var_cast)
self.func = paddle.jit.to_static(full_graph=True)(test_not_var_cast)

@test_ast_only
@test_legacy_and_pir
def test_cast_result(self):
self.set_func()
# breakpoint()
# print("run once!!!")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#58959 预计今天就能合已经合了,下一个 PR 可以尝试开启这个单测,看看能不能过

另外这两行注释下个 PR 也可以删掉

res = self.do_test()
Expand Down
4 changes: 2 additions & 2 deletions test/dygraph_to_static/test_cinn_prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir,
test_legacy_and_pir_exe_and_pir_api,
)

import paddle
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_cinn_prim(self):


class TestBackend(Dy2StTestBase):
@test_legacy_and_pir
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最终态尚未打通 CINN,这里能跑过应该是 CINN 没有生效,这个标记一下「编译器」未支持吧

这个单测可以先开着

@test_legacy_and_pir_exe_and_pir_api
def test_backend(self):
x = paddle.randn([2, 4])
out1 = self.forward(x, 'CINN')
Expand Down
6 changes: 5 additions & 1 deletion test/dygraph_to_static/test_tensor_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import unittest

import numpy as np
from dygraph_to_static_utils_new import Dy2StTestBase
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_legacy_and_pir_exe_and_pir_api,
)

import paddle
from paddle import nn
Expand Down Expand Up @@ -94,6 +97,7 @@ def h(g):
loss.backward()
np.testing.assert_allclose(x.grad.numpy(), x_jit.grad.numpy())

@test_legacy_and_pir_exe_and_pir_api
def test_hook_in_init_for_layer(self):
def hook(grad):
return grad * 2
Expand Down
3 changes: 2 additions & 1 deletion test/dygraph_to_static/test_variable_trans_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

import unittest

from dygraph_to_static_utils_new import Dy2StTestBase
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir_api

from paddle.jit.dy2static.utils import ast_to_source_code
from paddle.jit.dy2static.variable_trans_func import create_fill_constant_node


class TestVariableTransFunc(Dy2StTestBase):
@test_legacy_and_pir_api
def test_create_fill_constant_node(self):
node = create_fill_constant_node("a", 1.0)
source = "a = paddle.full(shape=[1], dtype='float64', fill_value=1.0, name='a')"
Expand Down