Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support configuration of RoPE theta #351

Merged
merged 1 commit into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,9 @@ def _add_network_size_args(parser):
'This is the size of position embedding.')
group.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not')
group.add_argument('--rotary-position-embeddings-theta', type=int, default=10000,
help='Rotary positional embeddings theta value.',
dest='rope_theta')
group.add_argument('--rotary-percent', type=float, default=1.0,
help='Percent of rotary dimension to use, default 100%')
group.add_argument('--no-position-embedding',
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def __init__(self,
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al
# https://github.com/kingoflolz/mesh-transformer-jax/
self.rotary_pos_emb = RotaryEmbedding(rotary_dim)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim, theta=args.rope_theta)

# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
Expand Down
4 changes: 2 additions & 2 deletions megatron/model/rotary_pos_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb']

class RotaryEmbedding(nn.Module):
def __init__(self, dim):
def __init__(self, dim, theta=10000):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
if importlib.util.find_spec('einops') is None:
raise RuntimeError("einops is required for Rotary Embedding")
Expand Down
2 changes: 1 addition & 1 deletion megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def update_rotary_pos_emb(seq_length):
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al
# https://github.com/kingoflolz/mesh-transformer-jax/
rotary_pos_emb = RotaryEmbedding(rotary_dim)(seq_length).to(
rotary_pos_emb = RotaryEmbedding(rotary_dim, theta=args.rope_theta)(seq_length).to(
get_accelerator().current_device_name())
args.rotary_pos_emb = rotary_pos_emb

Expand Down