From 635f816b5bb2403620c2f946a75e224abcb864a3 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 6 Apr 2021 21:54:15 +0530 Subject: [PATCH] [WIP] GPT Neo cleanup (#10985) * better names * add attention mixin * all slow tests in one class * make helper methods static so we can test * add local attention tests * better names * doc * apply review suggestions --- .../models/gpt_neo/modeling_gpt_neo.py | 417 ++++++++++-------- tests/test_modeling_gpt_neo.py | 191 ++++++-- 2 files changed, 392 insertions(+), 216 deletions(-) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 9fb0d7475fb9d6..72ccaf15e86638 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -130,7 +130,130 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): return model -class GPTNeoSelfAttention(nn.Module): +class GPTNeoAttentionMixin: + """ + A few attention related utilities for attention modules in GPT Neo, to be used as a mixin. + """ + + @staticmethod + def _get_block_length_and_num_blocks(seq_length, window_size): + """ + Computes ``block_length`` and ``num_blocks`` such that ``seq_length`` becomes evenly divisible by + ``block_length``. + """ + block_length = window_size + while seq_length % block_length != 0: + block_length -= 1 + num_blocks = seq_length // block_length + return block_length, num_blocks + + @staticmethod + def _look_back(tensor, block_length, window_size, pad_value=0, is_key_value=True): + """ + Used to implement attention between consecutive blocks. This method assumes that dim 1 of :obj:`tensor` + represents the :obj:`seq_length` dimention. It splits :obj:`seq_length` dimention into :obj:`num_blocks` and + :obj:`window_size` + :obj:`block_length`. It pads the :obj:`seq_length` dimention if necessary. + + Example:: + + tensor: torch.tensor([[[ 0.4983], [ 2.6918], [-0.0071], [ 1.0492], [-1.8348], [ 0.7672], [ 0.2986], [ 0.0285]]]) + with shape (1, 8, 1) + block_length = window_size = 4 + _look_back => + torch.tensor([[[[ 0.0000], [ 0.0000], [ 0.0000], [ 0.0000], [ 0.4983], [ 2.6918], [-0.0071], [ 1.0492]], + [[ 0.4983], [ 2.6918], [-0.0071], [ 1.0492], [-1.8348], [ 0.7672], [ 0.2986], [ 0.0285]]]]) + + Args: + tensor (:obj:`torch.Tensor`): tensor of shape :obj:`[batch_size, seq_length, hidden_dim]` or :obj:`[batch_size, seq_length]` + block_length (:obj:`int`): An integer specifying the length of each block, used as a step size when creating the blocks. + window_size (:obj:`int`): An integer specifying the size of attention window, used to calculate the final block size when creating the block. + pad_value (obj:`int`): An integer specifying the value to use when padding the :obj:`tensor`. + is_key_value (:obj:`bool`): A boolean indicating if the :obj:`tensor` is a key/value tensor. + + Returns: + tensor of shape :obj:`[batch_size, num_blocks, window_size + block_length, ...]` if :obj:`is_key_value` is + :obj:`True` else a tensor of shape :obj:`[batch_size, window_size + block_length, num_blocks, ...]` + """ + if len(tensor.shape) == 3: + padding_side = (0, 0, window_size, 0) + elif len(tensor.shape) == 2: + padding_side = (window_size, 0) + else: + raise ValueError(f"Input tensor rank should be one of [2, 3], but is: {len(tensor.shape)}") + + padded_tensor = F.pad(tensor, padding_side, value=pad_value) + padded_tensor = padded_tensor.unfold(dimension=1, size=window_size + block_length, step=block_length) + + if is_key_value: + padded_tensor = padded_tensor.transpose(-2, -1) + return padded_tensor + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(*new_shape) + if len(tensor.shape) == 5: + return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features) + elif len(tensor.shape) == 4: + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + if len(tensor.shape) == 5: + tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() + elif len(tensor.shape) == 4: + tensor = tensor.permute(0, 2, 1, 3).contiguous() + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def _split_seq_length_dim_to(self, tensors, dim_factor_1, dim_factor_2, hidden_size): + """ + Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims + """ + batch_size = tensors.shape[0] + split_dim_shape = (batch_size, dim_factor_1, dim_factor_2) + + if len(tensors.shape) == 3: + return torch.reshape(tensors, split_dim_shape + (hidden_size,)) + elif len(tensors.shape) == 2: + return torch.reshape(tensors, split_dim_shape) + else: + raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}") + + def _attn(self, query, key, value, causal_mask, masked_bias, attn_dropout, attention_mask=None, head_mask=None): + # Keep the attention weights computation in fp32 to avoid overflow issues + query = query.to(torch.float32) + key = key.to(torch.float32) + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.Softmax(dim=-1)(attn_weights) + attn_weights = attn_weights.to(value.dtype) + attn_weights = attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + +class GPTNeoSelfAttention(nn.Module, GPTNeoAttentionMixin): def __init__(self, config): super().__init__() @@ -149,56 +272,16 @@ def __init__(self, config): self.embed_dim = config.hidden_size self.num_heads = config.num_heads self.head_dim = self.embed_dim // self.num_heads - assert ( - self.head_dim * self.num_heads == self.embed_dim - ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + ) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) - def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): - # Keep the attention weights computation in fp32 to avoid overflow issues - q = q.to(torch.float32) - k = k.to(torch.float32) - - attn_weights = torch.matmul(q, k) - nd, ns = attn_weights.size(-2), attn_weights.size(-1) - - mask = self.bias[:, :, ns - nd : ns, :ns] - attn_weights = torch.where(mask.bool(), attn_weights, self.masked_bias.to(attn_weights.dtype)) - - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask - - attn_weights = nn.Softmax(dim=-1)(attn_weights) - attn_weights = attn_weights.to(v.dtype) - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - outputs = (torch.matmul(attn_weights, v),) - if output_attentions: - outputs += (attn_weights,) - return outputs - - def merge_heads(self, x): - x = x.permute(0, 2, 1, 3).contiguous() - new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) - return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states - - def split_heads(self, x, k=False): - new_x_shape = x.size()[:-1] + (self.num_heads, x.size(-1) // self.num_heads) - x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states - if k: - return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) - else: - return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) - def forward( self, hidden_states, @@ -213,31 +296,40 @@ def forward( key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) - query = self.split_heads(query) - key = self.split_heads(key, k=True) - value = self.split_heads(value) + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: - past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below - key = torch.cat((past_key, key), dim=-1) + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) if use_cache is True: - present = (key.transpose(-2, -1), value) # transpose to have same shapes + present = (key, value) else: present = None - attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) - a = attn_outputs[0] + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + + attn_output, attn_weights = self._attn( + query, key, value, causal_mask, self.masked_bias, self.attn_dropout, attention_mask, head_mask + ) - a = self.merge_heads(a) - a = self.out_proj(a) - a = self.resid_dropout(a) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) - return (a, present) + attn_outputs[1:] # a, present, (attentions) + return outputs # a, present, (attentions) -class GPTNeoLocalSelfAttention(nn.Module): +class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin): def __init__(self, config): super().__init__() @@ -249,9 +341,10 @@ def __init__(self, config): self.embed_dim = config.hidden_size self.num_heads = config.num_heads self.head_dim = self.embed_dim // self.num_heads - assert ( - self.head_dim * self.num_heads == self.embed_dim - ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + ) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) @@ -260,94 +353,39 @@ def __init__(self, config): self.window_size = config.window_size - def shift(self, x, offset, pad_value=0, dim=2): - t = x.shape[1] - dims = (len(x.shape) - dim) * (0, 0) - padded_x = F.pad(x, (*dims, offset, 0), value=pad_value) - return padded_x[:, :t, ...] - - def look_around(self, x, block_length, window_size): - num_complete_blocks = window_size // block_length - - parts = [x] - for i in range(1, num_complete_blocks + 1): - parts = [self.shift(x, i)] + parts - - partial_size = window_size % block_length - if partial_size > 0: - margin = x[:, :, block_length - partial_size : block_length, ...] - parts = [self.shift(margin, num_complete_blocks + 1)] + parts - return torch.cat(parts, dim=2) - - def split_heads(self, x, k=False): - new_x_shape = x.size()[:-1] + (self.num_heads, x.size(-1) // self.num_heads) - x = x.view(*new_x_shape) - if k: - return x.permute(0, 1, 3, 4, 2) # (batch, chunks, head, head_features, seq_length) - else: - return x.permute(0, 1, 3, 2, 4) # (batch, chunks, head, seq_length, head_features) - - def merge_heads(self, x): - x = x.permute(0, 1, 3, 2, 4).contiguous() - new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) - return x.view(*new_x_shape) + def _create_attention_mask(self, batch_size, seq_length, num_blocks, block_length, device, attention_mask=None): + indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1) - def _split_seq_length_dim_to(self, tensors, num_blocks, block_length): - return tensors.reshape(tensors.size()[0], num_blocks, block_length, -1) + query_indices = self._split_seq_length_dim_to(indices, num_blocks, block_length, self.embed_dim) + key_indices = self._look_back(indices, block_length, self.window_size, is_key_value=False) - def create_attention_mask(self, bs, seq_len, windows, block_length, attention_mask): - ticker = torch.arange(seq_len)[None, :] - b_t = ticker.reshape(1, windows, block_length) + # create mask tensor such that each block contains a causal_mask for that block + causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)) - bq_t = b_t - bq_k = self.look_around(b_t, block_length, self.window_size) + if attention_mask is None: + attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device) - # compute attn mask - # this matches the original implem in mess-tensorflow - # https://github.com/tensorflow/mesh/blob/8bd599a21bad01cef1300a8735c17306ce35db6e/mesh_tensorflow/transformer/attention.py#L805 - relative_position = bq_k.unsqueeze(-2) - bq_t.unsqueeze(-1) - relative_position = relative_position.transpose(-1, -2) + # A block can also be padded becuase of the _look_back operation + # look back into the attention_block such that it will also get padded the same way + # and have 0s in the padded position + attention_mask = self._look_back(attention_mask, block_length, self.window_size, is_key_value=False) + attention_mask = attention_mask.unsqueeze(-2) # Add an extra dimention to account for hidden_dim - sequence_id = torch.ones(bs, seq_len) - q_seq = sequence_id.reshape(-1, windows, block_length) - m_seq = sequence_id.reshape(-1, windows, block_length) - m_seq = self.look_around(m_seq, block_length, self.window_size) + # Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation) + # will contain 0s. + # This also makes sure that other positions ignored by the attention_mask will also be ignored + # in the causal_mask. + causal_mask = causal_mask * attention_mask - if attention_mask is not None: - attention_mask = attention_mask.to(m_seq.device) - attention_mask = attention_mask.reshape(-1, windows, block_length) - attention_mask = self.look_around(attention_mask, block_length, self.window_size) - m_seq *= attention_mask + # In GPT Neo's local attention each window can attend to at most window_size tokens + # rest of the tokens should be ignored. + relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1) + visible = torch.gt(relative_position, -self.window_size) - visible = torch.eq(q_seq.unsqueeze(-1), m_seq.unsqueeze(-2)).transpose(-1, -2) - visible = torch.logical_and(visible, torch.gt(relative_position, -self.window_size)) - mask = torch.logical_and(visible, torch.less_equal(relative_position, 0)).transpose(-1, -2).unsqueeze(2) - return mask + causal_mask = causal_mask * visible + causal_mask = causal_mask.unsqueeze(-3).bool() # Add an extra dimention to account for num_heads - def _attn(self, q, k, v, causal_mask, head_mask=None, output_attentions=False): - # attn - - # Keep the attention weights computation in fp32 to avoid overflow issues - q = q.to(torch.float32) - k = k.to(torch.float32) - - attn_weights = torch.matmul(q, k) - attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) - - attn_weights = nn.Softmax(dim=-1)(attn_weights) - attn_weights = attn_weights.to(v.dtype) - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, v) - - outputs = (attn_output,) - if output_attentions: - outputs += (attn_weights,) - return outputs + return causal_mask def forward( self, @@ -371,51 +409,58 @@ def forward( key = self.k_proj(key_value_hidden_states) value = self.v_proj(key_value_hidden_states) - # compute block length and windows - bs, seq_len = hidden_states.shape[:2] - full_seq_length = seq_len + past_length - block_length = self.window_size - while full_seq_length % block_length != 0: - block_length -= 1 - num_blocks = full_seq_length // block_length + # compute block length and num_blocks + batch_size, seq_length = hidden_states.shape[:2] + full_seq_length = seq_length + past_length + block_length, num_blocks = self._get_block_length_and_num_blocks(full_seq_length, self.window_size) # create buckets if layer_past is not None: - # we just need 1 window with block_length 1 when caching is enabled - query = self._split_seq_length_dim_to(query, 1, 1) + # we just need 1 block with block_length 1 when caching is enabled + query = self._split_seq_length_dim_to(query, 1, 1, self.embed_dim) else: - query = self._split_seq_length_dim_to(query, num_blocks, block_length) - - key = self._split_seq_length_dim_to(key, num_blocks, block_length) - value = self._split_seq_length_dim_to(value, num_blocks, block_length) + query = self._split_seq_length_dim_to(query, num_blocks, block_length, self.embed_dim) - key = self.look_around(key, block_length, self.window_size) - value = self.look_around(value, block_length, self.window_size) + key = self._look_back(key, block_length, self.window_size) + value = self._look_back(value, block_length, self.window_size) - # select key/value vectors only for the last window + # select key/value vectors only for the last block if layer_past is not None: key = key[:, -1:, ...] value = value[:, -1:, ...] - query = self.split_heads(query) - key = self.split_heads(key, k=True) - value = self.split_heads(value) + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) - mask = self.create_attention_mask(bs, full_seq_length, num_blocks, block_length, attention_mask) + mask = self._create_attention_mask( + batch_size, full_seq_length, num_blocks, block_length, hidden_states.device, attention_mask + ) if layer_past is not None: - mask = mask[:, -1:, :, -1:, :] # only take the mask for the last window - mask = mask.to(hidden_states.device) + mask = mask[:, -1:, :, -1:, :] # only take the mask for the last block # attn - attn_outputs = self._attn(query, key, value, mask, head_mask, output_attentions) - attn = attn_outputs[0] + attn_output, attn_weights = self._attn( + query, + key, + value, + causal_mask=mask, + masked_bias=self.masked_bias, + attn_dropout=self.attn_dropout, + head_mask=head_mask, + ) - attn = self.merge_heads(attn) - attn = attn.reshape(bs, seq_len, self.embed_dim) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(batch_size, seq_length, self.embed_dim) - attn = self.out_proj(attn) - attn = self.resid_dropout(attn) - return (attn,) + attn_outputs[1:] + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output,) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, (attentions) class GPTNeoAttention(nn.Module): @@ -464,7 +509,7 @@ def forward( return outputs -class MLP(nn.Module): +class GPTNeoMLP(nn.Module): def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * hidden_size super().__init__() embed_dim = config.hidden_size @@ -473,13 +518,15 @@ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * self.act = ACT2FN[config.activation_function] self.dropout = nn.Dropout(config.resid_dropout) - def forward(self, x): - h = self.act(self.c_fc(x)) - h2 = self.c_proj(h) - return self.dropout(h2) + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states -class Block(nn.Module): +class GPTNeoBlock(nn.Module): def __init__(self, config, layer_id): super().__init__() hidden_size = config.hidden_size @@ -487,7 +534,7 @@ def __init__(self, config, layer_id): self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = GPTNeoAttention(config, layer_id) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = MLP(inner_dim, config) + self.mlp = GPTNeoMLP(inner_dim, config) def forward( self, @@ -498,8 +545,10 @@ def forward( use_cache=False, output_attentions=False, ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( - self.ln_1(hidden_states), + hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, @@ -509,11 +558,13 @@ def forward( attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection - hidden_states = attn_output + hidden_states + hidden_states = attn_output + residual - feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) # residual connection - hidden_states = hidden_states + feed_forward_hidden_states + hidden_states = residual + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs @@ -638,7 +689,7 @@ def _init_weights(self, module): @add_start_docstrings( - "The bare GPTNeo Model transformer outputting raw hidden-states without any specific head on top.", + "The bare GPT Neo Model transformer outputting raw hidden-states without any specific head on top.", GPT_NEO_START_DOCSTRING, ) class GPTNeoModel(GPTNeoPreTrainedModel): @@ -649,7 +700,7 @@ def __init__(self, config): self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.drop = nn.Dropout(config.embed_dropout) - self.h = nn.ModuleList([Block(config, layer_id=i) for i in range(config.num_layers)]) + self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.init_weights() diff --git a/tests/test_modeling_gpt_neo.py b/tests/test_modeling_gpt_neo.py index 023a9d265edfdb..14d966d61b4bce 100644 --- a/tests/test_modeling_gpt_neo.py +++ b/tests/test_modeling_gpt_neo.py @@ -18,6 +18,7 @@ import unittest from transformers import is_torch_available +from transformers.file_utils import cached_property from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester @@ -35,6 +36,7 @@ GPTNeoForCausalLM, GPTNeoModel, ) + from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin, GPTNeoLocalSelfAttention class GPTNeoModelTester: @@ -430,11 +432,164 @@ def _check_attentions_for_generate( # check attn size self.assertListEqual(shapes, expected_shape) + +@require_torch +class GPTNeoLocalAttentionTest(unittest.TestCase): + def _get_hidden_states(self): + return torch.tensor( + [ + [ + [0.4983, -0.7584, -1.6944, 0.5440], + [2.6918, 0.4206, 0.4176, 0.2055], + [-0.0071, -0.0405, -1.4920, -0.3630], + [1.0492, 0.1599, -1.7648, 0.2419], + [-1.8348, 2.0514, -0.1946, 0.3203], + [0.7672, -1.1600, -1.7118, -0.9056], + [0.2986, 0.5372, 0.7729, -0.1927], + [0.0285, 0.2629, -1.1156, -1.1992], + ] + ], + dtype=torch.float32, + device=torch_device, + ) + + def test_look_back(self): + hidden_states = self._get_hidden_states() + batch_size, seq_length, hidden_size = hidden_states.shape + + # check when seq_length is divisible by window_size + window_size = 4 + block_length, num_block = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size) + blocked_hidden_states = GPTNeoAttentionMixin._look_back(hidden_states, block_length, window_size) + expected_shape = [batch_size, num_block, window_size + block_length, hidden_size] + self.assertListEqual(list(blocked_hidden_states.shape), expected_shape) + # The last block should contain the last (window_size + block_length) hidden_states + self.assertTrue( + torch.all(blocked_hidden_states[:, -1, ...] == hidden_states[:, -(window_size + block_length) :, ...]) + ) + + # check when seq_length is not divisible by window_size + window_size = 3 + block_length, num_block = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size) + blocked_hidden_states = GPTNeoAttentionMixin._look_back(hidden_states, block_length, window_size) + expected_shape = [batch_size, num_block, window_size + block_length, hidden_size] + self.assertListEqual(list(blocked_hidden_states.shape), expected_shape) + # The last block should contain the last (window_size + block_length) hidden_states + self.assertTrue( + torch.all(blocked_hidden_states[:, -1, ...] == hidden_states[:, -(window_size + block_length) :, ...]) + ) + + # check when window_size is > seq_length + window_size = 19 + block_length, num_block = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size) + blocked_hidden_states = GPTNeoAttentionMixin._look_back(hidden_states, block_length, window_size) + expected_shape = [batch_size, num_block, window_size + block_length, hidden_size] + self.assertListEqual(list(blocked_hidden_states.shape), expected_shape) + + # when window_size > seq_length, num_blocks becomes 1, in this case + # the first window_size values in blocked_hidden_staes are all zeros + # and the last block_length values are equal to the hidden_states + values = blocked_hidden_states[:, -1, :window_size, ...] + expected_values = torch.zeros_like(values) + self.assertTrue(torch.all(values == expected_values)) + + self.assertTrue(torch.all(blocked_hidden_states[:, -1, -block_length:, ...] == hidden_states)) + + def test_create_attention_mask(self): + config = GPTNeoConfig.from_pretrained("valhalla/gpt-neo-random-tiny") + layer = GPTNeoLocalSelfAttention(config) + window_size = config.window_size + batch_size, seq_length = 8, 1 + block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size) + + causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device) + # check shapes + expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length] + self.assertListEqual(list(causal_mask.shape), expected_shape) + # first window_size tokens in the first block are always padded + # and should not be attended + self.assertTrue(torch.all(causal_mask[:, 0, :, :, :window_size] == 0)) + # each window can attend at most window_size tokens + self.assertTrue(torch.all(torch.sum(causal_mask, dim=4) <= config.window_size)) + + # check if user provided attention_mask is handled correctly + attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=torch_device) + attention_mask[:, -3:] = 0 # don't attend last 3 tokens + + causal_mask = layer._create_attention_mask( + batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask + ) + # last 3 tokens will be in the last block and shoul have 0s in causal_mask + self.assertTrue(torch.all(causal_mask[:, -1, :, :, -3:] == 0)) + # check shapes + expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length] + self.assertListEqual(list(causal_mask.shape), expected_shape) + # first window_size tokens in the first block are always padded + # and should not be attended + self.assertTrue(torch.all(causal_mask[:, 0, :, :, :window_size] == 0)) + # each window can attend at most window_size tokens + self.assertTrue(torch.all(torch.sum(causal_mask, dim=4) <= config.window_size)) + + def test_local_attn_probs(self): + model = GPTNeoModel.from_pretrained("valhalla/gpt-neo-random-tiny").eval() + layer = model.h[1].attn.attention.to(torch_device) + hidden_states = self._get_hidden_states() + hidden_states = torch.cat([hidden_states, hidden_states - 0.5], dim=2) + batch_size, seq_length, hidden_size = hidden_states.shape + mask_tokens = 3 + attention_mask = torch.ones(batch_size, seq_length, device=torch_device, dtype=torch.long) + attention_mask[:, -mask_tokens:] = 0 # dont atten last mask_tokens + + _, attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True) + + # the last 3 tokens will be in the last block, and should have 0 attn_probs + self.assertTrue(torch.all(attn_probs[:, -1, :, -mask_tokens:, -mask_tokens:] == 0)) + # the first config.window_size tokens in the first block are always padded + # and should have 0 attn_probs + self.assertTrue(torch.all(attn_probs[:, 0, :, : model.config.window_size :, : model.config.window_size] == 0)) + + +@require_torch +class GPTNeoModelLanguageGenerationTest(unittest.TestCase): + @cached_property + def model(self): + return GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B").to(torch_device) + + @cached_property + def tokenizer(self): + return GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B") + + @slow + def test_lm_generate_gpt_neo(self): + for checkpointing in [True, False]: + model = self.model + model.config.gradient_checkpointing = checkpointing + input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog + # fmt: off + # The dog-eared copy of the book, which is a collection of essays by the late author, + expected_output_ids = [464, 3290, 12, 3380, 4866, 286, 262, 1492, 11, 543, 318, 257, 4947, 286, 27126, 416, 262, 2739, 1772, 11] + # fmt: on + output_ids = model.generate(input_ids, do_sample=False) + self.assertListEqual(output_ids[0].tolist(), expected_output_ids) + + @slow + def test_gpt_neo_sample(self): + model = self.model + tokenizer = self.tokenizer + + torch.manual_seed(0) + tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) + input_ids = tokenized.input_ids.to(torch_device) + output_ids = model.generate(input_ids, do_sample=True) + output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + EXPECTED_OUTPUT_STR = "Today is a nice day and if you don’t get the memo here is what you can" + self.assertEqual(output_str, EXPECTED_OUTPUT_STR) + @slow def test_batch_generation(self): - model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") - model.to(torch_device) - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + model = self.model + tokenizer = self.tokenizer tokenizer.padding_side = "left" @@ -479,33 +634,3 @@ def test_model_from_pretrained(self): for model_name in GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = GPTNeoModel.from_pretrained(model_name) self.assertIsNotNone(model) - - -@require_torch -class GPTNeoModelLanguageGenerationTest(unittest.TestCase): - @slow - def test_lm_generate_gpt_neo(self): - for checkpointing in [True, False]: - model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B", gradient_checkpointing=checkpointing) - model.to(torch_device) - input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog - # fmt: off - expected_output_ids = [464, 3290, 12, 3380, 4866, 286, 262, 1492, 11, 543, 318, 257, 4947, 286, 27126, 416, 262, 2739, 1772, 11] # The dog-eared copy of the book, which is a collection of essays by the late author, - # fmt: on - output_ids = model.generate(input_ids, do_sample=False) - self.assertListEqual(output_ids[0].tolist(), expected_output_ids) - - @slow - def test_gpt_neo_sample(self): - tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B") - model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") - model.to(torch_device) - - torch.manual_seed(0) - tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) - input_ids = tokenized.input_ids.to(torch_device) - output_ids = model.generate(input_ids, do_sample=True) - output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) - - EXPECTED_OUTPUT_STR = "Today is a nice day and if you don’t get the memo here is what you can" - self.assertEqual(output_str, EXPECTED_OUTPUT_STR)