-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
train.py
executable file
·159 lines (125 loc) · 4.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python3
import os
import argparse
import torch
import torch.distributed as dist
import torchvision
import torchvision.transforms as transforms
from torchvision.models import AlexNet
from torchvision.models import vgg19
import deepspeed
from deepspeed.pipe import PipelineModule
from deepspeed.utils import RepeatingLoader
def cifar_trainset(local_rank, dl_path='/tmp/cifar10-data'):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Ensure only one rank downloads.
# Note: if the download path is not on a shared filesytem, remove the semaphore
# and switch to args.local_rank
dist.barrier()
if local_rank != 0:
dist.barrier()
trainset = torchvision.datasets.CIFAR10(root=dl_path,
train=True,
download=True,
transform=transform)
if local_rank == 0:
dist.barrier()
return trainset
def get_args():
parser = argparse.ArgumentParser(description='CIFAR')
parser.add_argument('--local_rank',
type=int,
default=-1,
help='local rank passed from distributed launcher')
parser.add_argument('-s',
'--steps',
type=int,
default=100,
help='quit after this many steps')
parser.add_argument('-p',
'--pipeline-parallel-size',
type=int,
default=2,
help='pipeline parallelism')
parser.add_argument('--backend',
type=str,
default='nccl',
help='distributed backend')
parser.add_argument('--seed', type=int, default=1138, help='PRNG seed')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
def train_base(args):
torch.manual_seed(args.seed)
# VGG also works :-)
#net = vgg19(num_classes=10)
net = AlexNet(num_classes=10)
trainset = cifar_trainset(args.local_rank)
engine, _, dataloader, __ = deepspeed.initialize(
args=args,
model=net,
model_parameters=[p for p in net.parameters() if p.requires_grad],
training_data=trainset)
dataloader = RepeatingLoader(dataloader)
data_iter = iter(dataloader)
rank = dist.get_rank()
gas = engine.gradient_accumulation_steps()
criterion = torch.nn.CrossEntropyLoss()
total_steps = args.steps * engine.gradient_accumulation_steps()
step = 0
for micro_step in range(total_steps):
batch = next(data_iter)
inputs = batch[0].to(engine.device)
labels = batch[1].to(engine.device)
outputs = engine(inputs)
loss = criterion(outputs, labels)
engine.backward(loss)
engine.step()
if micro_step % engine.gradient_accumulation_steps() == 0:
step += 1
if rank == 0 and (step % 10 == 0):
print(f'step: {step:3d} / {args.steps:3d} loss: {loss}')
def join_layers(vision_model):
layers = [
*vision_model.features,
vision_model.avgpool,
lambda x: torch.flatten(x, 1),
*vision_model.classifier,
]
return layers
def train_pipe(args, part='parameters'):
torch.manual_seed(args.seed)
deepspeed.runtime.utils.set_random_seed(args.seed)
#
# Build the model
#
# VGG also works :-)
#net = vgg19(num_classes=10)
net = AlexNet(num_classes=10)
net = PipelineModule(layers=join_layers(net),
loss_fn=torch.nn.CrossEntropyLoss(),
num_stages=args.pipeline_parallel_size,
partition_method=part,
activation_checkpoint_interval=0)
trainset = cifar_trainset(args.local_rank)
engine, _, _, _ = deepspeed.initialize(
args=args,
model=net,
model_parameters=[p for p in net.parameters() if p.requires_grad],
training_data=trainset)
for step in range(args.steps):
loss = engine.train_batch()
if __name__ == '__main__':
args = get_args()
deepspeed.init_distributed(dist_backend=args.backend)
args.local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(args.local_rank)
if args.pipeline_parallel_size == 0:
train_base(args)
else:
train_pipe(args)