Skip to content

Commit

Permalink
New apex compatible squad (#19)
Browse files Browse the repository at this point in the history
* change squad baseline to use new apex
  • Loading branch information
jeffra authored Apr 30, 2020
1 parent 6a698b2 commit 9e2c735
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
2 changes: 1 addition & 1 deletion BingBertSquad/deepspeed_bsz24_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
"type": "Adam",
"params": {
"lr": 3e-5,
"max_grad_norm": 1.0,
"weight_decay": 0.0,
"bias_correction": false
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true
}
Expand Down
63 changes: 46 additions & 17 deletions BingBertSquad/nvidia_run_squad_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE

from apex import amp
from turing.nvidia_modeling import BertForQuestionAnswering, BertConfig

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
Expand Down Expand Up @@ -712,6 +713,31 @@ def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_n
return is_nan


from apex.multi_tensor_apply import multi_tensor_applier
class GradientClipper:
"""
Clips gradient norm of an iterable of parameters.
"""
def __init__(self, max_grad_norm):
self.max_norm = max_grad_norm
if multi_tensor_applier.available:
import amp_C
self._overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self.multi_tensor_scale = amp_C.multi_tensor_scale
else:
raise RuntimeError('Gradient clipping requires cuda extensions')

def step(self, parameters):
l = [p.grad for p in parameters if p.grad is not None]
total_norm, _ = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [l], False)
total_norm = total_norm.item()
if (total_norm == float('inf')): return
clip_coef = self.max_norm / (total_norm + 1e-6)
if clip_coef < 1:
multi_tensor_applier(self.multi_tensor_scale, self._overflow_buf, [l, l], clip_coef)


def main():
parser = get_argument_parser()
args = parser.parse_args()
Expand Down Expand Up @@ -813,18 +839,7 @@ def main():
#model.bert.load_state_dict(bert_state_dict, strict=False)
logger.info(f"Pretrained Bert Encoder Loaded from: {args.model_file}")

if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)

# Prepare optimizer
param_optimizer = list(model.named_parameters())
Expand All @@ -844,25 +859,33 @@ def main():
t_total = t_total // torch.distributed.get_world_size()
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
bias_correction=False,
max_grad_norm=1.0)
bias_correction=False)
if args.loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=False, loss_scale="dynamic")
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
raise NotImplementedError("dynamic loss scale is only supported in baseline, please set loss_scale=0")
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=t_total)

if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)

global_step = 0
if args.do_train:
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
Expand Down Expand Up @@ -901,6 +924,8 @@ def main():
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

gradClipper = GradientClipper(max_grad_norm=1.0)

model.train()
ema_loss = 0.
sample_count = 0
Expand Down Expand Up @@ -928,10 +953,14 @@ def main():
model.enable_allreduce()

if args.fp16:
optimizer.backward(loss)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()

# gradient clipping
gradClipper.step(amp.master_params(optimizer))

sample_count += (args.train_batch_size * torch.distributed.get_world_size())

if (step + 1) % args.gradient_accumulation_steps == 0:
Expand Down

0 comments on commit 9e2c735

Please sign in to comment.