Skip to content

Commit

Permalink
initial submit
Browse files Browse the repository at this point in the history
  • Loading branch information
mingcv committed Nov 30, 2021
1 parent 249df8c commit ed55371
Show file tree
Hide file tree
Showing 27 changed files with 3,230 additions and 0 deletions.
4 changes: 4 additions & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .low_light import *
from .low_light_test import *
from .mef import *

171 changes: 171 additions & 0 deletions datasets/low_light.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import os
import random

import torch
import torch.utils.data as data
import torchvision.transforms as T
from PIL import Image


class LowLightFDataset(data.Dataset):
def __init__(self, root, image_split='images_aug', targets_split='targets', training=True):
self.root = root
self.num_instances = 8
self.img_root = os.path.join(root, image_split)
self.target_root = os.path.join(root, targets_split)
self.training = training
print('----', image_split, targets_split, '----')
self.imgs = list(sorted(os.listdir(self.img_root)))
self.gts = list(sorted(os.listdir(self.target_root)))

names = [img_name.split('_')[0] + '.' + img_name.split('.')[-1] for img_name in self.imgs]
self.imgs = list(
filter(lambda img_name: img_name.split('_')[0] + '.' + img_name.split('.')[-1] in self.gts, self.imgs))

self.gts = list(filter(lambda gt: gt in names, self.gts))

print(len(self.imgs), len(self.gts))
self.preproc = T.Compose(
[T.ToTensor()]
)
self.preproc_gt = T.Compose(
[T.ToTensor()]
)

def __getitem__(self, idx):
fn, ext = self.gts[idx].split('.')
imgs = []
for i in range(self.num_instances):
img_path = os.path.join(self.img_root, f"{fn}_{i}.{ext}")
imgs += [self.preproc(Image.open(img_path).convert("RGB"))]

if self.training:
random.shuffle(imgs)
gt_path = os.path.join(self.target_root, self.gts[idx])
gt = Image.open(gt_path).convert("RGB")
gt = self.preproc_gt(gt)

# print(img_path, gt_path)
return torch.stack(imgs, dim=0), gt, fn

def __len__(self):
return len(self.gts)


class LowLightFDatasetEval(data.Dataset):
def __init__(self, root, targets_split='targets', training=True):
self.root = root
self.num_instances = 1
self.img_root = os.path.join(root, 'images')
self.target_root = os.path.join(root, targets_split)
self.training = training

self.imgs = list(sorted(os.listdir(self.img_root)))
self.gts = list(sorted(os.listdir(self.target_root)))

self.imgs = list(filter(lambda img_name: img_name in self.gts, self.imgs))
self.gts = list(filter(lambda gt: gt in self.imgs, self.gts))

print(len(self.imgs), len(self.gts))
self.preproc = T.Compose(
[T.ToTensor()]
)
self.preproc_gt = T.Compose(
[T.ToTensor()]
)

def __getitem__(self, idx):
fn, ext = self.gts[idx].split('.')
imgs = []
for i in range(self.num_instances):
img_path = os.path.join(self.img_root, f"{fn}.{ext}")
imgs += [self.preproc(Image.open(img_path).convert("RGB"))]

gt_path = os.path.join(self.target_root, self.gts[idx])
gt = Image.open(gt_path).convert("RGB")
gt = self.preproc_gt(gt)

# print(img_path, gt_path)
return torch.stack(imgs, dim=0), gt, fn

def __len__(self):
return len(self.gts)


class LowLightDataset(data.Dataset):
def __init__(self, root, targets_split='targets', color_tuning=False):
self.root = root
self.img_root = os.path.join(root, 'images')
self.target_root = os.path.join(root, targets_split)
self.color_tuning = color_tuning
self.imgs = list(sorted(os.listdir(self.img_root)))
self.gts = list(sorted(os.listdir(self.target_root)))

self.imgs = list(filter(lambda img_name: img_name in self.gts, self.imgs))
self.gts = list(filter(lambda gt: gt in self.imgs, self.gts))

print(len(self.imgs), len(self.gts))
self.preproc = T.Compose(
[T.ToTensor()]
)
self.preproc_gt = T.Compose(
[T.ToTensor()]
)

def __getitem__(self, idx):
fn, ext = self.gts[idx].split('.')

img_path = os.path.join(self.img_root, self.imgs[idx])
img = Image.open(img_path).convert("RGB")
img = self.preproc(img)

gt_path = os.path.join(self.target_root, self.gts[idx])
gt = Image.open(gt_path).convert("RGB")
gt = self.preproc_gt(gt)

if self.color_tuning:
return img, gt, 'a' + self.imgs[idx], 'a' + self.imgs[idx]
else:
return img, gt, fn

def __len__(self):
return len(self.imgs)


class LowLightDatasetReverse(data.Dataset):
def __init__(self, root, targets_split='targets', color_tuning=False):
self.root = root
self.img_root = os.path.join(root, 'images')
self.target_root = os.path.join(root, targets_split)
self.color_tuning = color_tuning
self.imgs = list(sorted(os.listdir(self.img_root)))
self.gts = list(sorted(os.listdir(self.target_root)))

self.imgs = list(filter(lambda img_name: img_name in self.gts, self.imgs))
self.gts = list(filter(lambda gt: gt in self.imgs, self.gts))

print(len(self.imgs), len(self.gts))
self.preproc = T.Compose(
[T.ToTensor()]
)
self.preproc_gt = T.Compose(
[T.ToTensor()]
)

def __getitem__(self, idx):
img_path = os.path.join(self.img_root, self.imgs[idx])
img = Image.open(img_path).convert("RGB")
img = self.preproc(img)

gt_path = os.path.join(self.target_root, self.gts[idx])
gt = Image.open(gt_path).convert("RGB")
gt = self.preproc_gt(gt)

if self.color_tuning:
return gt, img, 'a' + self.imgs[idx], 'a' + self.imgs[idx]
else:
fn, ext = os.path.splitext(self.imgs[idx])
return gt, img, '%03d' % int(fn) + ext

def __len__(self):
return len(self.imgs)
41 changes: 41 additions & 0 deletions datasets/low_light_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os

import torch.utils.data as data
import torchvision.transforms as T
from PIL import Image


class LowLightDatasetTest(data.Dataset):
def __init__(self, root, reside=False):
self.root = root
self.items = []

subsets = os.listdir(root)
for subset in subsets:
img_root = os.path.join(root, subset)
img_names = list(sorted(os.listdir(img_root)))

for img_name in img_names:
self.items.append((
os.path.join(img_root, img_name),
subset,
img_name
))

self.preproc = T.Compose(
[T.ToTensor()]
)
self.preproc_raw = T.Compose(
[T.ToTensor()]
)

def __getitem__(self, idx):
img_path, subset, img_name = self.items[idx]
img = Image.open(img_path).convert("RGB")
img = img.resize((img.width // 8 * 8, img.height // 8 * 8), Image.ANTIALIAS)
img_raw = self.preproc_raw(img)

return img_raw, subset, img_name

def __len__(self):
return len(self.items)
36 changes: 36 additions & 0 deletions datasets/mef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import random

import torch.utils.data as data
import torchvision.transforms as T
from PIL import Image


class MEFDataset(data.Dataset):
def __init__(self, root):
self.img_root = root

self.numbers = list(sorted(os.listdir(self.img_root)))
print(len(self.numbers))

self.preproc = T.Compose(
[T.ToTensor()]
)

def __getitem__(self, idx):
number = self.numbers[idx]
im_dir = os.path.join(self.img_root, number)
fn1, fn2 = tuple(random.sample(os.listdir(im_dir), k=2))
fp1 = os.path.join(im_dir, fn1)
fp2 = os.path.join(im_dir, fn2)
img1 = Image.open(fp1).convert("RGB")
img2 = Image.open(fp2).convert("RGB")
img1 = self.preproc(img1)
img2 = self.preproc(img2)

fn1 = f'{number}_{fn1}'
fn2 = f'{number}_{fn2}'
return img1, img2, fn1, fn2

def __len__(self):
return len(self.numbers)
Loading

0 comments on commit ed55371

Please sign in to comment.