Skip to content

Commit

Permalink
update client
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 29, 2024
1 parent f2665df commit 7b4ad89
Showing 1 changed file with 96 additions and 24 deletions.
120 changes: 96 additions & 24 deletions swift/llm/infer/infer_engine/infer_client.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,117 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import AsyncIterator, Iterator, List, Optional, Union
import asyncio
from copy import deepcopy
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union

import aiohttp
import json
from dacite import from_dict
from requests.exceptions import HTTPError

from swift.plugin import Metric
from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest, RequestConfig
from .base import BaseInferEngine
from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest, ModelList, RequestConfig
from .infer_engine import InferEngine


class InferClient(BaseInferEngine):
class InferClient(InferEngine):

def __init__(self,
model: str,
host: str = '127.0.0.1',
port: str = '8000',
api_key: str = 'EMPTY',
*,
url: Optional[str] = None) -> None:
self.model = model
timeout: Optional[int] = None) -> None:
self.api_key = api_key
self.host = host
self.port = port
self.timeout = timeout
self.models = []
for model in self.get_model_list().data:
self.models.append(model.id)
assert len(self.models) > 0, f'self.models: {self.models}'

def get_model_list(self, *, url: Optional[str] = None) -> ModelList:
return asyncio.run(self.get_model_list_async(url=url))

def _get_request_kwargs(self) -> Dict[str, Any]:
request_kwargs = {}
if isinstance(self.timeout, int) and self.timeout >= 0:
request_kwargs['timeout'] = self.timeout
if self.api_key is not None:
request_kwargs['headers'] = {'Authorization': f'Bearer {self.api_key}'}
return request_kwargs

if url is not None:
url = f'http://{host}:{port}/v1/chat/completions'
self.url = url
async def get_model_list_async(self, *, url: Optional[str] = None) -> ModelList:
if url is None:
url = f'http://{self.host}:{self.port}/v1/models'
async with aiohttp.ClientSession() as session:
async with session.get(url, **self._get_request_kwargs()) as resp:
resp_obj = await resp.json()
return from_dict(ModelList, resp_obj)

def infer(
self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
metrics: Optional[List[Metric]] = None,
*,
model: Optional[str] = None,
url: Optional[str] = None,
use_tqdm: Optional[bool] = None
) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
return super().infer(infer_requests, request_config, metrics, model=model, url=url, use_tqdm=use_tqdm)

@staticmethod
def get_model_list():
def _prepare_request_data(model: str, infer_request: InferRequest, request_config: RequestConfig) -> Dict[str, Any]:
pass

@staticmethod
async def get_model_list_async():
pass
def _parse_stream_data(data: bytes) -> Optional[str]:
data = data.decode(encoding='utf-8')
data = data.strip()
if len(data) == 0:
return
assert data.startswith('data:'), f'data: {data}'
return data[5:].strip()

def infer(self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
metrics: Optional[List[Metric]] = None,
*,
use_tqdm: Optional[bool] = None,
**kwargs) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
pass
async def infer_async(
self,
infer_request: InferRequest,
request_config: Optional[RequestConfig] = None,
*,
model: Optional[str] = None,
url: Optional[str] = None) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
request_config = deepcopy(request_config or RequestConfig())
if model is None:
if len(self.models) == 1:
model = self.models[0]
else:
raise ValueError(f'Please explicitly specify the model. Available models: {self.models}.')
if url is None:
url = f'http://{self.host}:{self.port}/v1/chat/completions'

async def infer_async(self, *args,
**kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
pass
request_data = self._prepare_request_data(model, infer_request, request_config)
if request_config.stream:

async def _gen_stream() -> AsyncIterator[ChatCompletionStreamResponse]:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=request_data, **self._get_request_kwargs()) as resp:
async for data in resp.content:
data = self._parse_stream_data(data)
if data == '[DONE]':
break
if data is not None:
resp_obj = json.loads(data)
if resp_obj['object'] == 'error':
raise HTTPError(resp_obj['message'])
yield from_dict(ChatCompletionStreamResponse, resp_obj)

return _gen_stream()
else:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=request_data, **self._get_request_kwargs()) as resp:
resp_obj = await resp.json()
if resp_obj['object'] == 'error':
raise HTTPError(resp_obj['message'])
return from_dict(ChatCompletionResponse, resp_obj)

0 comments on commit 7b4ad89

Please sign in to comment.