-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
cifar10_deepspeed.py
executable file
·402 lines (353 loc) · 13.2 KB
/
cifar10_deepspeed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import argparse
import os
import deepspeed
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from deepspeed.accelerator import get_accelerator
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
def add_argument():
parser = argparse.ArgumentParser(description="CIFAR")
# For train.
parser.add_argument(
"-e",
"--epochs",
default=30,
type=int,
help="number of total epochs (default: 30)",
)
parser.add_argument(
"--local_rank",
type=int,
default=-1,
help="local rank passed from distributed launcher",
)
parser.add_argument(
"--log-interval",
type=int,
default=2000,
help="output logging information at a given interval",
)
# For mixed precision training.
parser.add_argument(
"--dtype",
default="fp16",
type=str,
choices=["bf16", "fp16", "fp32"],
help="Datatype used for training",
)
# For ZeRO Optimization.
parser.add_argument(
"--stage",
default=0,
type=int,
choices=[0, 1, 2, 3],
help="Datatype used for training",
)
# For MoE (Mixture of Experts).
parser.add_argument(
"--moe",
default=False,
action="store_true",
help="use deepspeed mixture of experts (moe)",
)
parser.add_argument(
"--ep-world-size", default=1, type=int, help="(moe) expert parallel world size"
)
parser.add_argument(
"--num-experts",
type=int,
nargs="+",
default=[
1,
],
help="number of experts list, MoE related.",
)
parser.add_argument(
"--mlp-type",
type=str,
default="standard",
help="Only applicable when num-experts > 1, accepts [standard, residual]",
)
parser.add_argument(
"--top-k", default=1, type=int, help="(moe) gating top 1 and 2 supported"
)
parser.add_argument(
"--min-capacity",
default=0,
type=int,
help="(moe) minimum capacity of an expert regardless of the capacity_factor",
)
parser.add_argument(
"--noisy-gate-policy",
default=None,
type=str,
help="(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter",
)
parser.add_argument(
"--moe-param-group",
default=False,
action="store_true",
help="(moe) create separate moe param groups, required when using ZeRO w. MoE",
)
# Include DeepSpeed configuration arguments.
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
def create_moe_param_groups(model):
"""Create separate parameter groups for each expert."""
parameters = {"params": [p for p in model.parameters()], "name": "parameters"}
return split_params_into_different_moe_groups_for_optimizer(parameters)
def get_ds_config(args):
"""Get the DeepSpeed configuration dictionary."""
ds_config = {
"train_batch_size": 16,
"steps_per_print": 2000,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"betas": [0.8, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7,
},
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000,
},
},
"gradient_clipping": 1.0,
"prescale_gradients": False,
"bf16": {"enabled": args.dtype == "bf16"},
"fp16": {
"enabled": args.dtype == "fp16",
"fp16_master_weights_and_grads": False,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 15,
},
"wall_clock_breakdown": False,
"zero_optimization": {
"stage": args.stage,
"allgather_partitions": True,
"reduce_scatter": True,
"allgather_bucket_size": 50000000,
"reduce_bucket_size": 50000000,
"overlap_comm": True,
"contiguous_gradients": True,
"cpu_offload": False,
},
}
return ds_config
class Net(nn.Module):
def __init__(self, args):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.moe = args.moe
if self.moe:
fc3 = nn.Linear(84, 84)
self.moe_layer_list = []
for n_e in args.num_experts:
# Create moe layers based on the number of experts.
self.moe_layer_list.append(
deepspeed.moe.layer.MoE(
hidden_size=84,
expert=fc3,
num_experts=n_e,
ep_size=args.ep_world_size,
use_residual=args.mlp_type == "residual",
k=args.top_k,
min_capacity=args.min_capacity,
noisy_gate_policy=args.noisy_gate_policy,
)
)
self.moe_layer_list = nn.ModuleList(self.moe_layer_list)
self.fc4 = nn.Linear(84, 10)
else:
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
if self.moe:
for layer in self.moe_layer_list:
x, _, _ = layer(x)
x = self.fc4(x)
else:
x = self.fc3(x)
return x
def test(model_engine, testset, local_device, target_dtype, test_batch_size=4):
"""Test the network on the test data.
Args:
model_engine (deepspeed.runtime.engine.DeepSpeedEngine): the DeepSpeed engine.
testset (torch.utils.data.Dataset): the test dataset.
local_device (str): the local device name.
target_dtype (torch.dtype): the target datatype for the test data.
test_batch_size (int): the test batch size.
"""
# The 10 classes for CIFAR10.
classes = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
# Define the test dataloader.
testloader = torch.utils.data.DataLoader(
testset, batch_size=test_batch_size, shuffle=False, num_workers=0
)
# For total accuracy.
correct, total = 0, 0
# For accuracy per class.
class_correct = list(0.0 for i in range(10))
class_total = list(0.0 for i in range(10))
# Start testing.
model_engine.eval()
with torch.no_grad():
for data in testloader:
images, labels = data
if target_dtype != None:
images = images.to(target_dtype)
outputs = model_engine(images.to(local_device))
_, predicted = torch.max(outputs.data, 1)
# Count the total accuracy.
total += labels.size(0)
correct += (predicted == labels.to(local_device)).sum().item()
# Count the accuracy per class.
batch_correct = (predicted == labels.to(local_device)).squeeze()
for i in range(test_batch_size):
label = labels[i]
class_correct[label] += batch_correct[i].item()
class_total[label] += 1
if model_engine.local_rank == 0:
print(
f"Accuracy of the network on the {total} test images: {100 * correct / total : .0f} %"
)
# For all classes, print the accuracy.
for i in range(10):
print(
f"Accuracy of {classes[i] : >5s} : {100 * class_correct[i] / class_total[i] : 2.0f} %"
)
def main(args):
# Initialize DeepSpeed distributed backend.
deepspeed.init_distributed()
_local_rank = int(os.environ.get("LOCAL_RANK"))
get_accelerator().set_device(_local_rank)
########################################################################
# Step1. Data Preparation.
#
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1].
#
# Note:
# If running on Windows and you get a BrokenPipeError, try setting
# the num_worker of torch.utils.data.DataLoader() to 0.
########################################################################
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
if torch.distributed.get_rank() != 0:
# Might be downloading cifar data, let rank 0 download first.
torch.distributed.barrier()
# Load or download cifar data.
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
testset = torchvision.datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform
)
if torch.distributed.get_rank() == 0:
# Cifar data is downloaded, indicate other ranks can proceed.
torch.distributed.barrier()
########################################################################
# Step 2. Define the network with DeepSpeed.
#
# First, we define a Convolution Neural Network.
# Then, we define the DeepSpeed configuration dictionary and use it to
# initialize the DeepSpeed engine.
########################################################################
net = Net(args)
# Get list of parameters that require gradients.
parameters = filter(lambda p: p.requires_grad, net.parameters())
# If using MoE, create separate param groups for each expert.
if args.moe_param_group:
parameters = create_moe_param_groups(net)
# Initialize DeepSpeed to use the following features.
# 1) Distributed model.
# 2) Distributed data loader.
# 3) DeepSpeed optimizer.
ds_config = get_ds_config(args)
model_engine, optimizer, trainloader, __ = deepspeed.initialize(
args=args,
model=net,
model_parameters=parameters,
training_data=trainset,
config=ds_config,
)
# Get the local device name (str) and local rank (int).
local_device = get_accelerator().device_name(model_engine.local_rank)
local_rank = model_engine.local_rank
# For float32, target_dtype will be None so no datatype conversion needed.
target_dtype = None
if model_engine.bfloat16_enabled():
target_dtype = torch.bfloat16
elif model_engine.fp16_enabled():
target_dtype = torch.half
# Define the Classification Cross-Entropy loss function.
criterion = nn.CrossEntropyLoss()
########################################################################
# Step 3. Train the network.
#
# This is when things start to get interesting.
# We simply have to loop over our data iterator, and feed the inputs to the
# network and optimize. (DeepSpeed handles the distributed details for us!)
########################################################################
for epoch in range(args.epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader):
# Get the inputs. ``data`` is a list of [inputs, labels].
inputs, labels = data[0].to(local_device), data[1].to(local_device)
# Try to convert to target_dtype if needed.
if target_dtype != None:
inputs = inputs.to(target_dtype)
outputs = model_engine(inputs)
loss = criterion(outputs, labels)
model_engine.backward(loss)
model_engine.step()
# Print statistics
running_loss += loss.item()
if local_rank == 0 and i % args.log_interval == (
args.log_interval - 1
): # Print every log_interval mini-batches.
print(
f"[{epoch + 1 : d}, {i + 1 : 5d}] loss: {running_loss / args.log_interval : .3f}"
)
running_loss = 0.0
print("Finished Training")
########################################################################
# Step 4. Test the network on the test data.
########################################################################
test(model_engine, testset, local_device, target_dtype)
if __name__ == "__main__":
args = add_argument()
main(args)