Skip to content

Commit

Permalink
enable ir pass for auto_parallel (PaddlePaddle#58614)
Browse files Browse the repository at this point in the history
* enable ir pass

* add test

* fix pass_list

* refine code

* update flags

* update pir

* skip test if not a100

* hotfix ut
  • Loading branch information
zhiqiu authored Nov 8, 2023
1 parent 1255d56 commit d659b58
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 12 deletions.
12 changes: 10 additions & 2 deletions paddle/fluid/framework/new_executor/interpreter/plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ const std::vector<std::shared_ptr<Job>>& Plan::JobList() const {
return job_list_;
}

const std::vector<std::string> Plan::JobTypes() const {
std::vector<std::string> res;
for (auto kv : type_to_ir_program_) {
res.emplace_back(kv.first);
}
return res;
}

const std::shared_ptr<ProgramDesc> Plan::Program(
const std::string& job_type) const {
return type_to_program_.at(job_type);
Expand All @@ -76,8 +84,8 @@ std::shared_ptr<::pir::Program> Plan::IrProgram(
return type_to_ir_program_.at(job_type);
}

void Plan::UpdateIrProgram(const std::string& job_type,
std::shared_ptr<::pir::Program> ir_prog) {
void Plan::SetIrProgram(const std::string& job_type,
std::shared_ptr<::pir::Program> ir_prog) {
type_to_ir_program_[job_type] = ir_prog;
}

Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/new_executor/interpreter/plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ class Plan final {
~Plan() = default;

const std::vector<std::shared_ptr<Job>>& JobList() const;
const std::vector<std::string> JobTypes() const;

const std::shared_ptr<ProgramDesc> Program(const std::string& job_type) const;
std::shared_ptr<::pir::Program> IrProgram(const std::string& job_type) const;

void UpdateIrProgram(const std::string& job_type,
std::shared_ptr<::pir::Program> ir_prog);
void SetIrProgram(const std::string& job_type,
std::shared_ptr<::pir::Program> ir_prog);

int64_t MicroBatchNum() const;

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/new_executor/standalone_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(base_program.get(), place);
std::shared_ptr<pir::Program> shared_program = std::move(kernel_program);
plan_.UpdateIrProgram("job_" + std::to_string(job_idx), shared_program);
plan_.SetIrProgram("job_" + std::to_string(job_idx), shared_program);

if (FLAGS_pir_apply_inplace_pass) {
pir::PassManager pm(pir::IrContext::Instance(), 3);
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2046,7 +2046,10 @@ All parameter, weight, gradient are variables in Paddle.
py::arg("job_list"),
py::arg("type_to_ir_program"))
.def("job_list", &framework::interpreter::Plan::JobList)
.def("job_types", &framework::interpreter::Plan::JobTypes)
.def("micro_batch_num", &framework::interpreter::Plan::MicroBatchNum)
.def("set_ir_program", &framework::interpreter::Plan::SetIrProgram)
.def("ir_program", &framework::interpreter::Plan::IrProgram)
.def("program", &framework::interpreter::Plan::Program);

m.def("init_gflags", framework::InitGflags);
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import numpy as np

from paddle import pir

from ..pir import OpResult
from ..pir import Program as PirProgram
from ..pir import Value, translate_to_pir
Expand Down Expand Up @@ -1036,6 +1038,18 @@ def _get_program_and_executor(self, cached_data):
type_to_program = {"default": new_program.desc}
plan = core.Plan([default_job], type_to_program)

if (
new_program._pass_opt
and "pass_list" in new_program._pass_opt
and len(new_program._pass_opt['pass_list']) > 0
):
pm = pir.PassManager()
for p in new_program._pass_opt['pass_list']:
pm.add_pass(p)
for job_type in plan.job_types():
ir_program = plan.ir_program(job_type)
pm.run(ir_program)

new_exe = _StandaloneExecutor(place, plan, scope)
return new_program, new_exe

Expand Down
4 changes: 3 additions & 1 deletion python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -5676,6 +5676,7 @@ def __init__(self):

# assigned if this program has been parsed by a pipeline optimizer
self._pipeline_opt = None
self._pass_opt = None

# assigned if this program has been parsed by a heter pipeline parameter server optimizer
self._heter_pipeline_opt = None
Expand Down Expand Up @@ -6313,7 +6314,8 @@ def clone(self, for_test=False):
p.lr_scheduler = self.lr_scheduler
if hasattr(self, '_pipeline_opt'):
p._pipeline_opt = self._pipeline_opt

if hasattr(self, '_pass_opt'):
p._pass_opt = self._pass_opt
# NOTE(zhiqiu): we sync the cloned program, to update its program by
# its desc.
p._sync_with_cpp()
Expand Down
22 changes: 19 additions & 3 deletions python/paddle/distributed/auto_parallel/static/parallelizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import time

from paddle.distributed.passes import PassManager, new_pass
from paddle.framework import get_flags
from paddle.static import append_backward, program_guard
from paddle.utils import unique_name

Expand All @@ -28,6 +29,12 @@
from .reshard import Resharder
from .utils import get_pp_stage, is_sequential_run, use_new_executor

NEW_IR_PASS = [
'fused_gemm_epilogue_pass',
'fused_linear_param_grad_add_pass',
'fused_dropout_add_pass',
]


class Parallelizer:
def __init__(self, mode, completer, dist_context):
Expand Down Expand Up @@ -423,14 +430,24 @@ def _apply_post_optimization(
[main_program], [startup_program], self._pass_context
)

enable_ir = get_flags("FLAGS_enable_pir_in_executor")[
'FLAGS_enable_pir_in_executor'
]
ir_pass_list = []
if self.is_train and self._strategy.fused_passes.enable:
if len(self._strategy.fused_passes.fused_passes_list) > 0:
new_pass_list = []
for op in self._strategy.fused_passes.fused_passes_list:
new_pass_list.append(new_pass(op))
for p in self._strategy.fused_passes.fused_passes_list:
if p in NEW_IR_PASS and enable_ir:
ir_pass_list.append(p)
else:
new_pass_list.append(new_pass(p))
pass_manager = PassManager(new_pass_list)
pass_manager.apply([main_program], [startup_program])

main_program._pass_opt = {}
main_program._pass_opt['pass_list'] = ir_pass_list

if (
self.is_train
and self._strategy.pipeline.enable
Expand All @@ -448,7 +465,6 @@ def _apply_post_optimization(
"variable CUDA_DEVICE_MAX_CONNECTIONS=1, which may leads to performance "
"loss. Try to export CUDA_DEVICE_MAX_CONNECTIONS=1 for better performance."
)

main_program._pipeline_opt = {}
main_program._pipeline_opt["standalone_opt"] = {
"enable_send_recv_overlap": enable_send_recv_overlap,
Expand Down
61 changes: 58 additions & 3 deletions test/auto_parallel/gpt_with_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np
from get_gpt_model import FakeDataset, generate_model
from test_sparse_addmm_op import get_cuda_version

import paddle
from paddle.distributed import ParallelEnv
Expand All @@ -26,7 +27,7 @@
paddle.enable_static()


def apply_pass(use_sharding=False, pipeline_mode=None):
def apply_pass(use_sharding=False, pipeline_mode=None, fuse_passes_list=None):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
Expand Down Expand Up @@ -57,6 +58,11 @@ def apply_pass(use_sharding=False, pipeline_mode=None):
pipeline.schedule_mode = pipeline_mode
pipeline.accumulate_steps = 2

if fuse_passes_list:
fused_passes = strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list = fuse_passes_list

return strategy


Expand Down Expand Up @@ -87,10 +93,19 @@ def init(self, engine, name):
place = paddle.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)

def get_engine(self, mode, name, use_sharding=False, pipeline_mode=None):
def get_engine(
self,
mode,
name,
use_sharding=False,
pipeline_mode=None,
fuse_passes_list=None,
):
reset_prog()

strategy = apply_pass(use_sharding, pipeline_mode)
paddle.set_default_dtype('float32')

strategy = apply_pass(use_sharding, pipeline_mode, fuse_passes_list)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model(mode, dropout_prob=0.1)
Expand Down Expand Up @@ -131,6 +146,46 @@ def test_dp(self):
out_dp_prog.history["loss"][0], out_dp_ir.history["loss"][0]
)

def test_dp_with_fused_linear(self):
if not get_cuda_version() >= 11060:
return

self.enable_pir(False)
engine_dp_prog = self.get_engine(
"dp",
name="dp_prog_fuse_linear",
fuse_passes_list=['fuse_gemm_epilogue'],
)
out_dp_prog = engine_dp_prog.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)

self.enable_pir(True)
engine_dp_ir = self.get_engine(
"dp",
name="dp_pir_fuse_linear",
use_sharding=True,
fuse_passes_list=['fused_gemm_epilogue_pass'],
)
out_dp_ir = engine_dp_ir.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
# TODO(zhiqiu): fix accuracy problem and use array_equal to check it
np.testing.assert_allclose(
out_dp_prog.history["loss"][0],
out_dp_ir.history["loss"][0],
rtol=1e-5,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__,
out_dp_prog.history["loss"][0],
out_dp_ir.history["loss"][0],
out_dp_prog.history["loss"][0] - out_dp_ir.history["loss"][0],
),
)
# self.check_results(
# out_dp_prog.history["loss"][0], out_dp_ir.history["loss"][0]
# )

def test_mp(self):
self.enable_pir(False)
engine_mp_prog = self.get_engine("mp", name="mp_prog")
Expand Down

0 comments on commit d659b58

Please sign in to comment.