Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: replace os.path.join with yarl #2690

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from os import path
from threading import Lock
from time import time

from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout
from requests.sessions import Session
from yarl import URL


class XinferenceModelExtraParameter:
Expand Down Expand Up @@ -55,7 +55,10 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen
get xinference model extra parameter like model_format and model_handle_type
"""

url = path.join(server_url, 'v1/models', model_uid)
if not model_uid or not model_uid.strip() or not server_url or not server_url.strip():
raise RuntimeError('model_uid is empty')

url = str(URL(server_url) / 'v1' / 'models' / model_uid)

# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session()
Expand All @@ -66,7 +69,6 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen
response = session.get(url, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')

if response.status_code != 200:
raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')

Expand Down
3 changes: 2 additions & 1 deletion api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,5 @@ pydub~=0.25.1
gmpy2~=2.1.5
numexpr~=2.9.0
duckduckgo-search==4.4.3
arxiv==2.1.0
arxiv==2.1.0
yarl~=1.9.4
80 changes: 41 additions & 39 deletions api/tests/integration_tests/model_runtime/__mock/xinference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,68 +32,70 @@ def get(self: Session, url: str, **kwargs):
response = Response()
if 'v1/models/' in url:
# get model uid
model_uid = url.split('/')[-1]
model_uid = url.split('/')[-1] or ''
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
response.status_code = 404
response._content = b'{}'
return response

# check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
response.status_code = 404
response._content = b'{}'
return response

if model_uid in ['generate', 'chat']:
response.status_code = 200
response._content = b'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
return response

elif model_uid == 'embedding':
response.status_code = 200
response._content = b'''{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
return response

elif 'v1/cluster/auth' in url:
response.status_code = 200
response._content = b'''{
"auth": true
}'''
"auth": true
}'''
return response

def _check_cluster_authenticated(self):
Expand Down
Loading