diff --git a/docker-compose.yml b/docker-compose.yml index 1a28a65..8e897e5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,6 +14,8 @@ services: - ./out/:/workspace/out/ - ./last_run_prepared/:/workspace/last_run_prepared/ - ./huggingface/:/root/.cache/huggingface/ + # Axolotl Patch + - ./patch/axolotl/data.py:/workspace/axolotl/src/axolotl/utils/data.py environment: - HF_HUB_OFFLINE=1 # environment: diff --git a/patch/axolotl/data.py b/patch/axolotl/data.py new file mode 100644 index 0000000..5785e17 --- /dev/null +++ b/patch/axolotl/data.py @@ -0,0 +1,951 @@ +# This file is copied from https://github.com/OpenAccess-AI-Collective/axolotl/blob/v0.4.0/src/axolotl/utils/data.py + +"""Module containing data utilities""" +import functools +import hashlib +import logging +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import yaml +from datasets import ( + Dataset, + DatasetDict, + concatenate_datasets, + load_dataset, + load_from_disk, +) +from huggingface_hub import hf_hub_download +from torch.utils.data import RandomSampler +from transformers import PreTrainedTokenizerBase + +from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.datasets import TokenizedPromptDataset +from axolotl.prompt_strategies import load +from axolotl.prompt_strategies.dpo import load as load_dpo +from axolotl.prompt_tokenizers import ( + AlpacaMultipleChoicePromptTokenizingStrategy, + AlpacaPromptTokenizingStrategy, + AlpacaReflectionPTStrategy, + GPTeacherPromptTokenizingStrategy, + JeopardyPromptTokenizingStrategy, + OpenAssistantPromptTokenizingStrategy, + SummarizeTLDRPromptTokenizingStrategy, +) +from axolotl.prompters import ( + AlpacaPrompter, + GPTeacherPrompter, + JeopardyPrompter, + MultipleChoiceConcisePrompter, + MultipleChoiceExplainPrompter, + Prompter, + ReflectAlpacaPrompter, + SummarizeTLDRPrompter, + UnsupportedPrompter, +) +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.trainer import ( + calculate_total_num_steps, + process_datasets_for_packing, + process_pretraining_datasets_for_packing, +) + +LOG = logging.getLogger("axolotl") + + +def md5(to_hash: str, encoding: str = "utf-8") -> str: + try: + return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() + except TypeError: + return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec + + +def prepare_dataset(cfg, tokenizer): + prompters = [] + if not cfg.pretraining_dataset: + with zero_first(is_main_process()): + if cfg.test_datasets: + train_dataset, _, prompters = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" + ) + _, eval_dataset, _ = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test" + ) + else: + train_dataset, eval_dataset, prompters = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + ) + else: + path = cfg.pretraining_dataset + name = None + if isinstance(cfg.pretraining_dataset, list) and isinstance( + cfg.pretraining_dataset[0], dict + ): + path = cfg.pretraining_dataset[0]["path"] + name = cfg.pretraining_dataset[0]["name"] + + train_dataset = load_pretraining_dataset( + path, + tokenizer, + cfg, + name=name, + max_tokens=cfg.sequence_len, + seed=cfg.seed or 42, + ) + # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 + train_dataset = train_dataset.with_format("torch") + eval_dataset = None + return train_dataset, eval_dataset, cfg.max_steps, prompters + + if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: + total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) + if total_eval_steps == 0: + raise ValueError( + "eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. " + ) + + if cfg.max_steps: + total_num_steps = min( + calculate_total_num_steps(cfg, train_dataset), cfg.max_steps + ) + LOG.info(f"Maximum number of steps set at {total_num_steps}") + else: + total_num_steps = calculate_total_num_steps(cfg, train_dataset) + return train_dataset, eval_dataset, total_num_steps, prompters + + +def load_tokenized_prepared_datasets( + tokenizer, + cfg, + default_dataset_prepared_path, + split="train", +) -> Tuple[DatasetDict, List[Prompter]]: + cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets + tokenizer_name = tokenizer.__class__.__name__ + ds_hash = str( + md5( + ( + str(cfg.sequence_len) + + "@" + + str(cfg.sample_packing) + + "@" + + str(cfg.eval_sample_packing) + + "@" + + str(cfg.group_by_length) + + "@" + + "|".join( + sorted( + [ + f"{d.path}:{d.type}:{d.shards}:{d.conversation}" + for d in cfg_datasets + ] + ) + ) + + "|" + + tokenizer_name + ) + ) + ) + prepared_ds_path = ( + Path(cfg.dataset_prepared_path) / ds_hash + if cfg.dataset_prepared_path + else Path(default_dataset_prepared_path) / ds_hash + ) + dataset = None + prompters = [] + use_auth_token = cfg.hf_use_auth_token + try: + if cfg.push_dataset_to_hub: + dataset = load_dataset( + f"{cfg.push_dataset_to_hub}/{ds_hash}", + token=use_auth_token, + ) + dataset = dataset[split] + except Exception: # pylint: disable=broad-except # nosec + pass + + if dataset: + ... + elif ( + cfg.dataset_prepared_path + and any(prepared_ds_path.glob("*")) + and not cfg.is_preprocess + ): + LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + dataset = load_from_disk(str(prepared_ds_path)) + LOG.info("Prepared dataset loaded from disk...") + else: + LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") + LOG.info("Loading raw datasets...") + if not cfg.is_preprocess: + LOG.warning( + "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset." + ) + + if cfg.seed: + seed = cfg.seed + else: + LOG.info("No seed provided, using default seed of 42") + seed = 42 + + datasets = [] + + def for_d_in_datasets(dataset_configs): + for dataset in dataset_configs: + if dataset.name and isinstance(dataset.name, list): + for name in dataset.name: + yield DictDefault({**dataset, "name": name}) + else: + yield dataset + + # pylint: disable=invalid-name + for config_dataset in for_d_in_datasets(cfg_datasets): + ds: Optional[Union[Dataset, DatasetDict]] = None + ds_from_hub = False + try: + load_dataset( + config_dataset.path, + name=config_dataset.name, + streaming=True, + token=use_auth_token, + ) + ds_from_hub = True + except (FileNotFoundError, ConnectionError): + pass + + ds_from_cloud = False + storage_options = {} + remote_file_system = None + if config_dataset.path.startswith("s3://"): + try: + import aiobotocore.session # type: ignore + import s3fs # type: ignore + except ImportError as exc: + raise ImportError( + "s3:// paths require aiobotocore and s3fs to be installed" + ) from exc + + # Takes credentials from ~/.aws/credentials for default profile + s3_session = aiobotocore.session.AioSession(profile="default") + storage_options = {"session": s3_session} + remote_file_system = s3fs.S3FileSystem(**storage_options) + elif config_dataset.path.startswith( + "gs://" + ) or config_dataset.path.startswith("gcs://"): + try: + import gcsfs # type: ignore + except ImportError as exc: + raise ImportError( + "gs:// or gcs:// paths require gcsfs to be installed" + ) from exc + + # gcsfs will use default credentials from the environment else anon + # https://gcsfs.readthedocs.io/en/latest/#credentials + storage_options = {"token": None} + remote_file_system = gcsfs.GCSFileSystem(**storage_options) + # TODO: Figure out how to get auth creds passed + # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"): + # try: + # import adlfs + # except ImportError as exc: + # raise ImportError( + # "adl:// or abfs:// paths require adlfs to be installed" + # ) from exc + + # # Gen 1 + # storage_options = { + # "tenant_id": TENANT_ID, + # "client_id": CLIENT_ID, + # "client_secret": CLIENT_SECRET, + # } + # # Gen 2 + # storage_options = { + # "account_name": ACCOUNT_NAME, + # "account_key": ACCOUNT_KEY, + # } + + # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) + try: + if remote_file_system and remote_file_system.exists( + config_dataset.path + ): + ds_from_cloud = True + except (FileNotFoundError, ConnectionError): + pass + + # prefer local dataset, even if hub exists + local_path = Path(config_dataset.path) + if local_path.exists(): + if local_path.is_dir(): + # TODO dirs with arrow or parquet files could be loaded with `load_from_disk` + ds = load_dataset( + config_dataset.path, + name=config_dataset.name, + data_files=config_dataset.data_files, + streaming=False, + split=None, + ) + elif local_path.is_file(): + ds_type = get_ds_type(config_dataset) + + ds = load_dataset( + ds_type, + name=config_dataset.name, + data_files=config_dataset.path, + streaming=False, + split=None, + ) + else: + raise ValueError( + "unhandled dataset load: local path exists, but is neither a directory or a file" + ) + elif ds_from_hub: + ds = load_dataset( + config_dataset.path, + name=config_dataset.name, + streaming=False, + data_files=config_dataset.data_files, + token=use_auth_token, + ) + elif ds_from_cloud and remote_file_system: + if remote_file_system.isdir(config_dataset.path): + ds = load_from_disk( + config_dataset.path, + storage_options=storage_options, + ) + elif remote_file_system.isfile(config_dataset.path): + ds_type = get_ds_type(config_dataset) + ds = load_dataset( + ds_type, + name=config_dataset.name, + data_files=config_dataset.path, + streaming=False, + split=None, + storage_options=storage_options, + ) + else: + if isinstance(config_dataset.data_files, str): + fp = hf_hub_download( + repo_id=config_dataset.path, + repo_type="dataset", + filename=config_dataset.data_files, + ) + elif isinstance(config_dataset.data_files, list): + fp = [] + for file in config_dataset.data_files: + fp.append( + hf_hub_download( + repo_id=config_dataset.path, + repo_type="dataset", + filename=file, + ) + ) + else: + raise ValueError( + "data_files must be either a string or list of strings" + ) + ds = load_dataset( + "json", + name=config_dataset.name, + data_files=fp, + streaming=False, + split=None, + ) + if not ds: + raise ValueError("unhandled dataset load") + + d_base_type = d_prompt_style = None + d_type = config_dataset.type + if isinstance(d_type, str): + d_type_split = d_type.split(":") + d_base_type = d_type_split[0] + d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None + + if config_dataset.split and config_dataset.split in ds: + ds = ds[config_dataset.split] + elif split in ds: + ds = ds[split] + elif isinstance(ds, DatasetDict): + raise ValueError( + f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" + ) + + # support for using a subset of the data + if config_dataset.shards: + shards_idx = config_dataset.get("shards_idx", 0) + ds = ds.shuffle(seed=seed).shard( + num_shards=config_dataset.shards, index=shards_idx + ) + + dataset_wrapper, dataset_prompter = get_dataset_wrapper( + config_dataset=config_dataset, + dataset=ds, + tokenizer=tokenizer, + cfg=cfg, + d_base_type=d_base_type, + d_prompt_style=d_prompt_style, + ) + datasets.append(dataset_wrapper) + prompters.append(dataset_prompter) + + LOG.info("merging datasets") + dataset = concatenate_datasets(datasets) + + if len(datasets) > 1: + LOG.info("shuffle merged datasets") + dataset = dataset.shuffle(seed=seed) + + dataset, _ = process_datasets_for_packing(cfg, dataset, None) + + if cfg.local_rank == 0: + LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") + dataset.save_to_disk(str(prepared_ds_path)) + if cfg.push_dataset_to_hub: + LOG.info( + f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" + ) + dataset.push_to_hub( + f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True + ) + + return dataset, prompters + + +def get_ds_type(config_dataset: DictDefault): + """ + Get the dataset type from the path if it's not specified + """ + ds_type = "json" + if config_dataset.ds_type: + ds_type = config_dataset.ds_type + elif ".parquet" in config_dataset.path: + ds_type = "parquet" + elif ".arrow" in config_dataset.path: + ds_type = "arrow" + elif ".csv" in config_dataset.path: + ds_type = "csv" + elif ".txt" in config_dataset.path: + ds_type = "text" + return ds_type + + +def load_prepare_datasets( + tokenizer: PreTrainedTokenizerBase, + cfg, + default_dataset_prepared_path, + split="train", +) -> Tuple[Dataset, Dataset, List[Prompter]]: + dataset, prompters = load_tokenized_prepared_datasets( + tokenizer, cfg, default_dataset_prepared_path + ) + + if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: + LOG.info( + f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" + ) + dataset = dataset.shard( + num_shards=cfg.dataset_shard_num, + index=cfg.dataset_shard_idx, + ) + + if split == "train" and cfg.val_set_size: + # ensure we end up with the same fingerprint by doing rank0 first and being able to cache + to_hash_train = ( + dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(cfg.val_set_size) + + "|" + + "train" + + "|" + + str(cfg.seed or 42) + ) + to_hash_test = ( + dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(cfg.val_set_size) + + "|" + + "test" + + "|" + + str(cfg.seed or 42) + ) + train_fingerprint = md5(to_hash_train) + test_fingerprint = md5(to_hash_test) + + dataset = dataset.train_test_split( + test_size=cfg.val_set_size, + shuffle=False, + seed=cfg.seed or 42, + train_new_fingerprint=train_fingerprint, + test_new_fingerprint=test_fingerprint, + ) + + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + elif split == "test": + train_dataset = None + eval_dataset = dataset + else: + train_dataset = dataset + eval_dataset = None + + return train_dataset, eval_dataset, prompters + + +def get_dataset_wrapper( + config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style +): + dataset_wrapper = None + dataset_prompter = None + + ds_kwargs = { + "process_count": cfg.dataset_processes, + "keep_in_memory": cfg.dataset_keep_in_memory is True, + } + + if ( + "input_ids" in dataset.features + and "attention_mask" in dataset.features + and "labels" in dataset.features + ): + # dataset is already tokenized, just drop it straight in + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = dataset + elif isinstance(config_dataset.type, DictDefault): + ds_strategy = load( + "user_defined", tokenizer, cfg, config_dataset.type.to_dict() + ) + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + elif d_base_type == "alpaca": + dataset_prompter = AlpacaPrompter(d_prompt_style) + ds_strategy = AlpacaPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + dataset_wrapper = ds_wrapper + elif d_base_type == "explainchoice": + dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style) + ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + dataset_wrapper = ds_wrapper + elif d_base_type == "concisechoice": + dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style) + ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + dataset_wrapper = ds_wrapper + elif d_base_type == "summarizetldr": + dataset_prompter = SummarizeTLDRPrompter(d_prompt_style) + ds_strategy = SummarizeTLDRPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + dataset_wrapper = ds_wrapper + elif d_base_type == "jeopardy": + dataset_prompter = JeopardyPrompter(d_prompt_style) + ds_strategy = JeopardyPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + dataset_wrapper = ds_wrapper + elif d_base_type == "oasst": + dataset_prompter = AlpacaPrompter(d_prompt_style) + ds_strategy = OpenAssistantPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + dataset_wrapper = ds_wrapper + elif d_base_type == "gpteacher": + dataset_prompter = GPTeacherPrompter(d_prompt_style) + ds_strategy = GPTeacherPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + dataset_wrapper = ds_wrapper + elif d_base_type == "reflection": + dataset_prompter = ReflectAlpacaPrompter(d_prompt_style) + ds_strategy = AlpacaReflectionPTStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + dataset_wrapper = ds_wrapper + else: + suffix = "" + if ":load_" in config_dataset.type: + suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?" + LOG.error( + f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}" + ) + raise ValueError( + f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}" + ) + + return dataset_wrapper, dataset_prompter + + +def encode_pretraining( + tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] +) -> Dict[str, List]: + res = tokenizer( + examples, + truncation=True, + max_length=max_tokens - 2, + add_special_tokens=True, + ) + # Convert to PyTorch tensors + input_ids = [torch.tensor(seq) for seq in res["input_ids"]] + attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] + new_input_ids = [] + new_attention_mask = [] + # Append EOS and PAD tokens to input_ids, and correct attention_mask + for i, _ in enumerate(input_ids): + input_ids[i] = torch.cat( + ( + input_ids[i], + torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]), + ), + dim=0, + ) + attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) + + # Concatenate tokens so that their lengths are less than max_tokens + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + + for ids, mask in zip(input_ids, attention_mask): + if buffer_input_ids.numel() == max_tokens: + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + elif buffer_input_ids.numel() + ids.numel() <= max_tokens: + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + else: + buffer_input_ids = torch.cat( + ( + buffer_input_ids, + torch.full( + (max_tokens - buffer_input_ids.numel(),), + tokenizer.pad_token_id, + dtype=torch.long, + ), + ), + dim=0, + ) + buffer_attention_mask = torch.cat( + ( + buffer_attention_mask, + torch.full( + (max_tokens - buffer_attention_mask.numel(),), + 0, + dtype=torch.long, + ), + ), + dim=0, + ) + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + + if buffer_input_ids.numel() > 0: # for any leftover tokens + while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size + buffer_input_ids = torch.cat( + ( + buffer_input_ids, + torch.full( + (max_tokens - buffer_input_ids.numel(),), + tokenizer.pad_token_id, + dtype=torch.long, + ), + ), + dim=0, + ) + buffer_attention_mask = torch.cat( + ( + buffer_attention_mask, + torch.full( + (max_tokens - buffer_attention_mask.numel(),), + 0, + dtype=torch.long, + ), + ), + dim=0, + ) + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + + ret = { + "input_ids": [seq.tolist() for seq in new_input_ids], + "labels": [seq.tolist() for seq in new_input_ids], + "attention_mask": [seq.tolist() for seq in new_attention_mask], + } + + LOG.debug(len(ret["input_ids"])) + return ret + + +def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): + if cfg.sample_packing: + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding=True, + pad_to_multiple_of=max_tokens * cfg.micro_batch_size, + ) + encode = functools.partial( + encode_packed_pretraining, + tokenizer, + collate_fn, + max_seq_length=max_tokens, + batch_size=cfg.micro_batch_size, + ) + # set this to 1 so downstream data_loader doesn't try to increase the batch again + cfg.micro_batch_size = 1 + else: + encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + + dataset = load_dataset(path, streaming=True, split="train", name=name) + dataset = dataset.shuffle(seed=seed, buffer_size=10_000) + dataset = dataset.map( + encode, + batched=True, + batch_size=10_000, + input_columns="text", + # remove all the existing columns after mapping since they end up having + # a different length than the encoded/tokenized column + remove_columns=dataset.features.keys(), + desc="Encoding Pretraining", + ) + return dataset + + +def encode_packed_pretraining( + tokenizer: PreTrainedTokenizerBase, + collate_fn, + examples: List[str], + max_seq_length: int = 2048, + batch_size: int = 4, +) -> Dict[str, List]: + # pylint: disable=duplicate-code + # tokenize all the examples + # rows get split with stride (overlap) + res = tokenizer( + examples, + truncation=True, + max_length=max_seq_length - 1, + add_special_tokens=True, + return_overflowing_tokens=True, + stride=256, + ) + + input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]] + attention_mask = [seq + [1] for seq in res["attention_mask"]] + + tokenized_examples = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + train_dataset = Dataset.from_dict(tokenized_examples) + train_dataset = process_pretraining_datasets_for_packing( + train_dataset, max_seq_length + ) + + sampler = MultipackBatchSampler( + RandomSampler(train_dataset), + batch_size=batch_size, + drop_last=True, + batch_max_len=batch_size * max_seq_length, + lengths=get_dataset_lengths(train_dataset), + ) + + chunked_data = defaultdict(list) + + for data in sampler: + features = train_dataset[data] + features["labels"] = features["input_ids"].copy() + collated_features = collate_fn(features) + + for feature in features.keys(): + if feature == "length": + continue + chunked_data[feature].append(collated_features[feature].squeeze(0)) + + return chunked_data + + +def _get_path(ds_hash, cfg): + prepared_ds_path = ( + Path(cfg.dataset_prepared_path) / ds_hash + if cfg.dataset_prepared_path + else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash + ) + + return prepared_ds_path + + +def _load_preprocessed_ds(cfg, sub_cfg): + ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) + prepared_ds_path = _get_path(ds_hash, cfg) + dataset = None + + if ( + cfg.dataset_prepared_path + and any(prepared_ds_path.glob("*")) + and not cfg.is_preprocess + ): + LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + dataset = load_from_disk(str(prepared_ds_path)) + + return dataset + + +def _save_preprocessed_ds(cfg, sub_cfg, dataset): + ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) + prepared_ds_path = _get_path(ds_hash, cfg) + + if cfg.is_preprocess and is_main_process(): + LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + dataset.save_to_disk(str(prepared_ds_path)) + + +def load_prepare_dpo_datasets(cfg): + def load_split(dataset_cfgs, _cfg): + split_datasets: List[Any] = [] + for i, ds_cfg in enumerate(dataset_cfgs): + if ds_cfg["ds_type"] == "json": + for data_file in ds_cfg["data_files"]: + data_files = {ds_cfg["split"]: data_file} + ds = load_dataset( # pylint: disable=invalid-name + "json", + data_files=data_files, + split=ds_cfg["split"], + ) + split_datasets.insert(i, ds) + else: + ds = load_dataset( # pylint: disable=invalid-name + ds_cfg["path"], + split=ds_cfg["split"], + ) + split_datasets.insert(i, ds) + + for i, data_set in enumerate(split_datasets): + _type = dataset_cfgs[i]["type"] + if _type: + ds_transform_fn = load_dpo(_type, _cfg) + split_datasets[i] = data_set.map( + ds_transform_fn, + desc="Mapping RL Dataset", + ) + else: + # If no `type` is provided, assume the dataset is already in the expected format with + # "prompt", "chosen" and "rejected" already preprocessed + split_datasets[i] = data_set + + return concatenate_datasets(split_datasets) + + with zero_first(is_main_process()): + train_is_preprocessed = False + eval_is_preprocessed = False + if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets): + train_is_preprocessed = True + else: + train_dataset = load_split(cfg.datasets, cfg) + + eval_dataset = None + if cfg.test_datasets: + if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets): + eval_is_preprocessed = True + else: + eval_dataset = load_split(cfg.test_datasets, cfg) + if not eval_dataset: + eval_dataset = None + + if not train_is_preprocessed: + _save_preprocessed_ds(cfg, cfg.datasets, train_dataset) + if eval_dataset and not eval_is_preprocessed: + _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset) + + return train_dataset, eval_dataset \ No newline at end of file