Skip to content

Commit

Permalink
Fix DP parallel when using parallel vae
Browse files Browse the repository at this point in the history
  • Loading branch information
gty111 committed Sep 24, 2024
1 parent 989a0b7 commit 23f80af
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 33 deletions.
5 changes: 3 additions & 2 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def main():
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
engine_config.runtime_config.dtype = torch.bfloat16
local_rank = get_world_group().local_rank

pipe = xFuserFluxPipeline.from_pretrained(
Expand All @@ -32,7 +33,7 @@ def main():
else:
pipe = pipe.to(f"cuda:{local_rank}")

pipe.prepare_run(input_config)
pipe.prepare_run(input_config, steps=1)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
Expand Down Expand Up @@ -60,7 +61,7 @@ def main():
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if pipe.is_dp_last_group():
for i, image in enumerate(output.images):
image_rank = dp_group_index * dp_batch_size + i
image_name = f"flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png"
Expand Down
2 changes: 1 addition & 1 deletion examples/hunyuandit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main():
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if pipe.is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(output.images):
Expand Down
2 changes: 1 addition & 1 deletion examples/pixartalpha_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main():
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if pipe.is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(output.images):
Expand Down
2 changes: 1 addition & 1 deletion examples/pixartsigma_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main():
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if pipe.is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(output.images):
Expand Down
2 changes: 1 addition & 1 deletion examples/sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main():
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if pipe.is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(output.images):
Expand Down
17 changes: 0 additions & 17 deletions xfuser/core/distributed/group_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,23 +294,6 @@ def broadcast(self, input_: torch.Tensor, src: int = 0):
input_, src=self.ranks[src], group=self.device_group
)
return input_

def broadcast_latent(self, input_: torch.Tensor, dtype: torch.dtype):
"""Broadcast the final latent result
NOTE: we assume the last rank owns latents
"""
latent_rank = self.world_size - 1
device = f"cuda:{self.rank}"
if self.rank == latent_rank:
input_shape = torch.tensor(input_.shape,dtype=torch.int).to(device)
else:
input_shape = torch.zeros(4,dtype=torch.int).cuda().to(device)
self.broadcast(input_shape,src=latent_rank)

if self.rank != latent_rank:
input_ = torch.zeros(torch.Size(input_shape),dtype=dtype).to(device)
self.broadcast(input_,src=latent_rank)
return input_

def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
"""Broadcast the input object.
Expand Down
65 changes: 65 additions & 0 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
get_world_group,
get_runtime_state,
initialize_runtime_state,
is_dp_last_group,
get_sequence_parallel_rank,
)
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper

Expand Down Expand Up @@ -392,3 +394,66 @@ def _process_cfg_split_batch_latte(
else:
raise ValueError("Invalid classifier free guidance rank")
return concat_group_0

def is_dp_last_group(self):
"""Return True if in the last data parallel group, False otherwise.
Also include parallel vae situation.
"""
if get_runtime_state().runtime_config.use_parallel_vae:
return get_world_group().rank == 0
else:
return is_dp_last_group()

def gather_broadcast_latents(self, latents:torch.Tensor):
"""gather latents from dp last group and broacast final latents
"""

# ---------gather latents from dp last group-----------
rank = get_world_group().rank
device = f"cuda:{rank}"

# all gather dp last group rank list
dp_rank_list = [torch.zeros(1, dtype=int, device=device) for _ in range(get_world_group().world_size)]
if is_dp_last_group():
gather_rank = int(rank)
else:
gather_rank = -1
torch.distributed.all_gather(dp_rank_list, torch.tensor([gather_rank],dtype=int,device=device))

dp_rank_list = [int(dp_rank[0]) for dp_rank in dp_rank_list if int(dp_rank[0])!=-1]
dp_last_group = torch.distributed.new_group(dp_rank_list)

# gather latents from dp last group
if rank == dp_rank_list[-1]:
latents_list = [torch.zeros_like(latents) for _ in dp_rank_list]
else:
latents_list = None
if rank in dp_rank_list:
torch.distributed.gather(latents, latents_list, dst=dp_rank_list[-1], group=dp_last_group)

if rank == dp_rank_list[-1]:
latents = torch.cat(latents_list,dim=0)

# ------broadcast latents to all nodes---------
src = dp_rank_list[-1]
latents_shape_len = torch.zeros(1,dtype=torch.int,device=device)

# broadcast latents shape len
if rank == src:
latents_shape_len[0] = len(latents.shape)
get_world_group().broadcast(latents_shape_len,src=src)

# broadcast latents shape
if rank == src:
input_shape = torch.tensor(latents.shape,dtype=torch.int,device=device)
else:
input_shape = torch.zeros(latents_shape_len[0],dtype=torch.int,device=device)
get_world_group().broadcast(input_shape,src=src)

# broadcast latents
if rank != src:
dtype = get_runtime_state().runtime_config.dtype
latents = torch.zeros(torch.Size(input_shape),dtype=dtype,device=device)
get_world_group().broadcast(latents,src=src)

return latents
4 changes: 2 additions & 2 deletions xfuser/model_executor/pipelines/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,13 @@ def vae_decode(latents):

if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = get_world_group().broadcast_latent(latents,get_runtime_state().runtime_config.dtype)
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)

if is_dp_last_group():
if self.is_dp_last_group():
if output_type == "latent":
image = latents

Expand Down
4 changes: 2 additions & 2 deletions xfuser/model_executor/pipelines/pipeline_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,12 +463,12 @@ def vae_decode(latents):

if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = get_world_group().broadcast_latent(latents,get_runtime_state().runtime_config.dtype)
latents = self.gather_broadcast_latents(latents)
vae_decode(latents)
else:
if is_dp_last_group():
vae_decode(latents)
if is_dp_last_group():
if self.is_dp_last_group():
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if not output_type == "latent":
image, has_nsfw_concept = self.run_safety_checker(
Expand Down
4 changes: 2 additions & 2 deletions xfuser/model_executor/pipelines/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,12 @@ def vae_decode(latents):

if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = get_world_group().broadcast_latent(latents,get_runtime_state().runtime_config.dtype)
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)
if is_dp_last_group():
if self.is_dp_last_group():
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if not output_type == "latent":
if use_resolution_binning:
Expand Down
4 changes: 2 additions & 2 deletions xfuser/model_executor/pipelines/pipeline_pixart_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,13 @@ def vae_decode(latents):

if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = get_world_group().broadcast_latent(latents,get_runtime_state().runtime_config.dtype)
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)

if is_dp_last_group():
if self.is_dp_last_group():
if not output_type == "latent":
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,13 @@ def vae_decode(latents):

if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = get_world_group().broadcast_latent(latents,get_runtime_state().runtime_config.dtype)
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)

if is_dp_last_group():
if self.is_dp_last_group():
if output_type == "latent":
image = latents
else:
Expand Down

0 comments on commit 23f80af

Please sign in to comment.