diff --git a/examples/sd3_example.py b/examples/sd3_example.py index 29dc7de7..91149e39 100644 --- a/examples/sd3_example.py +++ b/examples/sd3_example.py @@ -7,9 +7,10 @@ from xfuser.core.distributed import ( get_world_group, is_dp_last_group, - get_data_parallel_world_size, + get_data_parallel_rank, get_runtime_state, ) +from xfuser.core.distributed.parallel_state import get_data_parallel_world_size def main(): @@ -45,10 +46,8 @@ def main(): f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" ) if input_config.output_type == "pil": - global_rank = get_world_group().rank - dp_group_world_size = get_data_parallel_world_size() - dp_group_index = global_rank // dp_group_world_size - num_dp_groups = engine_config.parallel_config.dp_degree + dp_group_index = get_data_parallel_rank() + num_dp_groups = get_data_parallel_world_size() dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups if is_dp_last_group(): if not os.path.exists("results"): diff --git a/tests/parallel_test.py b/tests/parallel_test.py index c2912b2c..7992c74f 100644 --- a/tests/parallel_test.py +++ b/tests/parallel_test.py @@ -1,9 +1,8 @@ -from xfuser.parallel import xdit_parallel +from xfuser.parallel import xDiTParallel import time import os import torch -import torch.distributed as dist from diffusers import StableDiffusion3Pipeline from xfuser import xFuserArgs @@ -22,20 +21,17 @@ def main(): engine_args = xFuserArgs.from_cli_args(args) engine_config, input_config = engine_args.create_config() + local_rank = get_world_group().local_rank pipe = StableDiffusion3Pipeline.from_pretrained( pretrained_model_name_or_path=engine_config.model_config.model, torch_dtype=torch.float16, - ) + ).to(f"cuda:{local_rank}") - pipe = xdit_parallel(pipe, engine_config) - local_rank = get_world_group().local_rank - pipe.to(f"cuda:{local_rank}") - - pipe.prepare_run(input_config) + paralleler = xDiTParallel(pipe, engine_config, input_config) torch.cuda.reset_peak_memory_stats() start_time = time.time() - output = pipe( + paralleler( height=input_config.height, width=input_config.height, prompt=input_config.prompt, @@ -47,33 +43,10 @@ def main(): elapsed_time = end_time - start_time peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") - parallel_info = ( - f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" - f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" - f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" - ) - if input_config.output_type == "pil": - global_rank = get_world_group().rank - dp_group_world_size = get_data_parallel_world_size() - dp_group_index = global_rank // dp_group_world_size - num_dp_groups = engine_config.parallel_config.dp_degree - dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups - if is_dp_last_group(): - if not os.path.exists("results"): - os.mkdir("results") - for i, image in enumerate(output.images): - image_rank = dp_group_index * dp_batch_size + i - image.save( - f"./results/stable_diffusion_3_result_{parallel_info}_{image_rank}.png" - ) - print( - f"image {i} saved to ./results/stable_diffusion_3_result_{parallel_info}_{image_rank}.png" - ) + paralleler.save("results/", "stable_diffusion_3") if get_world_group().rank == get_world_group().world_size - 1: - print( - f"{parallel_info} epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB" - ) + print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB") get_runtime_state().destory_distributed_env() diff --git a/xfuser/config/config.py b/xfuser/config/config.py index 1a15fe65..e2d38edf 100644 --- a/xfuser/config/config.py +++ b/xfuser/config/config.py @@ -231,6 +231,9 @@ def __post_init__(self): len(self.prompt) == len(self.negative_prompt) or len(self.negative_prompt) == 0 ), "prompts and negative_prompts must have the same quantities" + self.batch_size = self.batch_size or len(self.prompt) + else: + self.batch_size = self.batch_size or 1 assert self.output_type in [ "pil", "latent", diff --git a/xfuser/model_executor/pipelines/base_pipeline.py b/xfuser/model_executor/pipelines/base_pipeline.py index ba68a256..a0990efa 100644 --- a/xfuser/model_executor/pipelines/base_pipeline.py +++ b/xfuser/model_executor/pipelines/base_pipeline.py @@ -117,7 +117,9 @@ def data_parallel_fn(self, *args, **kwargs): batch_size = len(prompt) if isinstance(prompt, list) else 1 if batch_size > 1: dp_degree = get_runtime_state().parallel_config.dp_degree - dp_group_rank = get_world_group().rank // get_data_parallel_world_size() + dp_group_rank = get_world_group().rank // ( + get_world_group().world_size // get_data_parallel_world_size() + ) dp_group_batch_size = (batch_size + dp_degree - 1) // dp_degree start_batch_idx = dp_group_rank * dp_group_batch_size end_batch_idx = min( diff --git a/xfuser/parallel.py b/xfuser/parallel.py index 9979ce25..4d2bbf89 100644 --- a/xfuser/parallel.py +++ b/xfuser/parallel.py @@ -1,14 +1,51 @@ +import os +from typing import Any, Type, Union +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from xfuser.config.config import InputConfig +from xfuser.core.distributed import ( + init_distributed_environment, + initialize_model_parallel, +) from xfuser.config import EngineConfig +from xfuser.core.distributed.parallel_state import ( + get_data_parallel_rank, + get_data_parallel_world_size, + is_dp_last_group, +) from xfuser.logger import init_logger +from xfuser.model_executor.pipelines.base_pipeline import xFuserPipelineBaseWrapper from xfuser.model_executor.pipelines.register import xFuserPipelineWrapperRegister logger = init_logger(__name__) -def xdit_parallel(pipe, engine_config: EngineConfig): - if isinstance(pipe, type): - xfuser_pipe_class = xFuserPipelineWrapperRegister.get_class(pipe) - return xfuser_pipe_class - else: +class xDiTParallel: + def __init__(self, pipe, engine_config: EngineConfig, input_config: InputConfig): xfuser_pipe_wrapper = xFuserPipelineWrapperRegister.get_class(pipe) - return xfuser_pipe_wrapper(pipeline=pipe, engine_config=engine_config) + self.pipe = xfuser_pipe_wrapper(pipeline=pipe, engine_config=engine_config) + self.config = engine_config + self.pipe.prepare_run(input_config) + + def __call__( + self, + *args, + **kwargs, + ): + self.result = self.pipe(*args, **kwargs) + return self.result + + def save(self, directory: str, prefix: str): + dp_rank = get_data_parallel_rank() + parallel_info = ( + f"dp{self.config.parallel_config.dp_degree}_cfg{self.config.parallel_config.cfg_degree}_" + f"ulysses{self.config.parallel_config.ulysses_degree}_ring{self.config.parallel_config.ring_degree}_" + f"pp{self.config.parallel_config.pp_degree}_patch{self.config.parallel_config.pp_config.num_pipeline_patch}" + ) + prefix = f"{directory}/{prefix}_result_{parallel_info}_dprank{dp_rank}" + if is_dp_last_group(): + if not os.path.exists("results"): + os.mkdir("results") + for i, image in enumerate(self.result.images): + image.save(f"{prefix}_image{i}.png") + print(f"{prefix}_image{i}.png")