-
Notifications
You must be signed in to change notification settings - Fork 9.5k
/
eval_hooks.py
132 lines (115 loc) · 4.78 KB
/
eval_hooks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os.path as osp
import warnings
from mmcv.runner import Hook
from torch.utils.data import DataLoader
class EvalHook(Hook):
"""Evaluation hook.
Notes:
If new arguments are added for EvalHook, tools/test.py may be
effected.
Attributes:
dataloader (DataLoader): A PyTorch dataloader.
start (int, optional): Evaluation starting epoch. It enables evaluation
before the training starts if ``start`` <= the resuming epoch.
If None, whether to evaluate is merely decided by ``interval``.
Default: None.
interval (int): Evaluation interval (by epochs). Default: 1.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
"""
def __init__(self, dataloader, start=None, interval=1, **eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError('dataloader must be a pytorch DataLoader, but got'
f' {type(dataloader)}')
if not interval > 0:
raise ValueError(f'interval must be positive, but got {interval}')
if start is not None and start < 0:
warnings.warn(
f'The evaluation start epoch {start} is smaller than 0, '
f'use 0 instead', UserWarning)
start = 0
self.dataloader = dataloader
self.interval = interval
self.start = start
self.eval_kwargs = eval_kwargs
self.initial_epoch_flag = True
def before_train_epoch(self, runner):
"""Evaluate the model only at the start of training."""
if not self.initial_epoch_flag:
return
if self.start is not None and runner.epoch >= self.start:
self.after_train_epoch(runner)
self.initial_epoch_flag = False
def evaluation_flag(self, runner):
"""Judge whether to perform_evaluation after this epoch.
Returns:
bool: The flag indicating whether to perform evaluation.
"""
if self.start is None:
if not self.every_n_epochs(runner, self.interval):
# No evaluation during the interval epochs.
return False
elif (runner.epoch + 1) < self.start:
# No evaluation if start is larger than the current epoch.
return False
else:
# Evaluation only at epochs 3, 5, 7... if start==3 and interval==2
if (runner.epoch + 1 - self.start) % self.interval:
return False
return True
def after_train_epoch(self, runner):
if not self.evaluation_flag(runner):
return
from mmdet.apis import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader, show=False)
self.evaluate(runner, results)
def evaluate(self, runner, results):
eval_res = self.dataloader.dataset.evaluate(
results, logger=runner.logger, **self.eval_kwargs)
for name, val in eval_res.items():
runner.log_buffer.output[name] = val
runner.log_buffer.ready = True
class DistEvalHook(EvalHook):
"""Distributed evaluation hook.
Notes:
If new arguments are added, tools/test.py may be effected.
Attributes:
dataloader (DataLoader): A PyTorch dataloader.
start (int, optional): Evaluation starting epoch. It enables evaluation
before the training starts if ``start`` <= the resuming epoch.
If None, whether to evaluate is merely decided by ``interval``.
Default: None.
interval (int): Evaluation interval (by epochs). Default: 1.
tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Default: False.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
"""
def __init__(self,
dataloader,
start=None,
interval=1,
tmpdir=None,
gpu_collect=False,
**eval_kwargs):
super().__init__(
dataloader, start=start, interval=interval, **eval_kwargs)
self.tmpdir = tmpdir
self.gpu_collect = gpu_collect
def after_train_epoch(self, runner):
if not self.evaluation_flag(runner):
return
from mmdet.apis import multi_gpu_test
tmpdir = self.tmpdir
if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook')
results = multi_gpu_test(
runner.model,
self.dataloader,
tmpdir=tmpdir,
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
self.evaluate(runner, results)