From d036198e23345f3c25438f082396f7487028e8b6 Mon Sep 17 00:00:00 2001 From: Roy Date: Tue, 9 Apr 2024 06:17:21 +0800 Subject: [PATCH] [BugFix][Model] Fix commandr RoPE max_position_embeddings (#3919) --- vllm/model_executor/models/commandr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 4674dcbc14da6..29ba3844eb11d 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -140,7 +140,9 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.max_position_embeddings = config.max_position_embeddings + self.max_position_embeddings = getattr( + config, "model_max_length", None) or getattr( + config, "max_position_embeddings", 8192) self.rope_theta = config.rope_theta self.rope_scaling = getattr(config, "rope_scaling", None) self.use_qk_norm = getattr(config, "use_qk_norm", False)