Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] enable out-of-tree model register #3871

Merged
merged 14 commits into from
Apr 7, 2024
5 changes: 4 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ steps:
command: pytest -v -s engine tokenization test_sequence.py test_config.py

- label: Entrypoints Test
command: pytest -v -s entrypoints
commands:
# these tests have to be separated, because each one will allocate all posible GPU memory
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
- pytest -v -s entrypoints/test_server_oot_registration.py

- label: Examples Test
working_dir: "/vllm-workspace/examples"
Expand Down
27 changes: 27 additions & 0 deletions docs/source/models/adding_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor
Start by forking our `GitHub`_ repository and then :ref:`build it from source <build_from_source>`.
This gives you the ability to modify the codebase and test your model.

.. tip::
If you don't want to fork the repository and modify vLLM's codebase, please refer to the "Out-of-Tree Model Integration" section below.

1. Bring your model code
------------------------
Expand Down Expand Up @@ -94,3 +96,28 @@ This method should load the weights from the HuggingFace's checkpoint file and a
----------------------

Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader.py>`_.

6. Out-of-Tree Model Integration
--------------------------------------------

We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5.

Just add the following lines in your code:

.. code-block:: python

from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)

If you are running api server with `python -m vllm.entrypoints.openai.api_server args`, you can wrap the entrypoint with the following code:

.. code-block:: python

from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')

Save the above code in a file and run it with `python your_file.py args`.
66 changes: 66 additions & 0 deletions tests/entrypoints/test_server_oot_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import multiprocessing
import sys
import time

import torch
from openai import OpenAI, OpenAIError

from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port


class MyOPTForCausalLM(OPTForCausalLM):

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits


def server_function(port):
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + \
("--model facebook/opt-125m --dtype"
f" float32 --api-key token-abc123 --port {port}").split()
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')


def test_oot_registration_for_api_server():
port = get_open_port()
server = multiprocessing.Process(target=server_function, args=(port, ))
server.start()
client = OpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
while True:
try:
completion = client.chat.completions.create(
model="facebook/opt-125m",
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Hello!"
}],
temperature=0,
)
break
except OpenAIError as e:
if "Connection error" in str(e):
time.sleep(3)
else:
raise e
server.kill()
generated_text = completion.choices[0].message.content
# make sure only the first token is generated
rest = generated_text.replace("<s>", "")
assert rest == ""
32 changes: 32 additions & 0 deletions tests/models/test_oot_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch

from vllm import LLM, ModelRegistry, SamplingParams
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata


class MyOPTForCausalLM(OPTForCausalLM):

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits


def test_oot_registration():
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="facebook/opt-125m")
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
generated_text = output.outputs[0].text
# make sure only the first token is generated
rest = generated_text.replace(first_token, "")
assert rest == ""
2 changes: 2 additions & 0 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.entrypoints.llm import LLM
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams

__version__ = "0.4.0.post1"

__all__ = [
"LLM",
"ModelRegistry",
"SamplingParams",
"RequestOutput",
"CompletionOutput",
Expand Down
18 changes: 17 additions & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import importlib
from typing import List, Optional, Type
from typing import Dict, List, Optional, Type

import torch.nn as nn

Expand Down Expand Up @@ -55,6 +55,10 @@
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
}

# Architecture -> type.
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = []

Expand All @@ -74,6 +78,8 @@ class ModelRegistry:

@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
return None
if is_hip():
Expand All @@ -95,6 +101,16 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())

@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS:
logger.warning(
f"Model architecture {model_arch} is already registered, "
"and will be overwritten by the new model "
f"class {model_cls.__name__}.")
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls


__all__ = [
"ModelRegistry",
Expand Down
Loading