Skip to content

Commit

Permalink
[Core] Support Lora lineage and base model metadata management (vllm-…
Browse files Browse the repository at this point in the history
…project#6315)

Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
Jeffwan authored and sumitd2 committed Nov 14, 2024
1 parent fb40d6c commit 9532f3f
Show file tree
Hide file tree
Showing 15 changed files with 337 additions and 45 deletions.
64 changes: 64 additions & 0 deletions docs/source/models/lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,67 @@ Example request to unload a LoRA adapter:
-d '{
"lora_name": "sql_adapter"
}'
New format for `--lora-modules`
-------------------------------

In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:

.. code-block:: bash
--lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/
This would only include the `name` and `path` for each LoRA module, but did not provide a way to specify a `base_model_name`.
Now, you can specify a base_model_name alongside the name and path using JSON format. For example:

.. code-block:: bash
--lora-modules '{"name": "sql-lora", "path": "/path/to/lora", "base_model_name": "meta-llama/Llama-2-7b"}'
To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case.


Lora model lineage in model card
--------------------------------

The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this:

- The `parent` field of LoRA model `sql-lora` now links to its base model `meta-llama/Llama-2-7b-hf`. This correctly reflects the hierarchical relationship between the base model and the LoRA adapter.
- The `root` field points to the artifact location of the lora adapter.

.. code-block:: bash
$ curl http://localhost:8000/v1/models
{
"object": "list",
"data": [
{
"id": "meta-llama/Llama-2-7b-hf",
"object": "model",
"created": 1715644056,
"owned_by": "vllm",
"root": "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/",
"parent": null,
"permission": [
{
.....
}
]
},
{
"id": "sql-lora",
"object": "model",
"created": 1715644056,
"owned_by": "vllm",
"root": "~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/",
"parent": meta-llama/Llama-2-7b-hf,
"permission": [
{
....
}
]
}
]
}
91 changes: 91 additions & 0 deletions tests/entrypoints/openai/test_cli_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import json
import unittest

from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.utils import FlexibleArgumentParser

LORA_MODULE = {
"name": "module2",
"path": "/path/to/module2",
"base_model_name": "llama"
}


class TestLoraParserAction(unittest.TestCase):

def setUp(self):
# Setting up argparse parser for tests
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
self.parser = make_arg_parser(parser)

def test_valid_key_value_format(self):
# Test old format: name=path
args = self.parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
])
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
self.assertEqual(args.lora_modules, expected)

def test_valid_json_format(self):
# Test valid JSON format input
args = self.parser.parse_args([
'--lora-modules',
json.dumps(LORA_MODULE),
])
expected = [
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
self.assertEqual(args.lora_modules, expected)

def test_invalid_json_format(self):
# Test invalid JSON format input, missing closing brace
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'{"name": "module3", "path": "/path/to/module3"'
])

def test_invalid_type_error(self):
# Test type error when values are not JSON or key=value
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'invalid_format' # This is not JSON or key=value format
])

def test_invalid_json_field(self):
# Test valid JSON format but missing required fields
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'{"name": "module4"}' # Missing required 'path' field
])

def test_empty_values(self):
# Test when no LoRA modules are provided
args = self.parser.parse_args(['--lora-modules', ''])
self.assertEqual(args.lora_modules, [])

def test_multiple_valid_inputs(self):
# Test multiple valid inputs (both old and JSON format)
args = self.parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
json.dumps(LORA_MODULE),
])
expected = [
LoRAModulePath(name='module1', path='/path/to/module1'),
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
self.assertEqual(args.lora_modules, expected)


if __name__ == '__main__':
unittest.main()
83 changes: 83 additions & 0 deletions tests/entrypoints/openai/test_lora_lineage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json

import openai # use the official client for correctness check
import pytest
import pytest_asyncio
# downloading lora to test lora requests
from huggingface_hub import snapshot_download

from ...utils import RemoteOpenAIServer

# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"


@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)


@pytest.fixture(scope="module")
def server_with_lora_modules_json(zephyr_lora_files):
# Define the json format LoRA module configurations
lora_module_1 = {
"name": "zephyr-lora",
"path": zephyr_lora_files,
"base_model_name": MODEL_NAME
}

lora_module_2 = {
"name": "zephyr-lora2",
"path": zephyr_lora_files,
"base_model_name": MODEL_NAME
}

args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
json.dumps(lora_module_1),
json.dumps(lora_module_2),
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"64",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client_for_lora_lineage(server_with_lora_modules_json):
async with server_with_lora_modules_json.get_async_client(
) as async_client:
yield async_client


@pytest.mark.asyncio
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
zephyr_lora_files):
models = await client_for_lora_lineage.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert served_model.root == MODEL_NAME
assert served_model.parent is None
assert all(lora_model.root == zephyr_lora_files
for lora_model in lora_models)
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"
6 changes: 4 additions & 2 deletions tests/entrypoints/openai/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ async def client(server):


@pytest.mark.asyncio
async def test_check_models(client: openai.AsyncOpenAI):
async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files):
models = await client.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
assert served_model.root == MODEL_NAME
assert all(lora_model.root == zephyr_lora_files
for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"
6 changes: 4 additions & 2 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.transformers_utils.tokenizer import get_tokenizer

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]


@dataclass
Expand All @@ -37,7 +39,7 @@ async def _async_serving_chat_init():

serving_completion = OpenAIServingChat(engine,
model_config,
served_model_names=[MODEL_NAME],
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,
Expand All @@ -58,7 +60,7 @@ def test_serving_chat_should_set_correct_max_tokens():

serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
served_model_names=[MODEL_NAME],
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,
Expand Down
5 changes: 3 additions & 2 deletions tests/entrypoints/openai/test_serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing

MODEL_NAME = "meta-llama/Llama-2-7b"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
LORA_LOADING_SUCCESS_MESSAGE = (
"Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
Expand All @@ -25,7 +26,7 @@ async def _async_serving_engine_init():

serving_engine = OpenAIServing(mock_engine_client,
mock_model_config,
served_model_names=[MODEL_NAME],
BASE_MODEL_PATHS,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
Expand Down
14 changes: 10 additions & 4 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
Expand Down Expand Up @@ -476,13 +477,18 @@ def init_app_state(
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)

base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]

state.engine_client = engine_client
state.log_stats = not args.disable_log_stats

state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
served_model_names,
base_model_paths,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
Expand All @@ -494,7 +500,7 @@ def init_app_state(
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
served_model_names,
base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
Expand All @@ -503,13 +509,13 @@ def init_app_state(
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
served_model_names,
base_model_paths,
request_logger=request_logger,
)
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
served_model_names,
base_model_paths,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,
Expand Down
Loading

0 comments on commit 9532f3f

Please sign in to comment.