Skip to content

Commit

Permalink
[Auto Parallel] Update rule based tuner (#51908)
Browse files Browse the repository at this point in the history
* add patterns

* update rule based tuner

* add forward sub program completion

* add unittest

* add bwd sub program completion
  • Loading branch information
Caozhou1995 authored Mar 23, 2023
1 parent 13b8b5e commit 325fdf1
Show file tree
Hide file tree
Showing 12 changed files with 806 additions and 29 deletions.
11 changes: 9 additions & 2 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
fetch_vars={},
cluster=None,
strategy=None,
json_config=None,
):
# Data members related to original programs (unchanged)
self._original_serial_main_program = serial_main_prog
Expand Down Expand Up @@ -129,6 +130,8 @@ def __init__(
# A flag indicates whether the used parallelism is data parallel
self._data_parallel = False

self._json_config = json_config

@property
def serial_main_program(self):
return self._serial_main_program
Expand Down Expand Up @@ -181,6 +184,10 @@ def serial_ordered_nodes(self):
def process_meshes(self):
return self._process_meshes

@process_meshes.setter
def process_meshes(self, val):
self._process_meshes = val

@property
def pass_context(self):
return self._pass_context
Expand Down Expand Up @@ -397,7 +404,7 @@ def _restore(
if dist:
self._restore_dist_info(dist_mode)

def initialize(self, with_graph=True, with_cpp=False):
def initialize(self, with_graph=True, with_cpp=False, no_default=False):
if not self._is_initialized:
if not self._serial_main_program:
if self._original_serial_main_program:
Expand All @@ -418,7 +425,7 @@ def initialize(self, with_graph=True, with_cpp=False):
if not self._serial_fetch_vars:
self._restore_serial_fetch_vars()

self._init_dist_attr_for_program()
self._init_dist_attr_for_program(no_default)
# Backup the original distributed information for later restore
self._original_dist_tensors_for_program = copy.deepcopy(
self._dist_tensors_for_program
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
varname
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ def is_input_compatible(self, dist_op):
for mapping in ids_dims_mapping[1:]:
if is_dim_shard(mapping):
return False

if is_dim_shard(ids_dims_mapping[0]) and is_dim_shard(
w_dims_mapping[-2]
):
if ids_dims_mapping[0] == w_dims_mapping[-2]:
return False
return True

def is_output_compatible(self, dist_op):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
processes = process_mesh.process_ids
# col parallel: matmul + allreduce
if backward_op.attr("trans_y"):
Y_var_dim_mapping.reverse()
Y_var_dim_mapping = list(reversed(Y_var_dim_mapping))
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]

Expand Down
87 changes: 87 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/dist_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distributed.fleet.meta_optimizers.common import OpRole

from ..cost import (
_g_op_cost_factory,
build_comp_costs_from_descs,
build_comp_desc_from_dist_op,
build_dp_costs,
)
from ..utils import compute_compatible_and_update_dim_mapping
from .common import (
DistributedOperatorImpl,
DistributedOperatorImplContainer,
is_parameter_related,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
Expand All @@ -42,6 +51,84 @@ def __init__(self, name):
def is_input_compatible(self, dist_op):
return True

def calc_cost(self, op_role, dist_op, ctx, cluster):
"""Calculate the cost by the op role."""
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost

def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
processes = dist_op.dist_attr.process_mesh.process_ids
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
_g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
)
res_cost = [cost_mapping]

return res_cost

def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(
dist_op=dist_op, dist_context=ctx
)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.process_ids
backward_op = dist_op.serial_op
op_type = backward_op.type
cost_mapping = build_comp_costs_from_descs(
_g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster
)
res.append(cost_mapping)

main_block = backward_op.block
need_gradient_allreduce = False
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break

if need_gradient_allreduce:
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block
):
var_dim_mapping = dist_attr.get_input_dims_mapping(
varname
)
mesh_shape = process_mesh.shape
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(
res,
dist_op,
ctx,
var_names,
attrs,
parallel_axis,
cluster,
)
return res

def is_output_compatible(self, dist_op):
return True

Expand Down
Loading

0 comments on commit 325fdf1

Please sign in to comment.