Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse attention example #36

Merged
merged 2 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions BingBertSquad/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,7 @@ def get_argument_parser():
action='store_true',
help='Use DeepSpeed transformer kernel to accelerate.')

parser.add_argument(
'--dropout',
type=float,
default=0.1,
help='dropout')
parser.add_argument('--dropout', type=float, default=0.1, help='dropout')

return parser

Expand Down
10 changes: 10 additions & 0 deletions bing_bert/deepspeed_bsz64k_lamb_config_seq128.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,15 @@
"fp16": {
"enabled": true,
"loss_scale": 0
},
"sparse_attention": {
"mode": "fixed",
"block": 16,
"different_layout_per_head": true,
"num_local_blocks": 4,
"num_global_blocks": 1,
"attention": "bidirectional",
"horizontal_global_attention": false,
"num_different_global_patterns": 4
}
}
25 changes: 25 additions & 0 deletions bing_bert/ds_sa_train_bert_bsz64k_seq128.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

# This script runs deepspeed using sparse attention for BertEncoderLayer.

base_dir=`pwd`

# Where should we save checkpoints and tensorboard events?
JOB_NAME=lamb_64k_seq128
OUTPUT_DIR=${base_dir}/bert_model_outputs

mkdir -p $OUTPUT_DIR

NCCL_TREE_THRESHOLD=0 deepspeed ${base_dir}/deepspeed_train.py \
--cf ${base_dir}/bert_large_lamb.json \
--max_seq_length 128 \
--output_dir $OUTPUT_DIR \
--deepspeed \
--deepspeed_sparse_attention \
--print_steps 100 \
--lr_schedule "EE" \
--lr_offset 10e-4 \
--job_name $JOB_NAME \
--deepspeed_config ${base_dir}/deepspeed_bsz64k_lamb_config_seq128.json \
--data_path_prefix /data/bert \
&> ${JOB_NAME}.log
91 changes: 83 additions & 8 deletions bing_bert/nvidia/modelingpreln.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import torch.nn.functional as F
import torch.nn.init as init

from deepspeed.ops.sparse_attention import SparseAttentionUtils

logger = logging.getLogger(__name__)

PRETRAINED_MODEL_ARCHIVE_MAP = {
Expand All @@ -66,6 +68,47 @@
TF_WEIGHTS_NAME = 'model.ckpt'


def get_deepspeed_config(args):
if hasattr(args, 'deepspeed_config') and args.deepspeed_config:
from deepspeed import DeepSpeedConfig
return DeepSpeedConfig(args.deepspeed_config)
else:
raise RuntimeError('deepspeed_config is not found in args.')


def get_sparse_attention_config(args, num_heads):
if args.deepspeed_sparse_attention:
ds_config = get_deepspeed_config(args)
if hasattr(ds_config,
'sparse_attention') and ds_config.sparse_attention:
sa_config = ds_config.sparse_attention
sa_mode = sa_config.get('mode')
if (sa_mode == 'dense'):
from deepspeed.ops.sparse_attention import DenseSparsityConfig as STConfig
elif (sa_mode == 'fixed'):
from deepspeed.ops.sparse_attention import FixedSparsityConfig as STConfig
elif (sa_mode == 'bigbird'):
from deepspeed.ops.sparse_attention import BigBirdSparsityConfig as STConfig
elif (sa_mode == 'bslongformer'):
from deepspeed.ops.sparse_attention import BSLongformerSparsityConfig as STConfig
elif (sa_mode == 'variable'):
from deepspeed.ops.sparse_attention import VariableSparsityConfig as STConfig
else:
raise NotImplementedError(
f'Given sparsity mode, {sa_mode}, has not been implemented yet!'
)
del sa_config['mode']
return STConfig(num_heads=num_heads, **sa_config)
else:
from deepspeed.ops.sparse_attention import FixedSparsityConfig as STConfig
print(
'deepspeed sparse attention is not set; Fixed sparsity is used as default.'
)
return STConfig(num_heads=num_heads)
else:
return None


def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
Expand Down Expand Up @@ -510,20 +553,21 @@ def forward(self, hidden_states, attention_mask):


class BertEncoder(nn.Module):
def __init__(self, config, args):
def __init__(self, config, args, sparse_attention_config=None):
super(BertEncoder, self).__init__()

#Added later to make it similar to GPT-2
self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)

if args.deepspeed_transformer_kernel:
from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig, DeepSpeedConfig
if args.deepspeed_transformer_kernel and args.deepspeed_sparse_attention:
raise NotImplementedError(
f'Currently DeepSpeed Transformer Kernels do not support Sparse Attention. To use Sparse Attention, you need to disable Transformer Kernels!'
)

if hasattr(args, 'deepspeed_config') and args.deepspeed_config:
ds_config = DeepSpeedConfig(args.deepspeed_config)
else:
raise RuntimeError('deepspeed_config is not found in args.')
if args.deepspeed_transformer_kernel:
from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig

ds_config = get_deepspeed_config(args)
cuda_config = DeepSpeedTransformerConfig(
batch_size=ds_config.train_micro_batch_size_per_gpu,
max_seq_length=args.max_seq_length,
Expand All @@ -549,6 +593,12 @@ def __init__(self, config, args):
])
else:
layer = BertLayer(config)
if sparse_attention_config is not None:
from deepspeed.ops.sparse_attention import BertSparseSelfAttention

layer.attention.self = BertSparseSelfAttention(
config, sparsity_config=sparse_attention_config)

self.layer = nn.ModuleList([
copy.deepcopy(layer) for _ in range(config.num_hidden_layers)
])
Expand Down Expand Up @@ -936,7 +986,14 @@ class BertModel(BertPreTrainedModel):
def __init__(self, config, args=None):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config, args)
# set pad_token_id that is used for sparse attention padding
self.pad_token_id = config.pad_token_id if hasattr(
config, 'pad_token_id') and config.pad_token_id is not None else 0
# set sparse_attention_config if it has been selected
self.sparse_attention_config = get_sparse_attention_config(
args, config.num_attention_heads)
self.encoder = BertEncoder(
config, args, sparse_attention_config=self.sparse_attention_config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
logger.info("Init BERT pretrain model")
Expand Down Expand Up @@ -968,6 +1025,18 @@ def forward(self,
dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

# If BertEncoder uses sparse attention, it needs to be padded based on the sparse attention block size
if self.sparse_attention_config is not None:
pad_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = SparseAttentionUtils.pad_to_block_size(
block_size=self.sparse_attention_config.block,
input_ids=input_ids,
attention_mask=extended_attention_mask,
token_type_ids=token_type_ids,
position_ids=None,
inputs_embeds=None,
pad_token_id=self.pad_token_id,
model_mbeddings=self.embeddings)

embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(
embedding_output,
Expand All @@ -976,6 +1045,12 @@ def forward(self,
checkpoint_activations=checkpoint_activations)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)

# If BertEncoder uses sparse attention, and input_ids were padded, sequence output needs to be unpadded to original length
if self.sparse_attention_config is not None and pad_len > 0:
encoded_layers[-1] = SparseAttentionUtils.unpad_sequence_output(
pad_len, encoded_layers[-1])

if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output
Expand Down
4 changes: 4 additions & 0 deletions bing_bert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def get_argument_parser():
help=
'Use DeepSpeed transformer kernel memory optimization to checkpoint GELU activation.'
)
parser.add_argument('--deepspeed_sparse_attention',
default=False,
action='store_true',
help='Use DeepSpeed sparse self attention.')

parser.add_argument('--use_nvidia_dataset',
default=False,
Expand Down