-
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: Model architectures ['LlavaForCausalLM'] are not supported for now in vllm 0.4.0.post1 #4008
Comments
maybe this is a feature request idk enough about the intent of supporting llava. |
Don't know how I missed that when writing PR #3978. Interesting... was the 7b model modified specifically to facilitate the proof-of-concept for vLLM? |
Unsure I went and reviewed this #3042 and it seems llava 7b and 13b have different classes. We only supported LlavaForConditionalGeneration but not LlavaForCausalLM. |
@xwjiang2010 do you happen to know the effort to support LLaVa 13b. Pinging you since you worked on the initial PR for vision models. |
I think it's just a typo in their HuggingFace diff --git a/tests/conftest.py b/tests/conftest.py
index a7e8963..61ca4fe 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -131,6 +131,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
_VISION_LANGUAGE_MODELS = {
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
+ "llava-hf/llava-1.5-13b-hf": LlavaForConditionalGeneration,
}
diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py
index 17fc970..54a077d 100755
--- a/vllm/model_executor/models/__init__.py
+++ b/vllm/model_executor/models/__init__.py
@@ -33,6 +33,8 @@ _MODELS = {
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
+ "LlavaForCausalLM":
+ ("llava", "LlavaForConditionalGeneration"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"), Update: I have tested with some other models:
|
hmm okay then i can probably clone their model and just modify the config.json to use the LlavaForConditionalGeneration architecture. |
I have updated my PR accordingly. |
Thank you for your effort in supporting After checking the code in original llava code repo(https://github.com/haotian-liu/LLaVA), I found exactly llava-1.5-13b use the I don't know whether there will be a runtime bug or some performance losses after directly modifying the config.json to use the |
@Jianzhao-Huang It doesn't look like HuggingFace has In any case, I have just pushed a commit to #3978 which adds the 13b model to the LLaVA test case that checks its consistency against the native HuggingFace model. |
Hmm, seems that the GPU memory fails to be freed between testing each model. Does anyone know how to fix this issue? |
I have investigated further and it seems that the CI/CD infrastructure cannot even load the 13B model into memory (I removed all other LLaVA models from the test and it still OOMed). Not sure what I can do about that... |
Any fix? |
The incorrect architecture in |
Your current environment
🐛 Describe the bug
I am using:
which throws error: ValueError: Model architectures ['LlavaForCausalLM'] are not supported for now. Supported architectures: ['AquilaModel', 'AquilaForCausalLM', 'BaiChuanForCausalLM', 'BaichuanForCausalLM', 'BloomForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'CohereForCausalLM', 'DbrxForCausalLM', 'DeciLMForCausalLM', 'DeepseekForCausalLM', 'FalconForCausalLM', 'GemmaForCausalLM', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTJForCausalLM', 'GPTNeoXForCausalLM', 'InternLMForCausalLM', 'InternLM2ForCausalLM', 'JAISLMHeadModel', 'LlamaForCausalLM', 'LlavaForConditionalGeneration', 'LLaMAForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'QuantMixtralForCausalLM', 'MptForCausalLM', 'MPTForCausalLM', 'OLMoForCausalLM', 'OPTForCausalLM', 'OrionForCausalLM', 'PhiForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RWForCausalLM', 'StableLMEpochForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'XverseForCausalLM']
After checking huggingface llava-1.5-7b-hf uses LlavaForConditionalGeneration and llava-1.5-13b-hf uses
LlavaForCausalLM
?Any easy workaround / fix for this?
The text was updated successfully, but these errors were encountered: