Skip to content

Commit

Permalink
fix sampling during training function
Browse files Browse the repository at this point in the history
  • Loading branch information
pranay-ar committed Dec 22, 2023
1 parent b7340c8 commit ed40f36
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 150 deletions.
4 changes: 2 additions & 2 deletions configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ epochs: 300
batch_size: 64
image_size: 64
num_classes: 10
dataset_path: "data/cifar10-64/train"
dataset_path: "data/sample"
device: "cuda"
learning_rate: 6.4e-4
evaluation_interval: 10
fid: "off"
wandb: True
wandb: False
mixed_precision: False
distillation: False
compress: False
47 changes: 12 additions & 35 deletions ddpm_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
from tqdm import tqdm
from torch import optim
from utils import *
from argparse import ArgumentParser
from modules import UNet_conditional, EMA
Expand Down Expand Up @@ -52,7 +51,6 @@ def sample(self, model, total_images, batch_size, labels, cfg_scale=3):
x = torch.randn((batch_n, 3, self.img_size, self.img_size)).to(self.device)

for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
logging.info(f"Step {i} of noise_steps in batch {batch_idx // batch_size + 1}")
t = (torch.ones(batch_n) * i).long().to(self.device)
predicted_noise = model(x, t, labels)
if cfg_scale > 0:
Expand All @@ -73,31 +71,6 @@ def sample(self, model, total_images, batch_size, labels, cfg_scale=3):

model.train()
return generated_images

def sample_train(self, model, n, labels, cfg_scale=3):
logging.info(f"Sampling {n} new images....")
model.eval()
with torch.no_grad():
x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
t = (torch.ones(n) * i).long().to(self.device)
predicted_noise = model(x, t, labels)
if cfg_scale > 0:
uncond_predicted_noise = model(x, t, None)
predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
alpha = self.alpha[t][:, None, None, None]
alpha_hat = self.alpha_hat[t][:, None, None, None]
beta = self.beta[t][:, None, None, None]
if i > 1:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
model.train()
x = (x.clamp(-1, 1) + 1) / 2
x = (x * 255).type(torch.uint8)
return x


def train(configs):

Expand Down Expand Up @@ -133,7 +106,7 @@ def train(configs):
teacher = UNet_conditional(num_classes=num_classes).to(device)
ckpt = torch.load(
args.teacher_path if args.teacher_path is not None else \
"/work/pi_adrozdov_umass_edu/pranayr_umass_edu/cs682/Diffusion-Models-pytorch/models/DDPM_conditional/ema_ckpt.pt"
"./models/DDPM_conditional/ema_ckpt.pt"
)
ckpt = fix_state_dict(ckpt)
teacher.load_state_dict(ckpt)
Expand All @@ -143,7 +116,7 @@ def train(configs):
param.requires_grad = False


optimizer = optim.AdamW(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
mse = nn.MSELoss()
diffusion = Diffusion(img_size=image_size, device=device)
l = len(dataloader)
Expand Down Expand Up @@ -201,17 +174,21 @@ def train(configs):
wandb.log({"MSE": loss.item(), "Epoch": epoch, "Batch": i})

if epoch % 5 == 0:
labels = torch.arange(10).long().to(device)
sampled_images = diffusion.sample_train(model, n=len(labels), labels=labels)
ema_sampled_images = diffusion.sample_train(ema_model, n=len(labels), labels=labels)
plot_images(sampled_images)
make_grid(sampled_images, os.path.join(results_dir, f"{epoch}.jpg"))
make_grid(ema_sampled_images, os.path.join(results_dir, f"{epoch}_ema.jpg"))
torch.save(model.module.state_dict(), os.path.join(model_dir, f"ckpt.pt"))
torch.save(ema_model.state_dict(), os.path.join(model_dir, f"ema_ckpt.pt"))
torch.save(optimizer.state_dict(), os.path.join(model_dir, f"optim.pt"))
print("Saved model and optimizer states at epoch {}.".format(epoch))

saved_model = UNet_conditional(compress=1, num_classes=10).to(device)
ckpt = torch.load(os.path.join(model_dir, f"ema_ckpt.pt")) # load last checkpoint
ckpt = fix_state_dict(ckpt)
saved_model.load_state_dict(ckpt, strict=False)
diffusion = Diffusion(img_size=64, device=device)
labels = torch.arange(10).long().to(device)
ema_sampled_images = diffusion.sample(saved_model, total_images=len(labels), batch_size=10, labels=labels)
plot_images(ema_sampled_images)
make_grid(ema_sampled_images, os.path.join(results_dir, f"{epoch}_ema.jpg"))

if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument(
Expand Down
17 changes: 9 additions & 8 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,31 @@
from ddpm_conditional import Diffusion

device = "cuda"
# model = UNet_conditional(compress=2, num_classes=10).to(device)
model = torch.load("./models/pruned/ddpm_conditional_pruned/pruned/unet_pruned_0.16_0.01.pth")
ckpt = torch.load("./models/Pruned_0.16_0.01_FT/ema_ckpt.pt")
model = UNet_conditional(compress=1, num_classes=10).to(device)
# model = torch.load("./models/pruned/ddpm_conditional_kd_pruned/pruned/unet_pruned_0.16_0.01.pth")
ckpt = torch.load("./models/DDPM_conditional/ema_ckpt.pt")
ckpt = fix_state_dict(ckpt)
model.load_state_dict(ckpt, strict=False)
diffusion = Diffusion(img_size=64, device=device)
total_images_per_class = 1024
total_images_per_class =12
batch_size = 256
cfg_scale = 0

for class_index in range(10): # Assuming 10 classes
class_folder = f"./fid_data/generated_images_pruned_0.16_0.01_FT/class_{class_index}"
class_folder = f"./fid_data/generated_images_fp16/class_{class_index}"
os.makedirs(class_folder, exist_ok=True)

y = torch.full((total_images_per_class,), class_index, dtype=torch.long).to(device)
print("Shape of y:", y)
x = diffusion.sample(model, total_images_per_class, batch_size, y, cfg_scale=cfg_scale)

# Assuming save_images function handles saving all generated images for a class
save_images(x, class_folder)
# save_images(x, class_folder)

print(f"Images for class {class_index} have been saved in {class_folder}.")

# results_dir = "./fid_data/combined_generated_images_KD" # Directory where generated images will be saved
# dataset_path = './fid_data/combined_real_train' # Path to your CIFAR10 dataset
# results_dir = "./fid_data/combined_generated_images_kd_pruned_0.16_0.01_ft" # Directory where generated images will be saved
# dataset_path = './fid_data/combined_real_full_train' # Path to your CIFAR10 dataset
# device = "cuda"

# # Call the function to generate images and calculate FID
Expand Down
18 changes: 11 additions & 7 deletions noising_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
from torchvision.utils import save_image
from ddpm import Diffusion
from utils import get_data
import argparse

parser = argparse.ArgumentParser()
args = parser.parse_args()
args.batch_size = 1 # 5
args.image_size = 64
args.dataset_path = r"/work/pi_adrozdov_umass_edu/pranayr_umass_edu/cs682/Diffusion-Models-pytorch/cifar10-64/test/"

dataloader = get_data(args)
batch_size = 1
image_size = 64
dataset_path = r"./data/cifar10/test/"
configs = {
"batch_size": batch_size,
"image_size": image_size,
"dataset_path": dataset_path
}


dataloader = get_data(configs)

diff = Diffusion(device="cpu")

Expand Down
4 changes: 2 additions & 2 deletions ddpm_prune.py → pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
from glob import glob
from PIL import Image
from utils import fix_state_dict, get_dataset
from utils import fix_state_dict, get_data

import numpy as np
import torch.nn as nn
Expand Down Expand Up @@ -36,7 +36,7 @@

# loading images for gradient-based pruning
if args.pruner in ['taylor', 'diff-pruning']:
dataset = get_dataset(args.dataset)
dataset = get_data(args.dataset)
print(f"Dataset size: {len(dataset)}")
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, shuffle=False, num_workers=1, drop_last=True
Expand Down
26 changes: 26 additions & 0 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,29 @@
print_size_of_model(model_quantized)

torch.save(model_quantized.state_dict(), "./models/DDPM_conditional/ema_ckpt_quantized.pt")

# Load the entire quantized model
device = "cpu"
quantized_model = torch.load("./models/DDPM_conditional/ema_ckpt_quantized.pth", map_location=device)

# Prepare the model for inference
quantized_model.eval()

# Initialize the diffusion process
diffusion = Diffusion(img_size=64, device=device)

# Sample data for inference
n = 1
y = torch.Tensor([6] * n).long().to(device)

# Generate sample using the quantized model
with torch.no_grad():
outputs = diffusion.sample(quantized_model, n, 1,labels=y, cfg_scale=0)
# Assuming outputs is a list, select the first tensor
x = outputs[0] if isinstance(outputs, list) else outputs

# Normalize the output to [0, 1] range
x = (x - x.min()) / (x.max() - x.min())

# Save the output image
save_image(x, "./test_quantized.jpg")
10 changes: 10 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@ conda activate cdm

python generate.py
# python ddpm_conditional.py --config "./configs/fine_tune.yaml"

# python ddpm_prune.py \
# --dataset ./data/cifar10-64/train/ \
# --model_path ./models/DDPM_conditional/ema_ckpt.pt \
# --save_path ./models/pruned/ddpm_conditional_pruned \
# --pruning_ratio 0.16 \
# --batch_size 32 \
# --pruner diff-pruning \
# --thr 0.05 \
# --device cuda \
25 changes: 0 additions & 25 deletions run_rtx.sh

This file was deleted.

5 changes: 3 additions & 2 deletions test.py → single_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from ddpm_conditional import Diffusion

device = "cpu"
# model = torch.load("./models/pruned/ddpm_conditional_pruned/pruned/unet_pruned_0.16_0.01.pth")
model = UNet_conditional(compress=1,num_classes=10).to(device)
ckpt = torch.load("./models/DDPM_conditional/ema_ckpt.pt", map_location=device)
ckpt = fix_state_dict(ckpt)
model.load_state_dict(ckpt)
diffusion = Diffusion(img_size=64, device=device)
n = 1
y = torch.Tensor([6] * n).long().to(device)
x = diffusion.sample(model, 1, 8, labels=y, cfg_scale=0)
make_grid(x, "./test.jpg")
x = diffusion.sample(model, 1, 1, labels=y, cfg_scale=0)
# make_grid(x, "./test.jpg")
54 changes: 0 additions & 54 deletions test_quantize.py

This file was deleted.

45 changes: 30 additions & 15 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ddpm_conditional import Diffusion
from pytorch_fid import fid_score
import time
import random


def plot_images(images):
Expand Down Expand Up @@ -77,18 +78,32 @@ def print_size_of_model(model):
print('Size:', os.path.getsize("temp.p")/1e6, 'MB')
os.remove('temp.p')

def get_dataset(name_or_path, transform=None):

print(name_or_path)
if "cifar10-64" in name_or_path.lower():
if transform is None:
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(80), # configs.image_size + 1/4 *configs.image_size
torchvision.transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# print(name_or_path)
# print(transform)
dataset = torchvision.datasets.ImageFolder(name_or_path, transform=transform)
return dataset
def create_row_collage(base_dir, class_folders, output_path):
images = []
# Load one random image from each class folder
for folder in class_folders:
class_dir = os.path.join(base_dir, folder)
files = [f for f in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, f))]
if not files:
continue # Skip folders with no images
selected_file = random.choice(files)
img_path = os.path.join(class_dir, selected_file)
images.append(Image.open(img_path))

# Calculate total width and max height for the collage
total_width = sum(img.width for img in images)
max_height = max(img.height for img in images)

# Create a new blank image to paste the images into
collage = Image.new('RGB', (total_width, max_height))

# Paste images side by side
x_offset = 0
for img in images:
collage.paste(img, (x_offset, 0))
x_offset += img.width

# Save the collage
collage.save(output_path)

print("Collage saved to: ", output_path)

0 comments on commit ed40f36

Please sign in to comment.