-
Notifications
You must be signed in to change notification settings - Fork 191
/
fmoefy-v2.2.patch
107 lines (97 loc) · 4.78 KB
/
fmoefy-v2.2.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 26a7cec..0acfb22 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -21,6 +21,8 @@ import os
import torch
from megatron import fused_kernels
+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
@@ -40,6 +42,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
+ parser = _add_fmoe_args(parser)
# Custom arguments.
if extra_args_provider is not None:
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index 9d42260..2583db2 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -177,6 +177,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
+ if hasattr(param, 'dp_comm'):
+ main_param.dp_comm = param.dp_comm
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_fp16_params_this_group.append(main_param)
diff --git a/megatron/training.py b/megatron/training.py
index 56d1c7c..f825bf3 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -35,20 +35,24 @@ from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
from megatron.model import FP16Module
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+# from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+from fmoe.megatron import add_balance_log
def print_datetime(string):
"""Note that this call will sync across all ranks."""
@@ -102,6 +106,13 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
args = get_args()
timers = get_timers()
+ # Initialize FastMoE
+ if args.fmoefy:
+ from fmoe.megatron import patch_forward_step, patch_model_provider
+
+ forward_step_func = patch_forward_step(forward_step_func)
+ model_provider = patch_model_provider(model_provider)
+
# Model, optimizer, and learning rate.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
@@ -643,7 +654,7 @@ def train_step(forward_step_func, data_iterator,
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
- loss_scale, report_memory_flag, skipped_iter):
+ loss_scale, report_memory_flag, skipped_iter, model):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
@@ -725,6 +736,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
args.consumed_train_samples)
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
+ if args.fmoefy and args.balance_strategy and args.balance_strategy != 'naive':
+ add_balance_log(model, writer, iteration)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
@@ -816,7 +829,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
- report_memory_flag, skipped_iter)
+ report_memory_flag, skipped_iter, model)
# Autoresume
if args.adlr_autoresume and \