diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 14a12b9394059d..bee0269ff82e9a 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -354,6 +354,11 @@ def _compute_llama3_parameters( def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's gracefully handle it + if "rope_type" not in received_keys and "type" in received_keys: + received_keys -= {"type"} + received_keys.add("rope_type") + missing_keys = required_keys - received_keys if missing_keys: raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") @@ -361,14 +366,14 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, if optional_keys is not None: unused_keys = received_keys - required_keys - optional_keys else: - unused_keys = received_keys - received_keys + unused_keys = received_keys - required_keys if unused_keys: - raise KeyError(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") def _validate_default_rope_parameters(config: PretrainedConfig): rope_scaling = config.rope_scaling - rope_type = rope_scaling["rope_type"] + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type"} received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys) @@ -376,19 +381,19 @@ def _validate_default_rope_parameters(config: PretrainedConfig): def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): rope_scaling = config.rope_scaling - rope_type = rope_scaling["rope_type"] + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): rope_scaling = config.rope_scaling - rope_type = rope_scaling["rope_type"] + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` optional_keys = {"original_max_position_embeddings"} @@ -397,12 +402,12 @@ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") def _validate_yarn_parameters(config: PretrainedConfig): rope_scaling = config.rope_scaling - rope_type = rope_scaling["rope_type"] + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} optional_keys = {"attention_factor", "beta_fast", "beta_slow"} received_keys = set(rope_scaling.keys()) @@ -410,22 +415,22 @@ def _validate_yarn_parameters(config: PretrainedConfig): factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): - raise ValueError( + logger.warning( f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" ) beta_fast = rope_scaling.get("beta_fast") if beta_fast is not None and not isinstance(beta_fast, float): - raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") beta_slow = rope_scaling.get("beta_slow") if beta_slow is not None and not isinstance(beta_slow, float): - raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") if (beta_fast or 32) < (beta_slow or 1): - raise ValueError( + logger.warning( f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" ) @@ -433,7 +438,7 @@ def _validate_yarn_parameters(config: PretrainedConfig): def _validate_longrope_parameters(config: PretrainedConfig): rope_scaling = config.rope_scaling - rope_type = rope_scaling["rope_type"] + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "short_factor", "long_factor"} # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} @@ -445,15 +450,15 @@ def _validate_longrope_parameters(config: PretrainedConfig): short_factor = rope_scaling.get("short_factor") if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): - raise ValueError(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") + logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") if not len(short_factor) == dim // 2: - raise ValueError(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") + logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") long_factor = rope_scaling.get("long_factor") if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): - raise ValueError(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") + logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") if not len(long_factor) == dim // 2: - raise ValueError(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") + logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is @@ -468,48 +473,48 @@ def _validate_longrope_parameters(config: PretrainedConfig): else: factor = rope_scaling.get("factor") if factor is None: - raise ValueError("Missing required keys in `rope_scaling`: 'factor'") + logger.warning("Missing required keys in `rope_scaling`: 'factor'") elif not isinstance(factor, float) or factor < 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - raise ValueError( + logger.warning( f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" ) def _validate_llama3_parameters(config: PretrainedConfig): rope_scaling = config.rope_scaling - rope_type = rope_scaling["rope_type"] + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] if low_freq_factor is None or not isinstance(low_freq_factor, float): - raise ValueError(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") + logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") if high_freq_factor is None or not isinstance(high_freq_factor, float): - raise ValueError(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") + logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") if high_freq_factor < low_freq_factor: - raise ValueError( + logger.warning( "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" f"{high_freq_factor} and low_freq_factor={low_freq_factor}" ) original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): - raise ValueError( + logger.warning( "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " f"{original_max_position_embeddings}" ) if original_max_position_embeddings >= config.max_position_embeddings: - raise ValueError( + logger.warning( "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" ) @@ -534,17 +539,12 @@ def rope_config_validation(config: PretrainedConfig): if rope_scaling is None: return - possible_rope_types = set(ROPE_INIT_FUNCTIONS.keys()) - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - if rope_type is None: - raise ValueError( - f"rope_scaling must contain a non-None 'rope_type' field. Possible options are {possible_rope_types}" - ) - + # BC: "rope_type" was originally "type" + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) if validation_fn is not None: validation_fn(config) else: - raise ValueError( + logger.warning( f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" ) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index c632a870be7a18..710809093f3849 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -189,6 +189,9 @@ def __init__( self.mlp_bias = mlp_bias # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) super().__init__( diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 85d352fc814f6f..19cb7bd6be393c 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -526,6 +526,60 @@ def test_rope_class_retrocompatibility(self): torch.testing.assert_close(old_cos_long, new_cos_long) torch.testing.assert_close(old_sin_long, new_sin_long) + def test_model_loading_old_rope_configs(self): + def _reinitialize_config(base_config, new_kwargs): + # Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation + # steps. + base_config_dict = base_config.to_dict() + new_config = LlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs}) + return new_config + + # from untouched config -> ✅ + base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common() + original_model = LlamaForCausalLM(base_config).to(torch_device) + original_model(**model_inputs) + + # from a config with the expected rope configuration -> ✅ + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}}) + original_model = LlamaForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC + config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}}) + original_model = LlamaForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config) + config = _reinitialize_config( + base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}} + ) + self.assertTrue(config.rope_scaling["type"] == "linear") + self.assertTrue(config.rope_scaling["rope_type"] == "linear") + original_model = LlamaForCausalLM(config).to(torch_device) + original_model(**model_inputs) + + # from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}}) + original_model = LlamaForCausalLM(config).to(torch_device) + original_model(**model_inputs) + self.assertEqual(len(logs.output), 1) + self.assertIn("factor field", logs.output[0]) + + # from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + config = _reinitialize_config( + base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}} + ) + original_model = LlamaForCausalLM(config).to(torch_device) + original_model(**model_inputs) + self.assertEqual(len(logs.output), 1) + self.assertIn("Unrecognized keys", logs.output[0]) + + # from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception + with self.assertRaises(KeyError): + config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" + @require_flash_attn @require_torch_gpu @require_bitsandbytes