Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 29, 2024
1 parent cd0f001 commit 95b02da
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 14 deletions.
2 changes: 1 addition & 1 deletion swift/llm/infer/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def _gen_wrapper():
yield f'data: {json.dumps(asdict(res), ensure_ascii=False)}\n\n'
yield 'data: [DONE]\n\n'

return _gen_wrapper
return StreamingResponse(_gen_wrapper())
else:
return self._post_process(res_or_gen, return_cmpl_response)

Expand Down
1 change: 0 additions & 1 deletion swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ def infer_dataset(self) -> List[Dict[str, Any]]:
infer_requests = []
for data in val_dataset:
infer_request = InferRequest(**data)
infer_request.remove_response()
infer_requests.append(infer_request)
resp_list = self.infer(infer_requests, request_config, template=self.template, use_tqdm=True)
for data, resp in zip(val_dataset, resp_list):
Expand Down
12 changes: 8 additions & 4 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import re
from contextlib import contextmanager
from dataclasses import asdict
from functools import partial, wraps
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

Expand All @@ -18,7 +19,7 @@
from transformers.integrations import is_deepspeed_zero3_enabled

from .agent import loss_scale_map, split_str_parts_by
from .template_inputs import StdTemplateInputs, TemplateInputs
from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs
from .utils import Context, ContextType, Prompt, Word, fetch_one, findall
from .vision_utils import load_batch, load_image, normalize_bbox, rescale_image

Expand Down Expand Up @@ -135,7 +136,7 @@ def _preprocess_inputs(

def encode(
self,
inputs: TemplateInputs,
inputs: Union[TemplateInputs, Dict[str, Any], StdTemplateInputs, InferRequest],
*,
model=None,
) -> Dict[str, Any]:
Expand All @@ -144,10 +145,13 @@ def encode(
Returns:
return {'input_ids': List[int], 'labels': Optional[List[int]], ...}
"""
if isinstance(inputs, InferRequest):
inputs = asdict(inputs)
elif isinstance(inputs, TemplateInputs):
inputs = asdict(inputs)

if isinstance(inputs, dict):
inputs = StdTemplateInputs.from_dict(inputs, tools_prompt=self.tools_prompt)
elif isinstance(inputs, TemplateInputs):
inputs = StdTemplateInputs.from_template_inputs(inputs, tools_prompt=self.tools_prompt)
elif isinstance(inputs, StdTemplateInputs):
inputs = inputs.copy()

Expand Down
12 changes: 4 additions & 8 deletions swift/llm/template/template_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class InferRequest:

tools: Optional[List[Tool]] = None

def remove_response(self):
def __post_init__(self):
self._remove_response()

def _remove_response(self):
last_role = self.messages[-1]['role']
if last_role == 'assistant':
self.messages.pop()
Expand Down Expand Up @@ -153,13 +156,6 @@ def remove_messages_media(messages: Messages) -> Dict[str, Any]:
message['content'] = new_content
return res

@classmethod
def from_template_inputs(cls,
template_inputs: TemplateInputs,
*,
tools_prompt: str = 'react_en') -> 'StdTemplateInputs':
return cls.from_dict(asdict(template_inputs), tools_prompt=tools_prompt)

@staticmethod
def messages_join_observation(messages: Messages) -> None:
"""
Expand Down

0 comments on commit 95b02da

Please sign in to comment.