Skip to content

Commit

Permalink
Update test_static_prim.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoguoguo626807 authored Feb 28, 2023
1 parent 17717c2 commit 8cd6842
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions paddle/fluid/prim/tests/test_static_prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,35 +212,28 @@ TEST(StaticPrim, TanhBackwardComposite) {
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[1]->Type(), "elementwise_pow");
ASSERT_EQ(grad_ops[1]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Inputs().at("X")[0], "b");
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[2]->Type(), "fill_constant");
ASSERT_EQ(grad_ops[1]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[2]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[1]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[2]->Type(), "elementwise_sub");
ASSERT_EQ(grad_ops[2]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("X")[0],
grad_ops[2]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[2]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[3]->Type(), "elementwise_sub");
ASSERT_EQ(grad_ops[3]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[3]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0],
ASSERT_EQ(grad_ops[3]->Inputs().at("Y")[0],
grad_ops[2]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[3]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));

ASSERT_EQ(grad_ops[4]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[4]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y")[0],
grad_ops[3]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[4]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[4]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
}

TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
Expand Down

0 comments on commit 8cd6842

Please sign in to comment.