Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix cycle gan unittest for pir api mode. #58999

Merged
merged 9 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var,
static void ShareTensorsIntoScope(const std::vector<Tensor> &tensors,
paddle::framework::Scope *scope) {
for (size_t i = 0; i < tensors.size(); ++i) {
VLOG(4) << "Share Tensor Into Scope: " << i;
auto name = tensors[i].name();
if (name == paddle::framework::kFakeVarName ||
name == paddle::framework::kEmptyVarName) {
Expand Down Expand Up @@ -511,18 +512,17 @@ inline void PirRunProgramAPI(
details::GetNameFromValue(forward_global_block, middle_values, false);
auto skip_names_set =
std::set<std::string>(skip_names.begin(), skip_names.end());
skip_names =
details::GetNameFromValue(forward_global_block, output_values, false);
skip_names_set.insert(skip_names.begin(), skip_names.end());
auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>,
attrs.at("no_need_buffers"));
auto no_need_buffer_names = details::GetNameFromValue(
forward_global_block, no_need_buffer_values, false);
VLOG(4) << "start skip no need buffer vars with name:";
for (auto &name : no_need_buffer_names) {
VLOG(4) << "Skip no need buffer vars with name:" << name;
VLOG(4) << "Find no need buffer vars with name:" << name;
skip_names_set.erase(name);
}
skip_names =
details::GetNameFromValue(forward_global_block, output_values, false);
skip_names_set.insert(skip_names.begin(), skip_names.end());
details::print_collection(skip_names_set);
interpreter_core->SetSkipGcVars(skip_names_set);

Expand Down
21 changes: 15 additions & 6 deletions test/dygraph_to_static/test_cycle_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
# Use GPU:0 to elimate the influence of other tasks.
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_legacy_and_pir_exe_and_pir_api,
)

import paddle
from paddle.base.dygraph import to_variable
Expand Down Expand Up @@ -617,9 +620,12 @@ def train(args, to_static):
fake_pool_A = to_variable(fake_pool_A)

# optimize the d_A network
rec_B, fake_pool_rec_B = paddle.jit.to_static(
discriminatorA_to_static = paddle.jit.to_static(
cycle_gan.discriminatorA
)(data_B, fake_pool_B)
)
rec_B, fake_pool_rec_B = discriminatorA_to_static(
data_B, fake_pool_B
)
d_loss_A = (
paddle.square(fake_pool_rec_B) + paddle.square(rec_B - 1)
) / 2.0
Expand All @@ -630,9 +636,12 @@ def train(args, to_static):
cycle_gan.clear_gradients()

# optimize the d_B network
rec_A, fake_pool_rec_A = paddle.jit.to_static(
discriminatorB_to_static = paddle.jit.to_static(
cycle_gan.discriminatorB
)(data_A, fake_pool_A)
)
rec_A, fake_pool_rec_A = discriminatorB_to_static(
data_A, fake_pool_A
)
d_loss_B = (
paddle.square(fake_pool_rec_A) + paddle.square(rec_A - 1)
) / 2.0
Expand Down Expand Up @@ -681,7 +690,7 @@ def train(self, to_static):
out = train(self.args, to_static)
return out

@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_train(self):
st_out = self.train(to_static=True)
dy_out = self.train(to_static=False)
Expand Down