Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] dp issue #249

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions examples/sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from xfuser.core.distributed import (
get_world_group,
is_dp_last_group,
get_data_parallel_world_size,
get_data_parallel_rank,
get_runtime_state,
)
from xfuser.core.distributed.parallel_state import get_data_parallel_world_size


def main():
Expand Down Expand Up @@ -45,10 +46,8 @@ def main():
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
global_rank = get_world_group().rank
dp_group_world_size = get_data_parallel_world_size()
dp_group_index = global_rank // dp_group_world_size
num_dp_groups = engine_config.parallel_config.dp_degree
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if not os.path.exists("results"):
Expand Down
41 changes: 7 additions & 34 deletions tests/parallel_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from xfuser.parallel import xdit_parallel
from xfuser.parallel import xDiTParallel

import time
import os
import torch
import torch.distributed as dist
from diffusers import StableDiffusion3Pipeline

from xfuser import xFuserArgs
Expand All @@ -22,20 +21,17 @@ def main():
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()

local_rank = get_world_group().local_rank
pipe = StableDiffusion3Pipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
torch_dtype=torch.float16,
)
).to(f"cuda:{local_rank}")

pipe = xdit_parallel(pipe, engine_config)
local_rank = get_world_group().local_rank
pipe.to(f"cuda:{local_rank}")

pipe.prepare_run(input_config)
paralleler = xDiTParallel(pipe, engine_config, input_config)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
paralleler(
height=input_config.height,
width=input_config.height,
prompt=input_config.prompt,
Expand All @@ -47,33 +43,10 @@ def main():
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
global_rank = get_world_group().rank
dp_group_world_size = get_data_parallel_world_size()
dp_group_index = global_rank // dp_group_world_size
num_dp_groups = engine_config.parallel_config.dp_degree
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(output.images):
image_rank = dp_group_index * dp_batch_size + i
image.save(
f"./results/stable_diffusion_3_result_{parallel_info}_{image_rank}.png"
)
print(
f"image {i} saved to ./results/stable_diffusion_3_result_{parallel_info}_{image_rank}.png"
)
paralleler.save("results/", "stable_diffusion_3")

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"{parallel_info} epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB"
)
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()


Expand Down
3 changes: 3 additions & 0 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def __post_init__(self):
len(self.prompt) == len(self.negative_prompt)
or len(self.negative_prompt) == 0
), "prompts and negative_prompts must have the same quantities"
self.batch_size = self.batch_size or len(self.prompt)
else:
self.batch_size = self.batch_size or 1
assert self.output_type in [
"pil",
"latent",
Expand Down
4 changes: 3 additions & 1 deletion xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def data_parallel_fn(self, *args, **kwargs):
batch_size = len(prompt) if isinstance(prompt, list) else 1
if batch_size > 1:
dp_degree = get_runtime_state().parallel_config.dp_degree
dp_group_rank = get_world_group().rank // get_data_parallel_world_size()
dp_group_rank = get_world_group().rank // (
get_world_group().world_size // get_data_parallel_world_size()
)
dp_group_batch_size = (batch_size + dp_degree - 1) // dp_degree
start_batch_idx = dp_group_rank * dp_group_batch_size
end_batch_idx = min(
Expand Down
49 changes: 43 additions & 6 deletions xfuser/parallel.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
import os
from typing import Any, Type, Union
from diffusers.pipelines.pipeline_utils import DiffusionPipeline

from xfuser.config.config import InputConfig
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
from xfuser.config import EngineConfig
from xfuser.core.distributed.parallel_state import (
get_data_parallel_rank,
get_data_parallel_world_size,
is_dp_last_group,
)
from xfuser.logger import init_logger
from xfuser.model_executor.pipelines.base_pipeline import xFuserPipelineBaseWrapper
from xfuser.model_executor.pipelines.register import xFuserPipelineWrapperRegister

logger = init_logger(__name__)


def xdit_parallel(pipe, engine_config: EngineConfig):
if isinstance(pipe, type):
xfuser_pipe_class = xFuserPipelineWrapperRegister.get_class(pipe)
return xfuser_pipe_class
else:
class xDiTParallel:
def __init__(self, pipe, engine_config: EngineConfig, input_config: InputConfig):
xfuser_pipe_wrapper = xFuserPipelineWrapperRegister.get_class(pipe)
return xfuser_pipe_wrapper(pipeline=pipe, engine_config=engine_config)
self.pipe = xfuser_pipe_wrapper(pipeline=pipe, engine_config=engine_config)
self.config = engine_config
self.pipe.prepare_run(input_config)

def __call__(
self,
*args,
**kwargs,
):
self.result = self.pipe(*args, **kwargs)
return self.result

def save(self, directory: str, prefix: str):
dp_rank = get_data_parallel_rank()
parallel_info = (
f"dp{self.config.parallel_config.dp_degree}_cfg{self.config.parallel_config.cfg_degree}_"
f"ulysses{self.config.parallel_config.ulysses_degree}_ring{self.config.parallel_config.ring_degree}_"
f"pp{self.config.parallel_config.pp_degree}_patch{self.config.parallel_config.pp_config.num_pipeline_patch}"
)
prefix = f"{directory}/{prefix}_result_{parallel_info}_dprank{dp_rank}"
if is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(self.result.images):
image.save(f"{prefix}_image{i}.png")
print(f"{prefix}_image{i}.png")