Skip to content

Commit

Permalink
fix some while test (#60905)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjiyi authored Jan 18, 2024
1 parent b8c2a16 commit d64ec15
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def _setitem_static(x, indices, values):
ends = paddle.utils.get_int_tensor_list(ends)
if isinstance(steps, (list, tuple)):
if paddle.utils._contain_var(steps):
ends = paddle.utils.get_int_tensor_list(steps)
steps = paddle.utils.get_int_tensor_list(steps)

if value_tensor is None:
output = paddle._C_ops.set_value_(
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,7 @@ def test_conv3d_transpose(self):
conv3d1.bias.numpy(), conv3d2.bias.numpy()
)

@test_with_pir_api
def test_while_loop(self):
with self.static_graph():
i = paddle.tensor.fill_constant(shape=[1], dtype='int64', value=0)
Expand Down
7 changes: 5 additions & 2 deletions test/legacy_test/test_set_value_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import paddle
from paddle.base import core
from paddle.base.layer_helper import LayerHelper
from paddle.pir_utils import test_with_pir_api


class TestSetValueBase(unittest.TestCase):
Expand Down Expand Up @@ -57,12 +58,13 @@ def _get_answer(self):
class TestSetValueApi(TestSetValueBase):
def _run_static(self):
paddle.enable_static()
with paddle.static.program_guard(self.program):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
x = self._call_setitem_static_api(x)

exe = paddle.static.Executor(paddle.CPUPlace())
out = exe.run(self.program, fetch_list=[x])
out = exe.run(main_program, fetch_list=[x])
paddle.disable_static()
return out

Expand All @@ -74,6 +76,7 @@ def _run_dynamic(self):
paddle.enable_static()
return out

@test_with_pir_api
def test_api(self):
static_out = self._run_static()
dynamic_out = self._run_dynamic()
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_tensor_array_to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import paddle
from paddle import base
from paddle.base import Program, core, program_guard
from paddle.pir_utils import test_with_pir_api
from paddle.tensor.manipulation import tensor_array_to_tensor

paddle.enable_static()
Expand Down Expand Up @@ -256,6 +257,7 @@ def test_case(self):
for s, d in zip(outs_static, outs_dynamic):
np.testing.assert_array_equal(s, d.numpy())

@test_with_pir_api
def test_while_loop_case(self):
with base.dygraph.guard():
zero = paddle.tensor.fill_constant(
Expand Down

0 comments on commit d64ec15

Please sign in to comment.