Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St][PIR] Hold backward program in GradNode #63694

Merged
merged 25 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions paddle/fluid/eager/to_static/run_program_op_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,7 @@ inline void pir_run_program_ad_func(

grad_node->SetStepScope(step_scope); // just for set useable.

// Set Grad out rank as same as fwd input and set stop gradient to bwd
// NOTE(@xiongkun): Not every tensor in x(list of tensor) is required
// gradient. for example: x[1] is not used for output, the x[1] is ignored.

std::vector<const paddle::Tensor*> x_require_grad;
for (size_t i = 0; i < x.size(); ++i) {
x_require_grad.push_back(&x[i]);
}

grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0);
grad_node->SetGradOutMeta(x, /*slot id*/ 0);
grad_node->SetGradOutMeta(params, /*slot id*/ 1);

// TODO(@xiongkun): rewrite by new ir representation.
Expand Down
63 changes: 34 additions & 29 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,21 +467,18 @@ inline void PirRunProgramAPI(
auto param_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp"));

auto *forward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block"));
auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));
auto forward_program = PADDLE_GET_CONST(std::shared_ptr<::pir::Program>,
attrs.at("forward_program"));
auto backward_program = PADDLE_GET_CONST(std::shared_ptr<::pir::Program>,
attrs.at("backward_program"));

auto *forward_program =
forward_global_block->GetParentOp()->GetParentProgram();
::pir::Block *forward_global_block = forward_program->block();

if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "ForwardProgram is :\n";
forward_program->Print(print_stream);
if (!is_test) {
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
print_stream << "BackwardProgram is:\n";
backward_program->Print(print_stream);
} else {
Expand Down Expand Up @@ -514,7 +511,7 @@ inline void PirRunProgramAPI(
forward_global_block, params, param_values, global_inner_scope);
// Step 2. create new interpretercore
auto passed_kernel_program =
paddle::framework::ApplyIrPass(forward_program, place);
paddle::framework::ApplyIrPass(forward_program.get(), place);
if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "LoweredProgram( AfterPass ) is :\n";
Expand Down Expand Up @@ -1046,10 +1043,8 @@ inline void PirRunProgramGradAPI(

VLOG(4) << "global_inner_scope:" << global_inner_scope;

auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
auto backward_program = PADDLE_GET_CONST(std::shared_ptr<::pir::Program>,
attrs.at("backward_program"));

auto output_grad_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo_g"));
Expand All @@ -1069,18 +1064,22 @@ inline void PirRunProgramGradAPI(
details::Trans2ContiguousTensorsInplace(out_grad);

// share x, param, middles, output_grads, out into scope.
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out_grad,
output_grad_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, out_grad, output_grad_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, x, forward_input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_global_block,
backward_program->block(), x, forward_input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_program->block(),
middles,
forward_middle_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out,
forward_output_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, out, forward_output_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, params, parameter_values, global_inner_scope);
backward_program->block(), params, parameter_values, global_inner_scope);

// Clear out and middles to avoid hold memory until backward finish.
out.clear();
Expand All @@ -1089,6 +1088,7 @@ inline void PirRunProgramGradAPI(
auto &cache = paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr;

if (!cache.Has(program_id,
global_inner_scope,
place_hash_key,
Expand All @@ -1101,7 +1101,7 @@ inline void PirRunProgramGradAPI(
VLOG(2) << "No interpretercore cache, so create a new interpretercore";
// Step 1. share input_vars & parameters into scope
auto passed_kernel_program =
paddle::framework::ApplyIrPass(backward_program, place);
paddle::framework::ApplyIrPass(backward_program.get(), place);

const auto &new_block = passed_kernel_program->block();
passed_kernel_program = paddle::framework::ApplyRemoveShadowFeedPass(
Expand Down Expand Up @@ -1143,10 +1143,10 @@ inline void PirRunProgramGradAPI(
// get all eager gc vars
std::set<std::string> skip_eager_delete_vars;
auto skip_names = details::GetNameFromValue(
backward_global_block, x_grad_values, false, true);
backward_program->block(), x_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
skip_names = details::GetNameFromValue(
backward_global_block, p_grad_values, false, true);
backward_program->block(), p_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
cache.UpdateSkipEagerDeleteVars(program_id,
Expand Down Expand Up @@ -1179,7 +1179,7 @@ inline void PirRunProgramGradAPI(
}
}

if (!backward_global_block->empty()) {
if (!backward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
paddle::platform::TracerEventType::UserDefined,
Expand All @@ -1194,9 +1194,11 @@ inline void PirRunProgramGradAPI(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Step 4. get outputs
details::ShareTensorsFromScopeByValue(
backward_global_block, x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(
backward_global_block, params_grad, p_grad_values, global_inner_scope);
backward_program->block(), x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(backward_program->block(),
params_grad,
p_grad_values,
global_inner_scope);
VLOG(4) << "after backward gc all vars";
global_inner_scope->SetCanReused(true);
details::GcScope(global_inner_scope);
Expand Down Expand Up @@ -1335,8 +1337,7 @@ class GradNodeRunProgram : public egr::GradNodeBase {
if (x[i].is_dense_tensor()) {
x_grad->emplace_back(std::make_shared<phi::DenseTensor>());
} else if (x[i].is_selected_rows()) {
auto selected_row = std::make_shared<phi::SelectedRows>();
x_grad->emplace_back(selected_row);
x_grad->emplace_back(std::make_shared<phi::SelectedRows>());
}
x_grad->back().set_name(x_grad_names[i]);
}
Expand Down Expand Up @@ -1471,6 +1472,10 @@ class PirGradNodeRunProgram : public egr::GradNodeBase {
VLOG(3) << "End Eager Backward Node: PirGradNodeRunProgram";

*executed_ = true;
egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&x_grad,
this->OutputMeta()[0]);
egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&params_grad,
this->OutputMeta()[1]);
return {x_grad, params_grad};
}

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/phi/common/complex.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/utils/blank.h"

Expand Down Expand Up @@ -977,6 +978,9 @@ struct SetAttrDescVisitor {
void operator()(const std::vector<pir::Block *> &v) const {
// just do nothing.
}
void operator()(const std::shared_ptr<pir::Program> &v) const {
// just do nothing.
}
void operator()(const std::vector<VarDesc *> &v) const {
std::vector<std::string> var_names;
for (auto var : v) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/type_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ template class variant<paddle::blank,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>,
::pir::Block*,
std::vector<::pir::Value>>;
std::vector<::pir::Value>,
std::shared_ptr<::pir::Program>>;
} // namespace paddle
REGISTER_LOG_SIMPLY_STR(paddle::framework::AttributeMap);
REGISTER_LOG_SIMPLY_STR(paddle::framework::Attribute);
4 changes: 3 additions & 1 deletion paddle/fluid/framework/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/utils/blank.h"
#include "paddle/utils/small_vector.h"
Expand Down Expand Up @@ -67,7 +68,8 @@ using Attribute = paddle::variant<paddle::blank,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>,
::pir::Block*,
std::vector<::pir::Value>>;
std::vector<::pir::Value>,
std::shared_ptr<::pir::Program>>;
using AttributeMap = std::unordered_map<std::string, Attribute>;

using OpCreator =
Expand Down
20 changes: 16 additions & 4 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "paddle/phi/common/complex.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/op_result.h"
#include "paddle/pir/include/core/region.h"
#include "paddle/pir/include/core/value.h"

namespace paddle {
Expand Down Expand Up @@ -858,6 +859,17 @@ void CastPyArg2AttrIRBlock(PyObject* obj,
attrs[key] = reinterpret_cast<::pir::Block*&>(vh[0]);
}

void CastPyArg2AttrIRProgram(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos) {
VLOG(1) << "After Process pir::Program*";
const std::shared_ptr<::pir::Program> program =
::py::handle(obj).cast<std::shared_ptr<::pir::Program>>();
attrs[key] = program;
}

void CastPyArg2AttrValues(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
Expand Down Expand Up @@ -1020,11 +1032,11 @@ void ConstructAttrMapForRunProgram(

if (std::set<std::string>({"cuda_graph_capture_mode"}).count(key)) {
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"global_block",
"forward_global_block",
"backward_global_block"})
.count(key)) {
} else if (std::set<std::string>({"global_block"}).count(key)) {
CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"forward_program", "backward_program"})
.count(key)) {
CastPyArg2AttrIRProgram(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"is_test", "use_interpretorcore"})
.count(key)) {
CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ void BindProgram(py::module *m) {
)DOC");
program
.def(py::init([]() {
return std::make_unique<Program>(pir::IrContext::Instance());
return std::make_shared<Program>(pir::IrContext::Instance());
}))
.def("__str__",
[](const std::shared_ptr<Program> &self) {
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,10 +899,10 @@ def _prune_unused_params(self, program):

def _prepare_attributes(self):
attrs = [
'forward_global_block',
self.program.forward_program.global_block(),
'backward_global_block',
self.program.backward_program.global_block(),
'forward_program',
self.program.forward_program,
'backward_program',
self.program.backward_program,
'is_test',
not self.training,
'program_id',
Expand Down
3 changes: 2 additions & 1 deletion test/dygraph_to_static/test_no_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir

import paddle

Expand All @@ -33,6 +33,7 @@ def main_func(x, index):


class TestNoGradientCase(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_no_gradient(self):
paddle.disable_static()
x = paddle.randn([10, 3])
Expand Down