Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Oct 18, 2024
2 parents 3969552 + 4d38bb0 commit 7593823
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/zamba2/configuration_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
eos_token_id=2,

ft_lora = False,
use_long_context=False,
**kwargs,
):

Expand All @@ -170,6 +171,9 @@ def __init__(
self.use_shared_block_lora = use_shared_block_lora
self.use_shared_attention_lora = use_shared_attention_lora
self.lora_rank = lora_rank
self.use_long_context=use_long_context
if use_long_context:
self.max_position_embeddings = 16384

# for backward compatibility
if num_key_value_heads is None:
Expand Down
15 changes: 13 additions & 2 deletions src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,17 @@ def forward(self, hidden_states):


class Zamba2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=4096, base=10000, device=None):
def __init__(self, config, dim, max_position_embeddings=4096, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
if config.use_long_context:
a = 8 #Alpha value
base = base * a ** (dim / (dim-2)) #Base change formula
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)


@torch.no_grad()
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
Expand Down Expand Up @@ -346,11 +349,14 @@ def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_me

if config.use_mem_rope:
self.rotary_emb = Zamba2RotaryEmbedding(
config,
self.head_dim,
max_position_embeddings=config.max_position_embeddings,
base=self.rope_theta,
)



def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -1131,6 +1137,11 @@ def __init__(self, config: Zamba2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

if config.use_long_context:
logger.warning_once(
f"`use_long_context` has been set to True, therefore `max_position_embeddings` will be set to 16384."
)

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.blocks = torch.nn.ModuleList([Zamba2AttentionDecoderLayer(config) for _ in range(config.num_mem_blocks)])
Expand Down

0 comments on commit 7593823

Please sign in to comment.