-
Notifications
You must be signed in to change notification settings - Fork 3
/
mean_accumulator.py
97 lines (84 loc) · 3.43 KB
/
mean_accumulator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from copy import deepcopy
class MeanAccumulator:
def __init__(self, update_weight=1):
self.average = None
self.counter = 0
self.update_weight = update_weight
def value(self):
if isinstance(self.average, dict):
return {k: v.value() for k, v in self.average.items()}
elif isinstance(self.average, list):
return [v.value() for v in self.average]
else:
return self.average
def reduce(self):
"""Reduce over workers"""
if not torch.distributed.is_available() or torch.distributed.get_world_size() == 1:
# Skip this if there is only one worker
return
if isinstance(self.average, dict):
for key in sorted(self.average.keys()):
self.average[key].reduce()
elif isinstance(self.average, list):
for avg in self.average:
avg.reduce()
else:
device = "cuda" if torch.distributed.get_backend() == "nccl" else "cpu"
total_count = torch.tensor(self.counter, dtype=torch.float32, device=device)
handle_tc = torch.distributed.all_reduce(total_count, async_op=True)
# Average * count
if isinstance(self.average, torch.Tensor):
multiplied = self.average.clone()
else:
multiplied = torch.tensor(self.average, dtype=torch.float32, device=device)
multiplied.mul_(self.counter)
handle_mul = torch.distributed.all_reduce(multiplied, async_op=True)
handle_tc.wait()
handle_mul.wait()
self.counter = total_count.item()
if isinstance(self.average, torch.Tensor):
self.average.data = multiplied / total_count
else:
self.average = (multiplied / total_count).item()
def add(self, value, weight=1.0):
"""Add a value to the average"""
self.counter += weight
if self.average is None:
self._init(value, weight)
else:
if isinstance(self.average, dict):
for k, v in value.items():
self.average[k].add(v, weight)
elif isinstance(self.average, list):
for avg, new_value in zip(self.average, value):
avg.add(new_value, weight)
else:
self._update(value, weight)
def _update(self, value, weight):
alpha = float(self.update_weight * weight) / float(self.counter + self.update_weight - 1)
if isinstance(self.average, torch.Tensor):
self.average.mul_(1.0 - alpha)
self.average.add_(alpha, value)
elif isinstance(self.average, float):
self.average *= 1.0 - alpha
self.average += alpha * value
else:
raise ValueError("Unknown type")
def _init(self, value, weight):
if isinstance(value, dict):
self.average = {}
for key in value:
self.average[key] = MeanAccumulator()
self.average[key].add(value[key], weight)
elif isinstance(value, list):
self.average = []
for v in value:
acc = MeanAccumulator()
acc.add(value[key], weight)
self.average.append(acc)
else:
self.average = deepcopy(value)
def reset(self):
self.average = None
self.counter = 0