Skip to content

Commit

Permalink
[Eager] Fix edvr starganv2 (#43471)
Browse files Browse the repository at this point in the history
* fix starganv2

* fix starganv2 stop_gradient end error

* fix edvr_starganv2

* fix mul kernel to fix optional ddx

* fix typo
  • Loading branch information
JiabinYang authored Jun 14, 2022
1 parent 8cec127 commit c62a7e2
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 63 deletions.
32 changes: 16 additions & 16 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,8 @@ static std::string GenerateGradNodeCreationContent(
size_t bwd_in_slot_num = out_vars.size();
size_t bwd_out_slot_num = in_vars.size();
const char* GRAD_OP_NODE_TEMPLATE =
" auto grad_node = std::shared_ptr<GradNode%s>(new GradNode%s(%d, "
" auto grad_node = std::shared_ptr<%sGradNodeCompat>(new "
"%sGradNodeCompat(%d, "
"%d));\n";
grad_node_creation_str += " // Create GradOpNode\n";
grad_node_creation_str +=
Expand Down Expand Up @@ -2080,10 +2081,8 @@ static std::string GenerateSingleOpBase(
generated_grad_function_body +=
" paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> " +
hooked_grads +
" = "
"GradNode" +
fwd_op_type + "::ApplyGradientHooks(grads);\n";
hooked_grads + " = " + fwd_op_type +
"GradNodeCompat::ApplyGradientHooks(grads);\n";

// [Generation] Get Ins Map
std::unordered_set<std::string> dispensable_input_name_set;
Expand Down Expand Up @@ -2547,7 +2546,7 @@ static std::string GenerateGradNodeCCContents(
*/

const char* EAGER_LOG_TEMPLATE =
" VLOG(3) << \"Running Eager Backward Node: GradNode%s\";\n";
" VLOG(3) << \"Running Eager Backward Node: %sGradNodeCompat\";\n";
std::string generated_grad_function_body =
paddle::string::Sprintf(EAGER_LOG_TEMPLATE, fwd_op_type);

Expand Down Expand Up @@ -2616,7 +2615,7 @@ static std::string GenerateGradNodeCCContents(
const char* GRAD_FUNCTION_TEMPLATE =
"paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize> "
"GradNode%s::operator()("
"%sGradNodeCompat::operator()("
"paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
"egr::kSlotSmallVectorSize>& grads, bool "
"create_graph, bool is_new_grad) {\n"
Expand Down Expand Up @@ -2645,14 +2644,15 @@ static std::string GenerateGradNodeHeaderContents(
VLOG(6) << "Generating Grad Node Header";

const char* GRAD_NODE_TEMPLATE =
"class GradNode%s : public egr::GradNodeBase {\n"
"class %sGradNodeCompat : public egr::GradNodeBase {\n"
" public:\n"
" GradNode%s() : egr::GradNodeBase() { VLOG(7) << \" Construct "
"GradNode%s \"; }\n"
" GradNode%s(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : "
" %sGradNodeCompat() : egr::GradNodeBase() { VLOG(7) << \" Construct "
"%sGradNodeCompat \"; }\n"
" %sGradNodeCompat(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : "
"egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) { VLOG(7) << \" "
"Construct GradNode%s \"; }\n"
" ~GradNode%s() override { VLOG(6) << \" Destruct GradNode%s \"; }\n"
"Construct %sGradNodeCompat \"; }\n"
" ~%sGradNodeCompat() override { VLOG(6) << \" Destruct "
"%sGradNodeCompat \"; }\n"
"\n"
" virtual "
"paddle::small_vector<std::vector<paddle::experimental::Tensor>, "
Expand All @@ -2667,11 +2667,11 @@ static std::string GenerateGradNodeHeaderContents(
"%s\n"
" SetIsTensorWrappersCleared(true);\n"
" }\n"
" std::string name() override { return \"GradNode%sMid\"; } \n "
" std::string name() override { return \"%sGradNodeCompat\"; } \n "
"\n"
"std::shared_ptr<GradNodeBase> Copy() const override {{\n "
" auto copied_node = std::shared_ptr<GradNode%s>(new "
"GradNode%s(*this));\n "
" auto copied_node = std::shared_ptr<%sGradNodeCompat>(new "
"%sGradNodeCompat(*this));\n "
" return copied_node;\n "
"}}\n "
"\n"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand Down Expand Up @@ -147,7 +147,18 @@ def RemoveConstAndReference(string):


def GetGradNodeName(string):
return f"GradNode{string}Final"

def str2Hump(text):
arr = filter(None, text.split('_'))
res = ''
for i in arr:
res = res + i[0].upper() + i[1:]
return res

string = str2Hump(string)
if string.rfind("Grad") == (len(string) - 4):
string = string[:-4]
return f"{string}GradNodeFinal"


def GetDygraphForwardFunctionName(string):
Expand Down Expand Up @@ -335,6 +346,7 @@ def ParseYamlInplaceInfo(string):
### Generator Base ###
########################
class FunctionGeneratorBase:

def __init__(self, forward_api_contents, namespace):
self.forward_api_contents = forward_api_contents
self.namespace = namespace
Expand All @@ -357,7 +369,7 @@ def __init__(self, forward_api_contents, namespace):
# Special Op Attributes
self.optional_inputs = [] #[name, ...]
self.no_need_buffers = [] #[name, ...]
self.intermediate_outputs = [] #[name, ...]
self.intermediate_outputs = [] #[name, ...]
self.forward_inplace_map = {} #{name : name, ...}

def ParseForwardInplaceInfo(self):
Expand Down Expand Up @@ -423,20 +435,23 @@ def DetermineForwardPositionMap(self, forward_inputs_list,
input_type = forward_input[1]
input_pos = forward_input[2]

self.forward_inputs_position_map[
input_name] = [input_type, input_pos]
self.forward_inputs_position_map[input_name] = [
input_type, input_pos
]

for i in range(len(forward_returns_list)):
forward_return = forward_returns_list[i]
return_name = forward_return[0]
return_type = forward_return[1]
return_pos = forward_return[2]

self.forward_outputs_position_map[
return_name] = [return_type, return_pos]
self.forward_outputs_position_map[return_name] = [
return_type, return_pos
]


class GeneratorBase:

def __init__(self, api_yaml_path):
self.namespace = ""
self.api_yaml_path = api_yaml_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def GenerateCoreOpInfoDefinition():
## Generator Class ##
#####################
class DygraphFunctionGeneratorBase(FunctionGeneratorBase):

def __init__(self, forward_api_contents, grad_api_contents, namespace):
self.forward_api_contents = forward_api_contents
# Members from Parent:
Expand Down Expand Up @@ -532,8 +533,8 @@ def ForwardsValidationCheck(self):
max_input_position = max(max_input_position, pos)

for _, _, _, pos in forward_attrs_list:
assert pos > max_input_position, AssertMessage(pos,
max_input_position)
assert pos > max_input_position, AssertMessage(
pos, max_input_position)

def BackwardValidationCheck(self):
backward_forward_inputs_map = self.backward_forward_inputs_map
Expand Down Expand Up @@ -678,7 +679,7 @@ def GenerateNodeCreationCodes(self):
# Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys())
num_backward_outputs = len(forward_inputs_position_map.keys())
grad_node_name = GetGradNodeName(forward_api_name)
grad_node_name = GetGradNodeName(self.backward_api_name)

# Helper
indent = GetIndent(2)
Expand Down Expand Up @@ -845,6 +846,7 @@ def run(self):


class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):

def __init__(self, forward_api_contents, grad_api_contents, namespace):
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace)
Expand Down Expand Up @@ -947,12 +949,12 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
if is_inplaced and len(forward_outputs_position_map) == 1:
api_out_type = "auto&"
forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
num_outputs = len(
forward_outputs_position_map.keys()) - len(intermediate_outputs)

# Check Nan and Inf
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(function_name,
"api_result")
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
function_name, "api_result")

# Get Outputs
get_outputs_str = ""
Expand Down Expand Up @@ -1007,8 +1009,8 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
if pos == corresponding_pos:
has_corresponding_grad_output = True
if has_corresponding_grad_output or (
name in forward_inplace_map and
forward_api_name not in inplace_check_blacklist):
name in forward_inplace_map
and forward_api_name not in inplace_check_blacklist):
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
Expand Down Expand Up @@ -1116,17 +1118,20 @@ def UpdateCoreOpsInformation(self, is_inplaced):
forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list

num_args = len(forward_inputs_position_map.keys()) + len(
forward_attrs_list)
num_args = len(
forward_inputs_position_map.keys()) + len(forward_attrs_list)
num_returns = len(forward_outputs_position_map.keys())

final_state_fwd_api_name = "final_state_" + forward_api_name
core_ops_returns_info[
final_state_fwd_api_name] = ["" for i in range(num_returns)]
core_ops_args_info[
final_state_fwd_api_name] = ["" for i in range(num_args)]
core_ops_args_type_info[
final_state_fwd_api_name] = ["" for i in range(num_args)]
core_ops_returns_info[final_state_fwd_api_name] = [
"" for i in range(num_returns)
]
core_ops_args_info[final_state_fwd_api_name] = [
"" for i in range(num_args)
]
core_ops_args_type_info[final_state_fwd_api_name] = [
"" for i in range(num_args)
]

for name, (ttype, pos) in forward_inputs_position_map.items():
core_ops_args_info[final_state_fwd_api_name][pos] = name
Expand Down Expand Up @@ -1159,6 +1164,7 @@ def run(self):


class DygraphNodeGenerator(DygraphFunctionGeneratorBase):

def __init__(self,
forward_api_contents,
grad_api_contents,
Expand All @@ -1167,7 +1173,7 @@ def __init__(self,
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace)

# Record name mapping from forward_api_name to grad_api_names
# Record name mapping from forward_var_name to grad_var_names
self.to_next_grad_name_mapping = {} # {name : name}

# Generated Results
Expand Down Expand Up @@ -1281,7 +1287,7 @@ def GenerateNodeDeclaration(self):
attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
RemoveConstAndReference(atype), saved_attr_name)

grad_node_name = GetGradNodeName(forward_op_name)
grad_node_name = GetGradNodeName(self.backward_api_name)
self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, clear_tensor_wrapper_str, grad_node_name,
Expand Down Expand Up @@ -1447,8 +1453,8 @@ def GenerateNodeDefinition(self, next_grad_node_creation_str,
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});"""

# Check Nan and Inf
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(backward_api_name,
"returns")
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
backward_api_name, "returns")

# Prepare for Node Creation if Necessary
inputs_autograd_meta_str = ""
Expand Down Expand Up @@ -1533,7 +1539,7 @@ def GenerateNodeDefinition(self, next_grad_node_creation_str,
returns_str = f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"{indent}return returns;\n"

grad_node_name = GetGradNodeName(forward_api_name)
grad_node_name = GetGradNodeName(self.backward_api_name)

self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
Expand All @@ -1560,6 +1566,7 @@ def run(self):


class DygraphForwardAndNodesGenerator(GeneratorBase):

def __init__(self, api_yaml_path, backward_yaml_path):
# Parent members:
# self.namespace
Expand Down Expand Up @@ -1617,9 +1624,10 @@ def GenerateCode(self):
next_grad_api_contents = self.GetBackwardAPIContents(
backward_api_contents)

node_generator = DygraphNodeGenerator(
forward_api_contents, backward_api_contents, namespace,
next_grad_api_contents)
node_generator = DygraphNodeGenerator(forward_api_contents,
backward_api_contents,
namespace,
next_grad_api_contents)
node_generator.run()
self.node_declaration_str += node_generator.node_declaration_str + "\n"
self.node_definition_str += node_generator.node_definition_str + "\n"
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/eager/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
const std::vector<paddle::experimental::Tensor>& inputs = {},
bool allow_unused = false,
const std::vector<paddle::experimental::Tensor>& no_grad_vars = {}) {
VLOG(6) << "Start Backward";
VLOG(3) << "Start Backward";

// *Gradient Hook should happen at node-level
// *Inplace version check should perform at node-level
Expand Down Expand Up @@ -634,7 +634,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
}

VLOG(6) << "Update In degree Map for backward";
VLOG(3) << "Update In degree Map for backward";
// 3. Compute in_degree for each node
std::unordered_map<GradNodeBase*, int> node_in_degree_map =
getInDegreeMap(queue);
Expand All @@ -654,7 +654,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// |- node(grads)
// |- Prepare for next node
// 3. Update queue
VLOG(6) << "Run Backward";
VLOG(3) << "Run Backward";
while (!queue.empty()) {
GradNodeBase* node = queue.front();
VLOG(6) << "Running GradNode:" << node->name();
Expand Down Expand Up @@ -739,7 +739,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// Since we make edge has as same rank as bwd outputs, we indexing them
// with the same rank(i, j)
auto next_node_shared = edge.GetMutableGradNode();

VLOG(3) << "Found pending node: " << next_node_shared->name();
// Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached
// Or it could also originated from dispensable inputs
Expand Down Expand Up @@ -826,7 +826,7 @@ void Backward(
const std::vector<paddle::experimental::Tensor>& tensors, // outputs
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph) {
VLOG(6) << "Run in Backward";
VLOG(3) << "Run in Backward";
paddle::platform::RecordEvent backward_record_event(
"backward", paddle::platform::TracerEventType::Operator, 1);
RunBackward(tensors, grad_tensors, retain_graph);
Expand All @@ -839,7 +839,7 @@ std::vector<paddle::experimental::Tensor> Grad(
const std::vector<paddle::experimental::Tensor>& grad_tensors,
bool retain_graph, bool create_graph, bool only_inputs, bool allow_unused,
const std::vector<paddle::experimental::Tensor>& no_grad_vars) {
VLOG(6) << "Run in Grad";
VLOG(3) << "Run in Grad";

DuplicateCheck(inputs, true /* is_input */);
DuplicateCheck(tensors, false /* is_input */);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/eager/grad_node_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ void GradNodeBase::SetGradOutMeta(const paddle::experimental::Tensor& fwd_in,
fwd_in_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(fwd_in_meta));
}
VLOG(6) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
VLOG(3) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
<< this->name() << " (addr: " << this << ") "
<< " to " << fwd_in_meta->GetMutableGradNode()->name()
<< " (addr: " << fwd_in_meta->GetMutableGradNode().get() << ")";
Expand Down Expand Up @@ -281,7 +281,7 @@ void GradNodeBase::SetGradOutMeta(
fwd_in_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(fwd_in_meta));
}
VLOG(6) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
VLOG(3) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
<< this->name() << " (addr: " << this << ") "
<< " to " << fwd_in_meta->GetMutableGradNode()->name()
<< " (addr: " << fwd_in_meta->GetMutableGradNode().get() << ")";
Expand Down
Loading

0 comments on commit c62a7e2

Please sign in to comment.