From b309a7246e852b6b7fa4b1c1dcf4a88aa3842669 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 19 Oct 2022 12:01:25 +0800 Subject: [PATCH 01/36] partition param by order --- .../passes/auto_parallel_sharding.py | 95 ++++++++++++++++++- 1 file changed, 91 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 171188618efe2..05fbfe1e5888e 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -21,6 +21,7 @@ from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.operators.common import is_parameter_related, is_data_parallel_reduce_op from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr +from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_size OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() @@ -99,6 +100,14 @@ def _apply_single_impl(self, main_program, startup_program, context): main_block, startup_block = main_program.global_block( ), startup_program.global_block() + # NOTE Multi / Sub-Block Support + # we assume that only parameter are present and partitioned in main_block, + # there is NO new param in sub_block, and all params in sub_block follows the same + # partition as main_block. the above contraint fullfill the 3 most common use-cases in Paddle sub_block: + # 1. subblock for lr scheduler + # 2. sub-block uses the same or partial network of main-block, e.g. GPT3 generation model + # 3. sub-block used for double backward + self._build_sharding_groups(main_block, params_grads) self._shard_optimizer(main_block, startup_block, params_grads, context) self._shard_gradient_synchronization(main_block) @@ -108,7 +117,7 @@ def _apply_single_impl(self, main_program, startup_program, context): def _build_sharding_groups(self, main_block, params_grads): self._collective_data_parallel_groups(main_block) - self._build_sharding_infos(params_grads) + self._build_sharding_infos(main_block, params_grads) def _collective_data_parallel_groups(self, main_block): for op in main_block.ops: @@ -130,8 +139,12 @@ def _collective_data_parallel_groups(self, main_block): "So far Only and Exactly one data parallel group in network are supported, but got [{}] different data parallel groups" .format(len(self.dp_groups))) - def _build_sharding_infos(self, params_grads): + def _build_sharding_infos(self, main_block, params_grads): + + # order params + params_grads = _order_param_grads(main_block, params_grads) + # partition for dp_group in self.dp_groups: assert dp_group.nranks >= self.sharding_world_size, "sharding world size [{}] should not larger than dp world size [{}]".format( @@ -705,7 +718,40 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context): return dp_group -def shard_parameters(params, group_size): +def partition_by_use_order(params, group_size): + """ + shard the continouse param into same rank and divide the forward&backward computation into segement, + which will favor the fuse pass in later. + + we assume that the params is already sorted by utilization order. + """ + mapping = {} + total_param_mem = 0.0 + param2mem = [] + for param in params: + mem = get_var_size(param) + total_param_mem += mem + param2mem.append((param, mem)) + mapping = {x: [] for x in range(group_size)} + cur_rank = 0 + mem_accu = 0.0 + for param, mem in param2mem: + if mem_accu > total_param_mem * 1.0 * (cur_rank + 1) / group_size: + cur_rank += 1 + mapping[cur_rank].append(param) + mem_accu += mem + print() + print("######" * 6) + for k, v in mapping: + print("rank:{}, size:{}.".format(k, + sum([get_var_size(var) for var in v]))) + print([var.name for var in v]) + print("######" * 6) + print() + return mapping + + +def partition_by_greedy_even(params, group_size): # TODO(JZ-LIANG) support multiple partition methods # method1: greedy even but unorder # method2: roughly even with oreder @@ -721,9 +767,25 @@ def shard_parameters(params, group_size): param.name, numel) sizes[rank] += numel + print() + print("######" * 6) + for k, v in mapping: + print("rank:{}, size:{}.".format(k, + sum([get_var_size(var) for var in v]))) + print([var.name for var in v]) + print("######" * 6) + print() + return mapping +def partition_parameters(params, group_size, algor="greedy_even"): + if algor == "greedy_even": + return partition_by_greedy_even(params, group_size) + else: + return partition_by_use_order(params, group_size) + + class ShardingInfo(object): def __init__(self, group, rank, params_grads): @@ -738,7 +800,9 @@ def __init__(self, group, rank, params_grads): self.global_rank = rank self.local_rank = group.ranks.index(self.global_rank) # rank in below mapping are local rank in this sharding group - self.rank_to_params = shard_parameters(self.params, self.group_size) + self.rank_to_params = partition_parameters(self.params, + self.group_size, + algor="use_order") # include fp32 and fp16 param self.param_to_rank = dict() self._map_param_to_rank() @@ -800,3 +864,26 @@ def get_param_grad(self, param_name): if param_name not in self.params_grads: raise ValueError('param[{}] not in params_grads'.format(param_name)) return self.params_grads.get(param_name, None) + + +def _order_param_grads(block, param_grads): + print() + print("######" * 6) + print("the parameter order before sort: ") + print([p.name for p, g in param_grads]) + pname_to_pg_pairs = {} + for p, g in param_grads: + pname_to_pg_pairs[p.name] = (p, g) + + use_order = [] + for op in block.ops: + for input_name in op.input_arg_names(): + if (input_name in pname_to_pg_pairs) and (input_name not in order): + use_order.append(input_name) + if len(order) == len(pname_to_pg_pairs): + break + print("the parameter order after sort: ") + print(use_order) + print("######" * 6) + print() + return [pname_to_pg_pairs[p] for p in use_order] From c8c9fb74fde62bf945d276519cc42c4588c4329a Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 21 Oct 2022 11:35:36 +0800 Subject: [PATCH 02/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 05fbfe1e5888e..b0c6d6fd60fe7 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -452,7 +452,7 @@ def _shard_parameter(self, main_block, startup_block): if is_optimizer_op(op): continue - for input_name in op.desc.input_arg_names(): + for input_name in op.input_arg_names(): # NOTE hack for embedding op when AMP 02-3 # paddle amp force embedding (lookup table) to be run on fp32 if _is_param_fp16_cast_op(main_block, op, @@ -837,7 +837,7 @@ def get_broadcast_vars_and_param_usage(self, block): for op in block.ops: if is_optimizer_op(op): continue - for input_name in op.desc.input_arg_names(): + for input_name in op.input_arg_names: if input_name in self.param_names: param_usage[input_name] += 1 @@ -877,7 +877,7 @@ def _order_param_grads(block, param_grads): use_order = [] for op in block.ops: - for input_name in op.input_arg_names(): + for input_name in op.input_arg_names: if (input_name in pname_to_pg_pairs) and (input_name not in order): use_order.append(input_name) if len(order) == len(pname_to_pg_pairs): From a88af49a04e39c0166f07530ca013b20523c5fda Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 21 Oct 2022 11:40:45 +0800 Subject: [PATCH 03/36] bugfix --- .../distributed/passes/auto_parallel_sharding.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index b0c6d6fd60fe7..e89bd16ded073 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -284,7 +284,7 @@ def _shard_gradient_clip(self, main_block): reserved_vars.append(input_name) op.desc.set_input("X", reserved_vars) - sum_op_output = op.desc.output_arg_names()[0] + sum_op_output = op.output_arg_names[0] for i, sharding_info in enumerate(self.sharding_infos): new_op = main_block._insert_op( idx + i + 1, @@ -452,7 +452,7 @@ def _shard_parameter(self, main_block, startup_block): if is_optimizer_op(op): continue - for input_name in op.input_arg_names(): + for input_name in op.input_arg_names: # NOTE hack for embedding op when AMP 02-3 # paddle amp force embedding (lookup table) to be run on fp32 if _is_param_fp16_cast_op(main_block, op, @@ -617,7 +617,7 @@ def _is_param_grad_fp32_cast_op(block, op): if not _is_desired_cast_op(block, op, core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32): return False - output_name = op.desc.output_arg_names()[0] + output_name = op.output_arg_names[0] base_name = output_name[:output_name.find("@")] if not block.has_var(base_name): return False @@ -630,7 +630,7 @@ def _is_param_fp16_cast_op(block, op, params): return False if not _is_desired_cast_op(block, op): return False - input_name = op.desc.input_arg_names()[0] + input_name = op.input_arg_names[0] if input_name not in params: return False return True @@ -642,10 +642,10 @@ def _is_desired_cast_op(block, dst_var_type=core.VarDesc.VarType.FP16): if op.type != "cast": return False - assert (len(op.desc.input_arg_names()) == 1) - assert (len(op.desc.output_arg_names()) == 1) - input_var = block.var(op.desc.input_arg_names()[0]) - output_var = block.var(op.desc.output_arg_names()[0]) + assert (len(op.input_arg_names) == 1) + assert (len(op.output_arg_names) == 1) + input_var = block.var(op.input_arg_names[0]) + output_var = block.var(op.output_arg_names[0]) if input_var.dtype != src_var_type or \ output_var.dtype != dst_var_type: From 4121b55f26b6e1943cc895095e6e1d57491a1d16 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 21 Oct 2022 11:45:09 +0800 Subject: [PATCH 04/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index e89bd16ded073..f2eb52d4daadc 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -878,9 +878,10 @@ def _order_param_grads(block, param_grads): use_order = [] for op in block.ops: for input_name in op.input_arg_names: - if (input_name in pname_to_pg_pairs) and (input_name not in order): + if (input_name in pname_to_pg_pairs) and (input_name + not in use_order): use_order.append(input_name) - if len(order) == len(pname_to_pg_pairs): + if len(use_order) == len(pname_to_pg_pairs): break print("the parameter order after sort: ") print(use_order) From 0335bce2a71acc75eb62eedce680a5d34259c861 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 21 Oct 2022 11:47:41 +0800 Subject: [PATCH 05/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index f2eb52d4daadc..0590ee22ebdd4 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -742,7 +742,7 @@ def partition_by_use_order(params, group_size): mem_accu += mem print() print("######" * 6) - for k, v in mapping: + for k, v in mapping.items(): print("rank:{}, size:{}.".format(k, sum([get_var_size(var) for var in v]))) print([var.name for var in v]) From 55395962843e93b18663d9d889cf71aa5a9ec3be Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 25 Oct 2022 16:17:51 +0800 Subject: [PATCH 06/36] add logging --- .../passes/auto_parallel_sharding.py | 42 +++++++------------ 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 0590ee22ebdd4..8bcb27d5ef45e 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import reduce +import logging from paddle.framework import core from paddle.fluid import unique_name @@ -740,14 +741,7 @@ def partition_by_use_order(params, group_size): cur_rank += 1 mapping[cur_rank].append(param) mem_accu += mem - print() - print("######" * 6) - for k, v in mapping.items(): - print("rank:{}, size:{}.".format(k, - sum([get_var_size(var) for var in v]))) - print([var.name for var in v]) - print("######" * 6) - print() + return mapping @@ -767,23 +761,22 @@ def partition_by_greedy_even(params, group_size): param.name, numel) sizes[rank] += numel - print() - print("######" * 6) - for k, v in mapping: - print("rank:{}, size:{}.".format(k, - sum([get_var_size(var) for var in v]))) - print([var.name for var in v]) - print("######" * 6) - print() - return mapping def partition_parameters(params, group_size, algor="greedy_even"): if algor == "greedy_even": - return partition_by_greedy_even(params, group_size) + rank_to_params = partition_by_greedy_even(params, group_size) else: - return partition_by_use_order(params, group_size) + rank_to_params = partition_by_use_order(params, group_size) + + logging.info("Sharding Parameter Partition:") + for k, v in rank_to_params.items(): + logging.info("Rank:{}, Parameter Size:{} MB.".format( + k, sum([get_var_size(var) for var in v]))) + logging.info("Params in this rank: {}.".format([var.name for var in v])) + + return rank_to_params class ShardingInfo(object): @@ -867,10 +860,6 @@ def get_param_grad(self, param_name): def _order_param_grads(block, param_grads): - print() - print("######" * 6) - print("the parameter order before sort: ") - print([p.name for p, g in param_grads]) pname_to_pg_pairs = {} for p, g in param_grads: pname_to_pg_pairs[p.name] = (p, g) @@ -883,8 +872,7 @@ def _order_param_grads(block, param_grads): use_order.append(input_name) if len(use_order) == len(pname_to_pg_pairs): break - print("the parameter order after sort: ") - print(use_order) - print("######" * 6) - print() + + logging.info( + "Sharding the Order of param being used: {}.".format(use_order)) return [pname_to_pg_pairs[p] for p in use_order] From e93def0364764d5c4ff28ffdaa90672f64a9e4e1 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 26 Oct 2022 11:20:37 +0800 Subject: [PATCH 07/36] reorder opt --- .../passes/auto_parallel_sharding.py | 69 ++++++++++++++++--- 1 file changed, 61 insertions(+), 8 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 8bcb27d5ef45e..85dc3194a1fc8 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -110,12 +110,15 @@ def _apply_single_impl(self, main_program, startup_program, context): # 3. sub-block used for double backward self._build_sharding_groups(main_block, params_grads) - self._shard_optimizer(main_block, startup_block, params_grads, context) - self._shard_gradient_synchronization(main_block) - self._shard_parameter(main_block, startup_block) + for block in main_program.blocks: + self._shard_optimizer(block, startup_block, params_grads, context) + self._shard_gradient_synchronization(block) + self._shard_parameter(block, startup_block) context.set_attr("params_grads", self.shared_params_grads) + self._optimization_pass(main_program, startup_program) + def _build_sharding_groups(self, main_block, params_grads): self._collective_data_parallel_groups(main_block) self._build_sharding_infos(main_block, params_grads) @@ -143,7 +146,7 @@ def _collective_data_parallel_groups(self, main_block): def _build_sharding_infos(self, main_block, params_grads): # order params - params_grads = _order_param_grads(main_block, params_grads) + params_grads = re_order_program(main_block, params_grads) # partition for dp_group in self.dp_groups: @@ -523,6 +526,16 @@ def _shard_parameter(self, main_block, startup_block): main_block._sync_with_cpp() startup_block._sync_with_cpp() + def _optimization_pass(self, main_program, startup_program): + + with paddle.static.program_guard(main_program, startup_program): + _fuse_overlap_gradient_comm() + # TODO support multiple sub_blocks + if self.stage == 2: + _fuse_overlap_parameter_comm_stage_two() + elif self.stage == 3: + _fuse_overlap_parameter_comm_stage_three() + def _insert_init_and_broadcast_op(block, insert_idx, varname, local_rank, root_rank, ring_id, op_role, dist_context): @@ -746,9 +759,9 @@ def partition_by_use_order(params, group_size): def partition_by_greedy_even(params, group_size): - # TODO(JZ-LIANG) support multiple partition methods - # method1: greedy even but unorder - # method2: roughly even with oreder + """ + use greedy alogrithm to partition parameter as even as possible. + """ mapping = {} for rank_ in range(group_size): mapping[rank_] = [] @@ -859,7 +872,9 @@ def get_param_grad(self, param_name): return self.params_grads.get(param_name, None) -def _order_param_grads(block, param_grads): +def re_order_program(block, param_grads): + + # record order pname_to_pg_pairs = {} for p, g in param_grads: pname_to_pg_pairs[p.name] = (p, g) @@ -873,6 +888,44 @@ def _order_param_grads(block, param_grads): if len(use_order) == len(pname_to_pg_pairs): break + # reorder optimzier + last_op = block.ops + pname_to_op = {} + num_ops = len(block.ops) + # TODO support case when optimizer is not the last op + if is_optimizer_op(last_op) and last_op.type in _supported_optimizer_type: + # record and remove optimizer + for idx, op in reversed(list(enumerate(block.ops))): + if op.type not in _supported_optimizer_type: + break + + assert len(op.input("Param")) == 1 + block.desc._remove_op(idx, idx + 1) + pname_to_op[op.input("Param")] = block.ops.pop(idx) + assert len(use_order) == len(pname_to_op) + + # re-append + for pname in use_order: + new_op_desc = block.append_op(type='nop').desc + new_op_desc.copy_from(pname_to_op[pname].desc) + + assert len(block.ops) == num_ops + + # TODO reorder gradient clip order + logging.info( "Sharding the Order of param being used: {}.".format(use_order)) return [pname_to_pg_pairs[p] for p in use_order] + + +def _fuse_overlap_gradient_comm(): + pass + + +def _fuse_overlap_parameter_comm_stage_two(fuse_size): + main_program = default_main_program() + startup_program = default_startup_program() + + +def _fuse_overlap_parameter_comm_stage_three(fuse_size): + pass From faaea3c25b22cf524b64fff96944b51e74ddb7a3 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 26 Oct 2022 14:34:21 +0800 Subject: [PATCH 08/36] config --- .../passes/auto_parallel_sharding.py | 43 ++++++++++++++----- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 85dc3194a1fc8..f960f999714b7 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -97,6 +97,11 @@ def _apply_single_impl(self, main_program, startup_program, context): self.get_attr("sharding_degree") or self.get_attr("degree")) self.stage = int(self.get_attr("stage")) self.global_rank = int(self.get_attr("global_rank")) + self.fuse_overlap_optimization = True + if self.fuse_overlap_optimization: + self.partition_algor = "use_order" + else: + self.partition_algor = "greedy_even" params_grads = self.get_attr("params_grads") main_block, startup_block = main_program.global_block( ), startup_program.global_block() @@ -117,7 +122,8 @@ def _apply_single_impl(self, main_program, startup_program, context): context.set_attr("params_grads", self.shared_params_grads) - self._optimization_pass(main_program, startup_program) + if self.fuse_overlap_optimization: + self._optimization_pass(main_program, startup_program) def _build_sharding_groups(self, main_block, params_grads): self._collective_data_parallel_groups(main_block) @@ -178,7 +184,7 @@ def _build_sharding_infos(self, main_block, params_grads): self._dist_context._sharding_group = sharding_group # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group sharding_info = ShardingInfo(sharding_group, self.global_rank, - params_grads) + params_grads, self.partition_algor) self.sharding_infos.append(sharding_info) for param in sharding_info.params: self.varname_to_sharding_info[param.name] = sharding_info @@ -532,9 +538,9 @@ def _optimization_pass(self, main_program, startup_program): _fuse_overlap_gradient_comm() # TODO support multiple sub_blocks if self.stage == 2: - _fuse_overlap_parameter_comm_stage_two() + _fuse_overlap_parameter_comm_stage_two(self.sharding_infos) elif self.stage == 3: - _fuse_overlap_parameter_comm_stage_three() + _fuse_overlap_parameter_comm_stage_three(self.sharding_infos) def _insert_init_and_broadcast_op(block, insert_idx, varname, local_rank, @@ -794,7 +800,7 @@ def partition_parameters(params, group_size, algor="greedy_even"): class ShardingInfo(object): - def __init__(self, group, rank, params_grads): + def __init__(self, group, rank, params_grads, partition_algor): self.group = group self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads]) assert len(self.params_grads) == len(set( @@ -805,10 +811,10 @@ def __init__(self, group, rank, params_grads): self.group_size = group.nranks self.global_rank = rank self.local_rank = group.ranks.index(self.global_rank) + self.partition_algor = partition_algor # rank in below mapping are local rank in this sharding group - self.rank_to_params = partition_parameters(self.params, - self.group_size, - algor="use_order") + self.rank_to_params = partition_parameters(self.params, self.group_size, + self.partition_algor) # include fp32 and fp16 param self.param_to_rank = dict() self._map_param_to_rank() @@ -922,10 +928,27 @@ def _fuse_overlap_gradient_comm(): pass -def _fuse_overlap_parameter_comm_stage_two(fuse_size): +def _fuse_overlap_parameter_comm_stage_two(sharding_infos, fuse_size): + + assert len( + sharding_infos + ) == 1, "fuse overlap optimization only support one sharding group right now, but got [{}].".format( + len(sharding_infos)) + sharding_info = sharding_infos[0] + main_program = default_main_program() startup_program = default_startup_program() + # for param in sharding_info.params: + # n + + +def _fuse_overlap_parameter_comm_stage_three(sharding_infos, fuse_size): + + assert len( + sharding_infos + ) == 1, "fuse overlap optimization only support one sharding group right now, but got [{}].".format( + len(sharding_infos)) + sharding_info = sharding_infos[0] -def _fuse_overlap_parameter_comm_stage_three(fuse_size): pass From c4c4b54b1862e2d8673f2c3f9e2a22afb3966c7b Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 26 Oct 2022 17:44:31 +0800 Subject: [PATCH 09/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index f960f999714b7..d9cefc8ba1ce7 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -895,7 +895,7 @@ def re_order_program(block, param_grads): break # reorder optimzier - last_op = block.ops + last_op = block.ops[-1] pname_to_op = {} num_ops = len(block.ops) # TODO support case when optimizer is not the last op From 482a947ff81240ad25fa758d2c6a8315158a8221 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 26 Oct 2022 17:54:32 +0800 Subject: [PATCH 10/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index d9cefc8ba1ce7..c7b8af4d15cf4 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -907,7 +907,7 @@ def re_order_program(block, param_grads): assert len(op.input("Param")) == 1 block.desc._remove_op(idx, idx + 1) - pname_to_op[op.input("Param")] = block.ops.pop(idx) + pname_to_op[op.input("Param")[0]] = block.ops.pop(idx) assert len(use_order) == len(pname_to_op) # re-append From 184bbc9f70630c83b6a85cd0a0eeae43cbe7f87b Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 26 Oct 2022 18:01:03 +0800 Subject: [PATCH 11/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index c7b8af4d15cf4..63625443948f4 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -906,8 +906,8 @@ def re_order_program(block, param_grads): break assert len(op.input("Param")) == 1 - block.desc._remove_op(idx, idx + 1) pname_to_op[op.input("Param")[0]] = block.ops.pop(idx) + block.desc._remove_op(idx, idx + 1) assert len(use_order) == len(pname_to_op) # re-append From 2483f388682ae97434cdb944b3b3515edaa77f73 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 26 Oct 2022 19:30:44 +0800 Subject: [PATCH 12/36] bugfix --- .../distributed/passes/auto_parallel_sharding.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 63625443948f4..a1ec9dc4de2aa 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -898,27 +898,31 @@ def re_order_program(block, param_grads): last_op = block.ops[-1] pname_to_op = {} num_ops = len(block.ops) + remove_op_indices = [] # TODO support case when optimizer is not the last op if is_optimizer_op(last_op) and last_op.type in _supported_optimizer_type: - # record and remove optimizer + # record optimizer for idx, op in reversed(list(enumerate(block.ops))): if op.type not in _supported_optimizer_type: break - assert len(op.input("Param")) == 1 - pname_to_op[op.input("Param")[0]] = block.ops.pop(idx) - block.desc._remove_op(idx, idx + 1) + pname_to_op[op.input("Param")[0]] = block.ops[idx] + remove_op_indices.append(idx) assert len(use_order) == len(pname_to_op) - # re-append + # append new opts for pname in use_order: new_op_desc = block.append_op(type='nop').desc new_op_desc.copy_from(pname_to_op[pname].desc) + # remove old opts + for idx in remove_op_indices: + block._remove_op(idx, sync=False) + + block._sync_with_cpp() assert len(block.ops) == num_ops # TODO reorder gradient clip order - logging.info( "Sharding the Order of param being used: {}.".format(use_order)) return [pname_to_pg_pairs[p] for p in use_order] From 5b5a61aee01a3ac3b6f512c946f34a9fe061844f Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 26 Oct 2022 19:32:51 +0800 Subject: [PATCH 13/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index a1ec9dc4de2aa..7505b40e4cf2c 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -15,6 +15,7 @@ from functools import reduce import logging +import paddle from paddle.framework import core from paddle.fluid import unique_name from .pass_base import PassBase, register_pass From 5646f8ff22d7230a398bf0504bb801c2ce1cf7bc Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 26 Oct 2022 19:35:20 +0800 Subject: [PATCH 14/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 7505b40e4cf2c..688c9eb508bd6 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -539,9 +539,11 @@ def _optimization_pass(self, main_program, startup_program): _fuse_overlap_gradient_comm() # TODO support multiple sub_blocks if self.stage == 2: - _fuse_overlap_parameter_comm_stage_two(self.sharding_infos) + _fuse_overlap_parameter_comm_stage_two(self.sharding_infos, + fuse_size=1024) elif self.stage == 3: - _fuse_overlap_parameter_comm_stage_three(self.sharding_infos) + _fuse_overlap_parameter_comm_stage_three(self.sharding_infos, + fuse_size=1024) def _insert_init_and_broadcast_op(block, insert_idx, varname, local_rank, From c462abaa526e2c876e3d4ca5d44494fe2bba0c2a Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 26 Oct 2022 19:38:14 +0800 Subject: [PATCH 15/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 688c9eb508bd6..dcd2ce99c9f15 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -16,7 +16,9 @@ import logging import paddle + from paddle.framework import core +from paddle.fluid.framework import default_main_program, default_startup_program from paddle.fluid import unique_name from .pass_base import PassBase, register_pass from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op From 139aa0f23a8683c1f93f5a1426835790b705a558 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 26 Oct 2022 20:30:12 +0800 Subject: [PATCH 16/36] bugfix --- .../distributed/passes/auto_parallel_sharding.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index dcd2ce99c9f15..734d4004b8e25 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -155,7 +155,8 @@ def _collective_data_parallel_groups(self, main_block): def _build_sharding_infos(self, main_block, params_grads): # order params - params_grads = re_order_program(main_block, params_grads) + params_grads = re_order_program(main_block, params_grads, + self._dist_context) # partition for dp_group in self.dp_groups: @@ -883,7 +884,7 @@ def get_param_grad(self, param_name): return self.params_grads.get(param_name, None) -def re_order_program(block, param_grads): +def re_order_program(block, param_grads, dist_context): # record order pname_to_pg_pairs = {} @@ -911,14 +912,17 @@ def re_order_program(block, param_grads): if op.type not in _supported_optimizer_type: break assert len(op.input("Param")) == 1 - pname_to_op[op.input("Param")[0]] = block.ops[idx] + pname_to_op[op.input("Param")[0]] = op remove_op_indices.append(idx) assert len(use_order) == len(pname_to_op) # append new opts for pname in use_order: - new_op_desc = block.append_op(type='nop').desc - new_op_desc.copy_from(pname_to_op[pname].desc) + new_op = block.append_op(type='nop') + new_op.desc.copy_from(pname_to_op[pname].desc) + dist_context.set_op_dist_attr_for_program( + new_op, + dist_context.get_op_dist_attr_for_program(pname_to_op[pname])) # remove old opts for idx in remove_op_indices: From 36d97be919386be182af299755ef9245674818ec Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 26 Oct 2022 20:49:43 +0800 Subject: [PATCH 17/36] bugfix --- .../distributed/passes/auto_parallel_sharding.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 734d4004b8e25..f73e4795a602a 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -24,7 +24,7 @@ from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.operators.common import is_parameter_related, is_data_parallel_reduce_op -from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr +from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, get_logger from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_size OpRole = core.op_proto_and_checker_maker.OpRole @@ -39,6 +39,8 @@ "lars_momentum", "merged_momentum", "lamb", "sgd" ] +_logger = get_logger(logging.INFO) + def _is_reshard_op(op): return op.desc.has_attr("op_namescope") and \ @@ -795,11 +797,11 @@ def partition_parameters(params, group_size, algor="greedy_even"): else: rank_to_params = partition_by_use_order(params, group_size) - logging.info("Sharding Parameter Partition:") + _logger.info("Sharding Parameter Partition:") for k, v in rank_to_params.items(): - logging.info("Rank:{}, Parameter Size:{} MB.".format( + _logger.info("Rank:{}, Parameter Size:{} MB.".format( k, sum([get_var_size(var) for var in v]))) - logging.info("Params in this rank: {}.".format([var.name for var in v])) + _logger.info("Params in this rank: {}.".format([var.name for var in v])) return rank_to_params @@ -932,7 +934,7 @@ def re_order_program(block, param_grads, dist_context): assert len(block.ops) == num_ops # TODO reorder gradient clip order - logging.info( + _logger.info( "Sharding the Order of param being used: {}.".format(use_order)) return [pname_to_pg_pairs[p] for p in use_order] From 309b5d52981a6f1888283c56960af91c625c94be Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 11:15:29 +0800 Subject: [PATCH 18/36] stage2 bucket --- .../distributed/auto_parallel/constants.py | 4 +- .../paddle/distributed/auto_parallel/utils.py | 13 + ...uto_parallel_data_parallel_optimization.py | 12 +- .../passes/auto_parallel_sharding.py | 302 ++++++++++++------ 4 files changed, 228 insertions(+), 103 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 82c5011faf0af..51afad94c535b 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -82,7 +82,9 @@ def set_field_default_config(category, field, default_value): set_field_default_config(SHARDING, "enable", False) set_field_default_config(SHARDING, "stage", 1) set_field_default_config(SHARDING, "degree", 8) -set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0) +set_field_default_config(SHARDING, "overlap_grad_comm", False) +set_field_default_config(SHARDING, "bucket_size_numel", -1) +set_field_default_config(SHARDING, "partition_algor", "greedy_even") set_field_default_config(SHARDING, "enable_tuning", False) set_field_default_config(SHARDING, "tuning_range", []) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 88b5a0842262d..375fc92ec1efa 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -22,6 +22,7 @@ from functools import reduce import paddle.fluid.core as core +from paddle.fluid.framework import Variable from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.auto_parallel.process_group import get_all_process_groups from paddle.fluid.io import is_parameter, is_belong_to_optimizer @@ -1587,3 +1588,15 @@ def find_higher_order_backward_op(program): return True return False + + +def get_var_numel(var): + """ + input: + - var: variable + return: + number of elemnet in var + """ + assert isinstance(var, Variable) + assert -1 not in var.shape + return reduce(lambda x, y: x * y, var.shape) diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 8470aa5109961..203e39884ee1b 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -20,7 +20,7 @@ from paddle.fluid.framework import default_main_program from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op -from paddle.distributed.auto_parallel.utils import find_higher_order_backward_op, is_loss_grad_op, is_optimize_op, ring_id_to_process_group +from paddle.distributed.auto_parallel.utils import find_higher_order_backward_op, is_loss_grad_op, is_optimize_op, ring_id_to_process_group, get_var_numel from .pass_base import PassBase, PassType, register_pass # add new optimizers supporting rescale_grad here @@ -33,10 +33,6 @@ __max_stream_num_allow__ = 16 -def numel(var): - return np.prod(list(var.shape)) - - @register_pass("auto_parallel_data_parallel_optimization") class DataParallelOptimizationPass(PassBase): """ @@ -397,7 +393,7 @@ def op_depend_on_group(op, group): ring_id = op.attr("ring_id") grad_name = op.output_arg_names[0] grad_var = block.var(grad_name) - grad_numel = numel(grad_var) + grad_numel = get_var_numel(grad_var) if cur_group.acceptable(grad_var, ring_id): assert grad_name not in grouped_grad_names @@ -541,7 +537,7 @@ def acceptable(self, grad_var, ring_id): return True if ring_id != self.ring_id: return False - if numel(grad_var) + self.numel > self.max_group_size: + if get_var_numel(grad_var) + self.numel > self.max_group_size: return False if grad_var.dtype != self.dtype: return False @@ -552,7 +548,7 @@ def add(self, grad_var, ring_id, i): self.gradients.append(grad_var) self.ring_id = ring_id self.dtype = grad_var.dtype - self.numel += numel(grad_var) + self.numel += get_var_numel(grad_var) # remove auxiliary ops in non-fuse dp allreduce self.remove_allreduce_op_indices.append(i) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index f73e4795a602a..14b66035247cf 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -25,7 +25,7 @@ from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.operators.common import is_parameter_related, is_data_parallel_reduce_op from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, get_logger -from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_size +from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_numel OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() @@ -60,6 +60,9 @@ def __init__(self): self.set_attr("stage", None) self.set_attr("sharding_degree", None) # for parallelizer self.set_attr("degree", None) # for parallelizer_v2 + self.set_attr("overlap_grad_comm", None) + self.set_attr("bucket_size_numel", None) + self.set_attr("partition_algor", None) self.set_attr("params_grads", []) self.set_attr("global_rank", -1) self.dp_groups = set() @@ -90,6 +93,12 @@ def _check_self(self): if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr("global_rank") < 0: return False + if self.get_attr("overlap_grad_comm") is None: + return False + if self.get_attr("bucket_size_numel") is None: + return False + if self.get_attr("partition_algor") is None: + return False return True @@ -102,11 +111,9 @@ def _apply_single_impl(self, main_program, startup_program, context): self.get_attr("sharding_degree") or self.get_attr("degree")) self.stage = int(self.get_attr("stage")) self.global_rank = int(self.get_attr("global_rank")) - self.fuse_overlap_optimization = True - if self.fuse_overlap_optimization: - self.partition_algor = "use_order" - else: - self.partition_algor = "greedy_even" + self.overlap_grad_comm = self.get_attr("overlap_grad_comm") + self.bucket_size_numel = int(self.get_attr("bucket_size_numel")) + self.partition_algor = self.get_attr("partition_algor") params_grads = self.get_attr("params_grads") main_block, startup_block = main_program.global_block( ), startup_program.global_block() @@ -126,9 +133,7 @@ def _apply_single_impl(self, main_program, startup_program, context): self._shard_parameter(block, startup_block) context.set_attr("params_grads", self.shared_params_grads) - - if self.fuse_overlap_optimization: - self._optimization_pass(main_program, startup_program) + self._optimization_pass(main_program, startup_program) def _build_sharding_groups(self, main_block, params_grads): self._collective_data_parallel_groups(main_block) @@ -373,7 +378,7 @@ def _shard_optimizer_ops_and_states(self, main_block, startup_block): def _insert_optimizer_broadcasts(self, main_block, startup_block): - if self.stage > 2: + if self.stage > 2 or self.bucket_size_numel > 1: return for sharding_info in self.sharding_infos: @@ -541,14 +546,16 @@ def _shard_parameter(self, main_block, startup_block): def _optimization_pass(self, main_program, startup_program): with paddle.static.program_guard(main_program, startup_program): - _fuse_overlap_gradient_comm() + if self.overlap_grad_comm: + _fuse_overlap_gradient_comm() # TODO support multiple sub_blocks - if self.stage == 2: - _fuse_overlap_parameter_comm_stage_two(self.sharding_infos, - fuse_size=1024) - elif self.stage == 3: - _fuse_overlap_parameter_comm_stage_three(self.sharding_infos, - fuse_size=1024) + if self.bucket_size_numel > 1: + if self.stage == 2: + _fuse_overlap_parameter_comm_stage_two(self.sharding_infos, + fuse_size=1024) + elif self.stage == 3: + _fuse_overlap_parameter_comm_stage_three( + self.sharding_infos, fuse_size=1024) def _insert_init_and_broadcast_op(block, insert_idx, varname, local_rank, @@ -806,6 +813,164 @@ def partition_parameters(params, group_size, algor="greedy_even"): return rank_to_params +def re_order_program(block, param_grads, dist_context): + + # record order + pname_to_pg_pairs = {} + for p, g in param_grads: + pname_to_pg_pairs[p.name] = (p, g) + + use_order = [] + for op in block.ops: + for input_name in op.input_arg_names: + if (input_name in pname_to_pg_pairs) and (input_name + not in use_order): + use_order.append(input_name) + if len(use_order) == len(pname_to_pg_pairs): + break + + # reorder optimzier + last_op = block.ops[-1] + pname_to_op = {} + num_ops = len(block.ops) + remove_op_indices = [] + # TODO support case when optimizer is not the last op + if is_optimizer_op(last_op) and last_op.type in _supported_optimizer_type: + # record optimizer + for idx, op in reversed(list(enumerate(block.ops))): + if op.type not in _supported_optimizer_type: + break + assert len(op.input("Param")) == 1 + pname_to_op[op.input("Param")[0]] = op + remove_op_indices.append(idx) + assert len(use_order) == len(pname_to_op) + + # append new opts + for pname in use_order: + new_op = block.append_op(type='nop') + new_op.desc.copy_from(pname_to_op[pname].desc) + dist_context.set_op_dist_attr_for_program( + new_op, + dist_context.get_op_dist_attr_for_program(pname_to_op[pname])) + + # remove old opts + for idx in remove_op_indices: + block._remove_op(idx, sync=False) + + block._sync_with_cpp() + assert len(block.ops) == num_ops + + # TODO reorder gradient clip order + _logger.info( + "Sharding the Order of param being used: {}.".format(use_order)) + return [pname_to_pg_pairs[p] for p in use_order] + + +def group_param(sharding_info, fuse_size): + """ + param are group by: + rank id + fuse_size + dtype + """ + group_to_param_map = {} + param_to_group_map = {} + bucket = [] + cur_group = ParameterGroup(fuse_size) + for param in sharding_info.params: + rank = sharding_info.get_var_rank(param.name) + + if cur_group.acceptable(param, rank): + cur_group.collect(param, rank) + else: + cur_group = ParameterGroup(fuse_size) + cur_group.collect(param, rank) + + if cur_group in group_to_param_map: + group_to_param_map[cur_group].append(param_name) + else: + group_to_param_map[cur_group] = [param_name] + + param_to_group_map[param_name] = cur_group + + return group_to_param_map, param_to_group_map + + +def _fuse_overlap_gradient_comm(): + pass + + +def _fuse_overlap_parameter_comm_stage_two(sharding_infos, dist_context, + fuse_size): + + assert len( + sharding_infos + ) == 1, "fuse overlap optimization only support one sharding group right now, but got [{}].".format( + len(sharding_infos)) + sharding_info = sharding_infos[0] + + main_block = default_main_program().global_block() + startup_block = default_startup_program().global_block() + + group_to_param_map, param_to_group_map = group_param( + sharding_info, fuse_size) + + for group in group_to_param_map.keys(): + + assert len(group) >= 1 + if len(group) > 1: + coalesce_var_name = unique_name.generate( + 'coalecse_param_{}'.format(i)) + startup_block.create_var(name=coalesce_var_name, + dtype=group.dtype, + persistable=True, + stop_gradient=True) + group.coalesce_var = main_block.create_var(name=coalesce_var_name, + dtype=group.dtype, + persistable=True, + stop_gradient=True) + startup_block.append_op(type="coalesce_tensor", + inputs={"Input": group.params}, + outputs={ + "Output": group.params, + "FusedOutput": group.coalesce_var + }, + attrs={ + "copy_data": True, + "use_align": True, + "dtype": group.dtype, + OP_ROLE_KEY: OpRole.Forward + }) + else: + group.coalesce_var = group.params[0] + + # TODO Overlap broadcast with opt and next forward + new_op = main_block.append_op(type='c_broadcast', + inputs={'X': group.coalesce_var}, + outputs={'Out': group.coalesce_var}, + attrs={ + 'ring_id': sharding_info.group.id, + 'root': group.rank, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize + }) + + # NOTE the current dist context lack the presentation for bucket tensor which + # composes many tensor with different dims_mapping. we assign a fake dist attr + # for it currently. + + +def _fuse_overlap_parameter_comm_stage_three(sharding_infos, fuse_size): + + assert len( + sharding_infos + ) == 1, "fuse overlap optimization only support one sharding group right now, but got [{}].".format( + len(sharding_infos)) + sharding_info = sharding_infos[0] + + pass + + class ShardingInfo(object): def __init__(self, group, rank, params_grads, partition_algor): @@ -886,84 +1051,33 @@ def get_param_grad(self, param_name): return self.params_grads.get(param_name, None) -def re_order_program(block, param_grads, dist_context): +class ParameterGroup(object): - # record order - pname_to_pg_pairs = {} - for p, g in param_grads: - pname_to_pg_pairs[p.name] = (p, g) + def __init__(self, max_size): + self.max_siez = max_size + self.dtype = None + self.rank = -1 + self.numel = 0 + self.params = [] + self.coalesce_var = None - use_order = [] - for op in block.ops: - for input_name in op.input_arg_names: - if (input_name in pname_to_pg_pairs) and (input_name - not in use_order): - use_order.append(input_name) - if len(use_order) == len(pname_to_pg_pairs): - break - - # reorder optimzier - last_op = block.ops[-1] - pname_to_op = {} - num_ops = len(block.ops) - remove_op_indices = [] - # TODO support case when optimizer is not the last op - if is_optimizer_op(last_op) and last_op.type in _supported_optimizer_type: - # record optimizer - for idx, op in reversed(list(enumerate(block.ops))): - if op.type not in _supported_optimizer_type: - break - assert len(op.input("Param")) == 1 - pname_to_op[op.input("Param")[0]] = op - remove_op_indices.append(idx) - assert len(use_order) == len(pname_to_op) - - # append new opts - for pname in use_order: - new_op = block.append_op(type='nop') - new_op.desc.copy_from(pname_to_op[pname].desc) - dist_context.set_op_dist_attr_for_program( - new_op, - dist_context.get_op_dist_attr_for_program(pname_to_op[pname])) - - # remove old opts - for idx in remove_op_indices: - block._remove_op(idx, sync=False) - - block._sync_with_cpp() - assert len(block.ops) == num_ops - - # TODO reorder gradient clip order - _logger.info( - "Sharding the Order of param being used: {}.".format(use_order)) - return [pname_to_pg_pairs[p] for p in use_order] - - -def _fuse_overlap_gradient_comm(): - pass - - -def _fuse_overlap_parameter_comm_stage_two(sharding_infos, fuse_size): - - assert len( - sharding_infos - ) == 1, "fuse overlap optimization only support one sharding group right now, but got [{}].".format( - len(sharding_infos)) - sharding_info = sharding_infos[0] - - main_program = default_main_program() - startup_program = default_startup_program() - - # for param in sharding_info.params: - # n - - -def _fuse_overlap_parameter_comm_stage_three(sharding_infos, fuse_size): + def acceptable(param, rank): + if self.numel == 0: + return True + else: + if param.dtype is not self.dtype: + return False + if rank != self.rank: + return False + if self.numel + get_var_numel(param) > self.max_siez: + return False + return True - assert len( - sharding_infos - ) == 1, "fuse overlap optimization only support one sharding group right now, but got [{}].".format( - len(sharding_infos)) - sharding_info = sharding_infos[0] + def collect(param, rank): + self.dtype = param.dtype + self.rank = rank + self.numel += get_var_numel(param) + self.params.append(param) - pass + def __len__(self): + return len(self.params) From d8bea004795ecc6158589c4b788968e53d4df370 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 11:23:56 +0800 Subject: [PATCH 19/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 14b66035247cf..9afac97f4371e 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -24,8 +24,7 @@ from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.operators.common import is_parameter_related, is_data_parallel_reduce_op -from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, get_logger -from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_numel +from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, get_var_numel, get_logger OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() From 6d8c7cf7c80dd6b967070132f0aadbbc1648cb55 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 11:28:22 +0800 Subject: [PATCH 20/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 9afac97f4371e..e2b1dcde056dc 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -25,6 +25,7 @@ from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.operators.common import is_parameter_related, is_data_parallel_reduce_op from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_dist_op_attr_for_program_by_mesh_and_mapping, set_var_dist_attr, get_var_numel, get_logger +from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_size OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() From 7ccaff1862024e07ac46ff677d5f4e7f4d308d11 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 11:59:38 +0800 Subject: [PATCH 21/36] logging --- python/paddle/distributed/passes/auto_parallel_sharding.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index e2b1dcde056dc..25b5b9ce35a5f 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -552,6 +552,7 @@ def _optimization_pass(self, main_program, startup_program): if self.bucket_size_numel > 1: if self.stage == 2: _fuse_overlap_parameter_comm_stage_two(self.sharding_infos, + self._dist_context, fuse_size=1024) elif self.stage == 3: _fuse_overlap_parameter_comm_stage_three( @@ -914,7 +915,9 @@ def _fuse_overlap_parameter_comm_stage_two(sharding_infos, dist_context, group_to_param_map, param_to_group_map = group_param( sharding_info, fuse_size) - + _logger.info("Sharding Stage2 Optimization:") + _logger.info("[{}] Parameters are fused into [{}] Buckets".format( + len(param_to_group_map.keys()), len(group_to_param_map.keys()))) for group in group_to_param_map.keys(): assert len(group) >= 1 @@ -943,6 +946,7 @@ def _fuse_overlap_parameter_comm_stage_two(sharding_infos, dist_context, }) else: group.coalesce_var = group.params[0] + _logger.info("Bucket: {}".format([p.name for p in group.params])) # TODO Overlap broadcast with opt and next forward new_op = main_block.append_op(type='c_broadcast', From 34415f373f3d5cfcff0f838bcbd312277a744aff Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 12:02:10 +0800 Subject: [PATCH 22/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 25b5b9ce35a5f..f946ec4cd2319 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -1065,7 +1065,7 @@ def __init__(self, max_size): self.params = [] self.coalesce_var = None - def acceptable(param, rank): + def acceptable(self, param, rank): if self.numel == 0: return True else: @@ -1077,7 +1077,7 @@ def acceptable(param, rank): return False return True - def collect(param, rank): + def collect(self, param, rank): self.dtype = param.dtype self.rank = rank self.numel += get_var_numel(param) From 038204e732c093f6dc53b5e1f034c02fa6b66619 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 12:04:42 +0800 Subject: [PATCH 23/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index f946ec4cd2319..c89856dd93325 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -888,11 +888,11 @@ def group_param(sharding_info, fuse_size): cur_group.collect(param, rank) if cur_group in group_to_param_map: - group_to_param_map[cur_group].append(param_name) + group_to_param_map[cur_group].append(param.name) else: - group_to_param_map[cur_group] = [param_name] + group_to_param_map[cur_group] = [param.name] - param_to_group_map[param_name] = cur_group + param_to_group_map[param.name] = cur_group return group_to_param_map, param_to_group_map From 6a2f678d976f2e915e1a39cc3def12930c429286 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 12:08:34 +0800 Subject: [PATCH 24/36] bugfix --- .../distributed/passes/auto_parallel_sharding.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index c89856dd93325..e34cafb658f8d 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -551,12 +551,13 @@ def _optimization_pass(self, main_program, startup_program): # TODO support multiple sub_blocks if self.bucket_size_numel > 1: if self.stage == 2: - _fuse_overlap_parameter_comm_stage_two(self.sharding_infos, - self._dist_context, - fuse_size=1024) + _fuse_overlap_parameter_comm_stage_two( + self.sharding_infos, + self._dist_context, + fuse_size=self.bucket_size_numel) elif self.stage == 3: _fuse_overlap_parameter_comm_stage_three( - self.sharding_infos, fuse_size=1024) + self.sharding_infos, fuse_size=self.bucket_size_numel) def _insert_init_and_broadcast_op(block, insert_idx, varname, local_rank, @@ -916,8 +917,10 @@ def _fuse_overlap_parameter_comm_stage_two(sharding_infos, dist_context, group_to_param_map, param_to_group_map = group_param( sharding_info, fuse_size) _logger.info("Sharding Stage2 Optimization:") - _logger.info("[{}] Parameters are fused into [{}] Buckets".format( - len(param_to_group_map.keys()), len(group_to_param_map.keys()))) + _logger.info( + "Bucket size is [{}], [{}] Parameters are fused into [{}] Buckets". + format(fuse_size, len(param_to_group_map.keys()), + len(group_to_param_map.keys()))) for group in group_to_param_map.keys(): assert len(group) >= 1 From c5943bb2b531b62845dd53080e14939976a9ee28 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 12:15:38 +0800 Subject: [PATCH 25/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index e34cafb658f8d..d0c4b0752d7ba 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -1072,7 +1072,7 @@ def acceptable(self, param, rank): if self.numel == 0: return True else: - if param.dtype is not self.dtype: + if param.dtype != self.dtype: return False if rank != self.rank: return False From 36fdfe22a1069c5bd663e053433759f4fac3084e Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 12:17:45 +0800 Subject: [PATCH 26/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index d0c4b0752d7ba..2fdd099d6b603 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -921,7 +921,7 @@ def _fuse_overlap_parameter_comm_stage_two(sharding_infos, dist_context, "Bucket size is [{}], [{}] Parameters are fused into [{}] Buckets". format(fuse_size, len(param_to_group_map.keys()), len(group_to_param_map.keys()))) - for group in group_to_param_map.keys(): + for i, group in enumerate(group_to_param_map.keys()): assert len(group) >= 1 if len(group) > 1: From 7eb8f63b61bb890839b74d8afc84001a9321033c Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 15:31:23 +0800 Subject: [PATCH 27/36] bugfix --- python/paddle/distributed/passes/auto_parallel_sharding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 2fdd099d6b603..e6054d4e42df9 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -949,7 +949,9 @@ def _fuse_overlap_parameter_comm_stage_two(sharding_infos, dist_context, }) else: group.coalesce_var = group.params[0] - _logger.info("Bucket: {}".format([p.name for p in group.params])) + _logger.info("Bucket[{}] size [{}]MB : {}".format( + i, sum([get_var_size(p) for p in group.params]), + [p.name for p in group.params])) # TODO Overlap broadcast with opt and next forward new_op = main_block.append_op(type='c_broadcast', From ce5531ae6a342608c57f5141752cb3d6e403db03 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 20:09:36 +0800 Subject: [PATCH 28/36] debug --- python/paddle/distributed/auto_parallel/engine.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 7c550ab57852c..199c8b6a7cac3 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -828,6 +828,8 @@ def fit(self, with profiler.Profiler(timer_only=True) as prof: for epoch in range(epochs): for step, _ in enumerate(train_dataloader): + print_input(self.main_program, self._labels) + print_param(self.main_program) try: outs = self._executor.run( self.main_program, @@ -1629,3 +1631,13 @@ def inputs(self): @property def labels(self): return self._labels + + +def print_param(program): + for p in program.all_parameters(): + print(p) + + +def print_input(program, vars): + for v in vars: + print(p) From 8ed27034d17513520c70d44990c71aeb98ded707 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 20:24:22 +0800 Subject: [PATCH 29/36] debug --- python/paddle/distributed/auto_parallel/engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 199c8b6a7cac3..a5cc2954f76be 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -827,8 +827,8 @@ def fit(self, with profiler.Profiler(timer_only=True) as prof: for epoch in range(epochs): - for step, _ in enumerate(train_dataloader): - print_input(self.main_program, self._labels) + for step, data in enumerate(train_dataloader): + print_input(self.main_program, data) print_param(self.main_program) try: outs = self._executor.run( @@ -1635,9 +1635,9 @@ def labels(self): def print_param(program): for p in program.all_parameters(): - print(p) + print(p.name, p.get_value()) def print_input(program, vars): for v in vars: - print(p) + print(v.name, v.get_value()) From aeac1219a79bf0d1a1d574305b643312fcc31985 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 20:34:03 +0800 Subject: [PATCH 30/36] debug --- python/paddle/distributed/auto_parallel/engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index a5cc2954f76be..d06c6c7463de1 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -828,14 +828,16 @@ def fit(self, with profiler.Profiler(timer_only=True) as prof: for epoch in range(epochs): for step, data in enumerate(train_dataloader): - print_input(self.main_program, data) + self._strategy.return_numpy = True + fetch_names.append('labels') print_param(self.main_program) try: - outs = self._executor.run( + outs, lables = self._executor.run( self.main_program, fetch_list=fetch_names, use_program_cache=self._strategy.use_cache, return_numpy=self._strategy.return_numpy) + print("lables: {}".format(lables)) except core.EOFException: break if lr_scheduler and step % self._k_steps == 0: From 8c88d574e19ada52526ecf5c858bdea5897741af Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 20:37:18 +0800 Subject: [PATCH 31/36] debug --- python/paddle/distributed/auto_parallel/engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index d06c6c7463de1..dd080abd57025 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -837,7 +837,7 @@ def fit(self, fetch_list=fetch_names, use_program_cache=self._strategy.use_cache, return_numpy=self._strategy.return_numpy) - print("lables: {}".format(lables)) + print("lables: {}".format(lables[:20])) except core.EOFException: break if lr_scheduler and step % self._k_steps == 0: @@ -1637,9 +1637,9 @@ def labels(self): def print_param(program): for p in program.all_parameters(): - print(p.name, p.get_value()) + print(p.name, p.get_value()[:20]) def print_input(program, vars): for v in vars: - print(v.name, v.get_value()) + print(v.name, v.get_value()[:20]) From 17d4f1bcf75f97e01a7ef8cae47e22ae1d7fd6ca Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 20:50:17 +0800 Subject: [PATCH 32/36] debug --- python/paddle/distributed/auto_parallel/engine.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index dd080abd57025..57bb04a84f77b 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -838,6 +838,7 @@ def fit(self, use_program_cache=self._strategy.use_cache, return_numpy=self._strategy.return_numpy) print("lables: {}".format(lables[:20])) + outs = [outs] except core.EOFException: break if lr_scheduler and step % self._k_steps == 0: @@ -1636,8 +1637,10 @@ def labels(self): def print_param(program): - for p in program.all_parameters(): - print(p.name, p.get_value()[:20]) + for i, p in enumerate(program.all_parameters()): + if i == 10: + break + print(p.name, p.get_value()) def print_input(program, vars): From 2fdd2acc3ae93c75d13f6cea26908f8b44fea6f5 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 1 Nov 2022 21:10:54 +0800 Subject: [PATCH 33/36] debug --- .../paddle/distributed/auto_parallel/engine.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 57bb04a84f77b..dd0798b27689b 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -829,6 +829,7 @@ def fit(self, for epoch in range(epochs): for step, data in enumerate(train_dataloader): self._strategy.return_numpy = True + fetch_names = [fetch_names[0]] fetch_names.append('labels') print_param(self.main_program) try: @@ -838,7 +839,8 @@ def fit(self, use_program_cache=self._strategy.use_cache, return_numpy=self._strategy.return_numpy) print("lables: {}".format(lables[:20])) - outs = [outs] + print("outs: {}".format(outs)) + except core.EOFException: break if lr_scheduler and step % self._k_steps == 0: @@ -847,11 +849,12 @@ def fit(self, prof.step() - self._prepare_logger(outs, epoch, step, lr, - fetch_names, fetch_indices, - prof.step_info(), self._mode) - history = self._prepare_history(outs, fetch_indices, - self._mode) + # self._prepare_logger(outs, epoch, step, lr, + # fetch_names, fetch_indices, + # prof.step_info(), self._mode) + # history = self._prepare_history(outs, fetch_indices, + # self._mode) + history = None if valid_data and epoch % valid_freq == 0: self.evaluate(valid_data, valid_sample_split, batch_size, @@ -1640,7 +1643,7 @@ def print_param(program): for i, p in enumerate(program.all_parameters()): if i == 10: break - print(p.name, p.get_value()) + print(p.name, np.array(p.get_value())[:20]) def print_input(program, vars): From fa7f049bc97720aa00c4bf42569919ef945c3f18 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 3 Nov 2022 10:44:07 +0800 Subject: [PATCH 34/36] old engine --- .../distributed/auto_parallel/engine.py | 34 +- .../distributed/auto_parallel/engine.py.acc | 1651 +++++++++++++++++ 2 files changed, 1658 insertions(+), 27 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/engine.py.acc diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index dd0798b27689b..7c550ab57852c 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -827,20 +827,13 @@ def fit(self, with profiler.Profiler(timer_only=True) as prof: for epoch in range(epochs): - for step, data in enumerate(train_dataloader): - self._strategy.return_numpy = True - fetch_names = [fetch_names[0]] - fetch_names.append('labels') - print_param(self.main_program) + for step, _ in enumerate(train_dataloader): try: - outs, lables = self._executor.run( + outs = self._executor.run( self.main_program, fetch_list=fetch_names, use_program_cache=self._strategy.use_cache, return_numpy=self._strategy.return_numpy) - print("lables: {}".format(lables[:20])) - print("outs: {}".format(outs)) - except core.EOFException: break if lr_scheduler and step % self._k_steps == 0: @@ -849,12 +842,11 @@ def fit(self, prof.step() - # self._prepare_logger(outs, epoch, step, lr, - # fetch_names, fetch_indices, - # prof.step_info(), self._mode) - # history = self._prepare_history(outs, fetch_indices, - # self._mode) - history = None + self._prepare_logger(outs, epoch, step, lr, + fetch_names, fetch_indices, + prof.step_info(), self._mode) + history = self._prepare_history(outs, fetch_indices, + self._mode) if valid_data and epoch % valid_freq == 0: self.evaluate(valid_data, valid_sample_split, batch_size, @@ -1637,15 +1629,3 @@ def inputs(self): @property def labels(self): return self._labels - - -def print_param(program): - for i, p in enumerate(program.all_parameters()): - if i == 10: - break - print(p.name, np.array(p.get_value())[:20]) - - -def print_input(program, vars): - for v in vars: - print(v.name, v.get_value()[:20]) diff --git a/python/paddle/distributed/auto_parallel/engine.py.acc b/python/paddle/distributed/auto_parallel/engine.py.acc new file mode 100644 index 0000000000000..dd0798b27689b --- /dev/null +++ b/python/paddle/distributed/auto_parallel/engine.py.acc @@ -0,0 +1,1651 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import random +import numpy as np +from collections import defaultdict + +import paddle +import paddle.utils as utils + +from paddle import fluid, profiler, static +from paddle.metric import Metric +from paddle.static import InputSpec +from paddle.fluid import core +from paddle.fluid import Variable +from paddle.fluid.layers.utils import flatten +from paddle.fluid.executor import global_scope, _to_name_str +from paddle.fluid.framework import Operator, _non_static_mode +from paddle.fluid.framework import _current_expected_place as _get_device +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.distributed import fleet + +from .converter import Converter +from .helper import ProgramHelper +from .cluster import Cluster, get_default_cluster +from .planner_v2 import Planner +from .parallelizer_v2 import Parallelizer +from .dist_op import DistributedOperator +from .dist_saver import DistributedSaver +from .dist_loader import DistributedDataLoaderFromGenerator, DistributedDataLoader +from .utils import to_list, get_logger, get_dist_attr +from .process_group import new_process_group, get_all_process_groups +from .dist_context import DistributedContext, get_default_distributed_context +from .strategy import Strategy +from .interface import CollectionNames, get_collection + + +class Engine: + """ + An Engine object can provide the full power of auto parallel to users. + With the help of it, users can easily obtain the abilities of the + distributed training and inference. It also support the dynamic graph and + static graph at the same time. + + Args: + model (paddle.nn.Layer, optional): The model is an instance of + paddle.nn.Layer. + loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer` + instance or any callable function taken the predicted values and + ground truth values as input. It can be None when there is no loss. + Default: None. + optimizer (Optimizer|None, optional): The optimizer need to be set in training + and should be None in eval and predict mode. Default: None. + metrics (Metric|list[Metric]|None, optional): If metrics is set, all + metrics will be calculated and output in train/eval mode. Default: None. + cluster (Cluster|None, optional): The cluster represents the topology information + about the used physical devices. Default: None. (Unused for now) + strategy (Strategy|None, optional): The strategy is used to configure the + parallelization and optimization behaviors. Default: None. + + Examples: + + .. code-block:: python + + import paddle + import paddle.vision.transforms as T + from paddle.distributed.fleet import auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + train_dataset = MNIST(mode='train', transform=transform) + valid_dataset = MNIST(mode='test', transform=transform) + + model = paddle.vision.models.LeNet() + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + metrics = paddle.metric.Accuracy(topk=(1, 2)) + + engine = auto.Engine(model, loss, optimizer, metrics) + # fit + engine.fit(train_dataset, + epochs=2, + batch_size=64) + # evaluate + engine.evaluate(valid_dataset, + batch_size=64) + # predict + engine.predict(valid_dataset, + batch_size=64) + # save + engine.save("./my_model") + # load + engine.load("./my_model") + + """ + + def __init__(self, + model=None, + loss=None, + optimizer=None, + metrics=None, + cluster=None, + strategy=None): + + if model and not isinstance(model, + paddle.nn.Layer) and not callable(model): + raise TypeError( + "'model must be sub classes of `paddle.nn.Layer` or any callable function." + ) + self._model = model + + # if loss and not isinstance(loss, + # paddle.nn.Layer) and not callable(loss): + # raise TypeError( + # "'loss' must be sub classes of `paddle.nn.Layer` or any callable function." + # ) + self._loss = loss + + if optimizer and not isinstance( + optimizer, + (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)): + raise TypeError( + "'optimizer' must be object of class `paddle.optimizer.Optimizer`" + " or `paddle.fluid.optimizer.Optimizer`.") + self._optimizer = self._validate_opt(optimizer) + + metrics = metrics or [] + for metric in to_list(metrics): + assert isinstance(metric, Metric), \ + "{} is not sub class of Metric".format( + metric.__class__.__name__) + self._metrics = to_list(metrics) + + if cluster and not isinstance(cluster, Cluster): + raise TypeError( + "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`" + ) + self._cluster = cluster or get_default_cluster() + + if strategy and not isinstance(strategy, Strategy): + raise TypeError( + "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`" + ) + self._strategy = strategy or Strategy() + + if os.getenv("POD_NAME"): + print("Distribute training by paddle.distributed.launch", + flush=True) + fleet.init(is_collective=True) + + self._executor = None + self._cur_rank = paddle.distributed.get_rank() + self._nranks = paddle.distributed.get_world_size() + self._saver = DistributedSaver() + + self._logger = get_logger(logging.INFO) + + self._orig_main_prog = static.default_main_program() + self._orig_startup_prog = static.default_startup_program() + self._orig_dist_context = get_default_distributed_context() + self._dist_contexts = {} + self._serial_main_progs = {} + self._serial_startup_progs = {} + self._dist_main_progs = defaultdict(dict) # dist main programs + self._dist_startup_progs = defaultdict(dict) # dist startup programs + self._feed_vars = {} + self._fetch_vars = {} + self._planners = {} + self._has_prepared = {"train": False, "eval": False, "predict": False} + self._has_prepared_reader = { + "train": False, + "eval": False, + "predict": False + } + self._inputs_spec = [] + self._labels_spec = [] + self._inputs = [] + self._labels = [] + + self._skip_build = False + self._outside_dataloader = False + self._planned_mode = None + self._dygraph_mode = False + self._tuning = self._strategy.tuning + + def _prepare_data_spec(self, data, split, batch_size): + inputs_spec = [] + labels_spec = [] + if isinstance(data, paddle.io.IterableDataset): + if split is None: + inputs, labels = next(iter(data)) + else: + sample = next(iter(data)) + inputs = sample[:split] + labels = sample[split:] + elif isinstance(data, paddle.io.Dataset): + if split is None: + inputs, labels = data[0] + else: + sample = data[0] + inputs = sample[:split] + labels = sample[split:] + else: + raise ValueError( + "Data should be a Dataset or IterableDatset, but received {}.". + format(type(data).__name__)) + inputs = to_list(inputs) + labels = to_list(labels) + + num_shards = self._strategy.dataset.num_shards + + def _adjust_item_spec(num_shards, spec): + if num_shards > 1 and len(spec.shape) > 1: + spec.shape[0] = spec.shape[0] * num_shards + + def _infer_item_spec(item, name, batch_size, specs): + if isinstance(item, np.ndarray): + spec = InputSpec.from_numpy(item, name) + if batch_size is None: + _adjust_item_spec(num_shards, spec) + specs.append(spec) + else: + specs.append(spec.batch(batch_size)) + elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)): + _adjust_item_spec(num_shards, spec) + spec = InputSpec.from_tensor(item, name) + if batch_size is None: + specs.append(spec) + else: + specs.append(spec.batch(batch_size)) + else: + specs.append(InputSpec([batch_size], type(item), name)) + + if inputs is not None: + for i, item in enumerate(inputs): + assert item is not None, "Receive None input." + name = "input" + str(i) + _infer_item_spec(item, name, batch_size, inputs_spec) + if labels is not None: + for i, item in enumerate(labels): + assert item is not None, "Receive None input." + name = "label" + str(i) + _infer_item_spec(item, name, batch_size, labels_spec) + + inputs_spec = self._validate_spec(inputs_spec) + labels_spec = self._validate_spec(labels_spec) + return inputs_spec, labels_spec + + def _prepare_data_tensor(self, + inputs_spec, + labels_spec, + inputs=None, + labels=None): + if _non_static_mode() or self._dygraph_mode: + return None, None + inputs_spec = inputs_spec if inputs_spec else [] + labels_spec = labels_spec if labels_spec else [] + if inputs_spec: + assert isinstance(inputs_spec, list), \ + "inputs should be list, but received {}".format(type(inputs_spec)) + if inputs is None: + inputs = [s._create_feed_layer() for s in inputs_spec] + else: + assert isinstance(inputs, list), \ + "inputs should be list, but received {}".format(type(inputs)) + for input_spec, input in zip(inputs_spec, inputs): + if input_spec.shape != input.shape: + input.desc.set_shape(input_spec.shape) + if labels_spec: + assert isinstance(labels_spec, list), \ + "labels should be list, but received {}".format(type(labels_spec)) + if labels is None: + labels = [s._create_feed_layer() for s in labels_spec] + else: + assert isinstance(labels, list), \ + "labels should be list, but received {}".format(type(labels)) + for label_spec, label in zip(labels_spec, labels): + if label_spec.shape != label.shape: + label.desc.set_shape(label_spec.shape) + return inputs, labels + + def _prepare_reader(self): + dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] + dist_context = self._dist_contexts[self._mode] + dist_main_block = dist_main_prog.global_block() + + # NOTE: this list may be changed if Paddle changes the existing rules. + related_reader_ops = [ + "create_py_reader", "create_double_buffer_reader", "read" + ] + # remove the first three ops if multiple run fit/evaluate/predict + if dist_main_block.ops[0].type == 'create_py_reader': + for i in range(len(related_reader_ops)): + if dist_main_block.ops[0].type in related_reader_ops: + dist_main_block._remove_op(0, sync=False) + dist_main_block._sync_with_cpp() + # Step 1: find the reader ops + reader_op_indices = [] + for idx, op in enumerate(dist_main_block.ops): + if op.type in related_reader_ops: + reader_op_indices.append(idx) + # Step 2: insert the new reader ops to cpp + new_reader_ops = [] + for idx in reversed(reader_op_indices): + new_op_desc = dist_main_block.desc._prepend_op() + new_op_desc.copy_from(dist_main_block.ops[idx].desc) + new_op = Operator(dist_main_block, + new_op_desc, + type=new_op_desc.type()) + new_reader_ops.append(new_op) + dist_op = DistributedOperator(new_op) + dist_context.add_dist_op_for_program(dist_op) + # Step 3: insert the new reader ops to python + for new_op in new_reader_ops: + dist_main_block.ops.insert(0, new_op) + for i in range(len(reader_op_indices)): + reader_op_indices[i] += len(reader_op_indices) + # Step 4: remove the old reader ops from python and cpp + for idx in reversed(reader_op_indices): + op = dist_main_block.ops.pop(idx) + dist_main_block.desc._remove_op(idx, idx + 1) + dist_main_block._sync_with_cpp() + self._has_prepared_reader[self._mode] = True + + def _prepare_feed(self, data, user_feeds, mode): + feeds = {} + if data is not None: + if isinstance(data, (list, tuple)): + if len(data) == 1 and isinstance(data[0], dict): + for name, data in data[0].items(): + feeds[name] = data + else: + raise ValueError("Unsupported data {}".format(data)) + elif isinstance(data, dict): + for name, data in data.items(): + feeds[name] = data + else: + raise ValueError("Unsupported data {}".format(data)) + if user_feeds is not None: + assert isinstance(user_feeds, dict), \ + "user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__) + for name, data in user_feeds.items(): + feeds[name] = data + return feeds + + def _prepare_fetch(self, user_fetches, mode): + if user_fetches is not None: + assert isinstance(user_fetches, list), \ + "user_fetches must be a list, but receive {}".format(type(user_fetches).__name__) + fetch_names = [] + fetch_indices = [] + + def _process_fetch_group(group_name, var_list): + group_indices = [] + for var in var_list: + # Remove duplicate var_names + if self._is_local_var(var): + var_name = _to_name_str(var) + if var_name not in fetch_names: + fetch_names.append(var_name) + group_indices.append(fetch_names.index(var_name)) + if not group_indices: + fetch_names.append([]) + fetch_indices.append(group_indices) + + if mode != "predict": + _process_fetch_group("loss", self._fetch_vars[mode]["loss"]) + if mode != "predict": + metrics = self._fetch_vars[mode]["metrics"] + for i, var_list in enumerate(metrics): + _process_fetch_group("metrics_" + str(i), var_list) + if mode == "predict": + _process_fetch_group("outputs", self._fetch_vars[mode]["outputs"]) + user_fetches_collection = [ + item[1] for item in get_collection(CollectionNames.FETCHES) + ] + var_list = (user_fetches_collection or []) + (user_fetches or []) + _process_fetch_group("fetches", var_list) + return fetch_names, fetch_indices + + def _prepare_logger(self, + outs, + epoch=None, + step=None, + lr=None, + fetch_names=None, + fetch_indices=None, + profiler_log="", + mode=None): + logs = "[{}] ".format(mode) + if epoch is not None: + logs += "epoch: {:d} ".format(epoch) + if step is not None: + logs += "step: {:d} ".format(step) + if lr is not None: + logs += "lr: {:5e} ".format(lr) + group_idx = 0 + # logging loss + if mode != "predict": + loss_indices = fetch_indices[group_idx] + for idx in loss_indices: + logs += "loss: {:8f} ".format(outs[idx][0]) + group_idx += 1 + # logging metrics + if mode != "predict": + metric_vars = self._fetch_vars[mode]["metrics"] + if metric_vars: + for metric in self._metrics: + metrics_indices = fetch_indices[group_idx] + metric_out = [] + for idx in metrics_indices: + metric_out.append(outs[idx]) + if metric_out: + metric.update(*metric_out) + results = metric.accumulate() + for i, res in enumerate(to_list(results)): + logs += "{}: {:8f} ".format(metric.name()[i], res) + group_idx += 1 + # Skip logging outputs + if mode == "predict": + group_idx += 1 + # logging user fetches + fetches_logging = get_collection(CollectionNames.LOGGING) + for name, var in fetches_logging: + if var.name in fetch_names: + idx = fetch_names.index(var.name) + # Use the user defined name for logging + logs += "{}: {} ".format(name, outs[idx]) + logs += profiler_log + self._logger.info(logs) + + def _prepare_history(self, outs, fetch_indices=None, mode=None): + history = {} + group_idx = 0 + # store loss + if mode != "predict": + loss_indices = fetch_indices[group_idx] + loss_values = [] + for idx in loss_indices: + loss_values.append(outs[idx][0]) + history["loss"] = loss_values + group_idx += 1 + # store metrics + if mode != "predict": + metric_vars = self._fetch_vars[mode]["metrics"] + if metric_vars: + for metric in self._metrics: + metrics_indices = fetch_indices[group_idx] + metric_out = [] + for idx in metrics_indices: + metric_out.append(outs[idx]) + if metric_out: + metric.update(*metric_out) + results = metric.accumulate() + history[tuple(metric.name())] = to_list(results) + group_idx += 1 + # store outputs + if mode == "predict": + outputs_indices = fetch_indices[group_idx] + outputs_values = [] + for idx in outputs_indices: + outputs_values.append(outs[idx]) + history["outputs"] = outputs_values + group_idx += 1 + # store user fetches + fetches_indices = fetch_indices[group_idx] + fetches_values = [] + for idx in fetches_indices: + fetches_values.append(outs[idx]) + history["fetches"] = fetches_values + return history + + def _prepare_program(self, mode): + # Do the build process + self._build(mode) + # Do the planning process + self._plan(mode) + # Do the parallel process + self._parallel(mode) + # Init comm and startup program + self._initialize(mode) + self._has_prepared[mode] = True + + def _build(self, mode): + if _non_static_mode() or self._dygraph_mode: + paddle.disable_static() + self._dygraph_mode = True + self._logger.info("Building model with 'to_static' method.") + + inputs_spec = self._inputs_spec + labels_spec = self._labels_spec if self._labels_spec else [] + self.program_helper = ProgramHelper(self._model, self._loss, + self._metrics, inputs_spec, + labels_spec) + # build forward main program + self.program_helper.build_program(mode) + + self.concrete_program = self.program_helper.concrete_program + serial_main_prog = self.program_helper.main_program + serial_startup_prog = self.program_helper.startup_program + + inputs = self.program_helper.input_vars + outputs = self.program_helper.output_vars + labels = self.program_helper.label_vars + losses = self.program_helper.loss_vars + metrics = self.program_helper.metric_vars + + self._inputs = inputs + self._labels = labels + + paddle.enable_static() + else: + # build program in static mode + serial_main_prog = self._serial_main_progs.get(mode, None) + if serial_main_prog is not None: + return + + outputs = [] + losses = [] + metrics = [] + inputs = self._inputs if self._inputs else [] + labels = self._labels if self._labels else [] + serial_main_prog = self._orig_main_prog.clone() + serial_startup_prog = self._orig_startup_prog.clone() + if not self._skip_build: + with static.program_guard(serial_main_prog, serial_startup_prog), \ + utils.unique_name.guard(): + outputs = to_list(self._model(*inputs)) + if mode != "predict" and self._loss: + losses = to_list(self._loss(*(outputs + labels))) + + if mode != "predict" and (outputs or labels): + for metric in self._metrics: + metrics.append( + to_list(metric.compute(*(outputs + labels)))) + else: + losses = to_list(self._loss) + + default_ctx = get_default_distributed_context() + if not default_ctx.has_annotation: + # We build the world process group because the data parallel + # needs all ranks by default. + new_process_group(list(range(self._nranks))) + default_ctx.data_parallel = True + + feed_vars = {"inputs": inputs, "labels": labels} + + fetch_vars = { + "outputs": flatten(outputs), + "loss": losses, + "metrics": metrics + } + + if mode != "train": + serial_main_prog = serial_main_prog.clone(for_test=True) + + self._set_recompute_ckpts() + self._dist_contexts[mode] = DistributedContext( + serial_main_prog, serial_startup_prog, self._optimizer, losses, + feed_vars, fetch_vars, self._cluster, self._strategy) + self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale + + def _optimization_tuning(self, mode, dataset, batch_size): + if not self._tuning.enable: + raise ValueError("Please set `tuning.enable=True`.") + + assert mode == "train" + # Do the build process + self._build(mode) + # Do the planning process + self._plan(mode) + + dataset.dp_world_size = self._dp_world_sizes + dataset.dp_rank = self._dp_ranks + + from .tuner.optimization_tuner import OptimizationTuner + self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(), + self._dist_contexts[mode], + dataset, + self._inputs_spec, + self._labels_spec, + batch_size=batch_size, + rank=self._cur_rank) + + self._optimization_tuner.tune() + + if self._tuning.run_after_tuning: + # update the strategy + self._dist_contexts[ + mode]._strategy = self._optimization_tuner.get_best_config() + + def _plan(self, mode): + if self._planned_mode is None: + self._planned_mode = mode + else: + self._init_dist_context(mode) + + self._planners[mode] = Planner(mode, self._dist_contexts[mode]) + self._planners[mode].plan() + + # infer data parallel info + inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"] + labels_var = self._dist_contexts[mode].serial_feed_vars["labels"] + block = self._dist_contexts[mode].serial_main_program.global_block() + # TODO: check this feed_list + feed_list = [] + for var in inputs_var + labels_var: + if var.name in block.vars: + feed_list.append(block.vars[var.name]) + + self._dp_world_sizes = [] + self._dp_ranks = [] + for feed_var in feed_list: + dp_world_size, dp_rank = self._get_input_split_info( + feed_var, self._dist_contexts[mode]) + self._dp_world_sizes.append(dp_world_size) + self._dp_ranks.append(dp_rank) + + def _parallel(self, mode, all_ranks=False): + # Parallelize program based on the planner's results + # For now, the completer has to be passed to the planner, + # because we may use it to complete the annotation of the backwarkward and update. + parallelizer = Parallelizer(mode, self._planners[mode].completer, + self._dist_contexts[mode]) + if not all_ranks: + parallelizer.parallel(self._cur_rank) + else: + parallelizer.parallel_all() + + def _init_dist_context(self, mode): + # Init dist_context['mode'] with the first planned dist_context + # to guarantee that train/eval/predict mode have same parallel strategy + dist_context = self._dist_contexts[mode] + origin_main_prog = dist_context._original_serial_main_program + ref_mode = self._planned_mode + ref_dist_context = self._dist_contexts[ref_mode] + ref_origin_main_prog = ref_dist_context._original_serial_main_program + ref_blocks = ref_origin_main_prog.blocks + for ib, block in enumerate(origin_main_prog.blocks): + for iop, op in enumerate(block.ops): + ref_op = ref_blocks[ib].ops[iop] + assert op.type == ref_op.type, \ + "'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type) + ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program( + ref_op) + dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr) + + def _initialize(self, mode): + # Get the current content from the distributed context + self._serial_main_progs[mode] = self._dist_contexts[ + mode].serial_main_program + self._serial_startup_progs[mode] = self._dist_contexts[ + mode].serial_startup_program + self._dist_main_progs[mode] = self._dist_contexts[ + mode].dist_main_programs + self._dist_startup_progs[mode] = self._dist_contexts[ + mode].dist_startup_programs + self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars + self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars + self._lr_optimizer = self._dist_contexts[mode]._lr_optimizer + + if self._nranks > 1: + # Traverse different rank programs and traverse each op of them, + # instantiate communication by process_mapping. + all_process_groups = get_all_process_groups() + + # NOTE: add the comm init control in the future for auto search + for process_group in all_process_groups: + if self._cur_rank not in process_group.ranks: + continue + process_group.instantiate() + + place = _get_device() + if isinstance(place, fluid.CUDAPlace): + place = fluid.CUDAPlace(ParallelEnv().dev_id) + + if self._strategy.seed: + paddle.seed(self._strategy.seed + self._dp_ranks[0]) + np.random.seed(self._strategy.seed + self._dp_ranks[0]) + random.seed(self._strategy.seed + self._dp_ranks[0]) + + if self._dygraph_mode: + dist_context = self._dist_contexts[mode] + dist_main_program = self._dist_main_progs[mode][self._cur_rank] + self.program_helper.init(dist_main_program, place, dist_context) + + if self._executor is None: + self._executor = paddle.static.Executor(place) + uninitialized = [] + dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] + for var in dist_startup_prog.list_vars(): + scope_var = global_scope().find_var(var.name) + if scope_var and scope_var.get_tensor()._is_initialized(): + continue + uninitialized.append(var) + if uninitialized: + prune_startup_prog = dist_startup_prog._prune(uninitialized) + self._executor.run(prune_startup_prog) + + if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"): + self._set_state_dict(mode, self._strict, self._state_dict, + self._dist_attr) + + if self._strategy.reinit: + self._logger.info("NOTE: parameters wiil be re-initialized.") + dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] + self._executor.run(dist_startup_prog) + + def fit(self, + train_data, + train_sample_split=None, + batch_size=1, + epochs=1, + steps_per_epoch=None, + valid_data=None, + valid_sample_split=None, + valid_freq=1, + valid_steps=None, + collate_fn=None, + callbacks=None): + """ + Trains the model for a fixed number of epochs. If `valid_data` is set, + evaluation will be done at the end of each epoch. + + Args: + train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. + train_sample_split (int, optional): Each sample of the train dataset is assumed + to be a (input, label) pair by default and has two items. If each sample has + more than two items, train_sample_split specifies how to split these items into + input and label. The items before it are input and the left are label. Default: None. + batch_size (int, optional): The batch size of train_data and valid_data if provided. + The user's data will be used directly without batching if set to None. Default: 1. + epochs (int, optional): The number of epochs to train the model. Default: 1. + steps_per_epoch (int, optional): The total number of steps (batches of samples) + is executed in one epoch before stating the next one. If None, it is equal to + the number samples in your dataset divided by the batch size. Default: None. + valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for + evaluation at the end of epoch. No evaluation will be done if set to None. + Default: None. (Unsupported for now) + valid_freq (int, optional): Only relevant if valid_data is provided. This specifies + how many training epochs before a new evaluation is performed. Default: 1. + valid_sample_split (int, optional): Only relevant if valid_data is provided. + Each sample of the valid dataset is assumed to be a (input, label) pair + by default and has two items. If each sample has more than two items, + valid_sample_split specifies how to split these items into input and label. + The items before it are input and the left are label. Default: None. + valid_steps (int, optional): Only relevant if valid_data is provided. + It is the total number of steps (batches of samples) to draw before + stopping validation at the end of every epoch. If None, validation will run until the + `valid_data` dataset is exhausted. The validation will start from the + beginning of the dataset at each epoch. Default: None. + collate_fn(callable, optional): function to generate mini-batch data by merging + the sample list, None for only stack each fields of sample in axis + 0. Default None. + callbacks (Callback|None, optional): A list of `Callback` instances to apply + during training. Default: None. (Unused for now) + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle + import paddle.vision.transforms as T + from paddle.distributed.fleet import auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + train_dataset = MNIST(mode='train', transform=transform) + + model = paddle.vision.models.LeNet() + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + metrics = paddle.metric.Accuracy(topk=(1, 2)) + + engine = auto.Engine(model, loss, optimizer, metrics) + engine.fit(train_dataset, + epochs=2, + batch_size=64) + """ + self._mode = 'train' + self._inputs_spec, self._labels_spec = self._prepare_data_spec( + train_data, train_sample_split, batch_size) + self._inputs, self._labels = self._prepare_data_tensor( + self._inputs_spec, self._labels_spec) + if not self._has_prepared[self._mode]: + self._prepare_program(self._mode) + else: + self._switch_mode(self._mode) + train_dataloader = self._prepare_dataloader_from_generator( + dataset=train_data, + capacity=70, + # use_double_buffer=use_double_buffer, + iterable=False, + # return_list=return_list, + # use_multiprocess=use_multiprocess, + # drop_last=drop_last, + batch_size=batch_size, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + collate_fn=collate_fn) + fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) + lr_scheduler = self._get_lr_scheduler(self.main_program) + + with profiler.Profiler(timer_only=True) as prof: + for epoch in range(epochs): + for step, data in enumerate(train_dataloader): + self._strategy.return_numpy = True + fetch_names = [fetch_names[0]] + fetch_names.append('labels') + print_param(self.main_program) + try: + outs, lables = self._executor.run( + self.main_program, + fetch_list=fetch_names, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy) + print("lables: {}".format(lables[:20])) + print("outs: {}".format(outs)) + + except core.EOFException: + break + if lr_scheduler and step % self._k_steps == 0: + lr_scheduler.step() + lr = self._get_lr(self._lr_optimizer) + + prof.step() + + # self._prepare_logger(outs, epoch, step, lr, + # fetch_names, fetch_indices, + # prof.step_info(), self._mode) + # history = self._prepare_history(outs, fetch_indices, + # self._mode) + history = None + + if valid_data and epoch % valid_freq == 0: + self.evaluate(valid_data, valid_sample_split, batch_size, + valid_steps, collate_fn, callbacks) + self._switch_mode("train") + else: + self._reset_metrics() + return history + + def evaluate(self, + valid_data, + valid_sample_split=None, + batch_size=1, + steps=None, + collate_fn=None, + callbacks=None): + """ + Evaluate the loss and metrics of the model on evaluation data. + + Args: + valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. + valid_sample_split (int, optional): Each sample of the eval dataset is assumed + to be a (input, label) pair by default and has two items. If each sample has + more than two items, valid_sample_split specifies how to split these items into + input and label. The items before it are input and the left are label. Default: None. + batch_size (int, optional): The batch size of valid_data. The user's data will + be used directly without batching if set to None. Default: 1. + steps (int, optional): It is the total number of steps (batches of samples) to draw before + stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted. + The evaluation will start from the beginning of the dataset in each run. Default: None. + collate_fn(callable, optional): function to generate mini-batch data by merging + the sample list, None for only stack each fields of sample in axis + 0. Default None. + callbacks (Callback|None, optional): A list of `Callback` instances to apply + during evaluating. Default: None. (Unused for now) + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle + import paddle.vision.transforms as T + from paddle.distributed.fleet import auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + valid_dataset = MNIST(mode='test', transform=transform) + + model = paddle.vision.models.LeNet() + loss = paddle.nn.CrossEntropyLoss() + metrics = paddle.metric.Accuracy(topk=(1, 2)) + + engine = auto.Engine(model, loss, metrics=metrics) + engine.evaluate(valid_dataset, batch_size=64) + + """ + self._mode = 'eval' + self._inputs_spec, self._labels_spec = self._prepare_data_spec( + valid_data, valid_sample_split, batch_size) + self._inputs, self._labels = self._prepare_data_tensor( + self._inputs_spec, self._labels_spec) + if not self._has_prepared[self._mode]: + self._prepare_program(self._mode) + else: + self._switch_mode(self._mode) + assert self._mode in self._dist_main_progs, \ + "eval model is not ready, please call `engine._prepare_program('eval')` first." + valid_dataloader = self._prepare_dataloader_from_generator( + dataset=valid_data, + # feed_list=feed_list, + capacity=70, + # use_double_buffer=use_double_buffer, + iterable=False, + # return_list=return_list, + # use_multiprocess=use_multiprocess, + # drop_last=drop_last, + # places=places, + batch_size=batch_size, + # epochs=epochs, + steps_per_epoch=steps, + collate_fn=collate_fn) + fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) + + for step, _ in enumerate(valid_dataloader): + try: + outs = self._executor.run( + self.main_program, + fetch_list=fetch_names, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy) + except core.EOFException: + break + self._prepare_logger(outs, None, step, None, fetch_names, + fetch_indices, "", self._mode) + history = self._prepare_history(outs, fetch_indices, self._mode) + self._reset_metrics() + return history + + def predict(self, + test_data, + test_sample_split=None, + batch_size=1, + steps=None, + collate_fn=None, + callbacks=None): + """ + Compute the output predictions on testing data. + + Args: + test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. + test_sample_split (int, optional): Each sample of the test dataset is assumed + to be a (input, label) pair by default and has two items. If each sample has + more than two items, test_sample_split specifies how to split these items into + input and label. The items before it are input and the left are label. Default: None. + batch_size (int, optional): The batch size of test_data. The user's data will + be used directly without batching if set to None. Default: 1. + steps (int, optional): It is the total number of steps (batches of samples) to draw before + stopping predict. If None, predict will run until the `test_data` dataset is exhausted. + The predict will start from the beginning of the dataset in each run. Default: None. + collate_fn(callable, optional): function to generate mini-batch data by merging + the sample list, None for only stack each fields of sample in axis + 0. Default None. + callbacks (Callback|None, optional): A list of `Callback` instances to apply + during testing. Default: None. (Unused for now) + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle + import paddle.vision.transforms as T + from paddle.distributed.fleet import auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + valid_dataset = MNIST(mode='test', transform=transform) + + model = paddle.vision.models.LeNet() + + engine = auto.Engine(model) + engine.predict(valid_dataset, batch_size=64) + """ + self._mode = 'predict' + self._inputs_spec, self._labels_spec = self._prepare_data_spec( + test_data, test_sample_split, batch_size) + self._inputs, self._labels = self._prepare_data_tensor( + self._inputs_spec, self._labels_spec) + if not self._has_prepared[self._mode]: + self._prepare_program(self._mode) + else: + self._switch_mode(self._mode) + assert self._mode in self._dist_main_progs, \ + "predict model is not ready, please call `engine._prepare_program('predict')` first." + test_dataloader = self._prepare_dataloader_from_generator( + dataset=test_data, + # feed_list=feed_list, + capacity=70, + # use_double_buffer=use_double_buffer, + iterable=False, + # return_list=return_list, + # use_multiprocess=use_multiprocess, + # drop_last=drop_last, + # places=places, + batch_size=batch_size, + # epochs=epochs, + steps_per_epoch=steps, + collate_fn=collate_fn) + fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) + + for step, _ in enumerate(test_dataloader): + try: + outs = self._executor.run( + self.main_program, + fetch_list=fetch_names, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy) + except core.EOFException: + break + self._prepare_logger(outs, None, step, None, fetch_names, + fetch_indices, "", self._mode) + history = self._prepare_history(outs, fetch_indices, self._mode) + + return history + + def dataloader( + self, + dataset, + # return_list=True, + batch_size=1, + shuffle=False, + drop_last=False, + collate_fn=None, + num_workers=0, + use_buffer_reader=True, + use_shared_memory=True, + timeout=0, + worker_init_fn=None, + epochs=1, + steps_per_epoch=None, + sample_split=1, + mode=None): + if mode is not None: + self.to_mode(mode) + self._inputs_spec, self._labels_spec = self._prepare_data_spec( + dataset, sample_split, batch_size) + self._inputs, self._labels = self._prepare_data_tensor( + self._inputs_spec, self._labels_spec) + if not self._has_prepared[self._mode]: + self._prepare_program(self._mode) + else: + self._switch_mode(self._mode) + dataloader = self._prepare_dataloader( + dataset, + return_list=False, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + collate_fn=collate_fn, + num_workers=num_workers, + use_buffer_reader=use_buffer_reader, + use_shared_memory=use_shared_memory, + timeout=timeout, + worker_init_fn=worker_init_fn, + epochs=epochs, + steps_per_epoch=steps_per_epoch) + return dataloader + + def dataloader_from_generator( + self, + dataset, + capacity=70, + use_double_buffer=True, + iterable=True, + # return_list=False, + use_multiprocess=False, + drop_last=True, + batch_size=1, + epochs=1, + steps_per_epoch=None, + collate_fn=None, + sample_split=1, + mode=None): + if mode is not None: + self.to_mode(mode) + self._inputs_spec, self._labels_spec = self._prepare_data_spec( + dataset, sample_split, batch_size) + self._inputs, self._labels = self._prepare_data_tensor( + self._inputs_spec, self._labels_spec) + if not self._has_prepared[self._mode]: + self._prepare_program(self._mode) + else: + self._switch_mode(self._mode) + dataloader = self._prepare_dataloader_from_generator( + dataset=dataset, + # feed_list=feed_list, + capacity=capacity, + use_double_buffer=use_double_buffer, + iterable=iterable, + return_list=False, + use_multiprocess=use_multiprocess, + drop_last=drop_last, + # places=places, + batch_size=batch_size, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + collate_fn=collate_fn) + return dataloader + + def prepare(self, + inputs_spec=None, + labels_spec=None, + inputs=None, + labels=None, + main_program=None, + startup_program=None, + mode=None): + if mode is not None: + self.to_mode(mode) + if inputs or labels: + self._skip_build = True + self._inputs_spec = inputs_spec + self._labels_spec = labels_spec + self._inputs, self._labels = self._prepare_data_tensor( + self._inputs_spec, self._labels_spec, inputs, labels) + self._orig_main_prog = main_program + if self._orig_main_prog is None: + self._orig_main_prog = static.default_main_program() + self._orig_startup_prog = startup_program + if self._orig_startup_prog is None: + self._orig_startup_prog = static.default_startup_program() + if not self._has_prepared[self._mode]: + self._prepare_program(self._mode) + else: + self._switch_mode(self._mode) + elif inputs_spec or labels_spec: + self._inputs_spec = inputs_spec + self._labels_spec = labels_spec + self._outside_dataloader = True + self._inputs, self._labels = self._prepare_data_tensor( + self._inputs_spec, self._labels_spec) + self._orig_main_prog = main_program + if self._orig_main_prog is None: + self._orig_main_prog = static.default_main_program() + self._orig_startup_prog = startup_program + if self._orig_startup_prog is None: + self._orig_startup_prog = static.default_startup_program() + if not self._has_prepared[self._mode]: + self._prepare_program(self._mode) + else: + self._switch_mode(self._mode) + else: + assert self._inputs_spec and self._labels_spec, \ + "Please call the dataloader(...) before calling prepare(...)" + + def run( + self, + data=None, + # program=None, + feed=None, + fetch_list=None, + # feed_var_name='feed', + # fetch_var_name='fetch', + # scope=None, + # return_numpy=True, + # use_program_cache=False, + # return_merged=True, + # use_prune=False, + mode=None): + if mode is not None: + self.to_mode(mode) + feed_dict = self._prepare_feed(data, feed, self._mode) + fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode) + if self._outside_dataloader and not self._has_prepared_reader[ + self._mode]: + self._prepare_reader() + outs = self._executor.run(self.main_program, + feed=feed_dict, + fetch_list=fetch_names, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy) + self._prepare_logger(outs, None, None, None, fetch_names, fetch_indices, + "", self._mode) + history = self._prepare_history(outs, fetch_indices, self._mode) + return history + + def _prepare_dataloader(self, + dataset, + return_list=True, + batch_size=1, + shuffle=False, + drop_last=False, + collate_fn=None, + num_workers=0, + use_buffer_reader=True, + use_shared_memory=True, + timeout=0, + worker_init_fn=None, + epochs=1, + steps_per_epoch=None): + + if self._strategy.gradient_merge and batch_size is not None: + assert batch_size % self._k_steps == 0, \ + "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps) + batch_size //= self._k_steps + + dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] + dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] + dist_context = self._dist_contexts[self._mode] + dist_main_block = dist_main_prog.global_block() + + # NOTE: Get feed_list, then insert dataloader op with sharded var shape. + # Cause predict_program does not contain labels var, + # then we will add labels var from serial_program to dist_program, + # that maintains the length of feed_list equal to the length of dataset's values. + inputs_var = self._feed_vars[self._mode]["inputs"] + labels_var = self._feed_vars[self._mode]["labels"] + feed_list = [] + for var in inputs_var + labels_var: + if var.name in dist_main_block.vars: + feed_list.append(dist_main_block.vars[var.name]) + else: + copy_var = dist_main_block._clone_variable(var, var.persistable) + copy_var.desc.set_original_id(var.desc.original_id()) + feed_list.append(copy_var) + + # insert read op at the end of program + places = paddle.static.cuda_places() + with static.program_guard(dist_main_prog, dist_startup_prog): + dataloader = DistributedDataLoader( + dataset, + feed_list=feed_list, + places=places, + return_list=return_list, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + collate_fn=collate_fn, + num_workers=num_workers, + use_buffer_reader=use_buffer_reader, + use_shared_memory=use_shared_memory, + timeout=timeout, + worker_init_fn=worker_init_fn, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + split_data=self._strategy.split_data, + data_parallel_world_size=self._dp_world_sizes, + data_parallel_rank=self._dp_ranks) + + return dataloader + + def _prepare_dataloader_from_generator(self, + dataset, + capacity=None, + use_double_buffer=True, + iterable=True, + return_list=False, + use_multiprocess=False, + drop_last=True, + batch_size=1, + epochs=1, + steps_per_epoch=None, + collate_fn=None): + + if self._strategy.gradient_merge and batch_size is not None: + assert batch_size % self._k_steps == 0, \ + "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps) + batch_size //= self._k_steps + + dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] + dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] + dist_context = self._dist_contexts[self._mode] + dist_main_block = dist_main_prog.global_block() + + # NOTE: Get feed_list, then insert dataloader op with sharded var shape. + # Cause predict_program does not contain labels var, + # then we will add labels var from serial_program to dist_program, + # that maintains the length of feed_list equal to the length of dataset's values. + inputs_var = self._feed_vars[self._mode]["inputs"] + labels_var = self._feed_vars[self._mode]["labels"] + feed_list = [] + for var in inputs_var + labels_var: + if var.name in dist_main_block.vars: + feed_list.append(dist_main_block.vars[var.name]) + else: + copy_var = dist_main_block._clone_variable(var, var.persistable) + copy_var.desc.set_original_id(var.desc.original_id()) + feed_list.append(copy_var) + + # # remove the first three ops if multi run fit/evaluate/predict + # self._op_size = len(dist_main_block.ops) + # if dist_main_block.ops[0].type == 'create_py_reader': + # op_size -= 3 + # for _ in range(3): + # dist_main_block._remove_op(0, sync=False) + + places = paddle.static.cuda_places() + with static.program_guard(dist_main_prog, dist_startup_prog): + dataloader = DistributedDataLoaderFromGenerator( + dataset=dataset, + feed_list=feed_list, + capacity=capacity, + use_double_buffer=use_double_buffer, + iterable=iterable, + return_list=return_list, + use_multiprocess=use_multiprocess, + drop_last=drop_last, + places=places, + batch_size=batch_size, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + collate_fn=collate_fn, + split_data=self._strategy.split_data, + data_parallel_world_size=self._dp_world_sizes, + data_parallel_rank=self._dp_ranks) + self._prepare_reader() + # # move read op from the end of program to the start of program + # new_op_size = len(dist_main_block.ops) + # for _ in range(new_op_size - 1, op_size - 1, -1): + # op = dist_main_block.ops[new_op_size - 1] + # new_op_desc = dist_main_block.desc._prepend_op() + # new_op_desc.copy_from(op.desc) + # new_op = Operator(dist_main_block, + # new_op_desc, + # type=new_op_desc.type()) + # dist_main_block.ops.insert(0, new_op) + # dist_op = DistributedOperator(new_op) + # dist_context.add_dist_op_for_program(dist_op) + # for _ in range(new_op_size - op_size): + # dist_main_block._remove_op(new_op_size, sync=False) + # dist_main_block._sync_with_cpp() + return dataloader + + def _tune(self, tune_data, tune_sample_split=None, batch_size=1): + self._mode = 'train' + self._inputs_spec, self._labels_spec = self._prepare_data_spec( + tune_data, tune_sample_split, batch_size) + self._inputs, self._labels = self._prepare_data_tensor( + self._inputs_spec, self._labels_spec) + self._optimization_tuning(self._mode, tune_data, batch_size) + + def _validate_spec(self, specs): + specs = to_list(specs) + self._k_steps = self._strategy.gradient_merge.k_steps + if specs is not None: + for i, spec in enumerate(specs): + assert isinstance(spec, InputSpec) + if spec.name is None: + raise ValueError( + "Requires Input[{}].name != None, but receive `None` with {}." + .format(i, spec)) + if self._k_steps > 1: + shape = list(spec.shape) + assert shape[0] % self._k_steps == 0, \ + "Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps) + shape[0] //= self._k_steps + spec.shape = shape + return specs + + def _is_local_var(self, var): + var_name = _to_name_str(var) + return var_name in self.main_program.global_block().vars + + def _get_input_split_info(self, var, dist_context): + # deduce how the input data is split among the cluster + from .utils import _get_comm_group, _get_corresponding_rank + + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) + process_mesh = tensor_dist_attr.process_mesh + dims_mapping = tensor_dist_attr.dims_mapping + + if self._cur_rank not in process_mesh.processes: + rank_id = _get_corresponding_rank(dist_context, process_mesh, + self._cur_rank) + else: + rank_id = self._cur_rank + + batch_size_axis = dims_mapping[0] + if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1: + group_ranks = _get_comm_group(process_mesh.processes, + process_mesh.topology, + batch_size_axis, rank_id) + return len(group_ranks), group_ranks.index(rank_id) + + return 1, 0 + + def _set_recompute_ckpts(self): + # NOTE hack to enable recompute in engine api for GPT-3 + # TODO support more PaddleNLP/CV models here + + recompute = self._strategy.recompute + + # extract ckpts by specific model + if isinstance(self._model, paddle.nn.Layer): + if hasattr(self._model, + "gpt") and self._model.__class__.__name__ in [ + 'GPTForPretraining', 'GPTForPretrainingAuto' + ]: + exact_ckpts = self._model.gpt.checkpoints + else: + exact_ckpts = recompute.checkpoints + else: + exact_ckpts = recompute.checkpoints + + # modify strategy + if recompute.enable: + recompute.checkpoints = exact_ckpts[:] + logs = { + 'Model Class': self._model.__class__.__name__, + 'Applied Recompute ckpts': exact_ckpts + } + self._logger.info(logs) + + def _validate_opt(self, optimizer): + if optimizer is not None: + optimizer._parameter_list = None + optimizer._param_groups = None + return optimizer + + def _reset_metrics(self): + for metric in self._metrics: + metric.reset() + + def _switch_mode(self, mode): + self.to_mode(mode) + self._initialize(mode) + + def to_mode(self, mode): + assert mode in ["train", "eval", "predict"], \ + "mode {} should be one of ['train', 'eval', 'predict']".format(mode) + self._mode = mode + + def _set_state_dict(self, mode, strict, state_dict, dist_attr): + program = self._dist_main_progs[mode][self._cur_rank] + dist_context = self._dist_contexts[mode] + cur_dist_attr = get_dist_attr(program, dist_context) + converter = Converter(state_dict, dist_attr, cur_dist_attr) + state_dict = converter.convert(strict=strict) + program.set_state_dict(state_dict) + + def save(self, path, training=True): + """ + Saves the model, parameters, optimizer state to path. + If `training` is set to False, only inference model will be saved. + + Args: + path (str): The file prefix to save model. The format + is 'dirname/file_prefix' or 'file_prefix'. if empty str. + A exception will be raised. + training (bool, optional): Whether to save for training. If not, save + for inference only. If `training` is set to True, the optimizer state + will be saved. Otherwise, only the model and parameters are saved. + This function will silently overwrite existing file at the target + location. Default: True. + + Returns: + None + + Examples: + + .. code-block:: python + import paddle + import paddle.vision.transforms as T + from paddle.distributed.fleet import auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + train_dataset = MNIST(mode='train', transform=transform) + + model = paddle.vision.models.LeNet() + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + metrics = paddle.metric.Accuracy(topk=(1, 2)) + + engine = auto.Engine(model, loss, optimizer, metrics) + engine.fit(train_dataset, + epochs=1, + batch_size=64) + engine.save("./my_model") + + """ + if training: + assert 'train' in self._serial_main_progs, \ + "training model is not ready, please call `engine._prepare_program('train')` first." + serial_program = self._serial_main_progs["train"] + dist_main_prog = self._dist_main_progs["train"][self._cur_rank] + dist_context = self._dist_contexts["train"] + self._saver.save(path, + serial_program=serial_program, + dist_main_program=dist_main_prog, + dist_context=dist_context) + else: + mode = "predict" + feed_vars = self._feed_vars[mode]['inputs'] + fetch_vars = self._fetch_vars[mode]['outputs'] + dist_main_prog = self._dist_main_progs[mode][self._cur_rank] + self._saver.save_inference_model(path, + feed_vars, + fetch_vars, + self._executor, + program=dist_main_prog) + + def load(self, path, strict=True, load_optimizer=True): + """ + Load the stored model, parameters and optimizer states. + + Args: + path (str): The prefix of files storing the model states and + optimizer states. + strict (bool, optional): Whether to skip the loading of mismatch + parameter or raise an error when mismatch happens (not found + the parameter in file storing model states of or receives a + mismatch shape). Default: False. + load_optimizer (bool, optional): If True, the stored optimizer + states is restored. Otherwise, the optimizer states is initialized + from scratch. Default: False. + + Returns: + None + + Examples: + + .. code-block:: python + import paddle + import paddle.vision.transforms as T + from paddle.distributed.fleet import auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + train_dataset = MNIST(mode='train', transform=transform) + + model = paddle.vision.models.LeNet() + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + metrics = paddle.metric.Accuracy(topk=(1, 2)) + + engine = auto.Engine(model, loss, optimizer, metrics) + engine.fit(train_dataset, + epochs=1, + batch_size=64) + engine.save("./my_model") + engine.load("./my_model") + + """ + self._strict = strict + self._state_dict, self._dist_attr = self._saver.load( + path, load_optimizer) + return self._state_dict, self._dist_attr + + @staticmethod + def _get_lr_scheduler(program): + lr_sheduler = None + if hasattr(program, 'lr_sheduler'): + from paddle.optimizer.lr import LRScheduler + lr_sheduler = program.lr_sheduler + assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler" + return lr_sheduler + + def _get_lr(self, optimizer): + if isinstance(optimizer, paddle.optimizer.Optimizer): + return optimizer.get_lr() + elif isinstance(optimizer, paddle.fluid.optimizer.Optimizer): + if isinstance(optimizer._learning_rate, float): + return optimizer._learning_rate + else: + return optimizer._learning_rate() + else: + raise TypeError( + "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ + " or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer)) + ) + + @property + def main_program(self): + return self._dist_main_progs[self._mode][self._cur_rank] + + @property + def startup_program(self): + return self._dist_startup_progs[self._mode][self._cur_rank] + + @property + def dist_context(self): + return self._dist_contexts[self._mode] + + @property + def serial_main_program(self): + return self._serial_main_progs[self._mode] + + @property + def serial_startup_program(self): + return self._serial_startup_progs[self._mode] + + @property + def fetch_vars(self): + return self._fetch_vars[self._mode] + + @property + def inputs(self): + return self._inputs + + @property + def labels(self): + return self._labels + + +def print_param(program): + for i, p in enumerate(program.all_parameters()): + if i == 10: + break + print(p.name, np.array(p.get_value())[:20]) + + +def print_input(program, vars): + for v in vars: + print(v.name, v.get_value()[:20]) From 9f2534f3638bec40c50398ab3a02af651e83631a Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Mon, 7 Nov 2022 12:13:41 +0800 Subject: [PATCH 35/36] rm engine --- .../distributed/auto_parallel/engine.py.acc | 1651 ----------------- 1 file changed, 1651 deletions(-) delete mode 100644 python/paddle/distributed/auto_parallel/engine.py.acc diff --git a/python/paddle/distributed/auto_parallel/engine.py.acc b/python/paddle/distributed/auto_parallel/engine.py.acc deleted file mode 100644 index dd0798b27689b..0000000000000 --- a/python/paddle/distributed/auto_parallel/engine.py.acc +++ /dev/null @@ -1,1651 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import logging -import random -import numpy as np -from collections import defaultdict - -import paddle -import paddle.utils as utils - -from paddle import fluid, profiler, static -from paddle.metric import Metric -from paddle.static import InputSpec -from paddle.fluid import core -from paddle.fluid import Variable -from paddle.fluid.layers.utils import flatten -from paddle.fluid.executor import global_scope, _to_name_str -from paddle.fluid.framework import Operator, _non_static_mode -from paddle.fluid.framework import _current_expected_place as _get_device -from paddle.fluid.dygraph.parallel import ParallelEnv -from paddle.distributed import fleet - -from .converter import Converter -from .helper import ProgramHelper -from .cluster import Cluster, get_default_cluster -from .planner_v2 import Planner -from .parallelizer_v2 import Parallelizer -from .dist_op import DistributedOperator -from .dist_saver import DistributedSaver -from .dist_loader import DistributedDataLoaderFromGenerator, DistributedDataLoader -from .utils import to_list, get_logger, get_dist_attr -from .process_group import new_process_group, get_all_process_groups -from .dist_context import DistributedContext, get_default_distributed_context -from .strategy import Strategy -from .interface import CollectionNames, get_collection - - -class Engine: - """ - An Engine object can provide the full power of auto parallel to users. - With the help of it, users can easily obtain the abilities of the - distributed training and inference. It also support the dynamic graph and - static graph at the same time. - - Args: - model (paddle.nn.Layer, optional): The model is an instance of - paddle.nn.Layer. - loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer` - instance or any callable function taken the predicted values and - ground truth values as input. It can be None when there is no loss. - Default: None. - optimizer (Optimizer|None, optional): The optimizer need to be set in training - and should be None in eval and predict mode. Default: None. - metrics (Metric|list[Metric]|None, optional): If metrics is set, all - metrics will be calculated and output in train/eval mode. Default: None. - cluster (Cluster|None, optional): The cluster represents the topology information - about the used physical devices. Default: None. (Unused for now) - strategy (Strategy|None, optional): The strategy is used to configure the - parallelization and optimization behaviors. Default: None. - - Examples: - - .. code-block:: python - - import paddle - import paddle.vision.transforms as T - from paddle.distributed.fleet import auto - from paddle.vision.datasets import MNIST - - transform = T.Compose([ - T.Transpose(), - T.Normalize([127.5], [127.5]) - ]) - train_dataset = MNIST(mode='train', transform=transform) - valid_dataset = MNIST(mode='test', transform=transform) - - model = paddle.vision.models.LeNet() - loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.optimizer.Adam( - learning_rate=0.001, parameters=model.parameters()) - metrics = paddle.metric.Accuracy(topk=(1, 2)) - - engine = auto.Engine(model, loss, optimizer, metrics) - # fit - engine.fit(train_dataset, - epochs=2, - batch_size=64) - # evaluate - engine.evaluate(valid_dataset, - batch_size=64) - # predict - engine.predict(valid_dataset, - batch_size=64) - # save - engine.save("./my_model") - # load - engine.load("./my_model") - - """ - - def __init__(self, - model=None, - loss=None, - optimizer=None, - metrics=None, - cluster=None, - strategy=None): - - if model and not isinstance(model, - paddle.nn.Layer) and not callable(model): - raise TypeError( - "'model must be sub classes of `paddle.nn.Layer` or any callable function." - ) - self._model = model - - # if loss and not isinstance(loss, - # paddle.nn.Layer) and not callable(loss): - # raise TypeError( - # "'loss' must be sub classes of `paddle.nn.Layer` or any callable function." - # ) - self._loss = loss - - if optimizer and not isinstance( - optimizer, - (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)): - raise TypeError( - "'optimizer' must be object of class `paddle.optimizer.Optimizer`" - " or `paddle.fluid.optimizer.Optimizer`.") - self._optimizer = self._validate_opt(optimizer) - - metrics = metrics or [] - for metric in to_list(metrics): - assert isinstance(metric, Metric), \ - "{} is not sub class of Metric".format( - metric.__class__.__name__) - self._metrics = to_list(metrics) - - if cluster and not isinstance(cluster, Cluster): - raise TypeError( - "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`" - ) - self._cluster = cluster or get_default_cluster() - - if strategy and not isinstance(strategy, Strategy): - raise TypeError( - "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`" - ) - self._strategy = strategy or Strategy() - - if os.getenv("POD_NAME"): - print("Distribute training by paddle.distributed.launch", - flush=True) - fleet.init(is_collective=True) - - self._executor = None - self._cur_rank = paddle.distributed.get_rank() - self._nranks = paddle.distributed.get_world_size() - self._saver = DistributedSaver() - - self._logger = get_logger(logging.INFO) - - self._orig_main_prog = static.default_main_program() - self._orig_startup_prog = static.default_startup_program() - self._orig_dist_context = get_default_distributed_context() - self._dist_contexts = {} - self._serial_main_progs = {} - self._serial_startup_progs = {} - self._dist_main_progs = defaultdict(dict) # dist main programs - self._dist_startup_progs = defaultdict(dict) # dist startup programs - self._feed_vars = {} - self._fetch_vars = {} - self._planners = {} - self._has_prepared = {"train": False, "eval": False, "predict": False} - self._has_prepared_reader = { - "train": False, - "eval": False, - "predict": False - } - self._inputs_spec = [] - self._labels_spec = [] - self._inputs = [] - self._labels = [] - - self._skip_build = False - self._outside_dataloader = False - self._planned_mode = None - self._dygraph_mode = False - self._tuning = self._strategy.tuning - - def _prepare_data_spec(self, data, split, batch_size): - inputs_spec = [] - labels_spec = [] - if isinstance(data, paddle.io.IterableDataset): - if split is None: - inputs, labels = next(iter(data)) - else: - sample = next(iter(data)) - inputs = sample[:split] - labels = sample[split:] - elif isinstance(data, paddle.io.Dataset): - if split is None: - inputs, labels = data[0] - else: - sample = data[0] - inputs = sample[:split] - labels = sample[split:] - else: - raise ValueError( - "Data should be a Dataset or IterableDatset, but received {}.". - format(type(data).__name__)) - inputs = to_list(inputs) - labels = to_list(labels) - - num_shards = self._strategy.dataset.num_shards - - def _adjust_item_spec(num_shards, spec): - if num_shards > 1 and len(spec.shape) > 1: - spec.shape[0] = spec.shape[0] * num_shards - - def _infer_item_spec(item, name, batch_size, specs): - if isinstance(item, np.ndarray): - spec = InputSpec.from_numpy(item, name) - if batch_size is None: - _adjust_item_spec(num_shards, spec) - specs.append(spec) - else: - specs.append(spec.batch(batch_size)) - elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)): - _adjust_item_spec(num_shards, spec) - spec = InputSpec.from_tensor(item, name) - if batch_size is None: - specs.append(spec) - else: - specs.append(spec.batch(batch_size)) - else: - specs.append(InputSpec([batch_size], type(item), name)) - - if inputs is not None: - for i, item in enumerate(inputs): - assert item is not None, "Receive None input." - name = "input" + str(i) - _infer_item_spec(item, name, batch_size, inputs_spec) - if labels is not None: - for i, item in enumerate(labels): - assert item is not None, "Receive None input." - name = "label" + str(i) - _infer_item_spec(item, name, batch_size, labels_spec) - - inputs_spec = self._validate_spec(inputs_spec) - labels_spec = self._validate_spec(labels_spec) - return inputs_spec, labels_spec - - def _prepare_data_tensor(self, - inputs_spec, - labels_spec, - inputs=None, - labels=None): - if _non_static_mode() or self._dygraph_mode: - return None, None - inputs_spec = inputs_spec if inputs_spec else [] - labels_spec = labels_spec if labels_spec else [] - if inputs_spec: - assert isinstance(inputs_spec, list), \ - "inputs should be list, but received {}".format(type(inputs_spec)) - if inputs is None: - inputs = [s._create_feed_layer() for s in inputs_spec] - else: - assert isinstance(inputs, list), \ - "inputs should be list, but received {}".format(type(inputs)) - for input_spec, input in zip(inputs_spec, inputs): - if input_spec.shape != input.shape: - input.desc.set_shape(input_spec.shape) - if labels_spec: - assert isinstance(labels_spec, list), \ - "labels should be list, but received {}".format(type(labels_spec)) - if labels is None: - labels = [s._create_feed_layer() for s in labels_spec] - else: - assert isinstance(labels, list), \ - "labels should be list, but received {}".format(type(labels)) - for label_spec, label in zip(labels_spec, labels): - if label_spec.shape != label.shape: - label.desc.set_shape(label_spec.shape) - return inputs, labels - - def _prepare_reader(self): - dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] - dist_context = self._dist_contexts[self._mode] - dist_main_block = dist_main_prog.global_block() - - # NOTE: this list may be changed if Paddle changes the existing rules. - related_reader_ops = [ - "create_py_reader", "create_double_buffer_reader", "read" - ] - # remove the first three ops if multiple run fit/evaluate/predict - if dist_main_block.ops[0].type == 'create_py_reader': - for i in range(len(related_reader_ops)): - if dist_main_block.ops[0].type in related_reader_ops: - dist_main_block._remove_op(0, sync=False) - dist_main_block._sync_with_cpp() - # Step 1: find the reader ops - reader_op_indices = [] - for idx, op in enumerate(dist_main_block.ops): - if op.type in related_reader_ops: - reader_op_indices.append(idx) - # Step 2: insert the new reader ops to cpp - new_reader_ops = [] - for idx in reversed(reader_op_indices): - new_op_desc = dist_main_block.desc._prepend_op() - new_op_desc.copy_from(dist_main_block.ops[idx].desc) - new_op = Operator(dist_main_block, - new_op_desc, - type=new_op_desc.type()) - new_reader_ops.append(new_op) - dist_op = DistributedOperator(new_op) - dist_context.add_dist_op_for_program(dist_op) - # Step 3: insert the new reader ops to python - for new_op in new_reader_ops: - dist_main_block.ops.insert(0, new_op) - for i in range(len(reader_op_indices)): - reader_op_indices[i] += len(reader_op_indices) - # Step 4: remove the old reader ops from python and cpp - for idx in reversed(reader_op_indices): - op = dist_main_block.ops.pop(idx) - dist_main_block.desc._remove_op(idx, idx + 1) - dist_main_block._sync_with_cpp() - self._has_prepared_reader[self._mode] = True - - def _prepare_feed(self, data, user_feeds, mode): - feeds = {} - if data is not None: - if isinstance(data, (list, tuple)): - if len(data) == 1 and isinstance(data[0], dict): - for name, data in data[0].items(): - feeds[name] = data - else: - raise ValueError("Unsupported data {}".format(data)) - elif isinstance(data, dict): - for name, data in data.items(): - feeds[name] = data - else: - raise ValueError("Unsupported data {}".format(data)) - if user_feeds is not None: - assert isinstance(user_feeds, dict), \ - "user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__) - for name, data in user_feeds.items(): - feeds[name] = data - return feeds - - def _prepare_fetch(self, user_fetches, mode): - if user_fetches is not None: - assert isinstance(user_fetches, list), \ - "user_fetches must be a list, but receive {}".format(type(user_fetches).__name__) - fetch_names = [] - fetch_indices = [] - - def _process_fetch_group(group_name, var_list): - group_indices = [] - for var in var_list: - # Remove duplicate var_names - if self._is_local_var(var): - var_name = _to_name_str(var) - if var_name not in fetch_names: - fetch_names.append(var_name) - group_indices.append(fetch_names.index(var_name)) - if not group_indices: - fetch_names.append([]) - fetch_indices.append(group_indices) - - if mode != "predict": - _process_fetch_group("loss", self._fetch_vars[mode]["loss"]) - if mode != "predict": - metrics = self._fetch_vars[mode]["metrics"] - for i, var_list in enumerate(metrics): - _process_fetch_group("metrics_" + str(i), var_list) - if mode == "predict": - _process_fetch_group("outputs", self._fetch_vars[mode]["outputs"]) - user_fetches_collection = [ - item[1] for item in get_collection(CollectionNames.FETCHES) - ] - var_list = (user_fetches_collection or []) + (user_fetches or []) - _process_fetch_group("fetches", var_list) - return fetch_names, fetch_indices - - def _prepare_logger(self, - outs, - epoch=None, - step=None, - lr=None, - fetch_names=None, - fetch_indices=None, - profiler_log="", - mode=None): - logs = "[{}] ".format(mode) - if epoch is not None: - logs += "epoch: {:d} ".format(epoch) - if step is not None: - logs += "step: {:d} ".format(step) - if lr is not None: - logs += "lr: {:5e} ".format(lr) - group_idx = 0 - # logging loss - if mode != "predict": - loss_indices = fetch_indices[group_idx] - for idx in loss_indices: - logs += "loss: {:8f} ".format(outs[idx][0]) - group_idx += 1 - # logging metrics - if mode != "predict": - metric_vars = self._fetch_vars[mode]["metrics"] - if metric_vars: - for metric in self._metrics: - metrics_indices = fetch_indices[group_idx] - metric_out = [] - for idx in metrics_indices: - metric_out.append(outs[idx]) - if metric_out: - metric.update(*metric_out) - results = metric.accumulate() - for i, res in enumerate(to_list(results)): - logs += "{}: {:8f} ".format(metric.name()[i], res) - group_idx += 1 - # Skip logging outputs - if mode == "predict": - group_idx += 1 - # logging user fetches - fetches_logging = get_collection(CollectionNames.LOGGING) - for name, var in fetches_logging: - if var.name in fetch_names: - idx = fetch_names.index(var.name) - # Use the user defined name for logging - logs += "{}: {} ".format(name, outs[idx]) - logs += profiler_log - self._logger.info(logs) - - def _prepare_history(self, outs, fetch_indices=None, mode=None): - history = {} - group_idx = 0 - # store loss - if mode != "predict": - loss_indices = fetch_indices[group_idx] - loss_values = [] - for idx in loss_indices: - loss_values.append(outs[idx][0]) - history["loss"] = loss_values - group_idx += 1 - # store metrics - if mode != "predict": - metric_vars = self._fetch_vars[mode]["metrics"] - if metric_vars: - for metric in self._metrics: - metrics_indices = fetch_indices[group_idx] - metric_out = [] - for idx in metrics_indices: - metric_out.append(outs[idx]) - if metric_out: - metric.update(*metric_out) - results = metric.accumulate() - history[tuple(metric.name())] = to_list(results) - group_idx += 1 - # store outputs - if mode == "predict": - outputs_indices = fetch_indices[group_idx] - outputs_values = [] - for idx in outputs_indices: - outputs_values.append(outs[idx]) - history["outputs"] = outputs_values - group_idx += 1 - # store user fetches - fetches_indices = fetch_indices[group_idx] - fetches_values = [] - for idx in fetches_indices: - fetches_values.append(outs[idx]) - history["fetches"] = fetches_values - return history - - def _prepare_program(self, mode): - # Do the build process - self._build(mode) - # Do the planning process - self._plan(mode) - # Do the parallel process - self._parallel(mode) - # Init comm and startup program - self._initialize(mode) - self._has_prepared[mode] = True - - def _build(self, mode): - if _non_static_mode() or self._dygraph_mode: - paddle.disable_static() - self._dygraph_mode = True - self._logger.info("Building model with 'to_static' method.") - - inputs_spec = self._inputs_spec - labels_spec = self._labels_spec if self._labels_spec else [] - self.program_helper = ProgramHelper(self._model, self._loss, - self._metrics, inputs_spec, - labels_spec) - # build forward main program - self.program_helper.build_program(mode) - - self.concrete_program = self.program_helper.concrete_program - serial_main_prog = self.program_helper.main_program - serial_startup_prog = self.program_helper.startup_program - - inputs = self.program_helper.input_vars - outputs = self.program_helper.output_vars - labels = self.program_helper.label_vars - losses = self.program_helper.loss_vars - metrics = self.program_helper.metric_vars - - self._inputs = inputs - self._labels = labels - - paddle.enable_static() - else: - # build program in static mode - serial_main_prog = self._serial_main_progs.get(mode, None) - if serial_main_prog is not None: - return - - outputs = [] - losses = [] - metrics = [] - inputs = self._inputs if self._inputs else [] - labels = self._labels if self._labels else [] - serial_main_prog = self._orig_main_prog.clone() - serial_startup_prog = self._orig_startup_prog.clone() - if not self._skip_build: - with static.program_guard(serial_main_prog, serial_startup_prog), \ - utils.unique_name.guard(): - outputs = to_list(self._model(*inputs)) - if mode != "predict" and self._loss: - losses = to_list(self._loss(*(outputs + labels))) - - if mode != "predict" and (outputs or labels): - for metric in self._metrics: - metrics.append( - to_list(metric.compute(*(outputs + labels)))) - else: - losses = to_list(self._loss) - - default_ctx = get_default_distributed_context() - if not default_ctx.has_annotation: - # We build the world process group because the data parallel - # needs all ranks by default. - new_process_group(list(range(self._nranks))) - default_ctx.data_parallel = True - - feed_vars = {"inputs": inputs, "labels": labels} - - fetch_vars = { - "outputs": flatten(outputs), - "loss": losses, - "metrics": metrics - } - - if mode != "train": - serial_main_prog = serial_main_prog.clone(for_test=True) - - self._set_recompute_ckpts() - self._dist_contexts[mode] = DistributedContext( - serial_main_prog, serial_startup_prog, self._optimizer, losses, - feed_vars, fetch_vars, self._cluster, self._strategy) - self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale - - def _optimization_tuning(self, mode, dataset, batch_size): - if not self._tuning.enable: - raise ValueError("Please set `tuning.enable=True`.") - - assert mode == "train" - # Do the build process - self._build(mode) - # Do the planning process - self._plan(mode) - - dataset.dp_world_size = self._dp_world_sizes - dataset.dp_rank = self._dp_ranks - - from .tuner.optimization_tuner import OptimizationTuner - self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(), - self._dist_contexts[mode], - dataset, - self._inputs_spec, - self._labels_spec, - batch_size=batch_size, - rank=self._cur_rank) - - self._optimization_tuner.tune() - - if self._tuning.run_after_tuning: - # update the strategy - self._dist_contexts[ - mode]._strategy = self._optimization_tuner.get_best_config() - - def _plan(self, mode): - if self._planned_mode is None: - self._planned_mode = mode - else: - self._init_dist_context(mode) - - self._planners[mode] = Planner(mode, self._dist_contexts[mode]) - self._planners[mode].plan() - - # infer data parallel info - inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"] - labels_var = self._dist_contexts[mode].serial_feed_vars["labels"] - block = self._dist_contexts[mode].serial_main_program.global_block() - # TODO: check this feed_list - feed_list = [] - for var in inputs_var + labels_var: - if var.name in block.vars: - feed_list.append(block.vars[var.name]) - - self._dp_world_sizes = [] - self._dp_ranks = [] - for feed_var in feed_list: - dp_world_size, dp_rank = self._get_input_split_info( - feed_var, self._dist_contexts[mode]) - self._dp_world_sizes.append(dp_world_size) - self._dp_ranks.append(dp_rank) - - def _parallel(self, mode, all_ranks=False): - # Parallelize program based on the planner's results - # For now, the completer has to be passed to the planner, - # because we may use it to complete the annotation of the backwarkward and update. - parallelizer = Parallelizer(mode, self._planners[mode].completer, - self._dist_contexts[mode]) - if not all_ranks: - parallelizer.parallel(self._cur_rank) - else: - parallelizer.parallel_all() - - def _init_dist_context(self, mode): - # Init dist_context['mode'] with the first planned dist_context - # to guarantee that train/eval/predict mode have same parallel strategy - dist_context = self._dist_contexts[mode] - origin_main_prog = dist_context._original_serial_main_program - ref_mode = self._planned_mode - ref_dist_context = self._dist_contexts[ref_mode] - ref_origin_main_prog = ref_dist_context._original_serial_main_program - ref_blocks = ref_origin_main_prog.blocks - for ib, block in enumerate(origin_main_prog.blocks): - for iop, op in enumerate(block.ops): - ref_op = ref_blocks[ib].ops[iop] - assert op.type == ref_op.type, \ - "'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type) - ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program( - ref_op) - dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr) - - def _initialize(self, mode): - # Get the current content from the distributed context - self._serial_main_progs[mode] = self._dist_contexts[ - mode].serial_main_program - self._serial_startup_progs[mode] = self._dist_contexts[ - mode].serial_startup_program - self._dist_main_progs[mode] = self._dist_contexts[ - mode].dist_main_programs - self._dist_startup_progs[mode] = self._dist_contexts[ - mode].dist_startup_programs - self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars - self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars - self._lr_optimizer = self._dist_contexts[mode]._lr_optimizer - - if self._nranks > 1: - # Traverse different rank programs and traverse each op of them, - # instantiate communication by process_mapping. - all_process_groups = get_all_process_groups() - - # NOTE: add the comm init control in the future for auto search - for process_group in all_process_groups: - if self._cur_rank not in process_group.ranks: - continue - process_group.instantiate() - - place = _get_device() - if isinstance(place, fluid.CUDAPlace): - place = fluid.CUDAPlace(ParallelEnv().dev_id) - - if self._strategy.seed: - paddle.seed(self._strategy.seed + self._dp_ranks[0]) - np.random.seed(self._strategy.seed + self._dp_ranks[0]) - random.seed(self._strategy.seed + self._dp_ranks[0]) - - if self._dygraph_mode: - dist_context = self._dist_contexts[mode] - dist_main_program = self._dist_main_progs[mode][self._cur_rank] - self.program_helper.init(dist_main_program, place, dist_context) - - if self._executor is None: - self._executor = paddle.static.Executor(place) - uninitialized = [] - dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] - for var in dist_startup_prog.list_vars(): - scope_var = global_scope().find_var(var.name) - if scope_var and scope_var.get_tensor()._is_initialized(): - continue - uninitialized.append(var) - if uninitialized: - prune_startup_prog = dist_startup_prog._prune(uninitialized) - self._executor.run(prune_startup_prog) - - if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"): - self._set_state_dict(mode, self._strict, self._state_dict, - self._dist_attr) - - if self._strategy.reinit: - self._logger.info("NOTE: parameters wiil be re-initialized.") - dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] - self._executor.run(dist_startup_prog) - - def fit(self, - train_data, - train_sample_split=None, - batch_size=1, - epochs=1, - steps_per_epoch=None, - valid_data=None, - valid_sample_split=None, - valid_freq=1, - valid_steps=None, - collate_fn=None, - callbacks=None): - """ - Trains the model for a fixed number of epochs. If `valid_data` is set, - evaluation will be done at the end of each epoch. - - Args: - train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. - train_sample_split (int, optional): Each sample of the train dataset is assumed - to be a (input, label) pair by default and has two items. If each sample has - more than two items, train_sample_split specifies how to split these items into - input and label. The items before it are input and the left are label. Default: None. - batch_size (int, optional): The batch size of train_data and valid_data if provided. - The user's data will be used directly without batching if set to None. Default: 1. - epochs (int, optional): The number of epochs to train the model. Default: 1. - steps_per_epoch (int, optional): The total number of steps (batches of samples) - is executed in one epoch before stating the next one. If None, it is equal to - the number samples in your dataset divided by the batch size. Default: None. - valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for - evaluation at the end of epoch. No evaluation will be done if set to None. - Default: None. (Unsupported for now) - valid_freq (int, optional): Only relevant if valid_data is provided. This specifies - how many training epochs before a new evaluation is performed. Default: 1. - valid_sample_split (int, optional): Only relevant if valid_data is provided. - Each sample of the valid dataset is assumed to be a (input, label) pair - by default and has two items. If each sample has more than two items, - valid_sample_split specifies how to split these items into input and label. - The items before it are input and the left are label. Default: None. - valid_steps (int, optional): Only relevant if valid_data is provided. - It is the total number of steps (batches of samples) to draw before - stopping validation at the end of every epoch. If None, validation will run until the - `valid_data` dataset is exhausted. The validation will start from the - beginning of the dataset at each epoch. Default: None. - collate_fn(callable, optional): function to generate mini-batch data by merging - the sample list, None for only stack each fields of sample in axis - 0. Default None. - callbacks (Callback|None, optional): A list of `Callback` instances to apply - during training. Default: None. (Unused for now) - - Returns: - None - - Examples: - - .. code-block:: python - - import paddle - import paddle.vision.transforms as T - from paddle.distributed.fleet import auto - from paddle.vision.datasets import MNIST - - transform = T.Compose([ - T.Transpose(), - T.Normalize([127.5], [127.5]) - ]) - train_dataset = MNIST(mode='train', transform=transform) - - model = paddle.vision.models.LeNet() - loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.optimizer.Adam( - learning_rate=0.001, parameters=model.parameters()) - metrics = paddle.metric.Accuracy(topk=(1, 2)) - - engine = auto.Engine(model, loss, optimizer, metrics) - engine.fit(train_dataset, - epochs=2, - batch_size=64) - """ - self._mode = 'train' - self._inputs_spec, self._labels_spec = self._prepare_data_spec( - train_data, train_sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) - if not self._has_prepared[self._mode]: - self._prepare_program(self._mode) - else: - self._switch_mode(self._mode) - train_dataloader = self._prepare_dataloader_from_generator( - dataset=train_data, - capacity=70, - # use_double_buffer=use_double_buffer, - iterable=False, - # return_list=return_list, - # use_multiprocess=use_multiprocess, - # drop_last=drop_last, - batch_size=batch_size, - epochs=epochs, - steps_per_epoch=steps_per_epoch, - collate_fn=collate_fn) - fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) - lr_scheduler = self._get_lr_scheduler(self.main_program) - - with profiler.Profiler(timer_only=True) as prof: - for epoch in range(epochs): - for step, data in enumerate(train_dataloader): - self._strategy.return_numpy = True - fetch_names = [fetch_names[0]] - fetch_names.append('labels') - print_param(self.main_program) - try: - outs, lables = self._executor.run( - self.main_program, - fetch_list=fetch_names, - use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy) - print("lables: {}".format(lables[:20])) - print("outs: {}".format(outs)) - - except core.EOFException: - break - if lr_scheduler and step % self._k_steps == 0: - lr_scheduler.step() - lr = self._get_lr(self._lr_optimizer) - - prof.step() - - # self._prepare_logger(outs, epoch, step, lr, - # fetch_names, fetch_indices, - # prof.step_info(), self._mode) - # history = self._prepare_history(outs, fetch_indices, - # self._mode) - history = None - - if valid_data and epoch % valid_freq == 0: - self.evaluate(valid_data, valid_sample_split, batch_size, - valid_steps, collate_fn, callbacks) - self._switch_mode("train") - else: - self._reset_metrics() - return history - - def evaluate(self, - valid_data, - valid_sample_split=None, - batch_size=1, - steps=None, - collate_fn=None, - callbacks=None): - """ - Evaluate the loss and metrics of the model on evaluation data. - - Args: - valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. - valid_sample_split (int, optional): Each sample of the eval dataset is assumed - to be a (input, label) pair by default and has two items. If each sample has - more than two items, valid_sample_split specifies how to split these items into - input and label. The items before it are input and the left are label. Default: None. - batch_size (int, optional): The batch size of valid_data. The user's data will - be used directly without batching if set to None. Default: 1. - steps (int, optional): It is the total number of steps (batches of samples) to draw before - stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted. - The evaluation will start from the beginning of the dataset in each run. Default: None. - collate_fn(callable, optional): function to generate mini-batch data by merging - the sample list, None for only stack each fields of sample in axis - 0. Default None. - callbacks (Callback|None, optional): A list of `Callback` instances to apply - during evaluating. Default: None. (Unused for now) - - Returns: - None - - Examples: - - .. code-block:: python - - import paddle - import paddle.vision.transforms as T - from paddle.distributed.fleet import auto - from paddle.vision.datasets import MNIST - - transform = T.Compose([ - T.Transpose(), - T.Normalize([127.5], [127.5]) - ]) - valid_dataset = MNIST(mode='test', transform=transform) - - model = paddle.vision.models.LeNet() - loss = paddle.nn.CrossEntropyLoss() - metrics = paddle.metric.Accuracy(topk=(1, 2)) - - engine = auto.Engine(model, loss, metrics=metrics) - engine.evaluate(valid_dataset, batch_size=64) - - """ - self._mode = 'eval' - self._inputs_spec, self._labels_spec = self._prepare_data_spec( - valid_data, valid_sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) - if not self._has_prepared[self._mode]: - self._prepare_program(self._mode) - else: - self._switch_mode(self._mode) - assert self._mode in self._dist_main_progs, \ - "eval model is not ready, please call `engine._prepare_program('eval')` first." - valid_dataloader = self._prepare_dataloader_from_generator( - dataset=valid_data, - # feed_list=feed_list, - capacity=70, - # use_double_buffer=use_double_buffer, - iterable=False, - # return_list=return_list, - # use_multiprocess=use_multiprocess, - # drop_last=drop_last, - # places=places, - batch_size=batch_size, - # epochs=epochs, - steps_per_epoch=steps, - collate_fn=collate_fn) - fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) - - for step, _ in enumerate(valid_dataloader): - try: - outs = self._executor.run( - self.main_program, - fetch_list=fetch_names, - use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy) - except core.EOFException: - break - self._prepare_logger(outs, None, step, None, fetch_names, - fetch_indices, "", self._mode) - history = self._prepare_history(outs, fetch_indices, self._mode) - self._reset_metrics() - return history - - def predict(self, - test_data, - test_sample_split=None, - batch_size=1, - steps=None, - collate_fn=None, - callbacks=None): - """ - Compute the output predictions on testing data. - - Args: - test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. - test_sample_split (int, optional): Each sample of the test dataset is assumed - to be a (input, label) pair by default and has two items. If each sample has - more than two items, test_sample_split specifies how to split these items into - input and label. The items before it are input and the left are label. Default: None. - batch_size (int, optional): The batch size of test_data. The user's data will - be used directly without batching if set to None. Default: 1. - steps (int, optional): It is the total number of steps (batches of samples) to draw before - stopping predict. If None, predict will run until the `test_data` dataset is exhausted. - The predict will start from the beginning of the dataset in each run. Default: None. - collate_fn(callable, optional): function to generate mini-batch data by merging - the sample list, None for only stack each fields of sample in axis - 0. Default None. - callbacks (Callback|None, optional): A list of `Callback` instances to apply - during testing. Default: None. (Unused for now) - - Returns: - None - - Examples: - - .. code-block:: python - - import paddle - import paddle.vision.transforms as T - from paddle.distributed.fleet import auto - from paddle.vision.datasets import MNIST - - transform = T.Compose([ - T.Transpose(), - T.Normalize([127.5], [127.5]) - ]) - valid_dataset = MNIST(mode='test', transform=transform) - - model = paddle.vision.models.LeNet() - - engine = auto.Engine(model) - engine.predict(valid_dataset, batch_size=64) - """ - self._mode = 'predict' - self._inputs_spec, self._labels_spec = self._prepare_data_spec( - test_data, test_sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) - if not self._has_prepared[self._mode]: - self._prepare_program(self._mode) - else: - self._switch_mode(self._mode) - assert self._mode in self._dist_main_progs, \ - "predict model is not ready, please call `engine._prepare_program('predict')` first." - test_dataloader = self._prepare_dataloader_from_generator( - dataset=test_data, - # feed_list=feed_list, - capacity=70, - # use_double_buffer=use_double_buffer, - iterable=False, - # return_list=return_list, - # use_multiprocess=use_multiprocess, - # drop_last=drop_last, - # places=places, - batch_size=batch_size, - # epochs=epochs, - steps_per_epoch=steps, - collate_fn=collate_fn) - fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) - - for step, _ in enumerate(test_dataloader): - try: - outs = self._executor.run( - self.main_program, - fetch_list=fetch_names, - use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy) - except core.EOFException: - break - self._prepare_logger(outs, None, step, None, fetch_names, - fetch_indices, "", self._mode) - history = self._prepare_history(outs, fetch_indices, self._mode) - - return history - - def dataloader( - self, - dataset, - # return_list=True, - batch_size=1, - shuffle=False, - drop_last=False, - collate_fn=None, - num_workers=0, - use_buffer_reader=True, - use_shared_memory=True, - timeout=0, - worker_init_fn=None, - epochs=1, - steps_per_epoch=None, - sample_split=1, - mode=None): - if mode is not None: - self.to_mode(mode) - self._inputs_spec, self._labels_spec = self._prepare_data_spec( - dataset, sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) - if not self._has_prepared[self._mode]: - self._prepare_program(self._mode) - else: - self._switch_mode(self._mode) - dataloader = self._prepare_dataloader( - dataset, - return_list=False, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - collate_fn=collate_fn, - num_workers=num_workers, - use_buffer_reader=use_buffer_reader, - use_shared_memory=use_shared_memory, - timeout=timeout, - worker_init_fn=worker_init_fn, - epochs=epochs, - steps_per_epoch=steps_per_epoch) - return dataloader - - def dataloader_from_generator( - self, - dataset, - capacity=70, - use_double_buffer=True, - iterable=True, - # return_list=False, - use_multiprocess=False, - drop_last=True, - batch_size=1, - epochs=1, - steps_per_epoch=None, - collate_fn=None, - sample_split=1, - mode=None): - if mode is not None: - self.to_mode(mode) - self._inputs_spec, self._labels_spec = self._prepare_data_spec( - dataset, sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) - if not self._has_prepared[self._mode]: - self._prepare_program(self._mode) - else: - self._switch_mode(self._mode) - dataloader = self._prepare_dataloader_from_generator( - dataset=dataset, - # feed_list=feed_list, - capacity=capacity, - use_double_buffer=use_double_buffer, - iterable=iterable, - return_list=False, - use_multiprocess=use_multiprocess, - drop_last=drop_last, - # places=places, - batch_size=batch_size, - epochs=epochs, - steps_per_epoch=steps_per_epoch, - collate_fn=collate_fn) - return dataloader - - def prepare(self, - inputs_spec=None, - labels_spec=None, - inputs=None, - labels=None, - main_program=None, - startup_program=None, - mode=None): - if mode is not None: - self.to_mode(mode) - if inputs or labels: - self._skip_build = True - self._inputs_spec = inputs_spec - self._labels_spec = labels_spec - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec, inputs, labels) - self._orig_main_prog = main_program - if self._orig_main_prog is None: - self._orig_main_prog = static.default_main_program() - self._orig_startup_prog = startup_program - if self._orig_startup_prog is None: - self._orig_startup_prog = static.default_startup_program() - if not self._has_prepared[self._mode]: - self._prepare_program(self._mode) - else: - self._switch_mode(self._mode) - elif inputs_spec or labels_spec: - self._inputs_spec = inputs_spec - self._labels_spec = labels_spec - self._outside_dataloader = True - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) - self._orig_main_prog = main_program - if self._orig_main_prog is None: - self._orig_main_prog = static.default_main_program() - self._orig_startup_prog = startup_program - if self._orig_startup_prog is None: - self._orig_startup_prog = static.default_startup_program() - if not self._has_prepared[self._mode]: - self._prepare_program(self._mode) - else: - self._switch_mode(self._mode) - else: - assert self._inputs_spec and self._labels_spec, \ - "Please call the dataloader(...) before calling prepare(...)" - - def run( - self, - data=None, - # program=None, - feed=None, - fetch_list=None, - # feed_var_name='feed', - # fetch_var_name='fetch', - # scope=None, - # return_numpy=True, - # use_program_cache=False, - # return_merged=True, - # use_prune=False, - mode=None): - if mode is not None: - self.to_mode(mode) - feed_dict = self._prepare_feed(data, feed, self._mode) - fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode) - if self._outside_dataloader and not self._has_prepared_reader[ - self._mode]: - self._prepare_reader() - outs = self._executor.run(self.main_program, - feed=feed_dict, - fetch_list=fetch_names, - use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy) - self._prepare_logger(outs, None, None, None, fetch_names, fetch_indices, - "", self._mode) - history = self._prepare_history(outs, fetch_indices, self._mode) - return history - - def _prepare_dataloader(self, - dataset, - return_list=True, - batch_size=1, - shuffle=False, - drop_last=False, - collate_fn=None, - num_workers=0, - use_buffer_reader=True, - use_shared_memory=True, - timeout=0, - worker_init_fn=None, - epochs=1, - steps_per_epoch=None): - - if self._strategy.gradient_merge and batch_size is not None: - assert batch_size % self._k_steps == 0, \ - "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps) - batch_size //= self._k_steps - - dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] - dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] - dist_context = self._dist_contexts[self._mode] - dist_main_block = dist_main_prog.global_block() - - # NOTE: Get feed_list, then insert dataloader op with sharded var shape. - # Cause predict_program does not contain labels var, - # then we will add labels var from serial_program to dist_program, - # that maintains the length of feed_list equal to the length of dataset's values. - inputs_var = self._feed_vars[self._mode]["inputs"] - labels_var = self._feed_vars[self._mode]["labels"] - feed_list = [] - for var in inputs_var + labels_var: - if var.name in dist_main_block.vars: - feed_list.append(dist_main_block.vars[var.name]) - else: - copy_var = dist_main_block._clone_variable(var, var.persistable) - copy_var.desc.set_original_id(var.desc.original_id()) - feed_list.append(copy_var) - - # insert read op at the end of program - places = paddle.static.cuda_places() - with static.program_guard(dist_main_prog, dist_startup_prog): - dataloader = DistributedDataLoader( - dataset, - feed_list=feed_list, - places=places, - return_list=return_list, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - collate_fn=collate_fn, - num_workers=num_workers, - use_buffer_reader=use_buffer_reader, - use_shared_memory=use_shared_memory, - timeout=timeout, - worker_init_fn=worker_init_fn, - epochs=epochs, - steps_per_epoch=steps_per_epoch, - split_data=self._strategy.split_data, - data_parallel_world_size=self._dp_world_sizes, - data_parallel_rank=self._dp_ranks) - - return dataloader - - def _prepare_dataloader_from_generator(self, - dataset, - capacity=None, - use_double_buffer=True, - iterable=True, - return_list=False, - use_multiprocess=False, - drop_last=True, - batch_size=1, - epochs=1, - steps_per_epoch=None, - collate_fn=None): - - if self._strategy.gradient_merge and batch_size is not None: - assert batch_size % self._k_steps == 0, \ - "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps) - batch_size //= self._k_steps - - dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] - dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] - dist_context = self._dist_contexts[self._mode] - dist_main_block = dist_main_prog.global_block() - - # NOTE: Get feed_list, then insert dataloader op with sharded var shape. - # Cause predict_program does not contain labels var, - # then we will add labels var from serial_program to dist_program, - # that maintains the length of feed_list equal to the length of dataset's values. - inputs_var = self._feed_vars[self._mode]["inputs"] - labels_var = self._feed_vars[self._mode]["labels"] - feed_list = [] - for var in inputs_var + labels_var: - if var.name in dist_main_block.vars: - feed_list.append(dist_main_block.vars[var.name]) - else: - copy_var = dist_main_block._clone_variable(var, var.persistable) - copy_var.desc.set_original_id(var.desc.original_id()) - feed_list.append(copy_var) - - # # remove the first three ops if multi run fit/evaluate/predict - # self._op_size = len(dist_main_block.ops) - # if dist_main_block.ops[0].type == 'create_py_reader': - # op_size -= 3 - # for _ in range(3): - # dist_main_block._remove_op(0, sync=False) - - places = paddle.static.cuda_places() - with static.program_guard(dist_main_prog, dist_startup_prog): - dataloader = DistributedDataLoaderFromGenerator( - dataset=dataset, - feed_list=feed_list, - capacity=capacity, - use_double_buffer=use_double_buffer, - iterable=iterable, - return_list=return_list, - use_multiprocess=use_multiprocess, - drop_last=drop_last, - places=places, - batch_size=batch_size, - epochs=epochs, - steps_per_epoch=steps_per_epoch, - collate_fn=collate_fn, - split_data=self._strategy.split_data, - data_parallel_world_size=self._dp_world_sizes, - data_parallel_rank=self._dp_ranks) - self._prepare_reader() - # # move read op from the end of program to the start of program - # new_op_size = len(dist_main_block.ops) - # for _ in range(new_op_size - 1, op_size - 1, -1): - # op = dist_main_block.ops[new_op_size - 1] - # new_op_desc = dist_main_block.desc._prepend_op() - # new_op_desc.copy_from(op.desc) - # new_op = Operator(dist_main_block, - # new_op_desc, - # type=new_op_desc.type()) - # dist_main_block.ops.insert(0, new_op) - # dist_op = DistributedOperator(new_op) - # dist_context.add_dist_op_for_program(dist_op) - # for _ in range(new_op_size - op_size): - # dist_main_block._remove_op(new_op_size, sync=False) - # dist_main_block._sync_with_cpp() - return dataloader - - def _tune(self, tune_data, tune_sample_split=None, batch_size=1): - self._mode = 'train' - self._inputs_spec, self._labels_spec = self._prepare_data_spec( - tune_data, tune_sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) - self._optimization_tuning(self._mode, tune_data, batch_size) - - def _validate_spec(self, specs): - specs = to_list(specs) - self._k_steps = self._strategy.gradient_merge.k_steps - if specs is not None: - for i, spec in enumerate(specs): - assert isinstance(spec, InputSpec) - if spec.name is None: - raise ValueError( - "Requires Input[{}].name != None, but receive `None` with {}." - .format(i, spec)) - if self._k_steps > 1: - shape = list(spec.shape) - assert shape[0] % self._k_steps == 0, \ - "Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps) - shape[0] //= self._k_steps - spec.shape = shape - return specs - - def _is_local_var(self, var): - var_name = _to_name_str(var) - return var_name in self.main_program.global_block().vars - - def _get_input_split_info(self, var, dist_context): - # deduce how the input data is split among the cluster - from .utils import _get_comm_group, _get_corresponding_rank - - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) - process_mesh = tensor_dist_attr.process_mesh - dims_mapping = tensor_dist_attr.dims_mapping - - if self._cur_rank not in process_mesh.processes: - rank_id = _get_corresponding_rank(dist_context, process_mesh, - self._cur_rank) - else: - rank_id = self._cur_rank - - batch_size_axis = dims_mapping[0] - if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1: - group_ranks = _get_comm_group(process_mesh.processes, - process_mesh.topology, - batch_size_axis, rank_id) - return len(group_ranks), group_ranks.index(rank_id) - - return 1, 0 - - def _set_recompute_ckpts(self): - # NOTE hack to enable recompute in engine api for GPT-3 - # TODO support more PaddleNLP/CV models here - - recompute = self._strategy.recompute - - # extract ckpts by specific model - if isinstance(self._model, paddle.nn.Layer): - if hasattr(self._model, - "gpt") and self._model.__class__.__name__ in [ - 'GPTForPretraining', 'GPTForPretrainingAuto' - ]: - exact_ckpts = self._model.gpt.checkpoints - else: - exact_ckpts = recompute.checkpoints - else: - exact_ckpts = recompute.checkpoints - - # modify strategy - if recompute.enable: - recompute.checkpoints = exact_ckpts[:] - logs = { - 'Model Class': self._model.__class__.__name__, - 'Applied Recompute ckpts': exact_ckpts - } - self._logger.info(logs) - - def _validate_opt(self, optimizer): - if optimizer is not None: - optimizer._parameter_list = None - optimizer._param_groups = None - return optimizer - - def _reset_metrics(self): - for metric in self._metrics: - metric.reset() - - def _switch_mode(self, mode): - self.to_mode(mode) - self._initialize(mode) - - def to_mode(self, mode): - assert mode in ["train", "eval", "predict"], \ - "mode {} should be one of ['train', 'eval', 'predict']".format(mode) - self._mode = mode - - def _set_state_dict(self, mode, strict, state_dict, dist_attr): - program = self._dist_main_progs[mode][self._cur_rank] - dist_context = self._dist_contexts[mode] - cur_dist_attr = get_dist_attr(program, dist_context) - converter = Converter(state_dict, dist_attr, cur_dist_attr) - state_dict = converter.convert(strict=strict) - program.set_state_dict(state_dict) - - def save(self, path, training=True): - """ - Saves the model, parameters, optimizer state to path. - If `training` is set to False, only inference model will be saved. - - Args: - path (str): The file prefix to save model. The format - is 'dirname/file_prefix' or 'file_prefix'. if empty str. - A exception will be raised. - training (bool, optional): Whether to save for training. If not, save - for inference only. If `training` is set to True, the optimizer state - will be saved. Otherwise, only the model and parameters are saved. - This function will silently overwrite existing file at the target - location. Default: True. - - Returns: - None - - Examples: - - .. code-block:: python - import paddle - import paddle.vision.transforms as T - from paddle.distributed.fleet import auto - from paddle.vision.datasets import MNIST - - transform = T.Compose([ - T.Transpose(), - T.Normalize([127.5], [127.5]) - ]) - train_dataset = MNIST(mode='train', transform=transform) - - model = paddle.vision.models.LeNet() - loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.optimizer.Adam( - learning_rate=0.001, parameters=model.parameters()) - metrics = paddle.metric.Accuracy(topk=(1, 2)) - - engine = auto.Engine(model, loss, optimizer, metrics) - engine.fit(train_dataset, - epochs=1, - batch_size=64) - engine.save("./my_model") - - """ - if training: - assert 'train' in self._serial_main_progs, \ - "training model is not ready, please call `engine._prepare_program('train')` first." - serial_program = self._serial_main_progs["train"] - dist_main_prog = self._dist_main_progs["train"][self._cur_rank] - dist_context = self._dist_contexts["train"] - self._saver.save(path, - serial_program=serial_program, - dist_main_program=dist_main_prog, - dist_context=dist_context) - else: - mode = "predict" - feed_vars = self._feed_vars[mode]['inputs'] - fetch_vars = self._fetch_vars[mode]['outputs'] - dist_main_prog = self._dist_main_progs[mode][self._cur_rank] - self._saver.save_inference_model(path, - feed_vars, - fetch_vars, - self._executor, - program=dist_main_prog) - - def load(self, path, strict=True, load_optimizer=True): - """ - Load the stored model, parameters and optimizer states. - - Args: - path (str): The prefix of files storing the model states and - optimizer states. - strict (bool, optional): Whether to skip the loading of mismatch - parameter or raise an error when mismatch happens (not found - the parameter in file storing model states of or receives a - mismatch shape). Default: False. - load_optimizer (bool, optional): If True, the stored optimizer - states is restored. Otherwise, the optimizer states is initialized - from scratch. Default: False. - - Returns: - None - - Examples: - - .. code-block:: python - import paddle - import paddle.vision.transforms as T - from paddle.distributed.fleet import auto - from paddle.vision.datasets import MNIST - - transform = T.Compose([ - T.Transpose(), - T.Normalize([127.5], [127.5]) - ]) - train_dataset = MNIST(mode='train', transform=transform) - - model = paddle.vision.models.LeNet() - loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.optimizer.Adam( - learning_rate=0.001, parameters=model.parameters()) - metrics = paddle.metric.Accuracy(topk=(1, 2)) - - engine = auto.Engine(model, loss, optimizer, metrics) - engine.fit(train_dataset, - epochs=1, - batch_size=64) - engine.save("./my_model") - engine.load("./my_model") - - """ - self._strict = strict - self._state_dict, self._dist_attr = self._saver.load( - path, load_optimizer) - return self._state_dict, self._dist_attr - - @staticmethod - def _get_lr_scheduler(program): - lr_sheduler = None - if hasattr(program, 'lr_sheduler'): - from paddle.optimizer.lr import LRScheduler - lr_sheduler = program.lr_sheduler - assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler" - return lr_sheduler - - def _get_lr(self, optimizer): - if isinstance(optimizer, paddle.optimizer.Optimizer): - return optimizer.get_lr() - elif isinstance(optimizer, paddle.fluid.optimizer.Optimizer): - if isinstance(optimizer._learning_rate, float): - return optimizer._learning_rate - else: - return optimizer._learning_rate() - else: - raise TypeError( - "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ - " or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer)) - ) - - @property - def main_program(self): - return self._dist_main_progs[self._mode][self._cur_rank] - - @property - def startup_program(self): - return self._dist_startup_progs[self._mode][self._cur_rank] - - @property - def dist_context(self): - return self._dist_contexts[self._mode] - - @property - def serial_main_program(self): - return self._serial_main_progs[self._mode] - - @property - def serial_startup_program(self): - return self._serial_startup_progs[self._mode] - - @property - def fetch_vars(self): - return self._fetch_vars[self._mode] - - @property - def inputs(self): - return self._inputs - - @property - def labels(self): - return self._labels - - -def print_param(program): - for i, p in enumerate(program.all_parameters()): - if i == 10: - break - print(p.name, np.array(p.get_value())[:20]) - - -def print_input(program, vars): - for v in vars: - print(v.name, v.get_value()[:20]) From 104e6740581ddfaae79021a446dd1a1714e72f46 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Mon, 7 Nov 2022 15:16:00 +0800 Subject: [PATCH 36/36] update unitest --- .../fluid/tests/unittests/auto_parallel/test_strategy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py index cbe899a7e6eb2..58641a1ec3af2 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -44,7 +44,9 @@ def test_default_config(self): self.assertEqual(sharding.enable, False) self.assertEqual(sharding.stage, 1) self.assertEqual(sharding.degree, 8) - self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0) + self.assertAlmostEqual(sharding.overlap_grad_comm, False) + self.assertAlmostEqual(sharding.bucket_size_numel, -1) + self.assertAlmostEqual(sharding.partition_algor, "greedy_even") self.assertEqual(sharding.enable_tuning, False) self.assertEqual(sharding.tuning_range, [])