From 8cd6842b06421b072a5fc197e7082eb11c0461c4 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Tue, 28 Feb 2023 09:30:09 +0800 Subject: [PATCH] Update test_static_prim.cc --- paddle/fluid/prim/tests/test_static_prim.cc | 31 ++++++++------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/prim/tests/test_static_prim.cc b/paddle/fluid/prim/tests/test_static_prim.cc index f751f2d65a302..20694d1e2a622 100644 --- a/paddle/fluid/prim/tests/test_static_prim.cc +++ b/paddle/fluid/prim/tests/test_static_prim.cc @@ -212,35 +212,28 @@ TEST(StaticPrim, TanhBackwardComposite) { ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(), static_cast(1)); - ASSERT_EQ(grad_ops[1]->Type(), "elementwise_pow"); - ASSERT_EQ(grad_ops[1]->Inputs().at("X").size(), static_cast(1)); - ASSERT_EQ(grad_ops[1]->Inputs().at("Y").size(), static_cast(1)); - ASSERT_EQ(grad_ops[1]->Inputs().at("X")[0], "b"); - ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(), - static_cast(1)); - - ASSERT_EQ(grad_ops[2]->Type(), "fill_constant"); + ASSERT_EQ(grad_ops[1]->Type(), "fill_constant"); ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[2]->GetAttr("dtype")), static_cast(5)); // ProtoDataType::FP32 + ASSERT_EQ(grad_ops[1]->Outputs().at("Out").size(), + static_cast(1)); + + ASSERT_EQ(grad_ops[2]->Type(), "elementwise_sub"); + ASSERT_EQ(grad_ops[2]->Inputs().at("X").size(), static_cast(1)); + ASSERT_EQ(grad_ops[2]->Inputs().at("Y").size(), static_cast(1)); + ASSERT_EQ(grad_ops[2]->Inputs().at("X")[0], + grad_ops[2]->Outputs().at("Out")[0]); ASSERT_EQ(grad_ops[2]->Outputs().at("Out").size(), static_cast(1)); - ASSERT_EQ(grad_ops[3]->Type(), "elementwise_sub"); + ASSERT_EQ(grad_ops[3]->Type(), "elementwise_mul"); ASSERT_EQ(grad_ops[3]->Inputs().at("X").size(), static_cast(1)); ASSERT_EQ(grad_ops[3]->Inputs().at("Y").size(), static_cast(1)); - ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0], + ASSERT_EQ(grad_ops[3]->Inputs().at("Y")[0], grad_ops[2]->Outputs().at("Out")[0]); + ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0], "b@GRAD"); ASSERT_EQ(grad_ops[3]->Outputs().at("Out").size(), static_cast(1)); - - ASSERT_EQ(grad_ops[4]->Type(), "elementwise_mul"); - ASSERT_EQ(grad_ops[4]->Inputs().at("X").size(), static_cast(1)); - ASSERT_EQ(grad_ops[4]->Inputs().at("Y").size(), static_cast(1)); - ASSERT_EQ(grad_ops[4]->Inputs().at("Y")[0], - grad_ops[3]->Outputs().at("Out")[0]); - ASSERT_EQ(grad_ops[4]->Inputs().at("X")[0], "b@GRAD"); - ASSERT_EQ(grad_ops[4]->Outputs().at("Out").size(), - static_cast(1)); } TEST(StaticCompositeGradMaker, TestMutiInputMethod) {