Skip to content

Commit

Permalink
[Dy2St][PIR] Hold backward program in GradNode (#63694)
Browse files Browse the repository at this point in the history
Co-authored-by: xiongkun <[email protected]>
Co-authored-by: Nyakku Shigure <[email protected]>
  • Loading branch information
3 people authored May 6, 2024
1 parent 5570ef4 commit 4cd5384
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 58 deletions.
11 changes: 1 addition & 10 deletions paddle/fluid/eager/to_static/run_program_op_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,16 +296,7 @@ inline void pir_run_program_ad_func(

grad_node->SetStepScope(step_scope); // just for set useable.

// Set Grad out rank as same as fwd input and set stop gradient to bwd
// NOTE(@xiongkun): Not every tensor in x(list of tensor) is required
// gradient. for example: x[1] is not used for output, the x[1] is ignored.

std::vector<const paddle::Tensor*> x_require_grad;
for (size_t i = 0; i < x.size(); ++i) {
x_require_grad.push_back(&x[i]);
}

grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0);
grad_node->SetGradOutMeta(x, /*slot id*/ 0);
grad_node->SetGradOutMeta(params, /*slot id*/ 1);

// TODO(@xiongkun): rewrite by new ir representation.
Expand Down
70 changes: 35 additions & 35 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,21 +467,16 @@ inline void PirRunProgramAPI(
auto param_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp"));

auto *forward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block"));
auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));

auto *forward_program =
forward_global_block->GetParentOp()->GetParentProgram();
std::shared_ptr<::pir::Program> forward_program = PADDLE_GET_CONST(
std::shared_ptr<::pir::Program>, attrs.at("forward_program"));
std::shared_ptr<::pir::Program> backward_program = PADDLE_GET_CONST(
std::shared_ptr<::pir::Program>, attrs.at("backward_program"));

if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "ForwardProgram is :\n";
forward_program->Print(print_stream);
if (!is_test) {
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
print_stream << "BackwardProgram is:\n";
backward_program->Print(print_stream);
} else {
Expand Down Expand Up @@ -509,12 +504,12 @@ inline void PirRunProgramAPI(
<< program_id;
// Step 1. share input_vars & parameters into scope
details::ShareTensorsIntoScopeByValue(
forward_global_block, x, input_values, global_inner_scope);
forward_program->block(), x, input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
forward_global_block, params, param_values, global_inner_scope);
forward_program->block(), params, param_values, global_inner_scope);
// Step 2. create new interpretercore
auto passed_kernel_program =
paddle::framework::ApplyIrPass(forward_program, place);
paddle::framework::ApplyIrPass(forward_program.get(), place);
if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "LoweredProgram( AfterPass ) is :\n";
Expand All @@ -535,22 +530,22 @@ inline void PirRunProgramAPI(

// update interpretercore skip_gc_var
auto skip_names = details::GetNameFromValue(
forward_global_block, middle_values, false, true);
forward_program->block(), middle_values, false, true);
auto skip_names_set =
std::set<std::string>(skip_names.begin(), skip_names.end());
auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>,
attrs.at("no_need_buffers"));
auto no_need_buffer_names = details::GetNameFromValue(
forward_global_block, no_need_buffer_values, false, true);
forward_program->block(), no_need_buffer_values, false, true);
for (auto &name : no_need_buffer_names) {
VLOG(4) << "Find no need buffer vars with name:" << name;
skip_names_set.erase(name);
}
skip_names = details::GetNameFromValue(
forward_global_block, output_values, false, true);
forward_program->block(), output_values, false, true);
skip_names_set.insert(skip_names.begin(), skip_names.end());
skip_names = details::GetNameFromValue(
forward_global_block, input_values, true, false);
forward_program->block(), input_values, true, false);
skip_names_set.insert(skip_names.begin(), skip_names.end());
details::print_collection(skip_names_set);
interpreter_core->SetSkipGcVars(skip_names_set);
Expand All @@ -576,9 +571,9 @@ inline void PirRunProgramAPI(
interpreter_core = cached_value.core_;
// Step 2. update scope for cache interpretercore
details::ShareTensorsIntoScopeByValue(
forward_global_block, x, input_values, global_inner_scope);
forward_program->block(), x, input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
forward_global_block, params, param_values, global_inner_scope);
forward_program->block(), params, param_values, global_inner_scope);
// TODO(xiongkun): new ir how to build scope.
// if (interpreter_core->GetVariableScope()->GetMutableScope() !=
// global_inner_scope) {
Expand All @@ -589,7 +584,7 @@ inline void PirRunProgramAPI(
}

// interpretercore run
if (!forward_global_block->empty()) {
if (!forward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
paddle::platform::TracerEventType::UserDefined,
Expand All @@ -602,7 +597,7 @@ inline void PirRunProgramAPI(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Get Output, and Middle Outputs
details::ShareTensorsFromScopeByValue(
forward_global_block, out, output_values, global_inner_scope);
forward_program->block(), out, output_values, global_inner_scope);

VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());

Expand Down Expand Up @@ -1041,10 +1036,8 @@ inline void PirRunProgramGradAPI(

VLOG(4) << "global_inner_scope:" << global_inner_scope;

auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
std::shared_ptr<::pir::Program> backward_program = PADDLE_GET_CONST(
std::shared_ptr<::pir::Program>, attrs.at("backward_program"));

auto output_grad_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo_g"));
Expand All @@ -1064,8 +1057,10 @@ inline void PirRunProgramGradAPI(
details::Trans2ContiguousTensorsInplace(out_grad);

// share x, param, middles, output_grads, out into scope.
details::ShareTensorsIntoScopeByValue(
backward_global_block, out_grad, output_grad_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out_grad,
output_grad_values,
global_inner_scope);

auto &cache = paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
Expand All @@ -1082,7 +1077,7 @@ inline void PirRunProgramGradAPI(
VLOG(2) << "No interpretercore cache, so create a new interpretercore";
// Step 1. share input_vars & parameters into scope
auto passed_kernel_program =
paddle::framework::ApplyIrPass(backward_program, place);
paddle::framework::ApplyIrPass(backward_program.get(), place);

const auto &new_block = passed_kernel_program->block();
passed_kernel_program = paddle::framework::ApplyRemoveShadowFeedPass(
Expand Down Expand Up @@ -1124,10 +1119,10 @@ inline void PirRunProgramGradAPI(
// get all eager gc vars
std::set<std::string> skip_eager_delete_vars;
auto skip_names = details::GetNameFromValue(
backward_global_block, x_grad_values, false, true);
backward_program->block(), x_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
skip_names = details::GetNameFromValue(
backward_global_block, p_grad_values, false, true);
backward_program->block(), p_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
cache.UpdateSkipEagerDeleteVars(program_id,
Expand Down Expand Up @@ -1160,7 +1155,7 @@ inline void PirRunProgramGradAPI(
}
}

if (!backward_global_block->empty()) {
if (!backward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
paddle::platform::TracerEventType::UserDefined,
Expand All @@ -1175,9 +1170,11 @@ inline void PirRunProgramGradAPI(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Step 4. get outputs
details::ShareTensorsFromScopeByValue(
backward_global_block, x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(
backward_global_block, params_grad, p_grad_values, global_inner_scope);
backward_program->block(), x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(backward_program->block(),
params_grad,
p_grad_values,
global_inner_scope);
VLOG(4) << "after backward gc all vars";
global_inner_scope->SetCanReused(true);
details::GcScope(global_inner_scope);
Expand Down Expand Up @@ -1316,8 +1313,7 @@ class GradNodeRunProgram : public egr::GradNodeBase {
if (x[i].is_dense_tensor()) {
x_grad->emplace_back(std::make_shared<phi::DenseTensor>());
} else if (x[i].is_selected_rows()) {
auto selected_row = std::make_shared<phi::SelectedRows>();
x_grad->emplace_back(selected_row);
x_grad->emplace_back(std::make_shared<phi::SelectedRows>());
}
x_grad->back().set_name(x_grad_names[i]);
}
Expand Down Expand Up @@ -1446,6 +1442,10 @@ class PirGradNodeRunProgram : public egr::GradNodeBase {
VLOG(3) << "End Eager Backward Node: PirGradNodeRunProgram";

*executed_ = true;
egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&x_grad,
this->OutputMeta()[0]);
egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&params_grad,
this->OutputMeta()[1]);
return {x_grad, params_grad};
}

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/phi/common/complex.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/utils/blank.h"

Expand Down Expand Up @@ -977,6 +978,9 @@ struct SetAttrDescVisitor {
void operator()(const std::vector<pir::Block *> &v) const {
// just do nothing.
}
void operator()(const std::shared_ptr<pir::Program> &v) const {
// just do nothing.
}
void operator()(const std::vector<VarDesc *> &v) const {
std::vector<std::string> var_names;
for (auto var : v) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/type_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ template class variant<paddle::blank,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>,
::pir::Block*,
std::vector<::pir::Value>>;
std::vector<::pir::Value>,
std::shared_ptr<::pir::Program>>;
} // namespace paddle
REGISTER_LOG_SIMPLY_STR(paddle::framework::AttributeMap);
REGISTER_LOG_SIMPLY_STR(paddle::framework::Attribute);
4 changes: 3 additions & 1 deletion paddle/fluid/framework/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/utils/blank.h"
#include "paddle/utils/small_vector.h"
Expand Down Expand Up @@ -67,7 +68,8 @@ using Attribute = paddle::variant<paddle::blank,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>,
::pir::Block*,
std::vector<::pir::Value>>;
std::vector<::pir::Value>,
std::shared_ptr<::pir::Program>>;
using AttributeMap = std::unordered_map<std::string, Attribute>;

using OpCreator =
Expand Down
19 changes: 15 additions & 4 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,17 @@ void CastPyArg2AttrIRBlock(PyObject* obj,
attrs[key] = reinterpret_cast<::pir::Block*&>(vh[0]);
}

void CastPyArg2AttrIRProgram(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos) {
VLOG(1) << "After Process pir::Program*";
const std::shared_ptr<::pir::Program> program =
::py::handle(obj).cast<std::shared_ptr<::pir::Program>>();
attrs[key] = program;
}

void CastPyArg2AttrValues(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
Expand Down Expand Up @@ -1020,11 +1031,11 @@ void ConstructAttrMapForRunProgram(

if (std::set<std::string>({"cuda_graph_capture_mode"}).count(key)) {
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"global_block",
"forward_global_block",
"backward_global_block"})
.count(key)) {
} else if (std::set<std::string>({"global_block"}).count(key)) {
CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"forward_program", "backward_program"})
.count(key)) {
CastPyArg2AttrIRProgram(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"is_test", "use_interpretorcore"})
.count(key)) {
CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ void BindProgram(py::module *m) {
)DOC");
program
.def(py::init([]() {
return std::make_unique<Program>(pir::IrContext::Instance());
return std::make_shared<Program>(pir::IrContext::Instance());
}))
.def("__str__",
[](const std::shared_ptr<Program> &self) {
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,10 +914,10 @@ def _prune_unused_params(self, program):

def _prepare_attributes(self):
attrs = [
'forward_global_block',
self.program.forward_program.global_block(),
'backward_global_block',
self.program.backward_program.global_block(),
'forward_program',
self.program.forward_program,
'backward_program',
self.program.backward_program,
'is_test',
not self.training,
'program_id',
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/prim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(init_env_utils SRCS init_env_utils.cc)
target_compile_definitions(init_env_utils PUBLIC PADDLE_DLL_EXPORT)

paddle_test(test_comp_eager SRCS test_eager_prim.cc DEPS init_env_utils)
paddle_test(test_comp_eager SRCS test_eager_prim.cc init_env_utils.cc)
endif()

# skip win32 since wget is not installed by default on windows machine.
Expand Down
3 changes: 2 additions & 1 deletion test/dygraph_to_static/test_no_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir

import paddle

Expand All @@ -33,6 +33,7 @@ def main_func(x, index):


class TestNoGradientCase(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_no_gradient(self):
paddle.disable_static()
x = paddle.randn([10, 3])
Expand Down

0 comments on commit 4cd5384

Please sign in to comment.