Skip to content

Commit

Permalink
Add time logging in eval_metric.py (#150)
Browse files Browse the repository at this point in the history
* Add time logging in eval_metric.py

* Uniform the printing

* Minor fixes

* Adding blanks to make the codes look better
  • Loading branch information
zhiqwang authored Aug 22, 2021
1 parent db7e7dd commit ed73a5d
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 13 deletions.
47 changes: 34 additions & 13 deletions tools/eval_metric.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from pathlib import Path
import io
import time
import contextlib
import argparse

import torch
import torchvision

import yolort

from yolort.data import COCOEvaluator, _helper as data_helper
from yolort.data.coco import COCODetection
from yolort.data.transforms import default_val_transforms, collate_fn
from yolort.utils.logger import MetricLogger


def get_parser():
Expand Down Expand Up @@ -37,6 +39,8 @@ def get_parser():
parser.add_argument('--num_workers', default=8, type=int, metavar='N',
help='Number of data loading workers (default: 8)')

parser.add_argument('--print_freq', default=20, type=int,
help='The frequency of printing the logging')
parser.add_argument('--output_dir', default='.',
help='Path where to save')
return parser
Expand Down Expand Up @@ -107,22 +111,39 @@ def eval_metric(args):
model = model.eval()
model = model.to(device)

# COCO evaluation
print('Computing the mAP...')
with torch.no_grad():
for images, targets in data_loader:
images = [image.to(device) for image in images]
preds = model(images)
coco_evaluator.update(preds, targets)
results = evaluate(model, data_loader, coco_evaluator, device, args.print_freq)

results = coco_evaluator.compute()
# mAP results
print(f"The evaluated mAP at 0.50:0.95 is {results['AP']:0.3f}, "
f"and mAP at 0.50 is {results['AP50']:0.3f}.")

# Format the results
# coco_evaluator.derive_coco_results()

# mAP results
print(f"The evaluated mAP 0.5:095 is {results['AP']:0.3f}, "
f"and mAP 0.5 is {results['AP50']:0.3f}.")
@torch.no_grad()
def evaluate(model, data_loader, coco_evaluator, device, print_freq):
# COCO evaluation
metric_logger = MetricLogger(delimiter=" ")
header = 'Test:'
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)

if torch.cuda.is_available():
torch.cuda.synchronize()

model_time = time.time()
preds = model(images)
model_time = time.time() - model_time

evaluator_time = time.time()
coco_evaluator.update(preds, targets)
evaluator_time = time.time() - evaluator_time

metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)

# gather the stats from all processes
metric_logger.synchronize_between_processes()
results = coco_evaluator.compute()
return results


def cli_main():
Expand Down
164 changes: 164 additions & 0 deletions yolort/utils/logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from tabulate import tabulate
from collections import defaultdict, deque
import datetime
import time

import torch
import torch.distributed as dist


def create_small_table(small_dict):
Expand All @@ -23,3 +29,161 @@ def create_small_table(small_dict):
numalign="center",
)
return table


class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""

def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt

def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n

def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]

@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()

@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()

@property
def global_avg(self):
return self.total / self.count

@property
def max(self):
return max(self.deque)

@property
def value(self):
return self.deque[-1]

def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)


class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter

def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)

def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")

def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)

def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()

def add_meter(self, name, meter):
self.meters[name] = meter

def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)')


def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True

0 comments on commit ed73a5d

Please sign in to comment.