Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA support for Mixtral #2831

Merged
merged 15 commits into from
Feb 13, 2024
Merged

Conversation

tterrysun
Copy link
Contributor

Problem: We don't have LoRA support for Mixtral.

Solution: Add LoRA configurations for Mixtral and refactor relevant parts to allow this.

Testing: added correctness tests and updated existing tests.

@tterrysun tterrysun changed the title Add LoRA support for Mitral Add LoRA support for Mixtral Feb 10, 2024
@@ -94,6 +95,24 @@ def init_model(self) -> None:
# Initialize the model.
set_random_seed(self.model_config.seed)

self.model = get_model(self.model_config, self.device_config,
Copy link
Collaborator

@pcmoritz pcmoritz Feb 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tterrysun tterrysun marked this pull request as ready for review February 12, 2024 16:56
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = MixtralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.unpadded_vocab_size = config.vocab_size
Copy link
Collaborator

@pcmoritz pcmoritz Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say why this is different from the code in mistral.py? There is the minor cosmetic difference of unpadded_vocab_size vs. self.unpadded_vocab_size that we should fix, but also the larger difference of the padding_size difference in the ParallelLMHead as well as the different parameters of the Sampler. Is there a reason why the code is not the same? :)

Same comment comparing with llama.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be the same, fixed

@@ -270,6 +270,30 @@ def forward(

class LlamaForCausalLM(nn.Module):
supports_lora = True
lora_target_modules = [
Copy link
Collaborator

@pcmoritz pcmoritz Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The biggest question I have about this PR is these configurations, I don't think they should be here. There might be different ways to run a model (e.g. different layers get LoRAified, like for the mixtral model, do you apply LoRA to the MoE weight matrices or not?), and also it seems odd to have the lora configurations in the model itself. Are there better places where this can fit? Maybe @Yard1 has some ideas here as well :)

The other issue of the override code in vllm/lora/models.py being a little too complex will flow naturally from what we do here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are my thoughts so far: The lora_target_modules I believe is in the adapter_config.json of the LoRA model file, so we probably can use the information from there -- it just will need to be adapted a little bit via the packed_module_mapping. So it seems we can remove the LoRA specific stuff, which is already a big relief.

The other parameters are also useful for loading the model etc, so maybe we should keep them here? It would be worth investigating if we can reuse this in the model loading function for the base model (load_weights).

Copy link
Collaborator

@Yard1 Yard1 Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This defines all the layers we support for LoRA on the given model. The adapter can use a subset of those layers. We should not be using adapter_config.json directly because:

  1. different adapters will have different layers, therefore we need a common superset
  2. that superset has to be constant
  3. we need to implement support for each layer type, so if an adapter specifies a layer we do not support, we need to throw exceptions.

Given the above, defining those variables as attributes of the model class seems to be the best option.

Copy link
Collaborator

@pcmoritz pcmoritz Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great, I believe my confusion will go away if we rename lora_target_modules to supported_lora_modules so it is clearer that this is the superset of modules that we support, and not the actual ones that are loaded. Let's rename it, and in parallel I'll make a PR that shows how it would look like if we use packed_module_mapping in the load_weights function so the information is in only one place.

Let's also reorder things so the generic attributes come before the lora specific ones (and add a comment on the lora specific ones that they are lora specific).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put up a PR in #2843 :)

@@ -120,7 +129,7 @@ def test_lora_model_manager(dist_init, dummy_model):
2,
2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
lora_target_modules=["dense1", "dense2", "lm_head"])
support_lora_modules=["dense1", "dense2", "lm_head"])
Copy link
Collaborator

@pcmoritz pcmoritz Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename this to supported_lora_modules? The idea here is this is a list of modules that can be LoRAified (i.e. the list of modules for which the model supports applying LoRA adapters).

# allow overriding the target modules and mapping with initialization
if support_lora_modules:
self.support_lora_modules: List[str] = (
[support_lora_modules] if isinstance(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need the isinstance part any more now that only lists are supported, right?

lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not getattr(model, "supports_lora", False):
if not getattr(model, "supported_lora_modules", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This is slightly odd, if not hasattr(model, "supported_lora_modules") would is more natural here, since supported_lora_modules is not of type bool :)

@@ -195,11 +198,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
def create_lora_manager(
self,
model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
supported_lora_modules: Optional[Union[str, List[str]]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supported_lora_modules should be removed, right?

@@ -66,7 +66,7 @@ def get_model(model_config: ModelConfig,
# Create a model instance.
# The weights will be initialized as empty tensors.
with torch.device(device_config.device):
if getattr(model_class, "supports_lora", False):
if getattr(model_class, "supported_lora_modules", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above with hasattr

assert hasattr(
self.model, "supported_lora_modules"
) and self.model.supported_lora_modules, "Model does not support LoRA"
assert hasattr(self.model, "embedding_modules") and hasattr(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you split this into two asserts? That will make the error message a little clearer if one of them is missing (for debugging)

Copy link
Collaborator

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks great to me now, thanks for doing this :)

@pcmoritz
Copy link
Collaborator

I also did some manual testing on this PR: Merge in #2775, and then run

export LORA_PATH=/home/ray/mixtral_lora_checkpoint
python -m vllm.entrypoints.openai.api_server  --model mistralai/Mixtral-8x7B-Instruct-v0.1  --enable-lora  --lora-modules mixtral_lora=$LORA_PATH --tensor-parallel-size 4

and query the server on some workload I care about and it is working well!

Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@Yard1 Yard1 merged commit 2a543d6 into vllm-project:main Feb 13, 2024
19 checks passed
jvmncs pushed a commit to jvmncs/vllm that referenced this pull request Feb 14, 2024
* add mixtral lora support

* formatting

* fix incorrectly ported logic

* polish tests

* minor fixes and refactoring

* minor fixes

* formatting

* rename and remove redundant logic

* refactoring

* refactoring

* minor fix

* minor refactoring

* fix code smell
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 20, 2024
* add mixtral lora support

* formatting

* fix incorrectly ported logic

* polish tests

* minor fixes and refactoring

* minor fixes

* formatting

* rename and remove redundant logic

* refactoring

* refactoring

* minor fix

* minor refactoring

* fix code smell
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 22, 2024
* add mixtral lora support

* formatting

* fix incorrectly ported logic

* polish tests

* minor fixes and refactoring

* minor fixes

* formatting

* rename and remove redundant logic

* refactoring

* refactoring

* minor fix

* minor refactoring

* fix code smell
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
* add mixtral lora support

* formatting

* fix incorrectly ported logic

* polish tests

* minor fixes and refactoring

* minor fixes

* formatting

* rename and remove redundant logic

* refactoring

* refactoring

* minor fix

* minor refactoring

* fix code smell
@sfc-gh-ybsat
Copy link

@tterrysun @Yard1 seems like the Mixtral implementation does not support the expert linear layers: w1, w2, w3.
What would it take to add such support?
I tried naively adding the list to support_lora_modules in model_executor/models/mixtral.py but that obviously didnt work.
Would we need to make some punica kernel changes for this to work? Or what logic suffices to be updated?
Thanks in advance

Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
* add mixtral lora support

* formatting

* fix incorrectly ported logic

* polish tests

* minor fixes and refactoring

* minor fixes

* formatting

* rename and remove redundant logic

* refactoring

* refactoring

* minor fix

* minor refactoring

* fix code smell
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants