-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
3,230 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .low_light import * | ||
from .low_light_test import * | ||
from .mef import * | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import os | ||
import random | ||
|
||
import torch | ||
import torch.utils.data as data | ||
import torchvision.transforms as T | ||
from PIL import Image | ||
|
||
|
||
class LowLightFDataset(data.Dataset): | ||
def __init__(self, root, image_split='images_aug', targets_split='targets', training=True): | ||
self.root = root | ||
self.num_instances = 8 | ||
self.img_root = os.path.join(root, image_split) | ||
self.target_root = os.path.join(root, targets_split) | ||
self.training = training | ||
print('----', image_split, targets_split, '----') | ||
self.imgs = list(sorted(os.listdir(self.img_root))) | ||
self.gts = list(sorted(os.listdir(self.target_root))) | ||
|
||
names = [img_name.split('_')[0] + '.' + img_name.split('.')[-1] for img_name in self.imgs] | ||
self.imgs = list( | ||
filter(lambda img_name: img_name.split('_')[0] + '.' + img_name.split('.')[-1] in self.gts, self.imgs)) | ||
|
||
self.gts = list(filter(lambda gt: gt in names, self.gts)) | ||
|
||
print(len(self.imgs), len(self.gts)) | ||
self.preproc = T.Compose( | ||
[T.ToTensor()] | ||
) | ||
self.preproc_gt = T.Compose( | ||
[T.ToTensor()] | ||
) | ||
|
||
def __getitem__(self, idx): | ||
fn, ext = self.gts[idx].split('.') | ||
imgs = [] | ||
for i in range(self.num_instances): | ||
img_path = os.path.join(self.img_root, f"{fn}_{i}.{ext}") | ||
imgs += [self.preproc(Image.open(img_path).convert("RGB"))] | ||
|
||
if self.training: | ||
random.shuffle(imgs) | ||
gt_path = os.path.join(self.target_root, self.gts[idx]) | ||
gt = Image.open(gt_path).convert("RGB") | ||
gt = self.preproc_gt(gt) | ||
|
||
# print(img_path, gt_path) | ||
return torch.stack(imgs, dim=0), gt, fn | ||
|
||
def __len__(self): | ||
return len(self.gts) | ||
|
||
|
||
class LowLightFDatasetEval(data.Dataset): | ||
def __init__(self, root, targets_split='targets', training=True): | ||
self.root = root | ||
self.num_instances = 1 | ||
self.img_root = os.path.join(root, 'images') | ||
self.target_root = os.path.join(root, targets_split) | ||
self.training = training | ||
|
||
self.imgs = list(sorted(os.listdir(self.img_root))) | ||
self.gts = list(sorted(os.listdir(self.target_root))) | ||
|
||
self.imgs = list(filter(lambda img_name: img_name in self.gts, self.imgs)) | ||
self.gts = list(filter(lambda gt: gt in self.imgs, self.gts)) | ||
|
||
print(len(self.imgs), len(self.gts)) | ||
self.preproc = T.Compose( | ||
[T.ToTensor()] | ||
) | ||
self.preproc_gt = T.Compose( | ||
[T.ToTensor()] | ||
) | ||
|
||
def __getitem__(self, idx): | ||
fn, ext = self.gts[idx].split('.') | ||
imgs = [] | ||
for i in range(self.num_instances): | ||
img_path = os.path.join(self.img_root, f"{fn}.{ext}") | ||
imgs += [self.preproc(Image.open(img_path).convert("RGB"))] | ||
|
||
gt_path = os.path.join(self.target_root, self.gts[idx]) | ||
gt = Image.open(gt_path).convert("RGB") | ||
gt = self.preproc_gt(gt) | ||
|
||
# print(img_path, gt_path) | ||
return torch.stack(imgs, dim=0), gt, fn | ||
|
||
def __len__(self): | ||
return len(self.gts) | ||
|
||
|
||
class LowLightDataset(data.Dataset): | ||
def __init__(self, root, targets_split='targets', color_tuning=False): | ||
self.root = root | ||
self.img_root = os.path.join(root, 'images') | ||
self.target_root = os.path.join(root, targets_split) | ||
self.color_tuning = color_tuning | ||
self.imgs = list(sorted(os.listdir(self.img_root))) | ||
self.gts = list(sorted(os.listdir(self.target_root))) | ||
|
||
self.imgs = list(filter(lambda img_name: img_name in self.gts, self.imgs)) | ||
self.gts = list(filter(lambda gt: gt in self.imgs, self.gts)) | ||
|
||
print(len(self.imgs), len(self.gts)) | ||
self.preproc = T.Compose( | ||
[T.ToTensor()] | ||
) | ||
self.preproc_gt = T.Compose( | ||
[T.ToTensor()] | ||
) | ||
|
||
def __getitem__(self, idx): | ||
fn, ext = self.gts[idx].split('.') | ||
|
||
img_path = os.path.join(self.img_root, self.imgs[idx]) | ||
img = Image.open(img_path).convert("RGB") | ||
img = self.preproc(img) | ||
|
||
gt_path = os.path.join(self.target_root, self.gts[idx]) | ||
gt = Image.open(gt_path).convert("RGB") | ||
gt = self.preproc_gt(gt) | ||
|
||
if self.color_tuning: | ||
return img, gt, 'a' + self.imgs[idx], 'a' + self.imgs[idx] | ||
else: | ||
return img, gt, fn | ||
|
||
def __len__(self): | ||
return len(self.imgs) | ||
|
||
|
||
class LowLightDatasetReverse(data.Dataset): | ||
def __init__(self, root, targets_split='targets', color_tuning=False): | ||
self.root = root | ||
self.img_root = os.path.join(root, 'images') | ||
self.target_root = os.path.join(root, targets_split) | ||
self.color_tuning = color_tuning | ||
self.imgs = list(sorted(os.listdir(self.img_root))) | ||
self.gts = list(sorted(os.listdir(self.target_root))) | ||
|
||
self.imgs = list(filter(lambda img_name: img_name in self.gts, self.imgs)) | ||
self.gts = list(filter(lambda gt: gt in self.imgs, self.gts)) | ||
|
||
print(len(self.imgs), len(self.gts)) | ||
self.preproc = T.Compose( | ||
[T.ToTensor()] | ||
) | ||
self.preproc_gt = T.Compose( | ||
[T.ToTensor()] | ||
) | ||
|
||
def __getitem__(self, idx): | ||
img_path = os.path.join(self.img_root, self.imgs[idx]) | ||
img = Image.open(img_path).convert("RGB") | ||
img = self.preproc(img) | ||
|
||
gt_path = os.path.join(self.target_root, self.gts[idx]) | ||
gt = Image.open(gt_path).convert("RGB") | ||
gt = self.preproc_gt(gt) | ||
|
||
if self.color_tuning: | ||
return gt, img, 'a' + self.imgs[idx], 'a' + self.imgs[idx] | ||
else: | ||
fn, ext = os.path.splitext(self.imgs[idx]) | ||
return gt, img, '%03d' % int(fn) + ext | ||
|
||
def __len__(self): | ||
return len(self.imgs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
|
||
import torch.utils.data as data | ||
import torchvision.transforms as T | ||
from PIL import Image | ||
|
||
|
||
class LowLightDatasetTest(data.Dataset): | ||
def __init__(self, root, reside=False): | ||
self.root = root | ||
self.items = [] | ||
|
||
subsets = os.listdir(root) | ||
for subset in subsets: | ||
img_root = os.path.join(root, subset) | ||
img_names = list(sorted(os.listdir(img_root))) | ||
|
||
for img_name in img_names: | ||
self.items.append(( | ||
os.path.join(img_root, img_name), | ||
subset, | ||
img_name | ||
)) | ||
|
||
self.preproc = T.Compose( | ||
[T.ToTensor()] | ||
) | ||
self.preproc_raw = T.Compose( | ||
[T.ToTensor()] | ||
) | ||
|
||
def __getitem__(self, idx): | ||
img_path, subset, img_name = self.items[idx] | ||
img = Image.open(img_path).convert("RGB") | ||
img = img.resize((img.width // 8 * 8, img.height // 8 * 8), Image.ANTIALIAS) | ||
img_raw = self.preproc_raw(img) | ||
|
||
return img_raw, subset, img_name | ||
|
||
def __len__(self): | ||
return len(self.items) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import os | ||
import random | ||
|
||
import torch.utils.data as data | ||
import torchvision.transforms as T | ||
from PIL import Image | ||
|
||
|
||
class MEFDataset(data.Dataset): | ||
def __init__(self, root): | ||
self.img_root = root | ||
|
||
self.numbers = list(sorted(os.listdir(self.img_root))) | ||
print(len(self.numbers)) | ||
|
||
self.preproc = T.Compose( | ||
[T.ToTensor()] | ||
) | ||
|
||
def __getitem__(self, idx): | ||
number = self.numbers[idx] | ||
im_dir = os.path.join(self.img_root, number) | ||
fn1, fn2 = tuple(random.sample(os.listdir(im_dir), k=2)) | ||
fp1 = os.path.join(im_dir, fn1) | ||
fp2 = os.path.join(im_dir, fn2) | ||
img1 = Image.open(fp1).convert("RGB") | ||
img2 = Image.open(fp2).convert("RGB") | ||
img1 = self.preproc(img1) | ||
img2 = self.preproc(img2) | ||
|
||
fn1 = f'{number}_{fn1}' | ||
fn2 = f'{number}_{fn2}' | ||
return img1, img2, fn1, fn2 | ||
|
||
def __len__(self): | ||
return len(self.numbers) |
Oops, something went wrong.