Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 29, 2024
1 parent 5781b18 commit cd0f001
Show file tree
Hide file tree
Showing 10 changed files with 234 additions and 240 deletions.
9 changes: 9 additions & 0 deletions swift/llm/argument/deploy_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,12 @@ class DeployArguments(InferArguments):
served_model_name: Optional[str] = None
verbose: bool = True # Whether to log request_info
log_interval: int = 10 # Interval for printing global statistics

def _init_stream(self):
pass

def _init_eval_human(self):
pass

def _init_result_dir(self, folder_name: str = 'deploy_result') -> None:
super()._init_result_dir(folder_name=folder_name)
11 changes: 6 additions & 5 deletions swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ def __post_init__(self):
class InferArguments(BaseArguments, MergeArguments, VllmArguments, LmdeployArguments):
infer_backend: Literal['vllm', 'pt', 'lmdeploy'] = 'pt'
ckpt_dir: Optional[str] = field(default=None, metadata={'help': '/path/to/your/vx-xxx/checkpoint-xxx'})
max_batch_size: int = 16 # for pt engine

# only for inference
val_dataset_sample: Optional[int] = None
result_dir: Optional[str] = field(default=None, metadata={'help': '/path/to/your/infer_result'})
save_result: bool = True

max_batch_size: int = 16 # for pt engine
stream: Optional[bool] = None

def _init_result_dir(self) -> None:
def _init_result_dir(self, folder_name: str = 'infer_result') -> None:
self.result_path = None
if not self.save_result:
return
Expand All @@ -66,7 +66,7 @@ def _init_result_dir(self) -> None:
result_dir = self.model_info.model_dir
else:
result_dir = self.ckpt_dir
result_dir = os.path.join(result_dir, 'infer_result')
result_dir = os.path.join(result_dir, folder_name)
else:
result_dir = self.result_dir
result_dir = to_abspath(result_dir)
Expand All @@ -90,11 +90,11 @@ def __post_init__(self) -> None:
BaseArguments.__post_init__(self)
MergeArguments.__post_init__(self)
VllmArguments.__post_init__(self)
self._parse_lora_modules()

self._init_result_dir()
self._init_stream()
self._init_eval_human()
self._parse_lora_modules()

def _init_eval_human(self):
if len(self.dataset) == 0 and len(self.val_dataset) == 0:
Expand All @@ -106,6 +106,7 @@ def _init_eval_human(self):

def _parse_lora_modules(self) -> None:
if len(self.lora_modules) == 0:
self.lora_request_list = []
return
assert self.infer_backend in {'vllm', 'pt'}
if self.infer_backend == 'vllm':
Expand Down
Loading

0 comments on commit cd0f001

Please sign in to comment.