Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed to different color space #45

Open
wants to merge 39 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3689e11
Update model.py
lujulia Dec 7, 2022
fd81105
Update lowlight_train.py
lujulia Dec 7, 2022
20d7b33
Update lowlight_test.py
lujulia Dec 7, 2022
87fabca
Update lowlight_train.py
lujulia Dec 7, 2022
e2382ae
Update lowlight_test.py
lujulia Dec 7, 2022
33db104
Update README.md
lujulia Dec 7, 2022
a843d39
Update README.md
lujulia Dec 7, 2022
3dacf9c
Update README.md
lujulia Dec 7, 2022
683204d
Update README.md
lujulia Dec 7, 2022
397e390
Update README.md
lujulia Dec 7, 2022
3f7eb2c
Update README.md
lujulia Dec 8, 2022
f573214
Update lowlight_train.py
lujulia Dec 8, 2022
c4fb302
Update dataloader.py
lujulia Dec 8, 2022
e34ccb8
Update dataloader.py
lujulia Dec 8, 2022
f13ef1f
Update model.py
lujulia Dec 8, 2022
38f7ade
Update dataloader.py
lujulia Dec 8, 2022
e6d0888
Update lowlight_train.py
lujulia Dec 8, 2022
a2ab5ef
Update lowlight_test.py
lujulia Dec 8, 2022
94f9d03
Update lowlight_train.py
lujulia Dec 8, 2022
3105510
Update lowlight_test.py
lujulia Dec 9, 2022
64c5a04
Add files via upload
lujulia Dec 9, 2022
0ed3a94
Update README.md
lujulia Dec 9, 2022
f7c21d5
Update README.md
lujulia Dec 9, 2022
9fc04ed
Update README.md
lujulia Dec 9, 2022
a7a10df
Update README.md
lujulia Dec 9, 2022
d211f58
Update lowlight_test.py
lujulia Dec 9, 2022
d80fc52
Update README.md
lujulia Sep 8, 2023
dd00cb3
Update README.md
lujulia Sep 8, 2023
d049387
Update README.md
lujulia Sep 8, 2023
86cfb82
Update README.md
lujulia Sep 8, 2023
8c47ede
Update README.md
lujulia Sep 8, 2023
e4e617a
Update README.md
lujulia Sep 8, 2023
81656b4
Update README.md
lujulia Sep 8, 2023
df23ae5
Update README.md
lujulia Sep 8, 2023
b319b0b
Update README.md
lujulia Sep 8, 2023
953ef65
Update README.md
lujulia Sep 8, 2023
3aa2f83
Update README.md
lujulia Sep 23, 2023
ad44289
Update lowlight_test.py
lujulia Sep 23, 2023
08516f9
Update README.md
lujulia Jan 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,32 @@
# What is different between this fixed version and the original ZeroDCE?

(1) providing 7 different color spaces for training ("RGB", "HSV", "HLS", "YCbCr", "YUV", "LAB", and "LUV").

```
cd Zero-DCE_code
```
```
python lowlight_train.py --channel ("RGB", "HSV", "HLS", "YCbCr", "YUV", "LAB", and "LUV")
```
(2) providing 7 different color spaces of 200 epochs pretrained weight.

```
./Zero-DCE_code/snapshots/("RGB", "HSV", "HLS", "YCbCr", "YUV", "LAB", and "LUV").pth
```
(3) providing applications on videos.

```
cd Zero-DCE_code
```
```
python lowlight_test.py --mode (video/image) --channel ("RGB", "HSV", "HLS", "YCbCr", "YUV", "LAB", and "LUV")
```
(4) providing a tensorboard to display training loss.

```
tensorboard --logdir log/train_loss_("RGB", "HSV", "HLS", "YCbCr", "YUV", "LAB", and "LUV")
```

# Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement

You can find more details here: https://li-chongyi.github.io/Proj_Zero-DCE.html. Have fun!
Expand Down Expand Up @@ -76,8 +105,5 @@ The code is made available for academic research purpose only. Under Attribution
(Full paper: http://openaccess.thecvf.com/content_CVPR_2020/papers/Guo_Zero-Reference_Deep_Curve_Estimation_for_Low-Light_Image_Enhancement_CVPR_2020_paper.pdf)

## Contact
If you have any questions, please contact Chongyi Li at [email protected] or Chunle Guo at [email protected].

## TensorFlow Version
Thanks tuvovan ([email protected]) who re-produces our code by TF. The results of TF version look similar with our Pytorch version. But I do not have enough time to check the details.
https://github.com/tuvovan/Zero_DCE_TF
If you have any questions, please contact ICHEN LU at [email protected].
98 changes: 64 additions & 34 deletions Zero-DCE_code/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,75 @@

def populate_train_list(lowlight_images_path):




image_list_lowlight = glob.glob(lowlight_images_path + "*.jpg")
image_list_lowlight = glob.glob(lowlight_images_path + "*.JPG")

train_list = image_list_lowlight

random.shuffle(train_list)

return train_list



class lowlight_loader(data.Dataset):

def __init__(self, lowlight_images_path):

self.train_list = populate_train_list(lowlight_images_path)
self.size = 256

self.data_list = self.train_list
print("Total training examples:", len(self.train_list))




def __getitem__(self, index):

data_lowlight_path = self.data_list[index]

data_lowlight = Image.open(data_lowlight_path)

data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)

data_lowlight = (np.asarray(data_lowlight)/255.0)
data_lowlight = torch.from_numpy(data_lowlight).float()

return data_lowlight.permute(2,0,1)

def __len__(self):
return len(self.data_list)

def __init__(self, lowlight_images_path, channel):
self.train_list = populate_train_list(lowlight_images_path)
self.size = 256
self.channel = channel

self.data_list = self.train_list
print("Total training examples:", len(self.train_list))

def __getitem__(self, index):

data_lowlight_path = self.data_list[index]
data_lowlight = cv2.imread(data_lowlight_path)
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_BGR2RGB)
data_lowlight = cv2.resize(data_lowlight,(self.size,self.size),interpolation = cv2.INTER_AREA)
if self.channel=="RGB":
data_lowlight = (np.asarray(data_lowlight)/255.0)
else:
if self.channel=="HSV":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2HSV)
H, S, V = cv2.split(data_lowlight)
data_lowlight = np.asarray(data_lowlight).copy()
data_lowlight_1 = ((H)/(180.0))
data_lowlight_2 = ((S)/(255.0))
data_lowlight_3 = ((V)/(255.0))
elif self.channel=="HLS":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2HLS)
H, L, S = cv2.split(data_lowlight)
data_lowlight = np.asarray(data_lowlight).copy()
data_lowlight_1 = ((H)/(180.0))
data_lowlight_2 = ((L)/(255.0))
data_lowlight_3 = ((S)/(255.0))
elif self.channel=="YCbCr":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2YCrCb)
Y, Cr, Cb = cv2.split(data_lowlight)
data_lowlight_1 = ((Y)/(255.0))
data_lowlight_2 = ((Cr)/(255.0))
data_lowlight_3 = ((Cb)/(255.0))
elif self.channel=="YUV":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2YUV)
Y, U, V = cv2.split(data_lowlight)
data_lowlight_1 = ((Y)/(255.0))
data_lowlight_2 = ((U)/(255.0))
data_lowlight_3 = ((V)/(255.0))
elif self.channel=="LAB":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2Lab)
L, A, B = cv2.split(data_lowlight)
data_lowlight_1 = ((L-0.0)/(255.0-0.0))
data_lowlight_2 = ((A-1.0)/(255.0-1.0))
data_lowlight_3 = ((B-1.0)/(255.0-1.0))
elif self.channel=="LUV":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2Luv)
L, U, V = cv2.split(data_lowlight)
data_lowlight_1 = ((L)/(255.0))
data_lowlight_2 = ((U)/(255.0))
data_lowlight_3 = ((V)/(255.0))

data_lowlight = cv2.merge([data_lowlight_1,data_lowlight_2,data_lowlight_3])

data_lowlight = torch.from_numpy(data_lowlight).float()
data_lowlight = data_lowlight.permute(2,0,1)
return data_lowlight
def __len__(self):
return len(self.data_list)
186 changes: 151 additions & 35 deletions Zero-DCE_code/lowlight_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,165 @@
from PIL import Image
import glob
import time
import argparse
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image
import cv2
import tensorflow as tf




def lowlight(image_path):
os.environ['CUDA_VISIBLE_DEVICES']='0'
data_lowlight = Image.open(image_path)



data_lowlight = (np.asarray(data_lowlight)/255.0)
def Color_Choice(color_space,data_lowlight):
#data_lowlight = Image.open(data_lowlight_path).convert(color)
#data_lowlight = cv2.imread(data_lowlight_path)
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_BGR2RGB)
if color_space == "RGB":
data_lowlight = (np.asarray(data_lowlight)/255.0)
n = [255,0,255,0,255,0]
back = cv2.COLOR_RGB2BGR
else:
if color_space == "HSV":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2HSV)
n = [180,0,255,0,255,0]
back = cv2.COLOR_HSV2BGR
elif color_space == "HLS":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2HLS)
n = [180,0,255,0,255,0]
back = cv2.COLOR_HLS2BGR
elif color_space == "YCbCr":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2YCrCb)
n = [255,0,255,0,255,0]
back = cv2.COLOR_YCrCb2BGR
elif color_space == "YUV":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2YUV)
n = [255,0,255,0,255,0]
back = cv2.COLOR_YUV2BGR
elif color_space == "LAB":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2Lab)
n = [255,0,255-1,1,255-1,1]
back = cv2.COLOR_Lab2BGR
elif color_space == "LUV":
data_lowlight = cv2.cvtColor(data_lowlight, cv2.COLOR_RGB2Luv)
n = [255,0,255,0,255,0]
back = cv2.COLOR_Luv2BGR
c1,c2,c3 = cv2.split(data_lowlight)
data_lowlight_1 = ((c1-n[1])/(n[0]))
data_lowlight_2 = ((c2-n[3])/(n[2]))
data_lowlight_3 = ((c3-n[5])/(n[4]))
data_lowlight = cv2.merge([data_lowlight_1,data_lowlight_2,data_lowlight_3])

data_lowlight = torch.from_numpy(data_lowlight).float()
data_lowlight = data_lowlight.permute(2,0,1)
return data_lowlight,n,back


data_lowlight = torch.from_numpy(data_lowlight).float()
data_lowlight = data_lowlight.permute(2,0,1)
data_lowlight = data_lowlight.cuda().unsqueeze(0)

DCE_net = model.enhance_net_nopool().cuda()
DCE_net.load_state_dict(torch.load('snapshots/Epoch99.pth'))
start = time.time()
_,enhanced_image,_ = DCE_net(data_lowlight)
def lowlight(color_channel,lowlight_image):
os.environ['CUDA_VISIBLE_DEVICES']='0'

data_lowlight,con,inchan = Color_Choice(color_channel,lowlight_image)
data_lowlight = data_lowlight.cuda().unsqueeze(0)

if config.channel=="RGB":
DCE_net = model.enhance_net_nopool_3().cuda()
elif config.channel=="HSV":
DCE_net = model.enhance_net_nopool_1_3().cuda()
elif config.channel=="HLS":
DCE_net = model.enhance_net_nopool_1_2().cuda()
elif config.channel=="YCbCr" or config.channel=="YUV" or config.channel=="LAB" or config.channel=="LUV":
DCE_net = model.enhance_net_nopool_1_1().cuda()

DCE_net.load_state_dict(torch.load("snapshots/"+config.channel+".pth"))
#start = time.time()

_,enhanced_image,_ = DCE_net(data_lowlight)
data_lowlight = enhanced_image[0].permute(1,2,0).cpu().numpy()
temp1 = np.zeros((data_lowlight[:,:,0].shape[0],data_lowlight[:,:,0].shape[1]), dtype="uint8")
temp2 = np.zeros((data_lowlight[:,:,0].shape[0],data_lowlight[:,:,0].shape[1]), dtype="uint8")
temp3 = np.zeros((data_lowlight[:,:,0].shape[0],data_lowlight[:,:,0].shape[1]), dtype="uint8")
temp1[:,:] = (data_lowlight[:,:,0]*con[0]+con[1]).astype(dtype="uint8")
temp2[:,:] = (data_lowlight[:,:,1]*con[2]+con[3]).astype(dtype="uint8")
temp3[:,:] = (data_lowlight[:,:,2]*con[4]+con[5]).astype(dtype="uint8")

end_time = (time.time() - start)
print(end_time)
image_path = image_path.replace('test_data','result')
result_path = image_path
if not os.path.exists(image_path.replace('/'+image_path.split("/")[-1],'')):
os.makedirs(image_path.replace('/'+image_path.split("/")[-1],''))
data_enhanced = cv2.cvtColor(cv2.merge([temp1,temp2,temp3]), inchan)

#end_time = (time.time() - start)
#print(end_time)
return data_enhanced
"""
result_path = lowlight_images_path.replace('test_data','result')
if not os.path.exists(result_path.replace('/'+result_path.split("/")[-1],'')):
os.makedirs(result_path.replace('/'+result_path.split("/")[-1],''))

torchvision.utils.save_image(enhanced_image, result_path)
"""

torchvision.utils.save_image(enhanced_image, result_path)

if __name__ == '__main__':
# test_images
with torch.no_grad():
filePath = 'data/test_data/'

file_list = os.listdir(filePath)

for file_name in file_list:
test_list = glob.glob(filePath+file_name+"/*")
for image in test_list:
parser = argparse.ArgumentParser()
parser.add_argument("--lowlight_images_path", type=str, default="data/test_data/")
parser.add_argument("--mode", type=str, default="image")
parser.add_argument("--channel", type=str, default="RGB")
parser.add_argument("--save_images_path", type=str, default="data/result")
config = parser.parse_args()
with torch.no_grad():
file_path = config.lowlight_images_path
file_list = os.listdir(config.lowlight_images_path)
if config.mode == "image":
start1 = time.time()
for file_name in file_list:
test_list = glob.glob(file_path+file_name+"/*")
for image_path in test_list:
# image = image
print(image)
lowlight(image)



print(image_path)
start2 = time.time()
data_lowlight = cv2.imread(image_path)
data_enhanced = lowlight(config.channel,data_lowlight)
result_path = image_path.replace('test_data','result')
if not os.path.exists(result_path.replace('/'+result_path.split("/")[-1],'')):
os.makedirs(result_path.replace('/'+result_path.split("/")[-1],''))
cv2.imwrite(result_path,data_enhanced)
end_time2 = (time.time() - start2)
print("executive time of each frame: ", end_time2)
end_time1 = (time.time() - start1)
print("executive time of all images: ", end_time1)
elif config.mode == "video":
start1 = time.time()
for file_name in file_list:
test_list = glob.glob(file_path+file_name+"/*")
for video_path in test_list:
print(video_path)
start2 = time.time()
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
fps = cap.get(cv2.CAP_PROP_FPS)
result_path = video_path.replace('test_data','result')
if not os.path.exists(result_path.replace('/'+result_path.split("/")[-1],'')):
os.makedirs(result_path.replace('/'+result_path.split("/")[-1],''))
vid_writer = cv2.VideoWriter(
result_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (int(width), int(height))
)
i = 0
while True:
ret_val, frame = cap.read()
if ret_val:
start3 = time.time()
frame = np.array(frame)
lowlight(config.channel,frame)
frame_enhanced = lowlight(config.channel,frame)
vid_writer.write(frame_enhanced)
end_time3 = (time.time() - start3)
print("executive time of "+str(i+1)+"-th frame: ", end_time3)
i = i+1
else:
break
end_time2 = (time.time() - start2)
print("original frame per second: ",fps)
print("executive time of full video: ", end_time2)
end_time1 = (time.time() - start1)
print("executive time of all videos: ", end_time1)
Loading