Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flux-dev oom with 2gpus(each gpu is 24576MiB) #345

Open
algorithmconquer opened this issue Nov 14, 2024 · 9 comments
Open

flux-dev oom with 2gpus(each gpu is 24576MiB) #345

algorithmconquer opened this issue Nov 14, 2024 · 9 comments

Comments

@algorithmconquer
Copy link

The command is:
torchrun --nproc_per_node=2 ./examples/flux_example.py --model ./FLUX.1-dev/ --pipefusion_parallel_degree 1 --ulysses_degree 1 --ring_degree 1 --height 1024 --width 1024 --no_use_resolution_binning --output_type latent --num_inference_steps 28 --warmup_steps 1 --prompt 'brown dog laying on the ground with a metal bowl in front of him.' --use_cfg_parallel --use_parallel_vae
How to solve the problem?

@feifeibear
Copy link
Collaborator

--pipefusion_parallel_degree 2

Your command line is not valid. The parallel degree should be 2 in total.

@algorithmconquer
Copy link
Author

@feifeibear when the command is "torchrun --nproc_per_node=2 ./examples/flux_example.py --model ./FLUX.1-dev/ --pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1 --height 512 --width 512 --no_use_resolution_binning --output_type latent --num_inference_steps 28 --warmup_steps 1 --prompt 'brown dog laying on the ground with a metal bowl in front of him.' --use_cfg_parallel --use_parallel_vae" is error with word size is not equal 4;
when the command is "torchrun --nproc_per_node=2 ./examples/flux_example.py --model ./FLUX.1-dev/ --pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1 --height 512 --width 512 --no_use_resolution_binning --output_type latent --num_inference_steps 28 --warmup_steps 1 --prompt 'brown dog laying on the ground with a metal bowl in front of him.' --use_parallel_vae" is also oom error;

@feifeibear
Copy link
Collaborator

you should not use --use_cfg_parallel

@algorithmconquer
Copy link
Author

@feifeibear The command does not use --use_cfg_parallel, but it occurs oom error
image

@feifeibear
Copy link
Collaborator

I see, your memory is really small. I have a very simple optimization to avoid OOM. We can use FSDP to load the text encoder. We will add a PR for this ASAP.

@algorithmconquer
Copy link
Author

@feifeibear Thank you for your quick response.But when I use diffusers to inference with height=width=512, the problem will not occur;The code is:
pipe = FluxPipeline.from_pretrained(modelId, torch_dtype=torch.bfloat16, device_map="balanced") image = pipe(prompt, num_inference_steps=28, height=512, width=512, guidance_scale=3.5).images[0] image.save("out.png")

@Lay2000
Copy link
Collaborator

Lay2000 commented Nov 19, 2024

@feifeibear Thank you for your quick response.But when I use diffusers to inference with height=width=512, the problem will not occur;The code is: pipe = FluxPipeline.from_pretrained(modelId, torch_dtype=torch.bfloat16, device_map="balanced") image = pipe(prompt, num_inference_steps=28, height=512, width=512, guidance_scale=3.5).images[0] image.save("out.png")

@algorithmconquer Hello, could you provide the error log of the oom error? We need to check whether the oom error happend in the model loading process or the inference process. If it happened in the loading process. You could simpiliy quantize the Text Encoder into FP8, which could reduce the max memory use to 17GB without any quality loss.

Firstly, install the dependencies by running the following command:
pip install optimum-quanto

Then, you could use the following code to replace the original examples/flux_example.py

import logging
import time
import torch
import torch.distributed
import datetime
from transformers import T5EncoderModel
from optimum.quanto import freeze, qfloat8, quantize
from xfuser import xFuserFluxPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
    get_world_group,
    get_data_parallel_rank,
    get_data_parallel_world_size,
    get_runtime_state,
    is_dp_last_group,
)


def main():
    parser = FlexibleArgumentParser(description="xFuser Arguments")
    args = xFuserArgs.add_cli_args(parser).parse_args()
    engine_args = xFuserArgs.from_cli_args(args)
    engine_config, input_config = engine_args.create_config()
    engine_config.runtime_config.dtype = torch.bfloat16
    local_rank = get_world_group().local_rank
    text_encoder_2 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_2", torch_dtype=torch.bfloat16)

    print(datetime.datetime.now(), "Quantizing text encoder 2")
    quantize(text_encoder_2, weights=qfloat8)
    freeze(text_encoder_2)

    pipe = xFuserFluxPipeline.from_pretrained(
        pretrained_model_name_or_path=engine_config.model_config.model,
        engine_config=engine_config,
        torch_dtype=torch.bfloat16,
        text_encoder_2=text_encoder_2,
    )

    if args.enable_sequential_cpu_offload:
        pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
        logging.info(f"rank {local_rank} sequential CPU offload enabled")
    else:
        pipe = pipe.to(f"cuda:{local_rank}")

    parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

    pipe.prepare_run(input_config, steps=1)
    start_time = time.time()
    output = pipe(
        height=input_config.height,
        width=input_config.width,
        prompt=input_config.prompt,
        num_inference_steps=input_config.num_inference_steps,
        output_type=input_config.output_type,
        max_sequence_length=256,
        guidance_scale=0.0,
        generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
    )
    end_time = time.time()
    elapsed_time = end_time - start_time
    peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

    parallel_info = (
        f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
        f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
        f"tp{engine_args.tensor_parallel_degree}_"
        f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
    )
    if input_config.output_type == "pil":
        dp_group_index = get_data_parallel_rank()
        num_dp_groups = get_data_parallel_world_size()
        dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
        if pipe.is_dp_last_group():
            for i, image in enumerate(output.images):
                image_rank = dp_group_index * dp_batch_size + i
                image_name = f"flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png"
                image.save(f"./results/{image_name}")
                print(f"image {i} saved to ./results/{image_name}")

    if get_world_group().rank == get_world_group().world_size - 1:
        print(
            f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB"
        )
    get_runtime_state().destory_distributed_env()


if __name__ == "__main__":
    main()

@algorithmconquer
Copy link
Author

@Lay2000 Thank you for sharing the code. I was able to implement the inference pipeline for flux-dev in bfloat16 by using model shards with 2gpus(each gpu is 24576MiB). I want to try the inference performance of xdit in the same device and environment(datatype=bfloat16, height=width=1024, 2gpus(each gpu is 24576MiB)).

@algorithmconquer
Copy link
Author

@Lay2000
The running command is:
torchrun --nproc_per_node=2 flux_example_2.py --model ./flux.1-dev/ --use_cfg_parallel --height 1024 --width 1024 --prompt 'brown dog laying on the ground with a metal bowl in front of him.' --num_inference_steps 50 --no_use_resolution_binning

The error log is :
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/fluxProjects/xDiT_20241119/examples/flux_example_2.py", line 87, in
[rank1]: main()
[rank1]: File "/home/fluxProjects/xDiT_20241119/examples/flux_example_2.py", line 44, in main
[rank1]: pipe = pipe.to(f"cuda:{local_rank}")
[rank1]: File "/home/fluxProjects/xDiT_20241119/xfuser/model_executor/pipelines/base_pipeline.py", line 116, in to
[rank1]: self.module = self.module.to(*args, **kwargs)
[rank1]: File "/home/.custom/root/img-tx8ku2jzhi/envs/OneDiffV0/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py", line 454, in to
[rank1]: module.to(device, dtype)
[rank1]: File "/home/.custom/root/img-tx8ku2jzhi/envs/OneDiffV0/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1340, in to
[rank1]: return self._apply(convert)
[rank1]: File "/home/.custom/root/img-tx8ku2jzhi/envs/OneDiffV0/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank1]: module._apply(fn)
[rank1]: File "/home/.custom/root/img-tx8ku2jzhi/envs/OneDiffV0/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank1]: module._apply(fn)
[rank1]: File "/home/.custom/root/img-tx8ku2jzhi/envs/OneDiffV0/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank1]: module._apply(fn)
[rank1]: [Previous line repeated 2 more times]
[rank1]: File "/home/.custom/root/img-tx8ku2jzhi/envs/OneDiffV0/lib/python3.10/site-packages/torch/nn/modules/module.py", line 927, in _apply
[rank1]: param_applied = fn(param)
[rank1]: File "/home/.custom/root/img-tx8ku2jzhi/envs/OneDiffV0/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1326, in convert
[rank1]: return t.to(
[rank1]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 54.00 MiB. GPU 1 has a total capacity of 23.50 GiB of which 29.69 MiB is free. Process 39408 has 23.46 GiB memory in use. Of the allocated memory 23.23 GiB is allocated by PyTorch, and 9.63 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
W1119 14:08:31.719000 280957 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 281020 closing signal SIGTERM
E1119 14:08:31.985000 280957 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 1 (pid: 281021) of binary: /home/.custom/root/img-tx8ku2jzhi/envs/OneDiffV0/bin/python3.10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants