forked from open-mmlab/Live2Diff
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vid2vid.py
144 lines (120 loc) · 4.34 KB
/
vid2vid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import sys
sys.path.append(
os.path.join(
os.path.dirname(__file__),
"..",
)
)
import torch
from config import Args
from PIL import Image
from pydantic import BaseModel, Field
from live2diff.utils.config import load_config
from live2diff.utils.wrapper import StreamAnimateDiffusionDepthWrapper
default_prompt = "masterpiece, best quality, felted, 1man with glasses, glasses, play with his pen"
page_content = """<h1 class="text-3xl font-bold">Live2Diff: </h1>
<h2 class="text-xl font-bold">Live Stream Translation via Uni-directional Attention in Video Diffusion Models</h2>
<p class="text-sm">
This demo showcases
<a
href="https://github.com/open-mmlab/Live2Diff"
target="_blank"
class="text-blue-500 underline hover:no-underline">Live2Diff
</a>
pipeline using
<a
href="https://huggingface.co/latent-consistency/lcm-lora-sdv1-5"
target="_blank"
class="text-blue-500 underline hover:no-underline">LCM-LoRA</a
> with a MJPEG stream server.
</p>
"""
WARMUP_FRAMES = 8
WINDOW_SIZE = 16
class Pipeline:
class Info(BaseModel):
name: str = "Live2Diff"
input_mode: str = "image"
page_content: str = page_content
def build_input_params(self, default_prompt: str = default_prompt, width=512, height=512):
class InputParams(BaseModel):
prompt: str = Field(
default_prompt,
title="Prompt",
field="textarea",
id="prompt",
)
width: int = Field(
512,
min=2,
max=15,
title="Width",
disabled=True,
hide=True,
id="width",
)
height: int = Field(
512,
min=2,
max=15,
title="Height",
disabled=True,
hide=True,
id="height",
)
return InputParams
def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
config_path = args.config
cfg = load_config(config_path)
prompt = args.prompt or cfg.prompt or default_prompt
self.InputParams = self.build_input_params(default_prompt=prompt)
params = self.InputParams()
num_inference_steps = args.num_inference_steps or cfg.get("num_inference_steps", None)
strength = args.strength or cfg.get("strength", None)
t_index_list = args.t_index_list or cfg.get("t_index_list", None)
self.stream = StreamAnimateDiffusionDepthWrapper(
few_step_model_type="lcm",
config_path=config_path,
cfg_type="none",
strength=strength,
num_inference_steps=num_inference_steps,
t_index_list=t_index_list,
frame_buffer_size=1,
width=params.width,
height=params.height,
acceleration=args.acceleration,
do_add_noise=True,
output_type="pil",
enable_similar_image_filter=True,
similar_image_filter_threshold=0.98,
use_denoising_batch=True,
use_tiny_vae=True,
seed=args.seed,
engine_dir=args.engine_dir,
)
self.last_prompt = prompt
self.warmup_frame_list = []
self.has_prepared = False
def predict(self, params: "Pipeline.InputParams") -> Image.Image:
prompt = params.prompt
if prompt != self.last_prompt:
self.last_prompt = prompt
self.warmup_frame_list.clear()
if len(self.warmup_frame_list) < WARMUP_FRAMES:
# from PIL import Image
self.warmup_frame_list.append(self.stream.preprocess_image(params.image))
elif len(self.warmup_frame_list) == WARMUP_FRAMES and not self.has_prepared:
warmup_frames = torch.stack(self.warmup_frame_list)
self.stream.prepare(
warmup_frames=warmup_frames,
prompt=prompt,
guidance_scale=1,
)
self.has_prepared = True
if self.has_prepared:
image_tensor = self.stream.preprocess_image(params.image)
output_image = self.stream(image=image_tensor)
return output_image
else:
return Image.new("RGB", (params.width, params.height))