diff --git a/swift/llm/__init__.py b/swift/llm/__init__.py index acbcfb6aa..5a3a71bba 100644 --- a/swift/llm/__init__.py +++ b/swift/llm/__init__.py @@ -5,15 +5,15 @@ if TYPE_CHECKING: # Recommend using `xxx_main` - from .infer import (VllmEngine, InferRequest, RequestConfig, InferStats, LmdeployEngine, PtEngine, infer_main, - deploy_main, PtLoRARequest, InferClient) + from .infer import (VllmEngine, RequestConfig, InferStats, LmdeployEngine, PtEngine, infer_main, deploy_main, + PtLoRARequest, InferClient) from .export import export_main, merge_lora from .eval import eval_main from .train import sft_main, pt_main, rlhf_main from .argument import (EvalArguments, InferArguments, SftArguments, ExportArguments, DeployArguments, RLHFArguments, WebUIArguments, AppUIArguments) from .template import (TEMPLATE_MAPPING, Template, Word, get_template, TemplateType, register_template, - TemplateInputs, Messages, TemplateMeta, get_template_meta) + TemplateInputs, Messages, TemplateMeta, get_template_meta, InferRequest) from .model import (MODEL_MAPPING, ModelType, get_model_tokenizer, safe_snapshot_download, HfConfigFactory, ModelInfo, ModelMeta, get_model_meta) from .dataset import (AlpacaPreprocessor, MessagesPreprocessor, AutoPreprocessor, DatasetName, DATASET_MAPPING, @@ -29,8 +29,8 @@ _import_structure = { 'rlhf': ['rlhf_main'], 'infer': [ - 'deploy_main', 'VllmEngine', 'InferRequest', 'RequestConfig', 'InferStats', 'LmdeployEngine', 'PtEngine', - 'infer_main', 'PtLoRARequest', 'InferClient' + 'deploy_main', 'VllmEngine', 'RequestConfig', 'InferStats', 'LmdeployEngine', 'PtEngine', 'infer_main', + 'PtLoRARequest', 'InferClient' ], 'export': ['export_main', 'merge_lora'], 'eval': ['eval_main'], @@ -41,7 +41,7 @@ ], 'template': [ 'TEMPLATE_MAPPING', 'Template', 'Word', 'get_template', 'TemplateType', 'register_template', - 'TemplateInputs', 'Messages', 'TemplateMeta', 'get_template_meta' + 'TemplateInputs', 'Messages', 'TemplateMeta', 'get_template_meta', 'InferRequest' ], 'model': [ 'MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'safe_snapshot_download', 'HfConfigFactory', diff --git a/swift/llm/infer/__init__.py b/swift/llm/infer/__init__.py index c3e3e27eb..fb59de556 100644 --- a/swift/llm/infer/__init__.py +++ b/swift/llm/infer/__init__.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from .infer import infer_main from .deploy import deploy_main - from .protocol import InferRequest, RequestConfig + from .protocol import RequestConfig from .infer_engine import (InferEngine, VllmEngine, LmdeployEngine, PtEngine, InferStats, PtLoRARequest, InferClient) else: @@ -14,9 +14,9 @@ _import_structure = { 'deploy': ['deploy_main'], 'infer': ['infer_main'], - 'protocol': ['InferRequest', 'RequestConfig'], - 'infer_engine': ['InferEngine', 'VllmEngine', 'LmdeployEngine', 'PtEngine', 'InferStats', 'PtLoRARequest', - 'InferClient'], + 'protocol': ['RequestConfig'], + 'infer_engine': + ['InferEngine', 'VllmEngine', 'LmdeployEngine', 'PtEngine', 'InferStats', 'PtLoRARequest', 'InferClient'], } import sys diff --git a/swift/llm/infer/client_utils.py b/swift/llm/infer/client_utils.py deleted file mode 100644 index 6a5eeedb2..000000000 --- a/swift/llm/infer/client_utils.py +++ /dev/null @@ -1,293 +0,0 @@ -import os -import re -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union - -import aiohttp -import json -import requests -from dacite import from_dict -from requests.exceptions import HTTPError - -from swift.llm import History, Messages -from swift.llm.infer.protocol import (ChatCompletionResponse, ChatCompletionStreamResponse, CompletionResponse, - CompletionStreamResponse, ModelList, XRequestConfig) - - -def _get_request_kwargs(api_key: Optional[str] = None) -> Dict[str, Any]: - timeout = float(os.getenv('TIMEOUT', '300')) - request_kwargs = {} - if timeout > 0: - request_kwargs['timeout'] = timeout - if api_key is not None: - request_kwargs['headers'] = {'Authorization': f'Bearer {api_key}'} - return request_kwargs - - -def get_model_list_client(host: str = '127.0.0.1', port: str = '8000', api_key: str = 'EMPTY', **kwargs) -> ModelList: - url = kwargs.pop('url', None) - if url is None: - url = f'http://{host}:{port}/v1' - url = url.rstrip('/') - url = f'{url}/models' - resp_obj = requests.get(url, **_get_request_kwargs(api_key)).json() - return from_dict(ModelList, resp_obj) - - -async def get_model_list_client_async(host: str = '127.0.0.1', - port: str = '8000', - api_key: str = 'EMPTY', - **kwargs) -> ModelList: - url = kwargs.pop('url', None) - if url is None: - url = f'http://{host}:{port}/v1' - url = url.rstrip('/') - url = f'{url}/models' - async with aiohttp.ClientSession() as session: - async with session.get(url, **_get_request_kwargs(api_key)) as resp: - resp_obj = await resp.json() - return from_dict(ModelList, resp_obj) - - -def _parse_stream_data(data: bytes) -> Optional[str]: - data = data.decode(encoding='utf-8') - data = data.strip() - if len(data) == 0: - return - assert data.startswith('data:'), f'data: {data}' - return data[5:].strip() - - -def compat_openai(messages: Messages, request) -> None: - for message in messages: - content = message['content'] - if isinstance(content, list): - text = '' - for line in content: - _type = line['type'] - value = line[_type] - if _type == 'text': - text += value - elif _type in {'image_url', 'audio_url', 'video_url'}: - value = value['url'] - if value.startswith('data:'): - match_ = re.match(r'data:(.+?);base64,(.+)', value) - assert match_ is not None - value = match_.group(2) - if _type == 'image_url': - text += '' - request.images.append(value) - elif _type == 'audio_url': - text += '