diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 504eaad43c8d7..3dde498bcd639 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -43,3 +43,18 @@ def test_models( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + del vllm_model diff --git a/tests/models/test_models.py b/tests/models/test_models.py index cfe2539e3a052..e4609620387fa 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -49,3 +49,18 @@ def test_models( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + del vllm_model diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index fc65ae108dbb1..ee7be26c0876c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -47,3 +47,10 @@ def forward( ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, kv_scale) + + def extra_repr(self) -> str: + s = f"head_size={self.impl.head_size}" # type: ignore + s += f", num_heads={self.impl.num_heads}" # type: ignore + s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore + s += f", scale={self.impl.scale}" # type: ignore + return s diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index baf1d4f266181..d101aa323b0e1 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -67,6 +67,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_tanh_and_mul(out, x) return out + def extra_repr(self) -> str: + return f'approximate={repr(self.approximate)}' + class NewGELU(nn.Module): diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a6619714b8aab..8de0794158986 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -64,3 +64,8 @@ def forward( self.variance_epsilon, ) return out + + def extra_repr(self) -> str: + s = f"hidden_size={self.weight.data.size(0)}" + s += f", eps={self.variance_epsilon}" + return s diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 289b317cc991e..7726dcb9a5fbd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -181,6 +181,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output_bias = self.bias if self.skip_bias_add else None return output, output_bias + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -281,6 +287,14 @@ def forward(self, input_): output_bias = self.bias if self.skip_bias_add else None return output, output_bias + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", gather_output={self.gather_output}" + return s + class MergedColumnParallelLinear(ColumnParallelLinear): """Packed linear layers with column parallelism. @@ -685,3 +699,11 @@ def forward(self, input_): output = output_ output_bias = self.bias return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 22620d9fc86d9..91eb96998c3cf 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -70,6 +70,12 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, logits = logits[:, :self.org_vocab_size] return logits + def extra_repr(self) -> str: + s = f"vocab_size={self.vocab_size}" + s += f", forg_vocab_size={self.org_vocab_size}" + s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" + return s + def _prune_hidden_states( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 25365a9b50a1f..857d70fadcb57 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -156,6 +156,12 @@ def forward( self.cos_sin_cache, self.is_neox_style) return query, key + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + class LinearScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with linear scaling. diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 088c0849243c0..4585b1679cb5c 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -105,6 +105,14 @@ def forward(self, input_): output = tensor_model_parallel_all_reduce(output_parallel) return output + def extra_repr(self) -> str: + s = f"num_embeddings={self.num_embeddings_per_partition}" + s += f", embedding_dim={self.embedding_dim}" + s += f", org_vocab_size={self.org_vocab_size}" + s += f', num_embeddings_padded={self.num_embeddings_padded}' + s += f', tp_size={self.tp_size}' + return s + class ParallelLMHead(VocabParallelEmbedding): """Parallelized LM head.