-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LoRA support for Mixtral #2831
Conversation
vllm/worker/worker.py
Outdated
@@ -94,6 +95,24 @@ def init_model(self) -> None: | |||
# Initialize the model. | |||
set_random_seed(self.model_config.seed) | |||
|
|||
self.model = get_model(self.model_config, self.device_config, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code is in https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py#L84, it shouldn't be here.
) -> None: | ||
super().__init__() | ||
self.config = config | ||
self.linear_method = linear_method | ||
self.model = MixtralModel(config, linear_method) | ||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) | ||
self.sampler = Sampler(config.vocab_size) | ||
self.unpadded_vocab_size = config.vocab_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say why this is different from the code in mistral.py
? There is the minor cosmetic difference of unpadded_vocab_size
vs. self.unpadded_vocab_size
that we should fix, but also the larger difference of the padding_size
difference in the ParallelLMHead
as well as the different parameters of the Sampler
. Is there a reason why the code is not the same? :)
Same comment comparing with llama.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be the same, fixed
vllm/model_executor/models/llama.py
Outdated
@@ -270,6 +270,30 @@ def forward( | |||
|
|||
class LlamaForCausalLM(nn.Module): | |||
supports_lora = True | |||
lora_target_modules = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The biggest question I have about this PR is these configurations, I don't think they should be here. There might be different ways to run a model (e.g. different layers get LoRAified, like for the mixtral model, do you apply LoRA to the MoE weight matrices or not?), and also it seems odd to have the lora configurations in the model itself. Are there better places where this can fit? Maybe @Yard1 has some ideas here as well :)
The other issue of the override code in vllm/lora/models.py
being a little too complex will flow naturally from what we do here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are my thoughts so far: The lora_target_modules
I believe is in the adapter_config.json
of the LoRA model file, so we probably can use the information from there -- it just will need to be adapted a little bit via the packed_module_mapping
. So it seems we can remove the LoRA specific stuff, which is already a big relief.
The other parameters are also useful for loading the model etc, so maybe we should keep them here? It would be worth investigating if we can reuse this in the model loading function for the base model (load_weights
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This defines all the layers we support for LoRA on the given model. The adapter can use a subset of those layers. We should not be using adapter_config.json
directly because:
- different adapters will have different layers, therefore we need a common superset
- that superset has to be constant
- we need to implement support for each layer type, so if an adapter specifies a layer we do not support, we need to throw exceptions.
Given the above, defining those variables as attributes of the model class seems to be the best option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great, I believe my confusion will go away if we rename lora_target_modules
to supported_lora_modules
so it is clearer that this is the superset of modules that we support, and not the actual ones that are loaded. Let's rename it, and in parallel I'll make a PR that shows how it would look like if we use packed_module_mapping
in the load_weights
function so the information is in only one place.
Let's also reorder things so the generic attributes come before the lora specific ones (and add a comment on the lora specific ones that they are lora specific).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put up a PR in #2843 :)
tests/lora/test_lora_manager.py
Outdated
@@ -120,7 +129,7 @@ def test_lora_model_manager(dist_init, dummy_model): | |||
2, | |||
2, | |||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), | |||
lora_target_modules=["dense1", "dense2", "lm_head"]) | |||
support_lora_modules=["dense1", "dense2", "lm_head"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rename this to supported_lora_modules
? The idea here is this is a list of modules that can be LoRAified (i.e. the list of modules for which the model supports applying LoRA adapters).
vllm/lora/models.py
Outdated
# allow overriding the target modules and mapping with initialization | ||
if support_lora_modules: | ||
self.support_lora_modules: List[str] = ( | ||
[support_lora_modules] if isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need the isinstance
part any more now that only lists are supported, right?
vllm/lora/models.py
Outdated
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, | ||
**kwargs) -> LoRAModelManager: | ||
"""Create a LoRA adapter for a given model.""" | ||
if not getattr(model, "supports_lora", False): | ||
if not getattr(model, "supported_lora_modules", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: This is slightly odd, if not hasattr(model, "supported_lora_modules")
would is more natural here, since supported_lora_modules
is not of type bool :)
vllm/lora/worker_manager.py
Outdated
@@ -195,11 +198,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): | |||
def create_lora_manager( | |||
self, | |||
model: torch.nn.Module, | |||
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, | |||
supported_lora_modules: Optional[Union[str, List[str]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
supported_lora_modules
should be removed, right?
vllm/model_executor/model_loader.py
Outdated
@@ -66,7 +66,7 @@ def get_model(model_config: ModelConfig, | |||
# Create a model instance. | |||
# The weights will be initialized as empty tensors. | |||
with torch.device(device_config.device): | |||
if getattr(model_class, "supports_lora", False): | |||
if getattr(model_class, "supported_lora_modules", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above with hasattr
vllm/worker/model_runner.py
Outdated
assert hasattr( | ||
self.model, "supported_lora_modules" | ||
) and self.model.supported_lora_modules, "Model does not support LoRA" | ||
assert hasattr(self.model, "embedding_modules") and hasattr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can you split this into two asserts? That will make the error message a little clearer if one of them is missing (for debugging)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks great to me now, thanks for doing this :)
I also did some manual testing on this PR: Merge in #2775, and then run
and query the server on some workload I care about and it is working well! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
* add mixtral lora support * formatting * fix incorrectly ported logic * polish tests * minor fixes and refactoring * minor fixes * formatting * rename and remove redundant logic * refactoring * refactoring * minor fix * minor refactoring * fix code smell
* add mixtral lora support * formatting * fix incorrectly ported logic * polish tests * minor fixes and refactoring * minor fixes * formatting * rename and remove redundant logic * refactoring * refactoring * minor fix * minor refactoring * fix code smell
* add mixtral lora support * formatting * fix incorrectly ported logic * polish tests * minor fixes and refactoring * minor fixes * formatting * rename and remove redundant logic * refactoring * refactoring * minor fix * minor refactoring * fix code smell
* add mixtral lora support * formatting * fix incorrectly ported logic * polish tests * minor fixes and refactoring * minor fixes * formatting * rename and remove redundant logic * refactoring * refactoring * minor fix * minor refactoring * fix code smell
@tterrysun @Yard1 seems like the Mixtral implementation does not support the expert linear layers: w1, w2, w3. |
* add mixtral lora support * formatting * fix incorrectly ported logic * polish tests * minor fixes and refactoring * minor fixes * formatting * rename and remove redundant logic * refactoring * refactoring * minor fix * minor refactoring * fix code smell
Problem: We don't have LoRA support for Mixtral.
Solution: Add LoRA configurations for Mixtral and refactor relevant parts to allow this.
Testing: added correctness tests and updated existing tests.