Skip to content

Commit

Permalink
[Fix] data parallel bugs (xdit-project#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
Eigensystem authored and feifeibear committed Oct 25, 2024
1 parent 56b1975 commit 346b381
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 46 deletions.
9 changes: 4 additions & 5 deletions examples/sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from xfuser.core.distributed import (
get_world_group,
is_dp_last_group,
get_data_parallel_world_size,
get_data_parallel_rank,
get_runtime_state,
)
from xfuser.core.distributed.parallel_state import get_data_parallel_world_size


def main():
Expand Down Expand Up @@ -45,10 +46,8 @@ def main():
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
global_rank = get_world_group().rank
dp_group_world_size = get_data_parallel_world_size()
dp_group_index = global_rank // dp_group_world_size
num_dp_groups = engine_config.parallel_config.dp_degree
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if not os.path.exists("results"):
Expand Down
41 changes: 7 additions & 34 deletions tests/parallel_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from xfuser.parallel import xdit_parallel
from xfuser.parallel import xDiTParallel

import time
import os
import torch
import torch.distributed as dist
from diffusers import StableDiffusion3Pipeline

from xfuser import xFuserArgs
Expand All @@ -22,20 +21,17 @@ def main():
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()

local_rank = get_world_group().local_rank
pipe = StableDiffusion3Pipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
torch_dtype=torch.float16,
)
).to(f"cuda:{local_rank}")

pipe = xdit_parallel(pipe, engine_config)
local_rank = get_world_group().local_rank
pipe.to(f"cuda:{local_rank}")

pipe.prepare_run(input_config)
paralleler = xDiTParallel(pipe, engine_config, input_config)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
paralleler(
height=input_config.height,
width=input_config.height,
prompt=input_config.prompt,
Expand All @@ -47,33 +43,10 @@ def main():
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
global_rank = get_world_group().rank
dp_group_world_size = get_data_parallel_world_size()
dp_group_index = global_rank // dp_group_world_size
num_dp_groups = engine_config.parallel_config.dp_degree
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(output.images):
image_rank = dp_group_index * dp_batch_size + i
image.save(
f"./results/stable_diffusion_3_result_{parallel_info}_{image_rank}.png"
)
print(
f"image {i} saved to ./results/stable_diffusion_3_result_{parallel_info}_{image_rank}.png"
)
paralleler.save("results/", "stable_diffusion_3")

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"{parallel_info} epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB"
)
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()


Expand Down
3 changes: 3 additions & 0 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def __post_init__(self):
len(self.prompt) == len(self.negative_prompt)
or len(self.negative_prompt) == 0
), "prompts and negative_prompts must have the same quantities"
self.batch_size = self.batch_size or len(self.prompt)
else:
self.batch_size = self.batch_size or 1
assert self.output_type in [
"pil",
"latent",
Expand Down
4 changes: 3 additions & 1 deletion xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def data_parallel_fn(self, *args, **kwargs):
batch_size = len(prompt) if isinstance(prompt, list) else 1
if batch_size > 1:
dp_degree = get_runtime_state().parallel_config.dp_degree
dp_group_rank = get_world_group().rank // get_data_parallel_world_size()
dp_group_rank = get_world_group().rank // (
get_world_group().world_size // get_data_parallel_world_size()
)
dp_group_batch_size = (batch_size + dp_degree - 1) // dp_degree
start_batch_idx = dp_group_rank * dp_group_batch_size
end_batch_idx = min(
Expand Down
49 changes: 43 additions & 6 deletions xfuser/parallel.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
import os
from typing import Any, Type, Union
from diffusers.pipelines.pipeline_utils import DiffusionPipeline

from xfuser.config.config import InputConfig
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
from xfuser.config import EngineConfig
from xfuser.core.distributed.parallel_state import (
get_data_parallel_rank,
get_data_parallel_world_size,
is_dp_last_group,
)
from xfuser.logger import init_logger
from xfuser.model_executor.pipelines.base_pipeline import xFuserPipelineBaseWrapper
from xfuser.model_executor.pipelines.register import xFuserPipelineWrapperRegister

logger = init_logger(__name__)


def xdit_parallel(pipe, engine_config: EngineConfig):
if isinstance(pipe, type):
xfuser_pipe_class = xFuserPipelineWrapperRegister.get_class(pipe)
return xfuser_pipe_class
else:
class xDiTParallel:
def __init__(self, pipe, engine_config: EngineConfig, input_config: InputConfig):
xfuser_pipe_wrapper = xFuserPipelineWrapperRegister.get_class(pipe)
return xfuser_pipe_wrapper(pipeline=pipe, engine_config=engine_config)
self.pipe = xfuser_pipe_wrapper(pipeline=pipe, engine_config=engine_config)
self.config = engine_config
self.pipe.prepare_run(input_config)

def __call__(
self,
*args,
**kwargs,
):
self.result = self.pipe(*args, **kwargs)
return self.result

def save(self, directory: str, prefix: str):
dp_rank = get_data_parallel_rank()
parallel_info = (
f"dp{self.config.parallel_config.dp_degree}_cfg{self.config.parallel_config.cfg_degree}_"
f"ulysses{self.config.parallel_config.ulysses_degree}_ring{self.config.parallel_config.ring_degree}_"
f"pp{self.config.parallel_config.pp_degree}_patch{self.config.parallel_config.pp_config.num_pipeline_patch}"
)
prefix = f"{directory}/{prefix}_result_{parallel_info}_dprank{dp_rank}"
if is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(self.result.images):
image.save(f"{prefix}_image{i}.png")
print(f"{prefix}_image{i}.png")

0 comments on commit 346b381

Please sign in to comment.