Skip to content

Commit

Permalink
update infer pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 24, 2024
1 parent ee4bae1 commit 0a17ee7
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 118 deletions.
5 changes: 3 additions & 2 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .dataset import (AlpacaPreprocessor, MessagesPreprocessor, AutoPreprocessor, DatasetName, DATASET_MAPPING,
MediaResource, register_dataset, register_dataset_info, dataset_map, stat_dataset, LLMDataset,
LLMIterableDataset, LazyLLMDataset, ConstantLengthDataset, print_example, sort_by_max_length,
standard_keys, load_dataset, DATASET_TYPE, HfDataset)
standard_keys, load_dataset, DATASET_TYPE, HfDataset, sample_dataset)
from .utils import (deep_getattr, to_device, Messages, History, decode_base64, history_to_messages,
messages_to_history, safe_tokenizer_decode)
from .module_mapping import MODEL_KEYS_MAPPING, MultiModelKeys
Expand All @@ -45,7 +45,8 @@
'AlpacaPreprocessor', 'ClsPreprocessor', 'ComposePreprocessor', 'MessagesPreprocessor', 'DatasetName',
'DATASET_MAPPING', 'MediaResource', 'register_dataset', 'register_dataset_info', 'dataset_map',
'stat_dataset', 'LLMDataset', 'LLMIterableDataset', 'LazyLLMDataset', 'ConstantLengthDataset',
'print_example', 'sort_by_max_length', 'standard_keys', 'load_dataset', 'DATASET_TYPE', 'HfDataset'
'print_example', 'sort_by_max_length', 'standard_keys', 'load_dataset', 'DATASET_TYPE', 'HfDataset',
'sample_dataset'
],
'utils': [
'deep_getattr', 'to_device', 'History', 'Messages', 'decode_base64', 'history_to_messages',
Expand Down
6 changes: 3 additions & 3 deletions swift/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Callable, List, Optional, Type, TypeVar, Union, Generic
from typing import Callable, List, Optional, Type, TypeVar, Union

from swift.utils import get_logger, parse_args, seed_everything

logger = get_logger()


T_Args = TypeVar('T_Args')


class Pipeline(ABC):
args_class = None

Expand Down Expand Up @@ -45,4 +46,3 @@ def main(self):
@abstractmethod
def run(self):
pass

2 changes: 1 addition & 1 deletion swift/llm/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
RowPreprocessor)
from .register import DATASET_MAPPING, register_dataset, register_dataset_info
from .utils import (ConstantLengthDataset, HfDataset, LazyLLMDataset, LLMDataset, LLMIterableDataset, dataset_map,
print_example, sort_by_max_length, stat_dataset)
print_example, sample_dataset, sort_by_max_length, stat_dataset)


def _update_fingerprint_mac(*args, **kwargs):
Expand Down
1 change: 0 additions & 1 deletion swift/llm/dataset/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from datasets import Dataset as HfDataset
from datasets import IterableDataset as HfIterableDataset
from datasets import concatenate_datasets, interleave_datasets
from numpy.random import RandomState
from tqdm.auto import tqdm
from transformers.utils import strtobool

Expand Down
39 changes: 6 additions & 33 deletions swift/llm/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from datasets import concatenate_datasets, interleave_datasets
from modelscope.hub.api import ModelScopeConfig
from modelscope.utils.config_ds import MS_CACHE_HOME
from numpy.random import RandomState

from swift.hub import HFHub, MSHub
from swift.utils import download_ms_file, get_logger, get_seed, safe_ddp_context, use_hf_hub
from .register import DATASET_MAPPING, DATASET_TYPE, DatasetMeta, SubsetDataset, register_dataset_info
from .utils import sample_dataset

logger = get_logger()

Expand Down Expand Up @@ -246,39 +246,12 @@ def _select_subsets(subsets: List[str], dataset_meta: DatasetMeta) -> List[Subse
subsets = [subset_mapping[subset_name].set_default(dataset_meta) for subset_name in subsets]
return subsets

@staticmethod
def sample_dataset(dataset: HfDataset,
dataset_sample: int,
random_state: Optional[RandomState] = None) -> HfDataset:
"""Sample dataset by a dataset_sample number
Args:
dataset: The dataset instance, iterable dataset is not supported
dataset_sample: The sample number
random_state: The random state
Returns:
The sampled dataset
"""
if random_state is None:
random_state = RandomState()

n_repeat_sample = dataset_sample // len(dataset)
n_random_sample = dataset_sample % len(dataset)
if n_repeat_sample >= 1 and n_random_sample >= 1:
logger.warning(f'dataset_sample:{dataset_sample} is greater than len(dataset):{len(dataset)}, '
'repeated sampling will be performed.')
idx = np.tile(range(len(dataset)), n_repeat_sample)
if n_random_sample >= 1:
idx_random = random_state.permutation(len(dataset))[:n_random_sample]
idx = np.concatenate([idx, idx_random])
dataset = dataset.select(idx)
return dataset

@staticmethod
def post_preprocess(
train_dataset: DATASET_TYPE,
dataset_sample: Optional[int] = None,
split_dataset_ratio: float = 0.,
random_state: Optional[RandomState] = None,
random_state: Optional[np.random.RandomState] = None,
streaming: bool = False,
*,
load_from_cache_file: bool = False,
Expand All @@ -305,7 +278,7 @@ def post_preprocess(
val_sample = dataset_sample
assert val_sample <= len(
val_dataset), f'dataset_sample: {dataset_sample}, len(val_dataset): {len(val_dataset)}'
val_dataset = DatasetLoader.sample_dataset(val_dataset, val_sample, random_state)
val_dataset = sample_dataset(val_dataset, val_sample, random_state)
else:
if split_dataset_ratio == 0:
train_sample = dataset_sample
Expand All @@ -319,7 +292,7 @@ def post_preprocess(
test_size=val_sample, seed=get_seed(random_state),
load_from_cache_file=load_from_cache_file).values()
assert train_sample > 0
train_dataset = DatasetLoader.sample_dataset(train_dataset, train_sample, random_state)
train_dataset = sample_dataset(train_dataset, train_sample, random_state)
return train_dataset, val_dataset

@staticmethod
Expand Down Expand Up @@ -430,7 +403,7 @@ def _parse_datasets(datasets: List[str]) -> List[str]:
def load_dataset(
datasets: List[str],
split_dataset_ratio: float = 0.,
dataset_seed: Union[int, RandomState] = 42,
dataset_seed: Union[int, np.random.RandomState] = 42,
*,
num_proc: int = 1,
strict: bool = True,
Expand Down Expand Up @@ -458,7 +431,7 @@ def load_dataset(
if isinstance(datasets, str):
datasets = [datasets]
if isinstance(dataset_seed, int):
dataset_seed = RandomState(dataset_seed)
dataset_seed = np.random.RandomState(dataset_seed)
datasets: List[str] = DatasetLoader._parse_datasets(datasets) # to dataset_names and register
train_datasets = []
val_datasets = []
Expand Down
27 changes: 27 additions & 0 deletions swift/llm/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,33 @@
os.environ['TOKENIZERS_PARALLELISM'] = 'true'


def sample_dataset(dataset: HfDataset,
dataset_sample: int,
random_state: Optional[np.random.RandomState] = None) -> HfDataset:
"""Sample dataset by a dataset_sample number
Args:
dataset: The dataset instance, iterable dataset is not supported
dataset_sample: The sample number
random_state: The random state
Returns:
The sampled dataset
"""
if random_state is None:
random_state = np.random.RandomState()

n_repeat_sample = dataset_sample // len(dataset)
n_random_sample = dataset_sample % len(dataset)
if n_repeat_sample >= 1 and n_random_sample >= 1:
logger.warning(f'dataset_sample:{dataset_sample} is greater than len(dataset):{len(dataset)}, '
'repeated sampling will be performed.')
idx = np.tile(range(len(dataset)), n_repeat_sample)
if n_random_sample >= 1:
idx_random = random_state.permutation(len(dataset))[:n_random_sample]
idx = np.concatenate([idx, idx_random])
dataset = dataset.select(idx)
return dataset


class LLMDataset(Dataset):
"""This class wraps the Dataset class, to offer the ability of custom dataset tokenizing"""

Expand Down
Loading

0 comments on commit 0a17ee7

Please sign in to comment.