Skip to content

Commit

Permalink
Use head_dim if in config for RoPE (#32495)
Browse files Browse the repository at this point in the history
* use head_dim if in config for RoPE

* typo

* simplify with getattr
  • Loading branch information
suiyoubi authored Aug 16, 2024
1 parent c215523 commit 5fd7ca7
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)

attention_factor = 1.0 # Unused in this type of RoPE

Expand Down Expand Up @@ -143,7 +144,8 @@ def _compute_dynamic_ntk_parameters(
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"]

Expand Down Expand Up @@ -185,7 +187,8 @@ def _compute_yarn_parameters(

base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"]

Expand Down Expand Up @@ -265,7 +268,8 @@ def _compute_longrope_parameters(

base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
long_factor = config.rope_scaling["long_factor"]
short_factor = config.rope_scaling["short_factor"]
factor = config.rope_scaling.get("factor")
Expand Down Expand Up @@ -450,7 +454,8 @@ def _validate_longrope_parameters(config: PretrainedConfig):
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)

partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)

short_factor = rope_scaling.get("short_factor")
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
Expand Down

0 comments on commit 5fd7ca7

Please sign in to comment.