-
Notifications
You must be signed in to change notification settings - Fork 191
/
v3.0.2.patch
498 lines (454 loc) · 20.7 KB
/
v3.0.2.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 102e890..c3504bd 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -20,6 +20,9 @@ import os
import torch
+# FastMoE
+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
@@ -43,6 +46,9 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
+ # FastMoE arguments.
+ parser = _add_fmoe_args(parser)
+
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
@@ -316,6 +322,12 @@ def parse_args(extra_args_provider=None, defaults={},
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
+ # if fmoe_num_experts is not specified,
+ # we are using lower version of megatron,
+ # copy num_experts to fmoe_num_experts
+ if not hasattr(args, 'fmoe_num_experts'):
+ args.fmoe_num_experts = args.num_experts
+
_print_args(args)
return args
diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py
index ceba352..01754d0 100644
--- a/megatron/checkpointing.py
+++ b/megatron/checkpointing.py
@@ -124,6 +124,10 @@ def read_metadata(tracker_filename):
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
+
+ args = get_args()
+ if args.fmoefy:
+ return iteration, release
# Get the max iteration retrieved across the ranks.
iters_cuda = torch.cuda.LongTensor([iteration])
@@ -134,6 +138,7 @@ def read_metadata(tracker_filename):
# If not, print a warning and chose the maximum
# iteration across all ranks.
if iteration != max_iter:
+ rank = torch.distributed.get_rank()
print('WARNING: on rank {} found iteration {} in the '
'metadata while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
@@ -399,7 +404,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
- except KeyError:
+ except KeyError as e:
+ print(e)
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py
index 2f6e1b8..e2483db 100644
--- a/megatron/data/indexed_dataset.py
+++ b/megatron/data/indexed_dataset.py
@@ -95,7 +95,7 @@ dtypes = {
3: np.int16,
4: np.int32,
5: np.int64,
- 6: np.float,
+ 6: np.float32,
7: np.double,
8: np.uint16
}
@@ -268,7 +268,7 @@ class IndexedDatasetBuilder(object):
np.int16: 2,
np.int32: 4,
np.int64: 8,
- np.float: 4,
+ np.float32: 4,
np.double: 8
}
diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py
index d8bee27..6f4ecfb 100644
--- a/megatron/optimizer/__init__.py
+++ b/megatron/optimizer/__init__.py
@@ -101,8 +101,9 @@ def get_megatron_optimizer(model,
# Determine whether the params have main-grad field.
params_have_main_grad = False
- if args.DDP_impl == 'local':
- params_have_main_grad = True
+ # FastMoE does not have main_grad field
+ # if args.DDP_impl == 'local':
+ # params_have_main_grad = True
if args.fp16 or args.bf16:
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index 36cd915..7b4eaaa 100644
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -54,6 +54,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
+ # FastMoE
+ grads_in_moe = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
@@ -65,7 +67,11 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
- grads_for_norm.append(grad)
+ # FastMoE
+ if hasattr(param, 'dp_comm') and param.dp_comm in ('none'):
+ grads_in_moe.append(grad)
+ else:
+ grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
@@ -74,6 +80,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Calculate norm.
if norm_type == inf:
+ # FastMoE TODO
+ assert False, f"norm_type {norm_type} is not supported by FastMoE "
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
@@ -98,7 +106,20 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
+ # FastMoE
+ if len(grads_in_moe) > 0 : # 'cold' experts may not have any grads in one iteration
+ grad_norm, _ = multi_tensor_applier(
+ amp_C.multi_tensor_l2norm,
+ dummy_overflow_buf,
+ [grads_in_moe],
+ False # no per-parameter norm
+ )
+ grad_norm = grad_norm ** norm_type
+ torch.distributed.all_reduce(grad_norm, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group())
+ total_norm += grad_norm
else:
+ # FastMoE TODO
+ assert False, f"norm_type {norm_type} is not supported by FastMoE "
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index d6ac42e..7eecff4 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -257,6 +257,9 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
+ # FastMoE
+ if hasattr(param, 'dp_comm'):
+ main_param.dp_comm = param.dp_comm
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
@@ -411,18 +414,27 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
-
+
+ # move to L433-L436
# If we found inf/nan, skip the update.
- if found_inf_flag:
- return False, None, None
+ # if found_inf_flag:
+ # return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
- if self.clip_grad > 0.0:
- grad_norm = self.clip_grad_norm(self.clip_grad)
+
+ # remove if branch to avoid dead-lock in FastMoE
+ # if self.clip_grad > 0.0:
+ # grad_norm = self.clip_grad_norm(self.clip_grad)
+ grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
+ # move early return to here to avoid dead-lock in FastMoE
+ # If we found inf/nan, skip the update.
+ if found_inf_flag:
+ return False, None, None
+
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
diff --git a/megatron/schedules.py b/megatron/schedules.py
index ac5ba6f..26b717a 100644
--- a/megatron/schedules.py
+++ b/megatron/schedules.py
@@ -24,7 +24,10 @@ from megatron import get_timers
from megatron import mpu
from megatron import p2p_communication
from megatron.utils import unwrap_model
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
@@ -66,7 +69,7 @@ def deallocate_output_tensor(out):
dtype = out.dtype,
)
-def custom_backward(output, grad_output):
+def custom_backward(output, grad_output, bal_loss):
'''Directly call C++ autograd engine.
To make the 'deallocate_output_tensor' (above) optimization work, the C++
@@ -89,11 +92,16 @@ def custom_backward(output, grad_output):
output,
memory_format = torch.preserve_format,
)
+ tensors = (output,)
+ grad_tensors = (grad_output,)
+ else:
+ tensors = (output, bal_loss)
+ grad_tensors = (grad_output, None)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
- tensors = (output,),
- grad_tensors = (grad_output,),
+ tensors = tensors,
+ grad_tensors = grad_tensors,
keep_graph = False,
create_graph = False,
inputs = tuple(),
@@ -127,7 +135,8 @@ def forward_step(forward_step_func,
unwrap_output_tensor = True
unwrapped_model.set_input_tensor(input_tensor)
- output_tensor, loss_func = forward_step_func(data_iterator, model)
+ output_tensor, loss_func, bal_loss = forward_step_func(data_iterator, model)
+ bal_loss = bal_loss / get_num_microbatches()
if mpu.is_pipeline_last_stage():
if not collect_non_loss_data:
output_tensor = loss_func(output_tensor)
@@ -145,13 +154,14 @@ def forward_step(forward_step_func,
# downstream as well.
if mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
+ assert False, f"encoder-decoder model is not supported by FastMoE "
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
- return output_tensor
- return [output_tensor]
+ return output_tensor, bal_loss
+ return [output_tensor, bal_loss]
-def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
+def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
@@ -185,7 +195,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass.
if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0])
- custom_backward(output_tensor[0], output_tensor_grad[0])
+ custom_backward(output_tensor[0], output_tensor_grad[0], bal_loss)
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
@@ -241,20 +251,20 @@ def forward_backward_no_pipelining(forward_step_func,
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
- output_tensor = forward_step(forward_step_func, data_iterator,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
- output_tensor = forward_step(forward_step_func, data_iterator,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only:
- backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
+ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss)
return forward_data_store
@@ -269,6 +279,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
+ # FastMoE TODO
+ assert False, "FastMoE not supports pipeline with interleaving"
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
forward_data_store = []
@@ -646,15 +658,17 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
+ bal_losses = None
if not forward_only:
input_tensors = []
output_tensors = []
+ bal_losses = []
forward_data_store = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, forward_data_store,
collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
@@ -662,6 +676,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
+ bal_losses.append(bal_loss)
deallocate_output_tensor(output_tensor[0])
# Before running 1F1B, need to receive first forward tensor.
@@ -674,7 +689,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, forward_data_store,
collect_non_loss_data)
if forward_only:
@@ -692,16 +707,18 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
+ bal_losses.append(bal_loss)
deallocate_output_tensor(output_tensor[0])
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
+ bal_loss = bal_loss.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
if last_iteration:
input_tensor = None
@@ -716,12 +733,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
+ bal_loss = bal_losses.pop(0)
output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
diff --git a/megatron/training.py b/megatron/training.py
index 023bdf1..caefb88 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -36,8 +36,13 @@ from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+
+# FastMoE
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
+
from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
@@ -45,7 +50,11 @@ from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
@@ -119,6 +128,13 @@ def pretrain(train_valid_test_dataset_provider,
args = get_args()
timers = get_timers()
+ # Initialize FastMoE
+ if args.fmoefy:
+ from fmoe.megatron import patch_forward_step, patch_model_provider
+
+ forward_step_func = patch_forward_step(forward_step_func, Megatron_Version="v3.0.2")
+ model_provider = patch_model_provider(model_provider, Megatron_Version='v3.0.2')
+
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
@@ -466,10 +482,12 @@ def train_step(forward_step_func, data_iterator,
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
- if args.DDP_impl == 'local':
- grad = word_embeddings_weight.main_grad
- else:
- grad = word_embeddings_weight.grad
+ grad = word_embeddings_weight.grad
+ # FastMoE does not have main_grad field
+ # if args.DDP_impl == 'local':
+ # grad = word_embeddings_weight.main_grad
+ # else:
+ # grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
@@ -568,26 +586,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Logging.
timers_to_log = []
- def add_to_logging(name):
- if name in timers.timers:
- timers_to_log.append(name)
- add_to_logging('forward-compute')
- add_to_logging('forward-recv')
- add_to_logging('forward-send')
- add_to_logging('forward-backward-send-forward-backward-recv')
- add_to_logging('backward-compute')
- add_to_logging('backward-recv')
- add_to_logging('backward-send')
- add_to_logging('backward-send-forward-recv')
- add_to_logging('backward-send-backward-recv')
- add_to_logging('backward-params-all-reduce')
- add_to_logging('backward-embedding-all-reduce')
- add_to_logging('optimizer-copy-to-main-grad')
- add_to_logging('optimizer-unscale-and-check-inf')
- add_to_logging('optimizer-clip-main-grad')
- add_to_logging('optimizer-copy-main-to-model-params')
- add_to_logging('optimizer')
- add_to_logging('batch-generator')
+ # FastMoE add several timers.
+ # For simplicity, add all timers to log.
+ def add_all():
+ for name in timers.timers:
+ timers_to_log.append(name)
+
+ add_all()
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
--
2.25.1