diff --git a/swift/llm/__init__.py b/swift/llm/__init__.py index d3c5a8b6a..bd41ec149 100644 --- a/swift/llm/__init__.py +++ b/swift/llm/__init__.py @@ -18,7 +18,7 @@ from .dataset import (AlpacaPreprocessor, MessagesPreprocessor, AutoPreprocessor, DatasetName, DATASET_MAPPING, MediaResource, register_dataset, register_dataset_info, dataset_map, stat_dataset, LLMDataset, LLMIterableDataset, LazyLLMDataset, ConstantLengthDataset, print_example, sort_by_max_length, - standard_keys, load_dataset, DATASET_TYPE, HfDataset) + standard_keys, load_dataset, DATASET_TYPE, HfDataset, sample_dataset) from .utils import (deep_getattr, to_device, Messages, History, decode_base64, history_to_messages, messages_to_history, safe_tokenizer_decode) from .module_mapping import MODEL_KEYS_MAPPING, MultiModelKeys @@ -45,7 +45,8 @@ 'AlpacaPreprocessor', 'ClsPreprocessor', 'ComposePreprocessor', 'MessagesPreprocessor', 'DatasetName', 'DATASET_MAPPING', 'MediaResource', 'register_dataset', 'register_dataset_info', 'dataset_map', 'stat_dataset', 'LLMDataset', 'LLMIterableDataset', 'LazyLLMDataset', 'ConstantLengthDataset', - 'print_example', 'sort_by_max_length', 'standard_keys', 'load_dataset', 'DATASET_TYPE', 'HfDataset' + 'print_example', 'sort_by_max_length', 'standard_keys', 'load_dataset', 'DATASET_TYPE', 'HfDataset', + 'sample_dataset' ], 'utils': [ 'deep_getattr', 'to_device', 'History', 'Messages', 'decode_base64', 'history_to_messages', diff --git a/swift/llm/base.py b/swift/llm/base.py index eb5072442..cfceb1b8a 100644 --- a/swift/llm/base.py +++ b/swift/llm/base.py @@ -1,14 +1,15 @@ import os from abc import ABC, abstractmethod from datetime import datetime -from typing import Callable, List, Optional, Type, TypeVar, Union, Generic +from typing import Callable, List, Optional, Type, TypeVar, Union + from swift.utils import get_logger, parse_args, seed_everything logger = get_logger() - T_Args = TypeVar('T_Args') + class Pipeline(ABC): args_class = None @@ -45,4 +46,3 @@ def main(self): @abstractmethod def run(self): pass - diff --git a/swift/llm/dataset/__init__.py b/swift/llm/dataset/__init__.py index 931891271..582482b79 100644 --- a/swift/llm/dataset/__init__.py +++ b/swift/llm/dataset/__init__.py @@ -11,7 +11,7 @@ RowPreprocessor) from .register import DATASET_MAPPING, register_dataset, register_dataset_info from .utils import (ConstantLengthDataset, HfDataset, LazyLLMDataset, LLMDataset, LLMIterableDataset, dataset_map, - print_example, sort_by_max_length, stat_dataset) + print_example, sample_dataset, sort_by_max_length, stat_dataset) def _update_fingerprint_mac(*args, **kwargs): diff --git a/swift/llm/dataset/dataset/dataset.py b/swift/llm/dataset/dataset/dataset.py index 024164f7c..9f109cd3d 100644 --- a/swift/llm/dataset/dataset/dataset.py +++ b/swift/llm/dataset/dataset/dataset.py @@ -11,7 +11,6 @@ from datasets import Dataset as HfDataset from datasets import IterableDataset as HfIterableDataset from datasets import concatenate_datasets, interleave_datasets -from numpy.random import RandomState from tqdm.auto import tqdm from transformers.utils import strtobool diff --git a/swift/llm/dataset/loader.py b/swift/llm/dataset/loader.py index a5d3b8b7b..db2db787d 100644 --- a/swift/llm/dataset/loader.py +++ b/swift/llm/dataset/loader.py @@ -11,11 +11,11 @@ from datasets import concatenate_datasets, interleave_datasets from modelscope.hub.api import ModelScopeConfig from modelscope.utils.config_ds import MS_CACHE_HOME -from numpy.random import RandomState from swift.hub import HFHub, MSHub from swift.utils import download_ms_file, get_logger, get_seed, safe_ddp_context, use_hf_hub from .register import DATASET_MAPPING, DATASET_TYPE, DatasetMeta, SubsetDataset, register_dataset_info +from .utils import sample_dataset logger = get_logger() @@ -246,39 +246,12 @@ def _select_subsets(subsets: List[str], dataset_meta: DatasetMeta) -> List[Subse subsets = [subset_mapping[subset_name].set_default(dataset_meta) for subset_name in subsets] return subsets - @staticmethod - def sample_dataset(dataset: HfDataset, - dataset_sample: int, - random_state: Optional[RandomState] = None) -> HfDataset: - """Sample dataset by a dataset_sample number - Args: - dataset: The dataset instance, iterable dataset is not supported - dataset_sample: The sample number - random_state: The random state - Returns: - The sampled dataset - """ - if random_state is None: - random_state = RandomState() - - n_repeat_sample = dataset_sample // len(dataset) - n_random_sample = dataset_sample % len(dataset) - if n_repeat_sample >= 1 and n_random_sample >= 1: - logger.warning(f'dataset_sample:{dataset_sample} is greater than len(dataset):{len(dataset)}, ' - 'repeated sampling will be performed.') - idx = np.tile(range(len(dataset)), n_repeat_sample) - if n_random_sample >= 1: - idx_random = random_state.permutation(len(dataset))[:n_random_sample] - idx = np.concatenate([idx, idx_random]) - dataset = dataset.select(idx) - return dataset - @staticmethod def post_preprocess( train_dataset: DATASET_TYPE, dataset_sample: Optional[int] = None, split_dataset_ratio: float = 0., - random_state: Optional[RandomState] = None, + random_state: Optional[np.random.RandomState] = None, streaming: bool = False, *, load_from_cache_file: bool = False, @@ -305,7 +278,7 @@ def post_preprocess( val_sample = dataset_sample assert val_sample <= len( val_dataset), f'dataset_sample: {dataset_sample}, len(val_dataset): {len(val_dataset)}' - val_dataset = DatasetLoader.sample_dataset(val_dataset, val_sample, random_state) + val_dataset = sample_dataset(val_dataset, val_sample, random_state) else: if split_dataset_ratio == 0: train_sample = dataset_sample @@ -319,7 +292,7 @@ def post_preprocess( test_size=val_sample, seed=get_seed(random_state), load_from_cache_file=load_from_cache_file).values() assert train_sample > 0 - train_dataset = DatasetLoader.sample_dataset(train_dataset, train_sample, random_state) + train_dataset = sample_dataset(train_dataset, train_sample, random_state) return train_dataset, val_dataset @staticmethod @@ -430,7 +403,7 @@ def _parse_datasets(datasets: List[str]) -> List[str]: def load_dataset( datasets: List[str], split_dataset_ratio: float = 0., - dataset_seed: Union[int, RandomState] = 42, + dataset_seed: Union[int, np.random.RandomState] = 42, *, num_proc: int = 1, strict: bool = True, @@ -458,7 +431,7 @@ def load_dataset( if isinstance(datasets, str): datasets = [datasets] if isinstance(dataset_seed, int): - dataset_seed = RandomState(dataset_seed) + dataset_seed = np.random.RandomState(dataset_seed) datasets: List[str] = DatasetLoader._parse_datasets(datasets) # to dataset_names and register train_datasets = [] val_datasets = [] diff --git a/swift/llm/dataset/utils.py b/swift/llm/dataset/utils.py index 9abb7ea41..9ec0c201d 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/llm/dataset/utils.py @@ -25,6 +25,33 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'true' +def sample_dataset(dataset: HfDataset, + dataset_sample: int, + random_state: Optional[np.random.RandomState] = None) -> HfDataset: + """Sample dataset by a dataset_sample number + Args: + dataset: The dataset instance, iterable dataset is not supported + dataset_sample: The sample number + random_state: The random state + Returns: + The sampled dataset + """ + if random_state is None: + random_state = np.random.RandomState() + + n_repeat_sample = dataset_sample // len(dataset) + n_random_sample = dataset_sample % len(dataset) + if n_repeat_sample >= 1 and n_random_sample >= 1: + logger.warning(f'dataset_sample:{dataset_sample} is greater than len(dataset):{len(dataset)}, ' + 'repeated sampling will be performed.') + idx = np.tile(range(len(dataset)), n_repeat_sample) + if n_random_sample >= 1: + idx_random = random_state.permutation(len(dataset))[:n_random_sample] + idx = np.concatenate([idx, idx_random]) + dataset = dataset.select(idx) + return dataset + + class LLMDataset(Dataset): """This class wraps the Dataset class, to offer the ability of custom dataset tokenizing""" diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 5db594c3e..5fbbc8c14 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -4,21 +4,18 @@ import re from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import numpy as np -from swift.llm import ( - HfDataset, InferArguments, Messages, Pipeline, Template, get_template, load_dataset, merge_lora -) -from swift.tuners import Swift -from swift.utils import append_to_jsonl, get_logger, get_main, read_multi_line, seed_everything +from swift.llm import (HfDataset, InferArguments, Messages, Pipeline, Template, get_template, load_dataset, merge_lora, + sample_dataset) +from swift.utils import append_to_jsonl, get_logger from .infer_engine import InferEngine, InferRequest, RequestConfig logger = get_logger() - @dataclass class InferCliState: # None: use default-system. '': not use system. @@ -43,12 +40,6 @@ def copy(self): return InferCliState(self.system, deepcopy(self.messages), self.images.copy(), self.audios.copy(), self.videos.copy(), self.multiline_mode, self.input_system) - def to_infer_request(self) -> InferRequest: - infer_state = self.copy() - if infer_state.system is not None: - infer_state.messages.insert(0, {'role': 'system', 'content': infer_state.system}) - return InferRequest(infer_state.messages, infer_state.images, infer_state.audios, infer_state.videos) - def add_query(self, query: str) -> None: self.messages.append({'role': 'user', 'content': query}) @@ -57,6 +48,8 @@ def add_response(self, response: str) -> None: def to_dict(self): infer_state = self.copy() + if infer_state.system is not None: + infer_state.messages.insert(0, {'role': 'system', 'content': infer_state.system}) return { 'messages': infer_state.messages, 'images': infer_state.images, @@ -65,15 +58,16 @@ def to_dict(self): } -class InferPipeline(Pipeline): +class InferPipeline(Pipeline, InferEngine): args_class = InferArguments def __init__(self, args: Union[List[str], InferArguments, None] = None) -> None: - self.args: InferArguments = self.parse_args(args) + self.args = self.parse_args(args) if args.merge_lora: merge_lora(args, device_map=args.merge_device_map) - self.infer_engine = self.get_infer_engine() - self.template = self.get_template(self.infer_engine.tokenizer) + self.template = self._get_template(self.tokenizer) + self.random_state = np.random.RandomState(args.dataset_seed) + super().__init__() def get_infer_engine(self) -> InferEngine: args = self.args @@ -82,7 +76,6 @@ def get_infer_engine(self) -> InferEngine: 'model_type': args.model_type, 'revision': args.model_revision, 'torch_dtype': args.torch_dtype, - 'use_hf': args.use_hf, } if args.infer_backend == 'pt': from .infer_engine import PtEngine @@ -119,10 +112,25 @@ def get_infer_engine(self) -> InferEngine: return infer_engine_cls(**kwargs) - - def run(self) -> None: + def _get_template(self, tokenizer) -> Template: args = self.args + template = get_template( + args.template_type, + tokenizer, + args.system, + args.max_length, + truncation_strategy=args.truncation_strategy, + max_pixels=args.max_pixels, + tools_prompt=args.tools_prompt) + logger.info(f'default_system: {template.default_system}') + return template + def run(self) -> List[Dict[str, Any]]: + args = self.args + if args.dataset and args.split_dataset_ratio > 0 or args.val_dataset: + return self.infer_dataset() + else: + return self.infer_cli() @staticmethod def _input_mm_data(infer_state: InferCliState) -> None: @@ -141,7 +149,8 @@ def _input_mm_file(mm_type: Literal['image', 'video', 'audio']) -> str: mm_val = getattr(infer_state, mm_key) mm_val.append(_input_mm_file(mm_type)) - def _prepare_save_result(self, args: InferArguments) -> str: + def _prepare_save_result(self) -> str: + args = self.args if args.result_dir is not None: result_dir = args.result_dir else: @@ -172,7 +181,7 @@ def _input_multiline(prompt: str) -> str: def _input_text(multiline_mode: bool, input_system: bool) -> str: if multiline_mode: addi_prompt = 'MS' if input_system else 'M' - text = InferEngine._input_multiline(f'<<<[{addi_prompt}] ') + text = InferPipeline._input_multiline(f'<<<[{addi_prompt}] ') else: addi_prompt = 'S' if input_system else '' text = input(f'<<<[{addi_prompt}] ') @@ -206,8 +215,8 @@ def _check_query(infer_state: InferCliState, query: str) -> Optional[str]: return return query - @staticmethod - def _prepare_request_config(args: InferArguments) -> RequestConfig: + def _prepare_request_config(self) -> RequestConfig: + args = self.args temperature = args.temperature if not args.do_sample: temperature = 0 @@ -221,29 +230,30 @@ def _prepare_request_config(args: InferArguments) -> RequestConfig: stream=args.stream, repetition_penalty=args.repetition_penalty) - def get_template(self, tokenizer) -> Template: - args = self.args - template = get_template( - args.template_type, - tokenizer, - args.system, - args.max_length, - truncation_strategy=args.truncation_strategy, - loss_scale=args.loss_scale_config, - max_pixels=args.max_pixels, - sequence_parallel_size=args.sequence_parallel_size, - tools_prompt=args.tools_prompt) - logger.info(f'default_system: {template.default_system}') - return template + def infer_single(self, infer_request: InferRequest, request_config: RequestConfig) -> Tuple[str, Messages]: + messages = infer_request.messages + res_or_gen = self.infer(self.template, [infer_request], request_config, use_tqdm=False) + if request_config.stream: + response = '' + for res in res_or_gen: + delta = res[0].choices[0].delta + print(delta, end='', flush=True) + response += delta + print() + else: + response = res_or_gen[0].choices[0].message.content + print(response) + messages.append({'role': 'assistant', 'content': response}) + return response, messages - def infer_cli(self, args: InferArguments) -> List[Dict[str, Any]]: - template = self.prepare_template(args) + def infer_cli(self) -> List[Dict[str, Any]]: + args = self.args + template = self.template result_path = None if args.save_result: - result_path = self._prepare_save_result(args) - request_config = self._prepare_request_config(args) + result_path = self._prepare_save_result() + request_config = self._prepare_request_config() - result = [] logger.info('Input `exit` or `quit` to exit the conversation.') logger.info('Input `multi-line` to switch to multi-line input mode.') logger.info('Input `reset-system` to reset the system and clear the history.') @@ -253,6 +263,7 @@ def infer_cli(self, args: InferArguments) -> List[Dict[str, Any]]: logger.info('The current template only supports single-round dialogues.') infer_state = InferCliState() + result_list = [] while True: if not template.support_multi_round: infer_state.clear() @@ -264,39 +275,66 @@ def infer_cli(self, args: InferArguments) -> List[Dict[str, Any]]: continue infer_state.add_query(query) self._input_mm_data(infer_state) - infer_request = infer_state.to_infer_request() - res_or_gen = self.infer(template, [infer_request], request_config, use_tqdm=False) - if request_config.stream: - response = '' - for res in res_or_gen: - delta = res[0].choices[0].delta - print(delta, end='', flush=True) - response += delta - print() - else: - response = res_or_gen[0].choices[0].message.content + data = infer_state.to_dict() + response, messages = self.infer_single(InferRequest(**data), request_config) infer_state.add_response(response) - data = infer_state.to_dict() - result.append(data) + data['messages'] = messages + result_list.append(data) if result_path is not None: append_to_jsonl(result_path, data, strict=False) - return result + return result_list - def prepare_dataset(self, args: InferArguments) -> HfDataset: + def prepare_val_dataset(self) -> HfDataset: + args = self.args load_dataset(args.val_dataset, args.split_dataset_ratio) + dataset_kwargs = { + 'dataset_seed': args.dataset_seed, + 'num_proc': args.num_proc, + 'load_from_cache_file': args.load_from_cache_file, + 'download_mode': args.download_mode, + 'model_name': args.model_name, + 'model_author': args.model_author, + 'strict': False + } if len(args.val_dataset) > 0: _, val_dataset = load_dataset(args.val_dataset, 1.0, **dataset_kwargs) else: - _, val_dataset = load_dataset(args.dataset, args.dataset_test_ratio, **dataset_kwargs) + _, val_dataset = load_dataset(args.dataset, args.split_dataset_ratio, **dataset_kwargs) + assert val_dataset is not None + if args.val_dataset_sample is not None: + val_dataset = sample_dataset(val_dataset, args.val_dataset_sample, self.random_state) + return val_dataset - def infer_dataset(self, args: InferArguments): - template = self.prepare_template(args) + def infer_dataset(self) -> List[Dict[str, Any]]: + args = self.args result_path = None if args.save_result: - result_path = self._prepare_save_result(args) - request_config = self._prepare_request_config(args) - - - + result_path = self._prepare_save_result() + logger.info(f'result_path: {result_path}') + request_config = self._prepare_request_config() + logger.info(f'request_config: {request_config}') + + val_dataset = self.prepare_val_dataset() + logger.info(f'val_dataset: {val_dataset}') + result_list = [] + if request_config.stream: + for data in val_dataset: + response, messages = self.infer_single(InferRequest(**data), request_config) + data['messages'] = messages + result_list.append(data) + if result_path is not None: + append_to_jsonl(result_path, data) + else: + infer_requests = [] + for data in val_dataset: + infer_requests.append(InferRequest(**data)) + resp_list = self.infer(self.template, infer_requests, request_config, use_tqdm=True) + for data, resp in zip(val_dataset, resp_list): + response = resp.choices[0].message.content + data['messages'].append({'role': 'assistant', 'content': response}) + result_list.append(data) + if result_path is not None: + append_to_jsonl(result_path, result_list) + return result_list diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/llm/infer/infer_engine/infer_engine.py index 7e159d4bc..8ab0e8dbb 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/llm/infer/infer_engine/infer_engine.py @@ -22,7 +22,6 @@ class InferEngine(BaseInferEngine): - def _prepare_model_tokenizer( self, model_id_or_path: str, diff --git a/swift/llm/template/register.py b/swift/llm/template/register.py index fbc497e9f..dfe2fa049 100644 --- a/swift/llm/template/register.py +++ b/swift/llm/template/register.py @@ -25,10 +25,11 @@ def get_template( max_length: Optional[int] = None, *, truncation_strategy: Literal['delete', 'truncation_left'] = 'delete', - loss_scale: str = 'default', max_pixels: int = -1, # h * w - sequence_parallel_size: int = 1, - tools_prompt: str = 'react_en') -> 'Template': + tools_prompt: str = 'react_en', + # train + loss_scale: str = 'default', + sequence_parallel_size: int = 1) -> 'Template': template_info = TEMPLATE_MAPPING[template_type] # To ensure that obtaining the same template_type multiple times does not interfere with each other. template = deepcopy(template_info['template']) diff --git a/swift/utils/np_utils.py b/swift/utils/np_utils.py index aeee8b2f3..90148b1fd 100644 --- a/swift/utils/np_utils.py +++ b/swift/utils/np_utils.py @@ -2,12 +2,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from numpy import ndarray -from numpy.random import RandomState -from pandas import DataFrame +import pandas as pd -def transform_jsonl_to_df(dict_list: List[Dict[str, Any]]) -> DataFrame: +def transform_jsonl_to_df(dict_list: List[Dict[str, Any]]) -> pd.DataFrame: """Relevant function: `io_utils.read_from_jsonl()`""" data_dict: Dict[str, List[Any]] = {} for i, obj in enumerate(dict_list): @@ -17,18 +15,18 @@ def transform_jsonl_to_df(dict_list: List[Dict[str, Any]]) -> DataFrame: data_dict[k].append(v) for k in set(data_dict.keys()) - set(obj.keys()): data_dict[k].append(None) - return DataFrame.from_dict(data_dict) + return pd.DataFrame.from_dict(data_dict) -def get_seed(random_state: Optional[RandomState] = None) -> int: +def get_seed(random_state: Optional[np.random.RandomState] = None) -> int: if random_state is None: - random_state = RandomState() + random_state = np.random.RandomState() seed_max = np.iinfo(np.int32).max seed = random_state.randint(0, seed_max) return seed -def stat_array(array: Union[ndarray, List[int], 'torch.Tensor']) -> Tuple[Dict[str, float], str]: +def stat_array(array: Union[np.ndarray, List[int], 'torch.Tensor']) -> Tuple[Dict[str, float], str]: if isinstance(array, list): array = np.array(array) mean = array.mean().item()