Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Fit allreduce_matmul_grad_overlapping when using master grad #61865

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def _insert_reshape_op(self, block, index, x, shape, op_role, out=None):
ctx=self.dist_context,
chunk_id=x_dist_attr.chunk_id,
)
block._sync_with_cpp()

return out

Expand All @@ -134,6 +133,53 @@ def _split_matmul_grad_and_multi_streaming_allreduce(
matmul_grad_op = ops[matmul_grad_id]
allreduce_op = ops[allreduce_id]

# NOTE(Sonder): Why move those operations to the back of matmul_v2?
# When using amp_master_grad, the cast operation is inserted after matmul_grad.
# However, when employing allreduce_matmul_grad_overlapping, the matmul_grad is
# split into two matmul operations. In this case, some operations would access
# uninitialized tensors. Therefore, we move the cast operation to the back of the
# second matmul operation to avoid this problem.
skip_overlapping = False
moved_ops_idx = []
moved_ops_output = []
matmul_grad_output = matmul_grad_op.output('Y@GRAD')[0]

for idx in range(matmul_grad_id + 1, allreduce_id):
if matmul_grad_output in ops[idx].desc.input_arg_names():
moved_ops_idx.append(idx)
moved_ops_output.extend(ops[idx].desc.output_arg_names())
else:
for input_name in ops[idx].desc.input_arg_names():
if input_name in moved_ops_output:
skip_overlapping = True

if skip_overlapping:
continue

for i, idx in enumerate(moved_ops_idx):
op = ops[idx]
dist_attr = self.dist_context.get_op_dist_attr_for_program(op)

op_inputs = op.desc.input_names()
op_outputs = op.desc.output_names()

op_inputs = {name: op.input(name) for name in op_inputs}
op_outputs = {name: op.output(name) for name in op_outputs}

op = block._insert_op_without_sync(
index=allreduce_id + 1 + i,
type=op.type,
inputs=op_inputs,
outputs=op_outputs,
attrs=op.all_attrs(),
)

self.dist_context.set_op_dist_attr_for_program(op, dist_attr)

for i, idx in enumerate(moved_ops_idx):
block._remove_op(idx - i, sync=False)
allreduce_id -= 1

tran_x = matmul_grad_op.attr("trans_x")
assert (
not tran_x
Expand Down Expand Up @@ -242,5 +288,5 @@ def _split_matmul_grad_and_multi_streaming_allreduce(
matmul_op, matmul_grad_dist_attr
)

block._remove_op(matmul_grad_id)
block._sync_with_cpp()
block._remove_op(matmul_grad_id, sync=False)
block._sync_with_cpp()