-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
【AutoParallel】Add master grad in AMP-O2 of AutoParallel (#59987)
* add master_grad in auto-parallel * reset third_party * fix coverage * support bf16 master_grad * fix bug in master_grad * change code according to review * change the way to find optimizer op
- Loading branch information
1 parent
8ed3d18
commit f84fbdd
Showing
7 changed files
with
325 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
239 changes: 239 additions & 0 deletions
239
python/paddle/distributed/passes/auto_parallel_master_grad.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import copy | ||
import logging | ||
from collections import OrderedDict | ||
from typing import List, Tuple | ||
|
||
from paddle.base import Variable | ||
from paddle.distributed.auto_parallel.static.utils import ( | ||
is_backward_op, | ||
is_gradient_clip_op, | ||
is_optimize_op, | ||
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, | ||
set_var_dist_attr, | ||
) | ||
from paddle.distributed.fleet.meta_optimizers.common import ( | ||
OP_ROLE_KEY, | ||
OpRole, | ||
) | ||
from paddle.framework import core | ||
from paddle.static import program_guard | ||
|
||
from ..utils.log_utils import get_logger | ||
from .auto_parallel_sharding import _supported_optimizer_type | ||
from .pass_base import PassBase, register_pass | ||
|
||
logger = get_logger(logging.INFO, "MasterGradPass") | ||
|
||
|
||
def get_output_in_varlist(op, var_names) -> List[str]: | ||
grad_names = [] | ||
for output_name in op.output_arg_names: | ||
if output_name in var_names: | ||
grad_names.append(output_name) | ||
return grad_names | ||
|
||
|
||
@register_pass("auto_parallel_master_grad_pass") | ||
class MasterGradPass(PassBase): | ||
""" | ||
Use the high precision gradient to replace the low precision gradient in optimizer to avoid inf/nan values of low precision. | ||
The high precision gradient 'master grad' will be used by communication operator, `update_loss_scaling`, `GradClip` and `optimizer`. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def _check_self(self): | ||
return True | ||
|
||
def _check_conflict(self, other_pass): | ||
return True | ||
|
||
def _apply_single_impl(self, main_program, startup_program, context): | ||
self._completer = self.get_attr("completer") | ||
dist_context = self.get_attr("dist_context") | ||
params_grads = self.get_attr("params_grads") | ||
logger.debug(f"Origin main_program: {main_program}") | ||
self._add_master_grad(main_program, params_grads, dist_context) | ||
self._regenerate_optimizer( | ||
main_program, startup_program, params_grads, dist_context | ||
) | ||
logger.debug(f"After main program: {main_program}") | ||
|
||
def _add_cast_op(self, cur_block, grad_names: List[str], dist_context): | ||
grad_first_ids = OrderedDict() | ||
for idx, op in enumerate(cur_block.ops): | ||
if is_optimize_op(op): | ||
break | ||
elif is_backward_op(op): | ||
var_names = get_output_in_varlist(op, grad_names) | ||
for var_name in var_names: | ||
if var_name not in grad_first_ids: | ||
grad_first_ids[var_name] = idx | ||
# Communication operators such as 'allreduce_sum' use input var as output. | ||
else: | ||
pass | ||
|
||
# insert cast op | ||
for grad_name, idx in reversed(grad_first_ids.items()): | ||
grad_var = cur_block.var(grad_name) | ||
if ( | ||
grad_var.dtype == core.VarDesc.VarType.FP16 | ||
or grad_var.dtype == core.VarDesc.VarType.BF16 | ||
): | ||
is_fp16 = grad_var.dtype == core.VarDesc.VarType.FP16 | ||
producer_op = cur_block.ops[idx] | ||
producer_op_dist_attr = ( | ||
dist_context.get_op_dist_attr_for_program(producer_op) | ||
) | ||
assert ( | ||
producer_op_dist_attr is not None | ||
), f"The op: '{producer_op}' should be distributed" | ||
ref_output_dist_attr = ( | ||
producer_op_dist_attr.get_output_dist_attr(grad_name) | ||
) | ||
assert ( | ||
ref_output_dist_attr is not None | ||
), f"The output: '{grad_name}' should be distributed" | ||
ref_mesh = ref_output_dist_attr.process_mesh | ||
ref_dims_mapping = ref_output_dist_attr.dims_mapping | ||
ref_chunk_id = producer_op_dist_attr.chunk_id | ||
grad_half_precision_name = ( | ||
grad_name + '@tmp_fp16' | ||
if is_fp16 | ||
else grad_name + '@tmp_bf16' | ||
) | ||
grad_half_precision = cur_block.create_var( | ||
name=grad_half_precision_name, | ||
dtype=grad_var.dtype, | ||
shape=grad_var.shape, | ||
persistable=False, | ||
stop_gradient=False, | ||
) | ||
set_var_dist_attr( | ||
dist_context, | ||
grad_half_precision, | ||
ref_dims_mapping, | ||
ref_mesh, | ||
chunk_id=ref_chunk_id, | ||
) | ||
producer_op._rename_output(grad_name, grad_half_precision.name) | ||
grad_var.desc.set_dtype(core.VarDesc.VarType.FP32) | ||
cast_op = cur_block._insert_op_without_sync( | ||
idx + 1, | ||
type="cast", | ||
inputs={"X": grad_half_precision}, | ||
outputs={"Out": grad_var}, | ||
attrs={ | ||
"in_dtype": grad_half_precision.dtype, | ||
"out_dtype": grad_var.dtype, | ||
}, | ||
) | ||
cast_op._set_attr(OP_ROLE_KEY, OpRole.Backward) | ||
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( | ||
cast_op, | ||
ref_mesh, | ||
ref_dims_mapping, | ||
dist_context, | ||
chunk_id=ref_chunk_id, | ||
) | ||
cur_block._sync_with_cpp() | ||
|
||
def _regenerate_optimizer( | ||
self, | ||
main_program, | ||
startup_program, | ||
params_grads: List[Tuple[Variable, Variable]], | ||
dist_context, | ||
): | ||
grad_names = [g.name for _, g in params_grads] | ||
# 1. delete the origin optimizer op | ||
# 1.1 delete the var and op associated with the optimizer op in main_program | ||
main_ops = main_program.global_block().ops | ||
main_ops_len = len(main_ops) | ||
first_optimize_idx = main_ops_len | ||
for idx, op in enumerate(main_ops): | ||
# We don't delete the operators for check_nan_inf | ||
if is_optimize_op(op) and is_gradient_clip_op(op): | ||
first_optimize_idx = idx | ||
break | ||
assert ( | ||
first_optimize_idx < main_ops_len | ||
), "The first optimizer op is not found!" | ||
deleted_temp_var_names = [] | ||
deleted_persist_var_names = [] | ||
reserved_var_names = [] | ||
for idx in range(main_ops_len - 1, first_optimize_idx - 1, -1): | ||
op = main_ops[idx] | ||
inout_arg_names = op.input_arg_names + op.output_arg_names | ||
if op.type in _supported_optimizer_type: | ||
param_names = op.input("Param") | ||
skip_update_names = op.input("SkipUpdate") | ||
for reserved_name in param_names + skip_update_names: | ||
if reserved_name not in reserved_var_names: | ||
reserved_var_names.append(reserved_name) | ||
for input_name in inout_arg_names: | ||
if input_name in grad_names: | ||
continue | ||
var = main_program.global_block().var(input_name) | ||
if ( | ||
var.persistable | ||
and input_name not in deleted_persist_var_names | ||
): | ||
deleted_persist_var_names.append(input_name) | ||
elif ( | ||
not var.persistable | ||
and input_name not in deleted_temp_var_names | ||
): | ||
deleted_temp_var_names.append(input_name) | ||
main_program.global_block()._remove_op(idx) | ||
|
||
for var_name in deleted_temp_var_names + deleted_persist_var_names: | ||
if var_name not in reserved_var_names: | ||
main_program.global_block()._remove_var(var_name) | ||
main_program.global_block()._sync_with_cpp() | ||
|
||
# 1.2 delete the var and op in startup_program | ||
for reserved_name in reserved_var_names: | ||
if reserved_name in deleted_persist_var_names: | ||
deleted_persist_var_names.remove(reserved_name) | ||
startup_global_block = startup_program.global_block() | ||
for var_name in deleted_persist_var_names: | ||
if startup_global_block.has_var(var_name): | ||
startup_global_block._remove_var(var_name) | ||
for idx, op in reversed(list(enumerate(startup_global_block.ops))): | ||
inout_arg_names = op.input_arg_names + op.output_arg_names | ||
for var_name in inout_arg_names: | ||
if var_name in deleted_persist_var_names: | ||
startup_program.global_block()._remove_op(idx) | ||
break | ||
|
||
# 2. re-generate new optimizer op | ||
serial_optimizer = copy.deepcopy(dist_context._serial_optimizer) | ||
serial_optimizer._learning_rate = ( | ||
dist_context._serial_optimizer._learning_rate | ||
) | ||
serial_optimizer._sorted = False | ||
with program_guard(main_program, startup_program): | ||
with main_program.switch_name_generator_guard("opt_"): | ||
_ = serial_optimizer.apply_gradients(params_grads) | ||
self._completer.complete_update_annotation(main_program) | ||
|
||
def _add_master_grad(self, main_program, params_grads, dist_context): | ||
grad_names = [g.name for _, g in params_grads] | ||
for sub_block in main_program.blocks: | ||
self._add_cast_op(sub_block, grad_names, dist_context) |
Oops, something went wrong.