Skip to content

Commit

Permalink
[Auto parallel] align infer accuracy for ernie generator mode (#40077)
Browse files Browse the repository at this point in the history
* [Auto Parallel] Support the auto completion of while_op
* align infer  accuracy
  • Loading branch information
JZ-LIANG authored Mar 25, 2022
1 parent f027b2a commit 02146ba
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False

op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
in_name = op_desc.input('Input')[0]
Expand All @@ -78,7 +79,7 @@ def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
x_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1837,6 +1837,8 @@ def forward(ctx, *args, **kwargs):
out_var_dist_attr)

intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_allreduce_sum", 'tmp'])),
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
Expand Down Expand Up @@ -1936,7 +1938,6 @@ def is_input_compatible(self, dist_op):
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False

if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
Expand Down
32 changes: 13 additions & 19 deletions python/paddle/distributed/auto_parallel/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,7 @@ def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs)

# set dist attr uid
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# param.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy(
dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None
dist_context.set_tensor_dist_attr_for_program(param, dist_attr)
return param


def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname,
Expand All @@ -352,13 +346,7 @@ def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer)

# set dist attr uid
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# var.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy(
dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None
dist_context.set_tensor_dist_attr_for_program(var, dist_attr)
return var


def _partition_var(dist_context, src_block, dst_block, src_varname,
Expand All @@ -369,7 +357,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
src_var = src_block.var(src_varname)

if src_var.type in __not_shape_var_type__:
dst_block.create_var(
new_var = dst_block.create_var(
type=src_var.type,
name=dst_varname,
persistable=True,
Expand All @@ -380,11 +368,17 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
target_shape = _get_dist_shape(src_var, dist_attr)

if isinstance(src_var, Parameter):
_partition_parameter(dist_context, src_var, dst_block, dst_varname,
target_shape)
new_var = _partition_parameter(dist_context, src_var, dst_block,
dst_varname, target_shape)
else:
_partition_intermediate_var(dist_context, src_var, dst_block,
dst_varname, target_shape)
new_var = _partition_intermediate_var(
dist_context, src_var, dst_block, dst_varname, target_shape)

dist_attr = copy.deepcopy(
dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None
dist_context.set_tensor_dist_attr_for_program(new_var, dist_attr)

return target_shape


Expand Down

0 comments on commit 02146ba

Please sign in to comment.