Skip to content

Commit

Permalink
[Auto Parallel] Rename methods of ProcessMesh
Browse files Browse the repository at this point in the history
  • Loading branch information
aoyulong committed Oct 31, 2022
1 parent 60e0c50 commit 0ce3633
Show file tree
Hide file tree
Showing 33 changed files with 228 additions and 223 deletions.
6 changes: 2 additions & 4 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,9 @@ def _validate_dims_mapping(dims_mapping, process_mesh):
if dims_mapping is None:
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
process_mesh.topology
):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(process_mesh.shape):
return False
for i in range(len(process_mesh.topology)):
for i in range(len(process_mesh.shape)):
if dims_mapping.count(i) > 1:
return False
return True
Expand Down
24 changes: 12 additions & 12 deletions python/paddle/distributed/auto_parallel/cost/base_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
assert process_mesh, "Process mesh must not be None."
processes = process_mesh.processes
processes = process_mesh.process_ids
for process in processes:
desc = {}
desc["op"] = op.type
Expand All @@ -105,7 +105,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
global_sizes = var.shape
# NOTE: When support uneven partition, the shard_sizes will be got from dist_attr.
shard_sizes = None
topology = process_mesh.topology
topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes(
global_sizes,
dims_mapping,
Expand All @@ -131,7 +131,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
)
relative_idx = _get_idx_in_axis(
processes,
dist_attr.process_mesh.topology,
dist_attr.process_mesh.shape,
embedding_row_dim_mapping,
process,
)
Expand All @@ -155,8 +155,8 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
process_mesh = dist_attr.process_mesh
global_sizes = var.shape
shard_sizes = None
processes = process_mesh.processes
topology = process_mesh.topology
processes = process_mesh.process_ids
topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes(
global_sizes,
dims_mapping,
Expand All @@ -172,7 +172,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
# Modify shape attr according to how output are partitioned
out_name = var_name_list[0]
dims_mapping = dist_attr.get_output_dims_mapping(out_name)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_shape = dist_attr.process_mesh.shape
shape_list = op.attr("shape")
# Modify target shape
for idx, axis in enumerate(dims_mapping):
Expand Down Expand Up @@ -255,7 +255,7 @@ def build_comm_desc_from_dist_op(
process_mesh = dist_attr.process_mesh
assert process_mesh, "Process mesh must not be None."

processes = process_mesh.processes
processes = process_mesh.process_ids
op_descs = {}
for process in processes:
rank_id = process
Expand Down Expand Up @@ -297,7 +297,7 @@ def build_comm_desc_from_dist_op(
)
global_sizes = var.shape
shard_sizes = None
topology = process_mesh.topology
topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes(
global_sizes,
dims_mapping,
Expand All @@ -313,8 +313,8 @@ def build_comm_desc_from_dist_op(

# Get comm group by parallel_axis or the given group_ranks.
if parallel_axis is not None:
process_mesh_shape = process_mesh.topology
process_mesh_group = process_mesh.processes
process_mesh_shape = process_mesh.shape
process_mesh_group = process_mesh.process_ids
comm_group_ranks = _get_comm_group(
process_mesh_group,
process_mesh_shape,
Expand Down Expand Up @@ -386,7 +386,7 @@ def build_dp_costs(

dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
processes = process_mesh.process_ids
assert len(var_names) == 1
vars = dist_op.serial_op.block.vars
var_name = var_names[0]
Expand Down Expand Up @@ -445,7 +445,7 @@ def build_dp_costs(
)
global_sizes = var.shape
shard_sizes = None
topology = process_mesh.topology
topology = process_mesh.shape
shape = DistributedTensor.get_local_sizes(
global_sizes,
dims_mapping,
Expand Down
22 changes: 11 additions & 11 deletions python/paddle/distributed/auto_parallel/cost/estimate_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _estimate_core(self, dist_context, resharder, block):
# Calc dist op cost
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
processes = op_dist_attr.process_mesh.process_ids

container = get_distributed_operator_impl_container(
op_dist_attr.impl_type
Expand Down Expand Up @@ -273,8 +273,8 @@ def _estimate_max_memory_by_dist_op(self, dist_context):
# This estimation will be improved, now reshard and inplace are not considered.
# Persist var is not free.
def _convert_pm_and_dm_to_str(process_mesh, dims_mapping):
processes = ",".join([str(x) for x in process_mesh.processes])
topology = ",".join([str(x) for x in process_mesh.topology])
processes = ",".join([str(x) for x in process_mesh.process_ids])
topology = ",".join([str(x) for x in process_mesh.shape])
dims_mapping = ",".join([str(x) for x in dims_mapping])
result = processes + topology + dims_mapping
return result
Expand Down Expand Up @@ -318,8 +318,8 @@ def _convert_pm_and_dm_to_str(process_mesh, dims_mapping):
sizes = DistributedTensor.get_local_sizes(
global_sizes,
input_dims_mapping,
process_mesh.topology,
process_mesh.processes,
process_mesh.shape,
process_mesh.process_ids,
)
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype
Expand All @@ -346,8 +346,8 @@ def _convert_pm_and_dm_to_str(process_mesh, dims_mapping):
sizes = DistributedTensor.get_local_sizes(
global_sizes,
output_dims_mapping,
process_mesh.topology,
process_mesh.processes,
process_mesh.shape,
process_mesh.process_ids,
)
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype
Expand Down Expand Up @@ -380,7 +380,7 @@ def _convert_pm_and_dm_to_str(process_mesh, dims_mapping):
# Not used
if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var)
for process in process_mesh.processes:
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
Expand All @@ -390,7 +390,7 @@ def _convert_pm_and_dm_to_str(process_mesh, dims_mapping):
if has_used_var not in can_free_vars:
can_free_vars.add(has_used_var)
if not var.persistable:
for process in process_mesh.processes:
for process in process_mesh.process_ids:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
Expand All @@ -409,7 +409,7 @@ def _convert_pm_and_dm_to_str(process_mesh, dims_mapping):
# Not used
if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var)
for process in process_mesh.processes:
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
Expand All @@ -419,7 +419,7 @@ def _convert_pm_and_dm_to_str(process_mesh, dims_mapping):
if has_used_var not in can_free_vars:
can_free_vars.add(has_used_var)
if not var.persistable:
for process in process_mesh.processes:
for process in process_mesh.process_ids:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,8 +871,8 @@ def amend_dist_attr_for_program(self):
else:
tensor_shape = serial_tensor.shape
dims_mapping = dist_attr.dims_mapping
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_processes = dist_attr.process_mesh.processes
process_mesh_shape = dist_attr.process_mesh.shape
process_mesh_processes = dist_attr.process_mesh.process_ids
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
Expand All @@ -888,8 +888,8 @@ def amend_dist_attr_for_program(self):
for dist_op in self._dist_ops_for_program.values():
serial_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_processes = dist_attr.process_mesh.processes
process_mesh_shape = dist_attr.process_mesh.shape
process_mesh_processes = dist_attr.process_mesh.process_ids
for arg_name in serial_op.input_arg_names:
if dist_op.get_serial_input(arg_name) is None:
tensor_shape = []
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/distributed/auto_parallel/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ def validate_dist_attr(self):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology
self.dist_attr.process_mesh.shape
):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
for i in range(len(self.dist_attr.process_mesh.shape)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != input_dist_attr.process_mesh:
Expand All @@ -187,10 +187,10 @@ def validate_dist_attr(self):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology
self.dist_attr.process_mesh.shape
):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
for i in range(len(self.dist_attr.process_mesh.shape)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != output_dist_attr.process_mesh:
Expand Down
16 changes: 8 additions & 8 deletions python/paddle/distributed/auto_parallel/dist_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ def validate_dist_attr(self):
if self.dist_attr.dims_mapping[
i
] < -1 or self.dist_attr.dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology
self.dist_attr.process_mesh.shape
):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
for i in range(len(self.dist_attr.process_mesh.shape)):
if self.dist_attr.dims_mapping.count(i) > 1:
return False
return True
Expand All @@ -248,8 +248,8 @@ def local_sizes(self, rank=None):
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.shape
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes
)
Expand All @@ -265,8 +265,8 @@ def local_offsets(self, rank=None):
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.shape
local_offsets = DistributedTensor.get_local_offsets(
global_sizes,
dims_mapping,
Expand All @@ -291,8 +291,8 @@ def local_shard(self, rank=None):
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
processes = self.dist_attr.process_mesh.process_ids
topology = self.dist_attr.process_mesh.shape
local_shard = DistributedTensor.get_local_shard(
global_sizes,
dims_mapping,
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distributed/auto_parallel/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ def init(self, main_program, place, dist_context):
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
dist_attr = {
"dims_mapping": var_dist_attr.dims_mapping,
"process_shape": var_dist_attr.process_mesh.topology,
"process_group": var_dist_attr.process_mesh.processes,
"process_shape": var_dist_attr.process_mesh.shape,
"process_group": var_dist_attr.process_mesh.process_ids,
}
# slice param_value with dist_attr
# share sliced_param_value with param_tensor in global_scope
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/distributed/auto_parallel/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def is_parameter_related(varname, block):

def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
var_shape = block.var(src_var.name).shape
var_topoloy = src_var_dist_attr.process_mesh.topology
var_topoloy = src_var_dist_attr.process_mesh.shape
var_dims_mapping = src_var_dist_attr.dims_mapping

complete_shape = []
Expand All @@ -285,7 +285,7 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
complete_shape.append(new_shape)

exact_shape = []
input_topology = op_input_dist_attr.process_mesh.topology
input_topology = op_input_dist_attr.process_mesh.shape
input_dims_mapping = op_input_dist_attr.dims_mapping
for idx, shape in enumerate(complete_shape):
if input_dims_mapping[idx] == -1:
Expand Down Expand Up @@ -360,10 +360,10 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):

op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op)
process_mesh = op_dist_attr.process_mesh
mesh_shape = process_mesh.topology
mesh_shape = process_mesh.shape
# FIXME Hack for Pipeline Parallelism where the current operator
# not belong to the mesh the current rank belong to.
if rank not in process_mesh.processes:
if rank not in process_mesh.process_ids:
rank = _get_corresponding_rank(dist_ctx, process_mesh, rank)

for var_name in act_grad_names:
Expand All @@ -374,8 +374,8 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):

if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
group_ranks = _get_comm_group(
process_mesh.processes,
process_mesh.topology,
process_mesh.process_ids,
process_mesh.shape,
batch_size_axis,
rank,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def backward(ctx, *args, **kwargs):
str(backward_op)
)

assert rank_id in dist_attr.process_mesh.processes
assert rank_id in dist_attr.process_mesh.process_ids

assert 'X' in kwargs, "input [{}] is not given".format('X')
assert 'Scale' in kwargs, "input [{}] is not given".format('Scale')
Expand Down Expand Up @@ -118,7 +118,7 @@ def backward(ctx, *args, **kwargs):
rank_id
in ctx.get_tensor_dist_attr_for_program(
main_block.var(varname)
).process_mesh.processes
).process_mesh.process_ids
):
filter_vars.append(varname)

Expand Down
Loading

0 comments on commit 0ce3633

Please sign in to comment.