From d0bc19780a2555f4eabb81a54df7be622581ca54 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 6 Apr 2024 17:11:41 -0700 Subject: [PATCH] [Core] enable out-of-tree model register (#3871) --- .buildkite/test-pipeline.yaml | 5 +- docs/source/models/adding_model.rst | 27 ++++++++ .../test_server_oot_registration.py | 66 +++++++++++++++++++ tests/models/test_oot_registration.py | 32 +++++++++ vllm/__init__.py | 2 + vllm/model_executor/models/__init__.py | 18 ++++- 6 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 tests/entrypoints/test_server_oot_registration.py create mode 100644 tests/models/test_oot_registration.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7ad3386fa499e..27e44463a30a6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -34,7 +34,10 @@ steps: command: pytest -v -s engine tokenization test_sequence.py test_config.py - label: Entrypoints Test - command: pytest -v -s entrypoints + commands: + # these tests have to be separated, because each one will allocate all posible GPU memory + - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py + - pytest -v -s entrypoints/test_server_oot_registration.py - label: Examples Test working_dir: "/vllm-workspace/examples" diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 45ef0340aae25..a82c2cef10e83 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -21,6 +21,8 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor Start by forking our `GitHub`_ repository and then :ref:`build it from source `. This gives you the ability to modify the codebase and test your model. +.. tip:: + If you don't want to fork the repository and modify vLLM's codebase, please refer to the "Out-of-Tree Model Integration" section below. 1. Bring your model code ------------------------ @@ -94,3 +96,28 @@ This method should load the weights from the HuggingFace's checkpoint file and a ---------------------- Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py `_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py `_. + +6. Out-of-Tree Model Integration +-------------------------------------------- + +We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5. + +Just add the following lines in your code: + +.. code-block:: python + + from vllm import ModelRegistry + from your_code import YourModelForCausalLM + ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) + +If you are running api server with `python -m vllm.entrypoints.openai.api_server args`, you can wrap the entrypoint with the following code: + +.. code-block:: python + + from vllm import ModelRegistry + from your_code import YourModelForCausalLM + ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) + import runpy + runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') + +Save the above code in a file and run it with `python your_file.py args`. diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py new file mode 100644 index 0000000000000..22e65bf7e7da1 --- /dev/null +++ b/tests/entrypoints/test_server_oot_registration.py @@ -0,0 +1,66 @@ +import multiprocessing +import sys +import time + +import torch +from openai import OpenAI, OpenAIError + +from vllm import ModelRegistry +from vllm.model_executor.models.opt import OPTForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.utils import get_open_port + + +class MyOPTForCausalLM(OPTForCausalLM): + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + # this dummy model always predicts the first token + logits = super().compute_logits(hidden_states, sampling_metadata) + logits.zero_() + logits[:, 0] += 1.0 + return logits + + +def server_function(port): + # register our dummy model + ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) + sys.argv = ["placeholder.py"] + \ + ("--model facebook/opt-125m --dtype" + f" float32 --api-key token-abc123 --port {port}").split() + import runpy + runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') + + +def test_oot_registration_for_api_server(): + port = get_open_port() + server = multiprocessing.Process(target=server_function, args=(port, )) + server.start() + client = OpenAI( + base_url=f"http://localhost:{port}/v1", + api_key="token-abc123", + ) + while True: + try: + completion = client.chat.completions.create( + model="facebook/opt-125m", + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Hello!" + }], + temperature=0, + ) + break + except OpenAIError as e: + if "Connection error" in str(e): + time.sleep(3) + else: + raise e + server.kill() + generated_text = completion.choices[0].message.content + # make sure only the first token is generated + rest = generated_text.replace("", "") + assert rest == "" diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py new file mode 100644 index 0000000000000..50ab06631500b --- /dev/null +++ b/tests/models/test_oot_registration.py @@ -0,0 +1,32 @@ +import torch + +from vllm import LLM, ModelRegistry, SamplingParams +from vllm.model_executor.models.opt import OPTForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata + + +class MyOPTForCausalLM(OPTForCausalLM): + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + # this dummy model always predicts the first token + logits = super().compute_logits(hidden_states, sampling_metadata) + logits.zero_() + logits[:, 0] += 1.0 + return logits + + +def test_oot_registration(): + # register our dummy model + ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) + prompts = ["Hello, my name is", "The text does not matter"] + sampling_params = SamplingParams(temperature=0) + llm = LLM(model="facebook/opt-125m") + first_token = llm.get_tokenizer().decode(0) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + generated_text = output.outputs[0].text + # make sure only the first token is generated + rest = generated_text.replace(first_token, "") + assert rest == "" diff --git a/vllm/__init__.py b/vllm/__init__.py index 52c36f55e9ebe..2c1fd40573240 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,6 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.engine.ray_utils import initialize_ray_cluster from vllm.entrypoints.llm import LLM +from vllm.model_executor.models import ModelRegistry from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams @@ -12,6 +13,7 @@ __all__ = [ "LLM", + "ModelRegistry", "SamplingParams", "RequestOutput", "CompletionOutput", diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b5c7e44de619c..4647947f695aa 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,5 +1,5 @@ import importlib -from typing import List, Optional, Type +from typing import Dict, List, Optional, Type import torch.nn as nn @@ -55,6 +55,10 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), } +# Architecture -> type. +# out of tree models +_OOT_MODELS: Dict[str, Type[nn.Module]] = {} + # Models not supported by ROCm. _ROCM_UNSUPPORTED_MODELS = [] @@ -74,6 +78,8 @@ class ModelRegistry: @staticmethod def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch in _OOT_MODELS: + return _OOT_MODELS[model_arch] if model_arch not in _MODELS: return None if is_hip(): @@ -95,6 +101,16 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: def get_supported_archs() -> List[str]: return list(_MODELS.keys()) + @staticmethod + def register_model(model_arch: str, model_cls: Type[nn.Module]): + if model_arch in _MODELS: + logger.warning( + f"Model architecture {model_arch} is already registered, " + "and will be overwritten by the new model " + f"class {model_cls.__name__}.") + global _OOT_MODELS + _OOT_MODELS[model_arch] = model_cls + __all__ = [ "ModelRegistry",