From e4e0a7f530df84260468b075240da8bf88d68e92 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 16 Nov 2023 13:26:27 +0800 Subject: [PATCH] [BugFix] Fix cycle gan unittest for pir api mode. (#58999) --------- Co-authored-by: SigureMo --- .../eager/to_static/run_program_op_node.h | 10 ++++----- test/dygraph_to_static/test_cycle_gan.py | 21 +++++++++++++------ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 1bfecab26e11f..bb398391ea37d 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -129,6 +129,7 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var, static void ShareTensorsIntoScope(const std::vector &tensors, paddle::framework::Scope *scope) { for (size_t i = 0; i < tensors.size(); ++i) { + VLOG(4) << "Share Tensor Into Scope: " << i; auto name = tensors[i].name(); if (name == paddle::framework::kFakeVarName || name == paddle::framework::kEmptyVarName) { @@ -511,18 +512,17 @@ inline void PirRunProgramAPI( details::GetNameFromValue(forward_global_block, middle_values, false); auto skip_names_set = std::set(skip_names.begin(), skip_names.end()); - skip_names = - details::GetNameFromValue(forward_global_block, output_values, false); - skip_names_set.insert(skip_names.begin(), skip_names.end()); auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("no_need_buffers")); auto no_need_buffer_names = details::GetNameFromValue( forward_global_block, no_need_buffer_values, false); - VLOG(4) << "start skip no need buffer vars with name:"; for (auto &name : no_need_buffer_names) { - VLOG(4) << "Skip no need buffer vars with name:" << name; + VLOG(4) << "Find no need buffer vars with name:" << name; skip_names_set.erase(name); } + skip_names = + details::GetNameFromValue(forward_global_block, output_values, false); + skip_names_set.insert(skip_names.begin(), skip_names.end()); details::print_collection(skip_names_set); interpreter_core->SetSkipGcVars(skip_names_set); diff --git a/test/dygraph_to_static/test_cycle_gan.py b/test/dygraph_to_static/test_cycle_gan.py index 58560286b3020..46b70b7bcc0cc 100644 --- a/test/dygraph_to_static/test_cycle_gan.py +++ b/test/dygraph_to_static/test_cycle_gan.py @@ -26,7 +26,10 @@ # Use GPU:0 to elimate the influence of other tasks. os.environ["CUDA_VISIBLE_DEVICES"] = "1" -from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_legacy_and_pir_exe_and_pir_api, +) import paddle from paddle.base.dygraph import to_variable @@ -617,9 +620,12 @@ def train(args, to_static): fake_pool_A = to_variable(fake_pool_A) # optimize the d_A network - rec_B, fake_pool_rec_B = paddle.jit.to_static( + discriminatorA_to_static = paddle.jit.to_static( cycle_gan.discriminatorA - )(data_B, fake_pool_B) + ) + rec_B, fake_pool_rec_B = discriminatorA_to_static( + data_B, fake_pool_B + ) d_loss_A = ( paddle.square(fake_pool_rec_B) + paddle.square(rec_B - 1) ) / 2.0 @@ -630,9 +636,12 @@ def train(args, to_static): cycle_gan.clear_gradients() # optimize the d_B network - rec_A, fake_pool_rec_A = paddle.jit.to_static( + discriminatorB_to_static = paddle.jit.to_static( cycle_gan.discriminatorB - )(data_A, fake_pool_A) + ) + rec_A, fake_pool_rec_A = discriminatorB_to_static( + data_A, fake_pool_A + ) d_loss_B = ( paddle.square(fake_pool_rec_A) + paddle.square(rec_A - 1) ) / 2.0 @@ -681,7 +690,7 @@ def train(self, to_static): out = train(self.args, to_static) return out - @test_legacy_and_pir + @test_legacy_and_pir_exe_and_pir_api def test_train(self): st_out = self.train(to_static=True) dy_out = self.train(to_static=False)