From 9f61138df3fc89c881e930a50ae4430050b05bdd Mon Sep 17 00:00:00 2001 From: Mddct Date: Thu, 29 Feb 2024 10:00:54 +0800 Subject: [PATCH] fix init rope attention and rope --- wenet/transformer/decoder.py | 54 ++++++++++++++++++++---------- wenet/transformer/decoder_layer.py | 3 +- wenet/transformer/embedding.py | 17 +++++++--- wenet/transformer/encoder.py | 16 +++++---- wenet/transformer/encoder_layer.py | 7 +++- 5 files changed, 65 insertions(+), 32 deletions(-) diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index 95584be22b..901f9f1545 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -82,17 +82,21 @@ def __init__( eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, + selfattention_layer_type: str = "selfattn", ): + assert selfattention_layer_type in ['selfattn', 'rope_selfattn'] super().__init__() attention_dim = encoder_output_size activation = WENET_ACTIVATION_CLASSES[activation_type]() + pos_emb_class = WENET_EMB_CLASSES[input_layer] self.embed = torch.nn.Sequential( torch.nn.Identity() if input_layer == "no_pos" else torch.nn.Embedding(vocab_size, attention_dim), - WENET_EMB_CLASSES[input_layer](attention_dim, - positional_dropout_rate), - ) + pos_emb_class(attention_dim, positional_dropout_rate) + if input_layer != 'rope' else pos_emb_class( + attention_dim, attention_dim // + attention_heads, positional_dropout_rate)) self.normalize_before = normalize_before self.after_norm = WENET_NORM_CLASSES[layer_norm_type](attention_dim, @@ -105,11 +109,12 @@ def __init__( else: self.output_layer = torch.nn.Identity() self.num_blocks = num_blocks + mlp_class = WENET_MLP_CLASSES[mlp_type] self.decoders = torch.nn.ModuleList([ DecoderLayer( attention_dim, - WENET_ATTENTION_CLASSES["selfattn"]( + WENET_ATTENTION_CLASSES[selfattention_layer_type]( attention_heads, attention_dim, self_attention_dropout_rate, @@ -119,7 +124,7 @@ def __init__( n_kv_head=n_kv_head, head_dim=head_dim, ), - WENET_ATTENTION_CLASSES["selfattn"]( + WENET_ATTENTION_CLASSES['selfattn']( attention_heads, attention_dim, src_attention_dropout_rate, @@ -191,12 +196,12 @@ def forward( tgt_mask = mask_to_bias(tgt_mask, tgt.dtype) memory_mask = mask_to_bias(memory_mask, memory_mask.dtype) - x, _ = self.embed(tgt) + x, pos_emb = self.embed(tgt) if self.gradient_checkpointing and self.training: x = self.forward_layers_checkpointed(x, tgt_mask, memory, - memory_mask) + memory_mask, pos_emb) else: - x = self.forward_layers(x, tgt_mask, memory, memory_mask) + x = self.forward_layers(x, tgt_mask, memory, memory_mask, pos_emb) if self.normalize_before: x = self.after_norm(x) if self.use_output_layer: @@ -204,22 +209,31 @@ def forward( olens = tgt_mask.sum(1) return x, torch.tensor(0.0), olens - def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor, - memory: torch.Tensor, - memory_mask: torch.Tensor) -> torch.Tensor: + def forward_layers( + self, + x: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor, + pos_emb: torch.Tensor = torch.empty(0), + ) -> torch.Tensor: for layer in self.decoders: x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, - memory_mask) + memory_mask, pos_emb) return x @torch.jit.ignore(drop=True) - def forward_layers_checkpointed(self, x: torch.Tensor, - tgt_mask: torch.Tensor, - memory: torch.Tensor, - memory_mask: torch.Tensor) -> torch.Tensor: + def forward_layers_checkpointed( + self, + x: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor, + pos_emb: torch.Tensor = torch.empty(0), + ) -> torch.Tensor: for layer in self.decoders: x, tgt_mask, memory, memory_mask = ckpt.checkpoint( - layer.__call__, x, tgt_mask, memory, memory_mask) + layer.__call__, x, tgt_mask, memory, memory_mask, pos_emb) return x def forward_one_step( @@ -244,7 +258,7 @@ def forward_one_step( y, cache: NN output value and cache per `self.decoders`. y.shape` is (batch, maxlen_out, token) """ - x, _ = self.embed(tgt) + x, pos_emb = self.embed(tgt) new_cache = [] for i, decoder in enumerate(self.decoders): if cache is None: @@ -255,6 +269,7 @@ def forward_one_step( tgt_mask, memory, memory_mask, + pos_emb, cache=c) new_cache.append(x) if self.normalize_before: @@ -336,6 +351,7 @@ def __init__( eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, + selfattention_layer_type: str = "selfattn", ): super().__init__() @@ -363,6 +379,7 @@ def __init__( eps=eps, n_kv_head=n_kv_head, head_dim=head_dim, + selfattention_layer_type=selfattention_layer_type, ) self.right_decoder = TransformerDecoder( @@ -388,6 +405,7 @@ def __init__( eps=eps, n_kv_head=n_kv_head, head_dim=head_dim, + selfattention_layer_type=selfattention_layer_type, ) def forward( diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py index 017d8dfccd..6f6ba64afd 100644 --- a/wenet/transformer/decoder_layer.py +++ b/wenet/transformer/decoder_layer.py @@ -69,6 +69,7 @@ def forward( tgt_mask: torch.Tensor, memory: torch.Tensor, memory_mask: torch.Tensor, + pos_emb: torch.Tensor = torch.empty(0), cache: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute decoded features. @@ -110,7 +111,7 @@ def forward( tgt_q_mask = tgt_mask[:, -1:, :] x = residual + self.dropout( - self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0]) if not self.normalize_before: x = self.norm1(x) diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index ce53cb9a6c..4e5c412d36 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -209,15 +209,15 @@ def precompute_freqs_cis(dim: int, return freqs_cis -# copy from:https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L95 +# modified from: +# https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L95 def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """Applies the rotary embedding to the query and key tensors.""" x_ = torch.view_as_complex( - torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1)) + torch.stack(torch.chunk(x.float(), 2, dim=-1), dim=-1)) x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) - x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], - -1).transpose(1, 2) + x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1) return x_out @@ -225,13 +225,16 @@ class RopePositionalEncoding(PositionalEncoding): def __init__(self, d_model: int, + pos_dim: int, dropout_rate: float, max_len: int = 1500, rope_theta=10000.0): + # NOTE(Mddct): pos_dim == attention_dim // attention_head super().__init__(d_model, dropout_rate=dropout_rate, max_len=max_len) delattr(self, 'pe') - self.pe = precompute_freqs_cis(d_model, max_len * 2, rope_theta) + self.pe = precompute_freqs_cis(pos_dim, max_len * 2, rope_theta) self.dropout_rate = dropout_rate + self.expand = False def forward( self, @@ -240,7 +243,11 @@ def forward( torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]: self.pe = self.pe.to(x.device) + if not self.expand: + self.pe = self.pe.unsqueeze(0) + self.expand = True pos_emb = self.position_encoding(offset, x.size(1), False) + pos_emb = pos_emb.unsqueeze(1) # [1, 1, seq, head_dim//2] # NOTE(Mddct): some model don't scale # TODO(Mddct): fix x = x * self.xscale diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index acbe179e32..6abdc88dcc 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -95,13 +95,13 @@ def __init__( self._output_size = output_size self.global_cmvn = global_cmvn + pos_emb_class = WENET_EMB_CLASSES[pos_enc_layer_type] self.embed = WENET_SUBSAMPLE_CLASSES[input_layer]( - input_size, - output_size, - dropout_rate, - WENET_EMB_CLASSES[pos_enc_layer_type](output_size, - positional_dropout_rate), - ) + input_size, output_size, dropout_rate, + pos_emb_class(output_size, positional_dropout_rate) + if pos_enc_layer_type != 'rope' else pos_emb_class( + output_size, output_size // + attention_heads, positional_dropout_rate)) self.normalize_before = normalize_before assert layer_norm_type in ['layer_norm', 'rms_norm'] @@ -373,6 +373,7 @@ def __init__( eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, + selfattention_layer_type: str = "selfattn", ): """ Construct TransformerEncoder @@ -385,12 +386,13 @@ def __init__( static_chunk_size, use_dynamic_chunk, global_cmvn, use_dynamic_left_chunk, gradient_checkpointing, use_sdpa, layer_norm_type, eps) + assert selfattention_layer_type in ['selfattn', 'rope_selfattn'] activation = WENET_ACTIVATION_CLASSES[activation_type]() mlp_class = WENET_MLP_CLASSES[mlp_type] self.encoders = torch.nn.ModuleList([ TransformerEncoderLayer( output_size, - WENET_ATTENTION_CLASSES["selfattn"]( + WENET_ATTENTION_CLASSES[selfattention_layer_type]( attention_heads, output_size, attention_dropout_rate, diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index 7ed9f4249b..da89920d2d 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -95,7 +95,12 @@ def forward( residual = x if self.normalize_before: x = self.norm1(x) - x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache) + x_att, new_att_cache = self.self_attn(x, + x, + x, + mask, + pos_emb, + cache=att_cache) x = residual + self.dropout(x_att) if not self.normalize_before: x = self.norm1(x)