Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/feat/refactor3' int…
Browse files Browse the repository at this point in the history
…o feat/refactor3
  • Loading branch information
Jintao-Huang committed Oct 29, 2024
2 parents 75386c8 + 5b5315b commit 7407569
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 393 deletions.
12 changes: 6 additions & 6 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

if TYPE_CHECKING:
# Recommend using `xxx_main`
from .infer import (VllmEngine, InferRequest, RequestConfig, InferStats, LmdeployEngine, PtEngine, infer_main,
deploy_main, PtLoRARequest, InferClient)
from .infer import (VllmEngine, RequestConfig, InferStats, LmdeployEngine, PtEngine, infer_main, deploy_main,
PtLoRARequest, InferClient)
from .export import export_main, merge_lora
from .eval import eval_main
from .train import sft_main, pt_main, rlhf_main
from .argument import (EvalArguments, InferArguments, SftArguments, ExportArguments, DeployArguments, RLHFArguments,
WebUIArguments, AppUIArguments)
from .template import (TEMPLATE_MAPPING, Template, Word, get_template, TemplateType, register_template,
TemplateInputs, Messages, TemplateMeta, get_template_meta)
TemplateInputs, Messages, TemplateMeta, get_template_meta, InferRequest)
from .model import (MODEL_MAPPING, ModelType, get_model_tokenizer, safe_snapshot_download, HfConfigFactory,
ModelInfo, ModelMeta, get_model_meta)
from .dataset import (AlpacaPreprocessor, MessagesPreprocessor, AutoPreprocessor, DatasetName, DATASET_MAPPING,
Expand All @@ -29,8 +29,8 @@
_import_structure = {
'rlhf': ['rlhf_main'],
'infer': [
'deploy_main', 'VllmEngine', 'InferRequest', 'RequestConfig', 'InferStats', 'LmdeployEngine', 'PtEngine',
'infer_main', 'PtLoRARequest', 'InferClient'
'deploy_main', 'VllmEngine', 'RequestConfig', 'InferStats', 'LmdeployEngine', 'PtEngine', 'infer_main',
'PtLoRARequest', 'InferClient'
],
'export': ['export_main', 'merge_lora'],
'eval': ['eval_main'],
Expand All @@ -41,7 +41,7 @@
],
'template': [
'TEMPLATE_MAPPING', 'Template', 'Word', 'get_template', 'TemplateType', 'register_template',
'TemplateInputs', 'Messages', 'TemplateMeta', 'get_template_meta'
'TemplateInputs', 'Messages', 'TemplateMeta', 'get_template_meta', 'InferRequest'
],
'model': [
'MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'safe_snapshot_download', 'HfConfigFactory',
Expand Down
8 changes: 4 additions & 4 deletions swift/llm/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
if TYPE_CHECKING:
from .infer import infer_main
from .deploy import deploy_main
from .protocol import InferRequest, RequestConfig
from .protocol import RequestConfig
from .infer_engine import (InferEngine, VllmEngine, LmdeployEngine, PtEngine, InferStats, PtLoRARequest,
InferClient)
else:
_extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
_import_structure = {
'deploy': ['deploy_main'],
'infer': ['infer_main'],
'protocol': ['InferRequest', 'RequestConfig'],
'infer_engine': ['InferEngine', 'VllmEngine', 'LmdeployEngine', 'PtEngine', 'InferStats', 'PtLoRARequest',
'InferClient'],
'protocol': ['RequestConfig'],
'infer_engine':
['InferEngine', 'VllmEngine', 'LmdeployEngine', 'PtEngine', 'InferStats', 'PtLoRARequest', 'InferClient'],
}

import sys
Expand Down
293 changes: 0 additions & 293 deletions swift/llm/infer/client_utils.py

This file was deleted.

10 changes: 5 additions & 5 deletions swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

import numpy as np

from swift.llm import (HfDataset, InferArguments, Messages, SwiftPipeline, Template, get_template, load_dataset,
merge_lora, sample_dataset)
from swift.llm import (HfDataset, InferArguments, InferRequest, Messages, SwiftPipeline, Template, get_template,
load_dataset, merge_lora, sample_dataset)
from swift.utils import append_to_jsonl, get_logger
from .infer_engine import InferEngine
from .protocol import InferRequest, RequestConfig
from .protocol import RequestConfig

logger = get_logger()

Expand Down Expand Up @@ -131,8 +131,8 @@ def run(self) -> List[Dict[str, Any]]:
result = self.infer_cli()
else:
result = self.infer_dataset()
if args.result_path is not None:
logger.info(f'The inference results have been saved to result_path: `{result_path}`.')
if self.result_path is not None:
logger.info(f'The inference results have been saved to result_path: `{self.result_path}`.')
return result

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion swift/llm/infer/infer_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from abc import ABC, abstractmethod
from typing import AsyncIterator, Iterator, List, Optional, Union

from swift.llm import InferRequest
from swift.plugin import Metric
from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, InferRequest, RequestConfig
from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, RequestConfig


class BaseInferEngine(ABC):
Expand Down
Loading

0 comments on commit 7407569

Please sign in to comment.