From 182d5d7755551c6d4bc948d9dabfdb0421555c0a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 5 Apr 2024 09:39:48 -0700 Subject: [PATCH 01/14] enable oot model register --- vllm/model_executor/models/__init__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b5c7e44de619c..a2b41cb7c6d2c 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,5 +1,5 @@ import importlib -from typing import List, Optional, Type +from typing import Dict, List, Optional, Type import torch.nn as nn @@ -55,6 +55,10 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), } +# Architecture -> type. +# out of tree models +_OOT_MODELS: Dict[str, Type[nn.Module]] = {} + # Models not supported by ROCm. _ROCM_UNSUPPORTED_MODELS = [] @@ -74,6 +78,8 @@ class ModelRegistry: @staticmethod def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch in _OOT_MODELS: + return _OOT_MODELS[model_arch] if model_arch not in _MODELS: return None if is_hip(): @@ -95,6 +101,12 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: def get_supported_archs() -> List[str]: return list(_MODELS.keys()) + @staticmethod + def register_out_of_tree_model(model_arch: str, + model_cls: Type[nn.Module]): + global _OOT_MODELS + _OOT_MODELS[model_arch] = model_cls + __all__ = [ "ModelRegistry", From d639c010da98eb8bc8729783241da2dfca333369 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 5 Apr 2024 09:52:27 -0700 Subject: [PATCH 02/14] add doc for guiding oot model register --- docs/source/models/adding_model.rst | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 45ef0340aae25..c8eced1708522 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -21,6 +21,8 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor Start by forking our `GitHub`_ repository and then :ref:`build it from source `. This gives you the ability to modify the codebase and test your model. +.. tip:: + If you don't want to fork the repository and modify vLLM's codebase, please refer to the "Out-of-Tree Model Integration" section below. 1. Bring your model code ------------------------ @@ -94,3 +96,16 @@ This method should load the weights from the HuggingFace's checkpoint file and a ---------------------- Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py `_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py `_. + +6. Out-of-Tree Model Integration +---------------------- + +We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5. + +Just add the following lines in your code: + +.. code-block:: python + + from vllm.model_executor.models import ModelRegistry + from your_code import YourModelForCausalLM + ModelRegistry.register_out_of_tree_model("YourModelForCausalLM", YourModelForCausalLM) From 3273706a86d0b70154339e6907dc1593bb42807c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 5 Apr 2024 10:22:32 -0700 Subject: [PATCH 03/14] fix doc lint --- docs/source/models/adding_model.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index c8eced1708522..4bcc9446b52d8 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -98,7 +98,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py `_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py `_. 6. Out-of-Tree Model Integration ----------------------- +-------------------------------------------- We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5. From 85e7553f7b29148f2bee91eb4a315dd77894b8ea Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 5 Apr 2024 17:16:44 -0700 Subject: [PATCH 04/14] add test for oot models --- tests/models/test_oot_registration.py | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 tests/models/test_oot_registration.py diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py new file mode 100644 index 0000000000000..5324d9dbce8fa --- /dev/null +++ b/tests/models/test_oot_registration.py @@ -0,0 +1,31 @@ +import torch +from torch import nn +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.opt import OPTForCausalLM +from vllm import LLM, SamplingParams + +class MyOPTForCausalLM(OPTForCausalLM): + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + # this dummy model always predicts the first token + logits = super().compute_logits(hidden_states, sampling_metadata) + logits.zero_() + logits[:, 0] += 1.0 + return logits + +def test_oot_registration(): + # register our dummy model + ModelRegistry.register_out_of_tree_model("OPTForCausalLM", MyOPTForCausalLM) + prompts = ["Hello, my name is", "The text does not matter"] + sampling_params = SamplingParams(temperature=0) + llm = LLM(model="facebook/opt-125m") + first_token = llm.get_tokenizer().decode(0) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + # make sure only the first token is generated + rest = generated_text.replace(first_token, "") + assert rest == "" From 8fb71706caa24011c165d9cb5953f2541cbc4a1a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 5 Apr 2024 17:17:12 -0700 Subject: [PATCH 05/14] fix linter --- tests/models/test_oot_registration.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 5324d9dbce8fa..91745c24bf14e 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -1,11 +1,13 @@ import torch -from torch import nn -from vllm.model_executor.sampling_metadata import SamplingMetadata + +from vllm import LLM, SamplingParams from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.opt import OPTForCausalLM -from vllm import LLM, SamplingParams +from vllm.model_executor.sampling_metadata import SamplingMetadata + class MyOPTForCausalLM(OPTForCausalLM): + def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: # this dummy model always predicts the first token @@ -14,9 +16,11 @@ def compute_logits(self, hidden_states: torch.Tensor, logits[:, 0] += 1.0 return logits + def test_oot_registration(): # register our dummy model - ModelRegistry.register_out_of_tree_model("OPTForCausalLM", MyOPTForCausalLM) + ModelRegistry.register_out_of_tree_model("OPTForCausalLM", + MyOPTForCausalLM) prompts = ["Hello, my name is", "The text does not matter"] sampling_params = SamplingParams(temperature=0) llm = LLM(model="facebook/opt-125m") @@ -24,7 +28,6 @@ def test_oot_registration(): outputs = llm.generate(prompts, sampling_params) for output in outputs: - prompt = output.prompt generated_text = output.outputs[0].text # make sure only the first token is generated rest = generated_text.replace(first_token, "") From c51050f9a4ea3705ab72bcaf72c06eae2d955d6d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 6 Apr 2024 09:21:15 -0700 Subject: [PATCH 06/14] register_out_of_tree_model --> register_model --- docs/source/models/adding_model.rst | 2 +- tests/models/test_oot_registration.py | 3 +-- vllm/model_executor/models/__init__.py | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 4bcc9446b52d8..3ec1bf6714ca0 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -108,4 +108,4 @@ Just add the following lines in your code: from vllm.model_executor.models import ModelRegistry from your_code import YourModelForCausalLM - ModelRegistry.register_out_of_tree_model("YourModelForCausalLM", YourModelForCausalLM) + ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 91745c24bf14e..33ef561958072 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -19,8 +19,7 @@ def compute_logits(self, hidden_states: torch.Tensor, def test_oot_registration(): # register our dummy model - ModelRegistry.register_out_of_tree_model("OPTForCausalLM", - MyOPTForCausalLM) + ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) prompts = ["Hello, my name is", "The text does not matter"] sampling_params = SamplingParams(temperature=0) llm = LLM(model="facebook/opt-125m") diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index a2b41cb7c6d2c..54988d264e05c 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -102,8 +102,7 @@ def get_supported_archs() -> List[str]: return list(_MODELS.keys()) @staticmethod - def register_out_of_tree_model(model_arch: str, - model_cls: Type[nn.Module]): + def register_model(model_arch: str, model_cls: Type[nn.Module]): global _OOT_MODELS _OOT_MODELS[model_arch] = model_cls From c3a3b16b7f67ede7b40414a90a3c4a6055781330 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 6 Apr 2024 09:24:21 -0700 Subject: [PATCH 07/14] add warning for overwritten --- vllm/model_executor/models/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 54988d264e05c..4647947f695aa 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -103,6 +103,11 @@ def get_supported_archs() -> List[str]: @staticmethod def register_model(model_arch: str, model_cls: Type[nn.Module]): + if model_arch in _MODELS: + logger.warning( + f"Model architecture {model_arch} is already registered, " + "and will be overwritten by the new model " + f"class {model_cls.__name__}.") global _OOT_MODELS _OOT_MODELS[model_arch] = model_cls From 9b6655e0a8c1f22a6c67c7e2bd6973d9971b6481 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 6 Apr 2024 09:26:53 -0700 Subject: [PATCH 08/14] expose ModelRegistry in top-level --- docs/source/models/adding_model.rst | 2 +- tests/models/test_oot_registration.py | 3 +-- vllm/__init__.py | 2 ++ 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 3ec1bf6714ca0..4007a4370eabb 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -106,6 +106,6 @@ Just add the following lines in your code: .. code-block:: python - from vllm.model_executor.models import ModelRegistry + from vllm import ModelRegistry from your_code import YourModelForCausalLM ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 33ef561958072..50ab06631500b 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -1,7 +1,6 @@ import torch -from vllm import LLM, SamplingParams -from vllm.model_executor.models import ModelRegistry +from vllm import LLM, ModelRegistry, SamplingParams from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata diff --git a/vllm/__init__.py b/vllm/__init__.py index 52c36f55e9ebe..2c1fd40573240 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,6 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.engine.ray_utils import initialize_ray_cluster from vllm.entrypoints.llm import LLM +from vllm.model_executor.models import ModelRegistry from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams @@ -12,6 +13,7 @@ __all__ = [ "LLM", + "ModelRegistry", "SamplingParams", "RequestOutput", "CompletionOutput", From f6326a0f193a88721a679f3238d3325b97007bba Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 6 Apr 2024 10:00:11 -0700 Subject: [PATCH 09/14] add guide for openai server --- docs/source/models/adding_model.rst | 13 +++++++++++++ tests/models/test_oot_registration.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 4007a4370eabb..6a2b2d29a3572 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -109,3 +109,16 @@ Just add the following lines in your code: from vllm import ModelRegistry from your_code import YourModelForCausalLM ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) + +If you are running api server with `python -m vllm.entrypoints.openai.api_server args`, you can wrap the entrypoint with the following code: + +.. code-block:: python + + from vllm import ModelRegistry + from your_code import YourModelForCausalLM + ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) + from vllm.entrypoints.openai.api_server import main + import runpy + runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') + +Save the above code in a file and run it with `python your_file.py args`. diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 50ab06631500b..c9c047c1ad252 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -30,3 +30,19 @@ def test_oot_registration(): # make sure only the first token is generated rest = generated_text.replace(first_token, "") assert rest == "" + + +def test_oot_registration_for_api_server(): + # register our dummy model + ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) + prompts = ["Hello, my name is", "The text does not matter"] + sampling_params = SamplingParams(temperature=0) + llm = LLM(model="facebook/opt-125m") + first_token = llm.get_tokenizer().decode(0) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + generated_text = output.outputs[0].text + # make sure only the first token is generated + rest = generated_text.replace(first_token, "") + assert rest == "" From e315617dd0b36871eeb43fdc1344a26856ebe6f5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 6 Apr 2024 10:15:48 -0700 Subject: [PATCH 10/14] add test for api server --- docs/source/models/adding_model.rst | 1 - tests/models/test_oot_registration.py | 46 ++++++++++++++++++++------- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 6a2b2d29a3572..a82c2cef10e83 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -117,7 +117,6 @@ If you are running api server with `python -m vllm.entrypoints.openai.api_server from vllm import ModelRegistry from your_code import YourModelForCausalLM ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) - from vllm.entrypoints.openai.api_server import main import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index c9c047c1ad252..52ae8482b37c5 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -1,4 +1,8 @@ +import multiprocessing +import sys + import torch +from openai import OpenAI from vllm import LLM, ModelRegistry, SamplingParams from vllm.model_executor.models.opt import OPTForCausalLM @@ -32,17 +36,37 @@ def test_oot_registration(): assert rest == "" -def test_oot_registration_for_api_server(): +def server_function(): # register our dummy model ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) - prompts = ["Hello, my name is", "The text does not matter"] - sampling_params = SamplingParams(temperature=0) - llm = LLM(model="facebook/opt-125m") - first_token = llm.get_tokenizer().decode(0) - outputs = llm.generate(prompts, sampling_params) + sys.argv = ["placeholder.py"] + \ + ("--model facebook/opt-125m --dtype" + " float32 --api-key token-abc123").split() + import runpy + runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') - for output in outputs: - generated_text = output.outputs[0].text - # make sure only the first token is generated - rest = generated_text.replace(first_token, "") - assert rest == "" + +def test_oot_registration_for_api_server(): + server = multiprocessing.Process(target=server_function) + server.start() + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + + completion = client.chat.completions.create( + model="facebook/opt-125m", + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Hello!" + }], + temperature=0, + ) + + generated_text = completion.choices[0].message.content + # make sure only the first token is generated + rest = generated_text.replace("", "") + assert rest == "" From 6f3e010a9154b3b10f4904310d8080fd44c4c12c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 6 Apr 2024 10:27:23 -0700 Subject: [PATCH 11/14] finish openai api server test --- tests/models/test_oot_registration.py | 36 ++++++++++++++++----------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 52ae8482b37c5..334e9e2d9b868 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -1,8 +1,9 @@ import multiprocessing import sys +import time import torch -from openai import OpenAI +from openai import OpenAI, OpenAIError from vllm import LLM, ModelRegistry, SamplingParams from vllm.model_executor.models.opt import OPTForCausalLM @@ -53,19 +54,26 @@ def test_oot_registration_for_api_server(): base_url="http://localhost:8000/v1", api_key="token-abc123", ) - - completion = client.chat.completions.create( - model="facebook/opt-125m", - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Hello!" - }], - temperature=0, - ) - + while True: + try: + completion = client.chat.completions.create( + model="facebook/opt-125m", + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Hello!" + }], + temperature=0, + ) + break + except OpenAIError as e: + if "Connection error" in str(e): + time.sleep(3) + else: + raise e + server.kill() generated_text = completion.choices[0].message.content # make sure only the first token is generated rest = generated_text.replace("", "") From 9ebcc56aa311b2280a6ba0b8ff0c9c643eec1f51 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 6 Apr 2024 11:07:18 -0700 Subject: [PATCH 12/14] separate test --- .../test_server_oot_registration.py | 64 +++++++++++++++++++ tests/models/test_oot_registration.py | 48 -------------- 2 files changed, 64 insertions(+), 48 deletions(-) create mode 100644 tests/entrypoints/test_server_oot_registration.py diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py new file mode 100644 index 0000000000000..de292eed1f4d2 --- /dev/null +++ b/tests/entrypoints/test_server_oot_registration.py @@ -0,0 +1,64 @@ +import multiprocessing +import sys +import time + +import torch +from openai import OpenAI, OpenAIError + +from vllm import ModelRegistry +from vllm.model_executor.models.opt import OPTForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata + + +class MyOPTForCausalLM(OPTForCausalLM): + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + # this dummy model always predicts the first token + logits = super().compute_logits(hidden_states, sampling_metadata) + logits.zero_() + logits[:, 0] += 1.0 + return logits + + +def server_function(): + # register our dummy model + ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) + sys.argv = ["placeholder.py"] + \ + ("--model facebook/opt-125m --dtype" + " float32 --api-key token-abc123").split() + import runpy + runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') + + +def test_oot_registration_for_api_server(): + server = multiprocessing.Process(target=server_function) + server.start() + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + while True: + try: + completion = client.chat.completions.create( + model="facebook/opt-125m", + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Hello!" + }], + temperature=0, + ) + break + except OpenAIError as e: + if "Connection error" in str(e): + time.sleep(3) + else: + raise e + server.kill() + generated_text = completion.choices[0].message.content + # make sure only the first token is generated + rest = generated_text.replace("", "") + assert rest == "" diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 334e9e2d9b868..50ab06631500b 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -1,9 +1,4 @@ -import multiprocessing -import sys -import time - import torch -from openai import OpenAI, OpenAIError from vllm import LLM, ModelRegistry, SamplingParams from vllm.model_executor.models.opt import OPTForCausalLM @@ -35,46 +30,3 @@ def test_oot_registration(): # make sure only the first token is generated rest = generated_text.replace(first_token, "") assert rest == "" - - -def server_function(): - # register our dummy model - ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) - sys.argv = ["placeholder.py"] + \ - ("--model facebook/opt-125m --dtype" - " float32 --api-key token-abc123").split() - import runpy - runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') - - -def test_oot_registration_for_api_server(): - server = multiprocessing.Process(target=server_function) - server.start() - client = OpenAI( - base_url="http://localhost:8000/v1", - api_key="token-abc123", - ) - while True: - try: - completion = client.chat.completions.create( - model="facebook/opt-125m", - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Hello!" - }], - temperature=0, - ) - break - except OpenAIError as e: - if "Connection error" in str(e): - time.sleep(3) - else: - raise e - server.kill() - generated_text = completion.choices[0].message.content - # make sure only the first token is generated - rest = generated_text.replace("", "") - assert rest == "" From fbbabb42339a433a54824e7cd3337b6178a96aaf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 6 Apr 2024 11:33:01 -0700 Subject: [PATCH 13/14] use dynamic port --- tests/entrypoints/test_server_oot_registration.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py index de292eed1f4d2..22e65bf7e7da1 100644 --- a/tests/entrypoints/test_server_oot_registration.py +++ b/tests/entrypoints/test_server_oot_registration.py @@ -8,6 +8,7 @@ from vllm import ModelRegistry from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.utils import get_open_port class MyOPTForCausalLM(OPTForCausalLM): @@ -21,21 +22,22 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits -def server_function(): +def server_function(port): # register our dummy model ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) sys.argv = ["placeholder.py"] + \ ("--model facebook/opt-125m --dtype" - " float32 --api-key token-abc123").split() + f" float32 --api-key token-abc123 --port {port}").split() import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') def test_oot_registration_for_api_server(): - server = multiprocessing.Process(target=server_function) + port = get_open_port() + server = multiprocessing.Process(target=server_function, args=(port, )) server.start() client = OpenAI( - base_url="http://localhost:8000/v1", + base_url=f"http://localhost:{port}/v1", api_key="token-abc123", ) while True: From b6a114f7c3d916cd6a1a1656eef8b5000efb3bbe Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 6 Apr 2024 11:58:14 -0700 Subject: [PATCH 14/14] separate test commands --- .buildkite/test-pipeline.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7ad3386fa499e..27e44463a30a6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -34,7 +34,10 @@ steps: command: pytest -v -s engine tokenization test_sequence.py test_config.py - label: Entrypoints Test - command: pytest -v -s entrypoints + commands: + # these tests have to be separated, because each one will allocate all posible GPU memory + - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py + - pytest -v -s entrypoints/test_server_oot_registration.py - label: Examples Test working_dir: "/vllm-workspace/examples"