Skip to content

Commit

Permalink
[AutoParallel] fix engine _build and cost method (#47263)
Browse files Browse the repository at this point in the history
* fix engine build method

* fix import

* update engine cost

* update raise error

* update cmakelist

* revert optimizer

* revert optimizer

* fix unittest

* fix unittest

Co-authored-by: caozhou <[email protected]>
  • Loading branch information
zhaoyinglia and Caozhou1995 authored Oct 28, 2022
1 parent 26c419c commit 315ef26
Show file tree
Hide file tree
Showing 7 changed files with 634 additions and 188 deletions.
38 changes: 38 additions & 0 deletions python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,25 @@ def calc_time(self):
return 0


@register_op_cost
class ArgsortOpCost(CompOpCost):
OP_TYPE = "argsort"

def __init__(self, op=None, op_desc=None, cluster=None):
super(ArgsortOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class AssignOpCost(CompOpCost):
OP_TYPE = "assign"
Expand Down Expand Up @@ -338,6 +357,25 @@ def calc_time(self):
return 0


@register_op_cost
class EqualOpCost(CompOpCost):
OP_TYPE = "equal"

def __init__(self, op=None, op_desc=None, cluster=None):
super(EqualOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class EmbeddingOpCost(CompOpCost):
OP_TYPE = "c_embedding"
Expand Down
28 changes: 14 additions & 14 deletions python/paddle/distributed/auto_parallel/cost/estimate_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,12 @@ def pretty_print_cost(self):

def get_cost_from_engine(engine, mode):
from ..utils import to_list
import copy

# Construct cost estimator by original main program
serial_main_prog = (
engine._serial_main_progs[mode].clone()
if mode in engine._serial_main_progs
engine._fwd_main_progs[mode].clone()
if mode in engine._fwd_main_progs
else engine._orig_main_prog.clone()
)

Expand All @@ -566,37 +567,36 @@ def get_cost_from_engine(engine, mode):
)
else engine._losses
)

if mode in engine._dist_contexts:
dist_context = engine._dist_contexts[mode]
completer = engine._planners[mode].completer
serial_optimizer = copy.deepcopy(engine._orig_optimizer)
if mode in engine._fwd_dist_contexts:
dist_context = copy.deepcopy(engine._fwd_dist_contexts[mode])
else:
from ..completion import Completer
from ..dist_context import DistributedContext

dist_context = DistributedContext(
serial_main_prog,
serial_startup_prog,
engine._optimizer,
serial_optimizer,
losses,
{},
{"loss": losses},
engine._cluster,
engine._strategy,
)
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program
)
from ..completion import Completer

completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program
)

if mode == "eval" or mode == "predict":
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
elif mode == "train":
from ..parallelizer_v2 import Parallelizer

# Get serial main program with backward
serial_optimizer = engine._optimizer
parallelizer = Parallelizer(mode, completer, dist_context)
# Generate backward
loss_name = dist_context.serial_loss.name
Expand Down
Loading

0 comments on commit 315ef26

Please sign in to comment.