-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Playing around with stable diffusion
- Loading branch information
0 parents
commit aa19a6e
Showing
5 changed files
with
251 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# PyCharm IDE | ||
.idea | ||
__pycache__ | ||
|
||
# vscode | ||
.vscode | ||
|
||
# Jupyter notebook checkpoints | ||
.ipynb_checkpoints | ||
|
||
# Data directory | ||
output/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2022 Aleksa Gordić | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
## Generate images using stable diffusion | ||
|
||
## Setup | ||
|
||
Follow the next steps to run this code: | ||
|
||
1. `git clone https://github.com/gordicaleksa/stable_diffusion_playground` | ||
2. Open Anaconda console and navigate into project directory `cd path_to_repo` | ||
3. Run `conda env create` from project directory (this will create a brand new conda environment). | ||
4. Run `activate sd_playground` (for running scripts from your console or setup the interpreter in your IDE) | ||
5. Run `huggingface-cli login` before the first time you try to use it to access model weights. | ||
|
||
That's it! It should work out-of-the-box executing environment.yml file which deals with dependencies. <br/> | ||
|
||
## Acknowledgements | ||
|
||
Took inspiration from [Karpathy's gist](https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355). | ||
|
||
## Licence | ||
|
||
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/gordicaleksa/stable_diffusion_playground/blob/master/LICENCE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
name: sd_playground | ||
channels: | ||
- defaults | ||
- pytorch | ||
dependencies: | ||
- python=3.8.5 | ||
- pip=20.3 | ||
- cudatoolkit=11.3 | ||
- pytorch=1.11.0 | ||
- numpy=1.19.2 | ||
- pip: | ||
- diffusers==0.2.4 | ||
- transformers==4.19.2 | ||
- scipy | ||
- matplotlib | ||
- fire==0.4.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# You'll have to run "huggingface-cli login" the first time so that you can access the model weights. | ||
|
||
import enum | ||
import os | ||
import json | ||
|
||
from diffusers import StableDiffusionPipeline | ||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler | ||
import fire | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import torch | ||
from torch import autocast | ||
|
||
|
||
class ExecutionMode(enum.Enum): | ||
GENERATE_DIVERSE = 0, | ||
REPRODUCE = 1, | ||
INTERPOLATE = 2 | ||
|
||
|
||
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): | ||
""" helper function to spherically interpolate two arrays v1 v2 """ | ||
|
||
if not isinstance(v0, np.ndarray): | ||
inputs_are_torch = True | ||
input_device = v0.device | ||
v0 = v0.cpu().numpy() | ||
v1 = v1.cpu().numpy() | ||
|
||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) | ||
if np.abs(dot) > DOT_THRESHOLD: | ||
v2 = (1 - t) * v0 + t * v1 | ||
else: | ||
theta_0 = np.arccos(dot) | ||
sin_theta_0 = np.sin(theta_0) | ||
theta_t = theta_0 * t | ||
sin_theta_t = np.sin(theta_t) | ||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0 | ||
s1 = sin_theta_t / sin_theta_0 | ||
v2 = s0 * v0 + s1 * v1 | ||
|
||
if inputs_are_torch: | ||
v2 = torch.from_numpy(v2).to(input_device) | ||
|
||
return v2 | ||
|
||
|
||
def generate_name(output_dir_path, suffix='jpg'): | ||
prefix = str(len(os.listdir(output_dir_path))).zfill(6) | ||
return f'{prefix}.{suffix}' | ||
|
||
|
||
def save_metadata(meta_dir, prompt, num_inference_steps, guidance_scale): | ||
data = { # Feel free to add anything else you might need. | ||
'prompt': prompt, | ||
'num_steps': num_inference_steps, | ||
'scale': guidance_scale | ||
} | ||
with open(os.path.join(meta_dir, generate_name(meta_dir, suffix='json')), 'w') as f: | ||
json.dump(data, f) | ||
|
||
|
||
def run( | ||
# -------------------------------------- | ||
# args you probably want to change | ||
name='ai_epiphany', # name of the output directory | ||
execution_mode=ExecutionMode.INTERPOLATE, | ||
prompt="a painting of an ai robot having an epiphany moment", | ||
num_inference_steps=50, # More (e.g. 100, 200 etc) can create slightly better images. | ||
guidance_scale=7.5, # Can depend on the prompt. Usually somewhere between 3-10 is good. | ||
num_imgs=5, # How many images you want to generate in this run. | ||
# -------------------------------------- | ||
# args you probably don't want to change | ||
seed=23, # I love it more than 42 | ||
width=512, | ||
height=512, | ||
fp16=True, # Set to True unless you have ~16 GBs of VRAM. | ||
src_latent_path="T:\\YouTube_Code\\8_Stable_Diffusion\\stable-diffusion\\ai_epiphany\\latents\\000000.npy", | ||
trg_latent_path=None, | ||
metadata_path="T:\\YouTube_Code\\8_Stable_Diffusion\\stable-diffusion\\ai_epiphany\\meta\\000000.json", | ||
# -------------------------------------- | ||
): | ||
assert torch.cuda.is_available(), "You need a GPU to run this script." | ||
assert height % 8 == 0 and width % 8 == 0, f"Width and height need to be a multiple of 8, got (w,h)=({width},{height})." | ||
device = "cuda" | ||
if seed: # If you want to have consistent runs. | ||
torch.manual_seed(seed) | ||
|
||
# Initialize the output file structure. | ||
root_dir = os.path.join(os.getcwd(), 'output', name) | ||
imgs_dir = os.path.join(root_dir, "samples") | ||
latents_dir = os.path.join(root_dir, "latents") | ||
meta_dir = os.path.join(root_dir, "meta") | ||
os.makedirs(imgs_dir, exist_ok=True) | ||
os.makedirs(latents_dir, exist_ok=True) | ||
os.makedirs(meta_dir, exist_ok=True) | ||
|
||
# Hardcoded the recommended scheduler - feel free to play with it. | ||
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") | ||
|
||
# Create diffusion pipeline object. | ||
pipe = StableDiffusionPipeline.from_pretrained( | ||
"CompVis/stable-diffusion-v1-4", | ||
torch_dtype=torch.float16 if fp16 else None, | ||
revision="fp16" if fp16 else "main", | ||
scheduler=lms, | ||
use_auth_token=True | ||
).to(device) | ||
|
||
if execution_mode == execution_mode.GENERATE_DIVERSE: | ||
for i in range(num_imgs): | ||
init_latent = torch.randn((1, pipe.unet.in_channels, height // 8, width // 8), device=device) | ||
|
||
with autocast(device): | ||
image = pipe( | ||
prompt, | ||
num_inference_steps=num_inference_steps, | ||
latents=init_latent, | ||
guidance_scale=guidance_scale | ||
)["sample"][0] | ||
|
||
# Make sure generation is reproducible. | ||
image.save(os.path.join(imgs_dir, generate_name(imgs_dir, suffix='jpg'))) | ||
# TODO: is there some clever python mechanism that can enable me to log all input arg values? | ||
save_metadata(meta_dir, prompt, num_inference_steps, guidance_scale) | ||
np.save(os.path.join(latents_dir, generate_name(latents_dir, suffix='npy')), init_latent.cpu().numpy()) | ||
|
||
elif execution_mode == execution_mode.REPRODUCE: | ||
assert src_latent_path, 'You need to provide the latent path if you wish to reproduce an image.' | ||
assert metadata_path, 'You need to provide the metadata path if you wish to reproduce an image.' | ||
with open(metadata_path) as metadata_file: | ||
metadata = json.load(metadata_file) | ||
init = torch.from_numpy(np.load(src_latent_path)).to(device) | ||
with autocast(device): | ||
image = pipe( | ||
**metadata, | ||
latents=init, | ||
output_type='npy', | ||
# as long as it's not pil it'll return numpy with the current imp of StableDiffusionPipeline | ||
)["sample"][0] | ||
plt.imshow((image * 255).astype(np.uint8)); | ||
plt.show() | ||
|
||
elif execution_mode == execution_mode.INTERPOLATE: | ||
if src_latent_path and trg_latent_path: | ||
print('Loading existing source and target latents.') | ||
src_init = torch.from_numpy(np.load(src_latent_path)).to(device) | ||
trg_init = torch.from_numpy(np.load(trg_latent_path)).to(device) | ||
else: | ||
print('Generating random source and target latents.') | ||
src_init = torch.randn((1, pipe.unet.in_channels, height // 8, width // 8), device=device) | ||
trg_init = torch.randn((1, pipe.unet.in_channels, height // 8, width // 8), device=device) | ||
|
||
# Make sure generation is reproducible. | ||
save_metadata(meta_dir, prompt, num_inference_steps, guidance_scale) | ||
np.save(os.path.join(latents_dir, generate_name(latents_dir, suffix='npy')), src_init.cpu().numpy()) | ||
np.save(os.path.join(latents_dir, generate_name(latents_dir, suffix='npy')), trg_init.cpu().numpy()) | ||
|
||
for i, t in enumerate(np.concatenate([[0], np.linspace(0, 1, num_imgs)])): | ||
if i == 0: | ||
init_latent = trg_init # Make sure you're happy with the target image before you waste too much time. | ||
else: | ||
init_latent = slerp(float(t), src_init, trg_init) | ||
|
||
with autocast(device): | ||
image = pipe( | ||
prompt, | ||
num_inference_steps=num_inference_steps, | ||
latents=init_latent, | ||
guidance_scale=guidance_scale | ||
)["sample"][0] | ||
|
||
image.save(os.path.join(imgs_dir, generate_name(imgs_dir, suffix='jpg'))) | ||
else: | ||
print(f'Execution mode {execution_mode} not supported.') | ||
|
||
|
||
if __name__ == '__main__': | ||
fire.Fire(run) |