diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 24a91af62c7b7..66dab65804b70 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,10 +1,10 @@ -from os import path from threading import Lock from time import time from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout from requests.sessions import Session +from yarl import URL class XinferenceModelExtraParameter: @@ -55,7 +55,10 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen get xinference model extra parameter like model_format and model_handle_type """ - url = path.join(server_url, 'v1/models', model_uid) + if not model_uid or not model_uid.strip() or not server_url or not server_url.strip(): + raise RuntimeError('model_uid is empty') + + url = str(URL(server_url) / 'v1' / 'models' / model_uid) # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 session = Session() @@ -66,7 +69,6 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen response = session.get(url, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') - if response.status_code != 200: raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') diff --git a/api/requirements.txt b/api/requirements.txt index 1c3e89e78094b..6aadcf77235dd 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -67,4 +67,5 @@ pydub~=0.25.1 gmpy2~=2.1.5 numexpr~=2.9.0 duckduckgo-search==4.4.3 -arxiv==2.1.0 \ No newline at end of file +arxiv==2.1.0 +yarl~=1.9.4 \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index e4cc2ceea6083..bba5704d2eb72 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -32,68 +32,70 @@ def get(self: Session, url: str, **kwargs): response = Response() if 'v1/models/' in url: # get model uid - model_uid = url.split('/')[-1] + model_uid = url.split('/')[-1] or '' if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ model_uid not in ['generate', 'chat', 'embedding', 'rerank']: response.status_code = 404 + response._content = b'{}' return response # check if url is valid if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): response.status_code = 404 + response._content = b'{}' return response if model_uid in ['generate', 'chat']: response.status_code = 200 response._content = b'''{ - "model_type": "LLM", - "address": "127.0.0.1:43877", - "accelerators": [ - "0", - "1" - ], - "model_name": "chatglm3-6b", - "model_lang": [ - "en" - ], - "model_ability": [ - "generate", - "chat" - ], - "model_description": "latest chatglm3", - "model_format": "pytorch", - "model_size_in_billions": 7, - "quantization": "none", - "model_hub": "huggingface", - "revision": null, - "context_length": 2048, - "replica": 1 - }''' + "model_type": "LLM", + "address": "127.0.0.1:43877", + "accelerators": [ + "0", + "1" + ], + "model_name": "chatglm3-6b", + "model_lang": [ + "en" + ], + "model_ability": [ + "generate", + "chat" + ], + "model_description": "latest chatglm3", + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantization": "none", + "model_hub": "huggingface", + "revision": null, + "context_length": 2048, + "replica": 1 + }''' return response elif model_uid == 'embedding': response.status_code = 200 response._content = b'''{ - "model_type": "embedding", - "address": "127.0.0.1:43877", - "accelerators": [ - "0", - "1" - ], - "model_name": "bge", - "model_lang": [ - "en" - ], - "revision": null, - "max_tokens": 512 -}''' + "model_type": "embedding", + "address": "127.0.0.1:43877", + "accelerators": [ + "0", + "1" + ], + "model_name": "bge", + "model_lang": [ + "en" + ], + "revision": null, + "max_tokens": 512 + }''' return response elif 'v1/cluster/auth' in url: response.status_code = 200 response._content = b'''{ - "auth": true -}''' + "auth": true + }''' return response def _check_cluster_authenticated(self):