Skip to content

Commit

Permalink
data loading and reconst algo for new use case
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 6, 2024
1 parent 60ac205 commit c7392a3
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 0 deletions.
1 change: 1 addition & 0 deletions MST
Submodule MST added at ae6ce9
205 changes: 205 additions & 0 deletions data_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import os
import torch
import scipy.io as sio
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import random
from pytorch_lightning import LightningDataModule


class CubesDataset(Dataset):
def __init__(self, data_dir, augment=True):
self.data_dir = data_dir
self.augment_ = augment
self.data_file_names = sorted(os.listdir(self.data_dir))

def __len__(self):
return len(self.data_file_names)

def __getitem__(self, idx):

cube, wavelengths = self.load_hyperspectral_data(idx) # H x W x lambda

if self.augment_:
cube = self.augment(cube) # lambda x H x W
else:
cube = torch.from_numpy(np.transpose(cube, (2, 0, 1))).float()[:,:128,:128] # lambda x H x W

return cube, wavelengths

def load_hyperspectral_data(self, idx):
file_path = os.path.join(self.data_dir, self.data_file_names[idx])
data = sio.loadmat(file_path)
if "img_expand" in data:
cube = data['img_expand'] / 65536.
elif "img" in data:
cube = data['img'] / 65536.
cube = cube.astype(np.float32) # H x W x lambda
wavelengths = np.linspace(450, 650, 28)

return cube, wavelengths

def augment(self, img, crop_size = 128):
h, w, _ = img.shape
x_index = np.random.randint(0, h - crop_size)
y_index = np.random.randint(0, w - crop_size)
processed_data = np.zeros((crop_size, crop_size, 28), dtype=np.float32)
processed_data = img[x_index:x_index + crop_size, y_index:y_index + crop_size, :]
processed_data = torch.from_numpy(np.transpose(processed_data, (2, 0, 1))).float()
processed_data = arguement_1(processed_data)

""" # The other half data use splicing.
processed_data = np.zeros((4, crop_size//2, crop_size//2, 28), dtype=np.float32)
for i in range(batch_size - batch_size // 2):
sample_list = np.random.randint(0, len(train_data), 4)
for j in range(4):
x_index = np.random.randint(0, h-crop_size//2)
y_index = np.random.randint(0, w-crop_size//2)
processed_data[j] = train_data[sample_list[j]][x_index:x_index+crop_size//2,y_index:y_index+crop_size//2,:]
gt_batch_2 = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))).cuda() # [4,28,128,128]
gt_batch.append(arguement_2(gt_batch_2))
gt_batch = torch.stack(gt_batch, dim=0) """
return processed_data


class CubesDataModule(LightningDataModule):
def __init__(self, data_dir, batch_size, num_workers=1):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset = CubesDataset(self.data_dir,augment=True)

def setup(self, stage=None):
dataset_size = len(self.dataset)
train_size = int(0.7 * dataset_size)
val_size = int(0.2 * dataset_size)
test_size = dataset_size - train_size - val_size

self.train_ds, self.val_ds, self.test_ds = random_split(self.dataset, [train_size, val_size, test_size])

def train_dataloader(self):
return DataLoader(self.train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True)

def val_dataloader(self):
return DataLoader(self.val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False)

def test_dataloader(self):
return DataLoader(self.test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False)


def arguement_1(x):
"""
:param x: c,h,w
:return: c,h,w
"""
rotTimes = random.randint(0, 3)
vFlip = random.randint(0, 1)
hFlip = random.randint(0, 1)
# Random rotation
for j in range(rotTimes):
x = torch.rot90(x, dims=(1, 2))
# Random vertical Flip
for j in range(vFlip):
x = torch.flip(x, dims=(2,))
# Random horizontal Flip
for j in range(hFlip):
x = torch.flip(x, dims=(1,))
return x


def shuffle_crop(train_data, batch_size, crop_size=256, argument=True):
if argument:
gt_batch = []
# The first half data use the original data.
index = np.random.choice(range(len(train_data)), batch_size//2)
processed_data = np.zeros((batch_size//2, crop_size, crop_size, 28), dtype=np.float32)
for i in range(batch_size//2):
img = train_data[index[i]]
h, w, _ = img.shape
x_index = np.random.randint(0, h - crop_size)
y_index = np.random.randint(0, w - crop_size)
processed_data[i, :, :, :] = img[x_index:x_index + crop_size, y_index:y_index + crop_size, :]
processed_data = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))).cuda().float()
for i in range(processed_data.shape[0]):
gt_batch.append(arguement_1(processed_data[i]))

# The other half data use splicing.
processed_data = np.zeros((4, crop_size//2, crop_size//2, 28), dtype=np.float32)
for i in range(batch_size - batch_size // 2):
sample_list = np.random.randint(0, len(train_data), 4)
for j in range(4):
x_index = np.random.randint(0, h-crop_size//2)
y_index = np.random.randint(0, w-crop_size//2)
processed_data[j] = train_data[sample_list[j]][x_index:x_index+crop_size//2,y_index:y_index+crop_size//2,:]
gt_batch_2 = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))).cuda() # [4,28,128,128]
gt_batch.append(arguement_2(gt_batch_2))
gt_batch = torch.stack(gt_batch, dim=0)
return gt_batch
else:
index = np.random.choice(range(len(train_data)), batch_size)
processed_data = np.zeros((batch_size, crop_size, crop_size, 28), dtype=np.float32)
for i in range(batch_size):
h, w, _ = train_data[index[i]].shape
x_index = np.random.randint(0, h - crop_size)
y_index = np.random.randint(0, w - crop_size)
processed_data[i, :, :, :] = train_data[index[i]][x_index:x_index + crop_size, y_index:y_index + crop_size, :]
gt_batch = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2)))
return gt_batch

def arguement_2(generate_gt):
c, h, w = generate_gt.shape[1],generate_gt.shape[2],generate_gt.shape[3]
divid_point_h = h//2
divid_point_w = w//2
output_img = torch.zeros(c,h,w).cuda()
output_img[:, :divid_point_h, :divid_point_w] = generate_gt[0]
output_img[:, :divid_point_h, divid_point_w:] = generate_gt[1]
output_img[:, divid_point_h:, :divid_point_w] = generate_gt[2]
output_img[:, divid_point_h:, divid_point_w:] = generate_gt[3]
return output_img

# class AcquisitionDataset(Dataset):
# def __init__(self, input, hs_cubes, transform=None, target_transform=None):
# """_summary_

# Args:
# input (_type_): List of size 2 with each element being a list:
# - First list: List of n torch.tensor acquisitions (2D)
# - Second list: List of n int labels
# hs_cubes (_type_): List of size m, hs_cubes[m] being the m-th hs cube
# transform (_type_, optional): _description_. Defaults to None.
# target_transform (_type_, optional): _description_. Defaults to None.
# """
# self.data = input # list of size 2, first elem is a list of n torch.tensor acquisitions (input), second elem is a list of size n with the index of corresponding hs cubes (output)
# self.labels = self.data[1]

# self.cubes = hs_cubes # list of cubes, number of cubes must be >= max(self.labels)

# self.transform = transform
# self.target_transform = target_transform

# def __len__(self):
# return len(self.data[1])

# def __getitem__(self, index):
# acq = self.data[0][index] # torch tensor of size x*y
# cube = self.cubes[self.labels[index]] # torch tensor of size x*y*w

# return acq, cube

if __name__ == "__main__":
data_dir = "/local/users/ademaio/lpaillet/mst_datasets"
datamodule = CubesDataModule(data_dir, batch_size=5, num_workers=2)
31 changes: 31 additions & 0 deletions mask_optim_recon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from simca import load_yaml_config
from simca.CassiSystemOptim import CassiSystemOptim
from simca.CassiSystem import CassiSystem
from data_handler import CubesDataModule
import numpy as np
import snoop
import matplotlib.pyplot as plt
import matplotlib.animation as anim
#import matplotlib
import torch
import time, datetime
import os

config_dataset = load_yaml_config("simca/configs/dataset.yml")
config_patterns = load_yaml_config("simca/configs/pattern.yml")
config_acquisition = load_yaml_config("simca/configs/acquisition.yml")
config_system = load_yaml_config("simca/configs/cassi_system_simple_optim_max_center.yml")

# Load datacubes
# Generate random mask
# Run SIMCA to make acquisition 1
# ResNET -> mask
# Run SIMCA to make acquisition 2
# Reconstruction MST/CST -> out_cube
# Compare out_cube with datacube to compute loss

data_dir = "/local/users/ademaio/lpaillet/mst_datasets"
datamodule = CubesDataModule(data_dir, batch_size=5, num_workers=2)

# cassi_system.dataset = datamodule.train_dataloader[i][0]
# cassi_system.wavelengths = datamodule.train_dataloader[i][1]

0 comments on commit c7392a3

Please sign in to comment.