Skip to content

Commit

Permalink
improve iterable support
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 2, 2025
1 parent e52cf85 commit a952e84
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 76 deletions.
9 changes: 7 additions & 2 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,20 @@ def cli():

@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--iterable/--no-iterable",
default=False,
help="Use IterableDataset for streaming processing of large datasets",
)
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
def preprocess(config: str, **kwargs):
def preprocess(config: str, iterable: bool, **kwargs):
"""Preprocess datasets before training."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

from axolotl.cli.preprocess import do_cli

do_cli(config=config, **kwargs)
do_cli(config=config, iterable=iterable, **kwargs)


@cli.command()
Expand Down
10 changes: 8 additions & 2 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import warnings
from pathlib import Path
from typing import Union
from typing import Optional, Union

import fire
import transformers
Expand All @@ -28,11 +28,17 @@
LOG = logging.getLogger("axolotl.cli.preprocess")


def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
def do_cli(
config: Union[Path, str] = Path("examples/"),
iterable: Optional[bool] = False,
**kwargs,
):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
if iterable:
parsed_cfg.preprocess_iterable = iterable
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def process(self, dataset):
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
map_kwargs["batch_size"] = 1_000
if self.prompt_tokenizer.filter_rows:
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/integrations/kd/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ def transform_logprobs(self, sample):

return sample

def tokenize_prompt(self, prompt):
def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super().tokenize_prompt(prompt)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt)

Expand Down
45 changes: 37 additions & 8 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional

from transformers import ProcessorMixin
Expand Down Expand Up @@ -193,7 +194,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):

def __init__(
self,
prompter,
prompter: ChatTemplatePrompter,
tokenizer,
train_on_inputs,
sequence_len,
Expand All @@ -220,22 +221,50 @@ def messages(self):
def messages(self, messages):
self._messages = messages

def tokenize_prompt(self, prompt):
@property
def supports_batched(self) -> bool:
# Let calling code know we can handle lists of examples
return True

def tokenize_prompt(self, prompt: dict[str, Any]) -> Dict[str, List[List[int]]]:
"""
Public method that can handle either a single prompt or a batch of prompts.
"""

res = defaultdict(lambda: [])
feature_names = list(prompt.keys())

# Process each prompt individually
for row in zip(*prompt.values()):
tokenized_prompt = self._tokenize_single_prompt(
dict(zip(feature_names, row))
)
for key, val in tokenized_prompt.items():
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])

# If there are no examples left, return an empty dictionary
if not res:
return {}

return dict(res)

def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
# Old simple legacy behavior that works reliably.
if (
not self.roles_to_train
and not self.train_on_eos
and not self.prompter.message_field_training
and not self.prompter.message_field_training_detail
and not self.prompter.message_field_training # type: ignore
and not self.prompter.message_field_training_detail # type: ignore
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt(
prompt_ids = self.prompter.build_prompt( # type: ignore
turns[:-1],
add_generation_prompt=True,
images=images,
)
tokenized_res = self.prompter.build_prompt(turns, images=images)
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
Expand All @@ -256,7 +285,7 @@ def tokenize_prompt(self, prompt):
return tokenized_prompt

turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns)
input_ids = self.prompter.build_prompt(turns) # type: ignore
labels = [IGNORE_TOKEN_ID] * len(input_ids)

last_eos_idx = -1
Expand Down Expand Up @@ -286,7 +315,7 @@ def tokenize_prompt(self, prompt):

if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
token_offsets = self.prompter.get_offsets_for_train_detail(
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
content, train_detail
)
LOG.debug(f"Token offsets: {token_offsets}")
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class SFTDataset(BaseModel):
type: Optional[Union[str, UserDefinedPrompterType]] = None
input_transform: Optional[str] = None
shards: Optional[int] = None
preprocess_shards: Optional[int] = None
conversation: Optional[str] = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: Optional[
Expand Down Expand Up @@ -800,6 +801,7 @@ class Config:

# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None
preprocess_iterable: Optional[bool] = None

total_num_tokens: Optional[int] = None
total_supervised_tokens: Optional[int] = None
Expand Down
31 changes: 29 additions & 2 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
concatenate_datasets,
load_dataset,
load_from_disk,
Expand Down Expand Up @@ -250,13 +251,25 @@ def for_d_in_datasets(dataset_configs):
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else:
yield dataset

streaming_ds = False
if cfg.preprocess_iterable:
streaming_ds = True
# pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token
config_dataset, use_auth_token, streaming=streaming_ds
)

d_base_type = d_prompt_style = None
Expand Down Expand Up @@ -313,7 +326,21 @@ def for_d_in_datasets(dataset_configs):

if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path))
if isinstance(dataset, IterableDataset):

def gen_from_iter_ds(_ds, _=None):
yield from _ds

ds_from_iter = Dataset.from_generator(
functools.partial(gen_from_iter_ds, dataset),
features=dataset.features,
num_proc=cfg.dataset_processes,
split=split,
gen_kwargs={"_": list(range(cfg.dataset_processes))},
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else:
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
LOG.info(
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
Expand Down
14 changes: 8 additions & 6 deletions src/axolotl/utils/data/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def get_ds_type(config_dataset: DictDefault):
return ds_type


def load_dataset_w_config(config_dataset, auth_token):
def load_dataset_w_config(
config_dataset, auth_token, streaming=False
) -> Union[Dataset, DatasetDict]:
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
Expand Down Expand Up @@ -117,7 +119,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
streaming=streaming,
split=None,
)
else:
Expand Down Expand Up @@ -153,7 +155,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
streaming=streaming,
data_files=config_dataset.data_files,
token=auth_token,
revision=config_dataset.revision,
Expand All @@ -172,7 +174,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
streaming=streaming,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
Expand All @@ -183,7 +185,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
streaming=streaming,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
Expand Down Expand Up @@ -213,7 +215,7 @@ def load_dataset_w_config(config_dataset, auth_token):
"json",
name=config_dataset.name,
data_files=fp,
streaming=False,
streaming=streaming,
split=None,
)
if not ds:
Expand Down
Loading

0 comments on commit a952e84

Please sign in to comment.