Skip to content

Commit

Permalink
[Auto Parallel] Support Primitive operators with Data Parallel (#42709)
Browse files Browse the repository at this point in the history
* auto parallel support primitive op with data parallel

* add primitive change

* 5 loss 3D cylinder acc aligned

* add unitest
  • Loading branch information
JZ-LIANG authored May 19, 2022
1 parent a777893 commit 6b8efc4
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 9 deletions.
67 changes: 67 additions & 0 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,3 +1250,70 @@ def complete_update_annotation(self, serial_main_program=None):
self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr)
continue

def complete_prim_annotation(self, serial_main_program=None):
"""
fill default data parallel annotation for program with primitive operators.
Arguments:
serial_main_program: partial annotated serial_main_program.
Returns:
serial_main_program: completed annotated serial_main_program.
"""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program

import time

start_time = time.time()
self._dist_context._is_initialized = True

start_time = time.time()
self._dist_context._init_dist_attr_for_program()

start_time = time.time()
self._init_global_mesh_for_program()

# Do the validation check and amend some completion
start_time = time.time()
self._dist_context.amend_dist_attr_for_program()
self._dist_context.validate_dist_attr_for_program()

def _init_global_mesh_for_program(self):
# Copy the dist tensors and dist ops annotated by users from the default context
# global mesh
from paddle.distributed.auto_parallel.process_group import get_world_process_group
world_ranks = get_world_process_group().ranks

for block in self._dist_context._serial_main_program.blocks:
for tensor in block.vars.values():
# Copy the distributed tensors in the default context
dist_tensor = self._dist_context.get_dist_tensor_for_program(
tensor)
assert dist_tensor is not None
dist_tensor.dist_attr.process_mesh = world_ranks
for op in block.ops:
# Copy the distributed operators in the default context
dist_op = self._dist_context.get_dist_op_for_program(op)
assert dist_op is not None
dist_op.dist_attr.process_mesh = world_ranks

# Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True)
if op_dist_impls is not None:
backup_op_dist_attr = copy.deepcopy(dist_op.dist_attr)
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
dist_op.dist_attr.impl_type = "default"
else:
dist_op.dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
dist_op.dist_attr.impl_idx = op_dist_impl.idx
break
else:
dist_op.dist_attr = backup_op_dist_attr
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@
from . import dist_slice
from . import dist_fused_feedforward
from . import dist_fused_attention
from . import dist_reduce_p
7 changes: 4 additions & 3 deletions python/paddle/distributed/auto_parallel/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@


def is_elementwise_op(op_type):
for eltwise_op in _g_elementwise_ops:
if eltwise_op in op_type:
return True
if op_type in _g_elementwise_ops:
return True
if "elementwise" in op_type:
return True
return False


Expand Down
63 changes: 57 additions & 6 deletions python/paddle/distributed/auto_parallel/operators/dist_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import is_valid_list_index, is_prim_op
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
Expand All @@ -35,6 +35,55 @@
__op_not_need_param_init__ = ["while", "cond"]


def prim_operator_data_parallel_functor(ctx, src_op):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block

var_name = src_op.output_arg_names[0]
if var_name in ctx.grads_params:
assert var_name not in ctx.synced_gradient, "in primtive mode, grad is already {} synced".format(
var_name)
ctx.synced_gradient.add(var_name)
sync_group = new_process_group(ctx.data_parallel_group)

allreduce_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [var_name]},
outputs={'Out': [var_name]},
attrs={
'ring_id': sync_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})

param = ctx.grads_params[var_name]
startup_block = dist_op_context.startup_block
new_op = startup_block.append_op(
type='c_broadcast',
inputs={'X': [param]},
outputs={'Out': [param]},
attrs={
'ring_id': sync_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})

grad_var = main_block.var(var_name)
dims_mapping = ctx.get_tensor_dist_attr_for_program(
grad_var).dims_mapping
dist_attr = ctx.get_op_dist_attr_for_program(src_op)
process_mesh = dist_attr.process_mesh
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
ctx.set_op_dist_attr_for_program(allreduce_op, op_attr)

return


class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedDefault, self).__init__(op_type)
Expand Down Expand Up @@ -292,7 +341,6 @@ def update_dims_mapping(self, dist_op):

@staticmethod
def forward(ctx, *args, **kwargs):

dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
Expand All @@ -315,15 +363,20 @@ def forward(ctx, *args, **kwargs):
output_name)

# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])

main_block._sync_with_cpp()
# data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled
if prim_enabled():
assert is_prim_op(src_op)
prim_operator_data_parallel_functor(ctx, src_op)
return

# param initialization sync
if src_op.type in __op_not_need_param_init__:
Expand Down Expand Up @@ -373,8 +426,6 @@ def forward(ctx, *args, **kwargs):
op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_dist_attr_for_program(new_op, op_attr)

startup_block._sync_with_cpp()

@staticmethod
def backward(ctx, *args, **kwargs):

Expand Down
151 changes: 151 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..utils import set_dist_op_desc_original_id
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank


class DistributedReducePrimtive(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedReducePrimtive, self).__init__(op_type)


register_distributed_operator_impl_container(
DistributedReducePrimtive("reduce_p"))


# Batch Dimension Reduce Primitive
class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReducePrimtiveImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True

def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr

return len(op_desc.input_arg_names()) == 1

def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
outputs = op_desc.output_arg_names()

if len(outputs) != 1:
return False

output_name = outputs[0]
output_var = dist_op.serial_op.block.var(output_name)
if output_var.shape != (1, ):
return False

return True

def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr

return self.is_input_compatible(dist_op) and self.is_output_compatible(
dist_op)

def update_dims_mapping(self, dist_op):
changed = False

return changed

@staticmethod
def forward(ctx, *args, **kwargs):

dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id

# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)

# replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])

# batch dimension synchronization
var_name = src_op.output_arg_names[0]
sync_group = new_process_group(ctx.data_parallel_group)
allreduce_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [var_name]},
outputs={'Out': [var_name]},
attrs={
'ring_id': sync_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})

# dist attr
var = main_block.var(var_name)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
new_op_attr = OperatorDistributedAttribute()
new_op_attr.process_mesh = op_dist_attr.process_mesh
new_op_attr.set_output_dims_mapping(var.name,
tensor_dist_attr.dims_mapping)
new_op_attr.set_input_dims_mapping(var.name,
tensor_dist_attr.dims_mapping)
ctx.set_op_dist_attr_for_program(allreduce_op, new_op_attr)

@staticmethod
def backward(ctx, *args, **kwargs):
raise RuntimeError(
"primitive operator does NOT have backward function, op type: {}".
format(str(op.type)))


register_distributed_operator_impl(
"reduce_p", DistributedReducePrimtiveImpl0("batch_dimension_reduce_p"))
5 changes: 5 additions & 0 deletions python/paddle/distributed/auto_parallel/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ def partition_block(self, ref_block, target_block):
dist_op_backward_impl.backward(
self._dist_context, **kinputs, **koutputs,
**{"grad_var_to_var": grad_var_to_var})
elif int(op.attr('op_role')) == 2:
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_impl = get_distributed_operator_impl_container(
"default").get_impl(0)
dist_op_impl.backward(self._dist_context, **kinputs, **koutputs)
else:
raise NotImplementedError(
"partitioner only support forward op and backward op, but got {}".
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,10 @@ def is_loss_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))


def is_prim_op(op):
return op.type.endswith("_p")


def get_loss_op(block):
loss_ops = []
for op in block.ops:
Expand All @@ -1118,6 +1122,9 @@ def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs):
tensor_dist_attr.dims_mapping = dims_mapping
# TODO get global mesh group
tensor_dist_attr.process_mesh = process_mesh
if "mark_annotated" in kwargs and kwargs["mark_annotated"]:
tensor_dist_attr.mark_annotated("dims_mapping")
tensor_dist_attr.mark_annotated("process_mesh")
dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr)
return tensor_dist_attr

Expand Down
Loading

0 comments on commit 6b8efc4

Please sign in to comment.