From 6a4744000e542e5c0a4f2d539e3896cfa73e7bfa Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Fri, 11 Oct 2024 09:20:47 +0000 Subject: [PATCH 1/9] support qwenvl2 --- xinference/model/llm/transformers/qwen2_vl.py | 28 ------------- xinference/model/llm/utils.py | 30 ++++++++++++++ xinference/model/llm/vllm/core.py | 41 ++++++++++++++++--- 3 files changed, 65 insertions(+), 34 deletions(-) diff --git a/xinference/model/llm/transformers/qwen2_vl.py b/xinference/model/llm/transformers/qwen2_vl.py index 3eccc0c736..4a371886db 100644 --- a/xinference/model/llm/transformers/qwen2_vl.py +++ b/xinference/model/llm/transformers/qwen2_vl.py @@ -75,34 +75,6 @@ def load(self): self.model_path, device_map=device, trust_remote_code=True ).eval() - def _transform_messages( - self, - messages: List[ChatCompletionMessage], - ): - transformed_messages = [] - for msg in messages: - new_content = [] - role = msg["role"] - content = msg["content"] - if isinstance(content, str): - new_content.append({"type": "text", "text": content}) - elif isinstance(content, List): - for item in content: # type: ignore - if "text" in item: - new_content.append({"type": "text", "text": item["text"]}) - elif "image_url" in item: - new_content.append( - {"type": "image", "image": item["image_url"]["url"]} - ) - elif "video_url" in item: - new_content.append( - {"type": "video", "video": item["video_url"]["url"]} - ) - new_message = {"role": role, "content": new_content} - transformed_messages.append(new_message) - - return transformed_messages - def chat( self, messages: List[ChatCompletionMessage], # type: ignore diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index a70341fff5..2a30f93f13 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -29,6 +29,7 @@ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, + ChatCompletionMessage, Completion, CompletionChoice, CompletionChunk, @@ -488,6 +489,35 @@ def _tool_calls_completion(cls, model_family, model_uid, c): "usage": usage, } + @classmethod + def _transform_messages( + self, + messages: List[ChatCompletionMessage], + ): + transformed_messages = [] + for msg in messages: + new_content = [] + role = msg["role"] + content = msg["content"] + if isinstance(content, str): + new_content.append({"type": "text", "text": content}) + elif isinstance(content, List): + for item in content: # type: ignore + if "text" in item: + new_content.append({"type": "text", "text": item["text"]}) + elif "image_url" in item: + new_content.append( + {"type": "image", "image": item["image_url"]["url"]} + ) + elif "video_url" in item: + new_content.append( + {"type": "video", "video": item["video_url"]["url"]} + ) + new_message = {"role": role, "content": new_content} + transformed_messages.append(new_message) + + return transformed_messages + def get_file_location( llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index e4fe3fd05a..43215f5b38 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -174,6 +174,7 @@ class VLLMGenerateConfig(TypedDict, total=False): if VLLM_INSTALLED and vllm.__version__ >= "0.6.1": VLLM_SUPPORTED_VISION_MODEL_LIST.append("internvl2") + VLLM_SUPPORTED_VISION_MODEL_LIST.append("qwen2-vl-instruct") class VLLMModel(LLM): @@ -309,11 +310,14 @@ def _sanitize_model_config( model_config.setdefault("max_num_seqs", 256) model_config.setdefault("quantization", None) model_config.setdefault("max_model_len", None) - model_config["limit_mm_per_prompt"] = ( - json.loads(model_config.get("limit_mm_per_prompt")) # type: ignore - if model_config.get("limit_mm_per_prompt") - else None - ) + if vllm.__version__ >= "0.6.1": + model_config["limit_mm_per_prompt"] = ( + json.loads(model_config.get("limit_mm_per_prompt")) # type: ignore + if model_config.get("limit_mm_per_prompt") + else { + "image": 2, # default 2 images all chat + } + ) return model_config @@ -733,6 +737,18 @@ def match( return False return VLLM_INSTALLED + def load(self): + super().load() + + self._processor = None + model_family = self.model_family.model_family or self.model_family.model_name + if "qwen2-vl" in model_family.lower(): + from transformers import AutoProcessor + + self._processor = AutoProcessor.from_pretrained( + self.model_path, trust_remote_code=True + ) + def _sanitize_chat_config( self, generate_config: Optional[Dict] = None, @@ -759,8 +775,21 @@ async def async_chat( generate_config: Optional[Dict] = None, request_id: Optional[str] = None, ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]: + messages = self._transform_messages(messages) + model_family = self.model_family.model_family or self.model_family.model_name - prompt, images = self.get_specific_prompt(model_family, messages) + + if "qwen2-vl" in model_family.lower(): + from qwen_vl_utils import process_vision_info + + prompt = self._processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + images, video_inputs = process_vision_info(messages) + if video_inputs: + raise ValueError("Not support video input now.") + else: + prompt, images = self.get_specific_prompt(model_family, messages) if len(images) == 0: inputs = { From 78f8ccf69712458ad211aaed71b5df7ad01b4214 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Sat, 12 Oct 2024 03:36:47 +0000 Subject: [PATCH 2/9] add vision model cache clean --- .../model/llm/llm_family_modelscope.json | 3 +- xinference/model/llm/transformers/cogvlm2.py | 3 +- .../model/llm/transformers/cogvlm2_video.py | 2 ++ .../model/llm/transformers/deepseek_vl.py | 2 ++ xinference/model/llm/transformers/glm4v.py | 3 +- .../model/llm/transformers/intern_vl.py | 2 ++ .../model/llm/transformers/minicpmv25.py | 2 ++ .../model/llm/transformers/minicpmv26.py | 2 ++ xinference/model/llm/transformers/omnilmm.py | 2 ++ .../model/llm/transformers/qwen2_audio.py | 11 ++++-- xinference/model/llm/transformers/qwen2_vl.py | 2 ++ xinference/model/llm/transformers/qwen_vl.py | 3 +- xinference/model/llm/transformers/utils.py | 34 ++++++++++++++++++- xinference/model/llm/transformers/yi_vl.py | 2 ++ xinference/model/llm/vllm/core.py | 3 +- xinference/model/llm/vllm/utils.py | 1 - 16 files changed, 68 insertions(+), 9 deletions(-) diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index cc856f44da..ae5d0fca2b 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -4395,7 +4395,8 @@ ], "model_ability": [ "chat", - "vision" + "vision", + "video" ], "model_description": "CogVLM2-Video achieves state-of-the-art performance on multiple video question answering tasks.", "model_specs": [ diff --git a/xinference/model/llm/transformers/cogvlm2.py b/xinference/model/llm/transformers/cogvlm2.py index f3c27454d9..27cdc23dc6 100644 --- a/xinference/model/llm/transformers/cogvlm2.py +++ b/xinference/model/llm/transformers/cogvlm2.py @@ -29,7 +29,7 @@ parse_messages, ) from .core import PytorchChatModel, PytorchGenerateConfig -from .utils import get_max_src_len +from .utils import cache_clean, get_max_src_len logger = logging.getLogger(__name__) @@ -176,6 +176,7 @@ def get_query_and_history( query = content return query, image, history + @cache_clean def chat( self, messages: List[Dict], diff --git a/xinference/model/llm/transformers/cogvlm2_video.py b/xinference/model/llm/transformers/cogvlm2_video.py index f39119f7aa..9fa7272a8e 100644 --- a/xinference/model/llm/transformers/cogvlm2_video.py +++ b/xinference/model/llm/transformers/cogvlm2_video.py @@ -28,6 +28,7 @@ parse_messages, ) from .core import PytorchChatModel, PytorchGenerateConfig +from .utils import cache_clean logger = logging.getLogger(__name__) @@ -227,6 +228,7 @@ def get_query_and_history( return query, image, video, history + @cache_clean def chat( self, messages: List[Dict], diff --git a/xinference/model/llm/transformers/deepseek_vl.py b/xinference/model/llm/transformers/deepseek_vl.py index cfec06b7d8..515644fec5 100644 --- a/xinference/model/llm/transformers/deepseek_vl.py +++ b/xinference/model/llm/transformers/deepseek_vl.py @@ -28,6 +28,7 @@ from ..llm_family import LLMFamilyV1, LLMSpecV1 from ..utils import generate_chat_completion, generate_completion_chunk from .core import PytorchChatModel, PytorchGenerateConfig +from .utils import cache_clean logger = logging.getLogger(__name__) @@ -137,6 +138,7 @@ def _fill_placeholder(_url, _index): return "".join(new_content), images return content, [] + @cache_clean def chat( self, messages: List[Dict], diff --git a/xinference/model/llm/transformers/glm4v.py b/xinference/model/llm/transformers/glm4v.py index c16a167688..b1109d4b04 100644 --- a/xinference/model/llm/transformers/glm4v.py +++ b/xinference/model/llm/transformers/glm4v.py @@ -26,7 +26,7 @@ from ..llm_family import LLMFamilyV1, LLMSpecV1 from ..utils import _decode_image, generate_chat_completion, generate_completion_chunk from .core import PytorchChatModel, PytorchGenerateConfig -from .utils import get_max_src_len +from .utils import cache_clean, get_max_src_len logger = logging.getLogger(__name__) @@ -129,6 +129,7 @@ def _get_processed_msgs(messages: List[Dict]) -> List[Dict]: res.append({"role": role, "content": text}) return res + @cache_clean def chat( self, messages: List[Dict], diff --git a/xinference/model/llm/transformers/intern_vl.py b/xinference/model/llm/transformers/intern_vl.py index 242d4d27ac..8150711e00 100644 --- a/xinference/model/llm/transformers/intern_vl.py +++ b/xinference/model/llm/transformers/intern_vl.py @@ -27,6 +27,7 @@ parse_messages, ) from .core import PytorchChatModel, PytorchGenerateConfig +from .utils import cache_clean logger = logging.getLogger(__name__) @@ -326,6 +327,7 @@ def load(self, **kwargs): use_fast=False, ) + @cache_clean def chat( self, messages: List[Dict], diff --git a/xinference/model/llm/transformers/minicpmv25.py b/xinference/model/llm/transformers/minicpmv25.py index 41b100d867..81fbc69706 100644 --- a/xinference/model/llm/transformers/minicpmv25.py +++ b/xinference/model/llm/transformers/minicpmv25.py @@ -29,6 +29,7 @@ parse_messages, ) from .core import PytorchChatModel, PytorchGenerateConfig +from .utils import cache_clean logger = logging.getLogger(__name__) @@ -119,6 +120,7 @@ def _message_content_to_chat(self, content): raise RuntimeError("Only one image per message is supported") return content, [] + @cache_clean def chat( self, messages: List[Dict], diff --git a/xinference/model/llm/transformers/minicpmv26.py b/xinference/model/llm/transformers/minicpmv26.py index 340ac841e2..cc6ba5e7a8 100644 --- a/xinference/model/llm/transformers/minicpmv26.py +++ b/xinference/model/llm/transformers/minicpmv26.py @@ -30,6 +30,7 @@ parse_messages, ) from .core import PytorchChatModel, PytorchGenerateConfig +from .utils import cache_clean logger = logging.getLogger(__name__) @@ -198,6 +199,7 @@ def _convert_to_specific_style(self, messages: List[Dict]) -> Tuple: msgs.append({"role": "user", "content": images_chat + [content]}) return msgs, video_existed + @cache_clean def chat( self, messages: List[Dict], diff --git a/xinference/model/llm/transformers/omnilmm.py b/xinference/model/llm/transformers/omnilmm.py index 3ddffda0a4..137ef5add1 100644 --- a/xinference/model/llm/transformers/omnilmm.py +++ b/xinference/model/llm/transformers/omnilmm.py @@ -24,6 +24,7 @@ from ..llm_family import LLMFamilyV1, LLMSpecV1 from ..utils import generate_chat_completion, parse_messages from .core import PytorchChatModel, PytorchGenerateConfig +from .utils import cache_clean logger = logging.getLogger(__name__) @@ -87,6 +88,7 @@ def _ensure_url(_url): return images, other_content return [], [{"type": "text", "text": content}] + @cache_clean def chat( self, messages: List[Dict], diff --git a/xinference/model/llm/transformers/qwen2_audio.py b/xinference/model/llm/transformers/qwen2_audio.py index 653f7217f8..603006c73c 100644 --- a/xinference/model/llm/transformers/qwen2_audio.py +++ b/xinference/model/llm/transformers/qwen2_audio.py @@ -20,10 +20,16 @@ import numpy as np from ....model.utils import select_device -from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk +from ....types import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessage, + CompletionChunk, +) from ..llm_family import LLMFamilyV1, LLMSpecV1 from ..utils import generate_chat_completion, generate_completion_chunk from .core import PytorchChatModel, PytorchGenerateConfig +from .utils import cache_clean logger = logging.getLogger(__name__) @@ -89,9 +95,10 @@ def _transform_messages( return text, audios + @cache_clean def chat( self, - messages: List[Dict], + messages: List[ChatCompletionMessage], # type: ignore generate_config: Optional[PytorchGenerateConfig] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: text, audios = self._transform_messages(messages) diff --git a/xinference/model/llm/transformers/qwen2_vl.py b/xinference/model/llm/transformers/qwen2_vl.py index 4a371886db..900f261113 100644 --- a/xinference/model/llm/transformers/qwen2_vl.py +++ b/xinference/model/llm/transformers/qwen2_vl.py @@ -27,6 +27,7 @@ from ..llm_family import LLMFamilyV1, LLMSpecV1 from ..utils import generate_chat_completion, generate_completion_chunk from .core import PytorchChatModel, PytorchGenerateConfig +from .utils import cache_clean logger = logging.getLogger(__name__) @@ -75,6 +76,7 @@ def load(self): self.model_path, device_map=device, trust_remote_code=True ).eval() + @cache_clean def chat( self, messages: List[ChatCompletionMessage], # type: ignore diff --git a/xinference/model/llm/transformers/qwen_vl.py b/xinference/model/llm/transformers/qwen_vl.py index 0b1e5b34f6..d803af75d7 100644 --- a/xinference/model/llm/transformers/qwen_vl.py +++ b/xinference/model/llm/transformers/qwen_vl.py @@ -28,7 +28,7 @@ from ..llm_family import LLMFamilyV1, LLMSpecV1 from ..utils import generate_chat_completion, generate_completion_chunk from .core import PytorchChatModel, PytorchGenerateConfig -from .utils import pad_prefill_tokens +from .utils import cache_clean, pad_prefill_tokens logger = logging.getLogger(__name__) @@ -137,6 +137,7 @@ def _get_prompt_and_chat_history(self, messages: List[Dict]): prompt = self._message_content_to_qwen(messages[-1]["content"]) return prompt, qwen_history + @cache_clean def chat( self, messages: List[Dict], diff --git a/xinference/model/llm/transformers/utils.py b/xinference/model/llm/transformers/utils.py index 36b1565c96..f3cf1a6c27 100644 --- a/xinference/model/llm/transformers/utils.py +++ b/xinference/model/llm/transformers/utils.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio +import functools import gc import logging import os @@ -777,3 +778,34 @@ def batch_inference_one_step( for r in req_list: r.stopped = True r.error_msg = str(e) + + +def cache_clean(fn): + @functools.wraps(fn) + async def _async_wrapper(self, *args, **kwargs): + import gc + + from ....device_utils import empty_cache + + result = await fn(self, *args, **kwargs) + + gc.collect() + empty_cache() + return result + + @functools.wraps(fn) + def _wrapper(self, *args, **kwargs): + import gc + + from ....device_utils import empty_cache + + result = fn(self, *args, **kwargs) + + gc.collect() + empty_cache() + return result + + if asyncio.iscoroutinefunction(fn): + return _async_wrapper + else: + return _wrapper diff --git a/xinference/model/llm/transformers/yi_vl.py b/xinference/model/llm/transformers/yi_vl.py index 9cfa87a536..69ce724402 100644 --- a/xinference/model/llm/transformers/yi_vl.py +++ b/xinference/model/llm/transformers/yi_vl.py @@ -29,6 +29,7 @@ parse_messages, ) from .core import PytorchChatModel, PytorchGenerateConfig +from .utils import cache_clean logger = logging.getLogger(__name__) @@ -99,6 +100,7 @@ def _message_content_to_yi(content) -> Union[str, tuple]: raise RuntimeError("Only one image per message is supported by Yi VL.") return content + @cache_clean def chat( self, messages: List[Dict], diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 43215f5b38..3eb91f59ec 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -34,6 +34,7 @@ from ....types import ( ChatCompletion, ChatCompletionChunk, + ChatCompletionMessage, Completion, CompletionChoice, CompletionChunk, @@ -771,7 +772,7 @@ def _sanitize_chat_config( @vllm_check async def async_chat( self, - messages: List[Dict], + messages: List[ChatCompletionMessage], # type: ignore generate_config: Optional[Dict] = None, request_id: Optional[str] = None, ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]: diff --git a/xinference/model/llm/vllm/utils.py b/xinference/model/llm/vllm/utils.py index 6984f87730..97af5ba580 100644 --- a/xinference/model/llm/vllm/utils.py +++ b/xinference/model/llm/vllm/utils.py @@ -26,7 +26,6 @@ def vllm_check(fn): @functools.wraps(fn) async def _async_wrapper(self, *args, **kwargs): - logger.info("vllm_check") try: return await fn(self, *args, **kwargs) except AsyncEngineDeadError: From 659b38c222676165d3c6581c8d4dc4a9e504f332 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Sat, 12 Oct 2024 03:43:50 +0000 Subject: [PATCH 3/9] fix error --- xinference/model/llm/transformers/qwen2_audio.py | 9 ++------- xinference/model/llm/utils.py | 2 +- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/xinference/model/llm/transformers/qwen2_audio.py b/xinference/model/llm/transformers/qwen2_audio.py index 603006c73c..fe46abdd7d 100644 --- a/xinference/model/llm/transformers/qwen2_audio.py +++ b/xinference/model/llm/transformers/qwen2_audio.py @@ -20,12 +20,7 @@ import numpy as np from ....model.utils import select_device -from ....types import ( - ChatCompletion, - ChatCompletionChunk, - ChatCompletionMessage, - CompletionChunk, -) +from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk from ..llm_family import LLMFamilyV1, LLMSpecV1 from ..utils import generate_chat_completion, generate_completion_chunk from .core import PytorchChatModel, PytorchGenerateConfig @@ -98,7 +93,7 @@ def _transform_messages( @cache_clean def chat( self, - messages: List[ChatCompletionMessage], # type: ignore + messages: List[Dict], generate_config: Optional[PytorchGenerateConfig] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: text, audios = self._transform_messages(messages) diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 2a30f93f13..0825658a93 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -119,7 +119,7 @@ def get_full_context( return self._build_from_raw_template(messages, chat_template, **kwargs) @staticmethod - def get_specific_prompt(model_family: str, messages: List[Dict]): + def get_specific_prompt(model_family: str, messages: List[ChatCompletionMessage]): """ Inspired by FastChat. Format chat history into a prompt according to the prompty style of different models. From ca6413081ef2fff1f375739e0ea716e09e2967fb Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Sat, 12 Oct 2024 04:05:15 +0000 Subject: [PATCH 4/9] fix error --- xinference/model/llm/transformers/qwen2_audio.py | 13 +++++++++---- xinference/model/llm/utils.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/xinference/model/llm/transformers/qwen2_audio.py b/xinference/model/llm/transformers/qwen2_audio.py index fe46abdd7d..e5ea0da981 100644 --- a/xinference/model/llm/transformers/qwen2_audio.py +++ b/xinference/model/llm/transformers/qwen2_audio.py @@ -14,13 +14,18 @@ import logging import uuid from io import BytesIO -from typing import Dict, Iterator, List, Optional, Union +from typing import Iterator, List, Optional, Union from urllib.request import urlopen import numpy as np from ....model.utils import select_device -from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk +from ....types import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessage, + CompletionChunk, +) from ..llm_family import LLMFamilyV1, LLMSpecV1 from ..utils import generate_chat_completion, generate_completion_chunk from .core import PytorchChatModel, PytorchGenerateConfig @@ -69,7 +74,7 @@ def load(self): def _transform_messages( self, - messages: List[Dict], + messages: List[ChatCompletionMessage], ): import librosa @@ -93,7 +98,7 @@ def _transform_messages( @cache_clean def chat( self, - messages: List[Dict], + messages: List[ChatCompletionMessage], generate_config: Optional[PytorchGenerateConfig] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: text, audios = self._transform_messages(messages) diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 0825658a93..98c9f666e3 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -135,7 +135,7 @@ def get_specific_prompt(model_family: str, messages: List[ChatCompletionMessage] ret = ( "" if system_prompt == "" - else "<|im_start|>system\n" + else "<|im_start|>system\n" # type: ignore + system_prompt + intra_message_sep + "\n" From dae0e200590953dc36cd1f5574bc550cb3997e57 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Sat, 12 Oct 2024 04:07:32 +0000 Subject: [PATCH 5/9] fix error --- xinference/model/llm/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 98c9f666e3..0db9afe74b 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -489,7 +489,6 @@ def _tool_calls_completion(cls, model_family, model_uid, c): "usage": usage, } - @classmethod def _transform_messages( self, messages: List[ChatCompletionMessage], From 0db5574f952903696ea90af75e2517049b2e80a3 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Sat, 12 Oct 2024 04:17:05 +0000 Subject: [PATCH 6/9] reset --- xinference/model/llm/llm_family_modelscope.json | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index ae5d0fca2b..cc856f44da 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -4395,8 +4395,7 @@ ], "model_ability": [ "chat", - "vision", - "video" + "vision" ], "model_description": "CogVLM2-Video achieves state-of-the-art performance on multiple video question answering tasks.", "model_specs": [ From 42576332f41ca02449f6df3b3a78f890fdfc8ae9 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Sat, 12 Oct 2024 08:15:21 +0000 Subject: [PATCH 7/9] fix vllm load error --- xinference/model/llm/llm_family.json | 21 ++++---- .../model/llm/llm_family_modelscope.json | 17 +++--- xinference/model/llm/vllm/core.py | 52 ++++++++++++------- 3 files changed, 50 insertions(+), 40 deletions(-) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 644bf0e714..8f442c1837 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -6909,18 +6909,15 @@ "model_id":"Qwen/Qwen2-VL-72B-Instruct-GPTQ-{quantization}" } ], - "prompt_style":{ - "style_name":"QWEN", - "system_prompt":"You are a helpful assistant", - "roles":[ - "user", - "assistant" - ], - "stop": [ - "<|im_end|>", - "<|endoftext|>" - ] - } + "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}", + "stop_token_ids": [ + 151645, + 151643 + ], + "stop": [ + "<|im_end|>", + "<|endoftext|>" + ] }, { "version": 1, diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index cc856f44da..d2bcf446e7 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -4627,14 +4627,15 @@ "model_hub": "modelscope" } ], - "prompt_style": { - "style_name": "QWEN", - "system_prompt": "You are a helpful assistant", - "roles": [ - "user", - "assistant" - ] - } + "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}", + "stop_token_ids": [ + 151645, + 151643 + ], + "stop": [ + "<|im_end|>", + "<|endoftext|>" + ] }, { "version": 1, diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 3eb91f59ec..5b5233fc9a 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -311,14 +311,6 @@ def _sanitize_model_config( model_config.setdefault("max_num_seqs", 256) model_config.setdefault("quantization", None) model_config.setdefault("max_model_len", None) - if vllm.__version__ >= "0.6.1": - model_config["limit_mm_per_prompt"] = ( - json.loads(model_config.get("limit_mm_per_prompt")) # type: ignore - if model_config.get("limit_mm_per_prompt") - else { - "image": 2, # default 2 images all chat - } - ) return model_config @@ -738,17 +730,32 @@ def match( return False return VLLM_INSTALLED - def load(self): - super().load() + def _sanitize_model_config( + self, model_config: Optional[VLLMModelConfig] + ) -> VLLMModelConfig: + if model_config is None: + model_config = VLLMModelConfig() - self._processor = None - model_family = self.model_family.model_family or self.model_family.model_name - if "qwen2-vl" in model_family.lower(): - from transformers import AutoProcessor + cuda_count = self._get_cuda_count() - self._processor = AutoProcessor.from_pretrained( - self.model_path, trust_remote_code=True - ) + model_config.setdefault("tokenizer_mode", "auto") + model_config.setdefault("trust_remote_code", True) + model_config.setdefault("tensor_parallel_size", cuda_count) + model_config.setdefault("block_size", 16) + model_config.setdefault("swap_space", 4) + model_config.setdefault("gpu_memory_utilization", 0.90) + model_config.setdefault("max_num_seqs", 256) + model_config.setdefault("quantization", None) + model_config.setdefault("max_model_len", None) + model_config["limit_mm_per_prompt"] = ( + json.loads(model_config.get("limit_mm_per_prompt")) # type: ignore + if model_config.get("limit_mm_per_prompt") + else { + "image": 2, # default 2 images all chat + } + ) + + return model_config def _sanitize_chat_config( self, @@ -777,14 +784,19 @@ async def async_chat( request_id: Optional[str] = None, ) -> Union[ChatCompletion, AsyncGenerator[ChatCompletionChunk, None]]: messages = self._transform_messages(messages) + tools = generate_config.pop("tools", []) if generate_config else None model_family = self.model_family.model_family or self.model_family.model_name - if "qwen2-vl" in model_family.lower(): + if "internvl2" not in model_family.lower(): from qwen_vl_utils import process_vision_info - prompt = self._processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + full_context_kwargs = {} + if tools and model_family in QWEN_TOOL_CALL_FAMILY: + full_context_kwargs["tools"] = tools + assert self.model_family.chat_template is not None + prompt = self.get_full_context( + messages, self.model_family.chat_template, **full_context_kwargs ) images, video_inputs = process_vision_info(messages) if video_inputs: From eb3baa3a64cc80d32dcc12d512b1c6d5841e0077 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Sat, 12 Oct 2024 09:07:37 +0000 Subject: [PATCH 8/9] fix vllm load error --- xinference/model/llm/vllm/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 5b5233fc9a..f77a51f0be 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -175,6 +175,8 @@ class VLLMGenerateConfig(TypedDict, total=False): if VLLM_INSTALLED and vllm.__version__ >= "0.6.1": VLLM_SUPPORTED_VISION_MODEL_LIST.append("internvl2") + +if VLLM_INSTALLED and vllm.__version__ >= "0.6.3": VLLM_SUPPORTED_VISION_MODEL_LIST.append("qwen2-vl-instruct") From 9e813f823ac383768ca067649a34eebd4c36fa34 Mon Sep 17 00:00:00 2001 From: wuzhaoxin <15667065080@162.com> Date: Sat, 12 Oct 2024 09:55:51 +0000 Subject: [PATCH 9/9] fix no image chat error --- xinference/model/llm/vllm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index f77a51f0be..5ecd01f81a 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -806,7 +806,7 @@ async def async_chat( else: prompt, images = self.get_specific_prompt(model_family, messages) - if len(images) == 0: + if not images: inputs = { "prompt": prompt, }