-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
lcm_gen.py
115 lines (104 loc) · 3.46 KB
/
lcm_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import dataclasses
import os
import random
import sys
from contextlib import contextmanager
from functools import partial
from typing import Iterable
from unittest.mock import patch
import torch
from PIL.Image import Image
from diffusers import DiffusionPipeline
def get_best_device() -> str:
if dev := os.environ.get("LCM_DEVICE"):
return dev
if sys.platform == "darwin":
try:
torch.mps.current_allocated_memory()
return "mps"
except Exception:
pass
if torch.cuda.is_available():
return "cuda"
return "cpu"
@dataclasses.dataclass
class ResultImage:
seed: int
batch_num: int
batch_index: int
batch_count: int
image: Image
@contextmanager
def default_progress_context(info: str):
print("Starting", info)
yield
print("Finished", info)
class LCMGenerator:
def __init__(
self,
*,
device: str | None = None,
fp16: bool = bool(os.environ.get("LCM_FP16")),
progress_context=default_progress_context,
):
if not device:
device = get_best_device()
self.device = device
self.dtype = torch.float16 if fp16 else torch.float32
self._pipe = None
self.progress_context = progress_context
@property
def pipe(self):
if not self._pipe:
with self.progress_context("Initializing pipeline..."):
pipe = DiffusionPipeline.from_pretrained(
"SimianLuo/LCM_Dreamshaper_v7",
custom_pipeline="latent_consistency_txt2img",
custom_revision="main",
)
pipe.to(torch_device=self.device, torch_dtype=self.dtype)
pipe.safety_checker = None # we're all adults here
self._pipe = pipe
return self._pipe
def generate(
self,
*,
prompt: str,
width: int,
height: int,
steps: int,
cfg: float,
batch_size: int = 1,
batch_count: int = 1,
seed: int | None = None,
) -> Iterable[ResultImage]:
if not seed or seed <= 0:
seed = random.randint(0, 2**32)
pipe: DiffusionPipeline = self.pipe
for batch_idx in range(batch_count):
with self.progress_context(
f"[{batch_idx + 1}/{batch_count}] Generating {batch_size} images ({self.device}, {self.dtype})...",
):
gen = torch.Generator(device="cpu")
batch_seed = seed + batch_idx
gen.manual_seed(batch_seed)
with patch("torch.randn", partial(torch.randn, generator=gen)):
for in_batch_index, image in enumerate(
pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=steps,
guidance_scale=cfg,
lcm_origin_steps=50,
output_type="pil",
num_images_per_prompt=batch_size,
).images,
):
yield ResultImage(
seed=batch_seed,
batch_num=batch_idx,
batch_index=in_batch_index,
batch_count=batch_count,
image=image,
)