-
Notifications
You must be signed in to change notification settings - Fork 50
/
market1501.py
127 lines (101 loc) · 3.58 KB
/
market1501.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import collections
import os
import random
import re
from torch.utils.data import dataset, sampler
from torchvision.datasets.folder import default_loader
def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'):
return sorted([os.path.join(root, f)
for root, _, files in os.walk(directory) for f in files
if re.match(r'([\w]+\.(?:' + ext + '))', f)])
class Market1501(dataset.Dataset):
"""
Attributes:
imgs (list of str): dataset image file paths
_id2label (dict): mapping from person id to softmax continuous label
"""
@staticmethod
def id(file_path):
"""
:param file_path: unix style file path
:return: person id
"""
return int(file_path.split('/')[-1].split('_')[0])
@staticmethod
def camera(file_path):
"""
:param file_path: unix style file path
:return: camera id
"""
return int(file_path.split('/')[-1].split('_')[1][1])
@property
def ids(self):
"""
:return: person id list corresponding to dataset image paths
"""
return [self.id(path) for path in self.imgs]
@property
def unique_ids(self):
"""
:return: unique person ids in ascending order
"""
return sorted(set(self.ids))
@property
def cameras(self):
"""
:return: camera id list corresponding to dataset image paths
"""
return [self.camera(path) for path in self.imgs]
def __init__(self, root, transform=None, target_transform=None, loader=default_loader):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.imgs = [path for path in list_pictures(self.root) if self.id(path) != -1]
# convert person id to softmax continuous label
self._id2label = {_id: idx for idx, _id in enumerate(self.unique_ids)}
def __getitem__(self, index):
path = self.imgs[index]
target = self._id2label[self.id(path)]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
class RandomIdSampler(sampler.Sampler):
"""
Sampler for triplet semihard sample mining.
Attributes:
_id2index (dict of list): mapping from person id to its image indexes in `data_source`
"""
@staticmethod
def _sample(population, k):
if len(population) < k:
population = population * k
return random.sample(population, k)
def __init__(self, data_source, batch_image):
"""
:param data_source: Market1501 dataset
:param batch_image: batch image size for one person id
"""
super(RandomIdSampler, self).__init__(data_source)
self.data_source = data_source
self.batch_image = batch_image
self._id2index = collections.defaultdict(list)
for idx, path in enumerate(data_source.imgs):
_id = data_source.id(path)
self._id2index[_id].append(idx)
def __iter__(self):
unique_ids = self.data_source.unique_ids
random.shuffle(unique_ids)
imgs = []
for _id in unique_ids:
imgs.extend(self._sample(self._id2index[_id], self.batch_image))
return iter(imgs)
def __len__(self):
return len(self._id2index) * self.batch_image