From 325fdf1d87b512fb679aad12cfcd7cae478711b3 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Thu, 23 Mar 2023 11:18:46 +0800 Subject: [PATCH] [Auto Parallel] Update rule based tuner (#51908) * add patterns * update rule based tuner * add forward sub program completion * add unittest * add bwd sub program completion --- .../distributed/auto_parallel/dist_context.py | 11 +- .../auto_parallel/operators/dist_default.py | 1 - .../auto_parallel/operators/dist_embedding.py | 6 + .../auto_parallel/operators/dist_matmul.py | 2 +- .../auto_parallel/operators/dist_scale.py | 87 +++ .../auto_parallel/tuner/rule_based_tuner.py | 553 +++++++++++++++++- .../unittests/auto_parallel/CMakeLists.txt | 1 + .../auto_parallel/test_dist_op_cost.py | 2 + .../auto_parallel/test_group_operators.py | 5 +- .../unittests/auto_parallel/test_pattern.py | 7 - .../auto_parallel/test_rule_based_tuner.py | 143 +++++ python/paddle/utils/flops.py | 17 +- 12 files changed, 806 insertions(+), 29 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 0db133236016d..22a83ae341d62 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -64,6 +64,7 @@ def __init__( fetch_vars={}, cluster=None, strategy=None, + json_config=None, ): # Data members related to original programs (unchanged) self._original_serial_main_program = serial_main_prog @@ -129,6 +130,8 @@ def __init__( # A flag indicates whether the used parallelism is data parallel self._data_parallel = False + self._json_config = json_config + @property def serial_main_program(self): return self._serial_main_program @@ -181,6 +184,10 @@ def serial_ordered_nodes(self): def process_meshes(self): return self._process_meshes + @process_meshes.setter + def process_meshes(self, val): + self._process_meshes = val + @property def pass_context(self): return self._pass_context @@ -397,7 +404,7 @@ def _restore( if dist: self._restore_dist_info(dist_mode) - def initialize(self, with_graph=True, with_cpp=False): + def initialize(self, with_graph=True, with_cpp=False, no_default=False): if not self._is_initialized: if not self._serial_main_program: if self._original_serial_main_program: @@ -418,7 +425,7 @@ def initialize(self, with_graph=True, with_cpp=False): if not self._serial_fetch_vars: self._restore_serial_fetch_vars() - self._init_dist_attr_for_program() + self._init_dist_attr_for_program(no_default) # Backup the original distributed information for later restore self._original_dist_tensors_for_program = copy.deepcopy( self._dist_tensors_for_program diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 54a6c959939c4..11537dde06428 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -174,7 +174,6 @@ def calc_bwd_cost(self, dist_op, ctx, cluster): varname ) mesh_shape = process_mesh.shape - batch_size_axis = var_dim_mapping[0] parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [varname + "@GRAD"] diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 08b00a5c7f63b..51e7f154f9deb 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -278,6 +278,12 @@ def is_input_compatible(self, dist_op): for mapping in ids_dims_mapping[1:]: if is_dim_shard(mapping): return False + + if is_dim_shard(ids_dims_mapping[0]) and is_dim_shard( + w_dims_mapping[-2] + ): + if ids_dims_mapping[0] == w_dims_mapping[-2]: + return False return True def is_output_compatible(self, dist_op): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 8266036c4ec8b..ee3c680aa5681 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -1507,7 +1507,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster): processes = process_mesh.process_ids # col parallel: matmul + allreduce if backward_op.attr("trans_y"): - Y_var_dim_mapping.reverse() + Y_var_dim_mapping = list(reversed(Y_var_dim_mapping)) assert Y_var_dim_mapping[0] < 0 parallel_axis = Y_var_dim_mapping[1] diff --git a/python/paddle/distributed/auto_parallel/operators/dist_scale.py b/python/paddle/distributed/auto_parallel/operators/dist_scale.py index a1a79f6c3b64e..e95e001b89000 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_scale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_scale.py @@ -12,10 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from paddle.distributed.fleet.meta_optimizers.common import OpRole + +from ..cost import ( + _g_op_cost_factory, + build_comp_costs_from_descs, + build_comp_desc_from_dist_op, + build_dp_costs, +) from ..utils import compute_compatible_and_update_dim_mapping from .common import ( DistributedOperatorImpl, DistributedOperatorImplContainer, + is_parameter_related, register_distributed_operator_impl, register_distributed_operator_impl_container, ) @@ -42,6 +51,84 @@ def __init__(self, name): def is_input_compatible(self, dist_op): return True + def calc_cost(self, op_role, dist_op, ctx, cluster): + """Calculate the cost by the op role.""" + cost = None + if int(op_role) == int(OpRole.Backward): + cost = self.calc_bwd_cost(dist_op, ctx, cluster) + else: + cost = self.calc_fwd_cost(dist_op, ctx, cluster) + assert cost is not None + return cost + + def calc_fwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) + processes = dist_op.dist_attr.process_mesh.process_ids + op_type = dist_op.serial_op.type + cost_mapping = build_comp_costs_from_descs( + _g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster + ) + res_cost = [cost_mapping] + + return res_cost + + def calc_bwd_cost(self, dist_op, ctx, cluster): + # calc comp op cost + res = [] + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + processes = process_mesh.process_ids + backward_op = dist_op.serial_op + op_type = backward_op.type + cost_mapping = build_comp_costs_from_descs( + _g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster + ) + res.append(cost_mapping) + + main_block = backward_op.block + need_gradient_allreduce = False + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and not is_parameter_related( + varname, main_block + ): + var_dim_mapping = dist_attr.get_input_dims_mapping(varname) + mesh_shape = process_mesh.shape + batch_size_axis = var_dim_mapping[0] + if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: + need_gradient_allreduce = True + break + + if need_gradient_allreduce: + for input_name in backward_op.desc.input_names(): + for varname in backward_op.desc.input(input_name): + if "@GRAD" not in varname and is_parameter_related( + varname, main_block + ): + var_dim_mapping = dist_attr.get_input_dims_mapping( + varname + ) + mesh_shape = process_mesh.shape + parallel_axis = batch_size_axis + attrs = {"use_calc_stream": True} + var_names = [varname + "@GRAD"] + build_dp_costs( + res, + dist_op, + ctx, + var_names, + attrs, + parallel_axis, + cluster, + ) + return res + def is_output_compatible(self, dist_op): return True diff --git a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py index 86038d97d22c7..6c74aac842dbf 100644 --- a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py @@ -12,9 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import logging import math +import os from abc import abstractmethod - +from collections import OrderedDict + +import paddle +from paddle.distributed.auto_parallel.completion import Completer +from paddle.distributed.auto_parallel.dist_attribute import ( + OperatorDistAttr, + TensorDistAttr, +) +from paddle.distributed.auto_parallel.dist_context import DistributedContext +from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor +from paddle.fluid import program_guard +from paddle.fluid.backward import append_backward +from paddle.fluid.framework import Parameter, unique_name + +from ...utils.log_utils import get_logger from ..graph import Graph _PATTERNS = {} @@ -548,6 +565,7 @@ def _compare_var_node(src, tgt): def _match_core(src_node, tgt_node): nonlocal not_matched + # not support one input name or output name corresponding to multiple vars if not_matched: return @@ -998,20 +1016,168 @@ def convert_to_process_meshes(device_mesh: list) -> list: class RuleBasedTuner: - def __init__(self, dist_context, mode="train"): + """ + A tuner based on rule from expert experience to search a good parallel strategy. + Args: + dist_context (DistributedContext): The distributed context. + mode (str): The mode of current task, it can be train or eval. Default: train. + level (str): The level of this tuner, it can be o1 or o2. + o2 level may find better strategy but need more time than o1. + If level is o1, it means all layers within same parallelism and place layers evenly when in pipeline parallism. + If level is o2, it means layers can has own parallelism and place layers may not evenly. + Default: o1. + """ + + def __init__(self, dist_context, mode="train", level="o1"): self._dist_context = dist_context + self._cluster = self._dist_context.cluster self._mode = mode + assert level in ["o1", "o2"] + self._level = level + self._logger = get_logger(logging.INFO) + self._use_dp = False - def cluster_operators(self, ops): - """ - Cluster operators to layers. + # forward sub program + self.fwd_sub_programs = OrderedDict() - Args: - ops (list): A operator list. + # dist_context of sub program + self.sub_programs_dist_context = OrderedDict() + + # graph of forward sub program + self.fwd_sub_program_graphs = OrderedDict() + + # full main program + self.full_main_program = None + + # full startup program + self.full_startup_program = None + + # full main program dist context + self.full_main_program_dist_context = None + + # tensor dist attribute from pattern setting + self.tensor_dist_attrs = {} + + # op original id to op mapping + self.op_original_id_to_op = {} + + # op original id to op idx in program + self.op_original_id_to_idx = {} + + # op original id to grad op original id mapping + self.op_original_id_to_grad_op_original_id = {} + + # all process meshes that the cluster can express + self.process_meshes = [] + + # all device meshes that the cluster can be partitioned + self.device_meshes_list = [] + + # the best cost of stage in a given device mesh + self.stage_best_cost_of_dm = {} + + # the best cost of stage in a given process mesh + self.stage_best_cost_of_pm = {} + + # the op clustering result + self.layers = [] + + self._is_run = True + if os.getenv("PADDLE_AUTO_PARALLEL_STAGE") != "tuner": + self._is_run = True + else: + self._is_run = False + self._strategy_path = None + if self._dist_context._json_config: + try: + self._strategy_path = self._dist_context._json_config[ + "tuner_save_path" + ] + except: + self._strategy_path = None + + @property + def dist_context(self): + return self._dist_context + + @property + def cluster(self): + return self._cluster + + @property + def mode(self): + return self._mode + + @property + def level(self): + return self._level + + def convert_process_mesh_to_key(self, process_mesh): + """Convert process mesh object to str.""" + processes = ",".join([str(x) for x in process_mesh._process_ids]) + topology = ",".join([str(x) for x in process_mesh._shape]) + key = processes + ";" + topology + return key + + def gen_full_program(self): + """Generate full program that contain backward and update phase program if mode is train.""" + self.full_main_program = self.dist_context.serial_main_program.clone() + if self.mode == "train": + self.full_startup_program = ( + self.dist_context.serial_startup_program.clone() + ) + loss = self.full_main_program.global_block().vars[ + self.dist_context.serial_loss.name + ] + serial_optimizer = self._dist_context.serial_optimizer + optimizer = copy.deepcopy(serial_optimizer) + self.full_main_program_dist_context = DistributedContext( + serial_main_prog=self.full_main_program, + serial_startup_prog=self.full_startup_program, + serial_loss=loss, + ) + # if in train mode, generate backward and update program. + with program_guard( + self.full_main_program, self.full_startup_program + ): + params_grads = append_backward( + loss, + distop_context=self.full_main_program_dist_context.dist_op_context, + ) + + with program_guard( + self.full_main_program, self.full_startup_program + ): + with unique_name.guard("opt_"): + optimizer_ops = optimizer.apply_gradients(params_grads) + + # op original id to grad op id + for idx, op in enumerate(self.full_main_program.global_block().ops): + self.op_original_id_to_op[op.desc.original_id()] = op + self.op_original_id_to_idx[op.desc.original_id()] = idx + + grad_op_id_to_op_id = ( + self.full_main_program_dist_context.dist_op_context.grad_op_id_to_op_id + ) + + for grad_op_original_id in grad_op_id_to_op_id: + op_id = grad_op_id_to_op_id[grad_op_original_id] + self.op_original_id_to_grad_op_original_id[ + op_id + ] = grad_op_original_id + + def cluster_operators(self): + """Group operators to layers.""" + ops = self._dist_context._serial_main_program.global_block().ops + + # clear op dist attr when user shard tensor or op but in the full auto parallel mode. + for op in ops: + op.dist_attr = OperatorDistAttr(op.desc) + + vars = self._dist_context._serial_main_program.global_block().vars + for var_name in vars: + vars[var_name].dist_attr = TensorDistAttr(vars[var_name].desc) - Returns: - List: The list contains the list of operators which belong to the same layer. - """ seq = [op.type for op in ops] while not OperatorClusteringUtil.stop_replace(seq): @@ -1061,6 +1227,7 @@ def cluster_operators(self, ops): to_replace_seq = OperatorClusteringUtil.replace_by_decomposed_seq( decomposed_sub_seq, to_replace_seq ) + result = seq[: to_replace_idxes[0]] if not has_merged: result.extend(to_replace_seq) @@ -1077,3 +1244,369 @@ def cluster_operators(self, ops): layers.append(layer) return layers + + def match_program(self, program): + """Use patterns to match the program and get tensor shard spec when pattern matched.""" + graph = GraphUtil.convert_to_graph(program.global_block()) + results = GraphUtil.match_all_patterns(graph) + if results: + for pattern_name in results.keys(): + pattern = _PATTERNS[pattern_name] + for parallelism in pattern.attrs["shard_spec"].keys(): + shard_spec = pattern.attrs["shard_spec"][parallelism] + for pattern_node_id in shard_spec.keys(): + for item in results[pattern_name]: + var_id = item[pattern_node_id] + var_desc_id = graph.attrs["id_to_var_desc_id"][ + var_id + ] + if var_desc_id not in self.tensor_dist_attrs: + self.tensor_dist_attrs[var_desc_id] = {} + self.tensor_dist_attrs[var_desc_id][ + parallelism + ] = shard_spec[pattern_node_id] + tensor_name = graph.attrs["id_to_var_name"][var_id] + self._logger.info( + "{}'s shard_spec may be {} when under {} parallelism.".format( + tensor_name, + shard_spec[pattern_node_id], + parallelism, + ) + ) + else: + self._logger.info( + "No pattern has be matched by this program. Currently, only the transformer-based models are supported. Data parallelism will be used." + ) + self._use_dp = True + + def gen_fwd_sub_programs_by_clone(self): + """Generate all forward sub programs by cloned from the original program.""" + for idx, layer in enumerate(self.layers): + sub_fwd_program = self._gen_fwd_sub_program_by_clone(layer) + self.fwd_sub_programs[idx] = sub_fwd_program + + def _gen_fwd_sub_program_by_clone(self, ops): + """Generate the forward sub program of the given ops.""" + program = paddle.static.Program() + block = ops[0].block + vars = block.vars + target_block = program.global_block() + with paddle.static.program_guard(program): + has_cloned_vars = set() + for op in ops: + new_op_desc = target_block.desc.append_op() + new_op_desc.copy_from(op.desc) + for var_name in op.input_arg_names: + if var_name not in has_cloned_vars: + if vars[var_name].is_parameter: + src_var = vars[var_name] + copied_kwargs = {} + copied_kwargs['trainable'] = src_var.trainable + copied_kwargs[ + 'optimize_attr' + ] = src_var.optimize_attr + copied_kwargs['regularizer'] = src_var.regularizer + copied_kwargs[ + 'do_model_average' + ] = src_var.do_model_average + copied_kwargs['need_clip'] = src_var.need_clip + + param = Parameter( + block=target_block, + type=src_var.type, + name=src_var.name, + shape=src_var.shape, + dtype=src_var.dtype, + lod_level=src_var.lod_level, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer, + **copied_kwargs + ) + else: + target_block._clone_variable(vars[var_name]) + target_block.vars[var_name].persistable = vars[ + var_name + ].persistable + target_block.vars[var_name].desc.set_original_id( + vars[var_name].desc.original_id() + ) + has_cloned_vars.add(var_name) + + for var_name in op.output_arg_names: + if var_name not in has_cloned_vars: + target_block._clone_variable(vars[var_name]) + target_block.vars[var_name].persistable = vars[ + var_name + ].persistable + target_block.vars[var_name].desc.set_original_id( + vars[var_name].desc.original_id() + ) + has_cloned_vars.add(var_name) + + target_block._sync_with_cpp() + + return program + + def _compelte_sub_fwd_program(self, idx, sub_fwd_program, process_mesh): + """Compelete forward sub program.""" + selective_parallelisms = ( + ["dp", "mp"] if len(process_mesh.shape) == 1 else ["dp_mp", "mp_dp"] + ) + for parallelism in selective_parallelisms: + has_set_tensor_count = 0 + dist_context = DistributedContext(sub_fwd_program) + has_set_dist_attr_tensors = set() + dist_context.process_meshes = [] + dist_context.add_process_mesh(process_mesh) + vars = sub_fwd_program.global_block().vars + + # clear op dist attr + ops = sub_fwd_program.global_block().ops + for op in ops: + op.dist_attr = OperatorDistAttr(op.desc) + # clear tensor dist attr + for var_name in vars: + vars[var_name].dist_attr = TensorDistAttr(vars[var_name].desc) + + for var_name in vars: + var_id = vars[var_name].desc.original_id() + if var_id in self.tensor_dist_attrs: + if parallelism in self.tensor_dist_attrs[var_id]: + dims_mapping = self.tensor_dist_attrs[var_id][ + parallelism + ] + dist_tensor = DistributedTensor(vars[var_name]) + dist_tensor.dist_attr.process_mesh = process_mesh + dist_tensor.dist_attr.dims_mapping = dims_mapping + dist_tensor.dist_attr.mark_annotated("dims_mapping") + dist_tensor.dist_attr.mark_annotated("process_mesh") + dist_context.add_dist_tensor_for_program(dist_tensor) + has_set_tensor_count += 1 + has_set_dist_attr_tensors.add(var_id) + + # check whether no dist attr in dist context + if has_set_tensor_count > 0: + dist_context.initialize(no_default=True) + completer = Completer(dist_context) + completer.complete_forward_annotation() + if parallelism not in self.sub_programs_dist_context[idx]: + self.sub_programs_dist_context[idx][parallelism] = {} + key = self.convert_process_mesh_to_key(process_mesh) + self.sub_programs_dist_context[idx][parallelism][ + key + ] = dist_context + else: + self._logger.info( + "No pattern has be matched under {} parallelism whe sub program is {}.".format( + parallelism, sub_fwd_program + ) + ) + + def complete_sub_fwd_programs(self, process_mesh): + """Complete all forward sub programs.""" + for idx in self.fwd_sub_programs.keys(): + sub_fwd_program = self.fwd_sub_programs[idx] + if idx not in self.sub_programs_dist_context: + self.sub_programs_dist_context[idx] = {} + self._compelte_sub_fwd_program(idx, sub_fwd_program, process_mesh) + + def _complete_sub_bwd_program(self, sub_program_dist_context): + """ + Complete the backward OP according to the forward OP. + Most of the logic is the same as the backward completion in the completer. + The difference is that find the backward OP according to the forward OP, + while find the forward OP according to the backward OP in the completer. + """ + + def _is_grad_var_name(name): + if "@GRAD" in name: + return True + return False + + sub_fwd_program = sub_program_dist_context.serial_main_program + block = sub_fwd_program.global_block() + vars = self.full_main_program.global_block().vars + ops = self.full_main_program.global_block().ops + grad_var_to_var = ( + self.full_main_program_dist_context.dist_op_context.grad_var_to_var[ + 1 + ] + ) + for forward_op in block.ops: + if ( + forward_op.desc.original_id() + not in self.op_original_id_to_grad_op_original_id + ): + continue + grad_op_id = self.op_original_id_to_grad_op_original_id[ + forward_op.desc.original_id() + ] + # for unsqueeze2 op in gpt, it has no grad op + # or for no need to bwd + if grad_op_id not in self.op_original_id_to_op: + continue + grad_op = self.op_original_id_to_op[grad_op_id] + if grad_op.type == "concat" and forward_op.type == "split": + forward_op_dist_attr = ( + sub_program_dist_context.get_op_dist_attr_for_program( + forward_op + ) + ) + output_var = vars[grad_op.desc.output('Out')[0]] + split_input_var_name = forward_op.input("X")[0] + ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + split_input_var_name + ) + ref_mesh = forward_op_dist_attr.process_mesh + + grad_op_dist_attr = OperatorDistAttr() + for input_name in grad_op.input_arg_names: + grad_op_dist_attr.set_input_dims_mapping( + input_name, ref_dims_mapping + ) + + output_var_dist_attr = TensorDistAttr() + output_var_dist_attr.dims_mapping = ref_dims_mapping + output_var_dist_attr.process_mesh = ref_mesh + sub_program_dist_context.set_tensor_dist_attr_for_program( + output_var, output_var_dist_attr + ) + + grad_op_dist_attr.set_output_dims_mapping( + output_var.name, ref_dims_mapping + ) + grad_op_dist_attr.process_mesh = ref_mesh + sub_program_dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr + ) + grad_op_dist_attr.impl_type = ( + fwd_op_dist_attr.impl_type # noqa: F821 + ) + grad_op_dist_attr.impl_idx = ( + fwd_op_dist_attr.impl_idx # noqa: F821 + ) + continue + + fwd_op_dist_attr = ( + sub_program_dist_context.get_op_dist_attr_for_program( + forward_op + ) + ) + fwd_op_process_mesh = fwd_op_dist_attr.process_mesh + grad_op_dist_attr = OperatorDistAttr() + grad_op_dist_attr.process_mesh = fwd_op_process_mesh + + for input_name in grad_op.input_arg_names: + if ( + input_name not in forward_op.input_arg_names + and input_name not in forward_op.output_arg_names + ): + if input_name in grad_var_to_var.keys(): + fwd_name = grad_var_to_var[input_name] + ref_dims_mapping = ( + fwd_op_dist_attr.get_output_dims_mapping(fwd_name) + ) + else: + input_var = vars[input_name] + ref_dims_mapping = sub_program_dist_context.get_tensor_dist_attr_for_program( + input_var + ).dims_mapping + else: + if input_name in forward_op.input_arg_names: + ref_dims_mapping = ( + fwd_op_dist_attr.get_input_dims_mapping(input_name) + ) + else: + ref_dims_mapping = ( + fwd_op_dist_attr.get_output_dims_mapping(input_name) + ) + assert ( + ref_dims_mapping is not None + ), "[{}] 's dims mapping is NONE".format(input_name) + grad_op_dist_attr.set_input_dims_mapping( + input_name, ref_dims_mapping + ) + + for output_name in grad_op.output_arg_names: + assert output_name in grad_var_to_var + fwd_name = grad_var_to_var[output_name] + ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping( + fwd_name + ) + # var + output_var = vars[output_name] + tensor_dist_attr = TensorDistAttr() + tensor_dist_attr.dims_mapping = ref_dims_mapping + tensor_dist_attr.process_mesh = fwd_op_process_mesh + sub_program_dist_context.set_tensor_dist_attr_for_program( + output_var, tensor_dist_attr + ) + # op + grad_op_dist_attr.set_output_dims_mapping( + output_name, ref_dims_mapping + ) + + grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type + grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx + sub_program_dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr + ) + + grad_op_idx = self.op_original_id_to_idx[grad_op_id] + if grad_op_idx + 1 < len(ops): + grad_op_next_op = ops[grad_op_idx + 1] + if grad_op_next_op.type == "sum": + assert all( + map(_is_grad_var_name, grad_op_next_op.input_arg_names) + ) + output_name = grad_op_next_op.output_arg_names[0] + assert ( + output_name in grad_var_to_var + ), "sum op's output '{}' has no corresponding var".format( + output_name + ) + ref_fwd_var_name = grad_var_to_var[output_name] + ref_fwd_var = vars[ref_fwd_var_name] + ref_fwd_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program( + ref_fwd_var + ) + ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping + ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh + + # output + tensor_dist_attr = TensorDistAttr() + tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping + tensor_dist_attr.process_mesh = ref_fwd_process_mesh + output_var = vars[output_name] + sub_program_dist_context.set_tensor_dist_attr_for_program( + output_var, tensor_dist_attr + ) + + # op + grad_op_dist_attr = OperatorDistAttr() + grad_op_dist_attr.process_mesh = ref_fwd_process_mesh + + for var_name in grad_op_next_op.input_arg_names: + grad_op_dist_attr.set_input_dims_mapping( + var_name, ref_fwd_dims_mapping + ) + grad_op_dist_attr.set_output_dims_mapping( + output_name, ref_fwd_dims_mapping + ) + grad_op_dist_attr.impl_type = "default" + grad_op_dist_attr.impl_idx = 0 + + sub_program_dist_context.set_op_dist_attr_for_program( + grad_op_next_op, grad_op_dist_attr + ) + + def complete_sub_bwd_programs(self): + for idx in self.sub_programs_dist_context: + for parallelism in self.sub_programs_dist_context[idx]: + for key in self.sub_programs_dist_context[idx][parallelism]: + sub_program_dist_context = self.sub_programs_dist_context[ + idx + ][parallelism][key] + self._complete_sub_bwd_program(sub_program_dist_context) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 02fb7175b289b..d64aaf35b4d46 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -127,6 +127,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_pass_bf16 MODULES test_pass_bf16) py_test_modules(test_dist_saver MODULES test_dist_saver) py_test_modules(test_engine_save_load MODULES test_engine_save_load) + py_test_modules(test_rule_based_tuner MODULES test_rule_based_tuner) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py index 4eb0408976aba..ecff2bbf8935b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py @@ -178,6 +178,7 @@ def make_program(): [None, None], ) tmp_out = paddle.matmul(out1, tmp_param) + tmp_out = paddle.scale(tmp_out, 0.5) out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0] out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1] @@ -286,6 +287,7 @@ def make_program(): ) tmp_out = paddle.matmul(out1, tmp_param) + tmp_out = paddle.scale(tmp_out, 0.5) out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0] out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py index 2823d4d9a318c..e1d8eb8d37903 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py @@ -119,9 +119,10 @@ def test_gpt(self): RuleBasedTuner, ) - dist_context = DistributedContext() + dist_context = DistributedContext(train_program) + dist_context.initialize() tuner = RuleBasedTuner(dist_context) - layers = tuner.cluster_operators(train_program.global_block().ops) + layers = tuner.cluster_operators() op_types = [] for layer in layers: tmp = [] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py index 047b9c7507fbf..b239de918a251 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py @@ -112,18 +112,11 @@ def test_gpt(self): sequence_len, vocab_size, ) - from paddle.distributed.auto_parallel.dist_context import ( - DistributedContext, - ) from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( _PATTERNS, GraphUtil, - RuleBasedTuner, ) - dist_context = DistributedContext() - tuner = RuleBasedTuner(dist_context) - layers = tuner.cluster_operators(train_program.global_block().ops) graph = GraphUtil.convert_to_graph(train_program.global_block()) print("graph: ", graph) print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"]) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py new file mode 100644 index 0000000000000..d1285b7895e17 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py @@ -0,0 +1,143 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np + +import paddle +import paddle.static as static + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import ( + GPTForPretraining, + GPTModel, + GPTPretrainingCriterion, +) + + +def get_gpt_model( + train_program, start_program, place, batch_size, sequence_len, vocab_size +): + with static.program_guard(train_program, start_program): + tokens = paddle.static.data( + name="tokens", shape=[batch_size, sequence_len], dtype='int64' + ) + position_ids = paddle.static.data( + name="position_ids", shape=[batch_size, sequence_len], dtype='int64' + ) + attention_mask = paddle.static.data( + name="attention_mask", + shape=[batch_size, 1, sequence_len, sequence_len], + dtype='float32', + ) + labels = paddle.static.data( + name="labels", shape=[batch_size, sequence_len], dtype='int64' + ) + loss_mask = paddle.static.data( + name="loss_mask", shape=[batch_size, sequence_len], dtype='float32' + ) + + gpt = GPTModel( + vocab_size=1000, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=8, + intermediate_size=256, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + ) + + model = GPTForPretraining( + gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02 + ) + preds = model(tokens, position_ids, attention_mask) + criterion = GPTPretrainingCriterion() + loss = criterion(preds, labels, loss_mask) + + def gen_data(): + np.random.seed(2021) + tokens = [] + position_ids = [] + attention_mask = [] + labels = [] + loss_mask = [] + for _ in range(batch_size): + tokens.append(np.random.randint(vocab_size, size=sequence_len)) + position_ids.append(np.arange(sequence_len)) + attention_mask.append([np.tril(np.ones(sequence_len))]) + labels.append(np.random.randint(vocab_size, size=sequence_len)) + loss_mask.append(np.ones(sequence_len)) + + return tokens, position_ids, attention_mask, labels, loss_mask + + return train_program, start_program, loss, gen_data + + +class TestRuleBasedTuner(unittest.TestCase): + def test_gpt(self): + modeling.init_global() + train_program = static.Program() + start_program = static.Program() + place = paddle.set_device("gpu") + batch_size = 8 + sequence_len = 512 + vocab_size = 1000 + train_program, start_program, loss, gen_data = get_gpt_model( + train_program, + start_program, + place, + batch_size, + sequence_len, + vocab_size, + ) + from paddle.distributed.auto_parallel.dist_context import ( + DistributedContext, + ) + from paddle.distributed.auto_parallel.process_mesh import ProcessMesh + from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( + RuleBasedTuner, + ) + + clip = paddle.nn.ClipGradByGlobalNorm(0.2) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + dist_context = DistributedContext( + serial_main_prog=train_program, + serial_startup_prog=start_program, + serial_optimizer=opt, + serial_loss=loss, + ) + dist_context.initialize() + tuner = RuleBasedTuner(dist_context) + tuner.cluster_operators() + tuner.gen_full_program() + tuner.match_program(tuner._dist_context.serial_main_program) + process_mesh = ProcessMesh([0, 1]) + tuner.gen_fwd_sub_programs_by_clone() + tuner.complete_sub_fwd_programs(process_mesh) + tuner.complete_sub_bwd_programs() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/utils/flops.py b/python/paddle/utils/flops.py index a659a57206be9..2ff1d582cfc7b 100644 --- a/python/paddle/utils/flops.py +++ b/python/paddle/utils/flops.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy _FLOPS_COMPUTE_FUNC_MAP = {} @@ -244,8 +245,12 @@ def _matmul_flops(input_shapes, attrs): equation: flops = 2 * numel(output) * dim_n """ - x_shape = input_shapes.get("X", input_shapes.get("x", [[0]]))[0] - y_shape = input_shapes.get("Y", input_shapes.get("y", [[0]]))[0] + x_shape = copy.deepcopy( + input_shapes.get("X", input_shapes.get("x", [[0]]))[0] + ) + y_shape = copy.deepcopy( + input_shapes.get("Y", input_shapes.get("y", [[0]]))[0] + ) if attrs.get('transpose_X') or attrs.get('transpose_x'): x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1] @@ -276,11 +281,11 @@ def _matmul_v2_flops(input_shapes, attrs): shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m] equation: flops = 2 * numel(outputs) * dim_n """ - x_shape = input_shapes.get('X')[0] - y_shape = input_shapes.get('Y')[0] - if attrs.get('trans_x') is not None: + x_shape = copy.deepcopy(input_shapes.get('X')[0]) + y_shape = copy.deepcopy(input_shapes.get('Y')[0]) + if attrs.get('trans_x'): x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1] - if attrs.get('trans_y') is not None: + if attrs.get('trans_y'): y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1] dim_x = len(x_shape) dim_y = len(y_shape)