-
Notifications
You must be signed in to change notification settings - Fork 366
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f2665df
commit 7b4ad89
Showing
1 changed file
with
96 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,117 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
|
||
from typing import AsyncIterator, Iterator, List, Optional, Union | ||
import asyncio | ||
from copy import deepcopy | ||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union | ||
|
||
import aiohttp | ||
import json | ||
from dacite import from_dict | ||
from requests.exceptions import HTTPError | ||
|
||
from swift.plugin import Metric | ||
from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest, RequestConfig | ||
from .base import BaseInferEngine | ||
from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest, ModelList, RequestConfig | ||
from .infer_engine import InferEngine | ||
|
||
|
||
class InferClient(BaseInferEngine): | ||
class InferClient(InferEngine): | ||
|
||
def __init__(self, | ||
model: str, | ||
host: str = '127.0.0.1', | ||
port: str = '8000', | ||
api_key: str = 'EMPTY', | ||
*, | ||
url: Optional[str] = None) -> None: | ||
self.model = model | ||
timeout: Optional[int] = None) -> None: | ||
self.api_key = api_key | ||
self.host = host | ||
self.port = port | ||
self.timeout = timeout | ||
self.models = [] | ||
for model in self.get_model_list().data: | ||
self.models.append(model.id) | ||
assert len(self.models) > 0, f'self.models: {self.models}' | ||
|
||
def get_model_list(self, *, url: Optional[str] = None) -> ModelList: | ||
return asyncio.run(self.get_model_list_async(url=url)) | ||
|
||
def _get_request_kwargs(self) -> Dict[str, Any]: | ||
request_kwargs = {} | ||
if isinstance(self.timeout, int) and self.timeout >= 0: | ||
request_kwargs['timeout'] = self.timeout | ||
if self.api_key is not None: | ||
request_kwargs['headers'] = {'Authorization': f'Bearer {self.api_key}'} | ||
return request_kwargs | ||
|
||
if url is not None: | ||
url = f'http://{host}:{port}/v1/chat/completions' | ||
self.url = url | ||
async def get_model_list_async(self, *, url: Optional[str] = None) -> ModelList: | ||
if url is None: | ||
url = f'http://{self.host}:{self.port}/v1/models' | ||
async with aiohttp.ClientSession() as session: | ||
async with session.get(url, **self._get_request_kwargs()) as resp: | ||
resp_obj = await resp.json() | ||
return from_dict(ModelList, resp_obj) | ||
|
||
def infer( | ||
self, | ||
infer_requests: List[InferRequest], | ||
request_config: Optional[RequestConfig] = None, | ||
metrics: Optional[List[Metric]] = None, | ||
*, | ||
model: Optional[str] = None, | ||
url: Optional[str] = None, | ||
use_tqdm: Optional[bool] = None | ||
) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]: | ||
return super().infer(infer_requests, request_config, metrics, model=model, url=url, use_tqdm=use_tqdm) | ||
|
||
@staticmethod | ||
def get_model_list(): | ||
def _prepare_request_data(model: str, infer_request: InferRequest, request_config: RequestConfig) -> Dict[str, Any]: | ||
pass | ||
|
||
@staticmethod | ||
async def get_model_list_async(): | ||
pass | ||
def _parse_stream_data(data: bytes) -> Optional[str]: | ||
data = data.decode(encoding='utf-8') | ||
data = data.strip() | ||
if len(data) == 0: | ||
return | ||
assert data.startswith('data:'), f'data: {data}' | ||
return data[5:].strip() | ||
|
||
def infer(self, | ||
infer_requests: List[InferRequest], | ||
request_config: Optional[RequestConfig] = None, | ||
metrics: Optional[List[Metric]] = None, | ||
*, | ||
use_tqdm: Optional[bool] = None, | ||
**kwargs) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]: | ||
pass | ||
async def infer_async( | ||
self, | ||
infer_request: InferRequest, | ||
request_config: Optional[RequestConfig] = None, | ||
*, | ||
model: Optional[str] = None, | ||
url: Optional[str] = None) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: | ||
request_config = deepcopy(request_config or RequestConfig()) | ||
if model is None: | ||
if len(self.models) == 1: | ||
model = self.models[0] | ||
else: | ||
raise ValueError(f'Please explicitly specify the model. Available models: {self.models}.') | ||
if url is None: | ||
url = f'http://{self.host}:{self.port}/v1/chat/completions' | ||
|
||
async def infer_async(self, *args, | ||
**kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: | ||
pass | ||
request_data = self._prepare_request_data(model, infer_request, request_config) | ||
if request_config.stream: | ||
|
||
async def _gen_stream() -> AsyncIterator[ChatCompletionStreamResponse]: | ||
async with aiohttp.ClientSession() as session: | ||
async with session.post(url, json=request_data, **self._get_request_kwargs()) as resp: | ||
async for data in resp.content: | ||
data = self._parse_stream_data(data) | ||
if data == '[DONE]': | ||
break | ||
if data is not None: | ||
resp_obj = json.loads(data) | ||
if resp_obj['object'] == 'error': | ||
raise HTTPError(resp_obj['message']) | ||
yield from_dict(ChatCompletionStreamResponse, resp_obj) | ||
|
||
return _gen_stream() | ||
else: | ||
async with aiohttp.ClientSession() as session: | ||
async with session.post(url, json=request_data, **self._get_request_kwargs()) as resp: | ||
resp_obj = await resp.json() | ||
if resp_obj['object'] == 'error': | ||
raise HTTPError(resp_obj['message']) | ||
return from_dict(ChatCompletionResponse, resp_obj) |