diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index f468617f7..0e7dc9557 100755 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -119,7 +119,7 @@ def __init__(self, config, dim, max_position_embeddings=4096, base=10000, device self.dim = dim self.max_position_embeddings = max_position_embeddings if config.use_long_context: - a = 16 #Alpha value + a = 8 #Alpha value base = base * a ** (dim / (dim-2)) #Base change formula self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))