From 9bd3242109f58b2fe2863b22e7b6d404a563d0db Mon Sep 17 00:00:00 2001 From: caozhou Date: Tue, 21 Mar 2023 09:20:22 +0000 Subject: [PATCH] add bwd sub program completion --- .../auto_parallel/tuner/rule_based_tuner.py | 199 ++++++++++++++++++ .../auto_parallel/test_dist_op_cost.py | 2 + .../auto_parallel/test_group_operators.py | 5 +- .../unittests/auto_parallel/test_pattern.py | 7 - .../auto_parallel/test_rule_based_tuner.py | 2 + 5 files changed, 206 insertions(+), 9 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py index 547e8c87a06ad..6c74aac842dbf 100644 --- a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py @@ -1411,3 +1411,202 @@ def complete_sub_fwd_programs(self, process_mesh): if idx not in self.sub_programs_dist_context: self.sub_programs_dist_context[idx] = {} self._compelte_sub_fwd_program(idx, sub_fwd_program, process_mesh) + + def _complete_sub_bwd_program(self, sub_program_dist_context): + """ + Complete the backward OP according to the forward OP. + Most of the logic is the same as the backward completion in the completer. + The difference is that find the backward OP according to the forward OP, + while find the forward OP according to the backward OP in the completer. + """ + + def _is_grad_var_name(name): + if "@GRAD" in name: + return True + return False + + sub_fwd_program = sub_program_dist_context.serial_main_program + block = sub_fwd_program.global_block() + vars = self.full_main_program.global_block().vars + ops = self.full_main_program.global_block().ops + grad_var_to_var = ( + self.full_main_program_dist_context.dist_op_context.grad_var_to_var[ + 1 + ] + ) + for forward_op in block.ops: + if ( + forward_op.desc.original_id() + not in self.op_original_id_to_grad_op_original_id + ): + continue + grad_op_id = self.op_original_id_to_grad_op_original_id[ + forward_op.desc.original_id() + ] + # for unsqueeze2 op in gpt, it has no grad op + # or for no need to bwd + if grad_op_id not in self.op_original_id_to_op: + continue + grad_op = self.op_original_id_to_op[grad_op_id] + if grad_op.type == "concat" and forward_op.type == "split": + forward_op_dist_attr = ( + sub_program_dist_context.get_op_dist_attr_for_program( + forward_op + ) + ) + output_var = vars[grad_op.desc.output('Out')[0]] + split_input_var_name = forward_op.input("X")[0] + ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( + split_input_var_name + ) + ref_mesh = forward_op_dist_attr.process_mesh + + grad_op_dist_attr = OperatorDistAttr() + for input_name in grad_op.input_arg_names: + grad_op_dist_attr.set_input_dims_mapping( + input_name, ref_dims_mapping + ) + + output_var_dist_attr = TensorDistAttr() + output_var_dist_attr.dims_mapping = ref_dims_mapping + output_var_dist_attr.process_mesh = ref_mesh + sub_program_dist_context.set_tensor_dist_attr_for_program( + output_var, output_var_dist_attr + ) + + grad_op_dist_attr.set_output_dims_mapping( + output_var.name, ref_dims_mapping + ) + grad_op_dist_attr.process_mesh = ref_mesh + sub_program_dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr + ) + grad_op_dist_attr.impl_type = ( + fwd_op_dist_attr.impl_type # noqa: F821 + ) + grad_op_dist_attr.impl_idx = ( + fwd_op_dist_attr.impl_idx # noqa: F821 + ) + continue + + fwd_op_dist_attr = ( + sub_program_dist_context.get_op_dist_attr_for_program( + forward_op + ) + ) + fwd_op_process_mesh = fwd_op_dist_attr.process_mesh + grad_op_dist_attr = OperatorDistAttr() + grad_op_dist_attr.process_mesh = fwd_op_process_mesh + + for input_name in grad_op.input_arg_names: + if ( + input_name not in forward_op.input_arg_names + and input_name not in forward_op.output_arg_names + ): + if input_name in grad_var_to_var.keys(): + fwd_name = grad_var_to_var[input_name] + ref_dims_mapping = ( + fwd_op_dist_attr.get_output_dims_mapping(fwd_name) + ) + else: + input_var = vars[input_name] + ref_dims_mapping = sub_program_dist_context.get_tensor_dist_attr_for_program( + input_var + ).dims_mapping + else: + if input_name in forward_op.input_arg_names: + ref_dims_mapping = ( + fwd_op_dist_attr.get_input_dims_mapping(input_name) + ) + else: + ref_dims_mapping = ( + fwd_op_dist_attr.get_output_dims_mapping(input_name) + ) + assert ( + ref_dims_mapping is not None + ), "[{}] 's dims mapping is NONE".format(input_name) + grad_op_dist_attr.set_input_dims_mapping( + input_name, ref_dims_mapping + ) + + for output_name in grad_op.output_arg_names: + assert output_name in grad_var_to_var + fwd_name = grad_var_to_var[output_name] + ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping( + fwd_name + ) + # var + output_var = vars[output_name] + tensor_dist_attr = TensorDistAttr() + tensor_dist_attr.dims_mapping = ref_dims_mapping + tensor_dist_attr.process_mesh = fwd_op_process_mesh + sub_program_dist_context.set_tensor_dist_attr_for_program( + output_var, tensor_dist_attr + ) + # op + grad_op_dist_attr.set_output_dims_mapping( + output_name, ref_dims_mapping + ) + + grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type + grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx + sub_program_dist_context.set_op_dist_attr_for_program( + grad_op, grad_op_dist_attr + ) + + grad_op_idx = self.op_original_id_to_idx[grad_op_id] + if grad_op_idx + 1 < len(ops): + grad_op_next_op = ops[grad_op_idx + 1] + if grad_op_next_op.type == "sum": + assert all( + map(_is_grad_var_name, grad_op_next_op.input_arg_names) + ) + output_name = grad_op_next_op.output_arg_names[0] + assert ( + output_name in grad_var_to_var + ), "sum op's output '{}' has no corresponding var".format( + output_name + ) + ref_fwd_var_name = grad_var_to_var[output_name] + ref_fwd_var = vars[ref_fwd_var_name] + ref_fwd_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program( + ref_fwd_var + ) + ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping + ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh + + # output + tensor_dist_attr = TensorDistAttr() + tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping + tensor_dist_attr.process_mesh = ref_fwd_process_mesh + output_var = vars[output_name] + sub_program_dist_context.set_tensor_dist_attr_for_program( + output_var, tensor_dist_attr + ) + + # op + grad_op_dist_attr = OperatorDistAttr() + grad_op_dist_attr.process_mesh = ref_fwd_process_mesh + + for var_name in grad_op_next_op.input_arg_names: + grad_op_dist_attr.set_input_dims_mapping( + var_name, ref_fwd_dims_mapping + ) + grad_op_dist_attr.set_output_dims_mapping( + output_name, ref_fwd_dims_mapping + ) + grad_op_dist_attr.impl_type = "default" + grad_op_dist_attr.impl_idx = 0 + + sub_program_dist_context.set_op_dist_attr_for_program( + grad_op_next_op, grad_op_dist_attr + ) + + def complete_sub_bwd_programs(self): + for idx in self.sub_programs_dist_context: + for parallelism in self.sub_programs_dist_context[idx]: + for key in self.sub_programs_dist_context[idx][parallelism]: + sub_program_dist_context = self.sub_programs_dist_context[ + idx + ][parallelism][key] + self._complete_sub_bwd_program(sub_program_dist_context) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py index 4eb0408976aba..ecff2bbf8935b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py @@ -178,6 +178,7 @@ def make_program(): [None, None], ) tmp_out = paddle.matmul(out1, tmp_param) + tmp_out = paddle.scale(tmp_out, 0.5) out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0] out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1] @@ -286,6 +287,7 @@ def make_program(): ) tmp_out = paddle.matmul(out1, tmp_param) + tmp_out = paddle.scale(tmp_out, 0.5) out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0] out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py index 2823d4d9a318c..e1d8eb8d37903 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py @@ -119,9 +119,10 @@ def test_gpt(self): RuleBasedTuner, ) - dist_context = DistributedContext() + dist_context = DistributedContext(train_program) + dist_context.initialize() tuner = RuleBasedTuner(dist_context) - layers = tuner.cluster_operators(train_program.global_block().ops) + layers = tuner.cluster_operators() op_types = [] for layer in layers: tmp = [] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py index 047b9c7507fbf..b239de918a251 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py @@ -112,18 +112,11 @@ def test_gpt(self): sequence_len, vocab_size, ) - from paddle.distributed.auto_parallel.dist_context import ( - DistributedContext, - ) from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( _PATTERNS, GraphUtil, - RuleBasedTuner, ) - dist_context = DistributedContext() - tuner = RuleBasedTuner(dist_context) - layers = tuner.cluster_operators(train_program.global_block().ops) graph = GraphUtil.convert_to_graph(train_program.global_block()) print("graph: ", graph) print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"]) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py index 808a4427d978a..d1285b7895e17 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py @@ -134,7 +134,9 @@ def test_gpt(self): tuner.gen_full_program() tuner.match_program(tuner._dist_context.serial_main_program) process_mesh = ProcessMesh([0, 1]) + tuner.gen_fwd_sub_programs_by_clone() tuner.complete_sub_fwd_programs(process_mesh) + tuner.complete_sub_bwd_programs() if __name__ == "__main__":