-
-
Notifications
You must be signed in to change notification settings - Fork 927
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor utils.data module for line count linter (#1476)
- Loading branch information
Showing
5 changed files
with
376 additions
and
330 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
Data processing modules | ||
""" | ||
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401 | ||
from axolotl.utils.data.pretraining import ( # noqa: F401 | ||
encode_pretraining, | ||
wrap_pretraining_dataset, | ||
) | ||
from axolotl.utils.data.sft import ( # noqa: F401 | ||
get_dataset_wrapper, | ||
load_prepare_datasets, | ||
load_tokenized_prepared_datasets, | ||
prepare_dataset, | ||
) | ||
from axolotl.utils.data.utils import md5 # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
"""data handling specific to DPO""" | ||
|
||
import logging | ||
from pathlib import Path | ||
from typing import Any, List | ||
|
||
import yaml | ||
from datasets import concatenate_datasets, load_dataset, load_from_disk | ||
|
||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH | ||
from axolotl.prompt_strategies.dpo import load as load_dpo | ||
from axolotl.utils.data.utils import md5 | ||
from axolotl.utils.dict import DictDefault | ||
from axolotl.utils.distributed import is_main_process, zero_first | ||
|
||
LOG = logging.getLogger("axolotl") | ||
|
||
|
||
def _get_path(ds_hash, cfg): | ||
prepared_ds_path = ( | ||
Path(cfg.dataset_prepared_path) / ds_hash | ||
if cfg.dataset_prepared_path | ||
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash | ||
) | ||
|
||
return prepared_ds_path | ||
|
||
|
||
def _load_preprocessed_ds(cfg, sub_cfg): | ||
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) | ||
prepared_ds_path = _get_path(ds_hash, cfg) | ||
dataset = None | ||
|
||
# pylint: disable=duplicate-code | ||
if ( | ||
cfg.dataset_prepared_path | ||
and any(prepared_ds_path.glob("*")) | ||
and not cfg.is_preprocess | ||
): | ||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") | ||
dataset = load_from_disk(str(prepared_ds_path)) | ||
|
||
return dataset | ||
|
||
|
||
def _save_preprocessed_ds(cfg, sub_cfg, dataset): | ||
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) | ||
prepared_ds_path = _get_path(ds_hash, cfg) | ||
|
||
if cfg.is_preprocess and is_main_process(): | ||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") | ||
dataset.save_to_disk(str(prepared_ds_path)) | ||
|
||
|
||
def load_prepare_dpo_datasets(cfg): | ||
def load_split(dataset_cfgs, _cfg): | ||
split_datasets: List[Any] = [] | ||
for i, ds_cfg in enumerate(dataset_cfgs): | ||
if ds_cfg["ds_type"] == "json": | ||
for data_file in ds_cfg["data_files"]: | ||
data_files = {ds_cfg["split"]: data_file} | ||
ds = load_dataset( # pylint: disable=invalid-name | ||
"json", | ||
data_files=data_files, | ||
split=ds_cfg["split"], | ||
) | ||
split_datasets.insert(i, ds) | ||
else: | ||
ds = load_dataset( # pylint: disable=invalid-name | ||
ds_cfg["path"], | ||
split=ds_cfg["split"], | ||
) | ||
split_datasets.insert(i, ds) | ||
|
||
for i, data_set in enumerate(split_datasets): | ||
_type = dataset_cfgs[i]["type"] | ||
if _type: | ||
if isinstance(_type, DictDefault): | ||
_type = "user_defined.default" | ||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) | ||
split_datasets[i] = data_set.map( | ||
ds_transform_fn, | ||
desc="Mapping RL Dataset", | ||
) | ||
else: | ||
# If no `type` is provided, assume the dataset is already in the expected format with | ||
# "prompt", "chosen" and "rejected" already preprocessed | ||
split_datasets[i] = data_set | ||
|
||
return concatenate_datasets(split_datasets) | ||
|
||
with zero_first(is_main_process()): | ||
train_is_preprocessed = False | ||
eval_is_preprocessed = False | ||
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets): | ||
train_is_preprocessed = True | ||
else: | ||
train_dataset = load_split(cfg.datasets, cfg) | ||
|
||
eval_dataset = None | ||
if cfg.test_datasets: | ||
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets): | ||
eval_is_preprocessed = True | ||
else: | ||
eval_dataset = load_split(cfg.test_datasets, cfg) | ||
if not eval_dataset: | ||
eval_dataset = None | ||
|
||
if not train_is_preprocessed: | ||
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset) | ||
if eval_dataset and not eval_is_preprocessed: | ||
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset) | ||
|
||
return train_dataset, eval_dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
"""data handling specific to pretraining""" | ||
|
||
import functools | ||
import logging | ||
from collections import defaultdict | ||
from typing import Callable, Dict, List, Optional | ||
|
||
import torch | ||
from datasets import Dataset | ||
from torch.utils.data import RandomSampler | ||
from transformers import PreTrainedTokenizerBase | ||
|
||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq | ||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths | ||
from axolotl.utils.trainer import process_pretraining_datasets_for_packing | ||
|
||
LOG = logging.getLogger("axolotl") | ||
|
||
|
||
def encode_pretraining( | ||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] | ||
) -> Dict[str, List]: | ||
res = tokenizer( | ||
examples, | ||
truncation=True, | ||
max_length=max_tokens - 2, | ||
add_special_tokens=True, | ||
) | ||
# Convert to PyTorch tensors | ||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]] | ||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] | ||
new_input_ids = [] | ||
new_attention_mask = [] | ||
# Append EOS and PAD tokens to input_ids, and correct attention_mask | ||
for i, _ in enumerate(input_ids): | ||
input_ids[i] = torch.cat( | ||
( | ||
input_ids[i], | ||
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]), | ||
), | ||
dim=0, | ||
) | ||
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) | ||
|
||
# Concatenate tokens so that their lengths are less than max_tokens | ||
buffer_input_ids = torch.tensor([], dtype=torch.long) | ||
buffer_attention_mask = torch.tensor([], dtype=torch.long) | ||
|
||
for ids, mask in zip(input_ids, attention_mask): | ||
if buffer_input_ids.numel() == max_tokens: | ||
new_input_ids.append(buffer_input_ids) | ||
new_attention_mask.append(buffer_attention_mask) | ||
buffer_input_ids = torch.tensor([], dtype=torch.long) | ||
buffer_attention_mask = torch.tensor([], dtype=torch.long) | ||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) | ||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) | ||
elif buffer_input_ids.numel() + ids.numel() <= max_tokens: | ||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) | ||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) | ||
else: | ||
buffer_input_ids = torch.cat( | ||
( | ||
buffer_input_ids, | ||
torch.full( | ||
(max_tokens - buffer_input_ids.numel(),), | ||
tokenizer.pad_token_id, | ||
dtype=torch.long, | ||
), | ||
), | ||
dim=0, | ||
) | ||
buffer_attention_mask = torch.cat( | ||
( | ||
buffer_attention_mask, | ||
torch.full( | ||
(max_tokens - buffer_attention_mask.numel(),), | ||
0, | ||
dtype=torch.long, | ||
), | ||
), | ||
dim=0, | ||
) | ||
new_input_ids.append(buffer_input_ids) | ||
new_attention_mask.append(buffer_attention_mask) | ||
buffer_input_ids = torch.tensor([], dtype=torch.long) | ||
buffer_attention_mask = torch.tensor([], dtype=torch.long) | ||
|
||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) | ||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) | ||
|
||
if buffer_input_ids.numel() > 0: # for any leftover tokens | ||
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size | ||
buffer_input_ids = torch.cat( | ||
( | ||
buffer_input_ids, | ||
torch.full( | ||
(max_tokens - buffer_input_ids.numel(),), | ||
tokenizer.pad_token_id, | ||
dtype=torch.long, | ||
), | ||
), | ||
dim=0, | ||
) | ||
buffer_attention_mask = torch.cat( | ||
( | ||
buffer_attention_mask, | ||
torch.full( | ||
(max_tokens - buffer_attention_mask.numel(),), | ||
0, | ||
dtype=torch.long, | ||
), | ||
), | ||
dim=0, | ||
) | ||
new_input_ids.append(buffer_input_ids) | ||
new_attention_mask.append(buffer_attention_mask) | ||
|
||
ret = { | ||
"input_ids": [seq.tolist() for seq in new_input_ids], | ||
"labels": [seq.tolist() for seq in new_input_ids], | ||
"attention_mask": [seq.tolist() for seq in new_attention_mask], | ||
} | ||
|
||
LOG.debug(len(ret["input_ids"])) | ||
return ret | ||
|
||
|
||
def wrap_pretraining_dataset( | ||
dataset, | ||
tokenizer, | ||
cfg, | ||
ds_wrapper_fn, | ||
max_tokens=2048, | ||
batch_size=1, | ||
seed=42, | ||
buffer_size=10_000, | ||
): | ||
if cfg.sample_packing: | ||
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( | ||
tokenizer, | ||
return_tensors="pt", | ||
padding=True, | ||
pad_to_multiple_of=max_tokens * batch_size, | ||
multipack_attn=cfg.pretrain_multipack_attn, | ||
) | ||
encode = functools.partial( | ||
encode_packed_pretraining, | ||
collate_fn, | ||
ds_wrapper_fn, | ||
max_seq_length=max_tokens, | ||
batch_size=batch_size, | ||
multipack_attn=cfg.pretrain_multipack_attn, | ||
) | ||
# set this to 1 so downstream data_loader doesn't try to increase the batch again | ||
cfg.micro_batch_size = 1 | ||
else: | ||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) | ||
|
||
if cfg.shuffle_merged_datasets: | ||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) | ||
else: | ||
LOG.debug("NOT shuffling merged pretraining datasets") | ||
|
||
# remove all the existing columns after mapping since they end up having | ||
# a different length than the encoded/tokenized column | ||
# this is empty during streaming/pretraining | ||
remove_columns = [] | ||
if dataset.features is None: | ||
for first_row in dataset: | ||
remove_columns = first_row.keys() | ||
break | ||
else: | ||
remove_columns = dataset.features.keys() | ||
|
||
dataset = dataset.map( | ||
encode, | ||
batched=True, | ||
batch_size=buffer_size, | ||
# input_columns="text", | ||
remove_columns=remove_columns, | ||
) | ||
return dataset | ||
|
||
|
||
def encode_packed_pretraining( | ||
collate_fn, | ||
ds_wrapper: Callable, | ||
examples: Dict[str, List], | ||
max_seq_length: int = 2048, | ||
batch_size: int = 4, | ||
multipack_attn: Optional[bool] = False, | ||
) -> Dict[str, List]: | ||
# pylint: disable=duplicate-code | ||
# tokenize all the examples | ||
# rows get split with stride (overlap) | ||
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0] | ||
|
||
train_dataset = process_pretraining_datasets_for_packing( | ||
train_dataset, | ||
max_seq_length, | ||
skip_position_ids=not multipack_attn, | ||
) | ||
|
||
sampler = MultipackBatchSampler( | ||
RandomSampler(train_dataset), | ||
batch_size=1, | ||
drop_last=True, | ||
batch_max_len=batch_size * max_seq_length, | ||
lengths=get_dataset_lengths(train_dataset), | ||
) | ||
|
||
chunked_data = defaultdict(list) | ||
|
||
for batch in sampler: | ||
for data in batch: | ||
features = train_dataset[data] | ||
if "num_truncated_tokens" in features: | ||
del features["num_truncated_tokens"] | ||
if "num_truncated_tokens" in features: | ||
del features["num_truncated_tokens"] | ||
if "overflow_to_sample_mapping" in features: | ||
del features["overflow_to_sample_mapping"] | ||
if "labels" not in features: | ||
features["labels"] = features["input_ids"].copy() | ||
collated_features = collate_fn(features) | ||
|
||
for feature in features.keys(): | ||
if feature == "length": | ||
continue | ||
chunked_data[feature].append(collated_features[feature].squeeze(0)) | ||
|
||
return chunked_data |
Oops, something went wrong.