-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgroupDRO.py
104 lines (99 loc) · 4.03 KB
/
groupDRO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from algorithms.single_model_algorithm import SingleModelAlgorithm
from models.initializer import initialize_model
class GroupDRO(SingleModelAlgorithm):
"""
Group distributionally robust optimization.
Original paper:
@inproceedings{sagawa2019distributionally,
title={Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization},
author={Sagawa, Shiori and Koh, Pang Wei and Hashimoto, Tatsunori B and Liang, Percy},
booktitle={International Conference on Learning Representations},
year={2019}
}
"""
def __init__(self, config, d_out, grouper, loss, metric, n_train_steps, is_group_in_train):
# check config
assert config.uniform_over_groups
# initialize model
model = initialize_model(config, d_out).to(config.device)
# initialize module
super().__init__(
config=config,
model=model,
grouper=grouper,
loss=loss,
metric=metric,
n_train_steps=n_train_steps,
)
# additional logging
self.logged_fields.append('group_weight')
# step size
self.group_weights_step_size = config.group_dro_step_size
# initialize adversarial weights
self.group_weights = torch.zeros(grouper.n_groups)
self.group_weights[is_group_in_train] = 1
self.group_weights = self.group_weights/self.group_weights.sum()
self.group_weights = self.group_weights.to(self.device)
def process_batch(self, batch):
"""
A helper function for update() and evaluate() that processes the batch
Args:
- batch (tuple of Tensors): a batch of data yielded by data loaders
Output:
- results (dictionary): information about the batch
- g (Tensor)
- y_true (Tensor)
- metadata (Tensor)
- loss (Tensor)
- metrics (Tensor)
all Tensors are of size (batch_size,)
"""
results = super().process_batch(batch)
results['group_weight'] = self.group_weights
return results
def objective(self, results):
"""
Takes an output of SingleModelAlgorithm.process_batch() and computes the
optimized objective. For group DRO, the objective is the weighted average
of losses, where groups have weights groupDRO.group_weights.
Args:
- results (dictionary): output of SingleModelAlgorithm.process_batch()
Output:
- objective (Tensor): optimized objective; size (1,).
"""
group_losses, _, _ = self.loss.compute_group_wise(
results['y_pred'],
results['y_true'],
results['g'],
self.grouper.n_groups,
return_dict=False)
return group_losses @ self.group_weights
def _update(self, results):
"""
Process the batch, update the log, and update the model, group weights, and scheduler.
Args:
- batch (tuple of Tensors): a batch of data yielded by data loaders
Output:
- results (dictionary): information about the batch, such as:
- g (Tensor)
- y_true (Tensor)
- metadata (Tensor)
- loss (Tensor)
- metrics (Tensor)
- objective (float)
"""
# compute group losses
group_losses, _, _ = self.loss.compute_group_wise(
results['y_pred'],
results['y_true'],
results['g'],
self.grouper.n_groups,
return_dict=False)
# update group weights
self.group_weights = self.group_weights * torch.exp(self.group_weights_step_size*group_losses.data)
self.group_weights = (self.group_weights/(self.group_weights.sum()))
# save updated group weights
results['group_weight'] = self.group_weights
# update model
super()._update(results)