-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathuda_dataset.py
126 lines (110 loc) · 4.71 KB
/
uda_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# ---------------------------------------------------------------
# Copyright (c) 2021-2022 ETH Zurich, Lukas Hoyer. All rights reserved.
# Licensed under the Apache License, Version 2.0
# ---------------------------------------------------------------
import json
import os.path as osp
import mmcv
import numpy as np
import torch
from . import CityscapesDataset
from .builder import DATASETS
def get_rcs_class_probs(data_root, temperature):
with open(osp.join(data_root, 'sample_class_stats.json'), 'r') as of:
sample_class_stats = json.load(of)
overall_class_stats = {}
for s in sample_class_stats:
s.pop('file')
for c, n in s.items():
c = int(c)
if c not in overall_class_stats:
overall_class_stats[c] = n
else:
overall_class_stats[c] += n
overall_class_stats = {
k: v
for k, v in sorted(
overall_class_stats.items(), key=lambda item: item[1])
}
freq = torch.tensor(list(overall_class_stats.values()))
freq = freq / torch.sum(freq)
freq = 1 - freq
freq = torch.softmax(freq / temperature, dim=-1)
return list(overall_class_stats.keys()), freq.numpy()
@DATASETS.register_module()
class UDADataset(object):
def __init__(self, source, target, cfg):
self.source = source
self.target = target
self.ignore_index = target.ignore_index
self.CLASSES = target.CLASSES
self.PALETTE = target.PALETTE
assert target.ignore_index == source.ignore_index
assert target.CLASSES == source.CLASSES
assert target.PALETTE == source.PALETTE
rcs_cfg = cfg.get('rare_class_sampling')
self.rcs_enabled = rcs_cfg is not None
if self.rcs_enabled:
self.rcs_class_temp = rcs_cfg['class_temp']
self.rcs_min_crop_ratio = rcs_cfg['min_crop_ratio']
self.rcs_min_pixels = rcs_cfg['min_pixels']
self.rcs_classes, self.rcs_classprob = get_rcs_class_probs(
cfg['source']['data_root'], self.rcs_class_temp)
mmcv.print_log(f'RCS Classes: {self.rcs_classes}', 'mmseg')
mmcv.print_log(f'RCS ClassProb: {self.rcs_classprob}', 'mmseg')
with open(
osp.join(cfg['source']['data_root'],
'samples_with_class.json'), 'r') as of:
samples_with_class_and_n = json.load(of)
samples_with_class_and_n = {
int(k): v
for k, v in samples_with_class_and_n.items()
if int(k) in self.rcs_classes
}
self.samples_with_class = {}
for c in self.rcs_classes:
self.samples_with_class[c] = []
for file, pixels in samples_with_class_and_n[c]:
if pixels > self.rcs_min_pixels:
self.samples_with_class[c].append(file.split('/')[-1])
assert len(self.samples_with_class[c]) > 0
self.file_to_idx = {}
for i, dic in enumerate(self.source.img_infos):
file = dic['ann']['seg_map']
if isinstance(self.source, CityscapesDataset):
file = file.split('/')[-1]
self.file_to_idx[file] = i
def get_rare_class_sample(self):
c = np.random.choice(self.rcs_classes, p=self.rcs_classprob)
f1 = np.random.choice(self.samples_with_class[c])
i1 = self.file_to_idx[f1]
s1 = self.source[i1]
if self.rcs_min_crop_ratio > 0:
for j in range(10):
n_class = torch.sum(s1['gt_semantic_seg'].data == c)
# mmcv.print_log(f'{j}: {n_class}', 'mmseg')
if n_class > self.rcs_min_pixels * self.rcs_min_crop_ratio:
break
# Sample a new random crop from source image i1.
# Please note, that self.source.__getitem__(idx) applies the
# preprocessing pipeline to the loaded image, which includes
# RandomCrop, and results in a new crop of the image.
s1 = self.source[i1]
i2 = np.random.choice(range(len(self.target)))
s2 = self.target[i2]
return {
**s1, 'target_img_metas': s2['img_metas'],
'target_img': s2['img']
}
def __getitem__(self, idx):
if self.rcs_enabled:
return self.get_rare_class_sample()
else:
s1 = self.source[idx // len(self.target)]
s2 = self.target[idx % len(self.target)]
return {
**s1, 'target_img_metas': s2['img_metas'],
'target_img': s2['img']
}
def __len__(self):
return len(self.source) * len(self.target)