Skip to content

Commit

Permalink
[BugFix] Fix cycle gan unittest for pir api mode. (PaddlePaddle#58999)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: SigureMo <[email protected]>
  • Loading branch information
2 people authored and SecretXV committed Nov 28, 2023
1 parent 3a53059 commit e4e0a7f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
10 changes: 5 additions & 5 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var,
static void ShareTensorsIntoScope(const std::vector<Tensor> &tensors,
paddle::framework::Scope *scope) {
for (size_t i = 0; i < tensors.size(); ++i) {
VLOG(4) << "Share Tensor Into Scope: " << i;
auto name = tensors[i].name();
if (name == paddle::framework::kFakeVarName ||
name == paddle::framework::kEmptyVarName) {
Expand Down Expand Up @@ -511,18 +512,17 @@ inline void PirRunProgramAPI(
details::GetNameFromValue(forward_global_block, middle_values, false);
auto skip_names_set =
std::set<std::string>(skip_names.begin(), skip_names.end());
skip_names =
details::GetNameFromValue(forward_global_block, output_values, false);
skip_names_set.insert(skip_names.begin(), skip_names.end());
auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>,
attrs.at("no_need_buffers"));
auto no_need_buffer_names = details::GetNameFromValue(
forward_global_block, no_need_buffer_values, false);
VLOG(4) << "start skip no need buffer vars with name:";
for (auto &name : no_need_buffer_names) {
VLOG(4) << "Skip no need buffer vars with name:" << name;
VLOG(4) << "Find no need buffer vars with name:" << name;
skip_names_set.erase(name);
}
skip_names =
details::GetNameFromValue(forward_global_block, output_values, false);
skip_names_set.insert(skip_names.begin(), skip_names.end());
details::print_collection(skip_names_set);
interpreter_core->SetSkipGcVars(skip_names_set);

Expand Down
21 changes: 15 additions & 6 deletions test/dygraph_to_static/test_cycle_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
# Use GPU:0 to elimate the influence of other tasks.
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_legacy_and_pir_exe_and_pir_api,
)

import paddle
from paddle.base.dygraph import to_variable
Expand Down Expand Up @@ -617,9 +620,12 @@ def train(args, to_static):
fake_pool_A = to_variable(fake_pool_A)

# optimize the d_A network
rec_B, fake_pool_rec_B = paddle.jit.to_static(
discriminatorA_to_static = paddle.jit.to_static(
cycle_gan.discriminatorA
)(data_B, fake_pool_B)
)
rec_B, fake_pool_rec_B = discriminatorA_to_static(
data_B, fake_pool_B
)
d_loss_A = (
paddle.square(fake_pool_rec_B) + paddle.square(rec_B - 1)
) / 2.0
Expand All @@ -630,9 +636,12 @@ def train(args, to_static):
cycle_gan.clear_gradients()

# optimize the d_B network
rec_A, fake_pool_rec_A = paddle.jit.to_static(
discriminatorB_to_static = paddle.jit.to_static(
cycle_gan.discriminatorB
)(data_A, fake_pool_A)
)
rec_A, fake_pool_rec_A = discriminatorB_to_static(
data_A, fake_pool_A
)
d_loss_B = (
paddle.square(fake_pool_rec_A) + paddle.square(rec_A - 1)
) / 2.0
Expand Down Expand Up @@ -681,7 +690,7 @@ def train(self, to_static):
out = train(self.args, to_static)
return out

@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_train(self):
st_out = self.train(to_static=True)
dy_out = self.train(to_static=False)
Expand Down

0 comments on commit e4e0a7f

Please sign in to comment.