Skip to content

Commit

Permalink
TPU support (#2137)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2137

Reviewed By: ngoyal2707

Differential Revision: D21588555

Pulled By: myleott

fbshipit-source-id: d0c38498356aa8a97cb347bb1b943bb58e59489e
  • Loading branch information
myleott authored and facebook-github-bot committed May 18, 2020
1 parent 132ee8a commit 7751229
Show file tree
Hide file tree
Showing 17 changed files with 460 additions and 208 deletions.
18 changes: 12 additions & 6 deletions fairseq/criterions/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class MaskedLmLoss(FairseqCriterion):
Implementation for the loss used in masked language model (MLM) training.
"""

def __init__(self, task, tpu):
super().__init__(task)
self.tpu = tpu

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Expand All @@ -26,16 +30,18 @@ def forward(self, model, sample, reduce=True):
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
# compute MLM loss
masked_tokens = sample['target'].ne(self.padding_idx)
sample_size = masked_tokens.int().sum()

# Rare: when all tokens are masked, project all tokens.
# We use torch.where to avoid device-to-host transfers,
# except on CPU where torch.where is not well supported
# (see github.com/pytorch/pytorch/issues/26247).
if masked_tokens.device == torch.device('cpu'):
if self.tpu:
masked_tokens = None # always project all tokens on TPU
elif masked_tokens.device == torch.device('cpu'):
if not masked_tokens.any():
masked_tokens.fill_(True)
masked_tokens = None
else:
masked_tokens = torch.where(
masked_tokens.any(),
Expand All @@ -45,7 +51,8 @@ def forward(self, model, sample, reduce=True):

logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
targets = model.get_targets(sample, [logits])
targets = targets[masked_tokens]
if masked_tokens is not None:
targets = targets[masked_tokens]

loss = modules.cross_entropy(
logits.view(-1, logits.size(-1)),
Expand All @@ -54,9 +61,8 @@ def forward(self, model, sample, reduce=True):
ignore_index=self.padding_idx,
)

sample_size = masked_tokens.int().sum()
logging_output = {
'loss': loss.data,
'loss': loss,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],
'sample_size': sample_size,
Expand Down
68 changes: 39 additions & 29 deletions fairseq/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def is_master(args):


def infer_init_method(args):
if args.distributed_init_method is not None:
if args.distributed_init_method is not None or getattr(args, 'tpu', False):
return

# support torch.distributed.launch
Expand Down Expand Up @@ -80,34 +80,40 @@ def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')

if torch.distributed.is_initialized():
warnings.warn('Distributed is already initialized, cannot initialize twice!')
else:
logger.info('distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method,
))
dist.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
logger.info('initialized host {} as rank {}'.format(
socket.gethostname(), args.distributed_rank,
))

# perform a dummy all-reduce to initialize the NCCL communicator
if torch.cuda.is_available():
dist.all_reduce(torch.zeros(1).cuda())
if not getattr(args, 'tpu', False):
if torch.distributed.is_initialized():
warnings.warn('Distributed is already initialized, cannot initialize twice!')
else:
dist.all_reduce(torch.zeros(1))
logger.info('distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method,
))
dist.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
logger.info('initialized host {} as rank {}'.format(
socket.gethostname(), args.distributed_rank,
))

if is_master(args):
logging.getLogger().setLevel(logging.INFO)
else:
logging.getLogger().setLevel(logging.WARNING)
# perform a dummy all-reduce to initialize the NCCL communicator
if torch.cuda.is_available():
dist.all_reduce(torch.zeros(1).cuda())

args.distributed_rank = torch.distributed.get_rank()
args.distributed_rank = torch.distributed.get_rank()
else:
import torch_xla.core.xla_model as xm
assert xm.xrt_world_size() == args.distributed_world_size
args.device_id = xm.get_local_ordinal()
args.distributed_rank = xm.get_ordinal()
xm.rendezvous('distributed_init') # wait for all workers
xm.mark_step()

if is_master(args):
logging.getLogger().setLevel(logging.INFO)
else:
logging.getLogger().setLevel(logging.WARNING)

if args.model_parallel_size > 1:
try:
Expand Down Expand Up @@ -186,9 +192,13 @@ def get_default_group():


def all_reduce(tensor, group=None):
if group is None:
group = get_default_group()
return dist.all_reduce(tensor, group=group)
if isinstance(group, tuple) and group[0] == 'tpu':
import torch_xla.core.xla_model as xm
return xm.all_reduce('sum', [tensor], groups=group[1])
else:
if group is None:
group = get_default_group()
return dist.all_reduce(tensor, group=group)


def all_gather_list(data, group=None, max_size=16384):
Expand Down
15 changes: 15 additions & 0 deletions fairseq/models/fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,21 @@ def apply_prepare_for_onnx_export_(module):

self.apply(apply_prepare_for_onnx_export_)

def prepare_for_tpu_(self, **kwargs):
"""Optionally modify model for use on TPUs."""
seen = set()

def apply_prepare_for_tpu_(module):
if (
module != self
and hasattr(module, "prepare_for_tpu_")
and module not in seen
):
seen.add(module)
module.prepare_for_tpu_(**kwargs)

self.apply(apply_prepare_for_tpu_)

@classmethod
def from_pretrained(
cls,
Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/huggingface/transformers
28 changes: 16 additions & 12 deletions fairseq/modules/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,14 @@ def __init__(
self.reset_parameters()

self.onnx_trace = False

self.enable_torch_version = False
if hasattr(F, "multi_head_attention_forward"):
self.enable_torch_version = True
else:
self.enable_torch_version = False
self.tpu = False

def prepare_for_onnx_export_(self):
self.onnx_trace = True

def prepare_for_tpu_(self, **kwargs):
self.tpu = True

def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
Expand Down Expand Up @@ -143,8 +141,8 @@ def forward(
assert list(query.size()) == [tgt_len, bsz, embed_dim]

if (
self.enable_torch_version
and not self.onnx_trace
not self.onnx_trace
and not self.tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
Expand Down Expand Up @@ -327,9 +325,15 @@ def forward(
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
)
if not self.tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf")
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float('-inf'))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

if before_softmax:
Expand All @@ -340,7 +344,7 @@ def forward(
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(
attn_weights_float.type_as(attn_weights),
attn_weights,
p=self.dropout,
training=self.training,
)
Expand Down
6 changes: 5 additions & 1 deletion fairseq/modules/transformer_sentence_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
self.apply_bert_init = apply_bert_init
self.learned_pos_embedding = learned_pos_embedding
self.traceable = traceable
self.tpu = False # whether we're on TPU

self.embed_tokens = nn.Embedding(
self.vocab_size, self.embedding_dim, self.padding_idx
Expand Down Expand Up @@ -187,6 +188,9 @@ def freeze_module_params(m):
for layer in range(n_trans_layers_to_freeze):
freeze_module_params(self.layers[layer])

def prepare_for_tpu_(self, **kwargs):
self.tpu = True

def forward(
self,
tokens: torch.Tensor,
Expand All @@ -197,7 +201,7 @@ def forward(

# compute padding mask. This is needed for multi-head attention
padding_mask = tokens.eq(self.padding_idx)
if not self.traceable and not padding_mask.any():
if not self.traceable and not self.tpu and not padding_mask.any():
padding_mask = None

x = self.embed_tokens(tokens)
Expand Down
41 changes: 24 additions & 17 deletions fairseq/optim/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,12 @@ def _get_options(self, param_group, param_shape):
def _rms(self, tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)

def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, output):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1).unsqueeze(-1)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
torch.mul(r_factor, c_factor, out=output)
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = (
exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)
).rsqrt_()
c_factor = exp_avg_sq_col.rsqrt()
return torch.mm(r_factor.unsqueeze(-1), c_factor.unsqueeze(0))

def step(self, closure=None):
"""Performs a single optimization step.
Expand All @@ -155,7 +157,9 @@ def step(self, closure=None):
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
grad = p.grad.data
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError('Adafactor does not support sparse gradients.')

Expand All @@ -171,22 +175,24 @@ def step(self, closure=None):
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(grad)
if factored:
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).type_as(grad)
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).type_as(grad)
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state['exp_avg_sq'] = torch.zeros_like(grad)

state['RMS'] = 0
else:
if use_first_moment:
state['exp_avg'] = state['exp_avg'].type_as(grad)
state['exp_avg'] = state['exp_avg'].to(grad)
if factored:
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].type_as(grad)
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].type_as(grad)
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
else:
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(grad)
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)

p_data_fp32 = p.data.float()
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()

state['step'] += 1
state['RMS'] = self._rms(p_data_fp32)
Expand All @@ -202,15 +208,17 @@ def step(self, closure=None):
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))

# Approximation of exponential moving average of square of gradient
self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state['exp_avg_sq']

exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
torch.rsqrt(exp_avg_sq, out=update).mul_(grad)
update = exp_avg_sq.rsqrt().mul_(grad)

update.div_(max(1.0, self._rms(update) / group['clip_threshold']))
update.div_(
(self._rms(update) / group['clip_threshold']).clamp_(min=1.0)
)
update.mul_(group['lr'])

if use_first_moment:
Expand All @@ -223,8 +231,7 @@ def step(self, closure=None):

p_data_fp32.add_(-update)

# TODO: remove check once pyTorch avoids a copy for this case
if p.data_ptr() != p_data_fp32.data_ptr():
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)

return loss
Loading

0 comments on commit 7751229

Please sign in to comment.