-
Notifications
You must be signed in to change notification settings - Fork 486
/
test_train_mp_mnist.py
220 lines (192 loc) · 6.68 KB
/
test_train_mp_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import args_parse
from torch_xla import runtime as xr
MODEL_OPTS = {
'--ddp': {
'action': 'store_true',
},
'--pjrt_distributed': {
'action': 'store_true',
},
}
FLAGS = args_parse.parse_common_options(
datadir='/tmp/mnist-data',
batch_size=128,
momentum=0.5,
lr=0.01,
target_accuracy=98.0,
num_epochs=18,
opts=MODEL_OPTS.items(),
)
import os
import shutil
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.test.test_utils as test_utils
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch_xla.distributed.xla_backend
class MNIST(nn.Module):
def __init__(self):
super(MNIST, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn2 = nn.BatchNorm2d(20)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = self.bn1(x)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = self.bn2(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def _train_update(device, step, loss, tracker, epoch, writer):
test_utils.print_training_update(
device,
step,
loss.item(),
tracker.rate(),
tracker.global_rate(),
epoch,
summary_writer=writer)
def train_mnist(flags, **kwargs):
if flags.ddp or flags.pjrt_distributed:
dist.init_process_group('xla', init_method='xla://')
torch.manual_seed(1)
if flags.fake_data:
train_loader = xu.SampleGenerator(
data=(torch.zeros(flags.batch_size, 1, 28,
28), torch.zeros(flags.batch_size,
dtype=torch.int64)),
sample_count=60000 // flags.batch_size // xr.world_size())
test_loader = xu.SampleGenerator(
data=(torch.zeros(flags.batch_size, 1, 28,
28), torch.zeros(flags.batch_size,
dtype=torch.int64)),
sample_count=10000 // flags.batch_size // xr.world_size())
else:
train_dataset = datasets.MNIST(
os.path.join(flags.datadir, str(xr.global_ordinal())),
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
test_dataset = datasets.MNIST(
os.path.join(flags.datadir, str(xr.global_ordinal())),
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
train_sampler = None
if xr.world_size() > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xr.world_size(),
rank=xr.global_ordinal(),
shuffle=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=flags.batch_size,
sampler=train_sampler,
drop_last=flags.drop_last,
shuffle=False if train_sampler else True,
num_workers=flags.num_workers)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=flags.batch_size,
drop_last=flags.drop_last,
shuffle=False,
num_workers=flags.num_workers)
# Scale learning rate to num cores
lr = flags.lr * xr.world_size()
device = xm.xla_device()
model = MNIST().to(device)
# Initialization is nondeterministic with multiple threads in PjRt.
# Synchronize model parameters across replicas manually.
xm.broadcast_master_param(model)
if flags.ddp:
model = DDP(model)
writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(flags.logdir)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
loss_fn = nn.NLLLoss()
def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
model.train()
for step, (data, target) in enumerate(loader):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
if flags.ddp:
optimizer.step()
else:
xm.optimizer_step(optimizer)
tracker.add(flags.batch_size)
if step % flags.log_steps == 0:
xm.add_step_closure(
_train_update,
args=(device, step, loss, tracker, epoch, writer),
run_async=flags.async_closures)
def test_loop_fn(loader):
total_samples = 0
correct = 0
model.eval()
for data, target in loader:
output = model(data)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum()
total_samples += data.size()[0]
accuracy = 100.0 * correct.item() / total_samples
accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
return accuracy
train_device_loader = pl.MpDeviceLoader(train_loader, device)
test_device_loader = pl.MpDeviceLoader(test_loader, device)
accuracy, max_accuracy = 0.0, 0.0
for epoch in range(1, flags.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
train_loop_fn(train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
accuracy = test_loop_fn(test_device_loader)
xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format(
epoch, test_utils.now(), accuracy))
max_accuracy = max(accuracy, max_accuracy)
test_utils.write_to_summary(
writer,
epoch,
dict_to_write={'Accuracy/test': accuracy},
write_xla_metrics=True)
if flags.metrics_debug:
xm.master_print(met.metrics_report())
test_utils.close_summary_writer(writer)
xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
return max_accuracy
def _mp_fn(index, flags):
torch.set_default_dtype(torch.float32)
accuracy = train_mnist(flags)
if flags.tidy and os.path.isdir(flags.datadir):
shutil.rmtree(flags.datadir)
if accuracy < flags.target_accuracy:
print('Accuracy {} is below target {}'.format(accuracy,
flags.target_accuracy))
sys.exit(21)
if __name__ == '__main__':
debug_single_process = FLAGS.num_cores == 1
torch_xla.launch(
_mp_fn, args=(FLAGS,), debug_single_process=debug_single_process)