Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support latte #964

Merged
merged 15 commits into from
Jul 1, 2024
273 changes: 273 additions & 0 deletions benchmarks/text_to_video_latte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
MODEL = "maxin-cn/Latte-1"
CKPT = "t2v_v20240523.pt"
VARIANT = None
CUSTOM_PIPELINE = None
# SAMPLE_METHOD = "DDIM"
BETA_START = 0.0001
BETA_END = 0.02
BREA_SCHEDULE = "linear"
VARIANCE_TYPE = "learned_range"
STEPS = 50
SEED = 25
WARMUPS = 1
BATCH = 1
HEIGHT = 512
WIDTH = 512
VIDEO_LENGTH = 16
FPS = 8
GUIDANCE_SCALE = 7.5
ENABLE_TEMPORAL_ATTENTIONS = "true"
ENABLE_VAE_TEMPORAL_DECODER = "true"
OUTPUT_VIDEO = "output.mp4"

PROMPT = "An epic tornado attacking above aglowing city at night."

EXTRA_CALL_KWARGS = None
ATTENTION_FP16_SCORE_ACCUM_MAX_M = 0

COMPILER_CONFIG = None


import os
import importlib
import inspect
import argparse
import time
import json
import random
from PIL import Image, ImageDraw

import torch
import oneflow as flow
from onediffx import compile_pipe, OneflowCompileOptions
from diffusers.utils import load_image, export_to_video
from diffusers.schedulers import DDIMScheduler
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from transformers import T5EncoderModel, T5Tokenizer
import imageio


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=MODEL)
parser.add_argument("--ckpt", type=str, default=CKPT)
parser.add_argument("--prompt", type=str, default=PROMPT)
parser.add_argument("--save_graph", action="store_true")
parser.add_argument("--load_graph", action="store_true")
parser.add_argument("--variant", type=str, default=VARIANT)
parser.add_argument("--custom-pipeline", type=str, default=CUSTOM_PIPELINE)
# parser.add_argument("--sample-method", type=str, default=SAMPLE_METHOD)
parser.add_argument("--beta-start", type=float, default=BETA_START)
parser.add_argument("--beta-end", type=float, default=BETA_END)
parser.add_argument("--beta-schedule", type=str, default=BREA_SCHEDULE)
parser.add_argument(
"--enable_temporal_attentions",
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
default=ENABLE_TEMPORAL_ATTENTIONS,
)
parser.add_argument(
"--enable_vae_temporal_decoder",
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
default=ENABLE_VAE_TEMPORAL_DECODER,
)
parser.add_argument("--guidance-scale", type=float, default=GUIDANCE_SCALE)
parser.add_argument("--variance-type", type=str, default=VARIANCE_TYPE)
parser.add_argument("--steps", type=int, default=STEPS)
parser.add_argument("--seed", type=int, default=SEED)
parser.add_argument("--warmups", type=int, default=WARMUPS)
parser.add_argument("--batch", type=int, default=BATCH)
parser.add_argument("--height", type=int, default=HEIGHT)
parser.add_argument("--width", type=int, default=WIDTH)
parser.add_argument("--video-length", type=int, default=VIDEO_LENGTH)
parser.add_argument("--fps", type=int, default=FPS)
parser.add_argument("--extra-call-kwargs", type=str, default=EXTRA_CALL_KWARGS)
parser.add_argument("--output-video", type=str, default=OUTPUT_VIDEO)
parser.add_argument(
"--compiler",
type=str,
default="nexfort",
choices=["none", "oneflow", "nexfort"],
)
parser.add_argument(
"--compiler-config", type=str, default=COMPILER_CONFIG,
)
parser.add_argument(
"--attention-fp16-score-accum-max-m",
type=int,
default=ATTENTION_FP16_SCORE_ACCUM_MAX_M,
)
return parser.parse_args()


class IterationProfiler:
def __init__(self):
self.begin = None
self.end = None
self.num_iterations = 0

def get_iter_per_sec(self):
if self.begin is None or self.end is None:
return None
self.end.synchronize()
dur = self.begin.elapsed_time(self.end)
return self.num_iterations / dur * 1000.0

def callback_on_step_end(self, pipe, i, t, callback_kwargs={}):
if self.begin is None:
event = torch.cuda.Event(enable_timing=True)
event.record()
self.begin = event
else:
event = torch.cuda.Event(enable_timing=True)
event.record()
self.end = event
self.num_iterations += 1
return callback_kwargs


def main():
args = parse_args()

if os.path.exists(args.model):
model_path = args.model
else:
from huggingface_hub import snapshot_download

model_path = snapshot_download(repo_id=args.model)

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

from onediffx.text_to_video.latte import LatteT2V, LattePipeline

transformer_model = LatteT2V.from_pretrained(
model_path, subfolder="transformer", video_length=args.video_length
).to(device, dtype=torch.float16)

if args.enable_vae_temporal_decoder:
vae = AutoencoderKLTemporalDecoder.from_pretrained(
args.model, subfolder="vae_temporal_decoder", torch_dtype=torch.float16
).to(device)
else:
vae = AutoencoderKL.from_pretrained(
args.model, subfolder="vae", torch_dtype=torch.float16
).to(device)
tokenizer = T5Tokenizer.from_pretrained(args.model, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(
args.model, subfolder="text_encoder", torch_dtype=torch.float16
).to(device)

# set eval mode
transformer_model.eval()
vae.eval()
text_encoder.eval()

scheduler = DDIMScheduler.from_pretrained(
model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
clip_sample=False,
)

pipe = LattePipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer_model,
).to(device)

if args.compiler == "none":
pass
elif args.compiler == "oneflow":
strint marked this conversation as resolved.
Show resolved Hide resolved
print("Oneflow backend is now active...")
compile_options = OneflowCompileOptions()
compile_options.attention_allow_half_precision_score_accumulation_max_m = (
args.attention_fp16_score_accum_max_m
)
pipe = compile_pipe(pipe, options=compile_options)
elif args.compiler == "nexfort":
print("Nexfort backend is now active...")
if args.compiler_config is not None:
# config with dict
options = json.loads(args.compiler_config)
else:
# config with string
options = '{"mode": "max-optimize:max-autotune:freezing:benchmark:low-precision", \
"memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": false, \
"triton.fuse_attention_allow_fp16_reduction": false}}'
pipe = compile_pipe(
pipe, backend="nexfort", options=options, fuse_qkv_projections=True
)
else:
raise ValueError(f"Unknown compiler: {args.compiler}")

def get_kwarg_inputs():
kwarg_inputs = dict(
prompt=args.prompt,
video_length=args.video_length,
height=args.height,
width=args.width,
num_inference_steps=args.steps,
guidance_scale=args.guidance_scale,
enable_temporal_attentions=args.enable_temporal_attentions,
num_images_per_prompt=1,
mask_feature=True,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder,
**(
dict()
if args.extra_call_kwargs is None
else json.loads(args.extra_call_kwargs)
),
)
return kwarg_inputs

if args.warmups > 0:
print("=======================================")
print("Begin warmup")
begin = time.time()
for _ in range(args.warmups):
pipe(**get_kwarg_inputs()).video
end = time.time()
print("End warmup")
print(f"Warmup time: {end - begin:.3f}s")

print("=======================================")

kwarg_inputs = get_kwarg_inputs()
iter_profiler = IterationProfiler()
if "callback_on_step_end" in inspect.signature(pipe).parameters:
kwarg_inputs["callback_on_step_end"] = iter_profiler.callback_on_step_end
elif "callback" in inspect.signature(pipe).parameters:
kwarg_inputs["callback"] = iter_profiler.callback_on_step_end
torch.manual_seed(args.seed)
begin = time.time()
videos = pipe(**kwarg_inputs).video
end = time.time()

print(f"Inference time: {end - begin:.3f}s")
iter_per_sec = iter_profiler.get_iter_per_sec()
if iter_per_sec is not None:
print(f"Iterations per second: {iter_per_sec:.3f}")
cuda_mem_after_used = flow._oneflow_internal.GetCUDAMemoryUsed()
host_mem_after_used = flow._oneflow_internal.GetCPUMemoryUsed()
print(f"CUDA Mem after: {cuda_mem_after_used / 1024:.3f}GiB")
print(f"Host Mem after: {host_mem_after_used / 1024:.3f}GiB")

if args.output_video is not None:
# export_to_video(output_frames[0], args.output_video, fps=args.fps)
try:
imageio.mimwrite(
args.output_video, videos[0], fps=8, quality=9
) # highest quality is 10, lowest is 0
except:
print("Error when saving {}".format(prompt))
else:
print("Please set `--output-video` to save the output video")


if __name__ == "__main__":
main()
Binary file added imgs/latte_nexfort.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
86 changes: 86 additions & 0 deletions onediff_diffusers_extensions/examples/latte/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Run Latte with nexfort backend(Beta Release)


1. [Environment Setup](#environment-setup)
- [Set up onediff](#set-up-onediff)
- [Set up nexfort backend](#set-up-nexfort-backend)
- [Set up Latte](#set-up-latte)
2. [Run](#run)
- [Run without compile](#run-without-compile)
- [Run with compile](#run-with-compile)
3. [Performance Comparison](#performance-comparison)
4. [Quality](#quality)

## Environment setup
### Set up onediff
https://github.com/siliconflow/onediff?tab=readme-ov-file#installation

### Set up nexfort backend
https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler/backends/nexfort

### Set up Latte
HF model: https://huggingface.co/maxin-cn/Latte-1

Github source: https://github.com/Vchitect/Latte

## Run
model_id_or_path_to_latte is the model id or model path of latte, such as `maxin-cn/Latte-1` or `/data/hf_models/Latte-1/`

### Go to the onediff folder
```
cd onediff
```

### Run without compile(the original pytorch HF diffusers pipeline)
```
python3 ./benchmarks/text_to_video_latte.py \
--model maxin-cn/Latte-1 \
--steps 50 \
--compiler none \
----output-video ./latte.mp4 \
--prompt "An epic tornado attacking above aglowing city at night."
```

### Run with compile
```
python3 ./benchmarks/text_to_video_latte.py \
--model maxin-cn/Latte-1 \
--steps 50 \
--compiler nexfort \
----output-video ./latte_compile.mp4 \
--prompt "An epic tornado attacking above aglowing city at night."
```

## Performance Comparison

### Metric

#### On A100
| Metric | NVIDIA A100-PCIE-40GB (512 * 512) |
| ------------------------------------------------ | ----------------------------------- |
| Data update date(yyyy-mm-dd) | 2024-06-19 |
| PyTorch iteration speed | 1.60it/s |
| OneDiff iteration speed | 2.27it/s(+41.9%) |
| PyTorch E2E time | 32.618s |
| OneDiff E2E time | 22.601s(-30.7%) |
| PyTorch Max Mem Used | 28.208GiB |
| OneDiff Max Mem Used | 24.753GiB |
| PyTorch Warmup with Run time | 33.291s |
| OneDiff Warmup with Compilation time<sup>1</sup> | 572.877s |
| OneDiff Warmup with Cache time | 148.068s |

<sup>1</sup> OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Gold 6348 CPU @ 2.60GHz. Note this is just for reference, and it varies a lot on different CPU.

#### nexfort compile config and warmup cost
- compiler-config
- setting `--compiler-config '{"mode": "max-optimize:max-autotune:freezing:benchmark:low-precision", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": false, "triton.fuse_attention_allow_fp16_reduction": false}}` will help to make the best performance but the compilation time is about 572 seconds
- setting `--compiler-config '{"mode": "max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": false, "triton.fuse_attention_allow_fp16_reduction": false}}` will reduce compilation time to about 236 seconds and just slightly reduce the performance
- fuse_qkv_projections: True

## Quality

When using nexfort as the backend for onediff compilation acceleration (right video), the generated video are lossless.

<p align="center">
<img src="../../../imgs/latte_nexfort.gif">
</p>
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ python3 ./benchmarks/text_to_image.py \
--prompt "product photography, world of warcraft orc warrior, white background"
```

## Performance comparation
## Performance Comparison

### Metric

Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Latte originality is here https://github.com/Vchitect/Latte
# ```
# @article{ma2024latte,
# title={Latte: Latent Diffusion Transformer for Video Generation},
# author={Ma, Xin and Wang, Yaohui and Jia, Gengyun and Chen, Xinyuan and Liu, Ziwei and Li, Yuan-Fang and Chen, Cunjian and Qiao, Yu},
# journal={arXiv preprint arXiv:2401.03048},
# year={2024}
# }
# ```

from packaging import version
import importlib
import importlib.metadata

diffusers_0193_v = version.parse("0.19.3")
diffusers_0240_v = version.parse("0.24.0")
diffusers_version = version.parse(importlib.metadata.version("diffusers"))

if diffusers_version < diffusers_0193_v:
raise ImportError(
f"onediffx supports at least version of diffusers-0.19.3, Currently version {str(diffusers_version)}! Please upgrade diffusers!"
)

if diffusers_version >= diffusers_0240_v:
from .latte_t2v import LatteT2V
from .pipeline_latte import LattePipeline
Loading
Loading