Skip to content

Commit

Permalink
add bwd sub program completion
Browse files Browse the repository at this point in the history
  • Loading branch information
Caozhou1995 committed Mar 22, 2023
1 parent a2878f5 commit 9bd3242
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 9 deletions.
199 changes: 199 additions & 0 deletions python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,3 +1411,202 @@ def complete_sub_fwd_programs(self, process_mesh):
if idx not in self.sub_programs_dist_context:
self.sub_programs_dist_context[idx] = {}
self._compelte_sub_fwd_program(idx, sub_fwd_program, process_mesh)

def _complete_sub_bwd_program(self, sub_program_dist_context):
"""
Complete the backward OP according to the forward OP.
Most of the logic is the same as the backward completion in the completer.
The difference is that find the backward OP according to the forward OP,
while find the forward OP according to the backward OP in the completer.
"""

def _is_grad_var_name(name):
if "@GRAD" in name:
return True
return False

sub_fwd_program = sub_program_dist_context.serial_main_program
block = sub_fwd_program.global_block()
vars = self.full_main_program.global_block().vars
ops = self.full_main_program.global_block().ops
grad_var_to_var = (
self.full_main_program_dist_context.dist_op_context.grad_var_to_var[
1
]
)
for forward_op in block.ops:
if (
forward_op.desc.original_id()
not in self.op_original_id_to_grad_op_original_id
):
continue
grad_op_id = self.op_original_id_to_grad_op_original_id[
forward_op.desc.original_id()
]
# for unsqueeze2 op in gpt, it has no grad op
# or for no need to bwd
if grad_op_id not in self.op_original_id_to_op:
continue
grad_op = self.op_original_id_to_op[grad_op_id]
if grad_op.type == "concat" and forward_op.type == "split":
forward_op_dist_attr = (
sub_program_dist_context.get_op_dist_attr_for_program(
forward_op
)
)
output_var = vars[grad_op.desc.output('Out')[0]]
split_input_var_name = forward_op.input("X")[0]
ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
split_input_var_name
)
ref_mesh = forward_op_dist_attr.process_mesh

grad_op_dist_attr = OperatorDistAttr()
for input_name in grad_op.input_arg_names:
grad_op_dist_attr.set_input_dims_mapping(
input_name, ref_dims_mapping
)

output_var_dist_attr = TensorDistAttr()
output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = ref_mesh
sub_program_dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr
)

grad_op_dist_attr.set_output_dims_mapping(
output_var.name, ref_dims_mapping
)
grad_op_dist_attr.process_mesh = ref_mesh
sub_program_dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr
)
grad_op_dist_attr.impl_type = (
fwd_op_dist_attr.impl_type # noqa: F821
)
grad_op_dist_attr.impl_idx = (
fwd_op_dist_attr.impl_idx # noqa: F821
)
continue

fwd_op_dist_attr = (
sub_program_dist_context.get_op_dist_attr_for_program(
forward_op
)
)
fwd_op_process_mesh = fwd_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistAttr()
grad_op_dist_attr.process_mesh = fwd_op_process_mesh

for input_name in grad_op.input_arg_names:
if (
input_name not in forward_op.input_arg_names
and input_name not in forward_op.output_arg_names
):
if input_name in grad_var_to_var.keys():
fwd_name = grad_var_to_var[input_name]
ref_dims_mapping = (
fwd_op_dist_attr.get_output_dims_mapping(fwd_name)
)
else:
input_var = vars[input_name]
ref_dims_mapping = sub_program_dist_context.get_tensor_dist_attr_for_program(
input_var
).dims_mapping
else:
if input_name in forward_op.input_arg_names:
ref_dims_mapping = (
fwd_op_dist_attr.get_input_dims_mapping(input_name)
)
else:
ref_dims_mapping = (
fwd_op_dist_attr.get_output_dims_mapping(input_name)
)
assert (
ref_dims_mapping is not None
), "[{}] 's dims mapping is NONE".format(input_name)
grad_op_dist_attr.set_input_dims_mapping(
input_name, ref_dims_mapping
)

for output_name in grad_op.output_arg_names:
assert output_name in grad_var_to_var
fwd_name = grad_var_to_var[output_name]
ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping(
fwd_name
)
# var
output_var = vars[output_name]
tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = ref_dims_mapping
tensor_dist_attr.process_mesh = fwd_op_process_mesh
sub_program_dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr
)
# op
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_dims_mapping
)

grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type
grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx
sub_program_dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr
)

grad_op_idx = self.op_original_id_to_idx[grad_op_id]
if grad_op_idx + 1 < len(ops):
grad_op_next_op = ops[grad_op_idx + 1]
if grad_op_next_op.type == "sum":
assert all(
map(_is_grad_var_name, grad_op_next_op.input_arg_names)
)
output_name = grad_op_next_op.output_arg_names[0]
assert (
output_name in grad_var_to_var
), "sum op's output '{}' has no corresponding var".format(
output_name
)
ref_fwd_var_name = grad_var_to_var[output_name]
ref_fwd_var = vars[ref_fwd_var_name]
ref_fwd_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
ref_fwd_var
)
ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping
ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh

# output
tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping
tensor_dist_attr.process_mesh = ref_fwd_process_mesh
output_var = vars[output_name]
sub_program_dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr
)

# op
grad_op_dist_attr = OperatorDistAttr()
grad_op_dist_attr.process_mesh = ref_fwd_process_mesh

for var_name in grad_op_next_op.input_arg_names:
grad_op_dist_attr.set_input_dims_mapping(
var_name, ref_fwd_dims_mapping
)
grad_op_dist_attr.set_output_dims_mapping(
output_name, ref_fwd_dims_mapping
)
grad_op_dist_attr.impl_type = "default"
grad_op_dist_attr.impl_idx = 0

sub_program_dist_context.set_op_dist_attr_for_program(
grad_op_next_op, grad_op_dist_attr
)

def complete_sub_bwd_programs(self):
for idx in self.sub_programs_dist_context:
for parallelism in self.sub_programs_dist_context[idx]:
for key in self.sub_programs_dist_context[idx][parallelism]:
sub_program_dist_context = self.sub_programs_dist_context[
idx
][parallelism][key]
self._complete_sub_bwd_program(sub_program_dist_context)
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def make_program():
[None, None],
)
tmp_out = paddle.matmul(out1, tmp_param)
tmp_out = paddle.scale(tmp_out, 0.5)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]

out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1]
Expand Down Expand Up @@ -286,6 +287,7 @@ def make_program():
)

tmp_out = paddle.matmul(out1, tmp_param)
tmp_out = paddle.scale(tmp_out, 0.5)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]

out8 = paddle.transpose(out2, [1, 0]) # [4, 8] [0, -1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,10 @@ def test_gpt(self):
RuleBasedTuner,
)

dist_context = DistributedContext()
dist_context = DistributedContext(train_program)
dist_context.initialize()
tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops)
layers = tuner.cluster_operators()
op_types = []
for layer in layers:
tmp = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,11 @@ def test_gpt(self):
sequence_len,
vocab_size,
)
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS,
GraphUtil,
RuleBasedTuner,
)

dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops)
graph = GraphUtil.convert_to_graph(train_program.global_block())
print("graph: ", graph)
print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def test_gpt(self):
tuner.gen_full_program()
tuner.match_program(tuner._dist_context.serial_main_program)
process_mesh = ProcessMesh([0, 1])
tuner.gen_fwd_sub_programs_by_clone()
tuner.complete_sub_fwd_programs(process_mesh)
tuner.complete_sub_bwd_programs()


if __name__ == "__main__":
Expand Down

0 comments on commit 9bd3242

Please sign in to comment.