From ed40f3636c453cac93489f59217d040858bc78cd Mon Sep 17 00:00:00 2001 From: pranay-ar Date: Fri, 22 Dec 2023 08:05:00 +0000 Subject: [PATCH] fix sampling during training function --- configs/default.yaml | 4 +-- ddpm_conditional.py | 47 ++++++++-------------------- generate.py | 17 ++++++----- noising_test.py | 18 ++++++----- ddpm_prune.py => pruning.py | 4 +-- quantize.py | 26 ++++++++++++++++ run.sh | 10 ++++++ run_rtx.sh | 25 --------------- test.py => single_generation.py | 5 +-- test_quantize.py | 54 --------------------------------- utils.py | 45 ++++++++++++++++++--------- 11 files changed, 105 insertions(+), 150 deletions(-) rename ddpm_prune.py => pruning.py (98%) delete mode 100755 run_rtx.sh rename test.py => single_generation.py (77%) delete mode 100644 test_quantize.py diff --git a/configs/default.yaml b/configs/default.yaml index 487d5ff..beaff8d 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -3,12 +3,12 @@ epochs: 300 batch_size: 64 image_size: 64 num_classes: 10 -dataset_path: "data/cifar10-64/train" +dataset_path: "data/sample" device: "cuda" learning_rate: 6.4e-4 evaluation_interval: 10 fid: "off" -wandb: True +wandb: False mixed_precision: False distillation: False compress: False diff --git a/ddpm_conditional.py b/ddpm_conditional.py index 354a703..6a13bba 100644 --- a/ddpm_conditional.py +++ b/ddpm_conditional.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn from tqdm import tqdm -from torch import optim from utils import * from argparse import ArgumentParser from modules import UNet_conditional, EMA @@ -52,7 +51,6 @@ def sample(self, model, total_images, batch_size, labels, cfg_scale=3): x = torch.randn((batch_n, 3, self.img_size, self.img_size)).to(self.device) for i in tqdm(reversed(range(1, self.noise_steps)), position=0): - logging.info(f"Step {i} of noise_steps in batch {batch_idx // batch_size + 1}") t = (torch.ones(batch_n) * i).long().to(self.device) predicted_noise = model(x, t, labels) if cfg_scale > 0: @@ -73,31 +71,6 @@ def sample(self, model, total_images, batch_size, labels, cfg_scale=3): model.train() return generated_images - - def sample_train(self, model, n, labels, cfg_scale=3): - logging.info(f"Sampling {n} new images....") - model.eval() - with torch.no_grad(): - x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device) - for i in tqdm(reversed(range(1, self.noise_steps)), position=0): - t = (torch.ones(n) * i).long().to(self.device) - predicted_noise = model(x, t, labels) - if cfg_scale > 0: - uncond_predicted_noise = model(x, t, None) - predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale) - alpha = self.alpha[t][:, None, None, None] - alpha_hat = self.alpha_hat[t][:, None, None, None] - beta = self.beta[t][:, None, None, None] - if i > 1: - noise = torch.randn_like(x) - else: - noise = torch.zeros_like(x) - x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise - model.train() - x = (x.clamp(-1, 1) + 1) / 2 - x = (x * 255).type(torch.uint8) - return x - def train(configs): @@ -133,7 +106,7 @@ def train(configs): teacher = UNet_conditional(num_classes=num_classes).to(device) ckpt = torch.load( args.teacher_path if args.teacher_path is not None else \ - "/work/pi_adrozdov_umass_edu/pranayr_umass_edu/cs682/Diffusion-Models-pytorch/models/DDPM_conditional/ema_ckpt.pt" + "./models/DDPM_conditional/ema_ckpt.pt" ) ckpt = fix_state_dict(ckpt) teacher.load_state_dict(ckpt) @@ -143,7 +116,7 @@ def train(configs): param.requires_grad = False - optimizer = optim.AdamW(model.parameters(), lr=lr) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) mse = nn.MSELoss() diffusion = Diffusion(img_size=image_size, device=device) l = len(dataloader) @@ -201,17 +174,21 @@ def train(configs): wandb.log({"MSE": loss.item(), "Epoch": epoch, "Batch": i}) if epoch % 5 == 0: - labels = torch.arange(10).long().to(device) - sampled_images = diffusion.sample_train(model, n=len(labels), labels=labels) - ema_sampled_images = diffusion.sample_train(ema_model, n=len(labels), labels=labels) - plot_images(sampled_images) - make_grid(sampled_images, os.path.join(results_dir, f"{epoch}.jpg")) - make_grid(ema_sampled_images, os.path.join(results_dir, f"{epoch}_ema.jpg")) torch.save(model.module.state_dict(), os.path.join(model_dir, f"ckpt.pt")) torch.save(ema_model.state_dict(), os.path.join(model_dir, f"ema_ckpt.pt")) torch.save(optimizer.state_dict(), os.path.join(model_dir, f"optim.pt")) print("Saved model and optimizer states at epoch {}.".format(epoch)) + saved_model = UNet_conditional(compress=1, num_classes=10).to(device) + ckpt = torch.load(os.path.join(model_dir, f"ema_ckpt.pt")) # load last checkpoint + ckpt = fix_state_dict(ckpt) + saved_model.load_state_dict(ckpt, strict=False) + diffusion = Diffusion(img_size=64, device=device) + labels = torch.arange(10).long().to(device) + ema_sampled_images = diffusion.sample(saved_model, total_images=len(labels), batch_size=10, labels=labels) + plot_images(ema_sampled_images) + make_grid(ema_sampled_images, os.path.join(results_dir, f"{epoch}_ema.jpg")) + if __name__ == '__main__': parser = ArgumentParser() parser.add_argument( diff --git a/generate.py b/generate.py index 6a57131..4775304 100644 --- a/generate.py +++ b/generate.py @@ -11,30 +11,31 @@ from ddpm_conditional import Diffusion device = "cuda" -# model = UNet_conditional(compress=2, num_classes=10).to(device) -model = torch.load("./models/pruned/ddpm_conditional_pruned/pruned/unet_pruned_0.16_0.01.pth") -ckpt = torch.load("./models/Pruned_0.16_0.01_FT/ema_ckpt.pt") +model = UNet_conditional(compress=1, num_classes=10).to(device) +# model = torch.load("./models/pruned/ddpm_conditional_kd_pruned/pruned/unet_pruned_0.16_0.01.pth") +ckpt = torch.load("./models/DDPM_conditional/ema_ckpt.pt") ckpt = fix_state_dict(ckpt) model.load_state_dict(ckpt, strict=False) diffusion = Diffusion(img_size=64, device=device) -total_images_per_class = 1024 +total_images_per_class =12 batch_size = 256 cfg_scale = 0 for class_index in range(10): # Assuming 10 classes - class_folder = f"./fid_data/generated_images_pruned_0.16_0.01_FT/class_{class_index}" + class_folder = f"./fid_data/generated_images_fp16/class_{class_index}" os.makedirs(class_folder, exist_ok=True) y = torch.full((total_images_per_class,), class_index, dtype=torch.long).to(device) + print("Shape of y:", y) x = diffusion.sample(model, total_images_per_class, batch_size, y, cfg_scale=cfg_scale) # Assuming save_images function handles saving all generated images for a class - save_images(x, class_folder) + # save_images(x, class_folder) print(f"Images for class {class_index} have been saved in {class_folder}.") -# results_dir = "./fid_data/combined_generated_images_KD" # Directory where generated images will be saved -# dataset_path = './fid_data/combined_real_train' # Path to your CIFAR10 dataset +# results_dir = "./fid_data/combined_generated_images_kd_pruned_0.16_0.01_ft" # Directory where generated images will be saved +# dataset_path = './fid_data/combined_real_full_train' # Path to your CIFAR10 dataset # device = "cuda" # # Call the function to generate images and calculate FID diff --git a/noising_test.py b/noising_test.py index 849cf4f..db3f174 100644 --- a/noising_test.py +++ b/noising_test.py @@ -2,15 +2,19 @@ from torchvision.utils import save_image from ddpm import Diffusion from utils import get_data -import argparse -parser = argparse.ArgumentParser() -args = parser.parse_args() -args.batch_size = 1 # 5 -args.image_size = 64 -args.dataset_path = r"/work/pi_adrozdov_umass_edu/pranayr_umass_edu/cs682/Diffusion-Models-pytorch/cifar10-64/test/" -dataloader = get_data(args) +batch_size = 1 +image_size = 64 +dataset_path = r"./data/cifar10/test/" +configs = { + "batch_size": batch_size, + "image_size": image_size, + "dataset_path": dataset_path +} + + +dataloader = get_data(configs) diff = Diffusion(device="cpu") diff --git a/ddpm_prune.py b/pruning.py similarity index 98% rename from ddpm_prune.py rename to pruning.py index d65eb1b..f1d004e 100644 --- a/ddpm_prune.py +++ b/pruning.py @@ -8,7 +8,7 @@ import os from glob import glob from PIL import Image -from utils import fix_state_dict, get_dataset +from utils import fix_state_dict, get_data import numpy as np import torch.nn as nn @@ -36,7 +36,7 @@ # loading images for gradient-based pruning if args.pruner in ['taylor', 'diff-pruning']: - dataset = get_dataset(args.dataset) + dataset = get_data(args.dataset) print(f"Dataset size: {len(dataset)}") train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=1, drop_last=True diff --git a/quantize.py b/quantize.py index 47a7273..4ef20ce 100644 --- a/quantize.py +++ b/quantize.py @@ -27,3 +27,29 @@ print_size_of_model(model_quantized) torch.save(model_quantized.state_dict(), "./models/DDPM_conditional/ema_ckpt_quantized.pt") + +# Load the entire quantized model +device = "cpu" +quantized_model = torch.load("./models/DDPM_conditional/ema_ckpt_quantized.pth", map_location=device) + +# Prepare the model for inference +quantized_model.eval() + +# Initialize the diffusion process +diffusion = Diffusion(img_size=64, device=device) + +# Sample data for inference +n = 1 +y = torch.Tensor([6] * n).long().to(device) + +# Generate sample using the quantized model +with torch.no_grad(): + outputs = diffusion.sample(quantized_model, n, 1,labels=y, cfg_scale=0) + # Assuming outputs is a list, select the first tensor + x = outputs[0] if isinstance(outputs, list) else outputs + + # Normalize the output to [0, 1] range + x = (x - x.min()) / (x.max() - x.min()) + +# Save the output image +save_image(x, "./test_quantized.jpg") diff --git a/run.sh b/run.sh index 938051f..6eb1add 100755 --- a/run.sh +++ b/run.sh @@ -12,3 +12,13 @@ conda activate cdm python generate.py # python ddpm_conditional.py --config "./configs/fine_tune.yaml" + +# python ddpm_prune.py \ +# --dataset ./data/cifar10-64/train/ \ +# --model_path ./models/DDPM_conditional/ema_ckpt.pt \ +# --save_path ./models/pruned/ddpm_conditional_pruned \ +# --pruning_ratio 0.16 \ +# --batch_size 32 \ +# --pruner diff-pruning \ +# --thr 0.05 \ +# --device cuda \ \ No newline at end of file diff --git a/run_rtx.sh b/run_rtx.sh deleted file mode 100755 index 830177a..0000000 --- a/run_rtx.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash -#SBATCH -c 12 # Number of Cores per Task -#SBATCH --mem=200000 # Requested Memory -#SBATCH -p gypsum-2080ti # Partition -#SBATCH --gres=gpu:8 # Number of GPUs -#SBATCH -t 72:00:00 # Job time limit -#SBATCH -o slurm_jobs/slurm-%j.out # %j = job ID -#SBATCH --mail-type=END - -# source /home/pranayr_umass_edu/miniconda3/etc/profile.d/conda.sh -# conda activate cmd - -# python ddpm_conditional.py --config "./configs/rtx.yaml" -# python ddpm_conditional.py --config "./configs/compressed_kd_MPT_RLT.yaml" -python generate.py - -# python ddpm_prune.py \ -# --dataset ./data/cifar10-64/train/ \ -# --model_path ./models/DDPM_conditional/ema_ckpt.pt \ -# --save_path ./models/pruned/ddpm_conditional_pruned \ -# --pruning_ratio 0.16 \ -# --batch_size 32 \ -# --pruner diff-pruning \ -# --thr 0.05 \ -# --device cuda \ \ No newline at end of file diff --git a/test.py b/single_generation.py similarity index 77% rename from test.py rename to single_generation.py index 4633997..d94577f 100644 --- a/test.py +++ b/single_generation.py @@ -14,6 +14,7 @@ from ddpm_conditional import Diffusion device = "cpu" +# model = torch.load("./models/pruned/ddpm_conditional_pruned/pruned/unet_pruned_0.16_0.01.pth") model = UNet_conditional(compress=1,num_classes=10).to(device) ckpt = torch.load("./models/DDPM_conditional/ema_ckpt.pt", map_location=device) ckpt = fix_state_dict(ckpt) @@ -21,5 +22,5 @@ diffusion = Diffusion(img_size=64, device=device) n = 1 y = torch.Tensor([6] * n).long().to(device) -x = diffusion.sample(model, 1, 8, labels=y, cfg_scale=0) -make_grid(x, "./test.jpg") \ No newline at end of file +x = diffusion.sample(model, 1, 1, labels=y, cfg_scale=0) +# make_grid(x, "./test.jpg") \ No newline at end of file diff --git a/test_quantize.py b/test_quantize.py deleted file mode 100644 index 94a0d26..0000000 --- a/test_quantize.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch.nn as nn -from utils import * -from modules import UNet_conditional - -# # Initialize the model -# model = UNet_conditional(compress=1, num_classes=10) -# model.eval() - -# # Load the pre-trained model weights -# ckpt = torch.load("./models/DDPM_conditional/ema_ckpt.pt") -# ckpt = fix_state_dict(ckpt) -# model.load_state_dict(ckpt) - -# # Specify the layers to be quantized -# quantize_layers = {nn.Conv2d, nn.Linear} - -# # Apply dynamic quantization -# model_quantized = torch.quantization.quantize_dynamic( -# model, quantize_layers, dtype=torch.qint8 -# ) - -# # Save the entire quantized model -# torch.save(model_quantized, "./models/DDPM_conditional/ema_ckpt_quantized.pth") - -import torch -from ddpm_conditional import Diffusion -from torchvision.utils import save_image - -# Load the entire quantized model -device = "cpu" -quantized_model = torch.load("./models/DDPM_conditional/ema_ckpt_quantized.pth", map_location=device) - -# Prepare the model for inference -quantized_model.eval() - -# Initialize the diffusion process -diffusion = Diffusion(img_size=64, device=device) - -# Sample data for inference -n = 1 -y = torch.Tensor([6] * n).long().to(device) - -# Generate sample using the quantized model -with torch.no_grad(): - outputs = diffusion.sample(quantized_model, n, 1,labels=y, cfg_scale=0) - # Assuming outputs is a list, select the first tensor - x = outputs[0] if isinstance(outputs, list) else outputs - - # Normalize the output to [0, 1] range - x = (x - x.min()) / (x.max() - x.min()) - -# Save the output image -save_image(x, "./test_quantized.jpg") \ No newline at end of file diff --git a/utils.py b/utils.py index e627f4c..c19b4f1 100644 --- a/utils.py +++ b/utils.py @@ -7,6 +7,7 @@ from ddpm_conditional import Diffusion from pytorch_fid import fid_score import time +import random def plot_images(images): @@ -77,18 +78,32 @@ def print_size_of_model(model): print('Size:', os.path.getsize("temp.p")/1e6, 'MB') os.remove('temp.p') -def get_dataset(name_or_path, transform=None): - - print(name_or_path) - if "cifar10-64" in name_or_path.lower(): - if transform is None: - transform = torchvision.transforms.Compose([ - torchvision.transforms.Resize(80), # configs.image_size + 1/4 *configs.image_size - torchvision.transforms.RandomResizedCrop(64, scale=(0.8, 1.0)), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - ]) - # print(name_or_path) - # print(transform) - dataset = torchvision.datasets.ImageFolder(name_or_path, transform=transform) - return dataset \ No newline at end of file +def create_row_collage(base_dir, class_folders, output_path): + images = [] + # Load one random image from each class folder + for folder in class_folders: + class_dir = os.path.join(base_dir, folder) + files = [f for f in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, f))] + if not files: + continue # Skip folders with no images + selected_file = random.choice(files) + img_path = os.path.join(class_dir, selected_file) + images.append(Image.open(img_path)) + + # Calculate total width and max height for the collage + total_width = sum(img.width for img in images) + max_height = max(img.height for img in images) + + # Create a new blank image to paste the images into + collage = Image.new('RGB', (total_width, max_height)) + + # Paste images side by side + x_offset = 0 + for img in images: + collage.paste(img, (x_offset, 0)) + x_offset += img.width + + # Save the collage + collage.save(output_path) + + print("Collage saved to: ", output_path) \ No newline at end of file