-
Notifications
You must be signed in to change notification settings - Fork 9
/
utils.py
112 lines (80 loc) · 2.56 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import logging
import os
import wandb
import re
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
"""
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
"""
return [atoi(c) for c in re.split(r'(\d+)', text)]
def init_wandb(args):
wandb.init(project=args.wandb)
# Save run name.
wandb.run.save()
run_name = wandb.run.name
# Log args.
config = wandb.config
config.update(args)
return run_name
def set_logger(log_path, file_name):
os.makedirs(log_path)
path = os.path.join(log_path, file_name)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Log to .txt
file_handler = logging.FileHandler(path)
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logger.addHandler(file_handler)
# Log to console
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(stream_handler)
return logger
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class Scores():
def __init__(self):
self.TP = 0
self.TN = 0
self.FP = 0
self.FN = 0
self.patient_dice = []
self.patient_iou = []
def record(self, preds, label):
assert len(torch.unique(preds)) < 3
tp = torch.sum((label == 1) * (preds == 1))
tn = torch.sum((label == 0) * (preds == 0))
fp = torch.sum((label == 0) * (preds == 1))
fn = torch.sum((label == 1) * (preds == 0))
self.patient_dice.append(2 * tp / (2 * tp + fp + fn))
self.patient_iou.append(tp / (tp + fp + fn))
self.TP += tp
self.TN += tn
self.FP += fp
self.FN += fn
def compute_dice(self):
return 2 * self.TP / (2 * self.TP + self.FP + self.FN)
def compute_iou(self):
return self.TP / (self.TP + self.FP + self.FN)