From 355ee48af95156c1fae34b792b38c7fe9593a567 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Wed, 4 Jan 2023 06:18:57 +0000 Subject: [PATCH 01/13] Add 0d tensor test cases to test_cond.py --- .../paddle/fluid/tests/unittests/test_cond.py | 173 ++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_cond.py b/python/paddle/fluid/tests/unittests/test_cond.py index 3176ace0a3813..b2af3112e4799 100644 --- a/python/paddle/fluid/tests/unittests/test_cond.py +++ b/python/paddle/fluid/tests/unittests/test_cond.py @@ -68,6 +68,115 @@ def false_func(): np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05 ) + def test_return_0d_tensor(self): + """ + pseudocode: + + if 0.23 >= 0.1: + return 2 + else: + return -1 + """ + + paddle.enable_static() + + def true_func(): + return paddle.full(shape=[], dtype='int32', fill_value=2) + + def false_func(): + return paddle.full(shape=[], dtype='int32', fill_value=-1) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[1], dtype='float32', fill_value=0.1) + y = paddle.full(shape=[1], dtype='float32', fill_value=0.23) + pred = paddle.greater_equal(y, x) + out = paddle.static.nn.cond(pred, true_func, false_func) + # out is one tensor + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + (ret,) = exe.run(main_program, fetch_list=[out.name]) + np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05) + + def test_0d_tensor_as_cond(self): + """ + pseudocode: + + if 0.23 >= 0.1: + return 2 + else: + return -1 + """ + + paddle.enable_static() + + def true_func(): + return paddle.full(shape=[3, 3], dtype='int32', fill_value=2) + + def false_func(): + return paddle.full(shape=[3, 3], dtype='int32', fill_value=-1) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[], dtype='float32', fill_value=0.1) + y = paddle.full(shape=[], dtype='float32', fill_value=0.23) + pred = paddle.greater_equal(y, x) + out = paddle.static.nn.cond(pred, true_func, false_func) + # out is one tensor + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + (ret,) = exe.run(main_program, fetch_list=[out.name]) + np.testing.assert_allclose( + np.asarray(ret), np.full((3, 3), 2, np.int32), rtol=1e-05 + ) + + def test_0d_tensor_backward(self): + """ + pseudocode: + + a = -2.0 + if a >= 0: + return a + else: + return -a + """ + + paddle.enable_static() + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + a = paddle.full(shape=[], dtype='float32', fill_value=-2.0) + a.stop_gradient = False + out = paddle.static.nn.cond(a >= 0, lambda: a, lambda: -a) + append_backward(out) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + ret = exe.run(main_program, fetch_list=[out.name, a.grad_name]) + np.testing.assert_allclose( + np.asarray(ret[0]), np.array(2.0), rtol=1e-05 + ) + np.testing.assert_allclose( + np.asarray(ret[1]), np.array(-1.0), rtol=1e-05 + ) + def test_return_var_tuple(self): """ pseudocode: @@ -358,6 +467,70 @@ def greater_equal_branch(i, a): self.assertEqual(ret[0][0], expected_ret) self.assertEqual(ret[1][0], expected_a_grad) + def test_cond_inside_cond_0d_tensor(self): + """ + pseudocode: + i = 3.0 + a = 2 * i + if i < 5: + if i >= 3: + return a + a + else: + return a - a + else: + if i < 8: + return a * a + else: + return a / a + """ + + paddle.enable_static() + + def less_than_branch(i, a): + return paddle.static.nn.cond( + i >= 3.0, + lambda: a + 1, + lambda: 1 - a, + ) + + def greater_equal_branch(i, a): + return paddle.static.nn.cond( + i < 8.0, + lambda: a * 2, + lambda: a / 2, + ) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + i = paddle.full(fill_value=3.0, shape=[], dtype='float32') + i.stop_gradient = False + a = 2.0 * i + out = paddle.static.nn.cond( + i < 5.0, + lambda: less_than_branch(i, a), + lambda: greater_equal_branch(i, a), + ) + mean = paddle.mean(out) + append_backward(out) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + ret = exe.run( + main_program, + fetch_list=[out.name, i.grad_name], + ) + np.testing.assert_allclose( + np.asarray(ret[0]), np.array(7.0), rtol=1e-05 + ) + np.testing.assert_allclose( + np.asarray(ret[1]), np.array(2.0), rtol=1e-05 + ) + def test_cond_op_in_condition(self): paddle.enable_static() main_program = fluid.Program() From 4590ff1575a6e61a01ff37cd8e605cbc09d4676d Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Wed, 4 Jan 2023 09:13:09 +0000 Subject: [PATCH 02/13] Add 0d tensor test case for switch_case and case --- .../paddle/fluid/tests/unittests/test_case.py | 161 ++++++++++++++++ .../fluid/tests/unittests/test_switch_case.py | 182 ++++++++++++++++++ 2 files changed, 343 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_case.py b/python/paddle/fluid/tests/unittests/test_case.py index e5980abea5d1e..9123b4b009d18 100644 --- a/python/paddle/fluid/tests/unittests/test_case.py +++ b/python/paddle/fluid/tests/unittests/test_case.py @@ -89,6 +89,67 @@ def fn_3(): np.testing.assert_allclose(res[3], 2, rtol=1e-05) np.testing.assert_allclose(res[4], 2, rtol=1e-05) + def test_0d_tensor(self): + def fn_1(): + return paddle.full(shape=[], dtype='int32', fill_value=1) + + def fn_2(): + return paddle.full(shape=[], dtype='int32', fill_value=2) + + def fn_3(): + return paddle.full(shape=[], dtype='int32', fill_value=3) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[], dtype='float32', fill_value=0.3) + y = paddle.full(shape=[], dtype='float32', fill_value=0.1) + z = paddle.full(shape=[], dtype='float32', fill_value=0.2) + pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1 + pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3 + + # call fn_1 + out_0 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3 + ) + + # call fn_2 + out_1 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3 + ) + + # call default fn_3 + out_2 = paddle.static.nn.control_flow.case( + pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3 + ) + + # no default, call fn_2 + out_3 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_1, fn_2)] + ) + + # no default, call fn_2. but pred_2 is false + out_4 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_2, fn_2)] + ) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run( + main_program, fetch_list=[out_0, out_1, out_2, out_3, out_4] + ) + + np.testing.assert_allclose(res[0], 1, rtol=1e-05) + np.testing.assert_allclose(res[1], 2, rtol=1e-05) + np.testing.assert_allclose(res[2], 3, rtol=1e-05) + np.testing.assert_allclose(res[3], 2, rtol=1e-05) + np.testing.assert_allclose(res[4], 2, rtol=1e-05) + def test_return_var_tuple(self): def fn_1(): return layers.fill_constant( @@ -236,6 +297,106 @@ def fn_3(): np.testing.assert_allclose(res[1], 2, rtol=1e-05) np.testing.assert_allclose(res[2], 3, rtol=1e-05) + def test_nested_0d_tensor(self): + def fn_1(x=1): + var_5 = paddle.full(shape=[], dtype='int32', fill_value=5) + var_6 = paddle.full(shape=[], dtype='int32', fill_value=6) + out = paddle.static.nn.control_flow.case( + pred_fn_pairs=[ + ( + var_5 < var_6, + partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=x, + ), + ), + ( + var_5 == var_6, + partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=x, + ), + ), + ] + ) + return out + + def fn_2(x=2): + var_5 = paddle.full(shape=[], dtype='int32', fill_value=5) + var_6 = paddle.full(shape=[], dtype='int32', fill_value=6) + out = paddle.static.nn.control_flow.case( + pred_fn_pairs=[ + (var_5 < var_6, partial(fn_1, x=x)), + ( + var_5 == var_6, + partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=x, + ), + ), + ] + ) + return out + + def fn_3(): + var_5 = paddle.full(shape=[], dtype='int32', fill_value=5) + var_6 = paddle.full(shape=[], dtype='int32', fill_value=6) + out = paddle.static.nn.control_flow.case( + pred_fn_pairs=[ + (var_5 < var_6, partial(fn_2, x=3)), + ( + var_5 == var_6, + partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=7, + ), + ), + ] + ) + return out + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[], dtype='float32', fill_value=0.3) + y = paddle.full(shape=[], dtype='float32', fill_value=0.1) + z = paddle.full(shape=[], dtype='float32', fill_value=0.2) + pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1 + pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3 + + out_1 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3 + ) + + out_2 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3 + ) + + out_3 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(x == y, fn_1), (x == z, fn_2)], default=fn_3 + ) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run(main_program, fetch_list=[out_1, out_2, out_3]) + + np.testing.assert_allclose(res[0], 1, rtol=1e-05) + np.testing.assert_allclose(res[1], 2, rtol=1e-05) + np.testing.assert_allclose(res[2], 3, rtol=1e-05) + class TestAPICase_Error(unittest.TestCase): def test_error(self): diff --git a/python/paddle/fluid/tests/unittests/test_switch_case.py b/python/paddle/fluid/tests/unittests/test_switch_case.py index 119b5ac285f73..2ddbd0f7ff051 100644 --- a/python/paddle/fluid/tests/unittests/test_switch_case.py +++ b/python/paddle/fluid/tests/unittests/test_switch_case.py @@ -114,6 +114,93 @@ def fn_3(): err_msg='result is {} but answer is {}'.format(res[0], 2), ) + def test_0d_tensor(self): + def fn_1(): + return paddle.full(shape=[], dtype='int32', fill_value=1) + + def fn_2(): + return paddle.full(shape=[], dtype='int32', fill_value=2) + + def fn_3(): + return paddle.full(shape=[], dtype='int32', fill_value=3) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + index_1 = paddle.full(shape=[], dtype='int32', fill_value=1) + index_2 = paddle.full(shape=[], dtype='int32', fill_value=2) + index_5 = paddle.full(shape=[], dtype='int32', fill_value=5) + + # call fn_1 + out_0 = paddle.static.nn.switch_case( + branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3} + ) + + # call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3} + out_1 = paddle.static.nn.switch_case( + branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3) + ) + + # call default fn_3 + out_2 = paddle.static.nn.switch_case( + branch_index=index_5, + branch_fns=((1, fn_1), (2, fn_2)), + default=fn_3, + ) + + # no default, call fn_2 + out_3 = paddle.static.nn.switch_case( + branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)] + ) + + # no default, call fn_2 but branch_index is 5 + out_4 = paddle.static.nn.switch_case( + branch_index=index_5, + branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)], + ) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run( + main_program, fetch_list=[out_0, out_1, out_2, out_3, out_4] + ) + + np.testing.assert_allclose( + res[0], + 1, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 1), + ) + np.testing.assert_allclose( + res[1], + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 2), + ) + np.testing.assert_allclose( + res[2], + 3, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 3), + ) + np.testing.assert_allclose( + res[3], + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 2), + ) + np.testing.assert_allclose( + res[4], + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 2), + ) + def test_return_var_tuple(self): def fn_1(): return layers.fill_constant( @@ -257,6 +344,101 @@ def fn_3(): err_msg='result is {} but answer is {}'.format(res[2], 3), ) + def test_nested_switch_0d_tensor(self): + def fn_1(x=1): + out = paddle.static.nn.switch_case( + branch_index=paddle.full(shape=[], dtype='int32', fill_value=x), + branch_fns={ + 1: partial( + paddle.full, shape=[], dtype='int32', fill_value=1 + ), + x: partial( + paddle.full, shape=[], dtype='int32', fill_value=x + ), + }, + ) + return out + + def fn_2(x=2): + out = paddle.static.nn.switch_case( + branch_index=paddle.full(shape=[], dtype='int32', fill_value=2), + branch_fns={ + 1: partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=1, + ), + 2: partial(fn_1, x=x), + }, + ) + return out + + def fn_3(): + out = paddle.static.nn.switch_case( + branch_index=paddle.full(shape=[], dtype='int32', fill_value=3), + branch_fns={ + 1: partial( + paddle.full, + shape=[], + dtype='int32', + fill_value=1, + ), + 3: partial(fn_2, x=3), + }, + ) + return out + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + index_1 = fluid.data(name="index_1", shape=[1], dtype='uint8') + index_2 = paddle.full(shape=[], dtype='int32', fill_value=2) + index_3 = paddle.full(shape=[], dtype='int64', fill_value=3) + + out_1 = paddle.static.nn.switch_case( + branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3} + ) + out_2 = paddle.static.nn.switch_case( + branch_index=index_2, branch_fns={1: fn_1, 2: fn_2, 3: fn_3} + ) + + out_3 = paddle.static.nn.switch_case( + branch_index=index_3, branch_fns={1: fn_1, 2: fn_2, 3: fn_3} + ) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run( + main_program, + feed={"index_1": np.array([1], dtype="uint8")}, + fetch_list=[out_1, out_2, out_3], + ) + + np.testing.assert_allclose( + res[0], + 1, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[0], 1), + ) + np.testing.assert_allclose( + res[1], + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[1], 2), + ) + np.testing.assert_allclose( + res[2], + 3, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(res[2], 3), + ) + # test TypeError and ValueError of api switch_case class TestAPISwitchCase_Error(unittest.TestCase): From 6c3f49a8cfa11e2203434b791acf99af3bbbf536 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Wed, 4 Jan 2023 09:18:41 +0000 Subject: [PATCH 03/13] Update comment --- python/paddle/fluid/tests/unittests/test_cond.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cond.py b/python/paddle/fluid/tests/unittests/test_cond.py index b2af3112e4799..9769aa8df430e 100644 --- a/python/paddle/fluid/tests/unittests/test_cond.py +++ b/python/paddle/fluid/tests/unittests/test_cond.py @@ -474,14 +474,14 @@ def test_cond_inside_cond_0d_tensor(self): a = 2 * i if i < 5: if i >= 3: - return a + a + return a + 1 else: - return a - a + return 1 - a else: if i < 8: - return a * a + return a * 2 else: - return a / a + return a / 2 """ paddle.enable_static() From 5b975d7714e07dedfd7fcbcea486c8d35857e33a Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Sun, 15 Jan 2023 13:13:00 +0000 Subject: [PATCH 04/13] Fix 0d Tensor in ConditionalBlockOp, Memcpy, and AddNKernel --- .../controlflow/conditional_block_op.cc | 59 ++++++++++++++++++- .../operators/controlflow/fetch_v2_op.cc | 1 + paddle/fluid/operators/memcpy_d2h_op.cc | 3 + paddle/phi/core/tensor_utils.cc | 4 +- paddle/phi/kernels/cpu/add_n_kernel.cc | 8 +++ paddle/phi/kernels/gpu/add_n_kernel.cu | 9 +++ paddle/phi/kernels/memcpy_kernel.cc | 3 + python/paddle/fluid/backward.py | 2 +- .../paddle/fluid/tests/unittests/test_cond.py | 32 +++++++++- python/paddle/static/nn/control_flow.py | 2 +- 10 files changed, 118 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index f11bf6612c8ca..5bcbb00c2087e 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -300,8 +300,53 @@ class ConditionalBlockGradOp : public ConditionalOp { } platform::DeviceContext *dev_ctx = platform::DeviceContextPool::Instance().Get(place); + /* + const phi::DenseTensor& inside_tensor = inside_var->Get(); + phi::DenseTensorMeta inside_meta = inside_tensor.meta(); + if (inside_tensor.numel() == 1 && inside_tensor.dims().size() == 0) { + VLOG(6) << "Huihuang found "<< inside_grad_name <<" is_scalar"; + inside_meta.is_scalar = true; + inside_var->GetMutable()->set_meta(inside_meta); + outside_var->GetMutable()->set_meta(inside_meta); + outside_var->GetMutable()->mutable_data(place, 1); + } + */ + VLOG(6) << "Huihuang debug, before assign: inside_var numel = " + << inside_var->Get().numel() + << ", dims = " + << inside_var->Get().dims().size() + << ", memory_size = " + << inside_var->Get().memory_size() + << ", initailized = " + << inside_var->Get().initialized() + << ", outside_var numel = " + << outside_var->Get().numel() + << ", dims = " + << outside_var->Get().dims().size() + << ", memory_size = " + << outside_var->Get().memory_size() + << ", initailized = " + << outside_var->Get().initialized(); + framework::VisitVarType(*inside_var, AssignFunctor(outside_var, *dev_ctx)); + + VLOG(6) << "Huihuang debug, after assign: inside_var numel = " + << inside_var->Get().numel() + << ", dims = " + << inside_var->Get().dims().size() + << ", memory_size = " + << inside_var->Get().memory_size() + << ", initailized = " + << inside_var->Get().initialized() + << ", outside_var numel = " + << outside_var->Get().numel() + << ", dims = " + << outside_var->Get().dims().size() + << ", memory_size = " + << outside_var->Get().memory_size() + << ", initailized = " + << outside_var->Get().initialized(); } // Assign zero to the grad_vars that are in outside_grads but not in // inside_grads @@ -342,6 +387,16 @@ class ConditionalBlockGradOp : public ConditionalOp { scope, input_var->Get(), outside_var->GetMutable()); + + VLOG(6) << "Huihuang debug, after assign zero: input_var numel = " + << input_var->Get().numel() + << ", dims = " + << input_var->Get().dims().size() + << ", outside_var numel = " + << outside_var->Get().numel() + << ", dims = " + << outside_var->Get().dims().size(); + } else if (input_var->IsType()) { PADDLE_ENFORCE_EQ(outside_var->IsType(), true, @@ -384,8 +439,10 @@ class ConditionalBlockGradOp : public ConditionalOp { if (!input_tensor.IsInitialized() || input_tensor.numel() == 0) { return; } + if (!input_tensor.meta().is_scalar) { + outside_tensor->Resize(input_tensor.dims()); + } VLOG(4) << "Assigning zero to " << outside_tensor; - outside_tensor->Resize(input_tensor.dims()); outside_tensor->mutable_data(place, input_tensor.dtype()); const platform::DeviceContext *dev_ctx = platform::DeviceContextPool::Instance().Get(place); diff --git a/paddle/fluid/operators/controlflow/fetch_v2_op.cc b/paddle/fluid/operators/controlflow/fetch_v2_op.cc index 5a99dd695c02b..29ca6b11c1827 100644 --- a/paddle/fluid/operators/controlflow/fetch_v2_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_v2_op.cc @@ -121,6 +121,7 @@ class FetchV2Op : public framework::OperatorWithKernel { class FetchV2Kernel { public: void operator()(const framework::ExecutionContext &ctx) const { + VLOG(4) << "Huihuang debug in MemcpyD2HKernel"; auto fetch_var_name = ctx.InputName("X"); auto *fetch_var = ctx.InputVar("X"); if (fetch_var == nullptr) { diff --git a/paddle/fluid/operators/memcpy_d2h_op.cc b/paddle/fluid/operators/memcpy_d2h_op.cc index 06af45d48506a..c1f78f2277e48 100644 --- a/paddle/fluid/operators/memcpy_d2h_op.cc +++ b/paddle/fluid/operators/memcpy_d2h_op.cc @@ -63,6 +63,7 @@ class MemcpyD2HInferVarType : public framework::VarTypeInference { class MemcpyD2HKernel { public: void operator()(const framework::ExecutionContext &ctx) const { + VLOG(4) << "Huihuang debug in MemcpyD2HKernel"; auto *x = ctx.InputVar("X"); if (x == nullptr) { return; @@ -75,7 +76,9 @@ class MemcpyD2HKernel { // Get dev_ctx from ExecutionContext, it's D2H stream auto &dev_ctx = ctx.device_context(); auto dst_place_type = ctx.Attr("dst_place_type"); + VLOG(4) << "Huihuang debug before MemcpyD2HFunctor"; framework::VisitVarType(*x, MemcpyD2HFunctor(out, dev_ctx, dst_place_type)); + VLOG(4) << "Huihuang debug after MemcpyD2HFunctor"; } }; diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index 467552032f0ad..db81e38ddffa8 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -51,7 +51,9 @@ void Copy(const Context& dev_ctx, VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " << dst_place; - dst->Resize(src.dims()); + dst->set_meta(src.meta()); + VLOG(3) << "src.numel() = " << src.numel() << ", dst->numel() = " + << dst->numel(); void* dst_ptr = nullptr; if (paddle::platform::is_cpu_place(dst_place)) { diff --git a/paddle/phi/kernels/cpu/add_n_kernel.cc b/paddle/phi/kernels/cpu/add_n_kernel.cc index 54506ccd54f5b..6081d2f3133b1 100644 --- a/paddle/phi/kernels/cpu/add_n_kernel.cc +++ b/paddle/phi/kernels/cpu/add_n_kernel.cc @@ -20,7 +20,15 @@ template void AddNKernel(const Context& dev_ctx, const std::vector& x, DenseTensor* out) { + VLOG(6) << "Huihuang debug CPU AddNKernel"; size_t in_num = x.size(); + for (const TensorBase *tb : x) { + if (tb->initialized() && DenseTensor::classof(tb)) { + auto* dt = static_cast(tb); + out->set_meta(dt->meta()); + break; + } + } dev_ctx.template Alloc(out); bool in_place = false; diff --git a/paddle/phi/kernels/gpu/add_n_kernel.cu b/paddle/phi/kernels/gpu/add_n_kernel.cu index f32ba597f5b68..3dd9e22a5a069 100644 --- a/paddle/phi/kernels/gpu/add_n_kernel.cu +++ b/paddle/phi/kernels/gpu/add_n_kernel.cu @@ -94,6 +94,15 @@ void AddNKernel(const Context &dev_ctx, grids = dim3(CEIL_DIV(length, tile_size), 1, 1); blocks = dim3(tile_size, 1, 1); }; + + for (const TensorBase *tb : x) { + if (tb->initialized() && DenseTensor::classof(tb)) { + auto* dt = static_cast(tb); + out->set_meta(dt->meta()); + break; + } + } + auto *out_ptr = dev_ctx.template Alloc(out); bool in_place = false; if (x.size() > 0 && x[0]->initialized() && DenseTensor::classof(x[0])) { diff --git a/paddle/phi/kernels/memcpy_kernel.cc b/paddle/phi/kernels/memcpy_kernel.cc index 521edc26af320..d9326dfa4f3bf 100644 --- a/paddle/phi/kernels/memcpy_kernel.cc +++ b/paddle/phi/kernels/memcpy_kernel.cc @@ -30,6 +30,7 @@ void MemcpyH2DKernel(const Context& dev_ctx, const DenseTensor& x, int dst_place_type, DenseTensor* out) { + VLOG(6) << "Huihuang debug in MemcpyH2DKernel"; PADDLE_ENFORCE_GE( dst_place_type, 0, @@ -41,7 +42,9 @@ void MemcpyH2DKernel(const Context& dev_ctx, errors::OutOfRange("dst_place_type only support 0-3, but got: %d", dst_place_type)); + VLOG(6) << "Huihuang debug in MemcpyH2DKernel before copy"; Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + VLOG(6) << "Huihuang debug in MemcpyH2DKernel after copy"; } template diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index df975d06a45d4..0b2990aacb4a8 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -390,7 +390,7 @@ def _create_loss_op_desc_(loss): {}, {"Out": [_append_grad_suffix_(loss.name)]}, { - "shape": [1], + "shape": list(loss.shape), "value": 1.0, "dtype": loss.dtype, "force_cpu": False, diff --git a/python/paddle/fluid/tests/unittests/test_cond.py b/python/paddle/fluid/tests/unittests/test_cond.py index 9769aa8df430e..1b1c08e51e2cd 100644 --- a/python/paddle/fluid/tests/unittests/test_cond.py +++ b/python/paddle/fluid/tests/unittests/test_cond.py @@ -103,6 +103,7 @@ def false_func(): exe = fluid.Executor(place) (ret,) = exe.run(main_program, fetch_list=[out.name]) np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05) + self.assertEqual(ret.shape, ()) def test_0d_tensor_as_cond(self): """ @@ -129,7 +130,7 @@ def false_func(): y = paddle.full(shape=[], dtype='float32', fill_value=0.23) pred = paddle.greater_equal(y, x) out = paddle.static.nn.cond(pred, true_func, false_func) - # out is one tensor + # out is a tensor place = ( fluid.CUDAPlace(0) @@ -168,14 +169,41 @@ def test_0d_tensor_backward(self): if core.is_compiled_with_cuda() else fluid.CPUPlace() ) + exe = fluid.Executor(place) ret = exe.run(main_program, fetch_list=[out.name, a.grad_name]) np.testing.assert_allclose( np.asarray(ret[0]), np.array(2.0), rtol=1e-05 ) + self.assertEqual(ret[0].shape, ()) np.testing.assert_allclose( np.asarray(ret[1]), np.array(-1.0), rtol=1e-05 ) + self.assertEqual(ret[1].shape, ()) + + def test_0d_tensor_dygraph(self): + """ + pseudocode: + + a = -2.0 + if a >= 0: + return a + else: + return -a + """ + paddle.disable_static() + a = paddle.full(shape=[], dtype='float32', fill_value=-2.0) + a.stop_gradient = False + out = paddle.static.nn.cond(a >= 0, lambda: a, lambda: -a) + out.backward() + + np.testing.assert_allclose(np.asarray(out), np.array(2.0), rtol=1e-05) + self.assertEqual(out.shape, []) + + np.testing.assert_allclose( + np.asarray(a.grad), np.array(-1.0), rtol=1e-05 + ) + self.assertEqual(a.grad.shape, []) def test_return_var_tuple(self): """ @@ -527,9 +555,11 @@ def greater_equal_branch(i, a): np.testing.assert_allclose( np.asarray(ret[0]), np.array(7.0), rtol=1e-05 ) + self.assertEqual(ret[0].shape, ()) np.testing.assert_allclose( np.asarray(ret[1]), np.array(2.0), rtol=1e-05 ) + self.assertEqual(ret[1].shape, ()) def test_cond_op_in_condition(self): paddle.enable_static() diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index d21d95b097e3b..03381b424a7a9 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -969,7 +969,7 @@ def false_func(): if _non_static_mode(): assert isinstance(pred, Variable), "The pred in cond must be Variable" assert pred.size == 1, "condition input's numel should be 1" - pred = pred.numpy()[0] + pred = pred.numpy().item() if pred: if true_fn is not None: if not callable(true_fn): From 5649b0e1d6d1622b95253118fd7232d8824c4dfb Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Sun, 15 Jan 2023 13:24:07 +0000 Subject: [PATCH 05/13] Remove some debug vlog --- .../controlflow/conditional_block_op.cc | 45 ------------------- .../operators/controlflow/fetch_v2_op.cc | 1 - paddle/fluid/operators/memcpy_d2h_op.cc | 3 -- paddle/phi/kernels/memcpy_kernel.cc | 3 -- 4 files changed, 52 deletions(-) diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 5bcbb00c2087e..885d924fc64eb 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -300,53 +300,8 @@ class ConditionalBlockGradOp : public ConditionalOp { } platform::DeviceContext *dev_ctx = platform::DeviceContextPool::Instance().Get(place); - /* - const phi::DenseTensor& inside_tensor = inside_var->Get(); - phi::DenseTensorMeta inside_meta = inside_tensor.meta(); - if (inside_tensor.numel() == 1 && inside_tensor.dims().size() == 0) { - VLOG(6) << "Huihuang found "<< inside_grad_name <<" is_scalar"; - inside_meta.is_scalar = true; - inside_var->GetMutable()->set_meta(inside_meta); - outside_var->GetMutable()->set_meta(inside_meta); - outside_var->GetMutable()->mutable_data(place, 1); - } - */ - VLOG(6) << "Huihuang debug, before assign: inside_var numel = " - << inside_var->Get().numel() - << ", dims = " - << inside_var->Get().dims().size() - << ", memory_size = " - << inside_var->Get().memory_size() - << ", initailized = " - << inside_var->Get().initialized() - << ", outside_var numel = " - << outside_var->Get().numel() - << ", dims = " - << outside_var->Get().dims().size() - << ", memory_size = " - << outside_var->Get().memory_size() - << ", initailized = " - << outside_var->Get().initialized(); - framework::VisitVarType(*inside_var, AssignFunctor(outside_var, *dev_ctx)); - - VLOG(6) << "Huihuang debug, after assign: inside_var numel = " - << inside_var->Get().numel() - << ", dims = " - << inside_var->Get().dims().size() - << ", memory_size = " - << inside_var->Get().memory_size() - << ", initailized = " - << inside_var->Get().initialized() - << ", outside_var numel = " - << outside_var->Get().numel() - << ", dims = " - << outside_var->Get().dims().size() - << ", memory_size = " - << outside_var->Get().memory_size() - << ", initailized = " - << outside_var->Get().initialized(); } // Assign zero to the grad_vars that are in outside_grads but not in // inside_grads diff --git a/paddle/fluid/operators/controlflow/fetch_v2_op.cc b/paddle/fluid/operators/controlflow/fetch_v2_op.cc index 29ca6b11c1827..5a99dd695c02b 100644 --- a/paddle/fluid/operators/controlflow/fetch_v2_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_v2_op.cc @@ -121,7 +121,6 @@ class FetchV2Op : public framework::OperatorWithKernel { class FetchV2Kernel { public: void operator()(const framework::ExecutionContext &ctx) const { - VLOG(4) << "Huihuang debug in MemcpyD2HKernel"; auto fetch_var_name = ctx.InputName("X"); auto *fetch_var = ctx.InputVar("X"); if (fetch_var == nullptr) { diff --git a/paddle/fluid/operators/memcpy_d2h_op.cc b/paddle/fluid/operators/memcpy_d2h_op.cc index c1f78f2277e48..06af45d48506a 100644 --- a/paddle/fluid/operators/memcpy_d2h_op.cc +++ b/paddle/fluid/operators/memcpy_d2h_op.cc @@ -63,7 +63,6 @@ class MemcpyD2HInferVarType : public framework::VarTypeInference { class MemcpyD2HKernel { public: void operator()(const framework::ExecutionContext &ctx) const { - VLOG(4) << "Huihuang debug in MemcpyD2HKernel"; auto *x = ctx.InputVar("X"); if (x == nullptr) { return; @@ -76,9 +75,7 @@ class MemcpyD2HKernel { // Get dev_ctx from ExecutionContext, it's D2H stream auto &dev_ctx = ctx.device_context(); auto dst_place_type = ctx.Attr("dst_place_type"); - VLOG(4) << "Huihuang debug before MemcpyD2HFunctor"; framework::VisitVarType(*x, MemcpyD2HFunctor(out, dev_ctx, dst_place_type)); - VLOG(4) << "Huihuang debug after MemcpyD2HFunctor"; } }; diff --git a/paddle/phi/kernels/memcpy_kernel.cc b/paddle/phi/kernels/memcpy_kernel.cc index d9326dfa4f3bf..521edc26af320 100644 --- a/paddle/phi/kernels/memcpy_kernel.cc +++ b/paddle/phi/kernels/memcpy_kernel.cc @@ -30,7 +30,6 @@ void MemcpyH2DKernel(const Context& dev_ctx, const DenseTensor& x, int dst_place_type, DenseTensor* out) { - VLOG(6) << "Huihuang debug in MemcpyH2DKernel"; PADDLE_ENFORCE_GE( dst_place_type, 0, @@ -42,9 +41,7 @@ void MemcpyH2DKernel(const Context& dev_ctx, errors::OutOfRange("dst_place_type only support 0-3, but got: %d", dst_place_type)); - VLOG(6) << "Huihuang debug in MemcpyH2DKernel before copy"; Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); - VLOG(6) << "Huihuang debug in MemcpyH2DKernel after copy"; } template From ff144865da49708c17fe6e867c77f76ee1f0919f Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Sun, 15 Jan 2023 13:26:13 +0000 Subject: [PATCH 06/13] Remove some logs --- .../operators/controlflow/conditional_block_op.cc | 10 ---------- paddle/phi/kernels/cpu/add_n_kernel.cc | 1 - 2 files changed, 11 deletions(-) diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 885d924fc64eb..676cb0f8328bc 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -342,16 +342,6 @@ class ConditionalBlockGradOp : public ConditionalOp { scope, input_var->Get(), outside_var->GetMutable()); - - VLOG(6) << "Huihuang debug, after assign zero: input_var numel = " - << input_var->Get().numel() - << ", dims = " - << input_var->Get().dims().size() - << ", outside_var numel = " - << outside_var->Get().numel() - << ", dims = " - << outside_var->Get().dims().size(); - } else if (input_var->IsType()) { PADDLE_ENFORCE_EQ(outside_var->IsType(), true, diff --git a/paddle/phi/kernels/cpu/add_n_kernel.cc b/paddle/phi/kernels/cpu/add_n_kernel.cc index 6081d2f3133b1..df8c0b7a55b0f 100644 --- a/paddle/phi/kernels/cpu/add_n_kernel.cc +++ b/paddle/phi/kernels/cpu/add_n_kernel.cc @@ -20,7 +20,6 @@ template void AddNKernel(const Context& dev_ctx, const std::vector& x, DenseTensor* out) { - VLOG(6) << "Huihuang debug CPU AddNKernel"; size_t in_num = x.size(); for (const TensorBase *tb : x) { if (tb->initialized() && DenseTensor::classof(tb)) { From 526d56983b6effba7ee672d548fc9d7cc6dc4edd Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Sun, 15 Jan 2023 13:40:08 +0000 Subject: [PATCH 07/13] Format some code due to local clang-format differ to CI --- paddle/phi/core/tensor_utils.cc | 4 ++-- paddle/phi/kernels/cpu/add_n_kernel.cc | 4 ++-- paddle/phi/kernels/gpu/add_n_kernel.cu | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index db81e38ddffa8..6c88aef43bcdd 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -52,8 +52,8 @@ void Copy(const Context& dev_ctx, << dst_place; dst->set_meta(src.meta()); - VLOG(3) << "src.numel() = " << src.numel() << ", dst->numel() = " - << dst->numel(); + VLOG(3) << "src.numel() = " << src.numel() + << ", dst->numel() = " << dst->numel(); void* dst_ptr = nullptr; if (paddle::platform::is_cpu_place(dst_place)) { diff --git a/paddle/phi/kernels/cpu/add_n_kernel.cc b/paddle/phi/kernels/cpu/add_n_kernel.cc index df8c0b7a55b0f..f56c7f0fa3d0a 100644 --- a/paddle/phi/kernels/cpu/add_n_kernel.cc +++ b/paddle/phi/kernels/cpu/add_n_kernel.cc @@ -21,9 +21,9 @@ void AddNKernel(const Context& dev_ctx, const std::vector& x, DenseTensor* out) { size_t in_num = x.size(); - for (const TensorBase *tb : x) { + for (const TensorBase* tb : x) { if (tb->initialized() && DenseTensor::classof(tb)) { - auto* dt = static_cast(tb); + auto* dt = static_cast(tb); out->set_meta(dt->meta()); break; } diff --git a/paddle/phi/kernels/gpu/add_n_kernel.cu b/paddle/phi/kernels/gpu/add_n_kernel.cu index 3dd9e22a5a069..cee430bc0c0e3 100644 --- a/paddle/phi/kernels/gpu/add_n_kernel.cu +++ b/paddle/phi/kernels/gpu/add_n_kernel.cu @@ -97,7 +97,7 @@ void AddNKernel(const Context &dev_ctx, for (const TensorBase *tb : x) { if (tb->initialized() && DenseTensor::classof(tb)) { - auto* dt = static_cast(tb); + auto *dt = static_cast(tb); out->set_meta(dt->meta()); break; } From 278fd87650e5be79412c3240e00d05e481826922 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Mon, 16 Jan 2023 03:28:29 +0000 Subject: [PATCH 08/13] Fix CI accuracy --- paddle/fluid/operators/controlflow/conditional_block_op.cc | 2 +- paddle/phi/core/tensor_utils.cc | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 676cb0f8328bc..b961d321bfc80 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -384,7 +384,7 @@ class ConditionalBlockGradOp : public ConditionalOp { if (!input_tensor.IsInitialized() || input_tensor.numel() == 0) { return; } - if (!input_tensor.meta().is_scalar) { + if (input_tensor.dims().size() != 0) { outside_tensor->Resize(input_tensor.dims()); } VLOG(4) << "Assigning zero to " << outside_tensor; diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index 6c88aef43bcdd..467552032f0ad 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -51,9 +51,7 @@ void Copy(const Context& dev_ctx, VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " << dst_place; - dst->set_meta(src.meta()); - VLOG(3) << "src.numel() = " << src.numel() - << ", dst->numel() = " << dst->numel(); + dst->Resize(src.dims()); void* dst_ptr = nullptr; if (paddle::platform::is_cpu_place(dst_place)) { From 6419c114c43c3c64de2a8921b185911da2f577e7 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Mon, 16 Jan 2023 07:07:27 +0000 Subject: [PATCH 09/13] Fix test_custom_op_setup --- python/paddle/fluid/backward.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 0b2990aacb4a8..19da1ccdff132 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -385,12 +385,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): def _create_loss_op_desc_(loss): + create_shape = [] if len(loss.shape) == 0 else [1] op_desc = _create_op_desc_( "fill_constant", {}, {"Out": [_append_grad_suffix_(loss.name)]}, { - "shape": list(loss.shape), + "shape": create_shape, "value": 1.0, "dtype": loss.dtype, "force_cpu": False, From 5dfd603ad9ac5be80ae820b9a454e50276536a45 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Mon, 16 Jan 2023 12:48:39 +0000 Subject: [PATCH 10/13] Merge with add_n_kernel PR https://github.com/PaddlePaddle/Paddle/pull/49854 --- .../fluid/operators/controlflow/conditional_block_op.cc | 4 +--- paddle/phi/kernels/cpu/add_n_kernel.cc | 7 ------- paddle/phi/kernels/gpu/add_n_kernel.cu | 9 --------- 3 files changed, 1 insertion(+), 19 deletions(-) diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index b961d321bfc80..f11bf6612c8ca 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -384,10 +384,8 @@ class ConditionalBlockGradOp : public ConditionalOp { if (!input_tensor.IsInitialized() || input_tensor.numel() == 0) { return; } - if (input_tensor.dims().size() != 0) { - outside_tensor->Resize(input_tensor.dims()); - } VLOG(4) << "Assigning zero to " << outside_tensor; + outside_tensor->Resize(input_tensor.dims()); outside_tensor->mutable_data(place, input_tensor.dtype()); const platform::DeviceContext *dev_ctx = platform::DeviceContextPool::Instance().Get(place); diff --git a/paddle/phi/kernels/cpu/add_n_kernel.cc b/paddle/phi/kernels/cpu/add_n_kernel.cc index f56c7f0fa3d0a..54506ccd54f5b 100644 --- a/paddle/phi/kernels/cpu/add_n_kernel.cc +++ b/paddle/phi/kernels/cpu/add_n_kernel.cc @@ -21,13 +21,6 @@ void AddNKernel(const Context& dev_ctx, const std::vector& x, DenseTensor* out) { size_t in_num = x.size(); - for (const TensorBase* tb : x) { - if (tb->initialized() && DenseTensor::classof(tb)) { - auto* dt = static_cast(tb); - out->set_meta(dt->meta()); - break; - } - } dev_ctx.template Alloc(out); bool in_place = false; diff --git a/paddle/phi/kernels/gpu/add_n_kernel.cu b/paddle/phi/kernels/gpu/add_n_kernel.cu index cee430bc0c0e3..f32ba597f5b68 100644 --- a/paddle/phi/kernels/gpu/add_n_kernel.cu +++ b/paddle/phi/kernels/gpu/add_n_kernel.cu @@ -94,15 +94,6 @@ void AddNKernel(const Context &dev_ctx, grids = dim3(CEIL_DIV(length, tile_size), 1, 1); blocks = dim3(tile_size, 1, 1); }; - - for (const TensorBase *tb : x) { - if (tb->initialized() && DenseTensor::classof(tb)) { - auto *dt = static_cast(tb); - out->set_meta(dt->meta()); - break; - } - } - auto *out_ptr = dev_ctx.template Alloc(out); bool in_place = false; if (x.size() > 0 && x[0]->initialized() && DenseTensor::classof(x[0])) { From a9c5a40d18c7241a321b7c7009ffb142a623ea16 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Tue, 17 Jan 2023 07:28:48 +0000 Subject: [PATCH 11/13] Enrich Dygraph and Shape Unit Test for Case and Switch Case --- .../paddle/fluid/tests/unittests/test_case.py | 64 +++++++++++ .../fluid/tests/unittests/test_switch_case.py | 105 ++++++++++++++++-- 2 files changed, 161 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_case.py b/python/paddle/fluid/tests/unittests/test_case.py index 9123b4b009d18..50d37923112a0 100644 --- a/python/paddle/fluid/tests/unittests/test_case.py +++ b/python/paddle/fluid/tests/unittests/test_case.py @@ -145,10 +145,71 @@ def fn_3(): ) np.testing.assert_allclose(res[0], 1, rtol=1e-05) + self.assertEqual(res[0].shape, ()) np.testing.assert_allclose(res[1], 2, rtol=1e-05) + self.assertEqual(res[1].shape, ()) np.testing.assert_allclose(res[2], 3, rtol=1e-05) + self.assertEqual(res[2].shape, ()) np.testing.assert_allclose(res[3], 2, rtol=1e-05) + self.assertEqual(res[3].shape, ()) np.testing.assert_allclose(res[4], 2, rtol=1e-05) + self.assertEqual(res[4].shape, ()) + + def test_0d_tensor_dygraph(self): + paddle.disable_static() + + def fn_1(): + return paddle.full(shape=[], dtype='int32', fill_value=1) + + def fn_2(): + return paddle.full(shape=[], dtype='int32', fill_value=2) + + def fn_3(): + return paddle.full(shape=[], dtype='int32', fill_value=3) + + x = paddle.full(shape=[], dtype='float32', fill_value=0.3) + y = paddle.full(shape=[], dtype='float32', fill_value=0.1) + z = paddle.full(shape=[], dtype='float32', fill_value=0.2) + pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1 + pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3 + + # call fn_1 + out_0 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3 + ) + + # call fn_2 + out_1 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3 + ) + + # call default fn_3 + out_2 = paddle.static.nn.control_flow.case( + pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3 + ) + + # no default, call fn_2 + out_3 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_1, fn_2)] + ) + + # no default, call fn_2. but pred_2 is false + out_4 = paddle.static.nn.control_flow.case( + pred_fn_pairs=[(pred_2, fn_2)] + ) + + np.testing.assert_allclose(out_0, 1, rtol=1e-05) + self.assertEqual(out_0.shape, []) + np.testing.assert_allclose(out_1, 2, rtol=1e-05) + self.assertEqual(out_1.shape, []) + np.testing.assert_allclose(out_2, 3, rtol=1e-05) + self.assertEqual(out_2.shape, []) + np.testing.assert_allclose(out_3, 2, rtol=1e-05) + self.assertEqual(out_3.shape, []) + np.testing.assert_allclose(out_4, 2, rtol=1e-05) + self.assertEqual(out_4.shape, []) + + paddle.enable_static() def test_return_var_tuple(self): def fn_1(): @@ -394,8 +455,11 @@ def fn_3(): res = exe.run(main_program, fetch_list=[out_1, out_2, out_3]) np.testing.assert_allclose(res[0], 1, rtol=1e-05) + self.assertEqual(res[0].shape, ()) np.testing.assert_allclose(res[1], 2, rtol=1e-05) + self.assertEqual(res[1].shape, ()) np.testing.assert_allclose(res[2], 3, rtol=1e-05) + self.assertEqual(res[2].shape, ()) class TestAPICase_Error(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_switch_case.py b/python/paddle/fluid/tests/unittests/test_switch_case.py index 2ddbd0f7ff051..322e8c5d7c0c4 100644 --- a/python/paddle/fluid/tests/unittests/test_switch_case.py +++ b/python/paddle/fluid/tests/unittests/test_switch_case.py @@ -93,25 +93,25 @@ def fn_3(): res[1], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[1], 2), ) np.testing.assert_allclose( res[2], 3, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 3), + err_msg='result is {} but answer is {}'.format(res[2], 3), ) np.testing.assert_allclose( res[3], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[3], 2), ) np.testing.assert_allclose( res[4], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[4], 2), ) def test_0d_tensor(self): @@ -176,30 +176,116 @@ def fn_3(): rtol=1e-05, err_msg='result is {} but answer is {}'.format(res[0], 1), ) + self.assertEqual(res[0].shape, ()) np.testing.assert_allclose( res[1], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[1], 2), ) + self.assertEqual(res[1].shape, ()) np.testing.assert_allclose( res[2], 3, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 3), + err_msg='result is {} but answer is {}'.format(res[2], 3), ) + self.assertEqual(res[2].shape, ()) np.testing.assert_allclose( res[3], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[3], 2), ) + self.assertEqual(res[3].shape, ()) np.testing.assert_allclose( res[4], 2, rtol=1e-05, - err_msg='result is {} but answer is {}'.format(res[0], 2), + err_msg='result is {} but answer is {}'.format(res[4], 2), ) + self.assertEqual(res[4].shape, ()) + + def test_0d_tensor_dygraph(self): + paddle.disable_static() + + def fn_1(): + return paddle.full(shape=[], dtype='int32', fill_value=1) + + def fn_2(): + return paddle.full(shape=[], dtype='int32', fill_value=2) + + def fn_3(): + return paddle.full(shape=[], dtype='int32', fill_value=3) + + index_1 = paddle.full(shape=[], dtype='int32', fill_value=1) + index_2 = paddle.full(shape=[], dtype='int32', fill_value=2) + index_5 = paddle.full(shape=[], dtype='int32', fill_value=5) + + # call fn_1 + out_0 = paddle.static.nn.switch_case( + branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3} + ) + + # call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3} + out_1 = paddle.static.nn.switch_case( + branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3) + ) + + # call default fn_3 + out_2 = paddle.static.nn.switch_case( + branch_index=index_5, + branch_fns=((1, fn_1), (2, fn_2)), + default=fn_3, + ) + + # no default, call fn_2 + out_3 = paddle.static.nn.switch_case( + branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)] + ) + + # no default, call fn_2 but branch_index is 5 + out_4 = paddle.static.nn.switch_case( + branch_index=index_5, + branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)], + ) + np.testing.assert_allclose( + out_0, + 1, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(out_0, 1), + ) + self.assertEqual(out_0.shape, []) + np.testing.assert_allclose( + out_1, + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(out_1, 2), + ) + self.assertEqual(out_1.shape, []) + np.testing.assert_allclose( + out_2, + 3, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(out_2, 3), + ) + self.assertEqual(out_2.shape, []) + np.testing.assert_allclose( + out_3, + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(out_3, 2), + ) + self.assertEqual(out_3.shape, []) + np.testing.assert_allclose( + out_4, + 2, + rtol=1e-05, + err_msg='result is {} but answer is {}'.format(out_4, 2), + ) + self.assertEqual(out_4.shape, []) + + paddle.enable_static() def test_return_var_tuple(self): def fn_1(): @@ -426,18 +512,21 @@ def fn_3(): rtol=1e-05, err_msg='result is {} but answer is {}'.format(res[0], 1), ) + self.assertEqual(res[0].shape, ()) np.testing.assert_allclose( res[1], 2, rtol=1e-05, err_msg='result is {} but answer is {}'.format(res[1], 2), ) + self.assertEqual(res[1].shape, ()) np.testing.assert_allclose( res[2], 3, rtol=1e-05, err_msg='result is {} but answer is {}'.format(res[2], 3), ) + self.assertEqual(res[2].shape, ()) # test TypeError and ValueError of api switch_case From 0db2b00880a7b4955d05d8a5b55bb3a1e7da91f9 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Tue, 17 Jan 2023 10:34:41 +0000 Subject: [PATCH 12/13] Add backward 0d tensor test --- .../paddle/fluid/tests/unittests/test_case.py | 31 +++++++++++++++++ .../fluid/tests/unittests/test_switch_case.py | 33 +++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_case.py b/python/paddle/fluid/tests/unittests/test_case.py index 50d37923112a0..675b51cf0a053 100644 --- a/python/paddle/fluid/tests/unittests/test_case.py +++ b/python/paddle/fluid/tests/unittests/test_case.py @@ -22,6 +22,7 @@ import paddle.fluid.core as core import paddle.fluid.layers as layers import paddle.fluid.optimizer as optimizer +from paddle.fluid.backward import append_backward from paddle.fluid.framework import Program, program_guard paddle.enable_static() @@ -155,6 +156,36 @@ def fn_3(): np.testing.assert_allclose(res[4], 2, rtol=1e-05) self.assertEqual(res[4].shape, ()) + def test_0d_tensor_backward(self): + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[], dtype='float32', fill_value=-2.0) + x.stop_gradient = False + pred = paddle.full(shape=[], dtype='bool', fill_value=0) + # pred is False, so out = -x + out = paddle.static.nn.case( + pred_fn_pairs=[(pred, lambda: x)], default=lambda: -x + ) + append_backward(out) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run(main_program, fetch_list=[out.name, x.grad_name]) + np.testing.assert_allclose( + np.asarray(res[0]), np.array(2.0), rtol=1e-05 + ) + self.assertEqual(res[0].shape, ()) + np.testing.assert_allclose( + np.asarray(res[1]), np.array(-1.0), rtol=1e-05 + ) + self.assertEqual(res[1].shape, ()) + def test_0d_tensor_dygraph(self): paddle.disable_static() diff --git a/python/paddle/fluid/tests/unittests/test_switch_case.py b/python/paddle/fluid/tests/unittests/test_switch_case.py index 322e8c5d7c0c4..170b13998157d 100644 --- a/python/paddle/fluid/tests/unittests/test_switch_case.py +++ b/python/paddle/fluid/tests/unittests/test_switch_case.py @@ -21,6 +21,7 @@ import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.layers as layers +from paddle.fluid.backward import append_backward from paddle.fluid.framework import Program, program_guard paddle.enable_static() @@ -206,6 +207,38 @@ def fn_3(): ) self.assertEqual(res[4].shape, ()) + def test_0d_tensor_backward(self): + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = paddle.full(shape=[], dtype='float32', fill_value=-2.0) + x.stop_gradient = False + pred = paddle.full(shape=[], dtype='int32', fill_value=2) + # pred is 0, so out = 2 * x + out = paddle.static.nn.switch_case( + branch_index=pred, + branch_fns=[(1, lambda: x), (2, lambda: 2 * x)], + default=lambda: -x, + ) + append_backward(out) + + place = ( + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() + else fluid.CPUPlace() + ) + exe = fluid.Executor(place) + + res = exe.run(main_program, fetch_list=[out.name, x.grad_name]) + np.testing.assert_allclose( + np.asarray(res[0]), np.array(-4.0), rtol=1e-05 + ) + self.assertEqual(res[0].shape, ()) + np.testing.assert_allclose( + np.asarray(res[1]), np.array(2.0), rtol=1e-05 + ) + self.assertEqual(res[1].shape, ()) + def test_0d_tensor_dygraph(self): paddle.disable_static() From b2d1d013553269a93d55e3094a2b50a6f4d6a044 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Tue, 17 Jan 2023 10:37:09 +0000 Subject: [PATCH 13/13] Fix comment error --- python/paddle/fluid/tests/unittests/test_switch_case.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_switch_case.py b/python/paddle/fluid/tests/unittests/test_switch_case.py index 170b13998157d..3fad3bdfd0c0d 100644 --- a/python/paddle/fluid/tests/unittests/test_switch_case.py +++ b/python/paddle/fluid/tests/unittests/test_switch_case.py @@ -214,7 +214,7 @@ def test_0d_tensor_backward(self): x = paddle.full(shape=[], dtype='float32', fill_value=-2.0) x.stop_gradient = False pred = paddle.full(shape=[], dtype='int32', fill_value=2) - # pred is 0, so out = 2 * x + # pred is 2, so out = 2 * x out = paddle.static.nn.switch_case( branch_index=pred, branch_fns=[(1, lambda: x), (2, lambda: 2 * x)],