Skip to content

Commit

Permalink
[AutoParallel] Unify VPP Parallel (#60868)
Browse files Browse the repository at this point in the history
* vpp ready

* add unitest

* update unitest and corner cases

* typoes

* fixed bug
  • Loading branch information
JZ-LIANG authored Jan 19, 2024
1 parent 1034f0e commit 6a8c3c5
Show file tree
Hide file tree
Showing 9 changed files with 503 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ phi::DeviceContext* GetDistTensorDeviceContext(
phi::distributed::DistTensor* input) {
// TODO(GhostScreaming): pipeline parallel may create an undefined middle grad
// tensor. In such case, we need to get default place.
auto place = input && input->defined() ? input->place() : GetDefaultPlace();
auto place =
input && input->initialized() ? input->place() : GetDefaultPlace();
return phi::DeviceContextPool::Instance().Get(place);
}

Expand Down
57 changes: 21 additions & 36 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1043,41 +1043,21 @@ def __init__(
# convert dygraph model to static model
if isinstance(loader, ShardDataloader):
(
inputs_spec,
labels_spec,
self._engine._inputs_spec,
self._engine._labels_spec,
) = self._engine._prepare_data_spec_from_dataloader(loader)
else:
batch_size = loader.batch_sampler.batch_size
inputs_spec, labels_spec = self._engine._prepare_data_spec(
(
self._engine._inputs_spec,
self._engine._labels_spec,
) = self._engine._prepare_data_spec(
loader.dataset, None, batch_size
)

if optimizer is not None and loss is not None:
# get the static graph in train mode
self._engine.prepare(
copy.deepcopy(inputs_spec),
copy.deepcopy(labels_spec),
mode="train",
init_parameters=False,
)
if loss is not None:
# get the static graph in eval mode
self._engine.prepare(
copy.deepcopy(inputs_spec),
copy.deepcopy(labels_spec),
mode="eval",
init_parameters=False,
)
# get the static graph in predict mode
self._engine.prepare(
copy.deepcopy(inputs_spec),
None,
mode="predict",
init_parameters=False,
)
# paddle.enable_static() will be called implicitly in self._engine.prepare.
# call paddle.disable_static to keep the outside of DistModel in dynamic graph mode
paddle.disable_static()

# set the default mode
if optimizer is not None and loss is not None:
self.train()
Expand All @@ -1093,23 +1073,23 @@ def train(self):
parameters of the model and return the loss.
"""
if not self._engine._has_prepared["train"]:
raise RuntimeError(
"The model for training has not been prepared, please set 'loss' and 'optimizer' when constructing DistModel."
)
self._engine._prepare_program(mode="train", init_parameters=False)

self._mode = "train"
self._engine.to_mode("train")
paddle.disable_static()

def eval(self):
"""
Set the mode of DistModel to "eval". In "eval" mode,
executing ``__call__`` will return the loss.
"""
if not self._engine._has_prepared["eval"]:
raise RuntimeError(
"The model for evaluation has not been prepared, please set 'loss' when constructing DistModel."
)
self._engine._prepare_program(mode="eval", init_parameters=False)

self._mode = "eval"
self._engine.to_mode("eval")
paddle.disable_static()

def predict(self):
"""
Expand All @@ -1118,11 +1098,16 @@ def predict(self):
outputs of the model.
"""
if not self._engine._has_prepared["predict"]:
raise RuntimeError(
"The model for prediction has not been prepared."
self._engine.prepare(
copy.deepcopy(self._engine._inputs_spec),
None,
mode="predict",
init_parameters=False,
)

self._mode = "predict"
self._engine.to_mode("predict")
paddle.disable_static()

def __validate_mode(self, mode):
if mode is None and self._mode is None:
Expand Down
21 changes: 13 additions & 8 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,14 +825,19 @@ def _plan(self, mode):
if var.name in block.vars:
feed_list.append(block.vars[var.name])

self._dp_world_sizes = []
self._dp_ranks = []
for feed_var in feed_list:
dp_world_size, dp_rank = auto_utils.get_input_split_info(
self._cur_rank, feed_var, self._dist_contexts[mode]
)
self._dp_world_sizes.append(dp_world_size)
self._dp_ranks.append(dp_rank)
self._dp_world_sizes = getattr(self, "_dp_world_sizes", [])
self._dp_ranks = getattr(self, "_dp_ranks", [])
if mode in ['eval', 'predice'] or (
not self._dp_world_sizes and not self._dp_ranks
):
self._dp_world_sizes = []
self._dp_ranks = []
for feed_var in feed_list:
dp_world_size, dp_rank = auto_utils.get_input_split_info(
self._cur_rank, feed_var, self._dist_contexts[mode]
)
self._dp_world_sizes.append(dp_world_size)
self._dp_ranks.append(dp_rank)

def _parallel(self, mode, all_ranks=False):
# Parallelize program based on the planner's results
Expand Down
37 changes: 36 additions & 1 deletion python/paddle/distributed/auto_parallel/static/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
from collections import defaultdict

import paddle
from paddle.jit import not_to_static, to_static
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
Expand All @@ -30,6 +31,7 @@
)

from .converter import Converter
from .process_group import get_world_process_group
from .utils import get_logger, to_list


Expand Down Expand Up @@ -251,7 +253,9 @@ def build_program(self, mode):

# NOTE(dev): Because @to_static is a Lazy mechanism, so we explicitly call this to trigger
# generating Program IR immediately.
concrete_program = getattr(self.proxy_layer, func_name).concrete_program
concrete_program = getattr(
self.proxy_layer, func_name
).concrete_program # noqa: B018
prepare_op_amp_options(
concrete_program.main_program,
ProgramTranslator.get_instance()._amp_records,
Expand Down Expand Up @@ -335,8 +339,26 @@ def static_func(self):
def init(self, main_program, place, dist_context):
if self.lazy_init:
return

is_comm = False
for param in self.concrete_program.parameters:
if param.is_dist():
serial_main_program = self.concrete_program.main_program
var = serial_main_program.global_block().vars[param.name]
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
var
)
is_comm = True
tmp = paddle.base.core.reshard(param, var_dist_attr)
if tmp._is_initialized():
param.get_tensor()._share_data_with(tmp.get_tensor())
else:
param = None
paddle.device.synchronize()

# create var in scope and share parameters to scope
if param is None:
continue
if param.name not in main_program.global_block().vars:
continue
if param.is_dense():
Expand All @@ -361,6 +383,19 @@ def init(self, main_program, place, dist_context):
dense_tensor = global_scope().var(param.name).get_tensor()
dense_tensor._share_data_with(param.get_tensor().get_tensor())

world_group = get_world_process_group()
if (
is_comm
and world_group.nranks > 1
and paddle.distributed.get_world_size() > 1
):
paddle.disable_static()
barrier_tensor = paddle.full([1], 1, dtype="int32")
paddle._legacy_C_ops.barrier(
barrier_tensor, barrier_tensor, 'ring_id', 0
)
paddle.enable_static()

@property
def concrete_program(self):
return self.static_func().concrete_program
Expand Down
8 changes: 8 additions & 0 deletions test/auto_parallel/hybrid_strategy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,11 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_cross_mesh_reshard PROPERTIES TIMEOUT "120" LABELS
"RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_semi_auto_parallel_llama_model_vpp MODULES
test_semi_auto_parallel_llama_model_vpp ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_semi_auto_parallel_llama_model_vpp
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=HYBRID")
endif()
Loading

0 comments on commit 6a8c3c5

Please sign in to comment.