Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] enable out-of-tree model register #3871

Merged
merged 14 commits into from
Apr 7, 2024
15 changes: 15 additions & 0 deletions docs/source/models/adding_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor
Start by forking our `GitHub`_ repository and then :ref:`build it from source <build_from_source>`.
This gives you the ability to modify the codebase and test your model.

.. tip::
If you don't want to fork the repository and modify vLLM's codebase, please refer to the "Out-of-Tree Model Integration" section below.

1. Bring your model code
------------------------
Expand Down Expand Up @@ -94,3 +96,16 @@ This method should load the weights from the HuggingFace's checkpoint file and a
----------------------

Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader.py>`_.

6. Out-of-Tree Model Integration
--------------------------------------------

We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5.

Just add the following lines in your code:

.. code-block:: python

from vllm.model_executor.models import ModelRegistry
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
from your_code import YourModelForCausalLM
ModelRegistry.register_out_of_tree_model("YourModelForCausalLM", YourModelForCausalLM)
34 changes: 34 additions & 0 deletions tests/models/test_oot_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch

from vllm import LLM, SamplingParams
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata


class MyOPTForCausalLM(OPTForCausalLM):

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits


def test_oot_registration():
# register our dummy model
ModelRegistry.register_out_of_tree_model("OPTForCausalLM",
MyOPTForCausalLM)
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="facebook/opt-125m")
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
generated_text = output.outputs[0].text
# make sure only the first token is generated
rest = generated_text.replace(first_token, "")
assert rest == ""
14 changes: 13 additions & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import importlib
from typing import List, Optional, Type
from typing import Dict, List, Optional, Type

import torch.nn as nn

Expand Down Expand Up @@ -55,6 +55,10 @@
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
}

# Architecture -> type.
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = []

Expand All @@ -74,6 +78,8 @@ class ModelRegistry:

@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
return None
if is_hip():
Expand All @@ -95,6 +101,12 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())

@staticmethod
def register_out_of_tree_model(model_arch: str,
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
model_cls: Type[nn.Module]):
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls


__all__ = [
"ModelRegistry",
Expand Down
Loading