From 8fe82797461e7c59fcf55c96c2952ba7bdad00cd Mon Sep 17 00:00:00 2001 From: Arash Ashari Date: Mon, 31 Aug 2020 17:21:53 -0700 Subject: [PATCH 1/2] update bing_bert example to use sparse transformer (#19) * update bing_bert example to use sparse transformer * Updated teh BertSparseSelfAttention example based on the ST updates * updated bing_bert example based on final updates for Sparse Attention; also added un/pad of Bert layer input * updated based on Tunji's comment: added a separate script for SA * fixed a typo * added an exception when both transformer kernel and SA are set together. --- BingBertSquad/utils.py | 6 +- .../deepspeed_bsz64k_lamb_config_seq128.json | 10 ++ bing_bert/ds_sa_train_bert_bsz64k_seq128.sh | 25 +++++ bing_bert/nvidia/modelingpreln.py | 91 +++++++++++++++++-- bing_bert/utils.py | 4 + 5 files changed, 123 insertions(+), 13 deletions(-) create mode 100644 bing_bert/ds_sa_train_bert_bsz64k_seq128.sh diff --git a/BingBertSquad/utils.py b/BingBertSquad/utils.py index 75e754627..47db485f8 100755 --- a/BingBertSquad/utils.py +++ b/BingBertSquad/utils.py @@ -207,11 +207,7 @@ def get_argument_parser(): action='store_true', help='Use DeepSpeed transformer kernel to accelerate.') - parser.add_argument( - '--dropout', - type=float, - default=0.1, - help='dropout') + parser.add_argument('--dropout', type=float, default=0.1, help='dropout') return parser diff --git a/bing_bert/deepspeed_bsz64k_lamb_config_seq128.json b/bing_bert/deepspeed_bsz64k_lamb_config_seq128.json index d8f0457e2..2ae7ae5e4 100644 --- a/bing_bert/deepspeed_bsz64k_lamb_config_seq128.json +++ b/bing_bert/deepspeed_bsz64k_lamb_config_seq128.json @@ -20,5 +20,15 @@ "fp16": { "enabled": true, "loss_scale": 0 + }, + "sparse_attention": { + "mode": "fixed", + "block": 16, + "different_layout_per_head": true, + "num_local_blocks": 4, + "num_global_blocks": 1, + "attention": "bidirectional", + "horizontal_global_attention": false, + "num_different_global_patterns": 4 } } diff --git a/bing_bert/ds_sa_train_bert_bsz64k_seq128.sh b/bing_bert/ds_sa_train_bert_bsz64k_seq128.sh new file mode 100644 index 000000000..3132cc2b8 --- /dev/null +++ b/bing_bert/ds_sa_train_bert_bsz64k_seq128.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# This script runs deepspeed using sparse attention for BertEncoderLayer. + +base_dir=`pwd` + +# Where should we save checkpoints and tensorboard events? +JOB_NAME=lamb_64k_seq128 +OUTPUT_DIR=${base_dir}/bert_model_outputs + +mkdir -p $OUTPUT_DIR + +NCCL_TREE_THRESHOLD=0 deepspeed ${base_dir}/deepspeed_train.py \ +--cf ${base_dir}/bert_large_lamb.json \ +--max_seq_length 128 \ +--output_dir $OUTPUT_DIR \ +--deepspeed \ +--deepspeed_sparse_attention \ +--print_steps 100 \ +--lr_schedule "EE" \ +--lr_offset 10e-4 \ +--job_name $JOB_NAME \ +--deepspeed_config ${base_dir}/deepspeed_bsz64k_lamb_config_seq128.json \ +--data_path_prefix /data/bert \ +&> ${JOB_NAME}.log diff --git a/bing_bert/nvidia/modelingpreln.py b/bing_bert/nvidia/modelingpreln.py index d221acd0f..b91c65cce 100755 --- a/bing_bert/nvidia/modelingpreln.py +++ b/bing_bert/nvidia/modelingpreln.py @@ -43,6 +43,8 @@ import torch.nn.functional as F import torch.nn.init as init +from deepspeed.ops.sparse_attention import SparseAttentionUtils + logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = { @@ -66,6 +68,47 @@ TF_WEIGHTS_NAME = 'model.ckpt' +def get_deepspeed_config(args): + if hasattr(args, 'deepspeed_config') and args.deepspeed_config: + from deepspeed import DeepSpeedConfig + return DeepSpeedConfig(args.deepspeed_config) + else: + raise RuntimeError('deepspeed_config is not found in args.') + + +def get_sparse_attention_config(args, num_heads): + if args.deepspeed_sparse_attention: + ds_config = get_deepspeed_config(args) + if hasattr(ds_config, + 'sparse_attention') and ds_config.sparse_attention: + sa_config = ds_config.sparse_attention + sa_mode = sa_config.get('mode') + if (sa_mode == 'dense'): + from deepspeed.ops.sparse_attention import DenseSparsityConfig as STConfig + elif (sa_mode == 'fixed'): + from deepspeed.ops.sparse_attention import FixedSparsityConfig as STConfig + elif (sa_mode == 'bigbird'): + from deepspeed.ops.sparse_attention import BigBirdSparsityConfig as STConfig + elif (sa_mode == 'bslongformer'): + from deepspeed.ops.sparse_attention import BSLongformerSparsityConfig as STConfig + elif (sa_mode == 'variable'): + from deepspeed.ops.sparse_attention import VariableSparsityConfig as STConfig + else: + raise NotImplementedError( + f'Given sparsity mode, {sa_mode}, has not been implemented yet!' + ) + del sa_config['mode'] + return STConfig(num_heads=num_heads, **sa_config) + else: + from deepspeed.ops.sparse_attention import FixedSparsityConfig as STConfig + print( + 'deepspeed sparse attention is not set; Fixed sparsity is used as default.' + ) + return STConfig(num_heads=num_heads) + else: + return None + + def load_tf_weights_in_bert(model, tf_checkpoint_path): """ Load tf checkpoints in a pytorch model """ @@ -510,20 +553,21 @@ def forward(self, hidden_states, attention_mask): class BertEncoder(nn.Module): - def __init__(self, config, args): + def __init__(self, config, args, sparse_attention_config=None): super(BertEncoder, self).__init__() #Added later to make it similar to GPT-2 self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - if args.deepspeed_transformer_kernel: - from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig, DeepSpeedConfig + if args.deepspeed_transformer_kernel and args.deepspeed_sparse_attention: + raise NotImplementedError( + f'Currently DeepSpeed Transformer Kernels do not support Sparse Attention. To use Sparse Attention, you need to disable Transformer Kernels!' + ) - if hasattr(args, 'deepspeed_config') and args.deepspeed_config: - ds_config = DeepSpeedConfig(args.deepspeed_config) - else: - raise RuntimeError('deepspeed_config is not found in args.') + if args.deepspeed_transformer_kernel: + from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig + ds_config = self.get_deepspeed_config(args) cuda_config = DeepSpeedTransformerConfig( batch_size=ds_config.train_micro_batch_size_per_gpu, max_seq_length=args.max_seq_length, @@ -549,6 +593,12 @@ def __init__(self, config, args): ]) else: layer = BertLayer(config) + if sparse_attention_config is not None: + from deepspeed.ops.sparse_attention import BertSparseSelfAttention + + layer.attention.self = BertSparseSelfAttention( + config, sparsity_config=sparse_attention_config) + self.layer = nn.ModuleList([ copy.deepcopy(layer) for _ in range(config.num_hidden_layers) ]) @@ -936,7 +986,14 @@ class BertModel(BertPreTrainedModel): def __init__(self, config, args=None): super(BertModel, self).__init__(config) self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config, args) + # set pad_token_id that is used for sparse attention padding + self.pad_token_id = config.pad_token_id if hasattr( + config, 'pad_token_id') and config.pad_token_id is not None else 0 + # set sparse_attention_config if it has been selected + self.sparse_attention_config = get_sparse_attention_config( + args, config.num_attention_heads) + self.encoder = BertEncoder( + config, args, sparse_attention_config=self.sparse_attention_config) self.pooler = BertPooler(config) self.apply(self.init_bert_weights) logger.info("Init BERT pretrain model") @@ -968,6 +1025,18 @@ def forward(self, dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + # If BertEncoder uses sparse attention, it needs to be padded based on the sparse attention block size + if self.sparse_attention_config is not None: + pad_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = SparseAttentionUtils.pad_to_block_size( + block_size=self.sparse_attention_config.block, + input_ids=input_ids, + attention_mask=extended_attention_mask, + token_type_ids=token_type_ids, + position_ids=None, + inputs_embeds=None, + pad_token_id=self.pad_token_id, + model_mbeddings=self.embeddings) + embedding_output = self.embeddings(input_ids, token_type_ids) encoded_layers = self.encoder( embedding_output, @@ -976,6 +1045,12 @@ def forward(self, checkpoint_activations=checkpoint_activations) sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output) + + # If BertEncoder uses sparse attention, and input_ids were padded, sequence output needs to be unpadded to original length + if self.sparse_attention_config is not None and pad_len > 0: + encoded_layers[-1] = SparseAttentionUtils.unpad_sequence_output( + pad_len, encoded_layers[-1]) + if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] return encoded_layers, pooled_output diff --git a/bing_bert/utils.py b/bing_bert/utils.py index 21031b978..a1892221e 100755 --- a/bing_bert/utils.py +++ b/bing_bert/utils.py @@ -186,6 +186,10 @@ def get_argument_parser(): help= 'Use DeepSpeed transformer kernel memory optimization to checkpoint GELU activation.' ) + parser.add_argument('--deepspeed_sparse_attention', + default=False, + action='store_true', + help='Use DeepSpeed sparse self attention.') parser.add_argument('--use_nvidia_dataset', default=False, From 185c376f2fbd90c871bfe3568b16462979835bcc Mon Sep 17 00:00:00 2001 From: arashashari Date: Tue, 1 Sep 2020 16:23:54 +0000 Subject: [PATCH 2/2] fixed an issue with last PR: removed keyword self for function call as it was moved out of class --- bing_bert/nvidia/modelingpreln.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bing_bert/nvidia/modelingpreln.py b/bing_bert/nvidia/modelingpreln.py index b91c65cce..239f20ee8 100755 --- a/bing_bert/nvidia/modelingpreln.py +++ b/bing_bert/nvidia/modelingpreln.py @@ -567,7 +567,7 @@ def __init__(self, config, args, sparse_attention_config=None): if args.deepspeed_transformer_kernel: from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig - ds_config = self.get_deepspeed_config(args) + ds_config = get_deepspeed_config(args) cuda_config = DeepSpeedTransformerConfig( batch_size=ds_config.train_micro_batch_size_per_gpu, max_seq_length=args.max_seq_length,