Skip to content

Commit

Permalink
Add LoRA support for Mixtral (vllm-project#2831)
Browse files Browse the repository at this point in the history
* add mixtral lora support

* formatting

* fix incorrectly ported logic

* polish tests

* minor fixes and refactoring

* minor fixes

* formatting

* rename and remove redundant logic

* refactoring

* refactoring

* minor fix

* minor refactoring

* fix code smell
  • Loading branch information
tterrysun authored and jimpang committed Feb 22, 2024
1 parent efea78e commit dba6048
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 121 deletions.
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def sql_lora_files():
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")


@pytest.fixture(scope="session")
def mixtral_lora_files():
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")


@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
Expand Down
82 changes: 47 additions & 35 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,35 @@
RowParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager,
from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, LoRAMapping)
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear

EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}

EMBEDDING_PADDING_MODULES = ["lm_head"]


def test_from_lora_tensors(sql_lora_files):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
new_embeddings = load_file(
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
lora_model = LoRAModel.from_lora_tensors(1,
8,
16,
tensors,
"cuda",
embeddings=new_embeddings)
lora_model = LoRAModel.from_lora_tensors(
1,
8,
16,
tensors,
"cuda",
embeddings=new_embeddings,
embedding_modules=EMBEDDING_MODULES,
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
for module_name, lora in lora_model.loras.items():
assert lora.module_name == module_name
assert lora.rank == 8
Expand Down Expand Up @@ -90,14 +100,11 @@ def create_packed_lora(

def test_replace_submodules(dist_init, dummy_model):
model = dummy_model
manager = LoRAModelManager(model,
1,
1,
1,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=8,
max_loras=8),
lora_target_modules=["dense1", "layer1.dense2"])
model.supported_lora_modules = ["dense1", "layer1.dense2"]
model.packed_modules_mapping = {}
manager = LoRAModelManager(
model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8))
model = manager.model

assert isinstance(model.get_submodule("dense1"),
Expand All @@ -111,16 +118,14 @@ def test_replace_submodules(dist_init, dummy_model):

def test_lora_model_manager(dist_init, dummy_model):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
manager = LoRAModelManager(
model,
2,
2,
2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
lora_target_modules=["dense1", "dense2", "lm_head"])
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_lora(model_lora1)
assert manager.activate_lora(1)
Expand Down Expand Up @@ -159,16 +164,14 @@ def test_lora_model_manager(dist_init, dummy_model):

def test_lora_lru_cache_model_manager(dist_init, dummy_model):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
manager = LRUCacheLoRAModelManager(
model,
2,
2,
2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
lora_target_modules=["dense1", "dense2", "lm_head"])
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_lora(model_lora1)
assert manager.activate_lora(1)
Expand Down Expand Up @@ -212,14 +215,15 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
model.packed_modules_mapping = {}
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
manager = LRUCacheLoRAModelManager(
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
["dense1", "dense2", "lm_head"])
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))

assert all(x is None for x in manager.lora_index_to_id)

Expand Down Expand Up @@ -289,8 +293,9 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_lora_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
torch.device("cuda"))
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)

mapping = LoRAMapping([], [])
Expand Down Expand Up @@ -362,8 +367,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_lora_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
torch.device("cuda"))
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)

mapping = LoRAMapping([], [])
Expand Down Expand Up @@ -428,6 +434,13 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,

def test_packed_loras(dist_init, dummy_model_gate_up):
model = dummy_model_gate_up
model.supported_lora_modules = ["gate_up_proj"]
model.packed_modules_mapping = {
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
model_lora = create_packed_lora(
1,
model,
Expand All @@ -443,8 +456,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up):

manager = LoRAModelManager(
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
["gate_up_proj"])
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
model = manager.model

assert isinstance(model.get_submodule("gate_up_proj"),
Expand Down
53 changes: 53 additions & 0 deletions tests/lora/test_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import torch

import vllm
from vllm.lora.request import LoRARequest

MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"


def do_sample(llm, lora_path: str, lora_id: int):
prompts = [
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]",
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]",
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]",
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


@pytest.mark.parametrize("tp_size", [4])
def test_mixtral_lora(mixtral_lora_files, tp_size):
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=tp_size,
worker_use_ray=True)

expected_lora_output = [
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])",
]

assert do_sample(llm, mixtral_lora_files,
lora_id=1) == expected_lora_output
assert do_sample(llm, mixtral_lora_files,
lora_id=2) == expected_lora_output
Loading

0 comments on commit dba6048

Please sign in to comment.