-
Notifications
You must be signed in to change notification settings - Fork 34
/
MT-A_demo.py
114 lines (90 loc) · 3.57 KB
/
MT-A_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# ========================================================================
# Perceptual Quality Assessment of Smartphone Photography
# PyTorch Version 1.0 by Hanwei Zhu
# Copyright(c) 2020 Yuming Fang, Hanwei Zhu, Yan Zeng, Kede Ma, and Zhou Wang
# All Rights Reserved.
#
# ----------------------------------------------------------------------
# Permission to use, copy, or modify this software and its documentation
# for educational and research purposes only and without fee is hereby
# granted, provided that this copyright notice and the original authors'
# names appear on all copies and supporting documentation. This program
# shall not be used, rewritten, or adapted as the basis of a commercial
# software or hardware product without first obtaining permission of the
# authors. The authors make no representations about the suitability of
# this software for any purpose. It is provided "as is" without express
# or implied warranty.
# ----------------------------------------------------------------------
# This is an implementation of Multi-Task learning from image Attributes (MT-A)
# for blind image quality assessment.
# Please refer to the following paper:
#
# Y. Fang et al., "Perceptual Quality Assessment of Smartphone Photography"
# in IEEE Conference on Computer Vision and Pattern Recognition, 2020
#
# Kindly report any suggestions or corrections to [email protected]
# ========================================================================
import torch
import torch.nn as nn
import torchvision
from Prepare_image import Image_load
from PIL import Image
import argparse
import os
class MTA(nn.Module):
def __init__(self):
super(MTA, self).__init__()
self.backbone = torchvision.models.resnet50(pretrained=False)
fc_feature = self.backbone.fc.in_features
self.backbone.fc = nn.Linear(fc_feature, 6, bias=True)
def forward(self, x):
result = self.backbone(x)
return result
class Demo(object):
def __init__(self, config, load_weights=True, checkpoint_dir='./weights/MT-A_release.pt' ):
self.config = config
self.load_weights = load_weights
self.checkpoint_dir = checkpoint_dir
self.prepare_image = Image_load(size=512, stride=224)
self.model = MTA()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model_name = type(self.model).__name__
if self.load_weights:
self.initialize()
def predit_quality(self):
image_1 = self.prepare_image(Image.open(self.config.image_1).convert("RGB"))
image_2 = self.prepare_image(Image.open(self.config.image_2).convert("RGB"))
image_1 = image_1.to(self.device)
self.model.eval()
score_1 = self.model(image_1)[:, 0].mean()
print(score_1.item())
image_2 = image_2.to(self.device)
score_2 = self.model(image_2)[:, 0].mean()
print(score_2.item())
def initialize(self):
ckpt_path = self.checkpoint_dir
could_load = self._load_checkpoint(ckpt_path)
if could_load:
print('Checkpoint load successfully!')
else:
raise IOError('Fail to load the pretrained model')
def _load_checkpoint(self, ckpt):
if os.path.isfile(ckpt):
print("[*] loading checkpoint '{}'".format(ckpt))
checkpoint = torch.load(ckpt)
self.model.load_state_dict(checkpoint['state_dict'])
return True
else:
return False
def parse_config():
parser = argparse.ArgumentParser()
parser.add_argument('--image_1', type=str, default='./images/05293.png')
parser.add_argument('--image_2', type=str, default='./images/00914.png')
return parser.parse_args()
def main():
cfg = parse_config()
t = Demo(config=cfg)
t.predit_quality()
if __name__ == '__main__':
main()