-
Notifications
You must be signed in to change notification settings - Fork 18
/
generate_images.py
73 lines (52 loc) · 2.31 KB
/
generate_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from semantic_aug.augmentations.textual_inversion import TextualInversion
from diffusers import StableDiffusionPipeline
from itertools import product
from torch import autocast
from PIL import Image
from tqdm import trange
import os
import torch
import argparse
import numpy as np
import random
DEFAULT_ERASURE_CKPT = (
"/projects/rsalakhugroup/btrabucc/esd-models/" +
"compvis-word_airplane-method_full-sg_3-ng_1-iter_1000-lr_1e-05/" +
"diffusers-word_airplane-method_full-sg_3-ng_1-iter_1000-lr_1e-05.pt")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Stable Diffusion inference script")
parser.add_argument("--model-path", type=str, default="CompVis/stable-diffusion-v1-4")
parser.add_argument("--embed-path", type=str, default=(
"erasure-tokens/fine-tuned/pascal-0-8/airplane/learned_embeds.bin"))
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--num-generate", type=int, default=10)
parser.add_argument("--prompt", type=str, default="a photo of a <airplane>")
parser.add_argument("--out", type=str, default="erasure-tokens/fine-tuned/pascal-0-8/airplane/")
parser.add_argument("--guidance-scale", type=float, default=7.5)
parser.add_argument("--erasure-ckpt-name", type=str, default=DEFAULT_ERASURE_CKPT)
args = parser.parse_args()
os.makedirs(args.out, exist_ok=True)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
pipe = StableDiffusionPipeline.from_pretrained(
args.model_path, use_auth_token=True,
revision="fp16",
torch_dtype=torch.float16
).to('cuda')
aug = TextualInversion(args.embed_path, model_path=args.model_path)
pipe.tokenizer = aug.pipe.tokenizer
pipe.text_encoder = aug.pipe.text_encoder
pipe.set_progress_bar_config(disable=True)
pipe.safety_checker = None
if args.erasure_ckpt_name is not None:
pipe.unet.load_state_dict(torch.load(
args.erasure_ckpt_name, map_location='cuda'))
for idx in trange(args.num_generate,
desc="Generating Images"):
with autocast('cuda'):
image = pipe(
args.prompt,
guidance_scale=args.guidance_scale
).images[0]
image.save(os.path.join(args.out, f"{idx}.png"))