diff --git a/examples/BeautifulPrompt/.gitignore b/examples/BeautifulPrompt/.gitignore new file mode 100644 index 0000000..9c601bb --- /dev/null +++ b/examples/BeautifulPrompt/.gitignore @@ -0,0 +1,152 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/.build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# IDE +.idea/ +.vscode/ + +# macos +*.DS_Store +#data/ + +docs/.build + +# pytorch checkpoint +*.pt + +# wandb log +example/wandb/ + +outputs/* +logs/* +data/* +!data/README.md +wandb/* \ No newline at end of file diff --git a/examples/BeautifulPrompt/README.md b/examples/BeautifulPrompt/README.md new file mode 100644 index 0000000..022994a --- /dev/null +++ b/examples/BeautifulPrompt/README.md @@ -0,0 +1,43 @@ +# BeautifulPrompt +This project is implemented for the EMNLP Industry Track 2023 paper: "BeautifulPrompt: Towards Automatic Prompt Engineering for Text-to-Image Synthesis". Our code is based on pytorch and huggingface transformers. + +## Data & Models +We released our collected [dataset](https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/BeautifulPrompt/data.json), which includes prompt pairs and various scores, and also released a more extensive [rm_aesthetic dataset](https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/BeautifulPrompt/rm_aesthetic.json). + +We released the following models: +- [alibaba-pai/pai-bloom-1b1-text2prompt-sd](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd) +- [alibaba-pai/pai-bloom-1b1-text2prompt-sd-v2](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd-v2) + +## Run +### Installation +```bash +conda create -n trlx python=3.10 +conda activate trlx + +pip install -e . +pip install tensorboardX==2.6.0 +``` + +### Training +```bash +# Step 1 +bash scripts/sft.sh +# Step 2 +bash scripts/rm_aes.sh +bash scripts/rm_ps.sh +# Step 3 +bash scripts/ppo.sh +``` + +### Evaluation +```bash +pip install pillow +pip install image-reward==1.2 +pip install git+https://github.com/openai/CLIP +pip install diffusers + +bash scripts/eval.sh +``` + +## Acknowledgement +This repo benefits from [trlx](https://github.com/CarperAI/trlx). Thanks for their wonderful works. diff --git a/examples/BeautifulPrompt/beautiful_prompt/data.py b/examples/BeautifulPrompt/beautiful_prompt/data.py new file mode 100644 index 0000000..03644da --- /dev/null +++ b/examples/BeautifulPrompt/beautiful_prompt/data.py @@ -0,0 +1,211 @@ +import copy +import random +from dataclasses import dataclass, field +from typing import Callable, Dict, Sequence +import string + +import torch +import torch.distributed as dist +from torch.utils.data import Dataset +import transformers +from tqdm import tqdm +import trlx.utils.logging as logging + +from beautiful_prompt.utils import read_json, is_rank_0 + +logger = logging.get_logger() + + +IGNORE_INDEX = -100 + +def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + +def preprocess( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + max_length: int, +) -> Dict: + """Preprocess the data by tokenizing.""" + examples = [s + t for s, t in zip(sources, targets)] + examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)] + input_ids = examples_tokenized["input_ids"] + labels = copy.deepcopy(input_ids) + for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): + label[:source_len] = IGNORE_INDEX + return dict(input_ids=input_ids, labels=labels) + +class SFTDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_length: int = 512): + super(SFTDataset, self).__init__() + logger.info("Loading data...") + + data = read_json(data_path) + + new_data = [] + for d in data: + if d['pick_score'] < 18.5: + continue + + token_len = len(tokenizer.tokenize(d['prompt'])) + if token_len < 25: + continue + if token_len < 35 and random.random() < 0.3: + continue + new_data.append(d) + + data = new_data + + sources = [f"Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {d['raw_prompt']}\nOutput: " for d in data] + + targets = [d['prompt'].strip() + tokenizer.eos_token for d in data] + + logger.info(f'Num examples: {len(data)}') + logger.info(f'Example 1: {sources[0]}{targets[0]}') + + logger.info("Tokenizing inputs... This may take some time...") + data_dict = preprocess(sources, targets, tokenizer, max_length) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) + + +@dataclass +class DataCollatorForSFTDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + return dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + +class RMDatasetForAES(Dataset): + """ + Dataset for reward model for aesthetic score + + Args: + dataset: dataset for reward model + tokenizer: tokenizer for reward model + max_length: max length of input + """ + + def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> None: + super().__init__() + + logger.info("Loading data...") + + data = read_json(data_path) + + self.inputs = [] + + for d in tqdm(data, disable=not is_rank_0()): + inp = d['prompt'] + tokenizer.eos_token + inp = tokenizer(inp, + max_length=max_length, + padding=False, + truncation=True, + return_tensors="pt") + self.inputs.append({ + "input_ids": inp['input_ids'][0], + "labels": torch.tensor(d['aesthetic_score']) + }) + + def __len__(self): + length = len(self.inputs) + return length + + def __getitem__(self, idx): + return dict(input_ids=self.inputs[idx]["input_ids"], labels=self.inputs[idx]["labels"]) + +class RMDatasetForPS(Dataset): + """ + Dataset for reward model for Pick Score + + Args: + dataset: dataset for reward model + tokenizer: tokenizer for reward model + max_length: max length of input + """ + + def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> None: + super().__init__() + + logger.info("Loading data...") + + data = read_json(data_path) + + self.inputs = [] + + for d in tqdm(data, disable=not is_rank_0()): + inp = f"Input: {d['raw_prompt']}\nOutput: {d['prompt']}{tokenizer.eos_token}" + inp = tokenizer(inp, + max_length=max_length, + padding=False, + truncation=True, + return_tensors="pt") + self.inputs.append({ + "input_ids": inp['input_ids'][0], + "labels": torch.tensor(d['pick_score']), + # "labels": torch.tensor(d['image_reward']) + }) + + def __len__(self): + length = len(self.inputs) + return length + + def __getitem__(self, idx): + return dict(input_ids=self.inputs[idx]["input_ids"], labels=self.inputs[idx]["labels"]) + +@dataclass +class DataCollatorForRMDataset(DataCollatorForSFTDataset): + """Collate examples for reward model.""" + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + return dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + diff --git a/examples/BeautifulPrompt/beautiful_prompt/evaluator.py b/examples/BeautifulPrompt/beautiful_prompt/evaluator.py new file mode 100644 index 0000000..5ccfbdd --- /dev/null +++ b/examples/BeautifulPrompt/beautiful_prompt/evaluator.py @@ -0,0 +1,286 @@ +import os +from typing import Any, List, Mapping, Union + +import clip +import ImageReward as RM +import torch +import torch.nn as nn +import torch.nn.functional as F +from ImageReward.ImageReward import MLP as RM_MLP +from PIL import Image +from transformers import AutoModel, AutoProcessor + + + +class Evaluator(nn.Module): + def __init__(self) -> None: + super().__init__() + + def check_imgs(self, imgs: Union[List[str], List[Image.Image]]): + new_imgs = [] + for img in imgs: + if isinstance(img, str): + assert os.path.isfile(img) + pil_image = Image.open(img) + new_imgs.append(pil_image) + elif isinstance(img, Image.Image): + new_imgs.append(img) + else: + raise TypeError(r'This imgs parameter type has not been supportted yet. Please pass PIL.Image or file path str.') + + return new_imgs + + def forward(self, prompts, imgs): + ''' + Compute scores based on the given prompts and images. + + Args: + prompts: A list of prompts. + imgs: A list of images. + + Returns: + A list of scores for each pair. + ''' + raise NotImplementedError() + +class ImageReward(Evaluator): + '''Reference: https://github.com/THUDM/ImageReward.''' + + def __init__(self, checkpoint: str='ImageReward-v1.0', device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu') -> None: + super().__init__() + self.model = RM.load(checkpoint, device=device) + self.device = device + + @torch.no_grad() + def forward(self, prompts, imgs): + ''' + Reference: https://github.com/THUDM/ImageReward/blob/main/ImageReward/ImageReward.py#L84. + ''' + + assert isinstance(prompts, list) + assert isinstance(imgs, list) + assert len(prompts) == len(imgs) + imgs = self.check_imgs(imgs) + + # text encode + text_input = self.model.blip.tokenizer(prompts, padding='max_length', truncation=True, max_length=35, return_tensors='pt').to(self.device) + + imgs = torch.stack([self.model.preprocess(img) for img in imgs]).to(self.model.device) + img_embeds = self.model.blip.visual_encoder(imgs) + + # text encode cross attention with image + img_atts = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(self.device) + + text_output = self.model.blip.text_encoder(text_input.input_ids, + attention_mask=text_input.attention_mask, + encoder_hidden_states=img_embeds, + encoder_attention_mask=img_atts, + return_dict=True) + + txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim) + scores = self.model.mlp(txt_features).squeeze(dim=1) + scores = (scores - self.model.mean) / self.model.std + + return scores.cpu().tolist() + +class CLIPScore(Evaluator): + def __init__(self, device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu') -> None: + super().__init__() + self.model = RM.load_score('CLIP', device=device) + self.device = device + + @torch.no_grad() + def forward(self, prompts, imgs): + ''' + Reference: https://github.com/THUDM/ImageReward/blob/main/ImageReward/models/CLIPScore.py. + ''' + + assert isinstance(prompts, list) + assert isinstance(imgs, list) + assert len(prompts) == len(imgs) + + texts = clip.tokenize(prompts, truncate=True).to(self.device) + + txt_features = F.normalize(self.model.clip_model.encode_text(texts)).float() + + imgs = self.check_imgs(imgs) + + imgs = torch.stack([self.model.preprocess(img) for img in imgs]).to(self.device) + img_features = F.normalize(self.model.clip_model.encode_image(imgs)).float() + + rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True) + rewards = torch.squeeze(rewards, dim=1) + return rewards.cpu().tolist() + +class AestheticScore(Evaluator): + '''Reference: https://github.com/christophschuhmann/improved-aesthetic-predictor.''' + + def __init__(self, device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu') -> None: + super().__init__() + self.model = RM.load_score('Aesthetic', device=device) + self.device = device + + @torch.no_grad() + def forward(self, prompts, imgs): + ''' + Reference: https://github.com/THUDM/ImageReward/blob/main/ImageReward/models/AestheticScore.py#L45. + ''' + + assert isinstance(imgs, list) + + imgs = self.check_imgs(imgs) + + imgs = torch.stack([self.model.preprocess(img) for img in imgs]).to(self.device) + img_features = F.normalize(self.model.clip_model.encode_image(imgs)).float() + + scores = self.model.mlp(img_features) + scores = torch.squeeze(scores, dim=-1) + + return scores.cpu().tolist() + + +class PickScore(Evaluator): + '''Reference: https://github.com/yuvalkirstain/PickScore.''' + + def __init__(self, + processor_checkpoint: str = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K', + model_checkpoint: str = 'yuvalkirstain/PickScore_v1', + device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu') -> None: + + super().__init__() + self.processor = AutoProcessor.from_pretrained(processor_checkpoint) + self.model = AutoModel.from_pretrained(model_checkpoint).eval().to(device) + self.device = device + + @torch.no_grad() + def forward(self, prompts, imgs): + '''Reference: https://github.com/yuvalkirstain/PickScore#inference-with-pickscore.''' + + assert isinstance(imgs, list) + + imgs = self.check_imgs(imgs) + + imgs_inputs = self.processor( + images=imgs, + padding=True, + truncation=True, + max_length=77, + return_tensors='pt', + ).to(self.device) + + text_inputs = self.processor( + text=prompts, + padding=True, + truncation=True, + max_length=77, + return_tensors='pt', + ).to(self.device) + + img_embs = self.model.get_image_features(**imgs_inputs) + img_embs = img_embs / torch.norm(img_embs, dim=-1, keepdim=True) + + text_embs = self.model.get_text_features(**text_inputs) + text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True) + + # score + scores = self.model.logit_scale.exp() * (text_embs @ img_embs.T) + scores = scores.diagonal() + + return scores.cpu().tolist() + +class HPS(Evaluator): + '''Reference: https://github.com/tgxs002/align_sd.''' + + def __init__(self, + model_checkpoint: str, + device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu') -> None: + + super().__init__() + self.model, self.processor = clip.load('ViT-L/14', device=device) + params = torch.load(model_checkpoint)['state_dict'] + self.model.load_state_dict(params) + self.model.eval() + self.device = device + + @torch.no_grad() + def forward(self, prompts, imgs): + + assert isinstance(imgs, list) + + imgs = self.check_imgs(imgs) + + img_inputs = torch.stack([self.processor(img) for img in imgs]).to(self.device) + + text_inputs = clip.tokenize(prompts).to(self.device) + + image_features = self.model.encode_image(img_inputs) + text_features = self.model.encode_text(text_inputs) + + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + hps = image_features @ text_features.T + hps = hps.diagonal() + + return hps.cpu().tolist() + +class HPSv2(Evaluator): + '''Reference: https://github.com/tgxs002/HPSv2.''' + + + def __init__(self, + model_checkpoint: str, + local_dir: str = None, + cache_dir: str = None, + device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu') -> None: + + super().__init__() + + from hpsv2.open_clip import create_model_and_transforms, get_tokenizer + + self.model, _, self.processor = create_model_and_transforms( + 'ViT-H-14', + 'laion2B-s32B-b79K', + precision='amp', + device=device, + jit=False, + force_quick_gelu=False, + force_custom_text=False, + force_patch_dropout=False, + force_image_size=None, + pretrained_image=False, + image_mean=None, + image_std=None, + light_augmentation=True, + aug_cfg={}, + output_dict=True, + with_score_predictor=False, + with_region_predictor=False, + cache_dir=cache_dir, + local_dir=local_dir + ) + checkpoint = torch.load(model_checkpoint) + self.model.load_state_dict(checkpoint['state_dict']) + self.model.eval() + self.tokenizer = get_tokenizer('ViT-H-14') + self.device = device + + @torch.no_grad() + def forward(self, prompts, imgs): + + assert isinstance(imgs, list) + + imgs = self.check_imgs(imgs) + + img_inputs = torch.stack([self.processor(img) for img in imgs]).to(self.device) + + text_inputs = self.tokenizer(prompts).to(self.device) + + with torch.cuda.amp.autocast(): + outputs = self.model(img_inputs, text_inputs) + image_features, text_features = outputs['image_features'], outputs['text_features'] + + score = image_features @ text_features.T + score = score.diagonal() + + return score.cpu().tolist() diff --git a/examples/BeautifulPrompt/beautiful_prompt/trainer.py b/examples/BeautifulPrompt/beautiful_prompt/trainer.py new file mode 100644 index 0000000..42c821d --- /dev/null +++ b/examples/BeautifulPrompt/beautiful_prompt/trainer.py @@ -0,0 +1,277 @@ +import math +import time +from abc import ABC + +import torch +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer import get_scheduler +import trlx.utils.logging as logging +from accelerate import Accelerator + +from beautiful_prompt.utils import get_optimizer_grouped_parameters + +logger = logging.get_logger() + +class SFTTrainer(ABC): + """ + SFTTrainer + + Args: + model (torch.nn.Module):The model to be trained. + tokenizer (PreTrainedTokenizerBase): The tokenizer used to preprocess the input data. + train_dataloader (DataLoader): The dataloader containing the training data. + save_path (str): The path to save the trained model. + logging_dir (str): The directory to save the training logs. Defaults to 'logs'. + lr (float): The learning rate for the optimizer. Defaults to 1e-5. + batch_size (int): The batch size for training. Defaults to 8. + weight_decay (float): The weight decay for the optimizer. Defaults to 1e-3. + epochs (int): The number of epochs for training. Defaults to 3. + """ + def __init__( + self, + model: torch.nn.Module, + tokenizer: PreTrainedTokenizerBase, + train_dataloader: DataLoader, + save_path: str, + logging_dir: str = 'logs', + + lr: float = 1e-5, + batch_size: int = 8, + weight_decay: float = 1e-3, + epochs: int = 3 + ) -> None: + super().__init__() + + self.accelerator = Accelerator(log_with="tensorboard", project_dir=logging_dir) + ds_plugin = self.accelerator.state.deepspeed_plugin + + self.model = model + self.tokenizer = tokenizer + self.save_path = save_path + + optimizer_grouped_parameters = get_optimizer_grouped_parameters(model, weight_decay) + + self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95)) + + self.train_dataloader = train_dataloader + + self.epochs = epochs + + self.accumulation_steps = ds_plugin.gradient_accumulation_steps + + num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps + max_steps = math.ceil(self.epochs * num_update_steps_per_epoch) + + self.scheduler = get_scheduler("cosine", + self.optimizer, + num_warmup_steps=math.ceil(max_steps * 0.03), + num_training_steps=max_steps) + + self.model, self.optimizer, self.train_dataloader, self.scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.train_dataloader, self.scheduler + ) + + self._init_logger(lr, batch_size, epochs, ds_plugin) + + def _init_logger(self, lr, batch_size, epochs, ds_plugin): + if self.accelerator.is_main_process: + config = { + "lr": lr, + "batch_size": batch_size, + "epochs": epochs, + "mixed_precision": self.accelerator.mixed_precision, + "num_gpus": self.accelerator.num_processes, + "gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps, + "gradient_clipping": ds_plugin.gradient_clipping, + "zero_stage": ds_plugin.zero_stage, + } + + self.accelerator.init_trackers( + project_name=f'beautilful-prompt sft [{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())}]', + config=config + ) + + def train(self): + step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.epochs), + desc=f'steps', + disable=not self.accelerator.is_main_process) + + current_step = 0 + for epoch in range(self.epochs): + + self.model.train() + for batch_id, batch in enumerate(self.train_dataloader): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch(epoch) + + input_ids = batch["input_ids"].to(torch.cuda.current_device()) + attention_mask = batch["attention_mask"].to(torch.cuda.current_device()) + labels = batch["labels"].to(torch.cuda.current_device()) + + with self.accelerator.accumulate(self.model): + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + loss = outputs.loss + self.accelerator.backward(loss) + self.optimizer.step() + self.scheduler.step() + self.optimizer.zero_grad() + + if self.accelerator.sync_gradients: + self.accelerator.log({ + "loss": loss.item(), + "lr": self.scheduler.get_last_lr()[0], + "epoch": epoch, + "batch_id": batch_id + }, step=current_step) + + current_step += 1 + step_bar.update() + + step_bar.close() + self.save_model(self.save_path) + + def save_model(self, save_path: str) -> None: + self.accelerator.wait_for_everyone() + self.accelerator.unwrap_model(self.model).save_pretrained( + save_path, + save_function=self.accelerator.save, + is_main_process=self.accelerator.is_main_process, + state_dict=self.accelerator.get_state_dict(self.model) + ) + + if self.accelerator.is_main_process: + self.tokenizer.save_pretrained(save_path) + +class RMTrainer(SFTTrainer): + """ + Trainer to use while training reward model. + + Args: + model (torch.nn.Module):The model to be trained. + tokenizer (PreTrainedTokenizerBase): The tokenizer used to preprocess the input data. + train_dataloader (DataLoader): The dataloader containing the training data. + eval_dataloader (DataLoader): The dataloader containing the evaluate data. + save_path (str): The path to save the trained model. + logging_dir (str): The directory to save the training logs. Defaults to 'logs'. + eval_steps (int): The interval to evaluate the model during training. Defaults to 1000. + lr (float): The learning rate for the optimizer. Defaults to 1e-5. + batch_size (int): The batch size for training. Defaults to 8. + weight_decay (float): The weight decay for the optimizer. Defaults to 1e-3. + epochs (int): The number of epochs for training. Defaults to 3. + """ + def __init__( + self, + model: torch.nn.Module, + tokenizer: PreTrainedTokenizerBase, + train_dataloader: DataLoader, + save_path: str, + eval_dataloader: DataLoader = None, + logging_dir: str = 'logs', + eval_steps: int = 1000, + + lr: float = 1e-5, + batch_size: int = 8, + weight_decay: float = 1e-3, + epochs: int = 3 + ) -> None: + super().__init__( + model, + tokenizer, + train_dataloader, + save_path, + logging_dir, + lr, + batch_size, + weight_decay, + epochs + ) + + if eval_dataloader is not None: + self.eval_dataloader = self.accelerator.prepare(eval_dataloader) + self.eval_steps = eval_steps + else: + self.eval_steps = -1 + + def _init_logger(self, lr, batch_size, epochs, ds_plugin): + if self.accelerator.is_main_process: + config = { + "lr": lr, + "batch_size": batch_size, + "epochs": epochs, + "mixed_precision": self.accelerator.mixed_precision, + "num_gpus": self.accelerator.num_processes, + "gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps, + "gradient_clipping": ds_plugin.gradient_clipping, + "zero_stage": ds_plugin.zero_stage, + } + + self.accelerator.init_trackers( + project_name=f'beautilful-prompt rm [{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())}]', + config=config + ) + + def train(self): + step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.epochs), + desc=f'steps', + disable=not self.accelerator.is_main_process) + + current_step = 0 + for epoch in range(self.epochs): + + self.model.train() + for batch_id, batch in enumerate(self.train_dataloader): + if isinstance(self.train_dataloader.sampler, DistributedSampler): + self.train_dataloader.sampler.set_epoch(epoch) + + input_ids = batch["input_ids"].to(torch.cuda.current_device()) + attention_mask = batch["attention_mask"].to(torch.cuda.current_device()) + labels = torch.tensor(batch["labels"], dtype=self.model.dtype).to(torch.cuda.current_device()) + + with self.accelerator.accumulate(self.model): + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + loss = outputs.loss + self.accelerator.backward(loss) + self.optimizer.step() + self.scheduler.step() + self.optimizer.zero_grad() + + if self.accelerator.sync_gradients: + self.accelerator.log({ + "loss": loss.item(), + "lr": self.scheduler.get_last_lr()[0], + "epoch": epoch, + "batch_id": batch_id + }, step=current_step) + + current_step += 1 + step_bar.update() + + if self.eval_steps > 0 and (current_step + 1) % self.eval_steps: + result = self.eval() + self.accelerator.log(result, step=current_step) + + step_bar.close() + self.save_model(self.save_path) + + def eval(self): + self.model.eval() + total_loss = 0 + with torch.no_grad(): + for batch in self.eval_dataloader: + input_ids = batch["input_ids"].to(torch.cuda.current_device()) + attention_mask = batch["attention_mask"].to(torch.cuda.current_device()) + labels = torch.tensor(batch["labels"], dtype=self.model.dtype).to(torch.cuda.current_device()) + + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + loss = outputs.loss + total_loss += loss + total_loss = total_loss / len(self.eval_dataloader) + + self.model.train() + return { + "mse": total_loss, + } diff --git a/examples/BeautifulPrompt/beautiful_prompt/utils.py b/examples/BeautifulPrompt/beautiful_prompt/utils.py new file mode 100644 index 0000000..2ad23bc --- /dev/null +++ b/examples/BeautifulPrompt/beautiful_prompt/utils.py @@ -0,0 +1,90 @@ +import json +import random +import os + +from PIL import Image +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F + +def read_json(file_path): + with open(file_path, 'r', encoding='utf-8') as file: + return json.load(file) + +def save_json(data, file_path): + with open(file_path, 'w', encoding='utf-8') as file: + json.dump(data, file, ensure_ascii=False, indent=2) + +def is_rank_0() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 + +def set_seed(seed: int): + """ + Sets seeds across package dependencies for reproducibility. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + +def get_optimizer_grouped_parameters( + model, + weight_decay, + no_decay_name_list=[ + "bias", "LayerNorm.weight" + ] +): + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in model.named_parameters() + if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p for n, p in model.named_parameters() + if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + +def init_sd_pipeline(device = "cuda" if torch.cuda.is_available() else "cpu", **kwargs): + ''' + Initializes the Stable Diffusion pipeline + + Args: + device: The device to put the loaded model. + kwargs: Keyword arguments to be passed to the underlying DiffusionPipeline object. + This can include any of the arguments accepted by the from_pretrained() method of the DiffusionPipeline class. + + Returns: + A diffusion pipeline. + + Example: + + ```python + >>> from beautiful_prompt.utils import init_sd_pipeline + >>> sd_pipeline = init_sd_pipeline() + >>> sd_pipeline.text2img(prompt, width=512, height=512, negative_prompt=neg_prompt, max_embeddings_multiples=3).images[0] + ''' + from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler + + sd_pipeline = DiffusionPipeline.from_pretrained( + pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", + custom_pipeline="waifu-research-department/long-prompt-weighting-pipeline", + safety_checker=None, # Comment it for working safely + revision="fp16", + torch_dtype=torch.float16, + **kwargs + ) + + sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipeline.scheduler.config) + sd_pipeline.set_progress_bar_config(disable=True) + sd_pipeline.to(device) + + return sd_pipeline diff --git a/examples/BeautifulPrompt/config/ppo.yaml b/examples/BeautifulPrompt/config/ppo.yaml new file mode 100644 index 0000000..1e66022 --- /dev/null +++ b/examples/BeautifulPrompt/config/ppo.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 2 + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: no +dynamo_config: {} +fsdp_config: {} +machine_rank: 0 +main_training_function: main +megatron_lm_config: {} +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +use_cpu: false diff --git a/examples/BeautifulPrompt/config/sft.yaml b/examples/BeautifulPrompt/config/sft.yaml new file mode 100644 index 0000000..a8c9a14 --- /dev/null +++ b/examples/BeautifulPrompt/config/sft.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: no +dynamo_config: {} +fsdp_config: {} +machine_rank: 0 +main_training_function: main +megatron_lm_config: {} +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +use_cpu: false diff --git a/examples/BeautifulPrompt/eval.py b/examples/BeautifulPrompt/eval.py new file mode 100644 index 0000000..131f563 --- /dev/null +++ b/examples/BeautifulPrompt/eval.py @@ -0,0 +1,279 @@ +import argparse +import json +import os +import time +import math +import logging +import random + +from prettytable import PrettyTable +import torch +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModel + +from beautiful_prompt.utils import ( + read_json, + save_json, + init_sd_pipeline, + set_seed +) +from beautiful_prompt.evaluator import ( + ImageReward, + CLIPScore, + AestheticScore, + PickScore, + HPS +) + +logging.getLogger('transformers').setLevel(logging.ERROR) + +def generate_prompts(args, data): + if args.method == 'raw': + for d in data: + d['generated_prompt'] = d['raw_prompt'] + + elif args.method == 'magic-prompt': + model = AutoModelForCausalLM.from_pretrained('Gustavosta/MagicPrompt-Stable-Diffusion') + tokenizer = AutoTokenizer.from_pretrained('Gustavosta/MagicPrompt-Stable-Diffusion') + + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + + model.resize_token_embeddings(len(tokenizer)) + model.to(args.device) + + raw_prompts = [x['raw_prompt'] for x in data] + generated_prompts = [] + + for i in tqdm(range(math.ceil(len(raw_prompts) / args.batch_size)), desc='Generating prompts...', disable=args.disable_tqdm): + batch_ixs = slice(i * args.batch_size, (i + 1) * args.batch_size) + + inputs = raw_prompts[batch_ixs] + + model_inputs = tokenizer( + inputs, + padding=True, + truncation=True, + max_length=args.max_length, + return_tensors='pt', + ).to(args.device) + + generations = model.generate( + **model_inputs, + max_length=args.max_length, + do_sample=False, + repetition_penalty=1.2 + ) + + generated_prompt = tokenizer.batch_decode(generations, skip_special_tokens=True) + generated_prompts.extend([x.strip() for x in generated_prompt]) + + for d, generated_prompt in zip(data, generated_prompts): + d['generated_prompt'] = generated_prompt + + del model + torch.cuda.empty_cache() + + elif args.method == 'chatgpt': + for d in data: + d['generated_prompt'] = d['chatgpt_prompt'] + + elif args.method == 'beautiful-prompt': + model = AutoModelForCausalLM.from_pretrained(args.model_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + model.to(args.device) + + raw_prompts = [x['raw_prompt'] for x in data] + generated_prompts = [] + + for i in tqdm(range(math.ceil(len(raw_prompts) / args.batch_size)), desc='Generating prompts...', disable=args.disable_tqdm): + batch_ixs = slice(i * args.batch_size, (i + 1) * args.batch_size) + + inputs = [f'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompts}\nOutput:' for raw_prompts in raw_prompts[batch_ixs]] + + model_inputs = tokenizer( + inputs, + padding=True, + truncation=True, + max_length=args.max_length, + return_tensors='pt', + ).to(args.device) + + generations = model.generate( + **model_inputs, + max_length=args.max_length, + do_sample=False, + repetition_penalty=1.2 + ) + + generated_prompt = tokenizer.batch_decode(generations[:, model_inputs.input_ids.size(1):], skip_special_tokens=True) + generated_prompts.extend([x.strip() for x in generated_prompt]) + + for d, generated_prompt in zip(data, generated_prompts): + d['generated_prompt'] = generated_prompt + + del model + torch.cuda.empty_cache() + + else: + raise NotImplementedError() + + return data + +def generate_images(args, data): + generator = torch.Generator(device='cpu').manual_seed(args.seed) + prompts = [x['generated_prompt'] for x in data] + + os.makedirs(args.imgs_path, exist_ok=True) + + sd_pipeline = init_sd_pipeline() + sd_pipeline.set_progress_bar_config(disable=True) + + for i in tqdm(range(math.ceil(len(prompts) / args.batch_size)), desc='Generating images...', disable=args.disable_tqdm): + batch_ixs = slice(i * args.batch_size, (i + 1) * args.batch_size) + + images = sd_pipeline.text2img( + prompts[batch_ixs], + negative_prompt='nsfw, ((ugly)), (duplicate), morbid, mutilated, [out of frame], (extra fingers), mutated hands', + width=512, + height=512, + max_embeddings_multiples=6, + num_inference_steps=args.num_inference_steps, + generator=generator).images + + for j, img in enumerate(images): + idx = i * args.batch_size + j + img_path = os.path.join(args.imgs_path, f'{idx}.png') + img.save(img_path) + data[idx]['img_path'] = img_path + + del sd_pipeline + torch.cuda.empty_cache() + + return data + +def compute_score(args, data, score_type): + if score_type == 'ImageReward': + evaluator = ImageReward(args.imagereward_path, device=args.device) + elif score_type == 'CLIPScore': + evaluator = CLIPScore(device=args.device) + elif score_type == 'Aesthetic': + evaluator = AestheticScore(device=args.device) + elif score_type == 'PickScore': + evaluator = PickScore(processor_checkpoint='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', + model_checkpoint=args.pickscore_path, + device=args.device) + elif score_type == 'HPS': + evaluator = HPS(model_checkpoint=args.hps_path, device=args.device) + else: + raise NotImplementedError(f'{score_type} is not implemented.') + + for i in tqdm(range(math.ceil(len(data) / args.batch_size)), desc=f'Computing {score_type}...', disable=args.disable_tqdm): + batch_ixs = slice(i * args.batch_size, (i + 1) * args.batch_size) + + # use raw_prompts for computing ImageReward, PickScore, etc... + prompts = [x['raw_prompt'] for x in data[batch_ixs]] + imgs = [x['img_path'] for x in data[batch_ixs]] + + scores = evaluator(prompts, imgs) + + for j, score in enumerate(scores): + idx = i * args.batch_size + j + data[idx][score_type] = score + + del evaluator + torch.cuda.empty_cache() + + return data + +def main(args): + data = read_json(args.data_path) + + if args.num_items is not None and len(data) > args.num_items: + data = random.sample(data, args.num_items) + + # 1. Generate prompts + data = generate_prompts(args, data) + save_json(data, args.result_path) + + # 2. Generate images and save them + data = generate_images(args, data) + save_json(data, args.result_path) + + # 3. Compute Scores + args.batch_size *= 2 # accelerate + if args.score_types == 'all': + scores = ['PickScore', 'Aesthetic', 'HPS', 'CLIPScore'] + else: + scores = args.score_types + + for score in scores: + data = compute_score(args, data, score_type=score) + save_json(data, args.result_path) + + # 4. Results + result = {'All': {}} + sums = {} + for score in scores: + result['All'][score] = round(sum([x[score] for x in data]) / len(data), 4) + + if 'type' in data[0]: + for x in data: + x_type = x['type'] + if x_type not in sums: + sums[x_type] = {'count': 0} + for score in scores: + sums[x_type][score] = 0 + + sums[x_type]['count'] += 1 + for score in scores: + sums[x_type][score] += x[score] + + for x_type, score_sum in sums.items(): + for score in scores: + if x_type not in result: + result[x_type] = {} + result[x_type][score] = round(score_sum[score] / score_sum['count'], 4) + + table = PrettyTable() + table.title = f'Method: {args.method}' + table.field_names = ['Type'] + scores + + for score, score_average in result.items(): + table.add_row([score] + [score_average[score] for score in scores]) + + print(table) + + result['items'] = data + save_json(result, args.result_path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--method', type=str, default='beautiful-prompt', + choices=['raw', 'magic-prompt', 'chatgpt', 'beautiful-prompt'], + help='The method to generate prompts.') + parser.add_argument('--score_types', type=str, nargs='+', default='all') + + parser.add_argument('--data_path', type=str, default='data/test.json', help='Path to the data file.') + parser.add_argument('--result_path', type=str, default='data/eval-results.json', help='Path to the result file.') + parser.add_argument('--model_path', type=str, default='outputs/ppo', help='Path to the beautiful-prompt model dir.') + parser.add_argument('--imgs_path', type=str, default='data/imgs', help='Path to the images dir.') + parser.add_argument('--max_length', type=int, default=384) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--num_inference_steps', type=int, default=20) + + parser.add_argument('--hps_path', type=str) + parser.add_argument('--pickscore_path', type=str, default='yuvalkirstain/PickScore_v1') + parser.add_argument('--imagereward_path', type=str, default='ImageReward-v1.0') + + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--num_items', type=int, default=None) + parser.add_argument('--disable_tqdm', action='store_true') + + + args = parser.parse_args() + + args.device = 'cuda' if torch.cuda.is_available() else 'cpu' + set_seed(args.seed) + main(args) diff --git a/examples/BeautifulPrompt/get_win_rate.py b/examples/BeautifulPrompt/get_win_rate.py new file mode 100644 index 0000000..198c7d1 --- /dev/null +++ b/examples/BeautifulPrompt/get_win_rate.py @@ -0,0 +1,44 @@ +from beautiful_prompt.utils import read_json + +data = read_json('to_pick_data.json') + +data = [d for d in data if 'pick' in d] + +methods = ['raw', 'magic-prompt', 'chatgpt', 'beautiful-prompt-sft'] + +win = [0, 0, 0, 0] +tie = [0, 0, 0, 0] +loss = [0, 0, 0, 0] +total = [0, 0, 0, 0] +for d in data: + if d['method1'] in methods: + index = methods.index(d['method1']) + if d['pick'] == 'img1': + loss[index] += 1 + elif d['pick'] == 'img2': + win[index] += 1 + elif d['pick'] in ['pie', 'tie']: + tie[index] += 1 + else: + print(d['pick']) + assert 0 + else: + index = methods.index(d['method2']) + if d['pick'] == 'img2': + loss[index] += 1 + elif d['pick'] == 'img1': + win[index] += 1 + elif d['pick'] in ['pie', 'tie']: + tie[index] += 1 + else: + print(d['pick']) + assert 0 + + total[index] += 1 + +print(win, tie, loss, total) +for i in range(len(methods)): + print(methods[i]) + print(win[i] / total[i]) + print(tie[i] / total[i]) + print(loss[i] / total[i]) diff --git a/examples/BeautifulPrompt/pick.py b/examples/BeautifulPrompt/pick.py new file mode 100644 index 0000000..3c0d039 --- /dev/null +++ b/examples/BeautifulPrompt/pick.py @@ -0,0 +1,60 @@ +import json +import random + +from PIL import Image +import streamlit as st + +data_file = "to_pick_data.json" + +if 'data' not in st.session_state or 'unpicked_indexs' not in st.session_state: + st.session_state['data'] = list(json.load(open(data_file, "r"))) + + unpicked_indexs = [] + for i, d in enumerate(st.session_state['data']): + if 'pick' not in d or d['pick'] not in ['img1', 'img2', 'tie']: + unpicked_indexs.append(i) + + st.session_state['unpicked_indexs'] = unpicked_indexs + +def click(pick, **kwargs): + st.session_state['data'][current_index]["pick"] = pick + st.session_state['unpicked_indexs'].remove(current_index) + + # save_data + json.dump(st.session_state['data'], open(data_file, "w"), indent=2, ensure_ascii=False) + +def show(item): + # st.write("### Pick a better picture") + st.write('#### raw_prompt: ' + item['raw_prompt']) + + img1 = Image.open(item['img1']) + img2 = Image.open(item['img2']) + + col1, col2 = st.columns(2) + + with col1: + st.image(img1) + + with col2: + st.image(img2) + + st.write("##### Based on the raw prompt, which picture is better?") + + _, col1, col2, col3 = st.columns(4) + + col1.button("left", key=f'{current_index}_left', on_click=click, kwargs={'pick':'img1'}) + col2.button("tie", key=f'{current_index}_tie', on_click=click, kwargs={'pick':'tie'}) + col3.button("right", key=f'{current_index}_right', on_click=click, kwargs={'pick':'img2'}) + + # st.button("skip", key=f'{current_index}_skip') + + st.write(f"Remaining {len(st.session_state['unpicked_indexs'])} pairs of images to be picked.") + +if len(st.session_state['unpicked_indexs']) > 0: + # current_index = st.session_state['unpicked_indexs'][0] + current_index = random.choice(st.session_state['unpicked_indexs']) + + item = st.session_state['data'][current_index] + show(item) +else: + st.write("## All picture pairs is picked. Thank you!") diff --git a/examples/BeautifulPrompt/scripts/eval.sh b/examples/BeautifulPrompt/scripts/eval.sh new file mode 100644 index 0000000..ee3ea5b --- /dev/null +++ b/examples/BeautifulPrompt/scripts/eval.sh @@ -0,0 +1,41 @@ +python eval.py \ + --method raw \ + --num_inference_steps 20 \ + --data_path data/test.json \ + --result_path outputs/eval-results-raw.json \ + --imgs_path data/imgs/raw/ \ + --disable_tqdm + +python eval.py \ + --method magic-prompt \ + --num_inference_steps 20 \ + --data_path data/test.json \ + --result_path outputs/eval-results-magic-prompt.json \ + --imgs_path data/imgs/magic-prompt/ \ + --disable_tqdm + +python eval.py \ + --method chatgpt \ + --num_inference_steps 20 \ + --data_path data/test.json \ + --result_path outputs/eval-results-chatgpt.json \ + --imgs_path data/imgs/chatgpt/ \ + --disable_tqdm + +python eval.py \ + --method beautiful-prompt \ + --num_inference_steps 20 \ + --data_path data/test.json \ + --result_path outputs/eval-results-beautiful-prompt-sft.json \ + --imgs_path data/imgs/beautiful-prompt-sft/ \ + --model_path outputs/sft \ + --disable_tqdm + +python eval.py \ + --method beautiful-prompt \ + --num_inference_steps 20 \ + --data_path data/test.json \ + --result_path outputs/eval-results-beautiful-prompt.json \ + --imgs_path data/imgs/beautiful-prompt/ \ + --model_path outputs/ppo/checkpoint_5000 \ + --disable_tqdm diff --git a/examples/BeautifulPrompt/scripts/ppo.sh b/examples/BeautifulPrompt/scripts/ppo.sh new file mode 100644 index 0000000..b06c3ee --- /dev/null +++ b/examples/BeautifulPrompt/scripts/ppo.sh @@ -0,0 +1,10 @@ +accelerate launch --config_file config/ppo.yaml train_ppo.py \ + --data_path data/data.json \ + --model_path outputs/sft \ + --aes_model_path outputs/aes \ + --ps_model_path outputs/ps \ + --save_path outputs/ppo \ + --num_layers_unfrozen 8 \ + --total_steps 5000 \ + --alpha 0.7 \ + --batch_size 4 diff --git a/examples/BeautifulPrompt/scripts/rm_aes.sh b/examples/BeautifulPrompt/scripts/rm_aes.sh new file mode 100644 index 0000000..02a28a0 --- /dev/null +++ b/examples/BeautifulPrompt/scripts/rm_aes.sh @@ -0,0 +1,9 @@ +accelerate launch --config_file config/sft.yaml train_rm.py \ + --data_path data/rm_aes_data.json \ + --save_path outputs/rm_aes \ + --batch_size 16 \ + --max_length 384 \ + --model_path bigscience/bloom-1b1 \ + --epochs 1 \ + --lr 1e-5 \ + --rm_type aes diff --git a/examples/BeautifulPrompt/scripts/rm_ps.sh b/examples/BeautifulPrompt/scripts/rm_ps.sh new file mode 100644 index 0000000..fcb81bf --- /dev/null +++ b/examples/BeautifulPrompt/scripts/rm_ps.sh @@ -0,0 +1,9 @@ +accelerate launch --config_file config/sft.yaml train_rm.py \ + --data_path data/data.json \ + --save_path outputs/rm_ps \ + --batch_size 16 \ + --max_length 384 \ + --model_path bigscience/bloom-1b1 \ + --epochs 1 \ + --lr 1e-5 \ + --rm_type ps diff --git a/examples/BeautifulPrompt/scripts/sft.sh b/examples/BeautifulPrompt/scripts/sft.sh new file mode 100644 index 0000000..bca919a --- /dev/null +++ b/examples/BeautifulPrompt/scripts/sft.sh @@ -0,0 +1,9 @@ +accelerate launch --config_file config/sft.yaml train_sft.py \ + --data_path data/data.json \ + --model_path bigscience/bloom-1b1 \ + --save_path outputs/sft \ + --batch_size 16 \ + --max_length 384 \ + --epochs 4 \ + --lr 2e-5 \ + --weight_decay 0 diff --git a/examples/BeautifulPrompt/setup.cfg b/examples/BeautifulPrompt/setup.cfg new file mode 100644 index 0000000..591f98f --- /dev/null +++ b/examples/BeautifulPrompt/setup.cfg @@ -0,0 +1,70 @@ +[metadata] +name = trlx +author = Alex Havrilla +version = 0.6.0 +url = https://github.com/CarperAI/trlx +description = A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) +long_description = file: README.md +long_description_content_type = text/markdown +license = MIT + +[options] +packages = find: +install_requires = + accelerate>=0.17.1 + attrs>=22.1.0 + cattrs>=22.2.0 + datasets + deepspeed>=0.8.1 + einops>=0.4.1 + numpy>=1.23.2 + torchtyping + transformers>=4.27.1 + tqdm + rich + wandb>=0.13.5 + ray>=2.4.0 + tabulate>=0.9.0 + networkx + tritonclient + +[options.extras_require] +bnb = bitsandbytes +dev = + black + hypothesis + isort + flake8 + pre-commit + pytest + pytest-cov + +[options.packages.find] +exclude = + docs* + tests* + +[flake8] +max-complexity = 10 +max-line-length = 127 +# flake8 error codes: https://flake8.pycqa.org/en/latest/user/error-codes.html +# pycodestyle codes: https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes +# E203 # whitespace before ‘,’, ‘;’, or ‘:’ +# E741 # do not use variables named ‘l’, ‘O’, or ‘I’ +# F401 # module imported but unused +# F821 # undefined name name +# W503 # line break before binary operator +# W605 # invalid escape sequence ‘x’ +ignore = + E203 + E741 + F821 + W503 + W605 +per-file-ignores = __init__.py:F401,loading.py:F401 +exclude = + .git + __pycache__ + docs/source/conf.py + build + dist diff --git a/examples/BeautifulPrompt/setup.py b/examples/BeautifulPrompt/setup.py new file mode 100644 index 0000000..6068493 --- /dev/null +++ b/examples/BeautifulPrompt/setup.py @@ -0,0 +1,3 @@ +from setuptools import setup + +setup() diff --git a/examples/BeautifulPrompt/to_pick_data.json b/examples/BeautifulPrompt/to_pick_data.json new file mode 100644 index 0000000..85f6689 --- /dev/null +++ b/examples/BeautifulPrompt/to_pick_data.json @@ -0,0 +1,6 @@ +[{ + "img1": "", + "img2": "", + "raw_prompt": "", + "pick": "" +}] diff --git a/examples/BeautifulPrompt/train_ppo.py b/examples/BeautifulPrompt/train_ppo.py new file mode 100644 index 0000000..ee5095c --- /dev/null +++ b/examples/BeautifulPrompt/train_ppo.py @@ -0,0 +1,213 @@ +import json +import math +import os +import io +import random +import argparse +import time + +import torch +from torch import nn +from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification + +import trlx +from trlx.data.default_configs import ( + ModelConfig, + OptimizerConfig, + PPOConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) + +from beautiful_prompt.utils import read_json + +def create_reward_fn(args): # noqa: C901 + if os.environ.get("RANK", "0") == "0": + class RewardModel(nn.Module): + def __init__(self, checkpoint_path): + super().__init__() + self.model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path, num_labels=1) + + def forward(self, input_ids: torch.LongTensor, attention_mask = None) -> torch.Tensor: + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) + value = outputs['logits'].squeeze(-1) + return value + + aes_model = RewardModel(args.aes_model_path).eval().half() + ps_model = RewardModel(args.ps_model_path).eval().half() + + aes_tokenizer = AutoTokenizer.from_pretrained(args.aes_model_path) + aes_tokenizer.truncation_side = "left" + aes_tokenizer.padding_side = "right" + + ps_tokenizer = AutoTokenizer.from_pretrained(args.ps_model_path) + ps_tokenizer.truncation_side = "left" + ps_tokenizer.padding_side = "right" + + reward_device = torch.cuda.device_count() - 1 + + aes_model = aes_model.to(reward_device) + ps_model = ps_model.to(reward_device) + + reward_batch_size = args.reward_batch_size + delta_reward = True + + @torch.no_grad() + def get_reward(raw_prompts, generated_prompts): + aes_input = aes_tokenizer( + [p + aes_tokenizer.eos_token for p in generated_prompts], + padding=True, + truncation=True, + max_length=384, + return_tensors="pt", + ) + + ps_input = ps_tokenizer( + [f"Input: {rp}\nOutput: {p}{ps_tokenizer.eos_token}" for rp, p in zip(raw_prompts, generated_prompts)], + padding=True, + truncation=True, + max_length=400, + return_tensors="pt", + ) + + mbs = reward_batch_size + aess = [] + irs = [] + for i in range(math.ceil(len(generated_prompts) / mbs)): + batch_ixs = slice(i * mbs, (i + 1) * mbs) + input_ids = aes_input.input_ids[batch_ixs].to(reward_device) + attention_mask = aes_input.attention_mask[batch_ixs].to(reward_device) + scores = aes_model(input_ids, attention_mask) + aess.extend(scores) + + batch_ixs = slice(i * mbs, (i + 1) * mbs) + input_ids = ps_input.input_ids[batch_ixs].to(reward_device) + attention_mask = ps_input.attention_mask[batch_ixs].to(reward_device) + scores = ps_model(input_ids, attention_mask) + irs.extend(scores) + + prompts_len = [max(len(p), 200) for p in generated_prompts] + + return (1-args.alpha) * torch.hstack(aess) + args.alpha * torch.hstack(irs) + 0.01 * torch.tensor(len(prompts_len)) + + def reward_fn(samples, prompts, original_output, **kwargs): + generated_prompts = [s.replace(p, '').strip().strip('') for p, s in zip(prompts, samples)] + + raw_prompts = [p.split('Input:')[1].split('Output:')[0].strip().strip('') for p in prompts] + rewards = get_reward(raw_prompts, generated_prompts) + + if not delta_reward: + return rewards + + original_rewards = get_reward(raw_prompts, original_output) + + return rewards - original_rewards + else: + reward_fn = True + + return reward_fn + +def main(args): + config = TRLConfig( + train=TrainConfig( + seq_length=args.max_length, + epochs=10000, + total_steps=args.total_steps, + batch_size=args.batch_size, + checkpoint_interval=args.checkpoint_interval, + eval_interval=args.eval_interval, + pipeline="PromptPipeline", + trainer="AcceleratePPOTrainer", + checkpoint_dir=args.save_path, + save_optimizer=False, + tracker="tensorboard", + logging_dir=args.logging_dir, + project_name=f'beautilful-prompt ppo [{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())}]', + save_best=False + ), + model=ModelConfig(model_path=args.model_path, num_layers_unfrozen=args.num_layers_unfrozen), + tokenizer=TokenizerConfig(tokenizer_path=args.model_path, truncation_side="left"), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=args.lr, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=args.weight_decay)), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=args.lr)), + method=PPOConfig( + name="PPOConfig", + num_rollouts=64, + chunk_size=16, + ppo_epochs=4, + init_kl_coef=0.05, + target=6, + horizon=10000, + gamma=1, + lam=0.95, + cliprange=0.2, + cliprange_value=0.2, + vf_coef=0.5, + scale_reward="running", + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs=dict( + max_new_tokens=256, + top_k=0, + top_p=1.0, + do_sample=True, + ), + ), + ) + + dataset = read_json(args.data_path) + random.seed(42) + + random.shuffle(dataset) + dataset = dataset[:40000] + + prompts = [ + { + "prompt": f'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {x["raw_prompt"]}\nOutput:', + "original_output": x["prompt"] + } + for x in dataset[500:] + ] + eval_prompts = [ + { + "prompt": f'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {x["raw_prompt"]}\nOutput:', + "original_output": x["prompt"] + } + for x in dataset[:500] + ] + reward_fn = create_reward_fn(args) + + trlx.train( + prompts=prompts, + eval_prompts=eval_prompts, + reward_fn=reward_fn, + config=config, + stop_sequences=[], + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', type=str, default='outputs/sft') + parser.add_argument('--aes_model_path', type=str, default='outputs/rm_aes') + parser.add_argument('--ps_model_path', type=str, default='outputs/rm_ps') + parser.add_argument('--data_path', type=str, required=True) + + parser.add_argument('--lr', type=float, default=5e-6) + parser.add_argument('--weight_decay', type=float, default=1e-6) + parser.add_argument('--max_length', type=int, default=384) + parser.add_argument('--alpha', type=float, default=0.7) + + parser.add_argument('--save_path', type=str, default='outputs/ppo') + parser.add_argument('--logging_dir', type=str, default='logs') + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--reward_batch_size', type=int, default=32) + parser.add_argument('--total_steps', type=int, default=2000) + parser.add_argument('--checkpoint_interval', type=int, default=500) + parser.add_argument('--eval_interval', type=int, default=500) + parser.add_argument('--num_layers_unfrozen', type=int, default=8) + + args = parser.parse_args() + main(args) diff --git a/examples/BeautifulPrompt/train_rm.py b/examples/BeautifulPrompt/train_rm.py new file mode 100644 index 0000000..0b5b1e3 --- /dev/null +++ b/examples/BeautifulPrompt/train_rm.py @@ -0,0 +1,81 @@ +import argparse +import os + +from transformers import AutoModelForSequenceClassification, AutoTokenizer +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from beautiful_prompt.data import RMDatasetForAES, RMDatasetForPS, DataCollatorForRMDataset +from beautiful_prompt.utils import set_seed +from beautiful_prompt.trainer import RMTrainer + +def train(args): + set_seed(args.seed) + + model = AutoModelForSequenceClassification.from_pretrained(args.model_path, num_labels=1) + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + + tokenizer.truncation_side = 'left' + tokenizer.padding_side = 'right' + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if args.rm_type.lower() == 'aes': + train_dataset = RMDatasetForAES(args.data_path, tokenizer, max_length=args.max_length) + elif args.rm_type.lower() == 'ps': + train_dataset = RMDatasetForPS(args.data_path, tokenizer, max_length=args.max_length) + else: + raise NotImplementedError() + + if dist.is_initialized() and dist.get_world_size() > 1: + train_sampler = DistributedSampler(train_dataset, + shuffle=True, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + else: + train_sampler = None + + data_collator = DataCollatorForRMDataset(tokenizer=tokenizer) + train_dataloader = DataLoader(train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True) + + trainer = RMTrainer( + model, + tokenizer, + train_dataloader, + save_path=args.save_path, + logging_dir=args.logging_dir, + + lr=args.lr, + batch_size=args.batch_size, + weight_decay=args.weight_decay, + epochs=args.epochs + ) + trainer.train() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--model_path', type=str, default='bigscience/bloom-1b1') + parser.add_argument('--data_path', type=str, required=True) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--weight_decay', type=float, default=1e-3) + parser.add_argument('--max_length', type=int, default=384) + + parser.add_argument('--save_path', type=str, default='outputs/rm') + parser.add_argument('--logging_dir', type=str, default='logs') + parser.add_argument('--epochs', type=int, default=3) + parser.add_argument('--batch_size', type=int, default=4) + + parser.add_argument('--rm_type', type=str, default='aes') + + args = parser.parse_args() + train(args) diff --git a/examples/BeautifulPrompt/train_sft.py b/examples/BeautifulPrompt/train_sft.py new file mode 100644 index 0000000..629783e --- /dev/null +++ b/examples/BeautifulPrompt/train_sft.py @@ -0,0 +1,73 @@ +import argparse +import os + +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from beautiful_prompt.data import SFTDataset, DataCollatorForSFTDataset +from beautiful_prompt.utils import set_seed +from beautiful_prompt.trainer import SFTTrainer + +def train(args): + set_seed(args.seed) + + model = AutoModelForCausalLM.from_pretrained(args.model_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + + train_dataset = SFTDataset(args.data_path, tokenizer, max_length=args.max_length) + + if dist.is_initialized() and dist.get_world_size() > 1: + train_sampler = DistributedSampler(train_dataset, + shuffle=True, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + else: + train_sampler = None + + data_collator = DataCollatorForSFTDataset(tokenizer=tokenizer) + train_dataloader = DataLoader(train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True) + + trainer = SFTTrainer( + model, + tokenizer, + train_dataloader, + save_path=args.save_path, + logging_dir=args.logging_dir, + + lr=args.lr, + batch_size=args.batch_size, + weight_decay=args.weight_decay, + epochs=args.epochs + ) + trainer.train() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--model_path', type=str, default='bigscience/bloom-1b1') + parser.add_argument('--data_path', type=str, required=True) + parser.add_argument('--lr', type=float, default=1e-5) + + # weight_decay set to 0, it is easier to overfit, which is beneficial to PPO + parser.add_argument('--weight_decay', type=float, default=0) + parser.add_argument('--max_length', type=int, default=384) + + parser.add_argument('--save_path', type=str, default='outputs/sft') + parser.add_argument('--logging_dir', type=str, default='logs') + parser.add_argument('--epochs', type=int, default=3) + parser.add_argument('--batch_size', type=int, default=4) + + args = parser.parse_args() + train(args) diff --git a/examples/BeautifulPrompt/trlx/__init__.py b/examples/BeautifulPrompt/trlx/__init__.py new file mode 100644 index 0000000..7b26a92 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/__init__.py @@ -0,0 +1,2 @@ +from .trlx import train +from .utils import logging diff --git a/examples/BeautifulPrompt/trlx/data/__init__.py b/examples/BeautifulPrompt/trlx/data/__init__.py new file mode 100644 index 0000000..96d4675 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/data/__init__.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Iterable + +from torchtyping import TensorType + + +@dataclass +class GeneralElement: + """ + General element outputted by a data pipeline + """ + + pass + + +@dataclass +class RLElement: + """ + Batch element for RL model + """ + + state: Iterable[str] = None # Context/prompts + action: TensorType["N"] = None # Tokens generated by model given prompts + reward: float = None # Reward obtained for that generation + + +@dataclass +class BatchElement: + """ + General batch element for any transformer to use in its forward pass + """ + + tokens: TensorType["BATCH", "SEQ_LEN"] + masks: TensorType["BATCH", "SEQ_LEN"] diff --git a/examples/BeautifulPrompt/trlx/data/accelerate_base_datatypes.py b/examples/BeautifulPrompt/trlx/data/accelerate_base_datatypes.py new file mode 100644 index 0000000..838567e --- /dev/null +++ b/examples/BeautifulPrompt/trlx/data/accelerate_base_datatypes.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass +from typing import Iterable + +from torchtyping import TensorType + + +@dataclass +class PromptElement: + """ + Dataclass for a single prompt, containing its string and tokenized form. + + :param text: The prompt text. + :type text: str + + :param tokens: The prompt tokens. Should be a long tensor + :type tokens: torch.Tensor + """ + + text: str + tokens: TensorType["num_tokens"] + + +@dataclass +class PromptBatch: + """ + Batched PromptElement + + :param text: An iterable of prompt texts. + :type text: Iterable[str] + + :param tokens: A long tensor batch of prompt tokens. + :type tokens: torch.Tensor + """ + + text: Iterable[str] + tokens: TensorType["batch_size", "num_tokens"] + + +@dataclass +class AccelerateRLElement: + """ + Dataclass for RL elements, containing output tokens and rewards for each token. + + :param tokens: The output tokens. Should be a long tensor + :type tokens: torch.Tensor + + :param rewards: The rewards for each token. Should be a float tensor of same size as tokens. + :type rewards: torch.Tensor + """ + + output_tokens: TensorType["output_size"] + rewards: TensorType["output_size"] + + +@dataclass +class AccelerateRLBatchElement: + """ + Batched accelerate RL element + + :param tokens: Batches of long tensors of output tokens. + :type tokens: torch.Tensor + + :param rewards: Batches of float tensors of rewards for each output token. + :type rewards: torch.Tensor + """ + + output_tokens: TensorType["batch_size", "output_size"] + rewards: TensorType["batch_size", "output_size"] diff --git a/examples/BeautifulPrompt/trlx/data/configs.py b/examples/BeautifulPrompt/trlx/data/configs.py new file mode 100644 index 0000000..8b2af9c --- /dev/null +++ b/examples/BeautifulPrompt/trlx/data/configs.py @@ -0,0 +1,334 @@ +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set + +import yaml + +from trlx.data.method_configs import MethodConfig, get_method + + +def merge(base: Dict, update: Dict, updated: Set) -> Dict: + "Recursively updates a nested dictionary with new values" + for k, v in base.items(): + if k in update and isinstance(v, dict): + base[k] = merge(v, update[k], updated) + updated.add(k) + elif k in update: + base[k] = update[k] + updated.add(k) + + return base + + +def _merge_dicts(base: Dict, update: Dict) -> Dict: + "Merge two dictionaries recursively, returning a new dictionary." + + base = deepcopy(base) + + for k, v in update.items(): + if isinstance(v, dict): + base[k] = _merge_dicts(base.get(k, {}), v) + else: + base[k] = v + + return base + + +@dataclass +class ModelConfig: + """ + Config for a model. + + :param model_path: Path or name of the model (local or on huggingface hub) + :type model_path: str + + :param model_arch_type: Type of model architecture. Either "causal" or "seq2seq" + :type model_arch_type: str + + :param num_layers_unfrozen: Number of layers to unfreeze for fine-tuning. + -1 means all layers are unfrozen. + :type num_layers_unfrozen: int + + :param delta_kwargs: Keyword arguments for instantiating OpenDelta models for delta-tuning. + Follow the `OpenDelta.AutoDeltaConfig` specification, e.g. for LoRA style tuning, set + the `delta_type` to `lora` and include the model specific hyper-parameters (e.g. `lora_r`) + {"delta_type": "lora", "modified_modules": "all", "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.0} + or in YAML format: + delta_kwargs: + delta_type: lora + modified_modules: "all" + lora_r: 8 + lora_alpha: 16 + lora_dropout: 0.0 + See: https://opendelta.readthedocs.io/en/latest/modules/auto_delta.html#opendelta.auto_delta.AutoDeltaConfig + :type delta_kwargs: Optional[Dict[str, Any]] + """ + + model_path: str + model_arch_type: str = "causal" + num_layers_unfrozen: int = -1 + delta_kwargs: Optional[Dict[str, Any]] = None + + @classmethod + def from_dict(cls, config: Dict[str, Any]): + return cls(**config) + + +@dataclass +class TokenizerConfig: + """ + Config for a model. + + :param tokenizer_path: Path or name of the tokenizer (local or on huggingface hub) + :type tokenizer_path: str + + :param padding_side: Padding side + :type padding_path: str + + :param truncation_side: Truncation side + :type truncation_side: str + """ + + tokenizer_path: str + padding_side: str = "left" + truncation_side: str = "right" + + @classmethod + def from_dict(cls, config: Dict[str, Any]): + return cls(**config) + + +@dataclass +class OptimizerConfig: + """ + Config for an optimizer. + + :param name: Name of the optimizer + :type name: str + + :param kwargs: Keyword arguments for the optimizer (e.g. lr, betas, eps, weight_decay) + :type kwargs: Dict[str, Any] + """ + + name: str + kwargs: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, config: Dict[str, Any]): + return cls(**config) + + +@dataclass +class SchedulerConfig: + """ + Config for a learning rate scheduler. + + :param name: Name of the scheduler + :type name: str + + :param kwargs: Keyword arguments for the scheduler instance (e.g. warmup_steps, T_max) + :type kwargs: Dict[str, Any] + """ + + name: str + kwargs: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, config: Dict[str, Any]): + return cls(**config) + + +@dataclass +class TrainConfig: + """ + Config for train job on model. + + :param total_steps: Total number of training steps + :type total_steps: int + + :param seq_length: Number of tokens to use as context (max length for tokenizer) + :type seq_length: int + + :param epochs: Total number of passes through data + :type epochs: int + + :param batch_size: Batch size for training + :type batch_size: int + + :param tracker: Tracker to use for logging. Default: "wandb" + :type tracker: str + + :param checkpoint_interval: Save model every checkpoint_interval steps. + Each checkpoint is stored in a sub-directory of the `TrainConfig.checkpoint_dir` + directory in the format `checkpoint_dir/checkpoint_{step}`. + :type checkpoint_interval: int + + :param eval_interval: Evaluate model every eval_interval steps + :type eval_interval: int + + :param pipeline: Pipeline to use for training. One of the registered pipelines present in trlx.pipeline + :type pipeline: str + + :param trainer: Trainer to use for training. One of the registered trainers present in trlx.trainer + :type trainer: str + + :param trainer_kwargs: Extra keyword arguments for the trainer + :type trainer: Dict[str, Any] + + :param project_name: Project name for wandb + :type project_name: str + + :param entity_name: Entity name for wandb + :type entity_name: str + + :param group_name: Group name for wandb (used for grouping runs) + :type group_name: str + + :param checkpoint_dir: Directory to save checkpoints + :type checkpoint_dir: str + + :param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. + Only used by AcceleratePPOTrainer. + :type rollout_logging_dir: Optional[str] + + :param save_best: Save best model based on mean reward + :type save_best: bool + + :param seed: Random seed + :type seed: int + + :param minibatch_size: Size of model input during one forward pass. Must divide batch size + :type minibatch_size: int + """ + + total_steps: int + seq_length: int + epochs: int + batch_size: int + + checkpoint_interval: int + eval_interval: int + + pipeline: str # One of the pipelines in framework.pipeline + trainer: str # One of the trainers + trainer_kwargs: Dict[str, Any] = field(default_factory=dict) # Extra keyword arguments for the trainer + + project_name: str = "trlx" + entity_name: Optional[str] = None + group_name: Optional[str] = None + + checkpoint_dir: str = "ckpts" + rollout_logging_dir: Optional[str] = None + save_best: bool = True + save_optimizer: bool = True + + tracker: Optional[str] = "wandb" + logging_dir: Optional[str] = None + tags: Optional[List[str]] = field(default_factory=list) + + seed: int = 1000 + + minibatch_size: Optional[int] = None + + @classmethod + def from_dict(cls, config: Dict[str, Any]): + return cls(**config) + + +@dataclass +class TRLConfig: + """ + Top level config for trlX. Loads configs and can be converted to dictionary. + """ + + method: MethodConfig + model: ModelConfig + optimizer: OptimizerConfig + scheduler: SchedulerConfig + tokenizer: TokenizerConfig + train: TrainConfig + + @classmethod + def load_yaml(cls, yml_fp: str): + """ + Load yaml file as TRLConfig. + + :param yml_fp: Path to yaml file + :type yml_fp: str + """ + with open(yml_fp, mode="r") as file: + config = yaml.safe_load(file) + return cls.from_dict(config) + + def to_dict(self): + """ + Convert TRLConfig to dictionary. + """ + data = { + "method": self.method.__dict__, + "model": self.model.__dict__, + "optimizer": self.optimizer.__dict__, + "scheduler": self.scheduler.__dict__, + "tokenizer": self.tokenizer.__dict__, + "train": self.train.__dict__, + } + + return data + + def evolve(self, **kwargs) -> "TRLConfig": + """ + Evolve TRLConfig with new parameters. Can update nested parameters. + >>> config = trlx.data.default_configs.default_ilql_config() + >>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100)) + >>> config.method.gamma + 0.99 + """ + return TRLConfig.from_dict(_merge_dicts(self.to_dict(), kwargs)) + + @classmethod + def from_dict(cls, config: Dict): + """ + Convert dictionary to TRLConfig. + """ + return cls( + method=get_method(config["method"]["name"]).from_dict(config["method"]), + model=ModelConfig.from_dict(config["model"]), + tokenizer=TokenizerConfig.from_dict(config["tokenizer"]), + optimizer=OptimizerConfig.from_dict(config["optimizer"]), + scheduler=SchedulerConfig.from_dict(config["scheduler"]), + train=TrainConfig.from_dict(config["train"]), + ) + + @classmethod + def update(cls, baseconfig: Dict, config: Dict): + update = {} + # unflatten a string variable name into a nested dictionary + # key1.key2.key3: value -> {key1: {key2: {key3: value}}} + for name, value in config.items(): + if isinstance(value, dict): + update[name] = value + else: + *layers, var = name.split(".") + if layers: + d = update.setdefault(layers[0], {}) + for layer in layers[1:]: + d = d.setdefault(layer, {}) + d[var] = value + + if not isinstance(baseconfig, Dict): + baseconfig = baseconfig.to_dict() + + updates = set() + merged = merge(baseconfig, update, updates) + + for param in update: + if param not in updates: + raise ValueError(f"parameter {param} is not present in the config (typo or a wrong config)") + + return cls.from_dict(merged) + + def __str__(self): + """Returns a human-readable string representation of the config.""" + import json + + return json.dumps(self.to_dict(), indent=4) diff --git a/examples/BeautifulPrompt/trlx/data/default_configs.py b/examples/BeautifulPrompt/trlx/data/default_configs.py new file mode 100644 index 0000000..b8ab22c --- /dev/null +++ b/examples/BeautifulPrompt/trlx/data/default_configs.py @@ -0,0 +1,119 @@ +from trlx.models.modeling_ilql import ILQLConfig +from trlx.models.modeling_ppo import PPOConfig +from trlx.trainer.accelerate_sft_trainer import SFTConfig + +from .configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) + + +def default_ppo_config(): + return TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=10000, + batch_size=32, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AcceleratePPOTrainer", + ), + model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=2), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig( + name="adamw", kwargs=dict(lr=3e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + ), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=3e-5)), + method=PPOConfig( + name="PPOConfig", + num_rollouts=128, + chunk_size=128, + ppo_epochs=4, + init_kl_coef=0.001, + target=None, + horizon=10000, + gamma=1, + lam=0.95, + cliprange=0.2, + cliprange_value=0.2, + vf_coef=1, + scale_reward="ignored", + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs=dict( + max_new_tokens=40, + top_k=0, + top_p=1.0, + do_sample=True, + ), + ), + ) + + +def default_ilql_config(): + return TRLConfig( + train=TrainConfig( + seq_length=64, + batch_size=128, + epochs=100, + total_steps=1000, + checkpoint_interval=1000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateILQLTrainer", + ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig( + name="adamw", kwargs=dict(lr=5.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + ), + scheduler=SchedulerConfig( + name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=5.0e-5) # train.total_steps + ), + method=ILQLConfig( + name="ilqlconfig", + tau=0.7, + gamma=0.99, + cql_scale=0.1, + awac_scale=1, + alpha=0.001, + beta=0, + steps_for_target_q_sync=5, + two_qs=True, + gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=1, temperature=1.0), + ), + ) + + +def default_sft_config(): + return TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=1000, + batch_size=8, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateSFTTrainer", + ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig( + name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + ), + scheduler=SchedulerConfig( + name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4) # train.total_steps + ), + method=SFTConfig( + name="sftconfig", + gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True), + ), + ) diff --git a/examples/BeautifulPrompt/trlx/data/ilql_types.py b/examples/BeautifulPrompt/trlx/data/ilql_types.py new file mode 100644 index 0000000..cb83309 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/data/ilql_types.py @@ -0,0 +1,109 @@ +from dataclasses import dataclass, fields + +from torchtyping import TensorType # type: ignore + + +def flatten_dataclass(cls: type): + """Return a function that flattens a dataclass into a list""" + cls_fields = [f.name for f in fields(cls)] + return lambda x: [getattr(x, f) for f in cls_fields] + + +def unflatten_dataclass(cls: type): + """Return a function that unflattens a list into a dataclass""" + cls_fields = [f.name for f in fields(cls)] + return lambda x: cls(**dict(zip(cls_fields, x))) + + +@dataclass +class ILQLElement: + """ + Data element for ILQL + + :param input_ids: Input tokens. Should be a long tensor. + :type input_ids: torch.Tensor + + :param attention_mask: Attention mask. Should be a long tensor. + :type attention_mask: torch.Tensor + + :param rewards: Rewards for each token. Should be a float tensor of same size as tokens. + :type rewards: torch.Tensor + """ + + input_ids: TensorType["query_size"] + attention_mask: TensorType["query_size"] + rewards: TensorType["reward_size"] + states_ixs: TensorType["states_size"] + actions_ixs: TensorType["reward_size"] + dones: TensorType["states_size"] + + +@dataclass +class ILQLSeq2SeqElement: + """ + Data element for ILQL + + :param input_ids: Input tokens. Should be a long tensor. + :type input_ids: torch.Tensor + + :param attention_mask: Attention mask. Should be a long tensor. + :type attention_mask: torch.Tensor + + :param rewards: Rewards for each token. Should be a float tensor of same size as tokens. + :type rewards: torch.Tensor + """ + + input_ids: TensorType["query_size"] + attention_mask: TensorType["query_size"] + decoder_input_ids: TensorType["reward_size"] + rewards: TensorType["reward_size"] + states_ixs: TensorType["states_size"] + actions_ixs: TensorType["reward_size"] + dones: TensorType["states_size"] + + +@dataclass +class ILQLBatch: + """ + Batched ILQL data elements + + :param input_ids: Batch of input tokens. + :type input_ids: torch.Tensor + + :param attention_mask: Batch of attention masks. + :type attention_mask: torch.Tensor + + :param rewards: Batch of rewards for each token in each token batch. + :type rewards: torch.Tensor + """ + + input_ids: TensorType["batch_size", "query_size"] + attention_mask: TensorType["batch_size", "query_size"] + rewards: TensorType["batch_size", "reward_size"] + states_ixs: TensorType["batch_size", "states_size"] + actions_ixs: TensorType["batch_size", "reward_size"] + dones: TensorType["batch_size", "states_size"] + + +@dataclass +class ILQLSeq2SeqBatch: + """ + Batched ILQL data elements + + :param input_ids: Batch of input tokens. + :type input_ids: torch.Tensor + + :param attention_mask: Batch of attention masks. + :type attention_mask: torch.Tensor + + :param rewards: Batch of rewards for each token in each token batch. + :type rewards: torch.Tensor + """ + + input_ids: TensorType["batch_size", "query_size"] + attention_mask: TensorType["batch_size", "query_size"] + decoder_input_ids: TensorType["batch_size", "reward_size"] + rewards: TensorType["batch_size", "reward_size"] + states_ixs: TensorType["batch_size", "states_size"] + actions_ixs: TensorType["batch_size", "reward_size"] + dones: TensorType["batch_size", "states_size"] diff --git a/examples/BeautifulPrompt/trlx/data/method_configs.py b/examples/BeautifulPrompt/trlx/data/method_configs.py new file mode 100644 index 0000000..435ce13 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/data/method_configs.py @@ -0,0 +1,56 @@ +import sys +from dataclasses import dataclass +from typing import Any, Dict + +# specifies a dictionary of method configs +_METHODS: Dict[str, Any] = {} # registry + + +def register_method(name): + """Decorator used register a method config + Args: + name: Name of the method + """ + + def register_class(cls, name): + _METHODS[name] = cls + setattr(sys.modules[__name__], name, cls) + return cls + + if isinstance(name, str): + name = name.lower() + return lambda c: register_class(c, name) + + cls = name + name = cls.__name__ + register_class(cls, name.lower()) + + return cls + + +@dataclass +@register_method +class MethodConfig: + """ + Config for a certain RL method. + + :param name: Name of the method + :type name: str + """ + + name: str + + @classmethod + def from_dict(cls, config: Dict[str, Any]): + return cls(**config) + + +def get_method(name: str) -> MethodConfig: + """ + Return constructor for specified method config + """ + name = name.lower() + if name in _METHODS: + return _METHODS[name] + else: + raise Exception("Error: Trying to access a method that has not been registered") diff --git a/examples/BeautifulPrompt/trlx/data/ppo_types.py b/examples/BeautifulPrompt/trlx/data/ppo_types.py new file mode 100644 index 0000000..375d7d3 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/data/ppo_types.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass + +from torchtyping import TensorType + + +@dataclass +class PPORLElement: + """ + :param query_tensor: The query tensor i.e. the prompt tokens. + Should be a long tensor. + :type query_tensor: torch.Tensor + + :param response_tensor: The response tensor i.e. the output tokens. + Should be a long tensor. + :type response_tensor: torch.Tensor + + :param logprobs: The log probabilities over the response tokens generated + by the policy network (i.e. the autoregressive model). + Should be a float tensor of same size as tokens. + :type logprobs: torch.Tensor + + :param values: The values for each token generated from the value network or value head. + Should be a float tensor of same size as tokens. + :type values: torch.Tensor + + :param rewards: The rewards for each token outputted in response. + Should be a float tensor of same size as tokens. + :type rewards: torch.Tensor + """ + + query_tensor: TensorType["query_size"] + response_tensor: TensorType["response_size"] + logprobs: TensorType["response_size"] + values: TensorType["response_size"] + rewards: TensorType["response_size"] + + +@dataclass +class PPORLBatch: + """ + A batched version of the PPORLElement. See PPORLElement for more details on individual fields. + + :param query_tensors: A batch of query tensors. Should be a long tensor. + :type query_tensors: torch.Tensor + + :param response_tensors: A batch of response tensors. Should be a long tensor. + :type response_tensors: torch.Tensor + + :param logprobs: A batch of log probabilities from policy + :type logprobs: torch.Tensor + + :param values: A batch of values from value network + :type values: torch.Tensor + + :param rewards: A batch of rewards + :type rewards: torch.Tensor + """ + + query_tensors: TensorType["batch_size", "query_size"] + response_tensors: TensorType["batch_size", "response_size"] + logprobs: TensorType["batch_size", "response_size"] + values: TensorType["batch_size", "response_size"] + rewards: TensorType["batch_size", "response_size"] diff --git a/examples/BeautifulPrompt/trlx/models/README.md b/examples/BeautifulPrompt/trlx/models/README.md new file mode 100644 index 0000000..fdcb445 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/models/README.md @@ -0,0 +1,344 @@ +## Using pretrained NeMo models +To use a NeMo models in `.nemo` format, like [NeMo Megatron-GPT-20B](https://huggingface.co/nvidia/nemo-megatron-gpt-20B), download and un-tar it: +``` +tar xvf nemo_gpt20B_bf16_tp4.nemo +``` +This will extract the model weights and the model config. + +Then set `train.trainer_kwargs.pretrained_model` to the path to the directory containing the parameters. The model hyperparameters in the `train.trainer_kwargs.megatron_cfg` should match the ones in the model config. + +## Inference ILQL trained NeMo models +To load a checkpoint, run +``` +python examples/nemo_ilql_inference.py configs/nemo_configs/megatron_20b.yaml "/path/to/ilql_sentiments_logs/checkpoints" +``` +To save checkpoints, ensure the following is set in the NeMo config: +``` +exp_manager: + explicit_log_dir: ilql_sentiments_logs + create_checkpoint_callback: True +``` + +## Resume Training +To resume training, ensure the following is set in the NeMo config: +``` +exp_manager: + resume_if_exists: True +``` + +## NeMo Megatron setup +Clone https://github.com/NVIDIA/NeMo/tree/r1.15.0 (currently only up to `r1.15.0` is supoprted) and apex from https://github.com/NVIDIA/apex/. + +1) install conda (or mamba/micromamba) + +2) srun into a compute node with a gpu (if running on HPC cluster) +``` +srun --pty bash -i +``` + +3) copy the conda env export below and change the name and prefix +``` +conda env create -f env.yaml +``` + +4) install nemo +``` +git clone https://github.com/NVIDIA/NeMo/ +cd NeMo && pip install '.[all]' +``` + +6) install apex (or clone the github) +``` +git clone https://github.com/NVIDIA/apex/ +cd apex +pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./ +``` + + +# conda env export +``` +name: nemo-113 +prefix: /mnt/nvme/jobs/nemo/nemo-source +channels: + - anaconda + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - bzip2=1.0.8=h7f98852_4 + - c-ares=1.18.1=h7f8727e_0 + - ca-certificates=2022.9.24=ha878542_0 + - curl=7.84.0=h5eee18b_0 + - expat=2.4.4=h295c915_0 + - gettext=0.21.1=h27087fc_0 + - git=2.34.1=pl5262hc120c5b_0 + - krb5=1.19.2=hac12032_0 + - lame=3.100=h166bdaf_1003 + - ld_impl_linux-64=2.39=hcc3a1bd_1 + - libcurl=7.84.0=h91b91d3_0 + - libedit=3.1.20210910=h7f8727e_0 + - libev=4.33=h7f8727e_1 + - libffi=3.2.1=he1b5a44_1007 + - libflac=1.4.2=h27087fc_0 + - libgcc-ng=12.2.0=h65d4601_19 + - libgomp=12.2.0=h65d4601_19 + - libnghttp2=1.46.0=hce63b2e_0 + - libnsl=2.0.0=h7f98852_0 + - libogg=1.3.4=h7f98852_1 + - libopus=1.3.1=h7f98852_1 + - libsndfile=1.1.0=h27087fc_0 + - libsqlite=3.39.4=h753d276_0 + - libssh2=1.10.0=h8f2d780_0 + - libstdcxx-ng=12.2.0=h46fd767_19 + - libuuid=2.32.1=h7f98852_1000 + - libvorbis=1.3.7=h9c3ff4c_0 + - libzlib=1.2.12=h166bdaf_2 + - mpg123=1.30.2=h27087fc_1 + - ncurses=6.3=h27087fc_1 + - openssl=1.1.1q=h7f8727e_0 + - pcre2=10.37=he7ceb23_1 + - perl=5.26.2=h14c3975_0 + - pip=22.3.1=pyhd8ed1ab_0 + - python=3.8.2=he5300dc_7_cpython + - readline=8.1.2=h0f457ee_0 + - sqlite=3.39.4=h4ff8645_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.38.4=pyhd8ed1ab_0 + - xz=5.2.6=h166bdaf_0 + - zlib=1.2.12=h7f8727e_2 + - pip: + - absl-py==1.3.0 + - aiohttp==3.8.3 + - aiosignal==1.3.1 + - alabaster==0.7.12 + - aniso8601==9.0.1 + - antlr4-python3-runtime==4.9.3 + - appdirs==1.4.4 + - asttokens==2.1.0 + - async-timeout==4.0.2 + - attrdict==2.0.1 + - attrs==22.1.0 + - audioread==3.0.0 + - babel==2.11.0 + - backcall==0.2.0 + - beautifulsoup4==4.11.1 + - black==19.10b0 + - boto3==1.26.13 + - botocore==1.29.13 + - braceexpand==0.1.7 + - cachetools==5.2.0 + - certifi==2022.9.24 + - cffi==1.15.1 + - charset-normalizer==2.1.1 + - click==8.0.2 + - colorama==0.4.6 + - commonmark==0.9.1 + - contourpy==1.0.6 + - cycler==0.11.0 + - cython==0.29.32 + - debugpy==1.6.3 + - decorator==5.1.1 + - distance==0.1.3 + - docker-pycreds==0.4.0 + - docopt==0.6.2 + - docutils==0.19 + - editdistance==0.6.1 + - einops==0.6.0 + - entrypoints==0.4 + - exceptiongroup==1.0.4 + - executing==1.2.0 + - faiss-cpu==1.7.3 + - fasttext==0.9.2 + - filelock==3.8.0 + - flask==2.2.2 + - flask-restful==0.3.9 + - fonttools==4.38.0 + - frozenlist==1.3.3 + - fsspec==2022.11.0 + - ftfy==6.1.1 + - g2p-en==2.1.0 + - gdown==4.5.3 + - gitdb==4.0.9 + - gitpython==3.1.29 + - google-auth==2.14.1 + - google-auth-oauthlib==0.4.6 + - grpcio==1.50.0 + - h5py==3.7.0 + - huggingface-hub==0.11.0 + - hydra-core==1.2.0 + - idna==3.4 + - ijson==3.1.4 + - imagesize==1.4.1 + - importlib-metadata==5.0.0 + - importlib-resources==5.10.0 + - inflect==6.0.2 + - iniconfig==1.1.1 + - ipadic==1.0.0 + - ipykernel==6.17.1 + - ipython==8.6.0 + - ipywidgets==8.0.2 + - isort==4.3.21 + - itsdangerous==2.1.2 + - jedi==0.18.1 + - jieba==0.42.1 + - jinja2==3.1.2 + - jiwer==2.5.1 + - jmespath==1.0.1 + - joblib==1.2.0 + - jupyter-client==7.4.7 + - jupyter-core==5.0.0 + - jupyterlab-widgets==3.0.3 + - kaldi-python-io==1.2.2 + - kaldiio==2.17.2 + - kiwisolver==1.4.4 + - latexcodec==2.0.1 + - levenshtein==0.20.2 + - librosa==0.9.2 + - llvmlite==0.39.1 + - loguru==0.6.0 + - lxml==4.9.1 + - markdown==3.4.1 + - markupsafe==2.1.1 + - marshmallow==3.19.0 + - matplotlib==3.6.2 + - matplotlib-inline==0.1.6 + - mecab-python3==1.0.5 + - mpmath==1.2.1 + - multidict==6.0.2 + - nest-asyncio==1.5.6 + - nltk==3.7 + - numba==0.56.4 + - numpy==1.23.4 + - nvidia-cublas-cu11==11.10.3.66 + - nvidia-cuda-nvrtc-cu11==11.7.99 + - nvidia-cuda-runtime-cu11==11.7.99 + - nvidia-cudnn-cu11==8.5.0.96 + - oauthlib==3.2.2 + - omegaconf==2.2.3 + - onnx==1.12.0 + - opencc==1.1.4 + - packaging==21.3 + - pandas==1.5.1 + - pangu==4.0.6.1 + - parameterized==0.8.1 + - parso==0.8.3 + - pathspec==0.10.2 + - pathtools==0.1.2 + - pesq==0.0.4 + - pexpect==4.8.0 + - pickleshare==0.7.5 + - pillow==9.3.0 + - pip-api==0.0.30 + - pipreqs==0.4.11 + - plac==1.3.5 + - platformdirs==2.5.4 + - pluggy==1.0.0 + - pooch==1.6.0 + - portalocker==2.6.0 + - progress==1.6 + - promise==2.3 + - prompt-toolkit==3.0.32 + - protobuf==3.20.1 + - psutil==5.9.4 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pyannote-core==4.5 + - pyannote-database==4.1.3 + - pyannote-metrics==3.2.1 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pybind11==2.10.1 + - pybtex==0.24.0 + - pybtex-docutils==1.0.2 + - pycparser==2.21 + - pydantic==1.10.2 + - pydeprecate==0.3.2 + - pydub==0.25.1 + - pygments==2.13.0 + - pynini==2.1.5 + - pyparsing==3.0.9 + - pypinyin==0.47.1 + - pysocks==1.7.1 + - pystoi==0.3.3 + - pytest==7.2.0 + - pytest-runner==6.0.0 + - python-dateutil==2.8.2 + - pytorch-lightning==1.7.7 + - pytz==2022.6 + - pyyaml==5.4.1 + - pyzmq==24.0.1 + - rapidfuzz==2.13.2 + - regex==2022.10.31 + - requests==2.28.1 + - requests-oauthlib==1.3.1 + - resampy==0.4.2 + - rich==12.6.0 + - rsa==4.9 + - ruamel-yaml==0.17.21 + - ruamel-yaml-clib==0.2.7 + - s3transfer==0.6.0 + - sacremoses==0.0.53 + - scikit-learn==1.1.3 + - scipy==1.9.3 + - sentence-transformers==2.2.2 + - sentencepiece==0.1.97 + - sentry-sdk==1.11.0 + - setproctitle==1.3.2 + - setuptools==59.5.0 + - shellingham==1.5.0 + - shortuuid==1.0.11 + - simplejson==3.18.0 + - six==1.16.0 + - smmap==5.0.0 + - snowballstemmer==2.2.0 + - sortedcontainers==2.4.0 + - soundfile==0.11.0 + - soupsieve==2.3.2.post1 + - sox==1.4.1 + - sphinx==5.3.0 + - sphinxcontrib-applehelp==1.0.2 + - sphinxcontrib-bibtex==2.5.0 + - sphinxcontrib-devhelp==1.0.2 + - sphinxcontrib-htmlhelp==2.0.0 + - sphinxcontrib-jsmath==1.0.1 + - sphinxcontrib-qthelp==1.0.3 + - sphinxcontrib-serializinghtml==1.1.5 + - stack-data==0.6.1 + - sympy==1.11.1 + - tabulate==0.9.0 + - tensorboard==2.11.0 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - termcolor==2.1.0 + - text-unidecode==1.3 + - textdistance==4.5.0 + - texterrors==0.4.4 + - threadpoolctl==3.1.0 + - tokenizers==0.12.1 + - toml==0.10.2 + - tomli==2.0.1 + - torch==1.13.0 + - torchaudio==0.13.0 + - torchmetrics==0.10.3 + - torchvision==0.14.0 + - tornado==6.2 + - tqdm==4.64.1 + - traitlets==5.5.0 + - transformers==4.21.2 + - typed-ast==1.5.4 + - typer==0.7.0 + - typing-extensions==4.4.0 + - urllib3==1.26.12 + - wandb==0.13.5 + - wcwidth==0.2.5 + - webdataset==0.1.62 + - werkzeug==2.2.2 + - wget==3.2 + - widgetsnbextension==4.0.3 + - wrapt==1.14.1 + - yarg==0.1.9 + - yarl==1.8.1 + - youtokentome==1.0.6 + - zipp==3.10.0 +``` diff --git a/examples/BeautifulPrompt/trlx/models/__init__.py b/examples/BeautifulPrompt/trlx/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/BeautifulPrompt/trlx/models/modeling_base.py b/examples/BeautifulPrompt/trlx/models/modeling_base.py new file mode 100644 index 0000000..7fa8dfb --- /dev/null +++ b/examples/BeautifulPrompt/trlx/models/modeling_base.py @@ -0,0 +1,225 @@ +# Copyright 2022 CarperAI & The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# NOTE: This file contains a modified version of the `PreTrainedModelWrapper` class from +# HuggingFace's `trl` library. The original source code can be found here: +# https://github.com/lvwerra/trl/blob/78c13226bf8ea1ccd9b1c091f03a938098521f6c/trl/models/modeling_base.py + +import inspect +import json +import os +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.nn as nn +import transformers +from huggingface_hub import hf_hub_download + + +class PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin): + """A wrapper around `transformers.PreTrainedModel` + + Reference: @younesbelkada's `PreTrainedModelWrapper` + https://github.com/lvwerra/trl/blob/4f5c16fafde42d9aca971952bcdcc1f5a0a68cf0/trl/models/modeling_base.py#L2 + + Attributes: + _auto_model_parent_class (transformers.AutoModel): The `transformers.AutoModel` + type to base the wrapping behavior off of, e.g. `transformers.AutoModelForCausalLM`. + _supported_modules (List[str]): A list of attribute names for modules of + the underlying architecture model. This is used, for example, to save + and load any additional modules by manipulating the state dict. + _supported_args (List[str]): A list of arguments specific to the underlying + architecture to separate from arguments that are supported by the + parent `AutoModel` class. Any arguments that are not supported by the + underlying model will be passed to the parent `AutoModel` class. + """ + + _auto_model_parent_class: transformers.AutoModel = None + _supported_modules: List[str] = None + # TODO (jon-tow): Supported args should come from a `PretrainedConfig` of the + # specific underlying type similar to how config instances can be used to instantiate + # `transformers.PreTrainedModel`s. + _supported_args: List[str] = None + + def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, **kwargs): + super().__init__() + self.base_model = base_model + # cache `forward` args for general use (avoids incompatible args across architectures) + self.forward_kwargs = inspect.getfullargspec(self.base_model.forward).args + + @classmethod + def _split_kwargs(cls, kwargs: Dict[str, Any]): + """Separates the kwargs from the supported arguments within `supported_args` + and those that are not + """ + supported_kwargs = {} + unsupported_kwargs = {} + for key, value in kwargs.items(): + if key in cls._supported_args: + supported_kwargs[key] = value + else: + unsupported_kwargs[key] = value + return supported_kwargs, unsupported_kwargs + + @classmethod + def from_config(cls, config: transformers.PretrainedConfig, **kwargs): + """Instantiate the pretrained pytorch model from a configuration. + + Args: + config (transformers.PretrainedConfig): The configuration to use to + instantiate the base model. + + NOTE: Loading a model from its configuration file does **not** load the + model weights. It only affects the model's configuration. Use + `~transformers.AutoModel.from_pretrained` to load the model weights. + """ + if kwargs is not None: + wrapped_model_kwargs, from_config_kwargs = cls._split_kwargs(kwargs) + else: + from_config_kwargs = {} + wrapped_model_kwargs = {} + base_model = cls._auto_model_parent_class.from_config(config, **from_config_kwargs) + model = cls(base_model, **wrapped_model_kwargs) + return model + + @classmethod + def from_pretrained( # noqa: max-complexity + cls, + pretrained_model_name_or_path: Union[str, transformers.PreTrainedModel], + revision=None, + *model_args, + **kwargs, + ): + """Instantiate a pretrained pytorch model from a pretrained model configuration. + This method is a wrapper around `transformers.PreTrainedModel.from_pretrained`. + Please refer to the documentation of `transformers.PreTrainedModel.from_pretrained` + for more information. + + Args: + pretrained_model_name_or_path (str or `transformers.PreTrainedModel`): + The identifier of the pretrained model to load or the pretrained model itself. + *model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the `_auto_model_parent_class`. + **kwargs (dict, *optional*): + Dictionary of keyword arguments to pass to both the underlying `_auto_model_parent_class` + call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific + instance of the wrapped model. + + NOTE: You must pass in arguments specific to the wrapped model as keyword arguments. + """ + if kwargs is not None: + wrapped_model_kwargs, from_pretrained_kwargs = cls._split_kwargs(kwargs) + else: + from_pretrained_kwargs = {} + wrapped_model_kwargs = {} + + if isinstance(pretrained_model_name_or_path, str): + # Load the base model using the `transformers` AutoClass (e.g. AutoModelForCausalLM) + base_model = cls._auto_model_parent_class.from_pretrained( + pretrained_model_name_or_path, *model_args, revision=revision, **from_pretrained_kwargs + ) + elif isinstance(pretrained_model_name_or_path, transformers.PreTrainedModel): + base_model = pretrained_model_name_or_path + else: + raise ValueError( + f"Invalid type for `base_model_name_or_path`: {type(pretrained_model_name_or_path)}" + "Expected `str` or `transformers.PreTrainedModel`." + ) + + model = cls(base_model, **wrapped_model_kwargs) + + if isinstance(pretrained_model_name_or_path, str): + filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") + is_sharded = False + + if not os.path.exists(filename): + try: + filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin", revision=revision) + # Sharded + except Exception: + if os.path.exists(sharded_index_filename): + index_file_name = sharded_index_filename + else: + index_file_name = hf_hub_download( + pretrained_model_name_or_path, + "pytorch_model.bin.index.json", + revision=revision, + ) + with open(index_file_name, "r") as f: + index = json.load(f) + # Collect files containing weights from supported modules + files_to_download = set() + for k, v in index["weight_map"].items(): + if any([module in k for module in cls._supported_modules]): + files_to_download.add(v) + is_sharded = True + + if is_sharded: + # Merge each shard into a state dict + # TODO: Optimize this to avoid wasting RAM + state_dict = {} + for shard_file in files_to_download: + filename = os.path.join(pretrained_model_name_or_path, shard_file) + # Download if shard file doesn't exist locally + if not os.path.exists(filename): + filename = hf_hub_download(pretrained_model_name_or_path, shard_file, revision=revision) + state_dict.update(torch.load(filename, map_location="cpu")) + else: + state_dict = torch.load(filename, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path.state_dict() + + model.post_init(state_dict=state_dict) + return model + + def save_pretrained(self, *args, **kwargs): + """Save the pretrained model to a directory. This method is a wrapper + around `transformers.PreTrainedModel.save_pretrained`. Please refer to + the documentation of `transformers.PreTrainedModel.save_pretrained` for + more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `save_pretrained` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `save_pretrained` method. + """ + state_dict = kwargs.get("state_dict", None) + if state_dict is None: + state_dict = self.state_dict() + kwargs["state_dict"] = state_dict + + return self.base_model.save_pretrained(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + """Return the state_dict of the pretrained model.""" + raise NotImplementedError + + def post_init(self, *args, **kwargs): + """Post initialization method. This method is called after the model is + instantiated and loaded from a checkpoint. It can be used to perform + additional operations such as loading the state_dict. + """ + raise NotImplementedError + + def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: + """Filter out arguments not supported by the specific instance of + `base_model.transformer.forward` + """ + # FIXME: This is a hack to get around the fact that the `transformers` + # architectures we use don't have a consistent API for `forward` parameters. + return {k: v for k, v in kwargs.items() if k in self.forward_kwargs} diff --git a/examples/BeautifulPrompt/trlx/models/modeling_ilql.py b/examples/BeautifulPrompt/trlx/models/modeling_ilql.py new file mode 100644 index 0000000..e7d0ec6 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/models/modeling_ilql.py @@ -0,0 +1,488 @@ +import gc +import os +from copy import deepcopy +from dataclasses import dataclass +from functools import reduce +from typing import Optional, Tuple + +import deepspeed # type: ignore +import numpy as np +import torch +import torch.nn.functional as F +import transformers +from torch import nn +from torchtyping import TensorType + +from trlx.data.ilql_types import ILQLBatch +from trlx.data.method_configs import MethodConfig, register_method +from trlx.models.modeling_base import PreTrainedModelWrapper +from trlx.utils.modeling import ( + flatten_dict, + get_tensor_stats, + hf_get_hidden_size, + hf_get_lm_head, + make_head, +) + + +def topk_mask(xs: torch.FloatTensor, k: int): + if k > xs.shape[-1]: + return xs + mintop = torch.topk(xs, k)[0][:, -1].unsqueeze(-1) + return torch.where(xs < mintop, -np.inf * torch.ones_like(xs, dtype=xs.dtype), xs) + + +def batched_index_select( + x: TensorType["batch", "seq_len", "hidden"], + idxs: TensorType["batch", "index_len"], + dim: int, +) -> TensorType["batch", "index_len", "hidden"]: + """ + Gather vectors at idxs along dim from x + """ + idxs = idxs.unsqueeze(-1).expand(idxs.shape[0], idxs.shape[1], x.shape[-1]) + return x.gather(dim=dim, index=idxs) + + +@dataclass +@register_method +class ILQLConfig(MethodConfig): + tau: float + gamma: float + cql_scale: float + awac_scale: float + alpha: float + beta: float + steps_for_target_q_sync: float + two_qs: bool + gen_kwargs: dict + + def loss(self, outputs, labels): + logits, (qs, target_qs, vs) = outputs + terminal_mask = labels.dones[:, :-1] + n_nonterminal = max(1, terminal_mask.sum()) + # check type of labels + if isinstance(labels, ILQLBatch): + actions = labels.input_ids[:, 1:].gather(dim=1, index=labels.actions_ixs).unsqueeze(-1) + else: + actions = labels.decoder_input_ids[:, 1:].unsqueeze(-1) + nactions = actions.shape[1] + bsize, _, dsize = logits.shape + + Q = [q.gather(-1, actions).squeeze(-1) for q in qs] + targetQs = [q.gather(-1, actions).squeeze(-1).detach() for q in target_qs] + targetQ = reduce(torch.minimum, targetQs) + + # The loss_q assumes len(states) == len(rewards) + 1 + # values of current states + V = vs[:, :-1, 0] + # values of next states + Vnext = vs[:, 1:, 0] * labels.dones[:, 1:].to(vs.dtype) + # target to fit Q + Q_ = labels.rewards + self.gamma * Vnext.detach() + + loss_qs = [((Qi - Q_) * terminal_mask).pow(2).sum() / n_nonterminal for Qi in Q] + loss_q = sum(loss_qs) + + targetQ = targetQ.detach() + + loss_v = ( + ( + (targetQ >= V).int() * self.tau * (targetQ - V).pow(2) + + (targetQ < V).int() * (1 - self.tau) * (targetQ - V).pow(2) + ) + * terminal_mask + ).sum() / n_nonterminal + + def cql_loss(q): + loss = F.cross_entropy(q.reshape(-1, dsize), actions.reshape(-1), reduction="none") + loss = loss.reshape(bsize, nactions) * terminal_mask + loss = loss.sum() / n_nonterminal + return loss + + loss_cql = sum(cql_loss(q) for q in qs) + + # select logits from continuations + action_logits = batched_index_select(logits, labels.actions_ixs, dim=1) + cross_entropy = F.cross_entropy( + action_logits.reshape(-1, dsize), + actions.reshape(-1), + reduction="none", + ).reshape(bsize, nactions) + + with torch.no_grad(): + awac_weight = torch.exp(self.beta * (targetQ - V)) + + loss_awac = torch.sum(cross_entropy * awac_weight * terminal_mask) / n_nonterminal + loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac + + stats = dict( + losses=dict( + loss=loss.item(), + loss_q=loss_q.item(), + loss_v=loss_v.item(), + loss_cql=loss_cql.item(), + loss_awac=loss_awac.item(), + ), + values=get_tensor_stats(V, terminal_mask, n_nonterminal), + qvalues={str(ix): get_tensor_stats(Q[ix], terminal_mask, n_nonterminal) for ix in range(len(Q))}, + awac_weight=get_tensor_stats(awac_weight, terminal_mask, n_nonterminal), + ) + + return loss, flatten_dict(stats) + + +class ILQLHeads(nn.Module): + def __init__( + self, + hidden_size: int, + vocab_size: int, + two_qs: bool, + alpha: float, + dtype: type, + ): + super().__init__() + + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.two_qs = two_qs + self.alpha = alpha + self.v_head = make_head(self.hidden_size, 1, dtype) + + n_qs = 2 if self.two_qs else 1 + self.q_heads = nn.ModuleList(make_head(self.hidden_size, self.vocab_size, dtype) for _ in range(n_qs)) + self.target_q_heads = nn.ModuleList(deepcopy(q_head) for q_head in self.q_heads) + + for target_q_head in self.target_q_heads: + target_q_head.requires_grad_(False) + + def forward( + self, + hs: TensorType["batch", "seq_len", "hidden"], + states_ixs: Optional[TensorType["batch", "states_seq_len"]] = None, + actions_ixs: Optional[TensorType["batch", "actions_seq_len"]] = None, + **kwargs, + ) -> Tuple[ + Tuple[TensorType["batch", "actions_seq_len", "hidden"]], + Tuple[TensorType["batch", "actions_seq_len", "hidden"]], + TensorType["batch", "states_seq_len", "hidden"], + ]: + if states_ixs is not None: + states_hs = batched_index_select(hs, states_ixs, 1) + actions_hs = batched_index_select(hs, actions_ixs, 1) + else: + states_hs = actions_hs = hs + + qs = tuple(q_head(actions_hs) for q_head in self.q_heads) + target_qs = tuple(q_head(actions_hs) for q_head in self.target_q_heads) + vs = self.v_head(states_hs) + + return qs, target_qs, vs + + def _sync_target_q_heads(self, alpha): + for target_q_head, q_head in zip(self.target_q_heads, self.q_heads): + for target_param, copy_param in zip(target_q_head.parameters(), q_head.parameters()): + target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data) + + def sync_target_q_heads(self): + if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", "0") == "3": + with deepspeed.zero.GatheredParameters(list(self.parameters()), modifier_rank=0): + if deepspeed.comm.get_rank() == 0: + self._sync_target_q_heads(self.alpha) + else: + self._sync_target_q_heads(self.alpha) + + +class AutoModelForCausalLMWithILQLHeads(PreTrainedModelWrapper): + """An `AutoModel` class wrapper for `transformers` causal models wtih a language + modeling head and ILQL heads. + + References: + [1] Snell et al., "Offline RL for Natural Language Generation with Implicit Language Q Learning", + https://arxiv.org/abs/2206.11871, 2022 + """ + + _auto_model_parent_class = transformers.AutoModelForCausalLM + _supported_modules = ["ilql_heads"] + _supported_args = ["two_qs", "alpha"] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + two_qs: bool = True, + alpha: float = 0.99, + ): + super().__init__(base_model) + hidden_size = hf_get_hidden_size(self.base_model.config) + vocab_size = self.base_model.config.vocab_size + dtype = next(hf_get_lm_head(self.base_model).parameters()).dtype + self.two_qs = two_qs + self.alpha = alpha + self.ilql_heads = ILQLHeads(hidden_size, vocab_size, self.two_qs, self.alpha, dtype=dtype) + + def forward( + self, + input_ids, + attention_mask=None, + position_ids=None, + past_key_values=None, + actions_ixs=None, + states_ixs=None, + ): + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + ) + forward_kwargs["output_hidden_states"] = True + + outputs = self.base_model(**forward_kwargs) + qs, target_qs, vs = self.ilql_heads(outputs.hidden_states[-1], states_ixs=states_ixs, actions_ixs=actions_ixs) + + return outputs.logits, qs, target_qs, vs, outputs.past_key_values + + def generate( + self, + input_ids, + attention_mask=None, + position_ids=None, + past_key_values=None, + beta=1, + max_new_tokens=32, + max_length=1024, + temperature=1, + top_k=20, + logit_mask=None, + pad_token_id=None, + eos_token_id=None, + ): + """ + Generates samples akin to hf's `.generate` but with custom logp prepossessing: + changing token probabilities as to how advantageous they would be + according to value functions estimations. + """ + pad_token_id = pad_token_id if pad_token_id is not None else self.base_model.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.base_model.config.eos_token_id + + if attention_mask is None: + attention_mask = input_ids.not_equal(pad_token_id) + + if position_ids is None: + position_ids = attention_mask.cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask.eq(0), 0) + + samples = input_ids.clone() + max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1]) + + finished = torch.zeros(input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device) + for _ in range(max_new_tokens): + out = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + ) + + logits, _, target_qs, vs, past_key_values = out + if self.two_qs: + qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :]) + else: + qs = target_qs[:, -1, :] + + logits = logits[:, -1, :] + vs = vs[:, -1, :] + + if logit_mask is not None: + mask = logit_mask[input_ids[:, -1].squeeze().to(logit_mask.device)] + logits[torch.where(mask)] = -np.inf + + adv = qs - vs + pi_beta = F.log_softmax(logits, -1) + pi_top_k = topk_mask(pi_beta + beta * adv, top_k) + pi = F.softmax(pi_top_k / temperature, -1) + + input_ids = torch.multinomial(pi, num_samples=1) + input_ids = (1 - finished) * input_ids + finished * eos_token_id + finished = (input_ids == eos_token_id).long() + + samples = torch.hstack((samples, input_ids)) + attention_mask = torch.hstack((attention_mask, (input_ids != eos_token_id).long())) + position_ids = (position_ids[:, -1] + 1).view(-1, 1) + + if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", "0") != "3" and torch.all(finished): + break + + return samples + + def sync_target_q_heads(self): + self.ilql_heads.sync_target_q_heads() + + def state_dict(self, *args, **kwargs): + """ + Returns the state dictionary of the model. We add the state dictionary of the ilql heads + to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`. + """ + base_model_state_dict = self.base_model.state_dict(*args, **kwargs) + ilql_heads_state_dict = self.ilql_heads.state_dict(*args, **kwargs) + for k, v in ilql_heads_state_dict.items(): + base_model_state_dict[f"ilql_heads.{k}"] = v + return base_model_state_dict + + def post_init(self, state_dict): + """ + We add the state dictionary of the ilql heads to the state dictionary of the wrapped model + by preprending the key with `ilql_heads.`. This function removes the `ilql_heads.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "ilql_heads." in k: + state_dict[k.replace("ilql_heads.", "")] = state_dict.pop(k) + self.ilql_heads.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() + + +class AutoModelForSeq2SeqLMWithILQLHeads(PreTrainedModelWrapper): + """This is a wrapper around huggingface AutoModelForSeq2Seq with two additional scalar heads""" + + _auto_model_parent_class = transformers.AutoModelForSeq2SeqLM + _supported_modules = ["ilql_heads"] + _supported_args = ["two_qs", "alpha"] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + two_qs: bool = True, + alpha: float = 0.99, + ): + super().__init__(base_model) + hidden_size = hf_get_hidden_size(self.base_model.config) + vocab_size = self.base_model.config.vocab_size + dtype = next(hf_get_lm_head(self.base_model).parameters()).dtype + self.two_qs = two_qs + self.alpha = alpha + self.ilql_heads = ILQLHeads(hidden_size, vocab_size, self.two_qs, self.alpha, dtype=dtype) + + def sync_target_q_heads(self): + self.ilql_heads.sync_target_q_heads() + + def state_dict(self, *args, **kwargs): + """ + Returns the state dictionary of the model. We add the state dictionary of the ilql heads + to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`. + """ + base_model_state_dict = self.base_model.state_dict(*args, **kwargs) + ilql_heads_state_dict = self.ilql_heads.state_dict(*args, **kwargs) + for k, v in ilql_heads_state_dict.items(): + base_model_state_dict[f"ilql_heads.{k}"] = v + return base_model_state_dict + + def post_init(self, state_dict): + """ + We add the state dictionary of the ilql heads to the state dictionary of the wrapped model + by preprending the key with `ilql_heads.`. This function removes the `ilql_heads.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "ilql_heads." in k: + state_dict[k.replace("ilql_heads.", "")] = state_dict.pop(k) + self.ilql_heads.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() + + def forward( + self, + input_ids, + attention_mask=None, + decoder_input_ids=None, + past_key_values=None, + encoder_outputs=None, + actions_ixs=None, + states_ixs=None, + output_attentions=True, + output_hidden_states=True, + ): + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + past_key_values=past_key_values, + encoder_outputs=encoder_outputs, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + out = self.base_model(**forward_kwargs) + + hs = out.decoder_hidden_states[-1] + + logits = self.base_model.lm_head(hs) + qs, target_qs, vs = self.ilql_heads(hs, states_ixs=states_ixs, actions_ixs=actions_ixs) + encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions) + return logits, qs, target_qs, vs, out.past_key_values, encoder_outputs + + def generate( + self, + input_ids, + attention_mask=None, + decoder_input_ids=None, + past_key_values=None, + encoder_outputs=None, + beta=1, + max_new_tokens=32, + max_length=1024, + temperature=1, + top_k=20, + logit_mask=None, + pad_token_id=None, + eos_token_id=None, + ): + """ + Generates samples akin to hf's `.generate` but with custom logp prepossessing: + changing token probabilities as to how advantageous they would be + according to value functions estimations. + """ + + if eos_token_id is None or pad_token_id is None: + raise ValueError("eos_token_id and pad_token_id must be provided") + + if attention_mask is None: + attention_mask = input_ids.not_equal(pad_token_id) + + samples = input_ids.clone() + max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1]) + if decoder_input_ids is None: + decoder_input_ids = input_ids.new_zeros(input_ids.shape[0], 1) + + finished = torch.zeros(input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device) + for _ in range(max_new_tokens): + out = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids[:, -1].unsqueeze(-1), + past_key_values=past_key_values, + encoder_outputs=encoder_outputs, + ) + logits, _, target_qs, vs, past_key_values, encoder_outputs = out + if self.two_qs: + qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :]) + else: + qs = target_qs[:, -1, :] + + logits = logits[:, -1, :] + vs = vs[:, -1, :] + adv = qs - vs + pi_beta = F.log_softmax(logits, -1) + pi_top_k = topk_mask(pi_beta + beta * adv, top_k) + pi = F.softmax(pi_top_k / temperature, -1) + next_tokens = torch.multinomial(pi, num_samples=1) + next_tokens = (1 - finished) * next_tokens + finished * eos_token_id + finished = (next_tokens == eos_token_id).long() | (next_tokens == pad_token_id).long() + decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) + samples = decoder_input_ids + if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", "0") != "3" and torch.all(finished): + break + + return samples diff --git a/examples/BeautifulPrompt/trlx/models/modeling_nemo_ilql.py b/examples/BeautifulPrompt/trlx/models/modeling_nemo_ilql.py new file mode 100644 index 0000000..31ac49a --- /dev/null +++ b/examples/BeautifulPrompt/trlx/models/modeling_nemo_ilql.py @@ -0,0 +1,786 @@ +# Extensible version of the GPT model +import sys +from copy import deepcopy +from functools import partial, reduce +from math import sqrt +from pathlib import Path +from typing import List, Mapping, Optional, Tuple, Union + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F +from apex.transformer import parallel_state, tensor_parallel +from apex.transformer.tensor_parallel.mappings import ( + gather_from_sequence_parallel_region, +) +from einops import rearrange +from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( + MegatronPretrainingBatchSampler, +) +from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import ( + post_language_model_processing, +) +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import ( + MegatronGPTModel, +) +from nemo.collections.nlp.modules.common.megatron.module import ( + Float16Module, + MegatronModule, +) +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_ltor_masks_and_position_ids, +) +from nemo.collections.nlp.modules.common.transformer.text_generation import ( + LengthParam, + OutputType, + SamplingParam, +) +from nemo.collections.nlp.parts.utils_funcs import get_last_rank + +from trlx.data.ilql_types import ILQLBatch, unflatten_dataclass +from trlx.models.modeling_ilql import ILQLConfig, batched_index_select +from trlx.utils import to_device, tree_map + + +class ParallelLinear(nn.Module): + """Linear layer parallelized over the longer dimension.""" + + def __init__( + self, + in_size: int, + out_size: int, + init_method=partial(nn.init.kaiming_uniform_, a=sqrt(5), nonlinearity="relu"), + use_cpu_initialization=False, + bias=True, + sequence_parallel=False, + gradient_accumulation_fusion=False, + gather_output=True, + input_is_parallel=False, + ): + super().__init__() + + no_async_tensor_model_parallel_allreduce = ( + parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel + ) + + if in_size < out_size: + self.layer = tensor_parallel.ColumnParallelLinear( + in_size, + out_size, + gather_output=gather_output, + init_method=init_method, + skip_bias_add=False, + use_cpu_initialization=use_cpu_initialization, + bias=bias, + sequence_parallel_enabled=sequence_parallel, + no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, + gradient_accumulation_fusion=gradient_accumulation_fusion, + ) + else: + self.layer = tensor_parallel.RowParallelLinear( + in_size, + out_size, + input_is_parallel=input_is_parallel, + init_method=init_method, + skip_bias_add=False, + use_cpu_initialization=use_cpu_initialization, + bias=bias, + sequence_parallel_enabled=sequence_parallel, + gradient_accumulation_fusion=gradient_accumulation_fusion, + ) + + def forward(self, x): + output, bias = self.layer(x) + if bias is not None: + return output + bias + return output + + +def make_parallel_head(n_embd: int, out: int, sequence_parallel=False) -> nn.Sequential: + """Returns a generic sequential model parallel MLP head.""" + parallel_intermediate = out < (n_embd * 2) + return nn.Sequential( + ParallelLinear( + n_embd, + n_embd * 2, + sequence_parallel=sequence_parallel, + gather_output=not parallel_intermediate, + ), + nn.ReLU(), + ParallelLinear( + n_embd * 2, + out, + sequence_parallel=sequence_parallel, + input_is_parallel=parallel_intermediate, + ), + ) + + +class ParallelILQLHeads(nn.Module): + def __init__( + self, + config: ILQLConfig, + hidden_size: int, + vocab_size: int, + sequence_parallel=False, + ): + super().__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.v_head = make_parallel_head(hidden_size, 1, sequence_parallel=sequence_parallel) + self.config = config + + n_qs = 2 if self.config.two_qs else 1 + + self.q_heads = nn.ModuleList(make_parallel_head(self.hidden_size, self.vocab_size) for _ in range(n_qs)) + + self.target_q_heads = nn.ModuleList(deepcopy(q_head) for q_head in self.q_heads) + self.target_q_heads.requires_grad_(False) + + def forward(self, hidden_states): + qs = tuple(q_head(hidden_states) for q_head in self.q_heads) + target_qs = tuple(q_head(hidden_states) for q_head in self.target_q_heads) + vs = self.v_head(hidden_states) + + qs, target_qs, vs = tree_map(lambda t: rearrange(t, "T N ... -> N T ..."), (qs, target_qs, vs)) + + return qs, target_qs, vs + + def _sync_target_q_heads(self, alpha: float): + for target_q_head, q_head in zip(self.target_q_heads, self.q_heads): + for target_param, copy_param in zip(target_q_head.parameters(), q_head.parameters()): + target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data) + + def sync_target_q_heads(self): + self._sync_target_q_heads(self.config.alpha) + + +class LMHeads(MegatronModule): + def __init__(self, language_model, other_heads): + super().__init__() + # must be this attribute name + self.pre_process = language_model.pre_process + self.post_process = language_model.post_process + self.language_model = language_model + + self.other_heads = other_heads + + if hasattr(language_model, "word_embeddings"): + self.word_embeddings = language_model.word_embeddings + + # The tensor from the previous pipeline rank arrives via this method + def set_input_tensor(self, input_tensor): + return self.language_model.set_input_tensor(input_tensor) + + def word_embeddings_weight(self): + return self.language_model.word_embeddings_weight() + + def load_state_dict(self, lm_state_dict, strict=True): + """Load GPTModel state dict.""" + self.language_model.language_model.load_state_dict(lm_state_dict, strict=strict) + + def forward( + self, + *args, + get_key_value=False, + forward_method_parallel_output=None, + **kwargs, + ): + lm_output = self.language_model(*args, get_key_value=get_key_value, **kwargs) + logits = post_language_model_processing( + lm_output, + labels=None, + logit_weights=self.language_model.word_embeddings_weight(), + get_key_value=get_key_value, + parallel_output=False, # self.language_model.parallel_output, + forward_method_parallel_output=forward_method_parallel_output, + fp16_lm_cross_entropy=self.language_model.fp16_lm_cross_entropy, + return_logits=True, + sequence_parallel=self.language_model.sequence_parallel, + gradient_accumulation_fusion=self.language_model.gradient_accumulation_fusion, + ) + + if get_key_value: + logits, presents = logits + lm_output, lm_output_presents = lm_output + + heads_output = self.other_heads(lm_output) + return logits, heads_output + + +def unwrap_float16_module(module): + if isinstance(module, Float16Module): + return module.module + return module + + +def reshard_for_pipeline_parallelism(num_layers, state_dict): + """Filter out the layers that are not in the current pipeline stage + and shift the layer ids to match the local stage layer ids.""" + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + + stage_layers = num_layers // pp_size + pp_offset = pp_rank * stage_layers + + encoder_layers_key = "model.language_model.encoder.layers." + + def filter_in_pp_rank(key): + if key.startswith(encoder_layers_key): + layer_idx = int(key.split(".")[4]) + return pp_offset <= layer_idx < (pp_offset + stage_layers) + elif key.startswith("model.language_model.encoder.final_layernorm") and not pp_rank == (pp_size - 1): + return False + else: + return True + + def shift_layer_idx(key): + """If the key is for a transformer layer, shift down the layer index to select the + correct layer for this pipeline stage.""" + if key.startswith(encoder_layers_key): + layer_idx = int(key.split(".")[4]) + return f"{encoder_layers_key}{str(layer_idx - pp_offset)}.{'.'.join(key.split('.')[5:])}" + else: + return key + + state_dict = {shift_layer_idx(k): v for k, v in state_dict.items() if filter_in_pp_rank(k)} + + return state_dict + + +class ILQLGPT(MegatronGPTModel): + ilql_config: ILQLConfig + + def __init__(self, ilql_config, metric_fn=None, **kwargs): + self.ilql_config = ilql_config + self.metric_fn = metric_fn + super().__init__(**kwargs) + if len(list(self.parameters())) == 0: + raise ValueError("No parameters in model") + + self._ori_activations_checkpoint_granularity = self.cfg.get("activations_checkpoint_granularity", None) + self._ori_activations_checkpoint_method = self.cfg.get("activations_checkpoint_method", None) + self._ori_activations_checkpoint_num_layers = self.cfg.get("activations_checkpoint_num_layers", None) + + @classmethod + def list_available_models(cls) -> Optional[Mapping[str, str]]: + return None + + def build_train_valid_test_datasets(self): + pass + + def build_data_loader(self, dataset, collate_fn, consumed_samples=0): + dp_rank = parallel_state.get_data_parallel_rank() + dp_size = parallel_state.get_data_parallel_world_size() + print( + f"Building data loader for {type(dataset)=} {len(dataset)=} {dp_rank=} {dp_size=}", + file=sys.stderr, + ) + batch_sampler = MegatronPretrainingBatchSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=self.cfg.micro_batch_size, + global_batch_size=self.cfg.global_batch_size, + data_parallel_rank=dp_rank, + data_parallel_size=dp_size, + drop_last=True, + ) + return torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + # For some reason this causes a crash when using >0 workers + # with grad accumulation > 1 + num_workers=0, + pin_memory=True, + collate_fn=collate_fn, + ) + + def set_train_dataset(self, train_dataset, collate_fn): + self._train_dataset = train_dataset + self._train_collate_fn = collate_fn + + def set_valid_dataset(self, valid_dataset, collate_fn): + self._valid_dataset = valid_dataset + self._valid_collate_fn = collate_fn + + # Called by superclass to build data loaders + def setup_training_data(self, _): + if hasattr(self, "_train_dataset"): + self._train_dl = self.build_data_loader(self._train_dataset, self._train_collate_fn) + + def setup_validation_data(self, _): + if hasattr(self, "_valid_dataset"): + self._validation_dl = self.build_data_loader(self._valid_dataset, self._valid_collate_fn) + + def load_from_pretrained(self, checkpoint_dir): + mp_rank = parallel_state.get_tensor_model_parallel_rank() + rank_subfolder = f"mp_rank_{mp_rank:02d}" + rank_params = Path(checkpoint_dir) / rank_subfolder / "model_weights.ckpt" + print(f"Loading from {rank_params}") + state_dict = torch.load(rank_params) + + state_dict = reshard_for_pipeline_parallelism(self.cfg.num_layers, state_dict) + + def trim_key(key, prefix): + assert key.startswith(prefix), f"key {key} in state_dict does not start with {prefix}" + return key[len(prefix) :] + + lm_state_dict = {trim_key(k, "model.language_model."): v for k, v in state_dict.items()} + + encoder_state_dict = {trim_key(k, "encoder."): v for k, v in lm_state_dict.items() if k.startswith("encoder.")} + + lm_state_dict = {**lm_state_dict, "encoder": encoder_state_dict} + + unwrap_float16_module(self.model).load_state_dict(lm_state_dict, strict=True) + print(f"Loaded from pretrained {rank_params}") + + def model_provider_func(self, pre_process: bool, post_process: bool): + """ + Model construction for Apex Pipeline Parallelism. + Each rank will construct the model but inside the model, + only the relevant layers for that rank should be constructed. + On the first rank, pre_process will be True + On the last rank, post_process will be True + """ + gpt = super().model_provider_func(pre_process, post_process=post_process) + # This disables post-processing the lm output to the vocab + gpt.post_process = False + # This enables the final layernorm in the GPT model if there is one + gpt.language_model.post_process = post_process + # If running on the last pipeline stage, add the ILQL heads + if post_process: + parallel_ilql_heads = ParallelILQLHeads( + self.ilql_config, + self.cfg.hidden_size, + self.padded_vocab_size, + self.cfg.sequence_parallel, + ) + + return LMHeads( + gpt, + parallel_ilql_heads, + ) + else: + return gpt + + # Adapted from NeMo + # https://github.com/NVIDIA/NeMo/blob/r1.13.0/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L259 + def training_step(self, batch: ILQLBatch, batch_idx: int): # noqa: C901 + """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + """ + # we zero grads here because we also call backward in the apex fwd/bwd functions + self._optimizer.zero_grad() + + if parallel_state.is_pipeline_first_stage(ignore_virtual=True) or parallel_state.is_pipeline_last_stage( + ignore_virtual=True + ): + # we prepare the micro batches for the apex fwd/bwd function + batch_for_pipeline = batch + else: + # The intermediate pipeline stages do not need any inputs from data loader + # GPT3 uses decoder with AttnMask:causal, thus doesn't need attention_mask + batch_for_pipeline = None + + # Pipeline stages will transfer this shape tensor to and from the + # previous and next stages + # The model must output a tensor of this shape if not the last pipeline + # stage. The model is given input of this shape if not the first pipeline + # stage via .set_input_tensor + tensor_shape = [ + self.cfg.encoder_seq_length, + self.cfg.micro_batch_size, + self.cfg.hidden_size, + ] + + # handle asynchronous grad reduction + if self.with_distributed_adam: + if self.megatron_amp_o2: + # copy grads to main grad + def custom_sync_context_handler(): + return self._optimizer.no_sync(greedy_grad_copy=True) + + else: + # keep grad tensors around + def custom_sync_context_handler(): + return self._optimizer.no_sync(greedy_grad_copy=False) + + else: + if self.megatron_amp_o2 and not self.cfg.get("sequence_parallel", False): + custom_sync_context_handler = self._optimizer.no_sync + else: + # TODO: enable async grad all reduce for O1/autocast mixed precision training + custom_sync_context_handler = None + + # run forward and backwards passes for an entire global batch + # we do this inside training_step to support pipeline parallelism + # This gets the correct fwd/bwd pipeline step depending on the pipeline + # parallelism configuration + fwd_bwd_function = self._get_fwd_bwd_function() + + last_stage_output = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + batch=batch_for_pipeline, + model=self.model, + forward_only=False, + tensor_shape=tensor_shape, + dtype=self.autocast_dtype, + grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, + custom_sync_context_handler=custom_sync_context_handler, + sequence_parallel_enabled=self.cfg.get("sequence_parallel", False), + sync_batch_comm=self.cfg.get("sync_batch_comm", False), + num_micro_batches_with_partial_activation_checkpoints=self.cfg.get( + "num_micro_batches_with_partial_activation_checkpoints", None + ), + ) + + # only the last stages of the pipeline return losses + if last_stage_output: + # average loss across micro batches + outputs = {k: [output[k] for output in last_stage_output] for k in last_stage_output[0].keys()} + outputs = {k: torch.concat([torch.as_tensor(vi).unsqueeze(0) for vi in v]) for k, v in outputs.items()} + + mean_outputs = {k: v.mean() for k, v in outputs.items()} + loss_mean = mean_outputs["avg_loss"] + else: + mean_outputs = {} + loss_mean = torch.tensor(0.0).cuda() + + # when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced + if self.cfg.get("tensor_model_parallel_size", 1) > 1 and self.cfg.get("sequence_parallel", False): + self.allreduce_sequence_parallel_gradients() + if self.with_distributed_adam: + # launch grad reductions + # Note: grads in first pipeline stage have already been + # reduced + if not parallel_state.is_pipeline_first_stage(): + self.reduce_overlap_gradients() + elif self.megatron_amp_o2: + # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) + if self.cfg.get("pipeline_model_parallel_size", 1) > 1 or self.cfg.get("sequence_parallel", False): + # main grads are stored in the MainParamsOptimizer wrapper + self._optimizer.allreduce_main_grads() + else: + # async grad allreduce is not currently implemented for O1/autocasting mixed precision training + # so we all-reduce gradients after the pipeline + self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) + + if self.cfg.get("pipeline_model_parallel_size", 1) > 1: + # when using pipeline parallelism the first and last stage must keep embeddings in sync + self.allreduce_first_last_embeddings() + + # we can only log on one rank if it is rank zero so we broadcast from last rank + # we can avoid this broadcast by updating the PTL log function to accept specific ranks + torch.distributed.broadcast(loss_mean, get_last_rank()) + + if self.cfg.precision == 16: + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log("loss_scale", loss_scale) + + self.log( + "reduced_train_loss", + loss_mean, + prog_bar=True, + rank_zero_only=True, + ) + + for k, v in mean_outputs.items(): + if k != "avg_loss": + self.log(k, v) + + self.log( + "global_step", + float(self.trainer.global_step), + prog_bar=True, + rank_zero_only=True, + ) + + if self.trainer.global_step % self.ilql_config.steps_for_target_q_sync == 0 and self.trainer.global_step > 0: + if parallel_state.is_pipeline_last_stage(): + unwrap_float16_module(self.model).other_heads.sync_target_q_heads() + + return loss_mean + + def activation_checkpointing_(self, enable: bool): + def toggle_checkpointing(module): + if hasattr(module, "activations_checkpoint_granularity"): + if enable: + module.activations_checkpoint_granularity = self._ori_activations_checkpoint_granularity + else: + module.activations_checkpoint_granularity = None + + if hasattr(module, "activations_checkpoint_method"): + if enable: + module.activations_checkpoint_method = self._ori_activations_checkpoint_method + else: + module.activations_checkpoint_method = None + + if hasattr(module, "activations_checkpoint_num_layers"): + if enable: + module.activations_checkpoint_num_layers = self._ori_activations_checkpoint_num_layers + else: + module.activations_checkpoint_num_layers = None + + self.model.apply(toggle_checkpointing) + + if enable: + self.cfg.activations_checkpoint_granularity = self._ori_activations_checkpoint_granularity + self.cfg.activations_checkpoint_method = self._ori_activations_checkpoint_method + self.cfg.activations_checkpoint_num_layers = self._ori_activations_checkpoint_num_layers + else: + self.cfg.activations_checkpoint_granularity = None + self.cfg.activations_checkpoint_method = None + self.cfg.activations_checkpoint_num_layers = None + + # TODO: replace this with less magical code + def sequence_parallel_(self, enabled: bool): + self.cfg.sequence_parallel = enabled + + def toggle_sp(m): + if hasattr(m, "sequence_parallel"): + m.sequence_parallel = enabled + + # for the Row/ColumnParallelLinear layers + if hasattr(m, "sequence_parallel_enabled"): + if hasattr(m, "input_is_parallel"): + m.sequence_parallel_enabled = enabled and m.input_is_parallel + elif hasattr(m, "gather_output"): + m.sequence_parallel_enabled = enabled and not m.gather_output + else: + m.sequence_parallel_enabled = enabled + + self.model.apply(toggle_sp) + + def validation_step(self, batch: Tuple[List[int], List[int]], batch_idx: int): + if self.metric_fn is None: + raise ValueError("Must set metric_fn to use validation") + + sp_was_enabled = self.cfg.get("sequence_parallel", False) + if sp_was_enabled: + self.sequence_parallel_(False) + + activations_checkpointing_was_enabled = self.cfg.get("activations_checkpoint_granularity", None) is not None + + if activations_checkpointing_was_enabled: + self.activation_checkpointing_(False) + + input_ids, lengths = batch + input_ids, lengths = torch.as_tensor(input_ids), torch.as_tensor(lengths) + + input_ids, lengths = to_device((input_ids, lengths), torch.cuda.current_device(), non_blocking=True) + + max_new_tokens = self.ilql_config.gen_kwargs.get("max_new_tokens", 64) + + gen = self.generate((input_ids, lengths), dict(max_length=max_new_tokens, min_length=0)) + + metrics = self.metric_fn(gen["sentences"]) + + metric_keys, metric_values = zip(*metrics.items()) + + columns = ["sentences", *metric_keys] + rows = list(zip(gen["sentences"], *metric_values)) + + avg_metrics = {f"avg_{k}": torch.as_tensor(v).mean() for k, v in metrics.items()} + + if activations_checkpointing_was_enabled: + self.activation_checkpointing_(True) + + if sp_was_enabled: + self.sequence_parallel_(True) + + # NeMo generate resets the microbatch calculator + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator, + ) + from nemo.utils import AppState + + _reconfigure_microbatch_calculator( + rank=AppState().global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.global_batch_size, + micro_batch_size=self.cfg.micro_batch_size, + data_parallel_size=AppState().data_parallel_size, + ) + + return avg_metrics, (rows, columns) + + def validation_epoch_end(self, outputs: List[Tuple[dict, Tuple[List[str], List[str]]]]): + metrics, tables = zip(*outputs) + _, columns = tables[0] + rows = [r for trows, _ in tables for r in trows] + + self.logger.log_text(key="samples", columns=columns, data=rows) + + outputs_soa = {k: torch.as_tensor([d[k] for d in metrics]) for k in metrics[0].keys()} + # this assumes all validation microbatches are the same size + avg_outputs = {k: v.mean() for k, v in outputs_soa.items()} + for k, v in avg_outputs.items(): + self.log( + f"val_metrics/{k}", + v, + prog_bar=True, + rank_zero_only=True, + sync_dist=True, + ) + + # Need to override this otherwise distributed fused adam won't work + # with frozen layers + def parameters(self): + return (p for p in self.model.parameters() if p.requires_grad) + + def get_forward_output_and_loss_func(self, validation_step=False): + def fwd_output_and_loss_func(batch: List[torch.Tensor], model, checkpoint_activations_all_layers=None): + # On first and last pipeline stages, the input data is passed in + if batch is not None: + batch = unflatten_dataclass(ILQLBatch)(batch) + batch = to_device(batch, torch.cuda.current_device(), non_blocking=True) + + inputs = batch.input_ids + pad_by = self.cfg.encoder_seq_length - inputs.shape[1] + inputs = torch.nn.functional.pad(inputs, (0, pad_by), value=self.tokenizer.eos_id) + + ( + attention_mask, + loss_mask, + position_ids, + ) = get_ltor_masks_and_position_ids( + data=inputs, + eod_token=self.tokenizer.eos_id, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + ) + + model_output = model( + input_ids=inputs, + position_ids=position_ids.long(), + attention_mask=attention_mask, + ) + else: + # In-between stages are given data via the pipeline engine + # Still need to specify thes arguments to avoid errors + model_output = model(input_ids=None, position_ids=None, attention_mask=None) + + def gather_ntc(t: torch.Tensor): + """Gather sequence parallel tensor [batch, seq, hidden]""" + t = rearrange(t, "N T ... -> T N ...") + t = gather_from_sequence_parallel_region(t, to_model_parallel=False) + t = rearrange(t, "T N ... -> N T ...") + return t + + def loss_func(model_output): + # # TODO: implement this in a sequence parallel way + logits, (qs, target_qs, vs) = model_output + + if self.cfg.sequence_parallel: + qs, target_qs, vs = tree_map(gather_ntc, (qs, target_qs, vs)) + + qs = tree_map( + lambda t: batched_index_select(t, batch.actions_ixs, 1), + qs, + ) + + target_qs = tree_map( + lambda t: batched_index_select(t, batch.actions_ixs, 1), + target_qs, + ) + + vs = batched_index_select(vs, batch.states_ixs, 1) + + model_output = (logits, (qs, target_qs, vs)) + loss_for_mb, stats = self.ilql_config.loss(model_output, batch) + + reduced_loss = average_losses_across_data_parallel_group([loss_for_mb]) + + # TODO: figure out why this sync is needed (crashes otherwise) + torch.cuda.synchronize() + + return loss_for_mb, {"avg_loss": reduced_loss, **stats} + + return model_output, loss_func + + return fwd_output_and_loss_func + + def get_forward_output_only_func( + self, + set_inference_key_value_memory=False, + inference_max_sequence_len=None, + checkpoint_activations_all_layers=None, + ): + def fwd_output_only_func( + batch: torch.Tensor, + model, + ): + if batch is not None: + batch = to_device(batch, torch.cuda.current_device(), non_blocking=True) + + extra_arg = {} + + if len(batch) == 3: + tokens, attention_mask, position_ids = batch + else: + ( + tokens, + attention_mask, + position_ids, + set_inference_key_value_memory, + inference_max_sequence_len, + ) = batch + + extra_arg["set_inference_key_value_memory"] = set_inference_key_value_memory[0].item() + extra_arg["inference_max_sequence_len"] = inference_max_sequence_len[0].item() + + model_output = model( + input_ids=tokens, + position_ids=position_ids.long(), + attention_mask=attention_mask, + **extra_arg, + ) + else: + model_output = model(input_ids=None, position_ids=None, attention_mask=None) + + def ilql_postprocess(model_output): + model_output = tree_map(lambda t: t.float(), model_output) + + logits, (_, target_qs, vs) = model_output + + target_q = reduce(torch.minimum, target_qs) + advantage = target_q - vs + pi_beta = F.log_softmax(logits, -1) + beta = self.ilql_config.gen_kwargs.get("beta", 1.0) + + logits = pi_beta + beta * advantage + + return logits, {"logits": logits} + + return model_output, ilql_postprocess + + return fwd_output_only_func + + def generate( + self, + inputs: Union[List[str], torch.Tensor, List[dict]], + length_params: LengthParam, + sampling_params: SamplingParam = None, + ) -> OutputType: + if sampling_params is None: + sampling_params = { + "use_greedy": False, + "temperature": self.ilql_config.gen_kwargs.get("temperature", 1.0), + "top_k": self.ilql_config.gen_kwargs.get("top_k", 0), + "top_p": 0.9, + "repetition_penalty": 1.2, + "add_BOS": False, + "all_probs": False, + "compute_logprob": False, + } + + return super().generate(inputs, length_params, sampling_params) diff --git a/examples/BeautifulPrompt/trlx/models/modeling_nemo_sft.py b/examples/BeautifulPrompt/trlx/models/modeling_nemo_sft.py new file mode 100644 index 0000000..ac697ba --- /dev/null +++ b/examples/BeautifulPrompt/trlx/models/modeling_nemo_sft.py @@ -0,0 +1,513 @@ +# Extensible version of the GPT model +import logging +import sys +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +import torch.distributed +from apex.transformer import tensor_parallel +from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( + MegatronPretrainingBatchSampler, +) +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import ( + MegatronGPTModel, +) +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, +) +from nemo.collections.nlp.modules.common.transformer.text_generation import ( + LengthParam, + OutputType, + SamplingParam, +) +from nemo.collections.nlp.parts.utils_funcs import get_last_rank + +from trlx.models.modeling_nemo_ilql import ( + reshard_for_pipeline_parallelism, + unwrap_float16_module, +) +from trlx.trainer.accelerate_sft_trainer import SFTConfig +from trlx.utils import to_device + +try: + from apex.transformer import parallel_state + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + + +class SFTGPT(MegatronGPTModel): + sft_config: SFTConfig + + def __init__(self, sft_config: SFTConfig, metric_fn: Optional[Callable[[List[str]], Any]] = None, **kwargs): + self.sft_config = sft_config + self.metric_fn = metric_fn + super().__init__(**kwargs) + if len(list(self.parameters())) == 0: + raise ValueError("No parameters in model") + + self._ori_activations_checkpoint_granularity = self.cfg.get("activations_checkpoint_granularity", None) + self._ori_activations_checkpoint_method = self.cfg.get("activations_checkpoint_method", None) + self._ori_activations_checkpoint_num_layers = self.cfg.get("activations_checkpoint_num_layers", None) + + def build_train_valid_test_datasets(self): + pass + + def build_data_loader(self, dataset, collate_fn, consumed_samples=0): + dp_rank = parallel_state.get_data_parallel_rank() + dp_size = parallel_state.get_data_parallel_world_size() + logging.info( + f"Building data loader for {type(dataset)=} {len(dataset)=} {dp_rank=} {dp_size=}", + file=sys.stderr, + ) + batch_sampler = MegatronPretrainingBatchSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=self.cfg.micro_batch_size, + global_batch_size=self.cfg.global_batch_size, + data_parallel_rank=dp_rank, + data_parallel_size=dp_size, + drop_last=True, + ) + return torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + # For some reason this causes a crash when using >0 workers + # with grad accumulation > 1 + num_workers=0, + pin_memory=True, + collate_fn=collate_fn, + ) + + def set_train_dataset(self, train_dataset, collate_fn: Optional[callable] = None): + self._train_dataset = train_dataset + self._train_collate_fn = collate_fn + + def set_valid_dataset(self, valid_dataset, collate_fn: Optional[callable] = None): + self._valid_dataset = valid_dataset + self._valid_collate_fn = collate_fn + + def setup_training_data(self, _): + if hasattr(self, "_train_dataset"): + self._train_dl = self.build_data_loader(self._train_dataset, self._train_collate_fn) + + def setup_validation_data(self, _): + if hasattr(self, "_valid_dataset"): + self._validation_dl = self.build_data_loader(self._valid_dataset, self._valid_collate_fn) + + def load_from_pretrained(self, checkpoint_dir): + mp_rank = parallel_state.get_tensor_model_parallel_rank() + checkpoint_path = Path(checkpoint_dir) + + # Check if there are rank subfolders + rank_subfolder = f"mp_rank_{mp_rank:02d}" + rank_params = checkpoint_path / rank_subfolder / "model_weights.ckpt" + + print(f"Loading from {rank_params}") + state_dict = torch.load(rank_params) + + state_dict = reshard_for_pipeline_parallelism(self.cfg.num_layers, state_dict) + + def trim_key(key, prefix): + assert key.startswith(prefix), f"key {key} in state_dict does not start with {prefix}" + return key[len(prefix) :] + + lm_state_dict = {trim_key(k, "model.language_model."): v for k, v in state_dict.items()} + + encoder_state_dict = {trim_key(k, "encoder."): v for k, v in lm_state_dict.items() if k.startswith("encoder.")} + + lm_state_dict = {**lm_state_dict, "encoder": encoder_state_dict} + + unwrap_float16_module(self.model).load_state_dict(lm_state_dict, strict=True) + print(f"Loaded from pretrained {rank_params}") + + # Adapted from NeMo + # https://github.com/NVIDIA/NeMo/blob/r1.13.0/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L259 + def training_step(self, batch: List[torch.Tensor], batch_idx: int): # noqa: C901 + """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + """ + # we zero grads here because we also call backward in the apex fwd/bwd functions + self._optimizer.zero_grad() + + if parallel_state.is_pipeline_first_stage(ignore_virtual=True) or parallel_state.is_pipeline_last_stage( + ignore_virtual=True + ): + # we prepare the micro batches for the apex fwd/bwd function + batch_for_pipeline = batch + else: + # The intermediate pipeline stages do not need any inputs from data loader + # GPT3 uses decoder with AttnMask:causal, thus doesn't need attention_mask + batch_for_pipeline = None + + # Pipeline stages will transfer this shape tensor to and from the + # previous and next stages + # The model must output a tensor of this shape if not the last pipeline + # stage. The model is given input of this shape if not the first pipeline + # stage via .set_input_tensor + tensor_shape = [ + self.cfg.encoder_seq_length, + self.cfg.micro_batch_size, + self.cfg.hidden_size, + ] + + # handle asynchronous grad reduction + if self.with_distributed_adam: + if self.megatron_amp_o2: + # copy grads to main grad + def custom_sync_context_handler(): + return self._optimizer.no_sync(greedy_grad_copy=True) + + else: + # keep grad tensors around + def custom_sync_context_handler(): + return self._optimizer.no_sync(greedy_grad_copy=False) + + else: + if self.megatron_amp_o2 and not self.cfg.get("sequence_parallel", False): + custom_sync_context_handler = self._optimizer.no_sync + else: + # TODO: enable async grad all reduce for O1/autocast mixed precision training + custom_sync_context_handler = None + + # run forward and backwards passes for an entire global batch + # we do this inside training_step to support pipeline parallelism + # This gets the correct fwd/bwd pipeline step depending on the pipeline + # parallelism configuration + fwd_bwd_function = self._get_fwd_bwd_function() + + last_stage_output = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + batch=batch_for_pipeline, + model=self.model, + forward_only=False, + tensor_shape=tensor_shape, + dtype=self.autocast_dtype, + grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, + custom_sync_context_handler=custom_sync_context_handler, + sequence_parallel_enabled=self.cfg.get("sequence_parallel", False), + sync_batch_comm=self.cfg.get("sync_batch_comm", False), + num_micro_batches_with_partial_activation_checkpoints=self.cfg.get( + "num_micro_batches_with_partial_activation_checkpoints", None + ), + ) + + # only the last stages of the pipeline return losses + if last_stage_output: + # average loss across micro batches + outputs = {k: [output[k] for output in last_stage_output] for k in last_stage_output[0].keys()} + outputs = {k: torch.concat([torch.as_tensor(vi).unsqueeze(0) for vi in v]) for k, v in outputs.items()} + + mean_outputs = {k: v.mean() for k, v in outputs.items()} + loss_mean = mean_outputs["avg_loss"] + else: + mean_outputs = {} + loss_mean = torch.tensor(0.0).cuda() + + # when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced + if self.cfg.get("tensor_model_parallel_size", 1) > 1 and self.cfg.get("sequence_parallel", False): + self.allreduce_sequence_parallel_gradients() + if self.with_distributed_adam: + # launch grad reductions + # Note: grads in first pipeline stage have already been + # reduced + if not parallel_state.is_pipeline_first_stage(): + self.reduce_overlap_gradients() + elif self.megatron_amp_o2: + # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) + if self.cfg.get("pipeline_model_parallel_size", 1) > 1 or self.cfg.get("sequence_parallel", False): + # main grads are stored in the MainParamsOptimizer wrapper + self._optimizer.allreduce_main_grads() + else: + # async grad allreduce is not currently implemented for O1/autocasting mixed precision training + # so we all-reduce gradients after the pipeline + self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) + + if self.cfg.get("pipeline_model_parallel_size", 1) > 1: + # when using pipeline parallelism the first and last stage must keep embeddings in sync + self.allreduce_first_last_embeddings() + + # we can only log on one rank if it is rank zero so we broadcast from last rank + # we can avoid this broadcast by updating the PTL log function to accept specific ranks + torch.distributed.broadcast(loss_mean, get_last_rank()) + + if self.cfg.precision == 16: + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log("loss_scale", loss_scale) + + self.log( + "reduced_train_loss", + loss_mean, + prog_bar=True, + rank_zero_only=True, + ) + + for k, v in mean_outputs.items(): + if k != "avg_loss": + self.log(k, v) + + self.log( + "global_step", + float(self.trainer.global_step), + prog_bar=True, + rank_zero_only=True, + ) + return loss_mean + + def activation_checkpointing_(self, enable: bool): + def toggle_checkpointing(module): + if hasattr(module, "activations_checkpoint_granularity"): + if enable: + module.activations_checkpoint_granularity = self._ori_activations_checkpoint_granularity + else: + module.activations_checkpoint_granularity = None + + if hasattr(module, "activations_checkpoint_method"): + if enable: + module.activations_checkpoint_method = self._ori_activations_checkpoint_method + else: + module.activations_checkpoint_method = None + + if hasattr(module, "activations_checkpoint_num_layers"): + if enable: + module.activations_checkpoint_num_layers = self._ori_activations_checkpoint_num_layers + else: + module.activations_checkpoint_num_layers = None + + self.model.apply(toggle_checkpointing) + + if enable: + self.cfg.activations_checkpoint_granularity = self._ori_activations_checkpoint_granularity + self.cfg.activations_checkpoint_method = self._ori_activations_checkpoint_method + self.cfg.activations_checkpoint_num_layers = self._ori_activations_checkpoint_num_layers + else: + self.cfg.activations_checkpoint_granularity = None + self.cfg.activations_checkpoint_method = None + self.cfg.activations_checkpoint_num_layers = None + + # TODO: replace this with less magical code + def sequence_parallel_(self, enabled: bool): + self.cfg.sequence_parallel = enabled + + def toggle_sp(m): + if hasattr(m, "sequence_parallel"): + m.sequence_parallel = enabled + + # for the Row/ColumnParallelLinear layers + if hasattr(m, "sequence_parallel_enabled"): + if hasattr(m, "input_is_parallel"): + m.sequence_parallel_enabled = enabled and m.input_is_parallel + elif hasattr(m, "gather_output"): + m.sequence_parallel_enabled = enabled and not m.gather_output + else: + m.sequence_parallel_enabled = enabled + + self.model.apply(toggle_sp) + + def validation_step(self, batch: Tuple[List[int], List[int]], batch_idx: int): + if self.metric_fn is None: + raise ValueError("Must set metric_fn to use validation") + + sp_was_enabled = self.cfg.get("sequence_parallel", False) + if sp_was_enabled: + self.sequence_parallel_(False) + + activations_checkpointing_was_enabled = self.cfg.get("activations_checkpoint_granularity", None) is not None + + if activations_checkpointing_was_enabled: + self.activation_checkpointing_(False) + + input_ids, lengths = batch + input_ids, lengths = torch.as_tensor(input_ids), torch.as_tensor(lengths) + + input_ids, lengths = to_device((input_ids, lengths), torch.cuda.current_device(), non_blocking=True) + + max_new_tokens = self.sft_config.gen_kwargs.get("max_new_tokens", 64) + + gen = self.generate((input_ids, lengths), dict(max_length=max_new_tokens, min_length=0)) + print(f"Generated {len(gen['sentences'])} samples:\n{gen['sentences']}") + + metrics = self.metric_fn(gen["sentences"]) + + metric_keys, metric_values = zip(*metrics.items()) + + columns = ["sentences", *metric_keys] + rows = list(zip(gen["sentences"], *metric_values)) + + avg_metrics = {f"avg_{k}": torch.as_tensor(v).mean() for k, v in metrics.items()} + + if activations_checkpointing_was_enabled: + self.activation_checkpointing_(True) + + if sp_was_enabled: + self.sequence_parallel_(True) + + # NeMo generate resets the microbatch calculator + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator, + ) + from nemo.utils import AppState + + _reconfigure_microbatch_calculator( + rank=AppState().global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.global_batch_size, + micro_batch_size=self.cfg.micro_batch_size, + data_parallel_size=AppState().data_parallel_size, + ) + + return avg_metrics, (rows, columns) + + def validation_epoch_end(self, outputs: List[Tuple[dict, Tuple[List[str], List[str]]]]): + metrics, tables = zip(*outputs) + _, columns = tables[0] + rows = [r for trows, _ in tables for r in trows] + + self.logger.log_text(key="samples", columns=columns, data=rows) + + outputs_soa = {k: torch.as_tensor([d[k] for d in metrics]) for k in metrics[0].keys()} + # this assumes all validation microbatches are the same size + avg_outputs = {k: v.mean() for k, v in outputs_soa.items()} + for k, v in avg_outputs.items(): + self.log( + f"val_metrics/{k}", + v, + prog_bar=True, + rank_zero_only=True, + sync_dist=True, + ) + + # Need to override this otherwise distributed fused adam won't work + # with frozen layers + def parameters(self): + return (p for p in self.model.parameters() if p.requires_grad) + + def build_attention_mask_and_position_ids( + self, data: torch.LongTensor + ) -> Tuple[torch.BoolTensor, torch.LongTensor]: + micro_batch_size, seq_length = data.size() + + position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0).repeat(micro_batch_size, 1) + + attention_mask = torch.tril(torch.ones((1, seq_length, seq_length), device=data.device)).view( + 1, 1, seq_length, seq_length + ) + attention_mask = attention_mask < 0.5 + return attention_mask, position_ids + + def get_forward_output_and_loss_func(self, validation_step=False): + def fwd_output_and_loss_func(batch: List[torch.Tensor], model, checkpoint_activations_all_layers=None): + # On first and last pipeline stages, the input data is passed in + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + input_ids, loss_mask = [b.cuda(non_blocking=True) for b in batch] + attention_mask, position_ids = self.build_attention_mask_and_position_ids(data=input_ids) + else: + input_ids, loss_mask, attention_mask, position_ids = None, None, None, None + + output_tensor = model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + checkpoint_activations_all_layers=checkpoint_activations_all_layers, + ) + + def loss_func(output_tensor): + # Shift logits and labels to align predictions + logits = output_tensor[:, :-1, :] + labels = input_ids[:, 1:] + _loss_mask = loss_mask[:, 1:] # Align loss mask with labels + + labels = labels.transpose(0, 1).contiguous() # [b s] -> [s b] + logits = logits.transpose(0, 1).contiguous() # [b s h] -> [s b h] + + if self.cfg.fp16_lm_cross_entropy: + assert logits.dtype == torch.half + loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels) + else: + loss = tensor_parallel.vocab_parallel_cross_entropy(logits.float(), labels) + + _loss_mask = _loss_mask.contiguous().view(-1).float() + loss = loss.transpose(0, 1).contiguous().view(-1).float() + loss_for_mb = torch.sum(loss * _loss_mask) / _loss_mask.sum() + + reduced_loss = average_losses_across_data_parallel_group([loss_for_mb]) + + # TODO: figure out why this sync is needed (crashes otherwise) + torch.cuda.synchronize() + + return loss_for_mb, {"avg_loss": reduced_loss} + + return output_tensor, loss_func + + return fwd_output_and_loss_func + + def get_forward_output_only_func( + self, + set_inference_key_value_memory=False, + inference_max_sequence_len=None, + checkpoint_activations_all_layers=None, + ): + def fwd_output_only_func(batch, model): + if batch is not None: + batch = to_device(batch, torch.cuda.current_device(), non_blocking=True) + + extra_arg = {} + + if len(batch) == 3: + tokens, attention_mask, position_ids = batch + else: + ( + tokens, + attention_mask, + position_ids, + set_inference_key_value_memory, + inference_max_sequence_len, + ) = batch + + extra_arg["set_inference_key_value_memory"] = set_inference_key_value_memory[0].item() + extra_arg["inference_max_sequence_len"] = inference_max_sequence_len[0].item() + + output_tensor = model( + input_ids=tokens, + position_ids=position_ids.long(), + attention_mask=attention_mask, + **extra_arg, + ) + else: + output_tensor = model(input_ids=None, position_ids=None, attention_mask=None) + + def id_func(output_tensor): + return output_tensor, {"logits": output_tensor} + + return output_tensor, id_func + + return fwd_output_only_func + + def generate( + self, + inputs: Union[List[str], torch.Tensor, List[dict]], + length_params: LengthParam, + sampling_params: SamplingParam = None, + ) -> OutputType: + if sampling_params is None: + sampling_params = { + "use_greedy": self.sft_config.gen_kwargs.get("use_greedy", False), + "temperature": self.sft_config.gen_kwargs.get("temperature", 1.0), + "top_k": self.sft_config.gen_kwargs.get("top_k", 0), + "top_p": self.sft_config.gen_kwargs.get("top_p", 0.9), + "repetition_penalty": self.sft_config.gen_kwargs.get("repetition_penalty", 1.2), + "add_BOS": False, + "all_probs": False, + "compute_logprob": False, + } + + return super().generate(inputs, length_params, sampling_params) diff --git a/examples/BeautifulPrompt/trlx/models/modeling_ppo.py b/examples/BeautifulPrompt/trlx/models/modeling_ppo.py new file mode 100644 index 0000000..8fda460 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/models/modeling_ppo.py @@ -0,0 +1,1293 @@ +import gc +import inspect +from copy import deepcopy +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import transformers +from torchtyping import TensorType +from transformers.modeling_outputs import ModelOutput +from transformers.models.bloom import modeling_bloom +from transformers.models.opt import modeling_opt + +from trlx.data.method_configs import MethodConfig, register_method +from trlx.models.modeling_base import PreTrainedModelWrapper +from trlx.utils.modeling import ( + flatten_dict, + get_tensor_stats, + hf_get_decoder, + hf_get_decoder_blocks, + hf_get_decoder_final_norm, + hf_get_hidden_size, + hf_get_lm_head, + hf_get_num_hidden_layers, + make_head, + whiten, +) + +# KL Controllers + + +class AdaptiveKLController: + """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences" + Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 + Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py + """ + + def __init__(self, init_kl_coef: float, target: float, horizon: int): + self.value = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, current: float, n_steps: int): + """Returns adaptively updated KL coefficient, βₜ₊₁. + Arguments: + current: The current KL value between the newest policy and the initial policy. + """ + proportional_error = np.clip(current / self.target - 1, -0.2, 0.2) # ϵₜ + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult # βₜ₊₁ + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current: float, n_steps: int): + """Returns updated KL coefficient, βₜ₊₁. + Arguments: + current: The current KL value between the newest policy and the initial policy. + """ + pass + + +# PPO Configs + + +@dataclass +@register_method +class PPOConfig(MethodConfig): + """ + Config for PPO method + + :param ppo_epochs: Number of updates per batch + :type ppo_epochs: int + + :param num_rollouts: Number of experiences to observe before learning + :type num_rollouts: int + + :param init_kl_coef: Initial value for KL coefficient + :type init_kl_coef: float + + :param target: Target value for KL coefficient + :type target: float + + :param horizon: Number of steps for KL coefficient to reach target + :type horizon: int + + :param gamma: Discount factor + :type gamma: float + + :param lam: GAE lambda + :type lam: float + + :param cliprange: Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange) + :type cliprange: float + + :param cliprange_value: Clipping range for predicted values + (observed values - cliprange_value, observed values + cliprange_value) + :type cliprange_value: float + + :param vf_coef: Value loss scale w.r.t policy loss + :type vf_coef: float + + :param gen_kwargs: Additional kwargs for the generation + :type gen_kwargs: Dict[str, Any] + + :param gen_experience_kwargs: if this is not None, then the experience is generated using this + :type gen_experience_kwargs: Dict[str, Any] + """ + + ppo_epochs: int + num_rollouts: int + chunk_size: int + init_kl_coef: float + target: float + horizon: int + gamma: float + lam: float + cliprange: float + cliprange_value: float + vf_coef: float + scale_reward: Optional[str] + ref_mean: Optional[float] + ref_std: Optional[float] + cliprange_reward: float + gen_kwargs: dict + gen_experience_kwargs: Optional[dict] = None + + def get_advantages_and_returns( + self, + values: TensorType["batch_size", "response_size"], + rewards: TensorType["batch_size", "response_size"], + response_length: int, + use_whitening: Optional[bool] = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Function that computes advantages and returns from rewards and values. + Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347 + Note that rewards may include a KL divergence loss term. + + Advantages looks like this: + Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Returns looks like this: + Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Args: + values: Tensor of shape (batch_size, response_size) + rewards: Tensor of shape (batch_size, response_size) + response_length: Length of the response sequence + use_whitening: Whether to use whitening (ie. normalize advantages) or not + """ + lastgaelam = 0 + advantages_reversed = [] + for t in reversed(range(response_length)): + nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0 + delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.gamma * self.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + returns = advantages + values + if use_whitening: + advantages = whiten(advantages) + return advantages.detach(), returns + + def loss( + self, + logprobs: TensorType["batch_size", "response_size"], + values: TensorType["batch_size", "response_size"], + old_logprobs: TensorType["batch_size", "response_size"], + old_values: TensorType["batch_size", "response_size"], + advantages: TensorType["batch_size", "response_size"], + returns: TensorType["batch_size", "response_size"], + mask: TensorType["batch_size", "response_size"], + ): + """PPO objective function. + References: + - https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html + """ + values_clipped = torch.clamp( + values, + old_values - self.cliprange_value, + old_values + self.cliprange_value, + ) + n = mask.sum() + + vf_loss1 = (values - returns) ** 2 + vf_loss2 = (values_clipped - returns) ** 2 + vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n + vf_clipfrac = torch.sum((vf_loss2 > vf_loss1).float() * mask) / n + + log_ratio = (logprobs - old_logprobs) * mask + ratio = torch.exp(log_ratio) + # Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html + with torch.no_grad(): + approx_kl = torch.mean((ratio - 1) - log_ratio) + + pg_loss1 = -advantages * ratio + pg_loss2 = -advantages * torch.clamp( + ratio, + 1.0 - self.cliprange, + 1.0 + self.cliprange, + ) + pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n + pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * mask) / n + + loss = pg_loss + self.vf_coef * vf_loss + + stats = dict( + losses=dict( + total_loss=loss.item(), + policy_loss=pg_loss.item(), + value_loss=vf_loss.item(), + ), + values=dict( + get_tensor_stats(values, mask, n), + values_error=torch.sum(((values - returns) * mask) ** 2) / n, + clipfrac=vf_clipfrac, + ), + old_values=get_tensor_stats(old_values, mask, n), + returns=get_tensor_stats(returns, mask, n), + policy=dict(approx_kl=approx_kl.item(), clipfrac=pg_clipfrac.item()), + ratio=(ratio * mask).sum() / n, + padding_percentage=n / mask.numel(), + ) + + return loss, flatten_dict(stats) + + +# CausalLM architectures + + +@dataclass +class CausalLMOutputWithValue(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + value: Optional[torch.FloatTensor] = None + + +class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): + """An `AutoModel` class wrapper for `transformers` causal models that have a + language modeling head and a value head + """ + + _auto_model_parent_class = transformers.AutoModelForCausalLM + _supported_modules = ["v_head"] + _supported_args = [] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + ): + super().__init__(base_model) + self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + position_ids: Optional[List[torch.FloatTensor]] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithValue]: + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + forward_kwargs["output_hidden_states"] = True + forward_kwargs["return_dict"] = True + + outputs = self.base_model(**forward_kwargs) + value = self.v_head(outputs.hidden_states[-1]).squeeze(-1) + + if not return_dict: + outputs = (outputs.logits,) + outputs[1:] + (value,) + return outputs + + return CausalLMOutputWithValue(**outputs, value=value) + + def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]: + return self.base_model.generate(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + """ + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + base_model_state_dict = self.base_model.state_dict(*args, **kwargs) + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + base_model_state_dict[f"v_head.{k}"] = v + return base_model_state_dict + + def post_init(self, state_dict): + """ + Adds the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() # noqa: E702 + + +class AutoModelForCausalLMWithHydraValueHead(AutoModelForCausalLMWithValueHead): + _supported_modules = ["v_head", "frozen_head"] + _supported_args = ["num_layers_unfrozen"] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + num_layers_unfrozen: int = -1, + ): + super().__init__(base_model) + self.num_layers_unfrozen = num_layers_unfrozen + if self.num_layers_unfrozen > 0: + config = self.base_model.config + branch_class = hf_get_branch_class(config) + self.frozen_head = branch_class( + self.base_model, + num_layers_unfrozen=self.num_layers_unfrozen, + ).eval() + + def forward_hydra( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + position_ids: Optional[List[torch.FloatTensor]] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[torch.FloatTensor, CausalLMOutputWithValue]: + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return_dict = forward_kwargs.get("return_dict", True) + forward_kwargs["return_dict"] = True + forward_kwargs["output_hidden_states"] = True + + outputs = self.forward(**forward_kwargs) + # Select the hidden state before the first branching layer + input_hidden_state = outputs.hidden_states[-(self.num_layers_unfrozen + 1)] + + output_shape = outputs.hidden_states[-1].size() + forward_kwargs.pop("input_ids", None) # Ignore `input_ids` for branch head + forward_kwargs.pop("inputs_embeds", None) # Ignore `inputs_embeds` for branch head + hydra_outputs = self.frozen_head(input_hidden_state, output_shape, **forward_kwargs) + + if not return_dict: + return hydra_outputs.logits + return hydra_outputs + + +class ModelBranch(transformers.PreTrainedModel): + """Implements the frozen upper trunk of the pretrained reference model used + when computing the PPO KL-divergence penalty. + """ + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + num_layers_unfrozen: int, + ): + """ + Args: + base_model (transformers.PreTrainedModel): The pretrained model to extract upper trunk from + num_layers_unfrozen (int): The number of trainable layers + """ + super().__init__(base_model.config) + + # The branch is defined by the last `num_layers_unfrozen` layers of the pretrained model + decoder_blocks = deepcopy(hf_get_decoder_blocks(base_model)) + self.decoder_blocks = nn.ModuleList(list(decoder_blocks)[-num_layers_unfrozen:]) + self.final_norm = deepcopy(hf_get_decoder_final_norm(base_model)) + self.lm_head = deepcopy(hf_get_lm_head(base_model)) + + self.hidden_size = hf_get_hidden_size(self.config) + self.model_parallel = False + self.device_map = None + self.last_device = None + self.gradient_checkpointing = False + + # Freeze the entire branch + for parameter in self.parameters(): + parameter.requires_grad_(False) + + +class GPTModelBranch(ModelBranch): + def forward( # noqa: max-complexity + self, + hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids + output_shape: torch.Tensor, # output_size given by main trunk + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = False, + ) -> Union[Tuple, CausalLMOutputWithValue]: + """Reference: + https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/gpt2/modeling_gpt2.py#L743 # noqa: E501 + """ + batch_size, seq_length = hidden_states.shape[:2] + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + device = hidden_states.device + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.decoder_blocks)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length) + + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + if self.config.add_cross_attention and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config)) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)): + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + kwargs = dict( + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # Assumes we are never training the branch + block_params = inspect.getfullargspec(block.forward).args + if "encoder_hidden_states" not in block_params: + kwargs.pop("encoder_hidden_states") + kwargs.pop("encoder_attention_mask") + # Remove position_ids for GPT2Block + if "position_ids" not in block_params: + kwargs.pop("position_ids") + + outputs = block(hidden_states, **kwargs) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_norm(hidden_states) + + hidden_states = hidden_states.view(output_shape) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + outputs = (lm_logits,) + (None,) + (None,) + return outputs + + return CausalLMOutputWithValue( + logits=lm_logits, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class OPTModelBranch(ModelBranch): + def forward( # noqa: max-complexity + self, + hidden_states: torch.Tensor, + output_shape: torch.Tensor, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = False, + ) -> Union[Tuple, CausalLMOutputWithValue]: + """Reference: + https://github.com/huggingface/transformers/blob/bdb84e2bada3658f99c6a81c963ec562f8485151/src/transformers/models/opt/modeling_opt.py#L840 # noqa: E501 + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device) + + input_shape = hidden_states.size()[:-1] + combined_attention_mask = None + if input_shape[-1] > 1: + # `modeling_opt._make_causal_mask` @ transformers==4.27.1 doesn't have the `device` argument + if "device" in inspect.getfullargspec(modeling_opt._make_causal_mask).args: + kwargs = dict(device=hidden_state.device) + else: + kwargs = {} + + combined_attention_mask = modeling_opt._make_causal_mask( + input_shape, + hidden_states.dtype, + past_key_values_length=past_key_values_length, + **kwargs, + ).to(hidden_states.device) + + if attention_mask is not None: + expanded_attn_mask = modeling_opt._expand_mask( + attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] + ).to(hidden_states.device) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + attention_mask = combined_attention_mask + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.decoder_blocks)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.decoder_blocks)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.decoder_blocks): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.final_norm is not None: + hidden_states = self.final_norm(hidden_states) + + # TODO: Add output projection support + # https://github.com/huggingface/transformers/blob/699e90437f984d69ad3c9b891dd2e9d0fc2cffe4/src/transformers/models/opt/modeling_opt.py#L499 # noqa: E501 + # if self.project_out is not None: + # hidden_states = self.project_out(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + lm_logits = self.lm_head(hidden_states).contiguous() + + if not return_dict: + return tuple( + v + for v in [ + lm_logits, + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + + return CausalLMOutputWithValue( + logits=lm_logits, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class BloomModelBranch(ModelBranch): + def forward( # noqa: max-complexity + self, + hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids + output_shape: torch.Tensor, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = False, + ) -> Union[Tuple, CausalLMOutputWithValue]: + """Reference: + https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/bloom/modeling_bloom.py#L623 # noqa: E501 + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = hidden_states.shape[:2] + + if past_key_values is None: + past_key_values = tuple([None] * len(self.decoder_blocks)) + + head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config)) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = modeling_bloom.build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype) + + combined_attention_mask = None + device = attention_mask.device + input_shape = (batch_size, seq_length) + _, src_length = input_shape + + if src_length > 1: + combined_attention_mask = modeling_bloom._make_causal_mask( + input_shape, + device=device, + past_key_values_length=past_key_values_length, + ) + + expanded_attn_mask = modeling_bloom._expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + ) + causal_mask = combined_attention_mask + + for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.final_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return tuple( + v + for v in [ + lm_logits, + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return CausalLMOutputWithValue( + logits=lm_logits, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class LlamaModelBranch(ModelBranch): + def _make_causal_mask(self, input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, hidden_states, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = self._make_causal_mask( + input_shape, hidden_states.dtype, past_key_values_length=past_key_values_length + ).to(hidden_states.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]).to( + hidden_states.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + return combined_attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + output_shape: torch.Tensor, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = False, + ) -> Union[Tuple, CausalLMOutputWithValue]: + """Reference: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L491 + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size, seq_length = hidden_states.shape[:2] + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = hidden_states.device if hidden_states is not None else encoder_hidden_states.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.decoder_blocks): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_norm(hidden_states) + hidden_states = hidden_states.view(output_shape) + lm_logits = self.lm_head(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + outputs = (lm_logits,) + (None,) + (None,) + return outputs + + return CausalLMOutputWithValue( + logits=lm_logits, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +# Seq2Seq architectures + + +@dataclass +class Seq2SeqLMOutputWithValue(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + value: Optional[torch.FloatTensor] = None + + +class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): + """An `AutoModel` class wrapper for `transformers` sequence-to-sequence + models that have a language modeling head and a value head + """ + + _auto_model_parent_class = transformers.AutoModelForSeq2SeqLM + _supported_modules = ["v_head"] + _supported_args = [] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + ): + super().__init__(base_model) + self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = True, + output_hidden_states: Optional[bool] = True, + return_dict: Optional[bool] = None, + ) -> Seq2SeqLMOutputWithValue: + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + forward_kwargs["output_hidden_states"] = True + forward_kwargs["return_dict"] = True + + outputs = self.base_model(**forward_kwargs) + last_hidden_state = outputs.decoder_hidden_states[-1] + value = self.v_head(last_hidden_state).squeeze(-1) + + return Seq2SeqLMOutputWithValue(**outputs, value=value) + + def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]: + return self.base_model.generate(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + """ + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + base_model_state_dict = self.base_model.state_dict(*args, **kwargs) + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + base_model_state_dict[f"v_head.{k}"] = v + return base_model_state_dict + + def post_init(self, state_dict): + """ + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() # noqa: E702 + + +class AutoModelForSeq2SeqLMWithHydraValueHead(AutoModelForSeq2SeqLMWithValueHead): + _supported_modules = ["v_head", "frozen_head"] + _supported_args = ["num_layers_unfrozen"] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + num_layers_unfrozen: int = -1, + ): + super().__init__(base_model) + self.num_layers_unfrozen = num_layers_unfrozen + if self.num_layers_unfrozen > 0: + branch_class = T5Branch # TODO: Add support for other model branches + self.frozen_head = branch_class( + self.base_model, + num_layers_unfrozen=self.num_layers_unfrozen, + ).eval() + + def forward_hydra( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Seq2SeqLMOutputWithValue: + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return_dict = forward_kwargs.get("return_dict", True) + forward_kwargs["output_hidden_states"] = True + forward_kwargs["return_dict"] = True + + outputs = self.forward(**forward_kwargs) + # Select the hidden state before the first branching layer + input_hidden_state = outputs.decoder_hidden_states[-(self.num_layers_unfrozen + 1)] + hydra_outputs = self.frozen_head( + hidden_states=input_hidden_state, + attention_mask=decoder_attention_mask, + encoder_hidden_states=outputs.encoder_last_hidden_state, + encoder_attention_mask=attention_mask, + use_cache=False, + output_attentions=False, + output_hidden_states=True, + return_dict=return_dict, + ) + + if not return_dict: + return hydra_outputs.logits + return hydra_outputs + + +class T5Branch(ModelBranch): + """Decoder only T5 branch""" + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + num_layers_unfrozen: int, + ): + super().__init__(base_model, num_layers_unfrozen=num_layers_unfrozen) + self.dropout = hf_get_decoder(base_model).dropout + self.is_decoder = True + + def forward( # noqa: max-complexity + self, + hidden_states: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqLMOutputWithValue]: + """Reference: + https://github.com/huggingface/transformers/blob/bc21aaca789f1a366c05e8b5e111632944886393/src/transformers/models/t5/modeling_t5.py#L899 # noqa: E501 + """ + batch_size, seq_length = hidden_states.shape[:2] + input_shape = (batch_size, seq_length) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is None: + attention_mask = torch.ones(batch_size, seq_length, device=hidden_states.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=hidden_states.device, dtype=torch.long + ) + + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=hidden_states.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + position_bias = None + encoder_decoder_position_bias = None + + for _, layer_module in enumerate(self.decoder_blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + + hidden_states = self.final_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + sequence_output = hidden_states + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa: E501 + sequence_output = sequence_output * (self.config.d_model**-0.5) + + lm_logits = self.lm_head(sequence_output) + + if not return_dict: + return (lm_logits,) + + return Seq2SeqLMOutputWithValue( + logits=lm_logits, + decoder_hidden_states=all_hidden_states, + decoder_attentions=all_attentions, + ) + + +# Branch class utils + + +def hf_get_branch_class( + config: transformers.PretrainedConfig, +) -> "ModelBranch": + """Returns the model branch class for the given config.""" + gpt_branch_supported_archs = [ + "GPTJForCausalLM", + "GPT2LMHeadModel", + "GPTNeoForCausalLM", + "GPTNeoXForCausalLM", + ] + opt_branch_supported_archs = ["OPTForCausalLM"] + bloom_branch_supported_archs = ["BloomModel", "BloomForCausalLM"] + llama_branch_supported_archs = ["LlamaModel", "LlamaForCausalLM"] + arch = config.architectures[0] + if arch in gpt_branch_supported_archs: + return GPTModelBranch + elif arch in opt_branch_supported_archs: + return OPTModelBranch + elif arch in bloom_branch_supported_archs: + return BloomModelBranch + elif arch in llama_branch_supported_archs: + return LlamaModelBranch + else: + all_supported_archs = sum( + [ + gpt_branch_supported_archs, + opt_branch_supported_archs, + bloom_branch_supported_archs, + llama_branch_supported_archs, + ], + [], + ) + raise ValueError( + f"Unsupported architecture: `{arch}`. The following architectures are " + f"available for model branching:\n{all_supported_archs}" + ) diff --git a/examples/BeautifulPrompt/trlx/pipeline/__init__.py b/examples/BeautifulPrompt/trlx/pipeline/__init__.py new file mode 100644 index 0000000..b3fcbd5 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/pipeline/__init__.py @@ -0,0 +1,170 @@ +import random +import sys +from abc import abstractmethod, abstractstaticmethod +from dataclasses import is_dataclass +from typing import Any, Callable, Dict, Iterable + +from torch.utils.data import DataLoader, Dataset +from transformers.tokenization_utils_base import BatchEncoding + +from trlx.data import GeneralElement, RLElement +from trlx.utils import logging + +# specifies a dictionary of architectures +_DATAPIPELINE: Dict[str, any] = {} # registry + +logger = logging.get_logger(__name__) + + +def register_datapipeline(name): + """Decorator used register a CARP architecture + Args: + name: Name of the architecture + """ + + def register_class(cls, name): + _DATAPIPELINE[name] = cls + setattr(sys.modules[__name__], name, cls) + return cls + + if isinstance(name, str): + name = name.lower() + return lambda c: register_class(c, name) + + cls = name + name = cls.__name__ + register_class(cls, name.lower()) + + return cls + + +@register_datapipeline +class BasePipeline(Dataset): + def __init__(self, path: str = "dataset"): + super().__init__() + + @abstractmethod + def __getitem__(self, index: int) -> GeneralElement: + pass + + @abstractmethod + def __len__(self) -> int: + pass + + @abstractmethod + def create_loader( + self, + batch_size: int, + shuffle: bool, + prep_fn: Callable = None, + num_workers: int = 0, + ) -> DataLoader: + """ + Create a dataloader for the pipeline + + :param prep_fn: Typically a tokenizer. Applied to GeneralElement after collation. + """ + pass + + +class BaseRolloutStore(Dataset): + def __init__(self, capacity=-1): + self.history: Iterable[Any] = None + self.capacity = capacity + + @abstractmethod + def push(self, exps: Iterable[Any]): + """ + Push experiences to rollout storage + """ + pass + + def __getitem__(self, index: int) -> RLElement: + return self.history[index] + + def __len__(self) -> int: + return len(self.history) + + @abstractmethod + def create_loader( + self, + batch_size: int, + shuffle: bool, + prep_fn: Callable = None, + num_workers: int = 0, + ) -> DataLoader: + """ + Create a dataloader for the rollout store + + :param prep_fn: Applied to RLElement after collation (typically tokenizer) + :type prep_fn: Callable + """ + pass + + +class MiniBatchIterator: + """ + A custom iterator for generating mini-batches from a PyTorch DataLoader. + """ + + def __init__(self, data_loader, mb_size, num_mb): + """ + Initializes the MiniBatchIterator. + + Args: + data_loader (torch.utils.data.DataLoader): The DataLoader to generate mini-batches from. + mb_size (int): The size of each mini-batch. + num_mb (int): The number of mini-batches to generate for each iteration. + """ + self.data_loader = data_loader + self.data_loader_iter = iter(data_loader) + self.mb_size = mb_size + self.num_mb = num_mb + + def __iter__(self): + return self + + def __next__(self): + batch = next(self.data_loader_iter) + minibatches = [] + + for mbi in range(self.num_mb): + sliced_data = {} + batch_dict = batch + if is_dataclass(batch): + batch_dict = batch.__dict__ + for key, value in batch_dict.items(): + start_idx = mbi * self.mb_size + end_idx = (mbi + 1) * self.mb_size + sliced_data[key] = value[start_idx:end_idx] + + if len(sliced_data[key]) == 0: + logger.warning( + "WARNING: MiniBatchIterator generated a minibatch with 0 elements. " + "This may be due to the wrong mb_size and/or num_mb or the last batch" + "in the dataset being smaller." + ) + sliced_data.pop(key) + break + elif len(sliced_data[key]) < self.mb_size: + logger.warning( + "WARNING: MiniBatchIterator generated a minibatch with fewer elements than mb_size. " + "This may be due to the wrong mb_size and/or num_mb or the last batch in the dataset " + "being smaller." + ) + if not sliced_data: + break + + if isinstance(batch, BatchEncoding): + minibatch = BatchEncoding(sliced_data) + elif is_dataclass(batch): + minibatch = batch.__class__(**sliced_data) + # else: + # minibatch = sliced_data + + minibatches.append(minibatch) + + if not minibatches: + raise StopIteration + + return minibatches diff --git a/examples/BeautifulPrompt/trlx/pipeline/offline_pipeline.py b/examples/BeautifulPrompt/trlx/pipeline/offline_pipeline.py new file mode 100644 index 0000000..0e1acd4 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/pipeline/offline_pipeline.py @@ -0,0 +1,263 @@ +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Tuple, Union + +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader +from transformers import ( + DataCollatorWithPadding, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) + +from trlx.data.ilql_types import ( + ILQLBatch, + ILQLElement, + ILQLSeq2SeqBatch, + ILQLSeq2SeqElement, +) +from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline + + +@dataclass +class DialogMessage: + is_output: bool + tokens: Tuple[int] + + +def tokenize_dialogue( # noqa: C901 + dialogue: Union[str, Iterable[str]], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_length=2048 +) -> List[DialogMessage]: + """ + Tokenize sample with the interleaved form of (prompt_1, output_1, prompt_2, output_2...) + """ + if isinstance(dialogue, str): + bos_token = tokenizer.bos_token or tokenizer.eos_token + dialogue = [bos_token, dialogue] + elif isinstance(dialogue, Iterable): + if len(dialogue) % 2 != 0: + raise ValueError("Dialogue must have an even number of phrases, alternating prompt and output") + dialogue = list(dialogue) + + if not dialogue[-1].endswith(tokenizer.eos_token): + dialogue[-1] = dialogue[-1] + tokenizer.eos_token + + tokenized = [ + DialogMessage(is_output=i % 2 == 1, tokens=tuple(tokenizer(dialogue[i], add_special_tokens=False).input_ids)) + for i in range(len(dialogue)) + ] + + # flip to truncate from the left + if tokenizer.truncation_side == "left": + tokenized = [DialogMessage(is_output=m.is_output, tokens=m.tokens[::-1]) for m in tokenized[::-1]] + + # truncate if necessary + lengths = [len(t.tokens) for t in tokenized] + cumsum_lengths = [sum(lengths[:i]) for i in range(len(lengths))] + truncated = [ + DialogMessage(is_output=t.is_output, tokens=t.tokens[: max(max_length - cl, 0)]) + for t, cl in zip(tokenized, cumsum_lengths) + ] + + # flip back if was fliped to left truncate + if tokenizer.truncation_side == "left": + truncated = [DialogMessage(is_output=m.is_output, tokens=m.tokens[::-1]) for m in truncated[::-1]] + + # remove empty messages + out = [t for t in truncated if len(t.tokens) > 0] + + if out[0].is_output: + if sum(map(lambda msg: len(msg.tokens), out)) == max_length: + if tokenizer.truncation_side == "left": + out[0].tokens = out[0].tokens[1:] + else: + out[-1].tokens = out[-1].tokens[:-1] + + out.insert(0, DialogMessage(False, (tokenizer.bos_token_id,))) + return out + + +class DialogStore(BaseRolloutStore): + def __init__(self, dialogs: List[List[DialogMessage]], tokenizer: PreTrainedTokenizer): + super().__init__() + self.tokenizer = tokenizer + attention_masks = [torch.ones(sum(len(m.tokens) for m in d), dtype=torch.bool) for d in dialogs] + input_ids = [torch.tensor([t for m in d for t in m.tokens], dtype=torch.long) for d in dialogs] + # -100 is the ignore index for CrossEntropyLoss + labels = [ + torch.tensor([t if m.is_output else -100 for m in d for t in m.tokens], dtype=torch.long) for d in dialogs + ] + self.history = [ + dict(input_ids=i, attention_mask=a, labels=l) for i, a, l in zip(input_ids, attention_masks, labels) + ] + + def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: + hf_collate_fn = DataCollatorWithPadding(self.tokenizer) + + def collate_fn(elems: Iterable[dict]): + batch = hf_collate_fn( + {"input_ids": [e["input_ids"] for e in elems], "attention_mask": [e["attention_mask"] for e in elems]} + ) + labels = hf_collate_fn([{"input_ids": e["labels"]} for e in elems])["input_ids"] + batch["labels"] = labels + return batch + + return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle) + + +@register_datapipeline +class PromptPipeline(BasePipeline): + """ + Dataloader which is used to supply prompts for either training or evaluation + + Args: + prompts (`List[str]` or `List[Dict[str, Any]]`): list of raw text prompts or a dictionary with a required + key `"prompt"` and extra information, that would be passed along the generation for that prompt as a + keyword argument to a reward function. + max_prompt_length (`int`): max length of the prompt, if exceeded the prompt will be truncated according to + tokenizer's truncation setting. + tokenizer (`transformers.PreTrainedTokenizer`): a tokenizer to tokenize prompts with. + """ + + def __init__( + self, prompts: Union[Dict[str, Any], List[str]], max_prompt_length: int, tokenizer: PreTrainedTokenizer + ): + super().__init__() + + if isinstance(prompts[0], dict): + metadata = prompts + prompts = [x.pop("prompt") for x in metadata] + else: + metadata = [{}] * len(prompts) + + model_inputs = tokenizer( + prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=False + ) + + prompts_tokens = model_inputs["input_ids"] + attention_mask = model_inputs["attention_mask"] + + self.tokenizer = tokenizer + self.prompts = [ + {"input_ids": tokens, "attention_mask": mask, **metadata} + for tokens, mask, metadata in zip(prompts_tokens, attention_mask, metadata) + ] + + def __getitem__(self, ix: int): + return self.prompts[ix] + + def __len__(self) -> int: + return len(self.prompts) + + def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: + def collate_fn(xs): + out = self.tokenizer.pad([{"input_ids": x["input_ids"]} for x in xs], return_tensors="pt") + + for key in xs[0]: + if key != "input_ids" and key != "attention_mask": + out[key] = [x[key] for x in xs] + + return out + + return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle) + + +def ilql_collate_fn(elems: Iterable[ILQLElement]): + return ILQLBatch( + pad_sequence([x.input_ids for x in elems], batch_first=True, padding_value=0), + pad_sequence([x.attention_mask for x in elems], batch_first=True, padding_value=0), + pad_sequence([x.rewards for x in elems], batch_first=True, padding_value=0.0), + pad_sequence([x.states_ixs for x in elems], batch_first=True, padding_value=0), + pad_sequence([x.actions_ixs for x in elems], batch_first=True, padding_value=0), + pad_sequence([x.dones for x in elems], batch_first=True, padding_value=0), + ) + + +class ILQLRolloutStorage(BaseRolloutStore): + """ + Rollout storage for training ILQL + """ + + def __init__(self, input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones): + super().__init__() + + self.input_ids = input_ids + self.attention_mask = attention_mask + self.rewards = rewards + self.states_ixs = states_ixs + self.actions_ixs = actions_ixs + self.dones = dones + + def __getitem__(self, ix: int) -> ILQLElement: + return ILQLElement( + self.input_ids[ix], + self.attention_mask[ix], + self.rewards[ix], + self.states_ixs[ix], + self.actions_ixs[ix], + self.dones[ix], + ) + + def __len__(self) -> int: + return len(self.input_ids) + + def create_loader(self, batch_size: int, drop_last=True): + return DataLoader( + self, + batch_size=batch_size, + shuffle=True, + collate_fn=ilql_collate_fn, + drop_last=drop_last, + ) + + +def ilql_seq2seq_collate_fn(elems: Iterable[ILQLElement]): + return ILQLSeq2SeqBatch( + pad_sequence([x.input_ids for x in elems], batch_first=True, padding_value=0), + pad_sequence([x.attention_mask for x in elems], batch_first=True, padding_value=0), + pad_sequence([x.decoder_input_ids for x in elems], batch_first=True, padding_value=0), + pad_sequence([x.rewards for x in elems], batch_first=True, padding_value=0.0), + pad_sequence([x.states_ixs for x in elems], batch_first=True, padding_value=0), + pad_sequence([x.actions_ixs for x in elems], batch_first=True, padding_value=0), + pad_sequence([x.dones for x in elems], batch_first=True, padding_value=0), + ) + + +class ILQLSeq2SeqRolloutStorage(BaseRolloutStore): + """ + Rollout storage for training ILQL + """ + + def __init__(self, input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones): + super().__init__() + + self.input_ids = input_ids + self.attention_mask = attention_mask + self.decoder_input_ids = decoder_input_ids + self.rewards = rewards + self.states_ixs = states_ixs + self.actions_ixs = actions_ixs + self.dones = dones + + def __getitem__(self, ix: int) -> ILQLElement: + return ILQLSeq2SeqElement( + self.input_ids[ix], + self.attention_mask[ix], + self.decoder_input_ids[ix], + self.rewards[ix], + self.states_ixs[ix], + self.actions_ixs[ix], + self.dones[ix], + ) + + def __len__(self) -> int: + return len(self.input_ids) + + def create_loader(self, batch_size: int, drop_last=True): + return DataLoader( + self, + batch_size=batch_size, + shuffle=True, + collate_fn=ilql_seq2seq_collate_fn, + drop_last=drop_last, + ) diff --git a/examples/BeautifulPrompt/trlx/pipeline/ppo_pipeline.py b/examples/BeautifulPrompt/trlx/pipeline/ppo_pipeline.py new file mode 100644 index 0000000..7808f35 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/pipeline/ppo_pipeline.py @@ -0,0 +1,80 @@ +import json +import os +import time +from typing import Iterable + +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader + +from trlx.data.ppo_types import PPORLBatch, PPORLElement +from trlx.pipeline import BaseRolloutStore + + +class PPORolloutStorage(BaseRolloutStore): + """ + Rollout storage for training PPO + """ + + def __init__(self, pad_token_id): + super().__init__() + + self.pad_token_id = pad_token_id + self.history: Iterable[PPORLElement] = [None] + + def push(self, exps: Iterable[PPORLElement]): + self.history += exps + + def clear_history(self): + self.history = [] + + def export_history(self, location: str): + assert os.path.exists(location) + + fpath = os.path.join(location, f"epoch-{str(time.time())}.json") + + def exp_to_dict(exp): + {k: v.cpu().tolist() for k, v in exp.__dict__.items()} + + data = [exp_to_dict(exp) for exp in self.history] + with open(fpath, "w") as f: + f.write(json.dumps(data, indent=2)) + + def __getitem__(self, index: int) -> PPORLElement: + return self.history[index] + + def __len__(self) -> int: + return len(self.history) + + def create_loader( + self, + batch_size: int, + shuffle: bool, + ) -> DataLoader: + def collate_fn(elems: Iterable[PPORLElement]): + return PPORLBatch( + # Left padding of already left-padded queries + pad_sequence( + [elem.query_tensor.flip(0) for elem in elems], + padding_value=self.pad_token_id, + batch_first=True, + ).flip(1), + # Right pad the rest, to have a single horizontal query/response split + pad_sequence( + [elem.response_tensor for elem in elems], + padding_value=self.pad_token_id, + batch_first=True, + ), + pad_sequence( + [elem.logprobs for elem in elems], + padding_value=0.0, + batch_first=True, + ), + pad_sequence([elem.values for elem in elems], padding_value=0.0, batch_first=True), + pad_sequence( + [elem.rewards for elem in elems], + padding_value=0.0, + batch_first=True, + ), + ) + + return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=collate_fn) diff --git a/examples/BeautifulPrompt/trlx/reference.py b/examples/BeautifulPrompt/trlx/reference.py new file mode 100644 index 0000000..dab6b6d --- /dev/null +++ b/examples/BeautifulPrompt/trlx/reference.py @@ -0,0 +1,103 @@ +# python -m trlx.reference CarperAI/trlx:add-benchmark-tools --against CarperAI/trlx:main + +import argparse +import os +import subprocess + +import wandb +import wandb.apis.reports as wb + +parser = argparse.ArgumentParser() +parser.add_argument("branch", type=str, help="Git branch in the format `origin:branch`") +parser.add_argument("--against", type=str, default="CarperAI/trlx:main", help="Reference git branch") +parser.add_argument("--public", action="store_true", help="Use CarperAI entity to store/pull from w&b runs") +args = parser.parse_args() + +pr_origin = ref_origin = "CarperAI/trlx" +pr_branch = args.branch +ref_branch = args.against +if ":" in pr_branch: + pr_origin, pr_branch = pr_branch.rsplit(":", 1) +if ":" in ref_branch: + ref_origin, ref_branch = ref_branch.rsplit(":", 1) + +out = os.popen(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} --only_hash") +pr_hash, pr_git_hash = [x[:-1] for x in out.readlines()] + +out = os.popen(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} --only_hash") +ref_hash, ref_git_hash = [x[:-1] for x in out.readlines()] + +print(f"{pr_origin}:{pr_branch=} {pr_hash=} {pr_git_hash=}") +print(f"{ref_origin}:{ref_branch} {ref_hash=} {ref_git_hash=}") + +api = wandb.Api() +project_name = "CarperAI/trlx-references" if args.public else "trlx-references" +public = "--public" if args.public else "" + +runs = api.runs(project_name, filters={"tags": {"$in": [ref_hash]}}) +if runs: + print(f"On {ref_branch} @{ref_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}") +else: + print(f"Making runs on {ref_branch} @{ref_git_hash}") + subprocess.run(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} {public}".split()) + +runs = api.runs(project_name, filters={"tags": {"$in": [pr_hash]}}) +if runs: + print(f"On {pr_branch} @{pr_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}") +else: + print(f"Making runs on {pr_branch} @{pr_git_hash}") + subprocess.run(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} {public}".split()) + +report = wb.Report( + project=project_name.split("/")[1] if args.public else project_name, + title=f"{pr_branch} v. {ref_branch}", + description=f"{pr_branch}\n@{pr_git_hash}\n\n{ref_branch}\n@{ref_git_hash}", +) +blocks = [] + +experiment_names = set(x.name.split(":")[0] for x in api.runs(project_name)) +for name in experiment_names: + filters = {"$and": [{"display_name": {"$regex": f"^{name}"}}, {"tags": {"$in": [pr_hash, ref_hash]}}]} + + runs = api.runs(project_name, filters=filters) + metrics = set(sum([[metric for metric in run.history().columns if not metric.startswith("_")] for run in runs], [])) + + metrics_panels = [ + wb.LinePlot( + title=f"{metric}", + x="Step", + y=[metric], + title_x="Step", + smoothing_show_original=True, + max_runs_to_show=2, + plot_type="line", + font_size="auto", + legend_position="north", + ) + for metric in metrics + ] + + # sort the most important metrics to be shown first + major_metrics = set() + for metric in metrics: + if metric.startswith("reward") or metric.startswith("metric"): + major_metrics.add(metric) + metrics = metrics - major_metrics + + blocks.extend( + [ + wb.H1(text=name), + wb.PanelGrid( + panels=[panel for panel in metrics_panels if panel.title in major_metrics], + runsets=[wb.Runset(project=project_name, filters=filters)], + ), + wb.PanelGrid( + panels=[panel for panel in metrics_panels if panel.title in metrics], + runsets=[wb.Runset(project=project_name, filters=filters)], + ), + ] + ) + +report.blocks = blocks +report.save() +print(report.url) diff --git a/examples/BeautifulPrompt/trlx/sweep.py b/examples/BeautifulPrompt/trlx/sweep.py new file mode 100644 index 0000000..615cb73 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/sweep.py @@ -0,0 +1,348 @@ +# python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py +import argparse +import importlib +import json +from datetime import datetime + +import ray +import wandb +import wandb.apis.reports as wb +import yaml +from ray import tune +from ray.air import ScalingConfig +from ray.train.huggingface.accelerate import AccelerateTrainer +from ray.tune.logger import CSVLoggerCallback + + +def get_param_space(config: dict): # noqa: C901 + """Get the param space from the config file.""" + + def get_strategy(value): + """Get search space strategy from config. + A search space defines valid values for your hyperparameters and + can specify how these values are sampled. + + Refer to the documentation for more info: + https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs + + The user will have to define the search space in the config file by providing + the name of the `strategy` and the `values` to sample from. + + The valid strategies are: + - `uniform` (List) - Samples uniformly between the given bounds. + - `quniform` (List) - Samples uniformly between the given bounds, quantized. + - `loguniform` (List) - Samples uniformly between the given bounds on a log scale. + - `qloguniform` (List) - Samples uniformly between the given bounds on a log scale, quantized. + - `randn` (List) - Samples from a normal distribution. + - `qrandn` (List) - Samples from a normal distribution, quantized. + - `randint` (List) - Samples uniformly between the given bounds, quantized to integers. + - `qrandint` (List) - Samples uniformly between the given bounds, quantized to integers. + - `lograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. + - `qlograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. + - `choice` (List) - Samples from a discrete set of values. + - `qrandn` (List) - Samples from a normal distribution, quantized. + - `grid_search` (List) - Samples from the given list of values. + + """ + + strategy = value["strategy"] + if strategy == "uniform": + assert isinstance(value["values"], list) + assert len(value["values"]) == 2 + return tune.uniform(*value["values"]) + elif strategy == "quniform": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.quniform(*value["values"]) + elif strategy == "loguniform": + assert isinstance(value["values"], list) + assert 2 <= len(value["values"]) <= 3 + return tune.loguniform(*value["values"]) + elif strategy == "qloguniform": + assert isinstance(value["values"], list) + assert len(value["values"]) == 4 + return tune.qloguniform(*value["values"]) + elif strategy == "randn": + assert isinstance(value["values"], list) + assert len(value["values"]) == 2 + return tune.randn(*value["values"]) + elif strategy == "qrandn": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.qrandn(*value["values"]) + elif strategy == "randint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 2 + return tune.randint(*value["values"]) + elif strategy == "qrandint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.qrandint(*value["values"]) + elif strategy == "lograndint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.lograndint(*value["values"]) + elif strategy == "qlograndint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 4 + return tune.qlograndint(*value["values"]) + elif strategy == "choice": + assert isinstance(value["values"], list) + return tune.choice(value["values"]) + elif strategy == "grid": + assert isinstance(value["values"], list) + return tune.grid_search(value["values"]) + + for k, v in config.items(): + if k != "tune_config": + config[k] = get_strategy(v) + + return config + + +def get_search_alg(tune_config: dict): + """Initialize the search algorithm and return it. + + Bayesian Optimization is currently supported. + """ + search_alg = tune_config["search_alg"] + + if search_alg == "bayesopt": + try: + from ray.tune.search.bayesopt import BayesOptSearch + except ImportError: + raise ImportError("Please pip install bayesian-optimization to use BayesOptSearch.") + + assert "metric" in tune_config.keys() and "mode" in tune_config.keys() + "Please specify metric and mode for BayesOptSearch." + + return BayesOptSearch(metric=tune_config["metric"], mode=tune_config["mode"]) + elif search_alg == "bohb": + try: + from ray.tune.search.bohb import TuneBOHB + except ImportError: + raise ImportError("Please pip install hpbandster and ConfigSpace to use TuneBOHB.") + + assert "metric" in tune_config.keys() and "mode" in tune_config.keys() + "Please specify metric and mode for TuneBOHB." + + return TuneBOHB() + elif search_alg == "random": + return None + else: + NotImplementedError("Search algorithm not supported.") + + +def get_scheduler(tune_config: dict): + """Initialize the scheduler and return it. + + The schedulers can early terminate bad trials, pause trials, + clone trials, and alter hyperparameters of a running trial. + + Refer to the documentation for more info: + https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#tune-schedulers + + Currently available schedulers are: + - `hyperband` - Implements the HyperBand early stopping algorithm. + + """ + scheduler = tune_config["scheduler"] + + if scheduler == "hyperband": + return tune.schedulers.HyperBandScheduler() + elif scheduler == "hyperbandforbohb": + return tune.schedulers.HyperBandForBOHB() + elif scheduler == "fifo": + return None + else: + NotImplementedError("Scheduler not supported.") + + +def get_tune_config(tune_config: dict): + """Get the tune config to initialized `tune.TuneConfig` + to be passed `tune.Tuner`. + """ + if "search_alg" in tune_config.keys() and tune_config["search_alg"] is not None: + tune_config["search_alg"] = get_search_alg(tune_config) + + if "scheduler" in tune_config.keys() and tune_config["scheduler"] is not None: + tune_config["scheduler"] = get_scheduler(tune_config) + + # Remove config keys with None values. + tune_config = {k: v for k, v in tune_config.items() if v is not None} + + return tune_config + + +def create_report(target_metric, column_names, entity_name, project_name, group_name, best_config): + report = wb.Report( + project=project_name, + title=f"Hyperparameter Optimization Report: {project_name}", + description=group_name, + ) + + report.blocks = [ + wb.PanelGrid( + panels=[ + wb.ParallelCoordinatesPlot( + columns=[wb.PCColumn(f"c::{column}") for column in column_names] + [wb.PCColumn(target_metric)], + layout={"x": 0, "y": 0, "w": 12 * 2, "h": 5 * 2}, + ), + wb.ParameterImportancePlot( + with_respect_to=target_metric, + layout={"x": 0, "y": 5, "w": 6 * 2, "h": 4 * 2}, + ), + wb.ScatterPlot( + # Get it from the metric name. + title=f"{target_metric} v. Index", + x="Index", + y=target_metric, + running_ymin=True, + font_size="small", + layout={"x": 6, "y": 5, "w": 6 * 2, "h": 4 * 2}, + ), + ], + runsets=[ + wb.Runset(project=project_name).set_filters_with_python_expr(f'group == "{group_name}"'), + ], + ), + ] + + entity_project = f"{entity_name}/{project_name}" if entity_name else project_name + api = wandb.Api() + runs = api.runs(entity_project) + + for run in runs: + if run.group == group_name: + history = run.history() + metrics = history.columns + break + + metrics = [metric for metric in metrics if not metric.startswith("_")] + + line_plot_panels = [] + for metric in metrics: + line_plot_panels.append( + wb.LinePlot( + title=f"{metric}", + x="Step", + y=[f"{metric}"], + title_x="Step", + smoothing_show_original=True, + max_runs_to_show=100, + plot_type="line", + font_size="auto", + legend_position="north", + ) + ) + + report.blocks = report.blocks + [ + wb.H1(text="Metrics"), + wb.PanelGrid( + panels=line_plot_panels, + runsets=[wb.Runset(project=project_name).set_filters_with_python_expr(f'group == "{group_name}"')], + ), + ] + + if best_config: + best_config = best_config["train_loop_config"] + config = {} + for name, value in best_config.items(): + *layers, var = name.split(".") + if layers: + d = config.setdefault(layers[0], {}) + for layer in layers[1:]: + d = d.setdefault(layer, {}) + d[var] = value + + report.blocks = report.blocks + [ + wb.H1(text="Best Config"), + wb.CodeBlock(code=[json.dumps(config, indent=4)], language="json"), + ] + + report.save() + print(report.url) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("script", type=str, help="Path to the example script") + parser.add_argument( + "--config", + type=str, + required=True, + help="The config file defining the param_space.", + ) + + parser.add_argument( + "--accelerate_config", + type=str, + required=False, + help="The default config file for the script.", + ) + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs (workers) to use per trial.") + parser.add_argument("--num_cpus", type=int, default=4, help="Number of CPUs to use per GPU (worker).") + parser.add_argument("-y", "--assume_yes", action="store_true", help="Don't ask for confirmation") + parser.add_argument( + "--server_address", + type=str, + default=None, + required=False, + help="The address of server to connect to if using Ray Client.", + ) + + args = parser.parse_args() + + with open(args.config) as f: + config = yaml.safe_load(f) + tune_config = get_tune_config(config.pop("tune_config")) + param_space = get_param_space(config) + column_names = list(param_space.keys()) + target_metric = tune_config["metric"] + + if args.server_address: + ray.init(address=f"ray://{args.server_address}") + else: + ray.init() + + print(f'WARNING: Importing main from "{args.script}" and everything along with it') + + if not args.assume_yes: + print("Please confirm y/n: ", end="") + if input() != "y": + print("Exiting") + exit(1) + + # convert a nested path to a module path + script_path = args.script.replace(".py", "").replace("/", ".") + script = importlib.import_module(script_path) + project_name = "sweep_" + script_path.split(".")[-1] + + param_space["train.project_name"] = project_name + param_space["train.group_name"] = datetime.now().replace(microsecond=0).isoformat() + param_space_train = {"train_loop_config": param_space} + + tuner = tune.Tuner( + AccelerateTrainer( + script.main, + # Mandatory arg. None means use Accelerate default path + accelerate_config=args.accelerate_config, + scaling_config=ScalingConfig( + trainer_resources={"CPU": 0}, + num_workers=args.num_gpus, + use_gpu=True, + resources_per_worker={"CPU": args.num_cpus, "GPU": 1}, + ), + ), + param_space=param_space_train, + tune_config=tune.TuneConfig(**tune_config), + run_config=ray.air.RunConfig(local_dir="ray_results", callbacks=[CSVLoggerCallback()]), + ) + + results = tuner.fit() + group_name = param_space["train.group_name"] + entity_name = param_space.get("train.entity_name", None) + + create_report(target_metric, column_names, entity_name, project_name, group_name, results.get_best_result().config) + + ray.shutdown() diff --git a/examples/BeautifulPrompt/trlx/trainer/__init__.py b/examples/BeautifulPrompt/trlx/trainer/__init__.py new file mode 100644 index 0000000..8e0d239 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/trainer/__init__.py @@ -0,0 +1,103 @@ +import sys +from abc import abstractmethod +from typing import Any, Callable, Dict, Iterable, Optional + +from trlx.data.configs import TRLConfig +from trlx.pipeline import BaseRolloutStore + +# specifies a dictionary of architectures +_TRAINERS: Dict[str, Any] = {} # registry + + +def register_trainer(name): + """Decorator used to register a trainer + Args: + name: Name of the trainer type to register + """ + + def register_class(cls, name): + _TRAINERS[name] = cls + setattr(sys.modules[__name__], name, cls) + return cls + + if isinstance(name, str): + name = name.lower() + return lambda c: register_class(c, name) + + cls = name + name = cls.__name__ + register_class(cls, name.lower()) + + return cls + + +@register_trainer +class BaseRLTrainer: + def __init__( + self, + config: TRLConfig, + reward_fn=None, + metric_fn=None, + logit_mask=None, + stop_sequences=None, + train_mode=False, + ): + self.store: BaseRolloutStore = None + self.config = config + self.reward_fn = reward_fn + self.metric_fn = metric_fn + self.train_mode = train_mode + self.logit_mask = logit_mask + self.stop_sequences = stop_sequences + + def push_to_store(self, data): + self.store.push(data) + + def add_eval_pipeline(self, eval_pipeline): + """Adds pipeline for validation prompts""" + self.eval_pipeline = eval_pipeline + + @abstractmethod + def sample(self, prompts: Iterable[str], length: int, n_samples: int) -> Iterable[str]: + """ + Sample from the language. Takes prompts and maximum length to generate. + + :param prompts: List of prompts to tokenize and use as context + + :param length: How many new tokens to genrate for each prompt + :type length: int + + :param n_samples: Default behavior is to take number of prompts as this + """ + pass + + @abstractmethod + def learn( + self, + log_fn: Callable = None, + save_fn: Callable = None, + eval_fn: Callable = None, + ): + """ + Use experiences in RolloutStore to learn + + :param log_fn: Optional function that is called when logging and passed a dict of logging relevant values + :type log_fn: Callable[Dict[str, any]] + + :param save_fn: Optional function to call after saving. Is passed the components. + :type save_fn: Callable[Dict[str, any]] + + :param eval_fn: Optional function to call during evaluation. Eval doesn't do anything without this. + :type eval_fn: Callable[BaseRLTrainer] + """ + pass + + @abstractmethod + def save(self, directory: Optional[str] = None): + """Creates a checkpoint of training states""" + pass + + @abstractmethod + def load(self, directory=None): + """Loads a checkpoint created from `save`""" + pass diff --git a/examples/BeautifulPrompt/trlx/trainer/accelerate_base_trainer.py b/examples/BeautifulPrompt/trlx/trainer/accelerate_base_trainer.py new file mode 100644 index 0000000..576f416 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/trainer/accelerate_base_trainer.py @@ -0,0 +1,637 @@ +import contextlib +import json +import os +import sys +from abc import abstractmethod +from contextlib import contextmanager +from time import time +from typing import Dict, List, Optional, Tuple + +import ray +import torch +from accelerate import Accelerator # type: ignore +from ray.air import session +from rich.console import Console +from rich.table import Table +from transformers import AutoTokenizer + +import trlx.utils.logging as logging +from trlx.data.configs import TRLConfig +from trlx.pipeline import MiniBatchIterator +from trlx.trainer import BaseRLTrainer, register_trainer +from trlx.utils import ( + filter_non_scalars, + get_distributed_config, + get_git_tag, + get_optimizer_class, + get_scheduler_class, + significant, +) +from trlx.utils.modeling import ( + flatten_dict, + freeze_bottom_causal_layers, + freeze_bottom_seq2seq_layers, + gather_dict, + get_delta_model_class, + parse_delta_kwargs, +) +from trlx.trainer.utils import PPODecorators + +logger = logging.get_logger(__name__) + + +@register_trainer +class AccelerateRLTrainer(BaseRLTrainer): + """ + RL model trainer with an `accelerate` based backend + """ + + def __init__(self, config, **kwargs): # noqa: C901 + super().__init__(config, **kwargs) + self.max_length = config.train.seq_length + if config.train.minibatch_size: + assert config.train.batch_size % config.train.minibatch_size == 0, "Minibatch size must divide batch size" + self.mb_size = config.train.minibatch_size + else: + self.mb_size = config.train.batch_size + self.num_mb = config.train.batch_size // self.mb_size + self.mb_count = 0 + self.accelerator = Accelerator(log_with=config.train.tracker, project_dir=config.train.logging_dir) + + if self.accelerator.state.deepspeed_plugin is not None: + # by accelerate's default, arguments in `model.forward` would be casted to half + if "fp16" in self.accelerator.state.deepspeed_plugin.deepspeed_config: + self.accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["auto_cast"] = False + + if int(os.environ.get("WORLD_SIZE", 1)) > 1: + torch.distributed.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))]) + + self.model = self.setup_model() + self.opt = self.setup_optimizer() + self.scheduler = self.setup_scheduler() + + self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) + self.tokenizer.padding_side = config.tokenizer.padding_side + self.tokenizer.truncation_side = config.tokenizer.truncation_side + self.tokenizer.sep_token = "" + if config.model.model_arch_type != "seq2seq": + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + script_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0] + if not isinstance(config.model.model_path, str): + model_name = str(config.model.model_path).split()[0] + else: + model_name = config.model.model_path.split("/")[-1] + + if self.accelerator.num_processes == 1: + num_gpus = "1gpu" + else: + num_gpus = f"{self.accelerator.num_processes}gpus" + branch = get_git_tag()[0] + + run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}" + + if self.accelerator.is_main_process: + config_dict = self.config.to_dict() + dist_config = get_distributed_config(self.accelerator) + config_dict["distributed"] = dist_config + init_trackers_kwargs = {} + + if config.train.tracker == "wandb": + init_trackers_kwargs["wandb"] = { + "name": run_name, + "entity": self.config.train.entity_name, + "group": self.config.train.group_name, + "tags": self.config.train.tags + ["/".join(get_git_tag())], + "mode": "disabled" if os.environ.get("debug", False) else "online", + } + + self.accelerator.init_trackers( + project_name=self.config.train.project_name, + config=config_dict, + init_kwargs=init_trackers_kwargs, + ) + elif config.train.tracker == "tensorboard": + # flatten config for tensorboard, split list in hparams into flatten config + config_dict_flat = flatten_dict(config_dict) + config_dict_flat["optimizer/kwargs/beta_1"] = config_dict_flat["optimizer/kwargs/betas"][0] + config_dict_flat["optimizer/kwargs/beta_2"] = config_dict_flat["optimizer/kwargs/betas"][1] + config_dict_flat.pop("optimizer/kwargs/betas", None) + + for ix, tag in enumerate(config_dict_flat.pop("train/tags")): + config_dict_flat[f"train/tag_{ix}"] = tag + + self.accelerator.init_trackers( + project_name=self.config.train.project_name, + config=config_dict_flat, + ) + elif config.train.tracker is None: + self.accelerator.init_trackers(project_name=self.config.train.project_name) + else: + raise ValueError( + f"Only supported trackers are `wandb` and `tensorboard`. Got: `{config.train.tracker}`. " + "Set `tracker` to `None` to disable tracking." + ) + + self.nth_evaluation = 0 + self.generate_sweep_kwarg = None + for k, v in self.config.method.gen_kwargs.items(): + if isinstance(v, list): + if self.generate_sweep_kwarg is not None: + logger.info("Only a single sweep is allowed, {k} is going to be set to {v[0]}") + self.generate_kwargs[k] = v[0] + else: + self.generate_sweep_kwarg = (k, v) + + def setup_model(self): + """ + Returns a model derived from an instance's TRLConfig + """ + logger.info(f"Initializing model: {self.config.model.model_path}") + + # Retrieves model equipped for ppo, ilql, etc + model = self.get_arch(self.config) + if self.config.model.model_arch_type == "seq2seq": + freeze_bottom_seq2seq_layers(model.base_model, self.config.model.num_layers_unfrozen) + else: + freeze_bottom_causal_layers(model.base_model, self.config.model.num_layers_unfrozen) + # Set the delta tuning strategies + if self.config.model.delta_kwargs is not None: + delta_type, delta_kwargs = parse_delta_kwargs( + model.base_model.config, + self.config.model.delta_kwargs, + self.config.model.num_layers_unfrozen, + ) + delta_model_class = get_delta_model_class(delta_type) + delta_model = delta_model_class(model.base_model, **delta_kwargs) + delta_model.freeze_module(exclude=["deltas"], set_state_dict=True) + if self.accelerator.is_main_process: + delta_model.log() + return model + + def setup_optimizer(self): + """ + Returns an optimizer derived from an instance's TRLConfig + """ + optimizer_class = get_optimizer_class(self.config.optimizer.name) + optimizer = optimizer_class( + self.model.parameters(), + **self.config.optimizer.kwargs, + ) + + if "bitsandbytes" in optimizer.__class__.__module__: + # Force 32-bit `nn.Embedding` weights for stability. See discussion: + # https://github.com/huggingface/transformers/issues/14819#issuecomment-1016017746 + from bitsandbytes.optim import GlobalOptimManager + + manager = GlobalOptimManager.get_instance() + for module in self.model.modules(): + if isinstance(module, torch.nn.Embedding): + manager.register_module_override(module, "weight", {"optim_bits": 32}) + + return optimizer + + def setup_scheduler(self): + """ + Returns a learning rate scheduler derived from an instance's TRLConfig + """ + scheduler_class = get_scheduler_class(self.config.scheduler.name) + scheduler = scheduler_class(self.opt, **self.config.scheduler.kwargs) + return scheduler + + def decode( + self, + prompts: List[torch.LongTensor], + samples: List[torch.LongTensor], + prompt_sizes: torch.LongTensor = None, + append_eos_token: bool = False, + ) -> Tuple[List[str], List[str], List[str]]: + """ + Decode tensor generations into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str]) + """ + if prompt_sizes is None: + # Assuming prompts were left-padded + prompt_sizes = [prompts.shape[1]] * len(prompts) + + str_samples, str_prompts, str_outputs = [], [], [] + for prompt, sample, prompt_size in zip(prompts, samples, prompt_sizes): + if self.config.model.model_arch_type == "seq2seq": + output_start_ix = 0 + else: + output_start_ix = prompt_size + + str_prompt = self.tokenizer.decode(prompt[:prompt_size], skip_special_tokens=True) + str_output = self.tokenizer.decode(sample[output_start_ix:], skip_special_tokens=True) + # Trim outputs up to `self.stop_sequences` if any are present + trimmed = False + if self.stop_sequences: + for stop in self.stop_sequences: + stop_ix = str_output.find(stop) + if stop_ix >= 0: + str_output = str_output[:stop_ix].rstrip() + trimmed = True + + # Recover the last if it was present in the original sample + # or add one if it was trimmed with `self.stop_sequences`. + # When a generation ended due to `max_new_tokens` exhaustion, + # only then or token would not be present in the original sample at the end + if append_eos_token and ( + trimmed or sample[-1] == self.tokenizer.eos_token_id or sample[-1] == self.tokenizer.pad_token_id + ): + str_output += self.tokenizer.eos_token + + str_prompts.append(str_prompt) + str_outputs.append(str_output) + + if self.config.model.model_arch_type == "seq2seq": + sample = str_prompt + self.tokenizer.sep_token + str_output + else: + sample = str_prompt + str_output + + str_samples.append(sample) + + return str_samples, str_prompts, str_outputs + + def generate(self, input_ids, attention_mask=None, **kwargs): + """Wraps hf's `generate` adding some specific method's defaults""" + input_ids = input_ids.to(self.accelerator.device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.accelerator.device) + if self.generate_experience_kwargs is not None: + kwargs = dict(self.generate_experience_kwargs, **kwargs) + else: + kwargs = dict(self.generate_kwargs, **kwargs) + + with torch.no_grad(): + return self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids, attention_mask=attention_mask, **kwargs + ) + + def generate_eval(self, input_ids, attention_mask=None, **kwargs): + """Wraps hf's `generate` adding some specific method's defaults""" + input_ids = input_ids.to(self.accelerator.device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.accelerator.device) + + kwargs = dict(self.generate_kwargs, **kwargs) + + with torch.no_grad(): + return self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids, attention_mask=attention_mask, **kwargs + ) + + def save_pretrained(self, directory: Optional[str] = None, **kwargs): + """Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for + later use. + + Args: + directory (str, *optional*): The directory to save the trainer files to. + NOTE: If not specified, the model will be saved to a directory named `hf_model` in the + checkpoint directory as specified by the Trainer's config. + **kwargs: Additional keyword arguments passed to the underlying Hugging Face model's + `save_pretrained` method. + """ + if directory is None: + directory = os.path.join(self.config.train.checkpoint_dir, "hf_model") + + self.accelerator.wait_for_everyone() + self.accelerator.unwrap_model(self.model).save_pretrained( + directory, + save_function=self.accelerator.save, + is_main_process=self.accelerator.is_main_process, + state_dict=self.accelerator.get_state_dict(self.model), + **kwargs, + ) + + if self.accelerator.is_main_process: + self.tokenizer.save_pretrained(directory) + + def save(self, directory: Optional[str] = None, **kwargs): + """Creates a checkpoint of the optimizer, scheduler and model""" + self.accelerator.save_state(directory or self.config.train.checkpoint_dir, **kwargs) + + def load(self, directory: Optional[str] = None, **kwargs): + """Load checkpoint of optimizer, scheduler and a model""" + self.accelerator.load_state(directory or self.config.train.checkpoint_dir, **kwargs) + + def add_eval_pipeline(self, eval_pipeline): + """Adds pipeline from with validation prompts""" + self.eval_pipeline = eval_pipeline + + def evaluate(self): # noqa: C901 + """Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided""" + logger.info("Evaluating model") + + # Do multiple evaluations over a single list in `gen_kwargs` if present + if self.generate_sweep_kwarg is not None: + gen_sweep_arg, gen_sweep_values = self.generate_sweep_kwarg + else: + gen_sweep_values = [None] + + desc = [ + f"generation sweep 0/{len(gen_sweep_values)}", + f"eval batch 0/{len(self.eval_dataloader)}", + ] + tbar = logging.tqdm( + total=len(self.eval_dataloader) * len(gen_sweep_values), + desc=f"[{' | '.join(desc)}]", + disable=not self.accelerator.is_main_process, + position=0, + leave=True, + ) + + stats = {} + table = [] + + for i_sweep, gen_sweep_value in enumerate(gen_sweep_values): + # A dedicated suffix for wandb logging + if gen_sweep_value is not None: + sweep_suffix = f"@{gen_sweep_arg}={gen_sweep_value}" + else: + sweep_suffix = "" + + all_samples = [] + all_prompts = [] + all_prompt_sizes = [] + all_metadata = [] + generate_time = time() + for i_prompt, prompts in enumerate(self.eval_dataloader): + metadata = {k: v for k, v in prompts.items() if k != "input_ids" and k != "attention_mask"} + if self.generate_sweep_kwarg: + samples = self.generate_eval( + prompts["input_ids"], prompts["attention_mask"], **{gen_sweep_arg: gen_sweep_value} + ) + else: + samples = self.generate_eval(prompts["input_ids"], prompts["attention_mask"]) + + # TODO(reciprocated): this should be moved into `decode` + # but that needs to be synced with indexing in `make_experience` + if self.config.model.model_arch_type == "seq2seq": + samples = samples[:, 1:].contiguous() + + prompt_sizes = torch.tensor(prompts.input_ids.shape[1]).repeat(len(prompts.input_ids)) + prompts, samples, prompt_sizes = self.accelerator.gather_for_metrics( + self.accelerator.pad_across_processes( + [prompts.input_ids, samples, prompt_sizes.to(samples.device)], + dim=1, + pad_index=self.tokenizer.pad_token_id, + ) + ) + all_samples.extend(samples.tolist()) + all_prompts.extend(prompts.tolist()) + all_prompt_sizes.extend(prompt_sizes.tolist()) + + metadata = gather_dict(metadata, self.accelerator.gradient_state) + all_metadata.append(metadata) + + desc = [ + f"generation sweep {i_sweep + 1}/{len(gen_sweep_values)}", + f"eval batch {i_prompt + 1}/{len(self.eval_dataloader)}", + ] + tbar.set_description(f"[{' | '.join(desc)}]") + tbar.update() + tbar.close() + + stats["time/generate"] = time() - generate_time + + if self.accelerator.is_main_process: + str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes) + + columns = ["prompt", "output"] + columns_data = [str_prompts, str_outputs] + + metadata, *xs = all_metadata + for k in metadata: + for x in xs: + metadata[k].extend(x[k]) + + # in online setting, compute the reward for validation + if self.reward_fn: + logger.info("Computing rewards") + rewards = torch.tensor( + self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata), + dtype=float, + ) + mean_reward = rewards.mean().item() + columns.append("reward") + if not isinstance(rewards, list): + rewards = rewards.tolist() + columns_data.append(rewards) + stats[f"reward/mean{sweep_suffix}"] = mean_reward + + # additionally log any other metrics + if self.metric_fn: + logger.info("Computing metrics") + metric_time = time() + metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata) + stats["time/metric"] = time() - metric_time + + mean_metrics = { + f"metrics/{k}{sweep_suffix}": torch.as_tensor(xs).mean(-1).item() for k, xs in metrics.items() + } + + stats.update(mean_metrics) + + for metric, values in metrics.items(): + # Skip metrics that are scalers since they represent aggregated values + if isinstance(values, float): + continue + columns.append(metric) + if not isinstance(values, list): + values = values.tolist() + columns_data.append(values) + + # Prepend the sweep argument along with samples + if self.generate_sweep_kwarg: + columns.insert(0, gen_sweep_arg) + columns_data.insert(0, [gen_sweep_value] * len(samples)) + + table.append(list(zip(*columns_data))) + + # Log and display evaluation metrics + logger.info("Summarizing evaluation") + if self.accelerator.is_main_process: + rows = sum(list(map(list, zip(*table))), []) + + # Add metrics/rewards to the table's title + table_title = f"Evaluation #{self.nth_evaluation}" + for k, x in stats.items(): + if k.startswith("reward") or k.startswith("metrics"): + table_title += f" {k}: {significant(x)}" + + rich_table = Table(*columns, title=table_title, show_lines=True) + for ix in range(max(min(3, len(rows)), len(gen_sweep_values))): + rich_table.add_row(*[str(significant(x)) for x in rows[ix]]) + Console().print(rich_table) + + if self.config.train.tracker == "wandb": + import wandb + + stats["samples"] = wandb.Table(columns, rows) + + self.nth_evaluation += 1 + return stats + + @contextmanager + def _accumulate(self): + # We can't use accelerator.accumulate() since that checks if the dataloader is exhausted + # and we do exhaust the eval dataloader right before each training loop + self.mb_count += 1 + assert self.mb_count // self.num_mb <= self.config.train.total_steps, "Beyond total steps, something is wrong" + if ( + self.mb_count % self.accelerator.gradient_accumulation_steps == 0 + or self.mb_count // self.num_mb >= self.config.train.total_steps + ): + context = contextlib.nullcontext + else: + context = self.accelerator.no_sync + with context(self.model): + yield + + @PPODecorators.empty_cuda_cache() + def learn(self): # noqa: C901 + """ + Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader` + """ + logger.info("Starting training") + + self.prepare_learning() + self.iter_count = 0 + self.nth_evaluation = 0 + + if ray.is_initialized(): + checkpoint = session.get_checkpoint() + if checkpoint: + with checkpoint.as_directory() as dir: + self.accelerator.load_state(dir) + + with open(os.path.join(dir, "state.json")) as f: + state = json.load(f) + self.iter_count = state["iter_count"] + else: + results = self.evaluate() + self.accelerator.log(results, step=self.iter_count) + + tbar = logging.tqdm( + initial=self.iter_count, + total=self.total_steps, + disable=not self.accelerator.is_local_main_process, + position=0, + leave=True, + ) + + best_reward = -float("inf") + + # For each epoch + for _ in range(self.config.train.epochs): + # For each batch + for mbs in MiniBatchIterator(self.train_dataloader, self.mb_size, self.num_mb): + # For each update per batch + for _ in range(self.n_updates_per_batch): + # Note that whereas standard policy gradient methods perform one + # gradient update per batch, PPO for example commonly performs + # multiple gradient updates on the same batch of data. + # https://arxiv.org/pdf/1707.06347.pdf + forward_time = 0 + backward_time = 0 + stats_accum = [] + for mb in mbs: + with self._accumulate(): + forward_time -= time() + loss, stats = self.loss(mb) + forward_time += time() + backward_time -= time() + self.accelerator.backward(loss) + backward_time += time() + stats_accum.append(stats) + + forward_time /= self.num_mb + backward_time /= self.num_mb + # TODO(Dahoas): Best way to combine stats between mbs? + # How does accelerate do it? + stats = {key: sum([stats[key] for stats in stats_accum]) / self.num_mb for key in stats_accum[0]} + + self.opt.step() + self.opt.zero_grad() + self.scheduler.step() + self.iter_count += 1 + + if ( + self.iter_count % self.config.train.checkpoint_interval == 0 + or self.iter_count >= self.total_steps + ): + subfolder = f"checkpoint_{self.iter_count:0{len(str(self.total_steps))}d}" + directory = os.path.join(self.config.train.checkpoint_dir, subfolder) + logger.info(f"Saving intermediate checkpoint into {directory}") + if self.config.train.save_optimizer: + self.save(directory) + else: + self.save_pretrained(directory) + + stats["time/forward"] = forward_time + stats["time/backward"] = backward_time + for group_number, lr in enumerate(self.scheduler.get_last_lr()): + stats[f"learning_rate_group_{group_number}"] = lr + + if self.iter_count % self.config.train.eval_interval == 0 or self.iter_count >= self.total_steps: + results = self.evaluate() + stats.update(results) + if ray.is_initialized(): + session.report(filter_non_scalars(stats), checkpoint=checkpoint) + + # always save checkpoint with the greatest mean reward + if self.config.train.save_best: + if stats.get("reward/mean", -float("inf")) > best_reward: + best_reward = stats.get("reward/mean") + do_save = True + # in case ILQL reports reward estimate as one of its metrics + elif stats.get("metrics/reward", -float("inf")) > best_reward: + best_reward = stats.get("metrics/reward") + do_save = True + else: + do_save = False + do_save = torch.tensor(do_save, device=self.accelerator.device) + if torch.distributed.is_initialized(): + torch.distributed.all_reduce(do_save, torch.distributed.ReduceOp.MAX) + if do_save: + directory = os.path.join(self.config.train.checkpoint_dir, "best_checkpoint") + logger.info(f"Saving the best state so far into {directory}") + if self.config.train.save_optimizer: + self.save(directory) + else: + self.save_pretrained(directory) + + desc = " | ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss")) + tbar.set_description(f"[{desc}]") + tbar.update() + + self.accelerator.log(stats, step=self.iter_count) + + if self.iter_count >= self.total_steps: + return results + + self.post_backward_callback() + + self.post_epoch_callback() + tbar.close() + + @abstractmethod + def get_arch(self, config: TRLConfig): + """Returns a specific wrapper of the decoder architecture""" + pass + + @abstractmethod + def loss(self, batch) -> Tuple[float, Dict]: + """Compute loss on a batch from `store` and return some statistics""" + pass + + @abstractmethod + def post_backward_callback(self): + """Do something after model update""" + pass + + @abstractmethod + def post_epoch_callback(self): + """Do something after exhausting/single pass over `self.store`""" + pass diff --git a/examples/BeautifulPrompt/trlx/trainer/accelerate_ilql_trainer.py b/examples/BeautifulPrompt/trlx/trainer/accelerate_ilql_trainer.py new file mode 100644 index 0000000..e41d1f9 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/trainer/accelerate_ilql_trainer.py @@ -0,0 +1,250 @@ +import os +from typing import Union, cast + +import numpy as np +import torch +import transformers +from rich.console import Console +from rich.table import Table + +import trlx.utils.logging as logging +from trlx.data.configs import TRLConfig +from trlx.data.ilql_types import ILQLBatch, ILQLSeq2SeqBatch +from trlx.models.modeling_ilql import ( + AutoModelForCausalLMWithILQLHeads, + AutoModelForSeq2SeqLMWithILQLHeads, + ILQLConfig, +) +from trlx.pipeline.offline_pipeline import ( + ILQLRolloutStorage, + ILQLSeq2SeqRolloutStorage, + tokenize_dialogue, +) +from trlx.trainer import register_trainer +from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer +from trlx.utils import to_device + +logger = logging.get_logger(__name__) + + +def make_experience(samples, rewards, tokenizer=None, max_length=2048, verbose=True): # noqa: C901 + """ + Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer + """ + + if verbose: + logger.info("Collecting rollouts") + if tokenizer is not None: + samples = [tokenize_dialogue(s, tokenizer, max_length) for s in samples] + + all_input_ids = [] + all_actions_ixs = [] + all_states_ixs = [] + all_dones = [] + for sample in samples: + length = 0 + all_input_ids.append(torch.tensor(sum((s.tokens for s in sample), ()))) + actions_ixs = [] + for dm in sample: + if dm.is_output: + actions_ixs.append(torch.arange(length - 1, length + len(dm.tokens) - 1)) + + length += len(dm.tokens) + + states_ixs = torch.hstack((*actions_ixs, torch.tensor(length - 1))) + all_dones.append(torch.tensor([1] * (len(states_ixs) - 1) + [0], dtype=int)) + all_actions_ixs.append(torch.hstack(actions_ixs)) + all_states_ixs.append(states_ixs) + + if tokenizer is not None and os.environ.get("RANK", "0") == "0" and verbose: + logger.info("Logging sample example") + prompt = tokenizer.decode(all_input_ids[0][: all_states_ixs[0][1]]) + response = tokenizer.decode(all_input_ids[0][all_states_ixs[0][1] :]) + columns = ["Prompt", "Response", "Reward"] + table = Table(*columns, title="Sample Example", show_lines=True) + table.add_row(prompt, response, str(rewards[0])) + Console().print(table) + + sample_lengths = np.array(list(map(len, all_input_ids))) + output_lengths = np.array(list(map(len, all_actions_ixs))) + prompt_lengths = sample_lengths - output_lengths + returns = torch.tensor(rewards, dtype=float) + + if os.environ.get("RANK", "0") == "0" and verbose: + logger.info("Logging experience string statistics") + columns = ["Prompt Length", "Output Length", "Sample Length"] + table = Table(*columns, title="Experience String Stats (mean ∈ \[min, max])", show_lines=True) + row = [] + for lengths in [prompt_lengths, output_lengths, sample_lengths]: + row.append(f"{lengths.mean():.2f} ∈ [{min(lengths)}, {max(lengths)}]") + table.add_row(*row) + Console().print(table) + + returns = returns - returns.mean() + std_returns = returns.std() + if not torch.isnan(std_returns): + returns = returns / (std_returns + torch.finfo(returns.dtype).eps) + rewards = [torch.zeros(len(x)) for x in all_actions_ixs] + for rs, ret in zip(rewards, returns): + rs[-1] = ret + + attention_mask = [torch.ones(len(x), dtype=int) for x in all_input_ids] + + return ILQLRolloutStorage( + all_input_ids, + attention_mask, + rewards, + all_states_ixs, + all_actions_ixs, + all_dones, + ) + + +@register_trainer +class AccelerateILQLTrainer(AccelerateRLTrainer): + def __init__(self, config: TRLConfig, **kwargs): + super().__init__(config, **kwargs) + + if not isinstance(config.method, ILQLConfig): + raise ValueError("config.method must be ILQLConfig") + + self.ilql: ILQLConfig = cast(ILQLConfig, config.method) + + self.generate_kwargs = dict( + config.method.gen_kwargs, + max_length=self.max_length, + logit_mask=self.logit_mask, + eos_token_id=self.tokenizer.eos_token_id if self.tokenizer else 0, + pad_token_id=self.tokenizer.pad_token_id if self.tokenizer else 0, + ) + + def get_arch(self, config): + if config.model.model_arch_type == "seq2seq": + from_fn = AutoModelForSeq2SeqLMWithILQLHeads.from_pretrained + if issubclass(type(config.model.model_path), transformers.PretrainedConfig): + from_fn = AutoModelForSeq2SeqLMWithILQLHeads.from_config + else: + from_fn = AutoModelForCausalLMWithILQLHeads.from_pretrained + if issubclass(type(config.model.model_path), transformers.PretrainedConfig): + from_fn = AutoModelForCausalLMWithILQLHeads.from_config + return from_fn( + config.model.model_path, + two_qs=config.method.two_qs, + alpha=config.method.alpha, + ) + + def post_backward_callback(self): + if self.iter_count % self.config.method.steps_for_target_q_sync == 0: + self.accelerator.unwrap_model(self.model).sync_target_q_heads() + + def loss(self, batch: Union[ILQLBatch, ILQLSeq2SeqBatch]): + batch = to_device(batch, self.accelerator.device) + if self.config.model.model_arch_type == "seq2seq": + logits, qs, target_qs, vs, _, _ = self.model( + input_ids=batch.input_ids, + attention_mask=batch.attention_mask, + actions_ixs=batch.actions_ixs, + states_ixs=batch.states_ixs, + decoder_input_ids=batch.decoder_input_ids, + ) + else: + logits, qs, target_qs, vs, _ = self.model( + input_ids=batch.input_ids, + attention_mask=batch.attention_mask, + actions_ixs=batch.actions_ixs, + states_ixs=batch.states_ixs, + ) + + return self.ilql.loss((logits, (qs, target_qs, vs)), batch) + + def prepare_learning(self): + train_dataloader = self.store.create_loader(self.config.train.batch_size) + eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) + + ( + self.model, + self.opt, + self.train_dataloader, + self.eval_dataloader, + ) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader) + + self.n_updates_per_batch = 1 + self.total_steps = self.config.train.epochs * len(self.train_dataloader) + self.total_steps = min(self.total_steps, self.config.train.total_steps) + + def make_experience_seq2seq(self, samples, rewards, max_length=2048): + """ + Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer + """ + logger.info("Collecting rollouts") + if self.tokenizer: + samples = [tokenize_dialogue(s, self.tokenizer, max_length) for s in samples] + + all_input_ids = [] + all_output_ids = [] + all_actions_ixs = [] + all_states_ixs = [] + all_dones = [] + for sample in samples: + all_input_ids.append(torch.tensor(sample[0].tokens)) + all_output_ids.append(torch.tensor(sample[1].tokens)) + actions_ixs = [] + length = 0 + for phrase in sample: + if phrase.is_output: + length = len(phrase.tokens) + actions_ixs.append(torch.arange(0, length - 1)) + states_ixs = torch.hstack((*actions_ixs, torch.tensor(length - 1))) + all_dones.append(torch.tensor([1] * (len(states_ixs) - 1) + [0], dtype=int)) + all_actions_ixs.append(torch.hstack(actions_ixs)) + all_states_ixs.append(states_ixs) + + if self.tokenizer and os.environ.get("RANK", "0") == "0": + logger.info("Logging sample example") + prompt = self.tokenizer.decode(all_input_ids[0]) + response = self.tokenizer.decode(all_output_ids[0]) + columns = ["Prompt", "Response", "Reward"] + table = Table(*columns, title="Sample Example", show_lines=True) + table.add_row(prompt, response, str(rewards[0])) + Console().print(table) + + sample_lengths = np.array(list(map(len, all_input_ids))) + np.array(list(map(len, all_output_ids))) + output_lengths = np.array(list(map(len, all_output_ids))) + prompt_lengths = sample_lengths - output_lengths + returns = torch.tensor(rewards, dtype=float) + + if os.environ.get("RANK", "0") == "0": + logger.info("Logging experience string statistics") + columns = ["Prompt Length", "Output Length", "Sample Length"] + table = Table(*columns, title="Experience String Stats (mean ∈ \[min, max])", show_lines=True) + row = [] + for lengths in [prompt_lengths, output_lengths, sample_lengths]: + row.append(f"{lengths.mean():.2f} ∈ [{min(lengths)}, {max(lengths)}]") + table.add_row(*row) + Console().print(table) + + returns = (returns - returns.mean()) / (returns.std() + torch.finfo(returns.dtype).eps) + rewards = [torch.zeros(len(x)) for x in all_actions_ixs] + for rs, ret in zip(rewards, returns): + rs[-1] = ret + + attention_mask = [torch.ones(len(x), dtype=int) for x in all_input_ids] + self.store = ILQLSeq2SeqRolloutStorage( + all_input_ids, + attention_mask, + all_output_ids, + rewards, + all_states_ixs, + all_actions_ixs, + all_dones, + ) + + def make_experience(self, samples, rewards, max_length=2048): + """ + Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer + """ + + if self.config.model.model_arch_type == "seq2seq": + return self.make_experience_seq2seq(samples, rewards, max_length) + + self.store = make_experience(samples, rewards, self.tokenizer, max_length=max_length, verbose=True) diff --git a/examples/BeautifulPrompt/trlx/trainer/accelerate_ppo_trainer.py b/examples/BeautifulPrompt/trlx/trainer/accelerate_ppo_trainer.py new file mode 100644 index 0000000..3640568 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/trainer/accelerate_ppo_trainer.py @@ -0,0 +1,495 @@ +import json +import os +import uuid +from time import time +from typing import Callable, List + +import torch +import torch.nn.functional as F +import transformers +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +import trlx.utils.logging as logging +from trlx.data.accelerate_base_datatypes import PromptBatch +from trlx.data.configs import TRLConfig +from trlx.data.ppo_types import PPORLBatch, PPORLElement +from trlx.models.modeling_ppo import ( + AdaptiveKLController, + AutoModelForCausalLMWithHydraValueHead, + AutoModelForSeq2SeqLMWithHydraValueHead, + FixedKLController, +) +from trlx.pipeline.offline_pipeline import PromptPipeline +from trlx.pipeline.ppo_pipeline import PPORolloutStorage +from trlx.trainer import register_trainer +from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer +from trlx.utils import Clock, infinite_dataloader +from trlx.utils.modeling import RunningMoments, gather_dict, logprobs_of_labels +from trlx.trainer.utils import PPODecorators + +logger = logging.get_logger(__name__) + +@register_trainer +class AcceleratePPOTrainer(AccelerateRLTrainer): + """PPO Accelerate Trainer""" + + reward_fn: Callable[[List[str], List[str], List[str]], List[float]] + tokenizer: AutoTokenizer + + def __init__(self, config: TRLConfig, **kwargs): + """PPO Accelerate Trainer initialization + + Args: + config: Config + """ + super().__init__(config, **kwargs) + + # Setup rollout logging + if config.train.rollout_logging_dir is not None: + self.log_rollouts = True + self.setup_rollout_logging(config) + else: + self.log_rollouts = False + + # Setup the rollout store + # Rollouts contain the prompt & response, log probs, values and rewards - from each rollout + self.store = PPORolloutStorage(self.tokenizer.pad_token_id) + + # Create the rollout store dataloader (for batching up rollouts) + # TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future + rollout_loader: DataLoader = self.store.create_loader(self.config.train.batch_size, shuffle=True) + + # Prepare multi-GPU acceleration + self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare( + self.model, self.opt, self.scheduler, rollout_loader + ) + + self.store.clear_history() # Clear the rollout store + + # Setup a reference model when hydra heads are not used + if not hasattr(self.model, "frozen_head"): + self.ref_model = self.get_arch(self.config) + self.ref_model.to(self.accelerator.device) + self.ref_model.eval() + + # Setup the KL controller + # This helps prevent large divergences in the controller (policy) + if config.method.target is not None: + self.kl_ctl = AdaptiveKLController(config.method.init_kl_coef, config.method.target, config.method.horizon) + else: + self.kl_ctl = FixedKLController(config.method.init_kl_coef) + + # Create the parameters for the Hugging Face language model's generator + # method (that generates new tokens from a prompt). + # https://huggingface.co/docs/transformers/v4.25.1/en/main_classes/text_generation#transformers.GenerationMixin.generate + if config.model.model_arch_type == "seq2seq": + self.generate_kwargs = dict( + config.method.gen_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + if config.method.gen_experience_kwargs is not None: + self.generate_experience_kwargs = dict( + config.method.gen_experience_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + else: + self.generate_experience_kwargs = None + else: + self.generate_kwargs = dict( + config.method.gen_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + ) + if config.method.gen_experience_kwargs is not None: + self.generate_experience_kwargs = dict( + config.method.gen_experience_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + ) + else: + self.generate_experience_kwargs = None + + # Setup stats tracker + self.running_moments = RunningMoments() + self.ref_mean = self.config.method.ref_mean + self.ref_std = self.config.method.ref_std + + def get_arch(self, config: TRLConfig): + """Get the model""" + model_class = AutoModelForCausalLMWithHydraValueHead + if config.model.model_arch_type == "seq2seq": + model_class = AutoModelForSeq2SeqLMWithHydraValueHead + + from_fn = model_class.from_pretrained + # backward-compat: Try to create a randomly initialized architecture from a config + if issubclass(type(config.model.model_path), transformers.PretrainedConfig): + from_fn = model_class.from_config + + return from_fn( + config.model.model_path, + num_layers_unfrozen=config.model.num_layers_unfrozen, + ) + + def loss(self, batch: PPORLBatch): + """Forward pass & loss + + Args: + batch: Previous batch of episodes + """ + # Move `batch` data to `accelerator` device + query_tensors = batch.query_tensors.to(self.accelerator.device) + response_tensors = batch.response_tensors.to(self.accelerator.device) + old_logprobs = batch.logprobs.to(self.accelerator.device) + old_values = batch.values.to(self.accelerator.device) + old_rewards = batch.rewards.to(self.accelerator.device) + response_length = old_rewards.shape[1] + + advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length) + + if self.config.model.model_arch_type == "seq2seq": + input_ids = query_tensors + decoder_input_ids = response_tensors + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + decoder_attention_mask = ( + decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + ) + decoder_attention_mask[:, 0] = 1 + + # Forward pass + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + + logits = outputs.logits + values_pred = outputs.value + logprobs = logprobs_of_labels(logits[:, :-1, :], decoder_input_ids[:, 1:]) + mask = decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) + start = 0 + end = start + response_length + logprobs, values_pred, mask = ( + logprobs[:, start:end], + values_pred[:, start:end], + mask[:, start:end], + ) + else: + tokens = torch.cat((query_tensors, response_tensors), dim=1) + attention_mask = tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device) + outputs = self.model(tokens, attention_mask, return_dict=True) + logits = outputs.logits + values_pred = outputs.value + values_pred = values_pred[:, :-1] + logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:]) + + start = query_tensors.shape[1] - 1 + end = start + response_length + logprobs, values_pred, mask = ( + logprobs[:, start:end], + values_pred[:, start:end], + attention_mask[:, start:end], + ) + + loss, stats = self.config.method.loss( + logprobs=logprobs, + values=values_pred, + old_logprobs=old_logprobs, + old_values=old_values, + advantages=advantages, + returns=returns, + mask=mask, + ) + + return loss, stats + + def setup_rollout_logging(self, config): + # Make rollout logging dir for this run and store config + exists = os.path.exists(config.train.rollout_logging_dir) + isdir = os.path.isdir(config.train.rollout_logging_dir) + assert exists and isdir + + self.run_id = f"run-{uuid.uuid4()}" + self.rollout_logging_dir = os.path.join(config.train.rollout_logging_dir, self.run_id) + os.mkdir(self.rollout_logging_dir) + + with open(os.path.join(self.rollout_logging_dir, "config.json"), "w") as f: + f.write(json.dumps(config.to_dict(), indent=2)) + + def post_epoch_callback(self): + """Post epoch callback + + Clears the store and creates `num_rollouts` new episodes. + """ + if self.log_rollouts: + self.store.export_history(location=self.rollout_logging_dir) + self.store.clear_history() + # Collect more rollouts for training + self.make_experience(self.config.method.num_rollouts, self.iter_count) + + def post_backward_callback(self): + self.kl_ctl.update(self.mean_kl, n_steps=self.config.train.batch_size) + + def prepare_learning(self): + eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) + self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader) + self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=True) + + self.n_updates_per_batch = self.config.method.ppo_epochs + self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) + self.total_steps = min(self.total_steps, self.config.train.total_steps) + + def add_prompt_pipeline(self, pipeline: PromptPipeline): + """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" + prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) + prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) + self.prompt_iterator = infinite_dataloader(prompt_dataloader) + + @PPODecorators.empty_cuda_cache() + def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa: + """Make experiences + + Takes `chunk_size` number of prompts from `prompt_iterator`, samples + from the model and then computes the KL against a reference model. Finally it + then appends PPOElements to trainer's `store`. + + Args: + num_rollouts: Number of rollouts to generate + iter_count: Total number of updates run (i.e. number of updates run for all batches & epochs) + """ + logger.info("Collecting rollouts") + tbar = logging.tqdm( + total=num_rollouts, + disable=os.environ.get("RANK", 0) != "0", + desc=f"[rollout 0 / {num_rollouts}]", + # Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress + # bars (e.g. loss progress in trainers) + position=logging.get_verbosity() >= logging.WARNING, + # Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels + leave=logging.get_verbosity() < logging.WARNING, + ) + + clock = Clock() + ppo_rl_elements = [] + accumulated_stats = [] + + while len(ppo_rl_elements) < num_rollouts: + stats = {} + # Get next batch in prompt dataset + batch: PromptBatch = next(self.prompt_iterator) + + rollout_generate_time = time() + + # Generate samples from the language model (similar to using HuggingFace `generate` method) + samples = self.generate(batch["input_ids"], batch["attention_mask"]) + stats["time/rollout_generate"] = time() - rollout_generate_time + + prompt_tensors = batch.input_ids + device = samples.device + + prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) + padded_samples = self.accelerator.pad_across_processes( + samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False + ) + padded_prompts = self.accelerator.pad_across_processes( + prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False + ) + gathered_samples = self.accelerator.gather(padded_samples) + gathered_prompts = self.accelerator.gather(padded_prompts) + gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) + metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}) + + if self.accelerator.is_main_process: + all_str_samples, all_str_prompts, all_str_outputs = self.decode( + gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True + ) + + rollout_score_time = time() + all_scores = torch.tensor( + self.reward_fn( + samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, **metadata + ), + dtype=torch.float, + device=device, + ) + stats["time/rollout_score"] = time() - rollout_score_time + + all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1).unbind()) + else: + all_scores = None + + if torch.distributed.is_initialized(): + scores = torch.empty(len(samples), device=device) + torch.distributed.scatter(scores, all_scores) + else: + scores = all_scores[0].clone().detach() + + str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) + + # Pad the sample outputs + outputs = self.tokenizer(str_outputs).input_ids + if self.config.model.model_arch_type == "seq2seq": + # add to the start of the output + for i in range(len(outputs)): + outputs[i] = [self.tokenizer.pad_token_id] + outputs[i] + + outputs = list(map(torch.LongTensor, outputs)) + maxsize = max(map(len, outputs)) + outputs = [ + F.pad( + output, + (0, maxsize - len(output)), + value=self.tokenizer.pad_token_id, + ) + for output in outputs + ] + sample_outputs = torch.vstack(outputs).to(device) + + if self.config.method.cliprange_reward: + scores = torch.clip(scores, -self.config.method.cliprange_reward, self.config.method.cliprange_reward) + + # store statistics of the initial rollout as reference + if self.ref_mean is None: + self.ref_mean, self.ref_std = scores.mean(), scores.std() + all_scores_mean, all_scores_std = self.running_moments.update(scores) + stats["rollout_scores/mean"] = all_scores_mean.item() + stats["rollout_scores/std"] = all_scores_std.item() + stats["rollout_scores/running_mean"] = self.running_moments.mean.item() + stats["rollout_scores/running_std"] = self.running_moments.std.item() + + if self.config.method.scale_reward == "running": + scores /= self.running_moments.std + elif self.config.method.scale_reward == "ref": + scores /= self.ref_std + + # Precompute logprobs, values + if self.config.model.model_arch_type == "seq2seq": + attention_mask = batch.attention_mask.to(device) + prompt_tensors = batch.input_ids.to(device) + decoder_attention_mask = sample_outputs.not_equal(self.tokenizer.pad_token_id) + decoder_attention_mask[:, 0] = 1 + with torch.no_grad(): + outputs = self.model( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + decoder_attention_mask=decoder_attention_mask, + ) + logits = outputs.logits + values = outputs.value + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + decoder_attention_mask=decoder_attention_mask, + return_dict=True, + ).logits + else: + ref_logits = self.ref_model( + input_ids=prompt_tensors, + attention_mask=attention_mask, + decoder_input_ids=sample_outputs, + decoder_attention_mask=decoder_attention_mask, + return_dict=True, + ).logits + else: + all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) + attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) + with torch.no_grad(): + logits, *_, values = self.model( + all_tokens, + attention_mask=attention_mask, + ) + # TODO(dahoas): When hydra model works need to also support generation on hydra head + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + all_tokens, + attention_mask=attention_mask, + return_dict=True, + ).logits + else: + ref_logits = self.ref_model( + all_tokens, + attention_mask=attention_mask, + return_dict=True, + ).logits + ref_logits = ref_logits.to(device) + + if self.config.model.model_arch_type == "seq2seq": + logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) + else: + logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) + + n_samples: int = samples.shape[0] + + # Estimate the KL divergence between the model and reference model + if self.config.model.model_arch_type == "seq2seq": + attention_mask = sample_outputs != self.tokenizer.pad_token_id + start = 0 + else: + start = prompt_tensors.shape[1] - 1 + + log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] + kl = log_ratio.exp() - 1 - log_ratio + mean_kl_per_token = kl.mean() + mean_kl = kl.sum(1).mean() + + logprobs = logprobs.cpu() + ref_logprobs = ref_logprobs.cpu() + prompt_tensors = prompt_tensors.cpu() + sample_outputs = sample_outputs.cpu() + values = values.cpu()[:, :-1] + + # Get the logprobs and values, for tokens that are not padding, + # from the start of the prompt up to the token, while also including the latter + # (these are taken from the student model and not the reference model) + ends = start + attention_mask[:, start:].sum(1) + 1 + all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] + all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] + + kl_penalty = self.kl_ctl.value * -log_ratio.cpu() + kl_penalty = [xs[start : ends[ix]] for ix, xs in enumerate(kl_penalty)] + + rollout_count = 0 + + for sample_idx in range(n_samples): + rewards = kl_penalty[sample_idx] + rewards[-1] += scores[sample_idx].cpu() + + ppo_rl_elements.append( + PPORLElement( + query_tensor=prompt_tensors[sample_idx], + response_tensor=sample_outputs[sample_idx], + logprobs=all_logprobs[sample_idx], + values=all_values[sample_idx], + rewards=rewards, + ) + ) + + rollout_count += 1 + + if torch.distributed.is_initialized(): + torch.distributed.all_reduce(mean_kl, torch.distributed.ReduceOp.AVG) + + stats["time/rollout_time"] = clock.tick() + stats["policy/sqrt_kl"] = torch.sqrt(mean_kl).item() + stats["policy/kl_per_token"] = torch.sqrt(mean_kl_per_token).item() + accumulated_stats.append(stats) + + tbar.set_description(f"[rollout {len(ppo_rl_elements)} / {num_rollouts}]") + tbar.update(min(rollout_count, num_rollouts)) + tbar.close() + + stats = {k: sum([xs[k] for xs in accumulated_stats]) / len(accumulated_stats) for k in stats} + stats["kl_ctl_value"] = self.kl_ctl.value + self.mean_kl = stats["policy/sqrt_kl"] ** 2 + self.accelerator.log(stats, step=iter_count) + + # Push samples and rewards to trainer's rollout storage + self.push_to_store(ppo_rl_elements) diff --git a/examples/BeautifulPrompt/trlx/trainer/accelerate_sft_trainer.py b/examples/BeautifulPrompt/trlx/trainer/accelerate_sft_trainer.py new file mode 100644 index 0000000..cb471ca --- /dev/null +++ b/examples/BeautifulPrompt/trlx/trainer/accelerate_sft_trainer.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass + +from transformers import AutoModelForCausalLM + +from trlx.data.configs import TRLConfig +from trlx.data.method_configs import MethodConfig, register_method +from trlx.pipeline.offline_pipeline import ( + DialogStore, + PromptPipeline, + tokenize_dialogue, +) +from trlx.trainer import register_trainer +from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer + + +@dataclass +@register_method +class SFTConfig(MethodConfig): + """ + Config for SFT training + + :param gen_kwargs: kwargs for generation + :type gen_kwargs: Dict[str, Any] + """ + + gen_kwargs: dict + + +@register_trainer +class AccelerateSFTTrainer(AccelerateRLTrainer): + def __init__(self, config: TRLConfig, **kwargs): + super().__init__(config, **kwargs) + + self.generate_kwargs = dict( + config.method.gen_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + + def get_arch(self, config): + return AutoModelForCausalLM.from_pretrained(config.model.model_path) + + def loss(self, batch): + if "labels" in batch: + labels = batch.labels.clone() + else: + labels = batch.input_ids.clone() + labels[~batch.attention_mask.bool()] = -100 + + loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=labels).loss + stats = {"loss": loss.item()} + + return loss, stats + + def prepare_learning(self): + train_dataloader = self.store.create_loader(self.config.train.batch_size) + eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) + + ( + self.model, + self.opt, + self.train_dataloader, + self.eval_dataloader, + ) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader) + + self.n_updates_per_batch = 1 + self.total_steps = self.config.train.epochs * len(self.train_dataloader) + self.total_steps = min(self.total_steps, self.config.train.total_steps) + + def make_experience(self, samples, seq_length): + if isinstance(samples[0], str): + self.store = PromptPipeline(samples, seq_length, self.tokenizer) + else: + dialogs = [tokenize_dialogue(d, self.tokenizer, seq_length) for d in samples] + self.store = DialogStore(dialogs, self.tokenizer) diff --git a/examples/BeautifulPrompt/trlx/trainer/nemo_ilql_trainer.py b/examples/BeautifulPrompt/trlx/trainer/nemo_ilql_trainer.py new file mode 100644 index 0000000..099c976 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/trainer/nemo_ilql_trainer.py @@ -0,0 +1,204 @@ +from pathlib import Path +from typing import Iterable, Sequence, cast + +import torch +from nemo.collections.nlp.parts.nlp_overrides import ( + GradScaler, + MegatronHalfPrecisionPlugin, + NLPDDPStrategy, + PipelineMixedPrecisionPlugin, +) +from nemo.utils import get_rank, logging +from nemo.utils.exp_manager import StatelessTimer, exp_manager +from omegaconf.omegaconf import OmegaConf, open_dict +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.callbacks.timer import Timer +from pytorch_lightning.trainer.connectors.checkpoint_connector import ( + CheckpointConnector, +) + +from trlx.data.configs import TRLConfig +from trlx.data.ilql_types import ILQLBatch, ILQLElement, flatten_dataclass +from trlx.models.modeling_ilql import ILQLConfig +from trlx.models.modeling_nemo_ilql import ILQLGPT +from trlx.pipeline.offline_pipeline import ILQLRolloutStorage, ilql_collate_fn +from trlx.trainer import register_trainer +from trlx.trainer.accelerate_ilql_trainer import make_experience + +from . import BaseRLTrainer + + +def megatron_trainer(cfg): + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + seed_everything(cfg.model.get("seed", 1000)) + + megatron_amp_o2 = cfg.model.get("megatron_amp_O2", False) + with_distributed_adam = cfg.model.optim.get("name") == "distributed_fused_adam" + + plugins = [] + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, # we don't use DDP for async grad allreduce + gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, + find_unused_parameters=False, + ) + if cfg.trainer.precision in [16, "bf16"]: + scaler = None + if cfg.trainer.precision == 16: + scaler = GradScaler( + init_scale=cfg.model.get("native_amp_init_scale", 2**32), + growth_interval=cfg.model.get("native_amp_growth_interval", 1000), + hysteresis=cfg.model.get("hysteresis", 2), + ) + if megatron_amp_o2 and not with_distributed_adam: + plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device="cuda", scaler=scaler)) + else: + plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device="cuda", scaler=scaler)) + + trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) + try: + exp_manager(trainer, cfg.exp_manager) + except FileNotFoundError: + print(f"exp_manager failed to find git-rev, continuing anyway, {FileNotFoundError}") + # update resume from checkpoint found by exp_manager + if cfg.model.resume_from_checkpoint is not None: + resume_from_checkpoint = cfg.model.resume_from_checkpoint + else: + resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path + + logging.info(f"Resuming training from checkpoint: {resume_from_checkpoint}") + + trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint) + # Override timer callback to a stateless one + for idx, callback in enumerate(trainer.callbacks): + if isinstance(callback, Timer): + trainer.callbacks[idx] = StatelessTimer( + cfg.trainer.max_time, + ) + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + + return trainer + + +class ShuffledCyclicSequence: + def __init__(self, new_length: int, data: Sequence, seed: int): + self.data = data + self.new_length = new_length + + rng = torch.Generator().manual_seed(seed) + self.perm = torch.randperm(new_length, generator=rng, device="cpu") + + def __len__(self): + return self.new_length + + def __getitem__(self, idx): + permuted_idx = self.perm[idx].item() + return self.data[permuted_idx % len(self.data)] + + +@register_trainer +class NeMoILQLTrainer(BaseRLTrainer): + store: ILQLRolloutStorage + + def __init__( + self, + config: TRLConfig, + reward_fn=None, + logit_mask=None, + metric_fn=None, + stop_sequences=None, + train_mode=True, + megatron_cfg=None, + pretrained_model=None, + ): + super().__init__(config, train_mode) + self.logit_mask = logit_mask + self.metric_fn = metric_fn + self.reward_fn = None + + if not isinstance(config.method, ILQLConfig): + raise ValueError("config.method must be ILQLConfig") + + self.ilql_config: ILQLConfig = cast(ILQLConfig, config.method) + if isinstance(megatron_cfg, str): + cfg_path = Path(__file__).parent.parent.parent / "configs" / "nemo_configs" / megatron_cfg + logging.info(f"Loading NeMo config from {cfg_path=}") + megatron_cfg = OmegaConf.load(cfg_path) + + elif megatron_cfg is None: + raise ValueError("megatron_cfg must be a path or a config") + + self.trainer = megatron_trainer(megatron_cfg) + self.model = ILQLGPT( + ilql_config=self.ilql_config, + metric_fn=self.metric_fn, + cfg=megatron_cfg.model, + trainer=self.trainer, + ) + + if pretrained_model is not None: + self.model.load_from_pretrained(pretrained_model) + + self.batch_size = megatron_cfg.model.global_batch_size + self.tokenizer = self.model.tokenizer.tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token + self.max_length = megatron_cfg.model.encoder_seq_length + + self.tokenizer.truncation_side = config.tokenizer.truncation_side + + if stop_sequences is not None and len(stop_sequences) > 0: + logging.warning(f"Ignoring stop_sequences {stop_sequences=}") + + def learn(self): + def collate_fn(elems: Iterable[ILQLElement]): + batch = ilql_collate_fn(elems) + return flatten_dataclass(ILQLBatch)(batch) + + train_samples = self.model.cfg.global_batch_size * self.trainer.max_steps + train_dataset = ShuffledCyclicSequence(train_samples, self.store, self.config.train.seed) + self.model.set_train_dataset(train_dataset, collate_fn=collate_fn) + + def add_bos_if_not_present(x): + if len(x) == 0: + return [self.tokenizer.bos_token_id] + elif x[0] != self.tokenizer.bos_token_id: + return [self.tokenizer.bos_token_id] + x + else: + return x + + def eval_collate(elems): + context_tokens = [e["input_ids"] for e in elems] + context_tokens = [add_bos_if_not_present(x) for x in context_tokens] + + max_new_tokens = self.ilql_config.gen_kwargs.get("max_new_tokens", 64) + + context_lengths = [len(x) for x in context_tokens] + max_context = max(context_lengths) + + pad_id = self.tokenizer.eos_token_id + padded = [x + [pad_id] * (max_context + max_new_tokens - len(x)) for x in context_tokens] + + return [ + torch.as_tensor(padded, device="cpu"), + torch.as_tensor(context_lengths, device="cpu"), + ] + + max_train_steps = self.trainer.max_steps + eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches + eval_samples = eval_iters * self.model.cfg.global_batch_size + + eval_dataset = ShuffledCyclicSequence(eval_samples, self.eval_pipeline, self.config.train.seed) + + self.model.set_valid_dataset(eval_dataset, collate_fn=eval_collate) + + torch.set_float32_matmul_precision("medium") + self.trainer.fit(self.model) + + def make_experience(self, samples, rewards, max_length=2048): + """ + Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer + """ + verbose = get_rank.is_global_rank_zero() + self.store = make_experience(samples, rewards, self.tokenizer, max_length, verbose) diff --git a/examples/BeautifulPrompt/trlx/trainer/nemo_sft_trainer.py b/examples/BeautifulPrompt/trlx/trainer/nemo_sft_trainer.py new file mode 100644 index 0000000..8fa327c --- /dev/null +++ b/examples/BeautifulPrompt/trlx/trainer/nemo_sft_trainer.py @@ -0,0 +1,133 @@ +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union, cast + +import torch +import transformers +from nemo.utils import logging +from omegaconf.omegaconf import OmegaConf + +from trlx.data.configs import TRLConfig +from trlx.models.modeling_nemo_sft import SFTGPT +from trlx.trainer import BaseRLTrainer, register_trainer +from trlx.trainer.accelerate_sft_trainer import SFTConfig +from trlx.trainer.nemo_ilql_trainer import ShuffledCyclicSequence, megatron_trainer + + +@register_trainer +class NeMoSFTTrainer(BaseRLTrainer): + def __init__( + self, + config: TRLConfig, + metric_fn: Optional[Callable[[List[str]], Any]] = None, + megatron_cfg: Optional[Union[str, dict]] = None, + pretrained_model: Optional[str] = None, + **kwargs, + ): + super().__init__(config, metric_fn=metric_fn, **kwargs) + + if not isinstance(config.method, SFTConfig): + raise ValueError("config.method must be SFTConfig") + + self.sft_config: SFTConfig = cast(SFTConfig, config.method) + if isinstance(megatron_cfg, str): + cfg_path = Path(__file__).parent.parent.parent / "configs" / "nemo_configs" / megatron_cfg + logging.info(f"Loading NeMo config from {cfg_path=}") + megatron_cfg = OmegaConf.load(cfg_path) + elif megatron_cfg is None: + raise ValueError("megatron_cfg must be a path or a config") + + self.trainer = megatron_trainer(megatron_cfg) + self.model = SFTGPT( + sft_config=self.sft_config, + cfg=megatron_cfg.model, + trainer=self.trainer, + metric_fn=self.metric_fn, + ) + + if pretrained_model is not None: + self.model.load_from_pretrained(pretrained_model) + + self.batch_size = megatron_cfg.model.global_batch_size + self.tokenizer = self.model.tokenizer.tokenizer + self.tokenizer.truncation_side = config.tokenizer.truncation_side + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id + self.tokenizer.pad_token = self.tokenizer.eos_token + self.max_length = megatron_cfg.model.encoder_seq_length + + def learn(self): + def add_special_token_ids(input_ids: List[int], add_bos: bool, add_eos: bool): + if add_bos: + input_ids = [self.tokenizer.bos_token_id] + input_ids + if add_eos: + input_ids = input_ids + [self.tokenizer.eos_token_id] + if len(input_ids) > self.max_length: + input_ids = input_ids[: self.max_length] + return input_ids + + def pad_batch_and_build_loss_mask( + input_ids: List[List[int]], batch_max_length: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_loss_masks = [] + padded_input_ids = [] + for ids in input_ids: + input_length = len(ids) + padding_length = batch_max_length - input_length + padded_input_ids.append(ids + [self.tokenizer.pad_token_id] * padding_length) + loss_mask = [1.0] * input_length + [0.0] * padding_length + batch_loss_masks.append(torch.tensor(loss_mask, dtype=torch.float)) + padded_input_ids = torch.as_tensor(padded_input_ids, dtype=torch.long) + batch_loss_masks = torch.stack(batch_loss_masks, dim=0) + # NOTE: Un-build the loss mask if we're not going to mask eod tokens + if self.model.cfg.data.get("eod_mask_loss", False) is False: + loss_mask = torch.ones_like(loss_mask) + return padded_input_ids, batch_loss_masks + + def collate_fn(elems: List[transformers.BatchEncoding]): + context_tokens = [ + add_special_token_ids( + e["input_ids"], + self.model.cfg.data.get("add_bos", False), + self.model.cfg.data.get("add_eos", True), + ) + for e in elems + ] + input_ids, loss_mask = pad_batch_and_build_loss_mask(context_tokens, self.max_length) + return input_ids, loss_mask + + train_samples = self.model.cfg.global_batch_size * self.trainer.max_steps + train_dataset = ShuffledCyclicSequence(train_samples, self.store, self.config.train.seed) + self.model.set_train_dataset(train_dataset, collate_fn=collate_fn) + + def eval_collate(elems): + context_tokens = [ + add_special_token_ids(e["input_ids"], add_bos=self.model.cfg.data.get("add_bos", False), add_eos=False) + for e in elems + ] + max_new_tokens = self.sft_config.gen_kwargs.get("max_new_tokens", 64) + + context_lengths = [len(x) for x in context_tokens] + max_context = max(context_lengths) + + pad_id = self.tokenizer.eos_token_id + padded = [x + [pad_id] * (max_context + max_new_tokens - len(x)) for x in context_tokens] + + return [ + torch.as_tensor(padded, device="cpu"), + torch.as_tensor(context_lengths, device="cpu"), + ] + + max_train_steps = self.trainer.max_steps + eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches + eval_samples = eval_iters * self.model.cfg.global_batch_size + + eval_dataset = ShuffledCyclicSequence( + new_length=eval_samples, + data=self.eval_pipeline, + seed=self.config.train.seed, + ) + + self.model.set_valid_dataset(eval_dataset, collate_fn=eval_collate) + + torch.set_float32_matmul_precision("medium") + self.trainer.fit(self.model) diff --git a/examples/BeautifulPrompt/trlx/trainer/utils.py b/examples/BeautifulPrompt/trlx/trainer/utils.py new file mode 100644 index 0000000..2f3f6ac --- /dev/null +++ b/examples/BeautifulPrompt/trlx/trainer/utils.py @@ -0,0 +1,16 @@ +from contextlib import contextmanager +import gc + +import torch + +class PPODecorators(object): + optimize_cuda_cache = False + + @classmethod + @contextmanager + def empty_cuda_cache(cls): + yield + if cls.optimize_cuda_cache and torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + gc.collect() diff --git a/examples/BeautifulPrompt/trlx/trlx.py b/examples/BeautifulPrompt/trlx/trlx.py new file mode 100644 index 0000000..a97a674 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/trlx.py @@ -0,0 +1,125 @@ +import os +import warnings +from typing import Callable, Dict, Iterable, List, Optional, Tuple + +from trlx.data.configs import TRLConfig +from trlx.data.default_configs import ( + default_ilql_config, + default_ppo_config, + default_sft_config, +) +from trlx.utils import set_seed +from trlx.utils.loading import get_pipeline, get_trainer + + +def train( # noqa: C901 + model_path: Optional[str] = None, + reward_fn: Optional[Callable[[List[str], List[str], List[str]], List[float]]] = None, + dataset: Optional[Iterable[Tuple[str, float]]] = None, + samples: Optional[List[str]] = None, + rewards: Optional[List[float]] = None, + prompts: Optional[List[str]] = None, + eval_prompts: Optional[List[str]] = None, + metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None, + config: Optional[TRLConfig] = None, + stop_sequences: Optional[List[str]] = [], +): + """ + Dispatches online, offline reinforcement training or supervised finetuning + depending on whether a reward function or a list of samples & rewards, or only list of samples is given + + Args: + model_path (Optional[str]): Path to either huggingface checkpoint or a local directory + config (Optional[TRLConfig]): TRLX configuration object + reward_fn (Optional[Callable[[List[str], List[str], List[str]], List[float]]]): + Function to rate batches of generated samples. Its arguments are + (`samples`, `prompts`, `outputs`) and the return is a list of `rewards` + dataset (List[Union[str, List[str]]], List[float]): + Lists of samples and rewards for offline training. (Use `samples` and `rewards` instead) + samples (List[Union[str, List[str]]]): + List of strings or a list of prompts (questions or environment states) and outputs which are + meant to be optimized. In the latter case the following form is expected: + (prompt_0: str, output_0: str, prompt_1: str, output_1: str ...). + Giving a single string `s` for the sample is a shorthand for (`tokenizer.bos_token`, `s`) + rewards (List[float]): + List of real numbers measuring the goodness of each sample + prompts (`List[str]` or `List[Dict[str, Any]]`): Prompts to use for generations during online training. + If a dict is passed as prompt, it must have a required key `"prompt"`, all the extra keys would be + passed along the generation for that prompt as a keyword argument to reward function. + eval_prompts (List[str] or `List[Dict[str, Any]]`): Prompts to use for periodical validation of training + metric_fn (Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]]): + Function to compute statistics on batches of generated samples. Its arguments are the same + as in `reward_fn` (`samples`, `prompts`, `outputs`) but the return is dictionary with keys + as metric's name and values and lists of numeric values per each sample in batch + stop_sequences (Optional[List[str]]): + String sequences to trim generations (both for generating of experience and evaluation) up to its + encounter in them. Generations will not contain them and also will also be right-stripped + """ + if config is None: + warnings.warn( + "Passing the `config` argument implicitly is depreciated, use or" + "adapt some from `trlx/data/default_configs.py` instead" + ) + if reward_fn: + config = default_ppo_config() + elif rewards: + config = default_ilql_config() + else: + config = default_sft_config() + + set_seed(config.train.seed) + + if dataset: + warnings.warn("the `dataset` argument is being depreciated, split it into `samples` and `rewards` instead") + samples, rewards = dataset + + if model_path: + config.model.model_path = model_path + + trainer = get_trainer(config.train.trainer)( + config=config, + reward_fn=reward_fn, + metric_fn=metric_fn, + stop_sequences=stop_sequences, + **config.train.trainer_kwargs, + ) + + batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1)) + max_prompt_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] + + # Online training against a reward function (e.g. PPO) + if reward_fn: + prompts = prompts or [trainer.tokenizer.bos_token] * batch_size + + if eval_prompts is None: + eval_prompts = prompts[:batch_size] + + pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer) + trainer.add_prompt_pipeline(pipeline) + + if eval_prompts is None: + eval_prompts = prompts[:batch_size] + + trainer.make_experience(config.method.num_rollouts) + + # Offline training from the collected samples (e.g. SFT, ILQL) + elif samples: + if rewards is not None: + if len(samples) != len(rewards): + raise ValueError(f"Number of samples {len(samples)} should match the number of rewards {len(rewards)}") + + if eval_prompts is None: + eval_prompts = [trainer.tokenizer.bos_token] * batch_size + + if rewards is not None: + trainer.make_experience(samples, rewards, config.train.seq_length) + else: + trainer.make_experience(samples, config.train.seq_length) + else: + raise ValueError("Either `samples` or `reward_fn` should be given for training") + + eval_pipeline = get_pipeline(config.train.pipeline)(eval_prompts, max_prompt_length, trainer.tokenizer) + trainer.add_eval_pipeline(eval_pipeline) + + trainer.learn() + return trainer diff --git a/examples/BeautifulPrompt/trlx/utils/__init__.py b/examples/BeautifulPrompt/trlx/utils/__init__.py new file mode 100644 index 0000000..abc4d54 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/utils/__init__.py @@ -0,0 +1,240 @@ +import math +import os +import random +import subprocess +import time +from dataclasses import is_dataclass +from enum import Enum +from itertools import repeat +from numbers import Number +from typing import Any, Dict, Iterable, Tuple + +import numpy as np +import torch +from accelerate import Accelerator +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR + + +def print_rank_0(*message): + """ + Print only once from the main rank + """ + if os.environ.get("RANK", "0") == "0": + print(*message) + + +def significant(x: Number, ndigits=2) -> Number: + """ + Cut the number up to its `ndigits` after the most significant + """ + if isinstance(x, torch.Tensor): + x = x.item() + + if not isinstance(x, Number) or math.isnan(x) or x == 0: + return x + + return round(x, ndigits - int(math.floor(math.log10(abs(x))))) + + +def set_seed(seed: int): + """ + Sets seeds across package dependencies for reproducibility. + """ + seed += int(os.environ.get("RANK", 0)) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +# Training utils + + +def get_distributed_config(accelerator: Accelerator): + """ + Return accelerator distributed config + """ + + dist_config = { + "mixed_precision": accelerator.mixed_precision, + "num_gpus": accelerator.num_processes, + } + + if accelerator.state.deepspeed_plugin is not None: + ds_plugin = accelerator.state.deepspeed_plugin + dist_config.update( + { + "gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps, + "gradient_clipping": ds_plugin.gradient_clipping, + "zero_stage": ds_plugin.zero_stage, + "offload_optimizer_device": ds_plugin.offload_optimizer_device, + "offload_param_device": ds_plugin.offload_param_device, + } + ) + + return dist_config + + +class OptimizerName(str, Enum): + """Supported optimizer names""" + + ADAM: str = "adam" + ADAMW: str = "adamw" + ADAM_8BIT_BNB: str = "adam_8bit_bnb" + ADAMW_8BIT_BNB: str = "adamw_8bit_bnb" + SGD: str = "sgd" + + +def get_optimizer_class(name: OptimizerName): + """ + Returns the optimizer class with the given name + + Args: + name (str): Name of the optimizer as found in `OptimizerNames` + """ + if name == OptimizerName.ADAM: + return torch.optim.Adam + if name == OptimizerName.ADAMW: + return torch.optim.AdamW + if name == OptimizerName.ADAM_8BIT_BNB.value: + try: + from bitsandbytes.optim import Adam8bit + + return Adam8bit + except ImportError: + raise ImportError( + "You must install the `bitsandbytes` package to use the 8-bit Adam. " + "Install with: `pip install bitsandbytes`" + ) + if name == OptimizerName.ADAMW_8BIT_BNB.value: + try: + from bitsandbytes.optim import AdamW8bit + + return AdamW8bit + except ImportError: + raise ImportError( + "You must install the `bitsandbytes` package to use 8-bit AdamW. " + "Install with: `pip install bitsandbytes`" + ) + if name == OptimizerName.SGD.value: + return torch.optim.SGD + supported_optimizers = [o.value for o in OptimizerName] + raise ValueError(f"`{name}` is not a supported optimizer. " f"Supported optimizers are: {supported_optimizers}") + + +class SchedulerName(str, Enum): + """Supported scheduler names""" + + COSINE_ANNEALING = "cosine_annealing" + LINEAR = "linear" + + +def get_scheduler_class(name: SchedulerName): + """ + Returns the scheduler class with the given name + """ + if name == SchedulerName.COSINE_ANNEALING: + return CosineAnnealingLR + if name == SchedulerName.LINEAR: + return LinearLR + supported_schedulers = [s.value for s in SchedulerName] + raise ValueError(f"`{name}` is not a supported scheduler. " f"Supported schedulers are: {supported_schedulers}") + + +class Clock: + """ + Helper object for keeping track of time for computations. + """ + + def __init__(self): + self.start = time.time() + self.total_time = 0 + self.total_samples = 0 + + def tick(self, samples: int = 0) -> float: + """ + Returns time (s) since last call to tick(). Also records samples processed since last call. + + :param samples: number of samples that have been processed since last call + """ + end = time.time() + delta = end - self.start + self.start = end + + if samples != 0: + self.total_time += delta + self.total_samples += samples + + return delta + + def get_stat(self, n_samp: int = 1000, reset: bool = False): + """ + Returns average time (s) per n_samp samples processed + + :param reset: Reset counts? + """ + sec_per_samp = self.total_time / self.total_samples + + if reset: + self.total_samples = 0 + self.total_time = 0 + + return sec_per_samp * n_samp + + +def tree_map(f, tree: Any) -> Any: + """ + Apply function f to all leaves in tree + """ + if is_dataclass(tree): + return tree.__class__(**{k: tree_map(f, v) for k, v in tree.__dict__.items()}) + elif isinstance(tree, dict): + return {k: tree_map(f, v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return tree.__class__(tree_map(f, v) for v in tree) + else: + return f(tree) + + +def to_device(tree, device, non_blocking=False): + """ + Move all tensors in tree to device + """ + return tree_map(lambda x: x.to(device, non_blocking=non_blocking), tree) + + +def filter_non_scalars(xs: Dict) -> Dict: + """ + Trims everything that can't be casted to float + """ + ys = {} + for k, v in xs.items(): + try: + ys[k] = float(v) + except TypeError: + continue + + return ys + + +def get_git_tag() -> Tuple[str, str]: + """ + Returns commit's short hash and date + """ + try: + output = subprocess.check_output("git log --format='%h/%as' -n1".split()) + branch = subprocess.check_output("git rev-parse --abbrev-ref HEAD".split()) + return branch.decode()[:-1], output.decode()[1:-2] + except subprocess.CalledProcessError: + return "unknown", "unknown" + + +# Iter utils + + +def infinite_dataloader(dataloader: Iterable) -> Iterable: + """ + Returns a cyclic infinite dataloader from a finite dataloader + """ + for _ in repeat(dataloader): + yield from dataloader diff --git a/examples/BeautifulPrompt/trlx/utils/loading.py b/examples/BeautifulPrompt/trlx/utils/loading.py new file mode 100644 index 0000000..8f1722a --- /dev/null +++ b/examples/BeautifulPrompt/trlx/utils/loading.py @@ -0,0 +1,48 @@ +from typing import Callable, List + +# Register load pipelines via module import +from trlx.pipeline import _DATAPIPELINE +from trlx.pipeline.offline_pipeline import PromptPipeline + +# Register load trainers via module import +from trlx.trainer import _TRAINERS, register_trainer +from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer +from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer +from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer + +try: + from trlx.trainer.nemo_ilql_trainer import NeMoILQLTrainer + from trlx.trainer.nemo_sft_trainer import NeMoSFTTrainer +except ImportError: + # NeMo is not installed + def _trainers_unavailble(names: List[str]): + def log_error(*args, **kwargs): + raise ImportError("NeMo is not installed. Please install `nemo_toolkit` to use NeMo-based trainers.") + + # Register dummy trainers + for name in names: + register_trainer(name)(log_error) + + _trainers_unavailble(["NeMoILQLTrainer", "NeMoSFTTrainer"]) + + +def get_trainer(name: str) -> Callable: + """ + Return constructor for specified RL model trainer + """ + name = name.lower() + if name in _TRAINERS: + return _TRAINERS[name] + else: + raise Exception("Error: Trying to access a trainer that has not been registered") + + +def get_pipeline(name: str) -> Callable: + """ + Return constructor for specified pipeline + """ + name = name.lower() + if name in _DATAPIPELINE: + return _DATAPIPELINE[name] + else: + raise Exception("Error: Trying to access a pipeline that has not been registered") diff --git a/examples/BeautifulPrompt/trlx/utils/logging.py b/examples/BeautifulPrompt/trlx/utils/logging.py new file mode 100644 index 0000000..79badb4 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/utils/logging.py @@ -0,0 +1,340 @@ +# Copyright 2023 Optuna, Hugging Face, CarperAI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import CRITICAL # NOQA +from logging import DEBUG # NOQA +from logging import ERROR # NOQA +from logging import FATAL # NOQA +from logging import INFO # NOQA +from logging import NOTSET # NOQA +from logging import WARN # NOQA +from logging import WARNING # NOQA +from typing import Optional + +import torch +from tqdm import auto as tqdm_lib + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.INFO + + +def _get_default_logging_level(): + """ + If `TRLX_VERBOSITY` env var is set to one of the valid choices, return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("TRLX_VERBOSITY", None) + if env_level_str: + if env_level_str.lower() in log_levels: + return log_levels[env_level_str.lower()] + else: + logging.getLogger().warning( + f"Unknown option TRLX_VERBOSITY={env_level_str}, " f"has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict(): + return log_levels + + +class MultiProcessAdapter(logging.LoggerAdapter): + """A logger adapter for handling multi-process logging""" + + def log(self, level, msg, *args, **kwargs): + """ + Consumes an additional kwarg called `ranks` to determine which processes should log. + NOTE: To specify all processes, pass in an empty list `ranks=[]` + + Default: ["0"], i.e. only the main process logs + """ + # By default, silence all non-main processes + ranks = kwargs.pop("ranks", ["0"]) + should_log = os.environ.get("RANK", "0") in ranks or len(ranks) == 0 + if self.isEnabledFor(level) and should_log: + msg, kwargs = self.process(msg, kwargs) + self.logger._log(level, msg, args, **kwargs) + + def process(self, msg, kwargs): + this_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + return f"[RANK {this_rank}] {msg}", kwargs + + +def get_logger(name: Optional[str] = None) -> MultiProcessAdapter: + """ + Returns a `logging.Logger` for `name` that can handle multiple processes + + Args: + name: Name of the logger + + Usage: + >> logger = get_logger(__name__) + >> logger.debug("Check the...", ranks=["0", "1"]) # Only main and rank 1 log + """ + if name is None: + name = _get_library_name() + _configure_library_root_logger() + logger = logging.getLogger(name) + return MultiProcessAdapter(logger, {}) + + +def get_verbosity() -> int: + """ + Return the current level for trlx's root logger as an int. + Returns: + `int`: The logging level. + + trlx has following logging levels: + - 50: `trlx.logging.CRITICAL` or `trlx.logging.FATAL` + - 40: `trlx.logging.ERROR` + - 30: `trlx.logging.WARNING` or `trlx.logging.WARN` + - 20: `trlx.logging.INFO` + - 10: `trlx.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for trlX's root logger. + Args: + verbosity (`int`): + Logging level, e.g., one of: + - `trlx.logging.CRITICAL` or `trlx.logging.FATAL` + - `trlx.logging.ERROR` + - `trlx.logging.WARNING` or `trlx.logging.WARN` + - `trlx.logging.INFO` + - `trlx.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def disable_default_handler() -> None: + """Disable the default handler of trlx's root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of trlx's root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """Adds a handler to trlx's root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """Removes given handler from the trlx's root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler not in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the trlx's default handler to prevent + double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every trlx's logger. The explicit formatter is as follows: + ``` + [ASCTIME] [LEVELNAME] [FILENAME:LINE NUMBER:FUNCNAME] MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter( + "[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s" + ) + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for trlx's loggers. + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but if env var TRLX_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("TRLX_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +_tqdm_active = True + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar(): + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False diff --git a/examples/BeautifulPrompt/trlx/utils/modeling.py b/examples/BeautifulPrompt/trlx/utils/modeling.py new file mode 100644 index 0000000..00528d2 --- /dev/null +++ b/examples/BeautifulPrompt/trlx/utils/modeling.py @@ -0,0 +1,552 @@ +import functools +from typing import Any, Dict, List, MutableMapping, Tuple, Union + +import accelerate +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +import transformers + +try: + from opendelta import ( + AdapterModel, + BitFitModel, + LoraModel, + PrefixModel, + SoftPromptModel, + ) + + HAS_OPENDELTA = True +except ModuleNotFoundError: + HAS_OPENDELTA = False + + +def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequential: + """Returns a generic sequential MLP head.""" + return nn.Sequential( + nn.Linear(n_embd, n_embd * 2, dtype=dtype), + nn.ReLU(), + nn.Linear(n_embd * 2, out, dtype=dtype), + ) + + +def freeze_bottom_causal_layers(model: nn.Module, num_layers_unfrozen: int = 0): + """Freezes the bottom transformer block layers of the specified model.""" + hidden_layers = hf_get_decoder_blocks(model) + if num_layers_unfrozen == 0: + hidden_layers_to_freeze = list(hidden_layers) + elif num_layers_unfrozen > 0: + hidden_layers_to_freeze = list(hidden_layers)[:-num_layers_unfrozen] + else: + hidden_layers_to_freeze = [] + for layer in hidden_layers_to_freeze: + layer.requires_grad_(False) + + +def freeze_bottom_seq2seq_layers(model: nn.Module, num_layers_unfrozen: int = 0): + """Freezes the bottom transformer block layers of the specified model.""" + if num_layers_unfrozen == -1: + return + shared_embed = model.shared + decoder_embed = model.decoder.embed_tokens + encoder_blocks = model.encoder.block + encoder_norm_layer = model.encoder.final_layer_norm + decoder_norm_layer = model.decoder.final_layer_norm + decoder_blocks = model.decoder.block[:-num_layers_unfrozen] + blocks_to_freeze = ( + list(encoder_blocks) + + list(decoder_blocks) + + [shared_embed] + + [encoder_norm_layer] + + [decoder_norm_layer] + + [decoder_embed] + ) + for block in blocks_to_freeze: + block.requires_grad_(False) + + +def rhasattr(obj, attr): + """A chain-able attribute version of hasattr. For example, to check if + `obj` has the attribute `foo.bar.baz`, you can use: + `rhasattr(obj, "foo.bar.baz")` + Reference: https://stackoverflow.com/a/67303315 + """ + _nested_attrs = attr.split(".") + _curr_obj = obj + for _a in _nested_attrs[:-1]: + if hasattr(_curr_obj, _a): + _curr_obj = getattr(_curr_obj, _a) + else: + return False + return hasattr(_curr_obj, _nested_attrs[-1]) + + +def rgetattr(obj, attr: str, *args) -> object: + """A chain-able attribute version of getattr. For example, to get the + attribute `foo.bar.baz` from `obj`, you can use: + `rgetattr(obj, "foo.bar.baz")` + Reference: https://stackoverflow.com/a/31174427 + """ + + def _getattr(obj, attr): + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) + + +def findattr(obj, attrs: Tuple[str]) -> Union[object, None]: + for attr in attrs: + if rhasattr(obj, attr): + return rgetattr(obj, attr) + raise ValueError(f"Could not find an attribute from `{attrs}` in `{obj}`") + + +def hf_get_decoder(model: nn.Module) -> nn.Module: + """Returns the causal decoder backbone of the specified HuggingFace transformers + model. + NOTE: Different model configurations have different causal decoder attribute + names. + - transformer: (GPT2LMHeadModel, GPTJConfig) + - model.decoder: (OPTConfig, BloomConfig) + - gpt_neox: (GPTNeoXConfig) + """ + decoder_attrs = ("transformer", "model.decoder", "gpt_neox", "decoder") + return findattr(model, decoder_attrs) + + +def hf_get_decoder_final_norm(model: nn.Module) -> float: + """Returns the final (layer) norm of the specified decoder. + NOTE: Different model configurations have different final norm attribute names. + - transformer.ln_f: (GPT2LMHeadModel, GPTJForCausalLM) + - model.decoder.final_layer_norm: (OPTForCausalLM) + - gpt_neox.layers.final_layer_norm: (GPTNeoXForCausalLM) + """ + norm_attrs = ( + "transformer.ln_f", + "model.decoder.final_layer_norm", + "model.norm", + "decoder.final_layer_norm", + "gpt_neox.final_layer_norm", + ) + return findattr(model, norm_attrs) + + +def hf_get_decoder_blocks(model: nn.Module) -> Tuple[nn.Module]: + """Returns the decoder hidden layers of the specified model. + NOTE: Different model configurations have different hidden layer attribute names. + - transformer.h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM) + - model.decoder.layers: (OPTForCausalLM) + - gpt_neox.layers: (GPTNeoXForCausalLM) + - decoder.block: (T5ForConditionalGeneration) + """ + hidden_layers_attrs = ( + "h", + "layers", + "model.layers", + "decoder.layers", + "transformer.h", + "model.decoder.layers", + "gpt_neox.layers", + "decoder.block", + ) + return findattr(model, hidden_layers_attrs) + + +def hf_get_lm_head(model: nn.Module) -> nn.Module: + """Returns the language modeling (lm) head of the specified HuggingFace + transformers model. + NOTE: Different model configurations have different `lm_head` attribute names. + - lm_head: (GPT2LMHeadModel, BloomForCausalLM) + - embed_out: (GPTNeoXForCausalLM) + """ + return model.get_output_embeddings() + + +def hf_get_hidden_size(config: transformers.PretrainedConfig) -> int: + """Returns the hidden layer dimensionality of the model architecture specified + by the HuggingFace transformers config. + NOTE: Different model configurations have different hidden size attribute names. + - hidden_size: (OPTConfig, BloomConfig) + - n_embd: (GPT2Config, GPTJConfig) + - d_model: (PegasusConfig, XLNetConfig) + """ + hidden_size_attrs = ("hidden_size", "n_embd", "d_model") + return findattr(config, hidden_size_attrs) + + +def hf_get_num_hidden_layers(config: transformers.PretrainedConfig) -> int: + """Returns the number of hidden layers in the model architecture specified + by the HuggingFace transformers config. + NOTE: Different model configurations have different number-of-layers attribute + names. + - num_hidden_layers: (GPTNeoXConfig, OPTConfig) + - n_layer: (GPT2Config, GPTJConfig, BloomConfig) + """ + num_hidden_layers_attrs = ("num_hidden_layers", "n_layer") + return findattr(config, num_hidden_layers_attrs) + + +def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]: + """ + Computes element-wise mean and variance of the tensor across processes + """ + sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.device) + dist.all_reduce(sum_and_count, dist.ReduceOp.SUM) + global_sum, count = sum_and_count + global_mean = global_sum / count + + sum_var = torch.sum((xs - global_mean) ** 2) + dist.all_reduce(sum_var, dist.ReduceOp.SUM) + global_var = sum_var / count + return global_mean, global_var, count + + +def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch.Tensor: + """Whitens values""" + if distributed and dist.is_initialized(): + mean, var, _ = get_global_statistics(xs) + else: + var, mean = torch.var_mean(xs) + + whitened = (xs - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def logprobs_of_labels(logits, labels): + """Log probabilities of the labels + + These are calculated from the logits.""" + logprobs = F.log_softmax(logits, dim=-1) + logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1)) + return logprobs_labels.squeeze(-1) + + +def flatten_dict( + d: Union[dict, MutableMapping], + parent_key: str = "", + sep: str = "/", +) -> dict: + # From: https://stackoverflow.com/a/6027615 + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def gather_dict(obj: Dict, grad_state: accelerate.state.GradientState = None): + """ + Gather and concatenates key-values from a dictionary, optionally + trimming them if some of them were out of dataloader's padding + """ + if not torch.distributed.is_initialized(): + return obj + + objs = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(objs, obj) + + acc, *objs = objs + for obj in objs: + for k in obj: + acc[k].extend(obj[k]) + + if grad_state: + if grad_state.end_of_dataloader and grad_state.remainder > 0: + for k in acc: + acc[k] = acc[k][: grad_state.remainder] + + return acc + + +def get_tensor_stats(xs: torch.Tensor, mask: torch.Tensor, n: int): + if xs.numel() == 0: + return dict(mean=0, min=0, max=0, std=0) + + mean = (xs * mask).sum() / n + return dict( + mean=mean, + min=torch.where(mask.bool(), xs, np.inf).min(), + max=torch.where(mask.bool(), xs, -np.inf).max(), + std=torch.sqrt(((xs - mean) * mask).pow(2).sum() / n), + ) + + +class RunningMoments: + def __init__(self): + """ + Calculates the running mean and standard deviation of a data stream. Modified version of + https://github.com/DLR-RM/stable-baselines3/blob/a6f5049a99a4c21a6f0bcce458ca3306cef310e0/stable_baselines3/common/running_mean_std.py + """ + self.mean = 0 + self.std = 1 + self.var = 1 + self.count = 1e-24 + + def update(self, xs: torch.Tensor) -> Tuple[float, float]: + """Updates running moments from batch's moments computed across ranks""" + if dist.is_initialized(): + xs_mean, xs_var, xs_count = get_global_statistics(xs) + else: + xs_count = xs.numel() + xs_var, xs_mean = torch.var_mean(xs, unbiased=False) + + delta = xs_mean - self.mean + tot_count = self.count + xs_count + + new_sum = xs_var * xs_count + # correct old_sum deviation accounting for the new mean + old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count + tot_sum = old_sum + new_sum + + self.mean += delta * xs_count / tot_count + self.var = tot_sum / tot_count + self.std = (self.var * tot_count / (tot_count - 1)).sqrt() + self.count = tot_count + + return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt() + + +# OpenDelta utilities + + +MODIFIED_MODULES_DICT = { + "gptj": { + "attention": ["attn.q_proj", "attn.k_proj", "attn.v_proj"], + "mlp": ["mlp.fc_in", "mlp.fc_out"], + "all": [ + "attn.q_proj", + "attn.k_proj", + "attn.v_proj", + "attn.out_proj", + "mlp.fc_in", + "mlp.fc_out", + ], + }, + "gpt_neox": { + "attention": ["attention.query_key_value"], + "mlp": ["mlp.dense_h_to_4h", "mlp.dense_4h_to_h"], + "all": [ + "attention.query_key_value", + "attention.dense", + "mlp.dense_h_to_4h", + "mlp.dense_4h_to_h", + ], + }, + "opt": { + "attention": [ + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.q_proj", + "self_attn.out_proj", + ], + "mlp": ["fc1", "fc2"], + "all": [ + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.q_proj", + "self_attn.out_proj", + "fc1", + "fc2", + ], + }, + "bloom": { + "attention": ["self_attention.query_key_value", "self_attention.dense"], + "mlp": ["mlp.dense_h_to_4h", "mlp.dense_4h_to_h"], + "all": [ + "self_attention.query_key_value", + "self_attention.dense", + "mlp.dense_h_to_4h", + "mlp.dense_4h_to_h", + ], + }, + "t5": { + "attention": [ + "layer.0.SelfAttention.q", + "layer.0.SelfAttention.k", + "layer.0.SelfAttention.v", + "layer.0.SelfAttention.o", + "layer.1.EncDecAttention.q", + "layer.1.EncDecAttention.k", + "layer.1.EncDecAttention.v", + "layer.1.EncDecAttention.o", + ], + "mlp": [ + "layer.2.DenseReluDense.wo", + "layer.2.DenseReluDense.wi_0", + "layer.2.DenseReluDense.wi_1", + ], + "all": [ + "layer.0.SelfAttention.q", + "layer.0.SelfAttention.k", + "layer.0.SelfAttention.v", + "layer.0.SelfAttention.o", + "layer.1.EncDecAttention.q", + "layer.1.EncDecAttention.k", + "layer.1.EncDecAttention.v", + "layer.1.EncDecAttention.o", + "layer.2.DenseReluDense.wo", + "layer.2.DenseReluDense.wi_0", + "layer.2.DenseReluDense.wi_1", + ], + }, +} + + +def generate_layer_regex(config: transformers.PretrainedConfig, num_layers_unfrozen: int = -1) -> str: + """Generates a regex range for the specified number of learnable layers.""" + if num_layers_unfrozen == -1: + return "(\d)+." + num_hidden_layers = hf_get_num_hidden_layers(config) + start_layer = num_hidden_layers - num_layers_unfrozen + if start_layer < 0: + raise Exception("Number of layers unfrozen cannot be greater than number of layers in the model") + pattern = f"(?:{regex_for_range(start_layer, num_hidden_layers - 1)})." + return f"{pattern}" + + +def get_delta_modified_modules( + config: transformers.PretrainedConfig, + modified_modules: List[str], + num_layers_unfrozen: int = -1, +) -> List[str]: + """Returns a list of module names to be modified for a given delta method with + the specified number of learnable layers.""" + unfrozen_layers_pattern = generate_layer_regex(config, num_layers_unfrozen) + + # [r] for regex as per https://github.com/thunlp/OpenDelta/blob/main/opendelta/utils/name_based_addressing.py#L20 + regex_prefix = "[r]" + # TODO (jon-tow): `decoder.block.` is hardcoded to support T5 layer naming. + decoder_prefix = "decoder.block." if config.is_encoder_decoder else "" + module_list = [regex_prefix + decoder_prefix + unfrozen_layers_pattern + module for module in modified_modules] + return module_list + + +def get_delta_model_class(model_type: str): + if not HAS_OPENDELTA: + raise ValueError("OpenDelta package required to train with delta models. https://github.com/thunlp/OpenDelta.") + delta_models = { + "bitfit": BitFitModel, + "adapter": AdapterModel, + "prefix": PrefixModel, + "lora": LoraModel, + "softprompt": SoftPromptModel, + } + return delta_models[model_type] + + +def parse_delta_kwargs( + config: transformers.PretrainedConfig, + delta_kwargs: Dict[str, Any], + num_layers_unfrozen: int = -1, +) -> Tuple[str, Dict[str, Any]]: + """Parses through delta kwargs to get delta type and proper modified modules.""" + # This function is needed to parse through the `delta_kwargs` in order to: + # 1) Get the `delta_type` method name to access the correct `delta_model_class` + # 2a) Accept user specified `modified_modules` and if not provided use the `trlx` default mapping + # 2b) Convert the list of `modified_modules` to a range of layers that fit within the range + # of learnable layers as specified by `num_layers_unfrozen` + + # Pop `delta_type` to allow passing the kwargs to the model constructor since + # `delta_type` is not a valid argument of the constructor + delta_type = delta_kwargs.pop("delta_type") + assert delta_type in ["lora"], "Only `LoRA` based delta models are supported" + + # Use `trlx` default modified modules if none are specified + modified_modules = delta_kwargs.get("modified_modules", "all") + if modified_modules in ["all", "attention", "mlp"]: + if config.model_type not in MODIFIED_MODULES_DICT: + raise ValueError( + f"Model type `{config.model_type}` is not currently supported for " + "delta training with default modified modules." + ) + modified_modules = MODIFIED_MODULES_DICT[config.model_type][modified_modules] + # Update the `modified_modules` with the correct layer ranges + delta_kwargs["modified_modules"] = get_delta_modified_modules( + config, modified_modules, num_layers_unfrozen=num_layers_unfrozen + ) + + return delta_type, delta_kwargs + + +def regex_for_range(min_: int, max_: int) -> str: # noqa + """Returns a regex that matches all numbers in the given range. + + Example: regex_for_range(12, 34) -> "1[2-9]|2\d|3[0-4]" + + Copyright (c) 2013, Dmitry Voronin. All rights reserved. + Reference: https://github.com/voronind/range-regex + """ + + def split_to_patterns(min_, max_): + subpatterns = [] + start = min_ + for stop in split_to_ranges(min_, max_): + subpatterns.append(range_to_pattern(start, stop)) + start = stop + 1 + return subpatterns + + def split_to_ranges(min_, max_): + stops = {max_} + nines_count = 1 + stop = fill_by_nines(min_, nines_count) + while min_ <= stop < max_: + stops.add(stop) + nines_count += 1 + stop = fill_by_nines(min_, nines_count) + zeros_count = 1 + stop = fill_by_zeros(max_ + 1, zeros_count) - 1 + while min_ < stop <= max_: + stops.add(stop) + zeros_count += 1 + stop = fill_by_zeros(max_ + 1, zeros_count) - 1 + stops = list(stops) + stops.sort() + return stops + + def fill_by_nines(integer, nines_count): + return int(str(integer)[:-nines_count] + "9" * nines_count) + + def fill_by_zeros(integer, zeros_count): + return integer - integer % 10**zeros_count + + def range_to_pattern(start, stop): + pattern = "" + any_digit_count = 0 + for start_digit, stop_digit in zip(str(start), str(stop)): + if start_digit == stop_digit: + pattern += start_digit + elif start_digit != "0" or stop_digit != "9": + pattern += "[{}-{}]".format(start_digit, stop_digit) + else: + any_digit_count += 1 + if any_digit_count: + pattern += r"\d" + if any_digit_count > 1: + pattern += "{{{}}}".format(any_digit_count) + return pattern + + positive_subpatterns = [] + negative_subpatterns = [] + + if min_ < 0: + min__ = 1 + if max_ < 0: + min__ = abs(max_) + max__ = abs(min_) + negative_subpatterns = split_to_patterns(min__, max__) + min_ = 0 + if max_ >= 0: + positive_subpatterns = split_to_patterns(min_, max_) + + negative_only_subpatterns = ["-" + val for val in negative_subpatterns if val not in positive_subpatterns] + positive_only_subpatterns = [val for val in positive_subpatterns if val not in negative_subpatterns] + intersected_subpatterns = ["-?" + val for val in negative_subpatterns if val in positive_subpatterns] + subpatterns = negative_only_subpatterns + intersected_subpatterns + positive_only_subpatterns + return "|".join(subpatterns) diff --git a/examples/X-STA/.gitignore b/examples/X-STA/.gitignore new file mode 100644 index 0000000..c57c55d --- /dev/null +++ b/examples/X-STA/.gitignore @@ -0,0 +1,140 @@ +models/ +figures/ +outputs/ +analysis/ +download/ +/data/ +*.tsv +*.csv + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +outputs diff --git a/examples/X-STA/README.md b/examples/X-STA/README.md new file mode 100644 index 0000000..e90d517 --- /dev/null +++ b/examples/X-STA/README.md @@ -0,0 +1,23 @@ +# X-STA + +This project is implemented for the findgs of EMNLP 2023 paper: "Sharing, Teaching and Aligning: Knowledgeable Transfer Learning for Cross-Lingual Machine Reading Comprehension". Our code is based on pytorch and huggingface transformers. + +## Requirements +```bash +pip install -r requirement.txt +``` + +## Quick Start + +**NOTE**: Please make sure you have set up the environment correctly. + +1. Download data + +To download 3 datasets, please run `bash scripts/download_data.sh` which may take a while. + +We use translated training data from XTREME team. Please refere to their [repo](https://github.com/google-research/xtreme) or their [translation](https://console.cloud.google.com/storage/browser/xtreme_translations) directly. + +2. Model Training and Evaluation: +```bash +bash scripts/mbert/mlqa.sh +``` diff --git a/examples/X-STA/run_xqa.py b/examples/X-STA/run_xqa.py new file mode 100644 index 0000000..3078b5a --- /dev/null +++ b/examples/X-STA/run_xqa.py @@ -0,0 +1,1091 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning the library models for question-answering on SQuAD""" + +import argparse +import glob +import logging +import os +import random +import timeit +import itertools +import json +import pickle +import math + +import numpy as np +import torch +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset, ConcatDataset, TensorDataset +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm, trange +from collections import defaultdict +from prettytable import PrettyTable + +from src.data.qa import squad_convert_examples_to_features +from src.bert import BertForQuestionAnswering +from src.xlmr import XLMRobertaForQuestionAnswering + +from transformers import ( + XLMRobertaConfig, + XLMRobertaTokenizer, + BertConfig, + BertTokenizer, + WEIGHTS_NAME, + AdamW, + get_linear_schedule_with_warmup +) + +from src.data.squad_metrics import ( + compute_predictions_log_probs, + compute_predictions_logits, + squad_evaluate, +) +from src.data.qa import ( + SquadResult, + MLQAProcessor, + TydiqaProcessor, + XquadProcessor +) +from src.data.mlqa_evaluation_v1 import evaluate_with_path + +from src.utils import BatchContinuousRandomSampler, BatchContinuousDistributedSampler + +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + from tensorboardX import SummaryWriter + + +logger = logging.getLogger(__name__) + +MODEL_CLASSES = { + "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer), + "xlmr": (XLMRobertaConfig, XLMRobertaForQuestionAnswering, XLMRobertaTokenizer), +} + +PROCESSOR_MAP = { + "mlqa": MLQAProcessor, + "tydiqa": TydiqaProcessor, + "xquad": XquadProcessor, +} + +LANG_2_IDX = {} + +class AlignDataset(Dataset): + def __init__(self, datasets, examples, feature_qas_ids, en_index): + super(AlignDataset, self).__init__() + assert len(datasets) > 0, 'datasets should not be an empty iterable' + + self.datasets = datasets + self.en_index = en_index + + self.eng_id2feat = {} + + for idx, qas_id in enumerate(feature_qas_ids[en_index]): + if qas_id not in self.eng_id2feat: + self.eng_id2feat[qas_id] = [] + self.eng_id2feat[qas_id].append(idx) + + self.indexes = [] + qas_id_cnt = defaultdict(int) + for lang_idx, qas_ids in enumerate(feature_qas_ids): + for feat_idx, qas_id in enumerate(qas_ids): + + if lang_idx == en_index: + self.indexes.append((lang_idx, feat_idx, qas_id)) + else: + # for training, only add aligned data + if qas_id in self.eng_id2feat: + self.indexes.append((lang_idx, feat_idx, self.eng_id2feat[qas_id][qas_id_cnt[qas_id]])) + qas_id_cnt[qas_id] = (qas_id_cnt[qas_id] + 1 ) % len(self.eng_id2feat[qas_id]) + def __len__(self): + return len(self.indexes) + + def __getitem__(self, idx): + lang_idx, feat_idx, idx3 = self.indexes[idx] + + if lang_idx == self.en_index: + return [self.datasets[self.en_index][feat_idx], self.datasets[lang_idx][feat_idx]] + else: + # other languages + return [self.datasets[self.en_index][idx3], self.datasets[lang_idx][feat_idx]] + + +class AlignDatasetForEval(Dataset): + def __init__(self, datasets, feature_qas_ids): + super(AlignDatasetForEval, self).__init__() + assert len(datasets) > 0, 'datasets should not be an empty iterable' + assert len(datasets) == 2 + + self.datasets = datasets + self.indexes = [] + + self.en_qas_id_2_idx = {} + + for idx, qas_id in enumerate(feature_qas_ids[0]): + self.en_qas_id_2_idx[qas_id] = idx + + for idx, qas_id in enumerate(feature_qas_ids[1]): + # TODO + assert qas_id in self.en_qas_id_2_idx, "{} need English translation for inference".format(qas_id) + self.indexes.append([self.en_qas_id_2_idx[qas_id], idx]) + + def __len__(self): + return len(self.datasets[1]) + + def __getitem__(self, idx): + feat_idx_en, feat_idx = self.indexes[idx] + return [self.datasets[0][feat_idx_en], self.datasets[1][feat_idx]] + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def to_list(input): + if isinstance(input, list): + return [tensor.detach().cpu.tolist() for tensor in input] + else: + # tensor + return input.detach().cpu().tolist() + + +def get_max_steps(output_dir): + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(output_dir + "/**/" + WEIGHTS_NAME, recursive=True))) + max_step = None + for checkpoint in checkpoints: + if len(checkpoint.split("-")) > 1: + if max_step is None: + max_step = int(checkpoint.split("-")[-1]) + else: + max_step = max(max_step, int(checkpoint.split("-")[-1])) + return max_step + + +def save_model(args, model, tokenizer): + # Create output directory if needed + + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + logger.info("Saving model checkpoint to %s", args.output_dir) + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + # Take care of distributed/parallel training + model_to_save = model.module if hasattr(model, "module") else model + model_to_save.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + +def train(args, train_dataset, model, tokenizer): + """ Train the model """ + if args.local_rank in [-1, 0]: + if args.log_dir: + tb_writer = SummaryWriter(args.log_dir) + else: + tb_writer = SummaryWriter() + log_writer = open(os.path.join(args.output_dir, "evaluate_logs.txt"), 'a') + + args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + if args.norm: + train_sampler = BatchContinuousRandomSampler(train_dataset, batch_size=args.per_gpu_train_batch_size) if args.local_rank == -1 else BatchContinuousDistributedSampler(train_dataset, batch_size=args.per_gpu_train_batch_size) + else: + train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) + + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + else: + t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + + warmup_steps = args.warmup_steps if args.warmup_steps > 0 else math.ceil( + args.warmup_ratio * t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) + + if args.fp16: + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + + + # multi-gpu training (should be after apex fp16 initialization) + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Distributed training (should be after apex fp16 initialization) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True + ) + + # Train! + logger.info("***** Running training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) + logger.info( + " Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size + * args.gradient_accumulation_steps + * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), + ) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + + global_step = 1 + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + + + tr_loss, logging_loss, best_dev_score = 0.0, 0.0, 0.0 + model.zero_grad() + train_iterator = trange( + epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] + ) + + # Added here for reproductibility + set_seed(args) + + train_lang_ids = [] + train_langs = args.language.split(',') + + for lang in train_langs: + train_lang_ids.append(LANG_2_IDX[lang]) + + for epc in train_iterator: + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0], mininterval=10) + for step, batch in enumerate(epoch_iterator): + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + continue + + model.train() + + input_ids = torch.stack([d[0] for d in batch], dim=1).to(args.device) + attention_mask = torch.stack([d[1] for d in batch], dim=1).to(args.device) + token_type_ids = torch.stack([d[2] for d in batch], dim=1).to(args.device) + start_positions = torch.stack([d[3] for d in batch], dim=1).to(args.device) + end_positions = torch.stack([d[4] for d in batch], dim=1).to(args.device) + query_len = torch.stack([d[-2] for d in batch], dim=1).to(args.device) + lang_ids = torch.stack([d[-1] for d in batch], dim=1).to(args.device) + + inputs = {"input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": attention_mask, + "start_positions":start_positions, + "end_positions":end_positions, + "query_len": query_len, + "lang_ids": lang_ids} + + if args.model_type in ["xlmr"]: + del inputs["token_type_ids"] + + outputs = model(**inputs) + loss = outputs[0] + + if args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + if args.fp16: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + tr_loss += loss.item() + + if (step + 1) % args.gradient_accumulation_steps == 0: + if args.fp16: + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + optimizer.step() + scheduler.step() # Update learning rate schedule + model.zero_grad() + global_step += 1 + + + if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: + tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) + tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) + + logging_loss = tr_loss + + # eval and save model checkpoint + if args.local_rank in [-1, 0] and args.eval_steps > 0 and global_step % args.eval_steps == 0 and global_step >= t_total * 0.3: + result = evaluate(args, model, tokenizer, prefix="", split='dev') + tb_writer.add_scalar( + "eval_exact", result['dev_avg']['exact_match'], global_step) + tb_writer.add_scalar( + "eval_f1", result['dev_avg']['f1'], global_step) + dev_score = (result['dev_avg']['exact_match'] + 2 * result['dev_avg']['f1']) / 3 + + log_writer.write("{0}\t{1}\n".format(global_step, json.dumps(result))) + log_writer.flush() + logger.info(result) + + if dev_score >= best_dev_score: + best_dev_score = dev_score + save_model(args, model, tokenizer) + + if args.max_steps > 0 and global_step > args.max_steps: + epoch_iterator.close() + break + + if args.max_steps > 0 and global_step > args.max_steps: + train_iterator.close() + break + + # eval + if args.local_rank in [-1, 0]: + result = evaluate(args, model, tokenizer, prefix="", split='dev') + tb_writer.add_scalar( + "eval_exact", result['dev_avg']['exact_match'], global_step) + tb_writer.add_scalar( + "eval_f1", result['dev_avg']['f1'], global_step) + dev_score = (result['dev_avg']['exact_match'] + result['dev_avg']['f1']) / 2 + + log_writer.write("{0}\t{1}\n".format(global_step, json.dumps(result))) + log_writer.flush() + logger.info(result) + + if dev_score >= best_dev_score: + best_dev_score = dev_score + save_model(args, model, tokenizer) + + if args.local_rank in [-1, 0]: + tb_writer.close() + log_writer.close() + + return global_step, tr_loss / global_step + +def evaluate(args, model, tokenizer, prefix="", split='dev'): + languages = list(args.language.split(',')) + + if args.task_name == 'xquad' and split == 'dev': + # languages in mlqa + languages = ['en', 'ar', 'de', 'es', 'hi', 'vi', 'zh'] + elif args.task_name == 'tydiqa' and split == 'dev': + # only use en for dev set + languages = ['en'] + + all_languages_results = {} + processor = PROCESSOR_MAP[args.task_name]() + + args.eval_batch_size = args.per_gpu_eval_batch_size + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Batch size = %d", args.eval_batch_size) + + for lang in tqdm(languages, desc="Evaluating"): + + logger.info("evaluating on {0} {1}".format(split, lang)) + + + dataset, examples, features = load_and_cache_examples(args, tokenizer, language=lang, split=split, output_examples=True) + + if args.task_name == 'tydiqa': + dataset_en, examples_en, features_en = load_and_cache_examples(args, tokenizer, language=lang, split='translate-test', output_examples=True) + else: + dataset_en, examples_en, features_en = load_and_cache_examples(args, tokenizer, language='en', split=split, output_examples=True) + + feature_qas_ids = [] + feature_qas_ids.append([feature.qas_id for feature in features_en]) + feature_qas_ids.append([feature.qas_id for feature in features]) + + dataset = AlignDatasetForEval([dataset_en, dataset], feature_qas_ids) + + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + # Note that DistributedSampler samples randomly + eval_sampler = SequentialSampler(dataset) + eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + + # multi-gpu evaluate + if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): + model = torch.nn.DataParallel(model) + + all_results = [] + + for batch in eval_dataloader: + model.eval() + + input_ids = torch.stack([d[0] for d in batch], dim=1).to(args.device) + attention_mask = torch.stack([d[1] for d in batch], dim=1).to(args.device) + token_type_ids = torch.stack([d[2] for d in batch], dim=1).to(args.device) + lang_ids = torch.stack([d[-1] for d in batch], dim=1).to(args.device) + + + inputs = {"input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": attention_mask, + "lang_ids": lang_ids} + + example_indices = batch[1][3] + + with torch.no_grad(): + if args.model_type in ["xlmr"]: + del inputs["token_type_ids"] + + outputs = model(**inputs) + outputs = [outputs[0], outputs[1]] # [start_logits, end_logits] + + for i, example_index in enumerate(example_indices): + eval_feature = features[example_index.item()] + unique_id = int(eval_feature.unique_id) + + output = [to_list(output[i]) for output in outputs] + + start_logits, end_logits = output + result = SquadResult(unique_id, start_logits, end_logits) + + all_results.append(result) + + # Compute predictions + output_prediction_file = os.path.join(args.output_dir, "pred_{}_{}_{}.json".format(prefix, split, lang)) + output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}_{}_{}.json".format(prefix, split, lang)) + + if args.version_2_with_negative: + output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}_{}_{}.json".format(prefix, split, lang)) + else: + output_null_log_odds_file = None + + # XLNet and XLM use a more complex post-processing procedure + if args.model_type in ["xlnet", "xlm"]: + start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top + end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top + + predictions = compute_predictions_log_probs( + examples, + features, + all_results, + args.n_best_size, + args.max_answer_length, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, + start_n_top, + end_n_top, + args.version_2_with_negative, + tokenizer, + args.verbose_logging, + ) + else: + predictions = compute_predictions_logits( + examples, + features, + all_results, + args.n_best_size, + args.max_answer_length, + args.do_lower_case, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, + args.verbose_logging, + args.version_2_with_negative, + args.null_score_diff_threshold, + tokenizer, + map_to_origin=not (args.model_type == "xlmr" and (lang == 'zh' or lang == 'ko')) + ) + + # Compute the F1 and exact scores. + if args.task_name in ['xquad', 'tydiqa']: + results = squad_evaluate(examples, predictions) + elif args.task_name == 'mlqa': + results = evaluate_with_path(processor.get_dataset_path(args.data_dir, split, lang), output_prediction_file, lang) + else: + raise ValueError("not support yet") + all_languages_results["{0}_{1}".format(split, lang)] = {'exact_match': results['exact_match'], 'f1': results['f1']} + + table = PrettyTable() + table.title = f"{args.task_name}-{split}" + table.add_column('lang', ['EM', 'F1']) + for lang in languages: + table.add_column(lang, [ + '%.2f' % (all_languages_results[f"{split}_{lang}"]['exact_match']), + '%.2f' % (all_languages_results[f"{split}_{lang}"]['f1']), + ]) + + table.add_column('Avg.', [ + '%.2f' % np.mean([all_languages_results[f"{split}_{lang}"]['exact_match'] for lang in languages]), + '%.2f' % np.mean([all_languages_results[f"{split}_{lang}"]['f1'] for lang in languages]) + ]) + + logger.info(table) + + all_languages_results["{0}_avg".format(split)] = average_dic([value for key, value in all_languages_results.items() if split in key]) + + return all_languages_results + +def average_dic(dic_list): + if len(dic_list) == 0: + return {} + dic_sum = {} + for dic in dic_list: + if len(dic_sum) == 0: + for key, value in dic.items(): + dic_sum[key] = value + else: + assert set(dic_sum.keys()) == set(dic.keys()), "sum_keys:{0}, dic_keys:{1}".format(set(dic_sum.keys()), set(dic.keys())) + for key, value in dic.items(): + dic_sum[key] += value + for key in dic_sum: + dic_sum[key] /= len(dic_list) + return dic_sum + +def filter_examples(examples, max_num, examples_dev=None): + exist_ids = {} + new_examples = [] + if examples_dev is not None: + for example in examples_dev: + if example.qas_id not in exist_ids: + exist_ids[example.qas_id] = 1 + + for example in examples: + if example.qas_id in exist_ids: + new_examples.append(example) + + else: + for example in examples: + if example.qas_id not in exist_ids: + exist_ids[example.qas_id] = 1 + new_examples.append(example) + + return new_examples[:max_num] + +def load_and_cache_examples(args, tokenizer, language, split="train", output_examples=False, use_barrier=True): + + if use_barrier and args.local_rank not in [-1, 0] and split == "train": + # Make sure only the first process in distributed training process the dataset, and the others will use the cache + torch.distributed.barrier() + + # Load data features from cache or dataset file + input_dir = args.cache_dir if args.cache_dir else "." + data_cache_name = list(filter(None, args.model_name_or_path.split("/"))).pop() + if args.data_cache_name is not None: + data_cache_name = args.data_cache_name + + if args.task_name in ['mlqa', 'xquad'] and split == 'train': + task_name = 'squad' + elif args.task_name == 'xquad' and split == 'dev': + task_name = 'mlqa' + else: + task_name = args.task_name + cached_features_file = os.path.join( + input_dir, + "cached_{}_{}_{}_{}_{}".format( + task_name, + split, + language, + data_cache_name, + str(args.max_seq_length), + ), + ) + + # Init features and dataset from cache if it exists + if os.path.exists(cached_features_file) and not args.overwrite_cache: + logger.info("Loading features from cached file %s", cached_features_file) + features_and_dataset = torch.load(cached_features_file) + features, dataset, examples = ( + features_and_dataset["features"], + features_and_dataset["dataset"], + features_and_dataset["examples"], + ) + else: + logger.info("Creating features from dataset file at %s, language %s", input_dir, language) + + if not args.data_dir: + raise ValueError("data dir can't be empty") + processor = PROCESSOR_MAP[args.task_name]() + if split == "dev": + examples = processor.get_dev_examples_by_language(args.data_dir, language=language) + elif split == "test": + examples = processor.get_test_examples_by_language(args.data_dir, language=language) + elif split == "translate-test": + examples = processor.get_translate_test_examples_by_language(args.data_dir, language=language) + else: + examples = processor.get_train_examples_by_language(args.data_dir, language=language) + + features, dataset = squad_convert_examples_to_features( + examples=examples, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + doc_stride=args.doc_stride, + max_query_length=args.max_query_length, + is_training=split=="train", + return_dataset="pt", + threads=args.threads, + lang_id=LANG_2_IDX['en'] if split in ['translate-test', 'translate-dev'] else LANG_2_IDX[language], + ) + + if args.local_rank in [-1, 0]: + os.makedirs(input_dir, exist_ok=True) + logger.info("Saving features into cached file %s", cached_features_file) + torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file) + + if use_barrier and args.local_rank == 0 and split=="train": + # Make sure only the first process in distributed training process the dataset, and the others will use the cache + torch.distributed.barrier() + + if output_examples: + return dataset, examples, features + return dataset + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + # help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), + ) + parser.add_argument( + "--data_cache_name", + default=None, + type=str, + help="The name of cached data", + ) + + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model checkpoints and predictions will be written.", + ) + + # Other parameters + parser.add_argument( + "--log_dir", + default=None, + type=str, + help="The output log dir." + ) + + parser.add_argument( + "--benchmark", default='xtreme', type=str, choices=['xglue', 'xtreme'], help="xglue/xtreme" + ) + parser.add_argument( + "--task_name", default='mlqa', type=str, help="task" + ) + parser.add_argument( + "--pkl_index", default="0", type=str, help="pickle index for teach student training" + ) + parser.add_argument( + "--use_squad_for_tydiqa", action='store_true', help="include squad english data for tydiqa training" + ) + parser.add_argument( + "--gpu_id", default=None, type=str, help="GPU id" + ) + + parser.add_argument( + "--data_dir", + default=None, + type=str, + help="The input data dir. Should contain the .json files for the task." + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", + ) + + parser.add_argument( + "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" + ) + parser.add_argument( + "--tokenizer_name", + default="", + type=str, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--cache_dir", + default="", + type=str, + help="", + ) + + parser.add_argument( + "--version_2_with_negative", + action="store_true", + help="If true, the SQuAD examples contain some that do not have an answer.", + ) + parser.add_argument( + "--null_score_diff_threshold", + type=float, + default=0.0, + help="If null_score - best_non_null is greater than the threshold predict null.", + ) + + parser.add_argument( + "--max_seq_length", + default=384, + type=int, + help="The maximum total input sequence length after WordPiece tokenization. Sequences " + "longer than this will be truncated, and sequences shorter than this will be padded.", + ) + parser.add_argument( + "--doc_stride", + default=128, + type=int, + help="When splitting up a long document into chunks, how much stride to take between chunks.", + ) + parser.add_argument( + "--hidden_dropout_prob", + default=0.1, + type=float, + help="When splitting up a long document into chunks, how much stride to take between chunks.", + ) + + parser.add_argument( + "--max_query_length", + default=64, + type=int, + help="The maximum number of tokens for the question. Questions longer than this will " + "be truncated to this length.", + ) + parser.add_argument("--do_train", action='store_true', help="Whether to run training.") + parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") + parser.add_argument( + "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step." + ) + parser.add_argument( + "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." + ) + + parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") + parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.") + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform." + ) + parser.add_argument( + "--max_train_samples_per_epoch", default=None, type=int, help="Not use, for consistent usage with classification and tagging tasks" + ) + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") + parser.add_argument("--warmup_ratio", default=0, type=float, help="Linear warmup over warmup_ratio.") + parser.add_argument( + "--n_best_size", + default=20, + type=int, + help="The total number of n-best predictions to generate in the nbest_predictions.json output file.", + ) + parser.add_argument( + "--max_answer_length", + default=30, + type=int, + help="The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another.", + ) + parser.add_argument( + "--verbose_logging", + action="store_true", + help="If true, all of the warnings related to data processing will be printed. " + "A number of warnings are expected for a normal SQuAD evaluation.", + ) + parser.add_argument( + "--lang_id", + default=0, + type=int, + help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)", + ) + parser.add_argument("--logging_steps", type=int, default=100, help="Log every X updates steps.") + parser.add_argument("--eval_steps", type=int, default=200, help="Eval every X updates steps.") + parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available") + parser.add_argument( + "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" + ) + parser.add_argument( + "--overwrite_training", action="store_true", help="Overwrite the cached training model" + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", + ) + parser.add_argument( + "--fp16_opt_level", + type=str, + default="O1", + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html", + ) + parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.") + parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.") + parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") + # cross-lingual part + parser.add_argument( + "--language", + default=None, + type=str, + required=True, + help="Train and evaluation language.", + ) + + parser.add_argument("--eval_split", default='test', type=str, help="") + parser.add_argument("--norm", action="store_true") + parser.add_argument("--mix_layer", default=7, type=int) + parser.add_argument("--mix_layers", default=None, type=str) + parser.add_argument("--alpha", default=0.2, type=float) + parser.add_argument("--teaching_weight", default=0.1, type=float) + parser.add_argument("--consist_weight", default=0.1, type=float) + parser.add_argument("--align_weight", default=0.01, type=float) + parser.add_argument("--temp", default=0.05, type=float) + parser.add_argument("--cl", action="store_true") + + + args = parser.parse_args() + + train_langs_t = args.language.split(',') + train_langs = [] + + for l in train_langs_t: + if l not in train_langs: + train_langs.append(l) + + for i, lang in enumerate(train_langs): + LANG_2_IDX[lang] = i + + if args.doc_stride >= args.max_seq_length - args.max_query_length: + logger.warning( + "WARNING - You've set a doc stride which may be superior to the document length in some " + "examples. This could result in errors when building features from the examples. Please reduce the doc " + "stride or increase the maximum length to ensure the features are correctly built." + ) + + if ( + os.path.exists(args.output_dir) + and os.listdir(args.output_dir) + and args.do_train + and not args.overwrite_output_dir + ): + raise ValueError( + "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( + args.output_dir + ) + ) + + # Setup distant debugging if needed + if args.server_ip and args.server_port: + # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script + import ptvsd + + print("Waiting for debugger attach") + ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) + ptvsd.wait_for_attach() + + + # Setup CUDA, GPU & distributed training + if args.gpu_id: + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend="nccl") + args.n_gpu = 1 + args.device = device + + # Setup logging + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, + ) + if args.local_rank in [-1, 0]: + logger.setLevel(logging.INFO) + else: + logger.setLevel(logging.WARN) + + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + args.local_rank, + device, + args.n_gpu, + bool(args.local_rank != -1), + args.fp16, + ) + logger.info("Training/evaluation parameters %s", args) + + # Set seed + set_seed(args) + + # Load pretrained model and tokenizer + if args.local_rank not in [-1, 0]: + # Make sure only the first process in distributed training will download model & vocab + torch.distributed.barrier() + + args.model_type = args.model_type.lower() + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config = config_class.from_pretrained( + args.config_name if args.config_name else args.model_name_or_path + ) + + + tokenizer = tokenizer_class.from_pretrained( + args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, + do_lower_case=args.do_lower_case + ) + model = model_class.from_pretrained( + args.model_name_or_path, + config=config, + num_lang=len(args.language.split(',')), + args=args + ) + + + if args.local_rank == 0: + # Make sure only the first process in distributed training will download model & vocab + torch.distributed.barrier() + + model.to(args.device) + # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set. + # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will + # remove the need for this code, but it is still valid. + if args.fp16: + try: + import apex + apex.amp.register_half_function(torch, "einsum") + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + + # Training + if args.do_train: + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + train_langs = args.language.split(',') + + if args.local_rank not in [-1, 0]: + # Make sure only the first process in distributed training will download model & vocab + torch.distributed.barrier() + + logger.info(train_langs) + + datasets, feature_qas_ids = [], [] + examples, features = [], [] + for lang in train_langs: + + lg_train_dataset, lg_train_examples, lg_train_features = load_and_cache_examples(args, + tokenizer, + language=lang, + split="train", + output_examples=True, + use_barrier=True) + datasets.append(lg_train_dataset) + examples.append(lg_train_examples) + features.append(lg_train_features) + feature_qas_ids.append([feature.qas_id for feature in lg_train_features]) + + train_dataset = AlignDataset(datasets, examples, feature_qas_ids, en_index=train_langs.index('en')) + + if args.local_rank == 0: + # Make sure only the first process in distributed training will download model & vocab + torch.distributed.barrier() + + global_step, tr_loss = train(args, train_dataset, model, tokenizer) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) + + + # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory + if args.do_eval and args.local_rank in [-1, 0]: + log_writer = open(os.path.join(args.output_dir, "evaluate_logs.txt"), 'a') + + # Load model from output_dir + checkpoints = [args.output_dir] + + results = {} + for checkpoint in checkpoints: + global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" + + model = model_class.from_pretrained( + checkpoint, + config=config, + num_lang=len(args.language.split(',')), + args=args + ) + model.to(args.device) + + result = evaluate(args, model, tokenizer, prefix=global_step, split=args.eval_split) + + filtered_result = {} + for k, v in result.items(): + filtered_result[k] = { key:val for key, val in v.items() if key in ['exact', 'exact_match', 'f1']} + log_writer.write("{}\t{}\n".format(checkpoint, json.dumps(filtered_result))) + + results[checkpoint] = filtered_result + + log_writer.close() + logger.info("Results: {}".format(results)) + + + logger.info("Task Finished!") + +if __name__ == "__main__": + main() diff --git a/examples/X-STA/scripts/download_data.sh b/examples/X-STA/scripts/download_data.sh new file mode 100644 index 0000000..003979e --- /dev/null +++ b/examples/X-STA/scripts/download_data.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# Copyright 2020 Google and DeepMind. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +REPO=$PWD +DIR=$REPO/data/ +mkdir -p $DIR + +function download_squad { + echo "download squad" + base_dir=$DIR/squad/ + mkdir -p $base_dir && cd $base_dir + wget https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/train-v1.1.json -q --show-progress + wget https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/dev-v1.1.json -q --show-progress + echo "Successfully downloaded data at $DIR/squad" >> $DIR/download.log +} + +function download_xquad { + echo "download xquad" + base_dir=$DIR/xquad/ + mkdir -p $base_dir && cd $base_dir + for lang in ar de el en es hi ru th tr vi zh; do + wget https://raw.githubusercontent.com/deepmind/xquad/master/xquad.${lang}.json -q --show-progress + done + python $REPO/third_party/utils_preprocess.py --data_dir $base_dir --output_dir $base_dir --task xquad + echo "Successfully downloaded data at $DIR/xquad" >> $DIR/download.log +} + +function download_mlqa { + echo "download mlqa" + base_dir=$DIR/mlqa/ + mkdir -p $base_dir && cd $base_dir + zip_file=MLQA_V1.zip + wget https://dl.fbaipublicfiles.com/MLQA/${zip_file} -q --show-progress + unzip -qq ${zip_file} + rm ${zip_file} + python $REPO/third_party/utils_preprocess.py --data_dir $base_dir/MLQA_V1/test --output_dir $base_dir --task mlqa + echo "Successfully downloaded data at $DIR/mlqa" >> $DIR/download.log +} + +function download_tydiqa { + echo "download tydiqa-goldp" + base_dir=$DIR/tydiqa/ + mkdir -p $base_dir && cd $base_dir + tydiqa_train_file=tydiqa-goldp-v1.1-train.json + tydiqa_dev_file=tydiqa-goldp-v1.1-dev.tgz + wget https://storage.googleapis.com/tydiqa/v1.1/${tydiqa_train_file} -q --show-progress + wget https://storage.googleapis.com/tydiqa/v1.1/${tydiqa_dev_file} -q --show-progress + tar -xf ${tydiqa_dev_file} + rm ${tydiqa_dev_file} + out_dir=$base_dir/tydiqa-goldp-v1.1-train + python $REPO/third_party/utils_preprocess.py --data_dir $base_dir --output_dir $out_dir --task tydiqa + mv $base_dir/$tydiqa_train_file $out_dir/ + echo "Successfully downloaded data at $DIR/tydiqa" >> $DIR/download.log +} + + +download_squad +download_xquad +download_mlqa +download_tydiqa diff --git a/examples/X-STA/scripts/mbert/mlqa.sh b/examples/X-STA/scripts/mbert/mlqa.sh new file mode 100644 index 0000000..f83fb9b --- /dev/null +++ b/examples/X-STA/scripts/mbert/mlqa.sh @@ -0,0 +1,51 @@ +GPU=0,1,2,3 + +REPO=$PWD + +DATA_DIR=$REPO/data +MODEL_NAME_OR_PATH='bert-base-multilingual-cased' + +n_gpu=4 +epoch=1 +bsz=8 +grad_acc=1 +wd=0.0001 + +lr=3e-5 + +alpha=0.2 +mix_layer=7 + +OUTPUT_DIR=$REPO/outputs/mbert_mlqa +mkdir -p $OUTPUT_DIR +CUDA_VISIBLE_DEVICES=$GPU python -m torch.distributed.launch --nproc_per_node=${n_gpu} --master_port=$RANDOM ./run_xqa.py \ + --task_name mlqa \ + --data_dir $DATA_DIR \ + --model_type bert \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --language en,ar,de,es,hi,vi,zh \ + --do_train \ + --do_eval \ + --per_gpu_train_batch_size $bsz \ + --gradient_accumulation_steps $grad_acc \ + --learning_rate ${lr} \ + --per_gpu_eval_batch_size 32 \ + --num_train_epochs $epoch \ + --eval_steps 500 \ + --max_steps 24000 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --output_dir $OUTPUT_DIR \ + --log_dir $OUTPUT_DIR \ + --threads 8 \ + --cache_dir $DATA_DIR/caches_mbert_mlqa \ + --overwrite_output_dir \ + --warmup_ratio 0.1 \ + --weight_decay $wd \ + --consist_weight 0.05 \ + --teaching_weight 0.1 \ + --align_weight 0.05 \ + --alpha $alpha \ + --mix_layer $mix_layer \ + --norm \ + --cl diff --git a/examples/X-STA/scripts/mbert/tydiqa.sh b/examples/X-STA/scripts/mbert/tydiqa.sh new file mode 100644 index 0000000..c639d92 --- /dev/null +++ b/examples/X-STA/scripts/mbert/tydiqa.sh @@ -0,0 +1,51 @@ +GPU=0,1,2,3 + +REPO=$PWD + +DATA_DIR=$REPO/data +MODEL_NAME_OR_PATH='bert-base-multilingual-cased' + +n_gpu=4 +epoch=1 +bsz=8 +grad_acc=1 +wd=0.0001 + +lr=3e-5 + +alpha=0.2 +mix_layer=7 + +OUTPUT_DIR=$REPO/outputs/mbert_tydiqa +mkdir -p $OUTPUT_DIR +CUDA_VISIBLE_DEVICES=$GPU python -m torch.distributed.launch --nproc_per_node=${n_gpu} --master_port=$RANDOM ./run_xqa.py \ + --task_name tydiqa \ + --data_dir $DATA_DIR \ + --model_type bert \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --language en,ar,bn,fi,id,ko,ru,sw,te \ + --do_train \ + --do_eval \ + --per_gpu_train_batch_size $bsz \ + --gradient_accumulation_steps $grad_acc \ + --learning_rate ${lr} \ + --per_gpu_eval_batch_size 32 \ + --num_train_epochs $epoch \ + --eval_steps 100 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --output_dir $OUTPUT_DIR \ + --log_dir $OUTPUT_DIR \ + --threads 16 \ + --cache_dir $DATA_DIR/caches_mbert_tydiqa \ + --overwrite_output_dir \ + --warmup_steps 200 \ + --weight_decay $wd \ + --consist_weight 0.05 \ + --teaching_weight 0.1 \ + --align_weight 0.05 \ + --alpha $alpha \ + --mix_layer $mix_layer \ + --norm \ + --cl + diff --git a/examples/X-STA/scripts/mbert/xquad.sh b/examples/X-STA/scripts/mbert/xquad.sh new file mode 100644 index 0000000..e487e03 --- /dev/null +++ b/examples/X-STA/scripts/mbert/xquad.sh @@ -0,0 +1,51 @@ +GPU=4,5,6,7 + +REPO=$PWD + +DATA_DIR=$REPO/data +MODEL_NAME_OR_PATH='bert-base-multilingual-cased' + +n_gpu=4 +epoch=1 +bsz=8 +grad_acc=1 +wd=0.0001 + +lr=3e-5 + +alpha=0.2 +mix_layer=7 + +OUTPUT_DIR=$REPO/outputs/mbert_xquad +mkdir -p $OUTPUT_DIR +CUDA_VISIBLE_DEVICES=$GPU python -m torch.distributed.launch --nproc_per_node=${n_gpu} --master_port=$RANDOM ./run_xqa.py \ + --task_name xquad \ + --data_dir $DATA_DIR \ + --model_type bert \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --language en,ar,de,el,es,hi,ru,th,tr,vi,zh \ + --do_train \ + --do_eval \ + --per_gpu_train_batch_size $bsz \ + --gradient_accumulation_steps $grad_acc \ + --learning_rate ${lr} \ + --per_gpu_eval_batch_size 32 \ + --num_train_epochs $epoch \ + --eval_steps 500 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --output_dir $OUTPUT_DIR \ + --log_dir $OUTPUT_DIR \ + --threads 8 \ + --cache_dir $DATA_DIR/caches_mbert_mlqa \ + --overwrite_output_dir \ + --warmup_ratio 0.1 \ + --weight_decay $wd \ + --consist_weight 0.05 \ + --teaching_weight 0.1 \ + --align_weight 0.05 \ + --alpha $alpha \ + --mix_layer $mix_layer \ + --norm \ + --cl + diff --git a/examples/X-STA/scripts/xlmr/mlqa.sh b/examples/X-STA/scripts/xlmr/mlqa.sh new file mode 100644 index 0000000..b1c3840 --- /dev/null +++ b/examples/X-STA/scripts/xlmr/mlqa.sh @@ -0,0 +1,51 @@ +GPU=0,1,2,3 + +REPO=$PWD + +DATA_DIR=$REPO/data +MODEL_NAME_OR_PATH='xlm-roberta-base' + +n_gpu=4 +epoch=1 +bsz=8 +grad_acc=1 +wd=0.0001 + +lr=3e-5 + +alpha=0.2 +mix_layer=7 + +OUTPUT_DIR=$REPO/outputs/xlmr_mlqa +mkdir -p $OUTPUT_DIR +CUDA_VISIBLE_DEVICES=$GPU python -m torch.distributed.launch --nproc_per_node=${n_gpu} --master_port=$RANDOM ./run_xqa.py \ + --task_name mlqa \ + --data_dir $DATA_DIR \ + --model_type xlmr \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --language en,ar,de,es,hi,vi,zh \ + --do_train \ + --do_eval \ + --per_gpu_train_batch_size $bsz \ + --gradient_accumulation_steps $grad_acc \ + --learning_rate ${lr} \ + --per_gpu_eval_batch_size 32 \ + --num_train_epochs $epoch \ + --eval_steps 500 \ + --max_steps 24000 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --output_dir $OUTPUT_DIR \ + --log_dir $OUTPUT_DIR \ + --threads 8 \ + --cache_dir $DATA_DIR/caches_xlmr_mlqa \ + --overwrite_output_dir \ + --warmup_ratio 0.1 \ + --weight_decay $wd \ + --consist_weight 0.05 \ + --teaching_weight 0.1 \ + --align_weight 0.05 \ + --alpha $alpha \ + --mix_layer $mix_layer \ + --norm \ + --cl diff --git a/examples/X-STA/scripts/xlmr/tydiqa.sh b/examples/X-STA/scripts/xlmr/tydiqa.sh new file mode 100644 index 0000000..9c6b6d4 --- /dev/null +++ b/examples/X-STA/scripts/xlmr/tydiqa.sh @@ -0,0 +1,47 @@ +GPU=0,1,2,3 + +REPO=$PWD + +DATA_DIR=$REPO/data +MODEL_NAME_OR_PATH='xlm-roberta-base' + +n_gpu=4 +epoch=1 +bsz=8 +grad_acc=1 +wd=0.0001 + +lr=3e-5 + +OUTPUT_DIR=$REPO/outputs/xlmr_tydiqa +mkdir -p $OUTPUT_DIR +CUDA_VISIBLE_DEVICES=$GPU python -m torch.distributed.launch --nproc_per_node=${n_gpu} --master_port=$RANDOM ./run_xqa.py \ + --task_name tydiqa \ + --data_dir $DATA_DIR \ + --model_type xlmr \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --language en,ar,bn,fi,id,ko,ru,sw,te \ + --do_train \ + --do_eval \ + --per_gpu_train_batch_size $bsz \ + --gradient_accumulation_steps $grad_acc \ + --learning_rate ${lr} \ + --per_gpu_eval_batch_size 32 \ + --num_train_epochs $epoch \ + --eval_steps 100 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --output_dir $OUTPUT_DIR \ + --log_dir $OUTPUT_DIR \ + --threads 16 \ + --cache_dir $DATA_DIR/caches_xlmr_tydiqa \ + --overwrite_output_dir \ + --warmup_ratio 0.1 \ + --weight_decay $wd \ + --consist_weight 0.05 \ + --teaching_weight 0.1 \ + --align_weight 0.05 \ + --alpha $alpha \ + --mix_layer $mix_layer \ + --norm \ + --cl diff --git a/examples/X-STA/scripts/xlmr/xquad.sh b/examples/X-STA/scripts/xlmr/xquad.sh new file mode 100644 index 0000000..4a081d0 --- /dev/null +++ b/examples/X-STA/scripts/xlmr/xquad.sh @@ -0,0 +1,52 @@ +GPU=4,5,6,7 + +REPO=$PWD + +DATA_DIR=$REPO/data +MODEL_NAME_OR_PATH='xlm-roberta-base' + + +n_gpu=4 +epoch=1 +bsz=8 +grad_acc=1 +wd=0.0001 + +lr=3e-5 + +alpha=0.2 +mix_layer=7 + +lr=3e-5 + +OUTPUT_DIR=outputs/xlmr_xquad +mkdir -p $OUTPUT_DIR +CUDA_VISIBLE_DEVICES=$GPU python -m torch.distributed.launch --nproc_per_node=${n_gpu} --master_port=$RANDOM ./run_xqa.py \ + --task_name xquad \ + --data_dir $DATA_DIR \ + --model_type xlmr \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --language en,ar,de,el,es,hi,ru,th,tr,vi,zh \ + --do_eval \ + --per_gpu_train_batch_size $bsz \ + --gradient_accumulation_steps $grad_acc \ + --learning_rate ${lr} \ + --per_gpu_eval_batch_size 8 \ + --num_train_epochs $epoch \ + --eval_steps 500 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --output_dir $OUTPUT_DIR \ + --log_dir $OUTPUT_DIR/wo_t \ + --threads 8 \ + --cache_dir $DATA_DIR/caches_xlmr_mlqa \ + --overwrite_output_dir \ + --warmup_ratio 0.1 \ + --weight_decay $wd \ + --consist_weight 0.05 \ + --teaching_weight 0.1 \ + --align_weight 0.05 \ + --alpha $alpha \ + --mix_layer $mix_layer \ + --norm \ + --cl diff --git a/examples/X-STA/src/bert.py b/examples/X-STA/src/bert.py new file mode 100644 index 0000000..27c3814 --- /dev/null +++ b/examples/X-STA/src/bert.py @@ -0,0 +1,629 @@ +from typing import * +import torch +from torch import nn +import torch.nn.functional as F + +from transformers import BertPreTrainedModel +from transformers.models.bert.modeling_bert import ( + BertSelfAttention, BertIntermediate, BertOutput, BertEmbeddings, BertPooler, BertAttention +) +from transformers.modeling_outputs import ( + ModelOutput, + QuestionAnsweringModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions +) + +from src.utils import BatchNorm, AttentionTeacher, get_pair_entropy, ContrastiveLoss + +class BertMixSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor, ratio=None): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + if ratio is not None: + hidden_states = self.LayerNorm(ratio * hidden_states + (1-ratio) * input_tensor) + else: + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = BertSelfAttention(config) + self.output = BertMixSelfOutput(config) + self.pruned_heads = set() + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ratio=None + ): + + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states, ratio) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + +class BertMixLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = BertAttention(config) + self.add_cross_attention = config.add_cross_attention + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + add_attention=True, + add_ffn=True, + ratio=None + ): + + if add_attention: + attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ratio=ratio + ) + if not add_ffn: + return attention_outputs + attention_output = attention_outputs[0] + else: + attention_output = hidden_states + attention_outputs = (attention_output, None) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + return outputs + + + +class BertMixEncoder(nn.Module): + def __init__(self, config, mix_layer): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertMixLayer(config) for _ in range(config.num_hidden_layers)]) + self.mix_layer = mix_layer + + self.w = nn.Parameter(torch.tensor(1.0)) + self.b = nn.Parameter(torch.tensor(0.0)) + + self.f = nn.Sequential( + nn.Dropout(config.hidden_dropout_prob), + nn.Linear(config.hidden_size * 2, config.hidden_size) + ) + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + + def forward( + self, + hidden_states, + attention_mask=None, + raw_attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + lang_ids=None + ): + if isinstance(self.mix_layer, str): + mix_layers = [int(x) for x in self.mix_layer.split(',')] + else: + mix_layers = [self.mix_layer] + attention_entropy = None + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + all_attention_entropy = () + + next_decoder_cache = () if use_cache else None + + + attention_mask_en = attention_mask.view(-1, 2, *attention_mask.size()[1:])[:, 0] + attention_mask_trg = attention_mask.view(-1, 2, *attention_mask.size()[1:])[:, 1] + + raw_attention_mask = raw_attention_mask[:, None, None, :] + + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if i in mix_layers: + hidden_states_en = hidden_states.view(-1, 2, hidden_states.size(-2), hidden_states.size(-1))[:, 0] + hidden_states_trg = hidden_states.view(-1, 2, hidden_states.size(-2), hidden_states.size(-1))[:, 1] + + # trg self attention + self_attention_output = layer_module.attention.self( + hidden_states_trg, + attention_mask_trg, + layer_head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions + )[0] + + src_lang_id = int(lang_ids[0]) + trg_lang_id = int(lang_ids[1]) + if src_lang_id == trg_lang_id: + hidden_states_en_convert = hidden_states_en + else: + hidden_states_en_convert = hidden_states_en.detach() + self.f(torch.cat([hidden_states_en.detach(), hidden_states_trg.detach()], dim=-1)) + + cross_attention_outputs = layer_module.attention.self( + hidden_states_trg, + attention_mask_trg, + layer_head_mask, + encoder_hidden_states=hidden_states_en_convert, + encoder_attention_mask=attention_mask_en, + past_key_value=past_key_value, + output_attentions=True + ) + + cross_attention_output = cross_attention_outputs[0] + cross_attention_score = cross_attention_outputs[1] + + attention_entropy = get_pair_entropy(cross_attention_score) + + ratio = self.w * 0.3 + self.b + + attention_output = layer_module.attention.output( + ratio * cross_attention_output + (1 - ratio) * self_attention_output, + hidden_states_trg + ) + + # trg ffn + ffn_layer_outputs_trg = layer_module( + attention_output, + attention_mask_trg, + layer_head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + add_attention=False + ) + hidden_states_trg = ffn_layer_outputs_trg[0] + + # src + hidden_states_en = layer_module( + hidden_states_en, + attention_mask_en, + layer_head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions + )[0] + + hidden_states = torch.stack([hidden_states_en, hidden_states_trg], dim=1) + hidden_states = hidden_states.view(-1, *hidden_states.size()[2:]) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if attention_entropy is not None: + all_attention_entropy = all_attention_entropy + (attention_entropy,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + all_attention_entropy + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + +class BertMixModel(BertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration + set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, mix_layer=7, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + self.encoder = BertMixEncoder(config, mix_layer) + # self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + lang_ids=None + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + raw_attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + lang_ids=lang_ids + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + +class BertForQuestionAnswering(BertPreTrainedModel): + def __init__(self, config, args, num_lang=2): + super().__init__(config) + + self.num_labels = config.num_labels + + self.bert = BertMixModel(config, mix_layer=args.mix_layers if args.mix_layers is not None else args.mix_layer, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + self.init_weights() + + + self.teaching_weight = args.teaching_weight + self.align_weight = args.align_weight + self.consist_weight = args.consist_weight + self.alpha = args.alpha + self.norm = args.norm + self.cl = args.cl + + if self.cl: + self.cl_loss = ContrastiveLoss(config, temp=args.temp) + else: + self.mse_loss = nn.MSELoss() + + if self.teaching_weight > 0: + self.attention_teacher = AttentionTeacher(config) + + if self.norm: + self.bn = nn.ModuleList([BatchNorm(config.hidden_size) for _ in range(num_lang)]) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + query_len=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + lang_ids=None, + return_sequence_output=False + ): + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.size(0) + # Number of lang in one instance + # num_lang = input_ids.size(1) + + # Flatten input + input_ids = input_ids.view((-1, input_ids.size(-1))) # (bsz * 2, len) + attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bsz * 2, len) + if lang_ids is not None: + lang_ids = lang_ids.view(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bsz * 2, len) + if query_len is not None: + query_len = query_len.view(-1) + + outputs = self.bert( + input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=False, + lang_ids=lang_ids + ) + + sequence_output = outputs[0] + + if return_sequence_output: + return sequence_output + attention_entropy = outputs[-1] + + if self.norm: + sequence_output = sequence_output.view(batch_size, 2, sequence_output.size(-2), sequence_output.size(-1)) + + attention_mask_src = attention_mask.view(batch_size, 2, -1)[:, 0] + attention_mask_trg = attention_mask.view(batch_size, 2, -1)[:, 1] + + src_lang_id = int(lang_ids[0]) + trg_lang_id = int(lang_ids[1]) + sequence_output_src = self.bn[src_lang_id](sequence_output[:, 0], attention_mask_src) + sequence_output_trg = self.bn[trg_lang_id](sequence_output[:, 1], attention_mask_trg) + sequence_output = torch.stack([sequence_output_src, sequence_output_trg], dim=1) + sequence_output = sequence_output.view(batch_size * 2, *sequence_output.size()[2:]) + + seq_rep = (sequence_output * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) + + extended_attention_mask = attention_mask[:, None, None, :] + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + sequence_output = sequence_output.view(batch_size, 2, sequence_output.size(-2), sequence_output.size(-1)) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + if self.teaching_weight > 0: + logits_teacher = self.attention_teacher( + query=sequence_output[:, 1], + key=sequence_output[:, 0], + value=logits[:, 0].detach(), + attention_mask=extended_attention_mask.view(batch_size, 2, *extended_attention_mask.size()[1:])[:, 0] + ) + + start_logits_t, end_logits_t = logits_teacher.split(1, dim=-1) + start_logits_t = start_logits_t.squeeze(-1).contiguous() + end_logits_t = end_logits_t.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + start_positions = start_positions.view(-1) + end_positions = end_positions.view(-1) + + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(-1) + start_positions = start_positions.view(-1).clamp(0, ignored_index) + end_positions = end_positions.view(-1).clamp(0, ignored_index) + + loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss_src = loss_fct(start_logits.view(batch_size, 2, -1)[:, 0], start_positions.view(-1, 2)[:, 0]) + end_loss_src = loss_fct(end_logits.view(batch_size, 2, -1)[:, 0], end_positions.view(-1, 2)[:, 0]) + + start_loss_trg = loss_fct(start_logits.view(batch_size, 2, -1)[:, 1], start_positions.view(-1, 2)[:, 1]) + end_loss_trg = loss_fct(end_logits.view(batch_size, 2, -1)[:, 1], end_positions.view(-1, 2)[:, 1]) + + + loss = self.alpha * (start_loss_src + end_loss_src) / 2 \ + + (1 - self.alpha) * (start_loss_trg + end_loss_trg) / 2 + + if self.teaching_weight > 0: + start_loss_t = loss_fct(start_logits_t, start_positions.view(-1, 2)[:, 1]) + end_loss_t = loss_fct(end_logits_t, end_positions.view(-1, 2)[:, 1]) + loss += self.teaching_weight * (start_loss_t + end_loss_t) / 2 + + loss += self.align_weight * attention_entropy[0].mean() + + seq_rep = seq_rep.view(batch_size, 2, -1) + if self.cl: + loss += self.consist_weight * self.cl_loss(seq_rep[:, 0], seq_rep[:, 1]) + else: + loss += self.consist_weight * self.mse_loss(seq_rep[:, 0], seq_rep[:, 1]) + + else: + # predict + start_logits = start_logits[:, 1] + end_logits = end_logits[:, 1] + if self.teaching_weight > 0: + start_logits += start_logits_t + end_logits += end_logits_t + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + else: + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + ) diff --git a/examples/X-STA/src/data/mlqa_evaluation_v1.py b/examples/X-STA/src/data/mlqa_evaluation_v1.py new file mode 100644 index 0000000..8c11a7f --- /dev/null +++ b/examples/X-STA/src/data/mlqa_evaluation_v1.py @@ -0,0 +1,161 @@ +# Copyright (c) 2019-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +""" Official evaluation script for the MLQA dataset. """ +from __future__ import print_function +from collections import Counter +import string +import re +import argparse +import json +import sys +import unicodedata + + +PUNCT = {chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith('P')}.union(string.punctuation) +WHITESPACE_LANGS = ['en', 'es', 'hi', 'vi', 'de', 'ar'] +MIXED_SEGMENTATION_LANGS = ['zh'] + + +def whitespace_tokenize(text): + return text.split() + + +def mixed_segmentation(text): + segs_out = [] + temp_str = "" + for char in text: + if re.search(r'[\u4e00-\u9fa5]', char) or char in PUNCT: + if temp_str != "": + ss = whitespace_tokenize(temp_str) + segs_out.extend(ss) + temp_str = "" + segs_out.append(char) + else: + temp_str += char + + if temp_str != "": + ss = whitespace_tokenize(temp_str) + segs_out.extend(ss) + + return segs_out + + +def normalize_answer(s, lang): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text, lang): + if lang == 'en': + return re.sub(r'\b(a|an|the)\b', ' ', text) + elif lang == 'es': + return re.sub(r'\b(un|una|unos|unas|el|la|los|las)\b', ' ', text) + elif lang == 'hi': + return text # Hindi does not have formal articles + elif lang == 'vi': + return re.sub(r'\b(của|là|cái|chiếc|những)\b', ' ', text) + elif lang == 'de': + return re.sub(r'\b(ein|eine|einen|einem|eines|einer|der|die|das|den|dem|des)\b', ' ', text) + elif lang == 'ar': + return re.sub('\sال^|ال', ' ', text) + elif lang == 'zh': + return text # Chinese does not have formal articles + else: + raise Exception('Unknown Language {}'.format(lang)) + + def white_space_fix(text, lang): + if lang in WHITESPACE_LANGS: + tokens = whitespace_tokenize(text) + elif lang in MIXED_SEGMENTATION_LANGS: + tokens = mixed_segmentation(text) + else: + raise Exception('Unknown Language {}'.format(lang)) + return ' '.join([t for t in tokens if t.strip() != '']) + + def remove_punc(text): + return ''.join(ch for ch in text if ch not in PUNCT) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)), lang), lang) + + +def f1_score(prediction, ground_truth, lang): + prediction_tokens = normalize_answer(prediction, lang).split() + ground_truth_tokens = normalize_answer(ground_truth, lang).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth, lang): + return (normalize_answer(prediction, lang) == normalize_answer(ground_truth, lang)) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths, lang): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth, lang) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def evaluate(dataset, predictions, lang): + f1 = exact_match = total = 0 + for article in dataset: + for paragraph in article['paragraphs']: + for qa in paragraph['qas']: + total += 1 + if qa['id'] not in predictions: + message = 'Unanswered question ' + qa['id'] + \ + ' will receive score 0.' + print(message, file=sys.stderr) + continue + ground_truths = list(map(lambda x: x['text'], qa['answers'])) + prediction = predictions[qa['id']] + exact_match += metric_max_over_ground_truths( + exact_match_score, prediction, ground_truths, lang) + f1 += metric_max_over_ground_truths( + f1_score, prediction, ground_truths, lang) + + exact_match = 100.0 * exact_match / total + f1 = 100.0 * f1 / total + + return {'exact_match': exact_match, 'f1': f1} + + +def evaluate_with_path(dataset_file, prediction_file, answer_language): + with open(dataset_file) as dataset_file_reader: + dataset_json = json.load(dataset_file_reader) + dataset = dataset_json['data'] + with open(prediction_file) as prediction_file_reader: + predictions = json.load(prediction_file_reader) + return evaluate(dataset, predictions, answer_language) + +if __name__ == '__main__': + expected_version = '1.0' + parser = argparse.ArgumentParser( + description='Evaluation for MLQA ' + expected_version) + parser.add_argument('dataset_file', help='Dataset file') + parser.add_argument('prediction_file', help='Prediction File') + parser.add_argument('answer_language', help='Language code of answer language') + + args = parser.parse_args() + with open(args.dataset_file) as dataset_file: + dataset_json = json.load(dataset_file) + if (str(dataset_json['version']) != expected_version): + print('Evaluation expects v-' + expected_version + + ', but got dataset with v-' + dataset_json['version'], + file=sys.stderr) + dataset = dataset_json['data'] + with open(args.prediction_file) as prediction_file: + predictions = json.load(prediction_file) + print(json.dumps(evaluate(dataset, predictions, args.answer_language))) diff --git a/examples/X-STA/src/data/qa.py b/examples/X-STA/src/data/qa.py new file mode 100644 index 0000000..dd233a7 --- /dev/null +++ b/examples/X-STA/src/data/qa.py @@ -0,0 +1,796 @@ +import json +import logging +import os +from functools import partial +from multiprocessing import Pool, cpu_count + +import numpy as np +from tqdm import tqdm + +from transformers.models.bert.tokenization_bert import whitespace_tokenize +from .utils import DataProcessor +import torch +from torch.utils.data import TensorDataset + +from src.data.squad_metrics import compute_f1 + +logger = logging.getLogger(__name__) + + +# def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): +# """Returns tokenized answer spans that better match the annotated answer.""" +# tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) + +# for new_start in range(input_start, input_end + 1): +# for new_end in range(input_end, new_start - 1, -1): +# text_span = " ".join(doc_tokens[new_start : (new_end + 1)]) +# if text_span == tok_answer_text: +# return (new_start, new_end) + +# return (input_start, input_end) + +def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): + """Returns tokenized answer spans that better match the annotated answer.""" + tok_answer_text = tokenizer.convert_tokens_to_string(tokenizer.tokenize(orig_answer_text)) + + for new_start in range(input_start, input_end + 1): + for new_end in range(input_end, new_start - 1, -1): + text_span = tokenizer.convert_tokens_to_string(doc_tokens[new_start:(new_end + 1)]) + if text_span.strip() == tok_answer_text.strip(): + return (new_start, new_end) + + max_f1 = 0 + max_start = -1 + max_end = -1 + for new_start in range(input_start, input_end + 1): + for new_end in range(input_end, new_start - 1, -1): + text_span = tokenizer.convert_tokens_to_string(doc_tokens[new_start:(new_end + 1)]) + cur_f1 = compute_f1(tok_answer_text.strip(), text_span.strip()) + if cur_f1 > max_f1: + max_f1 = cur_f1 + max_start = new_start + max_end = new_end + + if max_start == -1 and max_end == -1: + max_start = input_start + max_end = input_end + + return (max_start, max_end) + + +def _check_is_max_context(doc_spans, cur_span_index, position): + """Check if this is the 'max context' doc span for the token.""" + best_score = None + best_span_index = None + for (span_index, doc_span) in enumerate(doc_spans): + end = doc_span.start + doc_span.length - 1 + if position < doc_span.start: + continue + if position > end: + continue + num_left_context = position - doc_span.start + num_right_context = end - position + score = min(num_left_context, num_right_context) + 0.01 * doc_span.length + if best_score is None or score > best_score: + best_score = score + best_span_index = span_index + + return cur_span_index == best_span_index + + +def _new_check_is_max_context(doc_spans, cur_span_index, position): + """Check if this is the 'max context' doc span for the token.""" + # if len(doc_spans) == 1: + # return True + best_score = None + best_span_index = None + for (span_index, doc_span) in enumerate(doc_spans): + end = doc_span["start"] + doc_span["length"] - 1 + if position < doc_span["start"]: + continue + if position > end: + continue + num_left_context = position - doc_span["start"] + num_right_context = end - position + score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"] + if best_score is None or score > best_score: + best_score = score + best_span_index = span_index + + return cur_span_index == best_span_index + + +def _is_whitespace(c): + if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: + return True + return False + + +def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training): + features = [] + if is_training and not example.is_impossible: + # Get start and end position + start_position = example.start_position + end_position = example.end_position + + # If the answer cannot be found in the text, then skip this example. + actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)]) + cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) + if actual_text.find(cleaned_answer_text) == -1: + logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) + return [] + + tok_to_orig_index = [] + orig_to_tok_index = [] + all_doc_tokens = [] + for (i, token) in enumerate(example.doc_tokens): + orig_to_tok_index.append(len(all_doc_tokens)) + sub_tokens = tokenizer.tokenize(token) + for sub_token in sub_tokens: + tok_to_orig_index.append(i) + all_doc_tokens.append(sub_token) + + if is_training and not example.is_impossible: + tok_start_position = orig_to_tok_index[example.start_position] + if example.end_position < len(example.doc_tokens) - 1: + tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 + else: + tok_end_position = len(all_doc_tokens) - 1 + + (tok_start_position, tok_end_position) = _improve_answer_span( + all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text + ) + + spans = [] + + truncated_query = tokenizer.encode( + example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length + ) + + sequence_added_tokens = ( + tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1 + if "roberta" in str(type(tokenizer)) or "camembert" in str(type(tokenizer)) + else tokenizer.model_max_length - tokenizer.max_len_single_sentence + ) + sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair + + span_doc_tokens = all_doc_tokens + while len(spans) * doc_stride < len(all_doc_tokens): + + # Define the side we want to truncate / pad and the text/pair sorting + if tokenizer.padding_side == "right": + texts = truncated_query + pairs = span_doc_tokens + truncation = "only_second" + else: + texts = span_doc_tokens + pairs = truncated_query + truncation = "only_first" + + encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic + texts, + pairs, + truncation=truncation, + padding="max_length", + max_length=max_seq_length, + return_overflowing_tokens=True, + stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, + return_token_type_ids=True, + ) + + paragraph_len = min( + len(all_doc_tokens) - len(spans) * doc_stride, + max_seq_length - len(truncated_query) - sequence_pair_added_tokens, + ) + + if tokenizer.pad_token_id in encoded_dict["input_ids"]: + if tokenizer.padding_side == "right": + non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)] + else: + last_padding_id_position = ( + len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id) + ) + non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :] + + else: + non_padded_ids = encoded_dict["input_ids"] + + tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) + + token_to_orig_map = {} + for i in range(paragraph_len): + index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i + token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i] + + encoded_dict["paragraph_len"] = paragraph_len + encoded_dict["tokens"] = tokens + encoded_dict["token_to_orig_map"] = token_to_orig_map + encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens + encoded_dict["token_is_max_context"] = {} + encoded_dict["start"] = len(spans) * doc_stride + encoded_dict["length"] = paragraph_len + + spans.append(encoded_dict) + + if "overflowing_tokens" not in encoded_dict or ( + "overflowing_tokens" in encoded_dict and len(encoded_dict["overflowing_tokens"]) == 0 + ): + break + span_doc_tokens = encoded_dict["overflowing_tokens"] + + for doc_span_index in range(len(spans)): + for j in range(spans[doc_span_index]["paragraph_len"]): + is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) + index = ( + j + if tokenizer.padding_side == "left" + else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j + ) + spans[doc_span_index]["token_is_max_context"][index] = is_max_context + + for span in spans: + # Identify the position of the CLS token + cls_index = span["input_ids"].index(tokenizer.cls_token_id) + + # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) + # Original TF implem also keep the classification token (set to 0) (not sure why...) + p_mask = np.array(span["token_type_ids"]) + + p_mask = np.minimum(p_mask, 1) + + if tokenizer.padding_side == "right": + # Limit positive values to one + p_mask = 1 - p_mask + + p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1 + + # Set the CLS index to '0' + p_mask[cls_index] = 0 + + span_is_impossible = example.is_impossible + start_position = 0 + end_position = 0 + if is_training and not span_is_impossible: + # For training, if our document chunk does not contain an annotation + # we throw it out, since there is nothing to predict. + doc_start = span["start"] + doc_end = span["start"] + span["length"] - 1 + out_of_span = False + + if not (tok_start_position >= doc_start and tok_end_position <= doc_end): + out_of_span = True + + if out_of_span: + start_position = cls_index + end_position = cls_index + span_is_impossible = True + else: + if tokenizer.padding_side == "left": + doc_offset = 0 + else: + doc_offset = len(truncated_query) + sequence_added_tokens + + start_position = tok_start_position - doc_start + doc_offset + end_position = tok_end_position - doc_start + doc_offset + + features.append( + SquadFeatures( + span["input_ids"], + span["attention_mask"], + span["token_type_ids"], + cls_index, + p_mask.tolist(), + qas_id=example.qas_id, + example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing. + unique_id=0, + paragraph_len=span["paragraph_len"], + token_is_max_context=span["token_is_max_context"], + tokens=span["tokens"], + token_to_orig_map=span["token_to_orig_map"], + start_position=start_position, + end_position=end_position, + is_impossible=span_is_impossible, + query_len=len(truncated_query) + ) + ) + return features + + +def squad_convert_example_to_features_init(tokenizer_for_convert): + global tokenizer + tokenizer = tokenizer_for_convert + + +def squad_convert_examples_to_features( + examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False, threads=1, lang_id=0 +): + """ + Converts a list of examples into a list of features that can be directly given as input to a model. + It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. + + Args: + examples: list of :class:`~transformers.data.processors.squad.SquadExample` + tokenizer: an instance of a child of :class:`~transformers.PreTrainedTokenizer` + max_seq_length: The maximum sequence length of the inputs. + doc_stride: The stride used when the context is too large and is split across several features. + max_query_length: The maximum length of the query. + is_training: whether to create features for model evaluation or model training. + return_dataset: Default False. Either 'pt' or 'tf'. + if 'pt': returns a torch.data.TensorDataset, + if 'tf': returns a tf.data.Dataset + threads: multiple processing threadsa-smi + + + Returns: + list of :class:`~transformers.data.processors.squad.SquadFeatures` + + Example:: + + processor = SquadV2Processor() + examples = processor.get_dev_examples(data_dir) + + features = squad_convert_examples_to_features( + examples=examples, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + doc_stride=args.doc_stride, + max_query_length=args.max_query_length, + is_training=not evaluate, + ) + """ + + # Defining helper methods + features = [] + threads = min(threads, cpu_count()) + with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p: + annotate_ = partial( + squad_convert_example_to_features, + max_seq_length=max_seq_length, + doc_stride=doc_stride, + max_query_length=max_query_length, + is_training=is_training, + ) + features = list( + tqdm( + p.imap(annotate_, examples, chunksize=32), + total=len(examples), + desc="convert squad examples to features", + ) + ) + new_features = [] + unique_id = 1000000000 + example_index = 0 + for example_features in tqdm(features, total=len(features), desc="add example index and unique id"): + if not example_features: + continue + for example_feature in example_features: + if example_feature.start_position >= max_seq_length or example_feature.end_position >= max_seq_length or example_feature.end_position < example_feature.start_position: + continue + + example_feature.example_index = example_index + example_feature.unique_id = unique_id + new_features.append(example_feature) + unique_id += 1 + example_index += 1 + features = new_features + del new_features + if return_dataset == "pt": + # Convert to Tensors and build dataset + all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) + all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long) + all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) + all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) + all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) + all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float) + all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) + + all_lang_id = torch.tensor([lang_id] * len(features), dtype=torch.long) + + if not is_training: + dataset = TensorDataset( + all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask, all_lang_id + ) + else: + all_query_len = torch.tensor([f.query_len for f in features], dtype=torch.long) + all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) + all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) + dataset = TensorDataset( + all_input_ids, + all_attention_masks, + all_token_type_ids, + all_start_positions, + all_end_positions, + all_cls_index, + all_p_mask, + all_is_impossible, + all_example_index, + all_query_len, + all_lang_id + ) + + return features, dataset + elif return_dataset == "tf": + raise NotImplementedError() + + return features + + +class SquadProcessor(DataProcessor): + """ + Processor for the SQuAD data set. + Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively. + """ + + train_file = None + dev_file = None + + def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False): + if not evaluate: + answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8") + answer_start = tensor_dict["answers"]["answer_start"][0].numpy() + answers = [] + else: + answers = [ + {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")} + for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"]) + ] + + answer = None + answer_start = None + + return SquadExample( + qas_id=tensor_dict["id"].numpy().decode("utf-8"), + question_text=tensor_dict["question"].numpy().decode("utf-8"), + context_text=tensor_dict["context"].numpy().decode("utf-8"), + answer_text=answer, + start_position_character=answer_start, + title=tensor_dict["title"].numpy().decode("utf-8"), + answers=answers, + ) + + def get_examples_from_dataset(self, dataset, evaluate=False): + """ + Creates a list of :class:`~transformers.data.processors.squad.SquadExample` using a TFDS dataset. + + Args: + dataset: The tfds dataset loaded from `tensorflow_datasets.load("squad")` + evaluate: boolean specifying if in evaluation mode or in training mode + + Returns: + List of SquadExample + + Examples:: + + import tensorflow_datasets as tfds + dataset = tfds.load("squad") + + training_examples = get_examples_from_dataset(dataset, evaluate=False) + evaluation_examples = get_examples_from_dataset(dataset, evaluate=True) + """ + + if evaluate: + dataset = dataset["validation"] + else: + dataset = dataset["train"] + + examples = [] + for tensor_dict in tqdm(dataset): + examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate)) + + return examples + + def get_train_examples(self, data_dir, filename=None, language='en'): + """ + Returns the training examples from the data directory. + + Args: + data_dir: Directory containing the data files used for training and evaluating. + filename: None by default, specify this if the training file has a different name than the original one + which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. + + """ + if data_dir is None: + data_dir = "" + + if self.train_file is None and filename is None: + raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") + + with open( + os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8" + ) as reader: + input_data = json.load(reader)["data"] + return self._create_examples(input_data, "train", language) + + def get_dev_examples(self, data_dir, filename=None, language='en'): + """ + Returns the evaluation example from the data directory. + + Args: + data_dir: Directory containing the data files used for training and evaluating. + filename: None by default, specify this if the evaluation file has a different name than the original one + which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. + """ + if data_dir is None: + data_dir = "" + + if self.dev_file is None and filename is None: + raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") + + with open( + os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8" + ) as reader: + input_data = json.load(reader)["data"] + return self._create_examples(input_data, "dev", language) + + def _create_examples(self, input_data, set_type, language): + is_training = set_type == "train" + examples = [] + for entry in tqdm(input_data): + title = entry["title"] if "title" in entry else None + for paragraph in entry["paragraphs"]: + context_text = paragraph["context"] + for qa in paragraph["qas"]: + qas_id = qa["id"] + + question_text = qa["question"] + start_position_character = None + answer_text = None + answers = [] + + if "is_impossible" in qa: + is_impossible = qa["is_impossible"] + else: + is_impossible = False + + if not is_impossible: + if is_training: + answer = qa["answers"][0] + answer_text = answer["text"] + start_position_character = answer["answer_start"] if answer["answer_start"] >=0 else None + else: + answers = qa["answers"] + + if is_training and start_position_character is None: + continue + + example = SquadExample( + qas_id=qas_id, + question_text=question_text, + context_text=context_text, + answer_text=answer_text, + start_position_character=start_position_character, + title=title, + is_impossible=is_impossible, + answers=answers, + language=language + ) + + examples.append(example) + return examples + + +class SquadV1Processor(SquadProcessor): + train_file = "train-v1.1.json" + dev_file = "dev-v1.1.json" + + +class SquadV2Processor(SquadProcessor): + train_file = "train-v2.0.json" + dev_file = "dev-v2.0.json" + +class MLQAProcessor(SquadProcessor): + + def get_dataset_path(self, data_dir, split, language): + assert split != "train" + return os.path.join(data_dir, "mlqa/MLQA_V1", split, "{0}-context-{1}-question-{1}.json".format(split, language)) + + def get_train_examples_by_language(self, data_dir, language): + if language == 'en': + return self.get_train_examples(data_dir, "squad/train-v1.1.json", language=language) + return self.get_train_examples(data_dir, f"xtreme_translations/SQuAD/translate-train/squad.translate.train.en-{language}.json", language) + + def get_dev_examples_by_language(self, data_dir, language): + return self.get_dev_examples(data_dir, "mlqa/MLQA_V1/dev/dev-context-{0}-question-{0}.json".format(language), language) + + def get_test_examples_by_language(self, data_dir, language): + return self.get_dev_examples(data_dir, "mlqa/MLQA_V1/test/test-context-{0}-question-{0}.json".format(language), language) + + def get_translate_test_examples_by_language(self, data_dir, language): + if language == 'en': + return self.get_dev_examples(data_dir, "mlqa/MLQA_V1/test/test-context-{0}-question-{0}.json".format(language), 'en') + return self.get_dev_examples(data_dir, f"xtreme_translations/MLQA/translate-test/mlqa.translate.test.{language}-en.json", 'en') + +class XquadProcessor(SquadProcessor): + def get_dataset_path(self, data_dir, split, language): + assert split == "dev" + return os.path.join(data_dir, "mlqa/MLQA_V1", split, "{0}-context-{1}-question-{1}.json".format(split, language)) + + def get_train_examples_by_language(self, data_dir, language): + if language == 'en': + return self.get_train_examples(data_dir, "squad/train-v1.1.json", language) + return self.get_train_examples(data_dir, f"xtreme_translations/SQuAD/translate-train/squad.translate.train.en-{language}.json", language) + + def get_dev_examples_by_language(self, data_dir, language): + return self.get_dev_examples(data_dir, "mlqa/MLQA_V1/dev/dev-context-{0}-question-{0}.json".format(language), language) + + def get_test_examples_by_language(self, data_dir, language): + return self.get_dev_examples(data_dir, f"xquad/xquad.{language}.json".format(language), language) + + def get_translate_test_examples_by_language(self, data_dir, language): + if language == 'en': + return self.get_dev_examples(data_dir, f"xquad/xquad.{language}.json".format(language), 'en') + return self.get_dev_examples(data_dir, f"xtreme_translations/XQuAD/translate-test/xquad.translate.test.{language}-en.json", 'en') + +class TydiqaProcessor(SquadProcessor): + + def get_train_examples_by_language(self, data_dir, language): + if language == 'en': + return self.get_train_examples(data_dir, "tydiqa/tydiqa-goldp-v1.1-train/tydiqa.goldp.en.train.json", language) + return self.get_train_examples(data_dir, f"xtreme_translations/TyDiQA-GoldP/translate-train/tydiqa.translate.train.en-{language}.json", language) + + def get_dev_examples_by_language(self, data_dir, language): + return self.get_dev_examples(data_dir, f"tydiqa/tydiqa-goldp-v1.1-dev/tydiqa.goldp.{language}.dev.json", language) + + def get_test_examples_by_language(self, data_dir, language): + return self.get_dev_examples(data_dir, f"tydiqa/tydiqa-goldp-v1.1-dev/tydiqa.goldp.{language}.dev.json", language) + + def get_translate_test_examples_by_language(self, data_dir, language): + if language == 'en': + return self.get_dev_examples(data_dir, "tydiqa/tydiqa-goldp-v1.1-dev/tydiqa.goldp.en.dev.json", 'en') + return self.get_dev_examples(data_dir, f"translations/tydiqa/tydiqa.translate.test.{language}-en.json", 'en') + # return self.get_dev_examples(data_dir, f"xtreme_translations/TyDiQA-GoldP/translate-test/tydiqa.translate.test.{language}-en.json", 'en') + + +class SquadExample(object): + """ + A single training/test example for the Squad dataset, as loaded from disk. + + Args: + qas_id: The example's unique identifier + question_text: The question string + context_text: The context string + answer_text: The answer string + start_position_character: The character position of the start of the answer + title: The title of the example + answers: None by default, this is used during evaluation. Holds answers as well as their start positions. + is_impossible: False by default, set to True if the example has no possible answer. + """ + + def __init__( + self, + qas_id, + question_text, + context_text, + answer_text, + start_position_character, + title, + answers=[], + is_impossible=False, + language='en' + ): + self.qas_id = qas_id + self.question_text = question_text + self.context_text = context_text + self.answer_text = answer_text + self.title = title + self.is_impossible = is_impossible + self.answers = answers + self.language = language + + self.start_position, self.end_position = 0, 0 + + doc_tokens = [] + char_to_word_offset = [] + prev_is_whitespace = True + + # Split on whitespace so that different tokens may be attributed to their original position. + for c in self.context_text: + if _is_whitespace(c): + prev_is_whitespace = True + else: + if prev_is_whitespace: + doc_tokens.append(c) + else: + doc_tokens[-1] += c + prev_is_whitespace = False + char_to_word_offset.append(len(doc_tokens) - 1) + + self.doc_tokens = doc_tokens + self.char_to_word_offset = char_to_word_offset + + # Start and end positions only has a value during evaluation. + if start_position_character is not None and not is_impossible: + self.start_position = char_to_word_offset[start_position_character] + self.end_position = char_to_word_offset[ + min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1) + ] + + +class SquadFeatures(object): + """ + Single squad example features to be fed to a model. + Those features are model-specific and can be crafted from :class:`~transformers.data.processors.squad.SquadExample` + using the :method:`~transformers.data.processors.squad.squad_convert_examples_to_features` method. + + Args: + input_ids: Indices of input sequence tokens in the vocabulary. + attention_mask: Mask to avoid performing attention on padding token indices. + token_type_ids: Segment token indices to indicate first and second portions of the inputs. + cls_index: the index of the CLS token. + p_mask: Mask identifying tokens that can be answers vs. tokens that cannot. + Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer + example_index: the index of the example + unique_id: The unique Feature identifier + paragraph_len: The length of the context + token_is_max_context: List of booleans identifying which tokens have their maximum context in this feature object. + If a token does not have their maximum context in this feature object, it means that another feature object + has more information related to that token and should be prioritized over this feature for that token. + tokens: list of tokens corresponding to the input ids + token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer. + start_position: start of the answer token index + end_position: end of the answer token index + """ + + def __init__( + self, + input_ids, + attention_mask, + token_type_ids, + cls_index, + p_mask, + example_index, + unique_id, + paragraph_len, + token_is_max_context, + tokens, + token_to_orig_map, + start_position, + end_position, + is_impossible, + start_logits=None, + end_logits=None, + query_len=None, + qas_id=None + ): + self.input_ids = input_ids + self.attention_mask = attention_mask + self.token_type_ids = token_type_ids + self.cls_index = cls_index + self.p_mask = p_mask + + self.example_index = example_index + self.unique_id = unique_id + self.paragraph_len = paragraph_len + self.token_is_max_context = token_is_max_context + self.tokens = tokens + self.token_to_orig_map = token_to_orig_map + + self.start_position = start_position + self.end_position = end_position + self.is_impossible = is_impossible + self.start_soft_position = start_logits + self.end_soft_position = end_logits + + self.query_len = query_len + self.qas_id = qas_id + + +class SquadResult(object): + """ + Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset. + + Args: + unique_id: The unique identifier corresponding to that example. + start_logits: The logits corresponding to the start of the answer + end_logits: The logits corresponding to the end of the answer + """ + + def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None): + self.start_logits = start_logits + self.end_logits = end_logits + self.unique_id = unique_id + + if start_top_index: + self.start_top_index = start_top_index + self.end_top_index = end_top_index + self.cls_logits = cls_logits diff --git a/examples/X-STA/src/data/squad_metrics.py b/examples/X-STA/src/data/squad_metrics.py new file mode 100644 index 0000000..b59fba3 --- /dev/null +++ b/examples/X-STA/src/data/squad_metrics.py @@ -0,0 +1,765 @@ +""" Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was +modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0 + +In addition to basic functionality, we also compute additional statistics and +plot precision-recall curves if an additional na_prob.json file is provided. +This file is expected to map question ID's to the model's predicted probability +that a question is unanswerable. +""" + + +import collections +import json +import logging +import math +import re +import string + +from transformers.models.bert.tokenization_bert import BasicTokenizer + + +logger = logging.getLogger(__name__) + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) + return re.sub(regex, " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def get_tokens(s): + if not s: + return [] + return normalize_answer(s).split() + + +def compute_exact(a_gold, a_pred): + return int(normalize_answer(a_gold) == normalize_answer(a_pred)) + + +def compute_f1(a_gold, a_pred): + gold_toks = get_tokens(a_gold) + pred_toks = get_tokens(a_pred) + common = collections.Counter(gold_toks) & collections.Counter(pred_toks) + num_same = sum(common.values()) + if len(gold_toks) == 0 or len(pred_toks) == 0: + # If either is no-answer, then F1 is 1 if they agree, 0 otherwise + return int(gold_toks == pred_toks) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(pred_toks) + recall = 1.0 * num_same / len(gold_toks) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def get_raw_scores(examples, preds): + """ + Computes the exact and f1 scores from the examples and the model predictions + """ + exact_scores = {} + f1_scores = {} + + for example in examples: + qas_id = example.qas_id + gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])] + + if not gold_answers: + # For unanswerable questions, only correct answer is empty string + gold_answers = [""] + + if qas_id not in preds: + print("Missing prediction for %s" % qas_id) + continue + + prediction = preds[qas_id] + exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers) + f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers) + + return exact_scores, f1_scores + + +def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): + new_scores = {} + for qid, s in scores.items(): + pred_na = na_probs[qid] > na_prob_thresh + if pred_na: + new_scores[qid] = float(not qid_to_has_ans[qid]) + else: + new_scores[qid] = s + return new_scores + + +def make_eval_dict(exact_scores, f1_scores, qid_list=None): + if not qid_list: + total = len(exact_scores) + return collections.OrderedDict( + [ + ("exact_match", 100.0 * sum(exact_scores.values()) / total), + ("f1", 100.0 * sum(f1_scores.values()) / total), + ("total", total), + ] + ) + else: + total = len(qid_list) + return collections.OrderedDict( + [ + ("exact_match", 100.0 * sum(exact_scores[k] for k in qid_list) / total), + ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total), + ("total", total), + ] + ) + + +def merge_eval(main_eval, new_eval, prefix): + for k in new_eval: + main_eval["%s_%s" % (prefix, k)] = new_eval[k] + + +def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): + num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) + cur_score = num_no_ans + best_score = cur_score + best_thresh = 0.0 + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + for i, qid in enumerate(qid_list): + if qid not in scores: + continue + if qid_to_has_ans[qid]: + diff = scores[qid] + else: + if preds[qid]: + diff = -1 + else: + diff = 0 + cur_score += diff + if cur_score > best_score: + best_score = cur_score + best_thresh = na_probs[qid] + + has_ans_score, has_ans_cnt = 0, 0 + for qid in qid_list: + if not qid_to_has_ans[qid]: + continue + has_ans_cnt += 1 + + if qid not in scores: + continue + has_ans_score += scores[qid] + + return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt + + +def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): + best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans) + best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans) + main_eval["best_exact"] = best_exact + main_eval["best_exact_thresh"] = exact_thresh + main_eval["best_f1"] = best_f1 + main_eval["best_f1_thresh"] = f1_thresh + main_eval["has_ans_exact"] = has_ans_exact + main_eval["has_ans_f1"] = has_ans_f1 + + +def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): + num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) + cur_score = num_no_ans + best_score = cur_score + best_thresh = 0.0 + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + for _, qid in enumerate(qid_list): + if qid not in scores: + continue + if qid_to_has_ans[qid]: + diff = scores[qid] + else: + if preds[qid]: + diff = -1 + else: + diff = 0 + cur_score += diff + if cur_score > best_score: + best_score = cur_score + best_thresh = na_probs[qid] + return 100.0 * best_score / len(scores), best_thresh + + +def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): + best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) + best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) + + main_eval["best_exact"] = best_exact + main_eval["best_exact_thresh"] = exact_thresh + main_eval["best_f1"] = best_f1 + main_eval["best_f1_thresh"] = f1_thresh + + +def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0): + qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples} + has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer] + no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer] + + if no_answer_probs is None: + no_answer_probs = {k: 0.0 for k in preds} + + exact, f1 = get_raw_scores(examples, preds) + + exact_threshold = apply_no_ans_threshold( + exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold + ) + f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold) + + evaluation = make_eval_dict(exact_threshold, f1_threshold) + + if has_answer_qids: + has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids) + merge_eval(evaluation, has_ans_eval, "HasAns") + + if no_answer_qids: + no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids) + merge_eval(evaluation, no_ans_eval, "NoAns") + + if no_answer_probs: + find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer) + + return evaluation + + +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heuristic between + # `pred_text` and `orig_text` to get a character-to-character alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + if verbose_logging: + logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) + return orig_text + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + if verbose_logging: + logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) + return orig_text + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in tok_ns_to_s_map.items(): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + if verbose_logging: + logger.info("Couldn't map start position") + return orig_text + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + if verbose_logging: + logger.info("Couldn't map end position") + return orig_text + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text + + +def _get_best_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + + +def _compute_softmax(scores): + """Compute softmax probability over raw logits.""" + if not scores: + return [] + + max_score = None + for score in scores: + if max_score is None or score > max_score: + max_score = score + + exp_scores = [] + total_sum = 0.0 + for score in scores: + x = math.exp(score - max_score) + exp_scores.append(x) + total_sum += x + + probs = [] + for score in exp_scores: + probs.append(score / total_sum) + return probs + + +def compute_predictions_logits( + all_examples, + all_features, + all_results, + n_best_size, + max_answer_length, + do_lower_case, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, + verbose_logging, + version_2_with_negative, + null_score_diff_threshold, + tokenizer, + map_to_origin=True, +): + """Write final predictions to the json file and log-odds of null if needed.""" + logger.info("Writing predictions to: %s" % (output_prediction_file)) + logger.info("Writing nbest to: %s" % (output_nbest_file)) + + example_index_to_features = collections.defaultdict(list) + for feature in all_features: + example_index_to_features[feature.example_index].append(feature) + + unique_id_to_result = {} + for result in all_results: + unique_id_to_result[result.unique_id] = result + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"] + ) + + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + scores_diff_json = collections.OrderedDict() + + for (example_index, example) in enumerate(all_examples): + features = example_index_to_features[example_index] + + prelim_predictions = [] + # keep track of the minimum score of null start+end of position 0 + score_null = 1000000 # large and positive + min_null_feature_index = 0 # the paragraph slice with min null score + null_start_logit = 0 # the start logit at the slice with min null score + null_end_logit = 0 # the end logit at the slice with min null score + for (feature_index, feature) in enumerate(features): + result = unique_id_to_result[feature.unique_id] + start_indexes = _get_best_indexes(result.start_logits, n_best_size) + end_indexes = _get_best_indexes(result.end_logits, n_best_size) + # if we could have irrelevant answers, get the min score of irrelevant + if version_2_with_negative: + feature_null_score = result.start_logits[0] + result.end_logits[0] + if feature_null_score < score_null: + score_null = feature_null_score + min_null_feature_index = feature_index + null_start_logit = result.start_logits[0] + null_end_logit = result.end_logits[0] + for start_index in start_indexes: + for end_index in end_indexes: + # We could hypothetically create invalid predictions, e.g., predict + # that the start of the span is in the question. We throw out all + # invalid predictions. + if start_index >= len(feature.tokens): + continue + if end_index >= len(feature.tokens): + continue + if start_index not in feature.token_to_orig_map: + continue + if end_index not in feature.token_to_orig_map: + continue + if not feature.token_is_max_context.get(start_index, False): + continue + if end_index < start_index: + continue + length = end_index - start_index + 1 + if length > max_answer_length: + continue + prelim_predictions.append( + _PrelimPrediction( + feature_index=feature_index, + start_index=start_index, + end_index=end_index, + start_logit=result.start_logits[start_index], + end_logit=result.end_logits[end_index], + ) + ) + if version_2_with_negative: + prelim_predictions.append( + _PrelimPrediction( + feature_index=min_null_feature_index, + start_index=0, + end_index=0, + start_logit=null_start_logit, + end_logit=null_end_logit, + ) + ) + prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) + + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit"] + ) + + seen_predictions = {} + nbest = [] + for pred in prelim_predictions: + if len(nbest) >= n_best_size: + break + feature = features[pred.feature_index] + if pred.start_index > 0: # this is a non-null prediction + tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)] + orig_doc_start = feature.token_to_orig_map[pred.start_index] + orig_doc_end = feature.token_to_orig_map[pred.end_index] + orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)] + + tok_text = tokenizer.convert_tokens_to_string(tok_tokens) + # tok_text = " ".join(tok_tokens) + # + # # De-tokenize WordPieces that have been split off. + # tok_text = tok_text.replace(" ##", "") + # tok_text = tok_text.replace("##", "") + + # Clean whitespace + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + if not map_to_origin: #'zh' in output_prediction_file: + final_text = tok_text + else: + final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) + + # if not map_to_origin: + # final_text = tok_text + if final_text in seen_predictions: + continue + + seen_predictions[final_text] = True + else: + final_text = "" + seen_predictions[final_text] = True + + nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit)) + # if we didn't include the empty option in the n-best, include it + if version_2_with_negative: + if "" not in seen_predictions: + nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit)) + + # In very rare edge cases we could only have single null prediction. + # So we just create a nonce prediction in this case to avoid failure. + if len(nbest) == 1: + nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) + + # In very rare edge cases we could have no valid predictions. So we + # just create a nonce prediction in this case to avoid failure. + if not nbest: + nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) + + assert len(nbest) >= 1 + + total_scores = [] + best_non_null_entry = None + for entry in nbest: + total_scores.append(entry.start_logit + entry.end_logit) + if not best_non_null_entry: + if entry.text: + best_non_null_entry = entry + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_logit"] = entry.start_logit + output["end_logit"] = entry.end_logit + nbest_json.append(output) + + assert len(nbest_json) >= 1 + + if not version_2_with_negative: + all_predictions[example.qas_id] = nbest_json[0]["text"] + else: + # predict "" iff the null score - the score of best non-null > threshold + score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit) + scores_diff_json[example.qas_id] = score_diff + if score_diff > null_score_diff_threshold: + all_predictions[example.qas_id] = "" + else: + all_predictions[example.qas_id] = best_non_null_entry.text + all_nbest_json[example.qas_id] = nbest_json + + with open(output_prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + + with open(output_nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + + if version_2_with_negative: + with open(output_null_log_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions + + +def compute_predictions_log_probs( + all_examples, + all_features, + all_results, + n_best_size, + max_answer_length, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, + start_n_top, + end_n_top, + version_2_with_negative, + tokenizer, + verbose_logging, +): + """ XLNet write prediction logic (more complex than Bert's). + Write final predictions to the json file and log-odds of null if needed. + + Requires utils_squad_evaluate.py + """ + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"] + ) + + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_log_prob", "end_log_prob"] + ) + + logger.info("Writing predictions to: %s", output_prediction_file) + # logger.info("Writing nbest to: %s" % (output_nbest_file)) + + example_index_to_features = collections.defaultdict(list) + for feature in all_features: + example_index_to_features[feature.example_index].append(feature) + + unique_id_to_result = {} + for result in all_results: + unique_id_to_result[result.unique_id] = result + + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + scores_diff_json = collections.OrderedDict() + + for (example_index, example) in enumerate(all_examples): + features = example_index_to_features[example_index] + + prelim_predictions = [] + # keep track of the minimum score of null start+end of position 0 + score_null = 1000000 # large and positive + + for (feature_index, feature) in enumerate(features): + result = unique_id_to_result[feature.unique_id] + + cur_null_score = result.cls_logits + + # if we could have irrelevant answers, get the min score of irrelevant + score_null = min(score_null, cur_null_score) + + for i in range(start_n_top): + for j in range(end_n_top): + start_log_prob = result.start_logits[i] + start_index = result.start_top_index[i] + + j_index = i * end_n_top + j + + end_log_prob = result.end_logits[j_index] + end_index = result.end_top_index[j_index] + + # We could hypothetically create invalid predictions, e.g., predict + # that the start of the span is in the question. We throw out all + # invalid predictions. + if start_index >= feature.paragraph_len - 1: + continue + if end_index >= feature.paragraph_len - 1: + continue + + if not feature.token_is_max_context.get(start_index, False): + continue + if end_index < start_index: + continue + length = end_index - start_index + 1 + if length > max_answer_length: + continue + + prelim_predictions.append( + _PrelimPrediction( + feature_index=feature_index, + start_index=start_index, + end_index=end_index, + start_log_prob=start_log_prob, + end_log_prob=end_log_prob, + ) + ) + + prelim_predictions = sorted( + prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True + ) + + seen_predictions = {} + nbest = [] + for pred in prelim_predictions: + if len(nbest) >= n_best_size: + break + feature = features[pred.feature_index] + + # XLNet un-tokenizer + # Let's keep it simple for now and see if we need all this later. + # + # tok_start_to_orig_index = feature.tok_start_to_orig_index + # tok_end_to_orig_index = feature.tok_end_to_orig_index + # start_orig_pos = tok_start_to_orig_index[pred.start_index] + # end_orig_pos = tok_end_to_orig_index[pred.end_index] + # paragraph_text = example.paragraph_text + # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip() + + # Previously used Bert untokenizer + tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)] + orig_doc_start = feature.token_to_orig_map[pred.start_index] + orig_doc_end = feature.token_to_orig_map[pred.end_index] + orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)] + tok_text = tokenizer.convert_tokens_to_string(tok_tokens) + + # Clean whitespace + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + + if hasattr(tokenizer, "do_lower_case"): + do_lower_case = tokenizer.do_lower_case + else: + do_lower_case = tokenizer.do_lowercase_and_remove_accent + + final_text = tok_text + #final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) + + if final_text in seen_predictions: + continue + + seen_predictions[final_text] = True + + nbest.append( + _NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob) + ) + + # In very rare edge cases we could have no valid predictions. So we + # just create a nonce prediction in this case to avoid failure. + if not nbest: + nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6)) + + total_scores = [] + best_non_null_entry = None + for entry in nbest: + total_scores.append(entry.start_log_prob + entry.end_log_prob) + if not best_non_null_entry: + best_non_null_entry = entry + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_log_prob"] = entry.start_log_prob + output["end_log_prob"] = entry.end_log_prob + nbest_json.append(output) + + assert len(nbest_json) >= 1 + assert best_non_null_entry is not None + + score_diff = score_null + scores_diff_json[example.qas_id] = score_diff + # note(zhiliny): always predict best_non_null_entry + # and the evaluation script will search for the best threshold + all_predictions[example.qas_id] = best_non_null_entry.text + + all_nbest_json[example.qas_id] = nbest_json + + with open(output_prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + + with open(output_nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + + if version_2_with_negative: + with open(output_null_log_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions diff --git a/examples/X-STA/src/data/utils.py b/examples/X-STA/src/data/utils.py new file mode 100644 index 0000000..d4c250f --- /dev/null +++ b/examples/X-STA/src/data/utils.py @@ -0,0 +1,106 @@ +import csv +import logging +import json +import copy +from typing import List + +logger = logging.getLogger(__name__) + + +class DataProcessor: + """Base class for data converters for sequence classification data sets.""" + + def get_train_examples(self, data_dir): + """Gets a collection of :class:`InputExample` for the train set.""" + raise NotImplementedError() + + def get_dev_examples(self, data_dir): + """Gets a collection of :class:`InputExample` for the dev set.""" + raise NotImplementedError() + + def get_test_examples(self, data_dir): + """Gets a collection of :class:`InputExample` for the test set.""" + raise NotImplementedError() + + def get_labels(self) -> List[str]: + """Gets the list of all labels for this data set.""" + raise NotImplementedError() + + @classmethod + def _read_tsv(cls, input_file, quotechar=None): + """Reads a tab separated value file.""" + with open(input_file, "r") as f: + return list(csv.reader(f, delimiter="\t", quotechar=quotechar)) + +class InputExample(object): + """ + A single training/test example for simple sequence classification. + + Args: + guid: Unique id for the example. + text_a: string. The untokenized text of the first sequence. For single + sequence tasks, only this sequence must be specified. + text_b: (Optional) string. The untokenized text of the second sequence. + Only must be specified for sequence pair tasks. + label: (Optional) string. The label of the example. This should be + specified for train and dev examples, but not for test examples. + """ + + def __init__(self, guid, text_a, text_b=None, label=None): + self.guid = guid + self.text_a = text_a + self.text_b = text_b + self.label = label + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + +class InputFeatures(object): + """ + A single set of features of data. + + Args: + input_ids: Indices of input sequence tokens in the vocabulary. + attention_mask: Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. + token_type_ids: Segment token indices to indicate first and second portions of the inputs. + label: Label corresponding to the input + """ + + def __init__(self, input_ids, attention_mask=None, token_type_ids=None, label=None, guid=None): + self.input_ids = input_ids + self.attention_mask = attention_mask + self.token_type_ids = token_type_ids + self.label = label + self.guid = guid + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + +class XInputFeatures(InputFeatures): + def __init__(self, input_ids, attention_mask=None, token_type_ids=None, label=None, lang_id=None, guid=None): + self.input_ids = input_ids + self.attention_mask = attention_mask + self.token_type_ids = token_type_ids + self.label = label + self.guid = guid + self.lang_id = lang_id diff --git a/examples/X-STA/src/utils.py b/examples/X-STA/src/utils.py new file mode 100644 index 0000000..3fc0c18 --- /dev/null +++ b/examples/X-STA/src/utils.py @@ -0,0 +1,236 @@ +import os +import torch +import torch.nn as nn +from torch.utils.data import RandomSampler +from torch.utils.data.distributed import DistributedSampler +import logging +import math + +from sklearn.metrics import f1_score, average_precision_score +import numpy as np + + +logger = logging.getLogger(__name__) + +def save_model(args, model, tokenizer): + # Create output directory if needed + + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + logger.info("Saving model checkpoint to %s", args.output_dir) + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + # Take care of distributed/parallel training + model_to_save = model.module if hasattr(model, "module") else model + model_to_save.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(args, os.path.join(args.output_dir, "training_args.bin")) + + +class BatchContinuousRandomSampler(RandomSampler): + """ make sure examples with same language in batch """ + + def __init__(self, data_source, replacement: bool = False, + num_samples = None, generator=None, batch_size=None) -> None: + + super().__init__(data_source, replacement, num_samples, generator) + self.batch_size = batch_size + + def __iter__(self): + n = len(self.data_source) + if self.generator is None: + generator = torch.Generator() + generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) + else: + generator = self.generator + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() + else: + l = torch.randperm(n // self.batch_size, generator=self.generator).tolist() + l2 = [] + for x in l: + for i in range(self.batch_size): + l2.append(x * self.batch_size + i) + + yield from l2 + +class BatchContinuousDistributedSampler(DistributedSampler): + """ make sure examples with same language in batch """ + + def __init__(self, dataset, num_replicas = None, + rank = None, shuffle = True, + seed = 0, drop_last = False, batch_size=None) -> None: + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + self.batch_size = batch_size + + n = len(dataset) + + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(n // self.batch_size, generator=g).tolist() # type: ignore + new_indices = [] + for x in indices: + for i in range(self.batch_size): + new_indices.append(x * self.batch_size + i) + + self.indices = new_indices[self.rank:self.total_size:self.num_replicas] + + def __iter__(self): + return iter(self.indices) + + def __len__(self) -> int: + return len(self.indices) + +class MLPLayer(nn.Module): + """ + Head for getting sentence representations over RoBERTa/BERT's CLS representation. + """ + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, features, **kwargs): + x = self.dense(features) + x = self.activation(x) + + return x + +class AttentionTeacher(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + attention_mask=None, + + query=None, # target_hidden_states + key=None, # source_hidden_states + value=None, # source_logits + ): + + mixed_query_layer = self.query(query) + mixed_key_layer = self.key(key) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) # batch_size, num_head, seq_len, seq_len + + output = torch.matmul(attention_probs, value.unsqueeze(1).repeat(1, self.num_attention_heads, 1, 1)) + output = torch.mean(output, dim=1, keepdim=True).squeeze(1) + + return output + +class ContrastiveLoss(nn.Module): + def __init__(self, config, temp=0.05): + super().__init__() + self.mlp = MLPLayer(config) + self.temp = temp + + self.cos = nn.CosineSimilarity(dim=-1) + + def forward(self, x, y): + x = self.mlp(x) + y = self.mlp(y) + + cos_sim = self.cos(x.unsqueeze(1), y.unsqueeze(0)) / self.temp + + labels = torch.arange(cos_sim.size(0)).long().to(x.device) + loss_fct = nn.CrossEntropyLoss() + + loss = loss_fct(cos_sim, labels) + return loss + +class BatchNorm(nn.Module): + def __init__(self, hidden_size=768, eps=1e-8, momentum=0.1): + super().__init__() + self.eps = eps + self.momentum = momentum + self.hidden_size = hidden_size + + self.register_buffer('running_mean', torch.zeros(hidden_size)) + self.register_buffer('running_var', torch.ones(hidden_size)) + + def forward(self, input, attention_mask=None): + if self.training: + exponential_average_factor = self.momentum + + mean = input.mean((0, 1)) + var = input.var((0, 1), unbiased=False) + + if attention_mask is not None: + mean = ((input * attention_mask[:, :, None]).sum(1) / attention_mask.sum(-1)[:, None]).mean(0) + var = torch.pow(input * attention_mask[:, :, None] \ + - mean[None, None, :] * attention_mask[:, :, None], 2).sum((0, 1)) \ + / attention_mask.sum() + else: + mean = input.mean((0, 1)) + var = input.var((0, 1), unbiased=False) + + with torch.no_grad(): + self.running_mean = exponential_average_factor * mean \ + + (1 - exponential_average_factor) * self.running_mean + self.running_var = exponential_average_factor * var \ + + (1 - exponential_average_factor) * self.running_var + else: + mean = self.running_mean + var = self.running_var + + return (input - mean[None, None, :]) / torch.sqrt(var[None, None, :] + self.eps) + +def get_attention_entropy(attention_score): + """Get attention entropy based on attention score.""" + + bz, n_heads, seq_len_q, seq_len_k = attention_score.size() + attention_score = attention_score.mean(dim=1) + # (batch size, seq_len_q, seq_len_k) + attention_entropy = -(attention_score * torch.log(attention_score + 1e-8)) + attention_entropy = attention_entropy.sum(dim=-1) + + mean_attention_entropy = attention_entropy.mean(dim=-1) + + return mean_attention_entropy + +def get_pair_entropy(attention_score): + """Get pairwise attention entropy.""" + entropy1 = get_attention_entropy(attention_score) + attention_score = attention_score.permute(0, 1, 3, 2) + entropy2 = get_attention_entropy(attention_score) + + return entropy1 + entropy2 diff --git a/examples/X-STA/src/xlmr.py b/examples/X-STA/src/xlmr.py new file mode 100644 index 0000000..c6145ed --- /dev/null +++ b/examples/X-STA/src/xlmr.py @@ -0,0 +1,579 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from src.utils import BatchNorm, AttentionTeacher, get_pair_entropy, ContrastiveLoss + +from transformers.models.roberta.modeling_roberta import ( + RobertaSelfAttention, + RobertaIntermediate, + RobertaOutput, + RobertaEmbeddings, + RobertaPooler, + RobertaPreTrainedModel, + RobertaSelfOutput +) + +from transformers.modeling_outputs import ( + QuestionAnsweringModelOutput, +) + +from transformers import XLMRobertaConfig + +class RobertaMixAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = RobertaSelfAttention(config) + self.output = RobertaSelfOutput(config) + self.pruned_heads = set() + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + +class RobertaMixLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = RobertaMixAttention(config) + self.add_cross_attention = config.add_cross_attention + + self.intermediate = RobertaIntermediate(config) + self.output = RobertaOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + add_attention=True, + add_ffn=True + ): + + if add_attention: + attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions + ) + if not add_ffn: + return attention_outputs + attention_output = attention_outputs[0] + else: + attention_output = hidden_states + attention_outputs = (attention_output, None) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + return outputs + +class RobertaMixEncoder(nn.Module): + def __init__(self, config, mix_layer): + super().__init__() + self.config = config + + self.layer = nn.ModuleList([RobertaMixLayer(config) for _ in range(config.num_hidden_layers)]) + self.mix_layer = mix_layer + + self.sigmod = nn.Sigmoid() + self.w = nn.Parameter(torch.tensor(1.0)) + self.b = nn.Parameter(torch.tensor(0.0)) + + self.f = nn.Sequential( + nn.Dropout(config.hidden_dropout_prob), + nn.Linear(config.hidden_size * 2, config.hidden_size) + ) + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + lang_ids=None + ): + + mix_layers = [self.mix_layer] + attention_entropy = None + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + all_attention_entropy = () + + next_decoder_cache = () if use_cache else None + + + attention_mask_en = attention_mask.view(-1, 2, *attention_mask.size()[1:])[:, 0] + attention_mask_trg = attention_mask.view(-1, 2, *attention_mask.size()[1:])[:, 1] + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if i in mix_layers: + hidden_states_en = hidden_states.view(-1, 2, hidden_states.size(-2), hidden_states.size(-1))[:, 0] + hidden_states_trg = hidden_states.view(-1, 2, hidden_states.size(-2), hidden_states.size(-1))[:, 1] + + # trg self attention + self_attention_output = layer_module.attention.self( + hidden_states_trg, + attention_mask_trg, + layer_head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions + )[0] + + src_lang_id = int(lang_ids[0]) + trg_lang_id = int(lang_ids[1]) + if src_lang_id == trg_lang_id: + hidden_states_en_convert = hidden_states_en + else: + hidden_states_en_convert = hidden_states_en.detach() + self.f(torch.cat([hidden_states_en.detach(), hidden_states_trg.detach()], dim=-1)) + + cross_attention_outputs = layer_module.attention.self( + hidden_states_trg, + attention_mask_trg, + layer_head_mask, + encoder_hidden_states=hidden_states_en_convert, + encoder_attention_mask=attention_mask_en, + past_key_value=past_key_value, + output_attentions=True + ) + + cross_attention_output = cross_attention_outputs[0] + cross_attention_score = cross_attention_outputs[1] + + attention_entropy = get_pair_entropy(cross_attention_score) + + ratio = self.w * 0.3 + self.b + + attention_output = layer_module.attention.output( + ratio * cross_attention_output + (1 - ratio) * self_attention_output, + hidden_states_trg + ) + + # trg ffn + ffn_layer_outputs_trg = layer_module( + attention_output, + attention_mask_trg, + layer_head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + add_attention=False + ) + hidden_states_trg = ffn_layer_outputs_trg[0] + + # src + hidden_states_en = layer_module( + hidden_states_en, + attention_mask_en, + layer_head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions + )[0] + + hidden_states = torch.stack([hidden_states_en, hidden_states_trg], dim=1) + hidden_states = hidden_states.view(-1, *hidden_states.size()[2:]) + else: + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if attention_entropy is not None: + all_attention_entropy = all_attention_entropy + (attention_entropy,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + all_attention_entropy + ] + if v is not None + ) + +class RobertaMixModel(RobertaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config, mix_layer=7, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = RobertaEmbeddings(config) + self.encoder = RobertaMixEncoder(config, mix_layer) + self.pooler = RobertaPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + lang_ids=None + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + lang_ids=lang_ids + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return (sequence_output, pooled_output) + encoder_outputs[1:] + +class XLMRobertaForQuestionAnswering(RobertaPreTrainedModel): + config_class = XLMRobertaConfig + + def __init__(self, config, args, num_lang=2): + super().__init__(config) + + self.num_labels = config.num_labels + + self.roberta = RobertaMixModel(config, mix_layer=args.mix_layer, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + self.init_weights() + + self.teaching_weight = args.teaching_weight + self.align_weight = args.align_weight + self.consist_weight = args.consist_weight + self.alpha = args.alpha + self.norm = args.norm + + self.cl = args.cl + + if self.cl: + self.cl_loss = ContrastiveLoss(config, temp=args.temp) + else: + self.mse_loss = nn.MSELoss() + + if self.teaching_weight > 0: + self.attention_teacher = AttentionTeacher(config) + + if self.norm: + self.bn = nn.ModuleList([BatchNorm(config.hidden_size) for _ in range(num_lang)]) + + self.num_lang = num_lang + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + query_len=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + lang_ids=None, + return_sequence_output=False + ): + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.size(0) + # Number of lang in one instance + # num_lang = input_ids.size(1) + + # Flatten input + input_ids = input_ids.view((-1, input_ids.size(-1))) # (bsz * 2, len) + attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bsz * 2, len) + if lang_ids is not None: + lang_ids = lang_ids.view(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bsz * 2, len) + + + outputs = self.roberta( + input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=False, + lang_ids=lang_ids + ) + + sequence_output = outputs[0] + seq_rep = (sequence_output * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) + + if return_sequence_output: + return sequence_output + attention_entropy = outputs[-1] + + if self.norm: + sequence_output = sequence_output.view(batch_size, 2, sequence_output.size(-2), sequence_output.size(-1)) + + attention_mask_src = attention_mask.view(batch_size, 2, -1)[:, 0] + attention_mask_trg = attention_mask.view(batch_size, 2, -1)[:, 1] + + src_lang_id = int(lang_ids[0]) + trg_lang_id = int(lang_ids[1]) + sequence_output_src = self.bn[src_lang_id](sequence_output[:, 0], attention_mask_src) + sequence_output_trg = self.bn[trg_lang_id](sequence_output[:, 1], attention_mask_trg) + sequence_output = torch.stack([sequence_output_src, sequence_output_trg], dim=1) + + extended_attention_mask = attention_mask[:, None, None, :] + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + sequence_output = sequence_output.view(batch_size, 2, sequence_output.size(-2), sequence_output.size(-1)) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + if self.teaching_weight > 0: + logits_teacher = self.attention_teacher( + query=sequence_output[:, 1].detach(), + key=sequence_output[:, 0].detach(), + value=logits[:, 0].detach(), + attention_mask=extended_attention_mask.view(batch_size, 2, *extended_attention_mask.size()[1:])[:, 0] + ) + + start_logits_t, end_logits_t = logits_teacher.split(1, dim=-1) + start_logits_t = start_logits_t.squeeze(-1).contiguous() + end_logits_t = end_logits_t.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + start_positions = start_positions.view(-1) + end_positions = end_positions.view(-1) + query_len = query_len.view(-1) + + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(-1) + start_positions = start_positions.view(-1).clamp(0, ignored_index) + end_positions = end_positions.view(-1).clamp(0, ignored_index) + + loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss_src = loss_fct(start_logits.view(batch_size, 2, -1)[:, 0], start_positions.view(-1, 2)[:, 0]) + end_loss_src = loss_fct(end_logits.view(batch_size, 2, -1)[:, 0], end_positions.view(-1, 2)[:, 0]) + + start_loss_trg = loss_fct(start_logits.view(batch_size, 2, -1)[:, 1], start_positions.view(-1, 2)[:, 1]) + end_loss_trg = loss_fct(end_logits.view(batch_size, 2, -1)[:, 1], end_positions.view(-1, 2)[:, 1]) + + + loss = self.alpha * (start_loss_src + end_loss_src) / 2 \ + + (1 - self.alpha) * (start_loss_trg + end_loss_trg) / 2 + + if self.teaching_weight > 0: + start_loss_t = loss_fct(start_logits_t, start_positions.view(-1, 2)[:, 1]) + end_loss_t = loss_fct(end_logits_t, end_positions.view(-1, 2)[:, 1]) + loss += self.teaching_weight * (start_loss_t + end_loss_t) / 2 + + + loss += self.align_weight * attention_entropy[0].mean() + + seq_rep = seq_rep.view(batch_size, 2, -1) + if self.cl: + loss += self.consist_weight * self.cl_loss(seq_rep[:, 0], seq_rep[:, 1]) + else: + loss += self.consist_weight * self.mse_loss(seq_rep[:, 0], seq_rep[:, 1]) + + else: + # predict + start_logits = start_logits[:, 1] + end_logits = end_logits[:, 1] + if self.teaching_weight > 0: + start_logits += start_logits_t + end_logits += end_logits_t + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + else: + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + ) diff --git a/examples/X-STA/third_party/ud-conversion-tools/conll.py b/examples/X-STA/third_party/ud-conversion-tools/conll.py new file mode 100644 index 0000000..af7a8e9 --- /dev/null +++ b/examples/X-STA/third_party/ud-conversion-tools/conll.py @@ -0,0 +1,390 @@ +import networkx as nx +from collections import Counter +import re + + +#TODO make these parse functions static methods of ConllReder +def parse_id(id_str): + if id_str == '_': + return None + if "." in id_str: + return None + ids = tuple(map(int, id_str.split("-"))) + if len(ids) == 1: + return ids[0] + else: + return ids + +def parse_feats(feats_str): + if feats_str == '_': + return {} + feat_pairs = [pair.split("=") for pair in feats_str.split("|")] + return {k: v for k, v in feat_pairs} + +def parse_deps(dep_str): + if dep_str == '_': + return [] + dep_pairs = [pair.split(":") for pair in dep_str.split("|")] + return [(int(pair[0]), pair[1]) for pair in dep_pairs if pair[0].isdigit()] + + + + +class DependencyTree(nx.DiGraph): + """ + A DependencyTree as networkx graph: + nodes store information about tokens + edges store edge related info, e.g. dependency relations + """ + + def __init__(self): + nx.DiGraph.__init__(self) + + def pathtoroot(self, child): + path = [] + newhead = self.head_of(self, child) + while newhead: + path.append(newhead) + newhead = self.head_of(self, newhead) + return path + + def head_of(self, n): + for u, v in self.edges(): + if v == n: + return u + return None + + def get_sentence_as_string(self,printid=False): + out = [] + for token_i in range(1, max(self.nodes()) + 1): + if printid: + out.append(str(token_i)+":"+self.node[token_i]['form']) + else: + out.append(self.node[token_i]['form']) + return u" ".join(out) + + def subsumes(self, head, child): + if head in self.pathtoroot(self, child): + return True + + def remove_arabic_diacritics(self): + # The following code is based on nltk.stem.isri + # It is equivalent to an interative application of isri.norm(word,num=1) + # i.e. we do not remove any hamza characters + + re_short_vowels = re.compile(r'[\u064B-\u0652]') + for n in self.nodes(): + self.node[n]["form"] = re_short_vowels.sub('', self.node[n]["form"]) + + + def get_highest_index_of_span(self, span): # retrieves the node index that is closest to root + #TODO: CANDIDATE FOR DEPRECATION + distancestoroot = [len(self.pathtoroot(self, x)) for x in span] + shortestdistancetoroot = min(distancestoroot) + spanhead = span[distancestoroot.index(shortestdistancetoroot)] + return spanhead + + def get_deepest_index_of_span(self, span): # retrieves the node index that is farthest from root + #TODO: CANDIDATE FOR DEPRECATION + distancestoroot = [len(self.pathtoroot(self, x)) for x in span] + longestdistancetoroot = max(distancestoroot) + lownode = span[distancestoroot.index(longestdistancetoroot)] + return lownode + + def span_makes_subtree(self, initidx, endidx): + G = nx.DiGraph() + span_nodes = list(range(initidx,endidx+1)) + span_words = [self.node[x]["form"] for x in span_nodes] + G.add_nodes_from(span_nodes) + for h,d in self.edges(): + if h in span_nodes and d in span_nodes: + G.add_edge(h,d) + return nx.is_tree(G) + + def _choose_spanhead_from_heuristics(self,span_nodes,pos_precedence_list): + distancestoroot = [len(nx.ancestors(self,x)) for x in span_nodes] + shortestdistancetoroot = min(distancestoroot) + distance_counter = Counter(distancestoroot) + + highest_nodes_in_span = [] + # Heuristic Nr 1: If there is one single highest node in the span, it becomes the head + # N.B. no need for the subspan to be a tree if there is one single highest element + if distance_counter[shortestdistancetoroot] == 1: + spanhead = span_nodes[distancestoroot.index(shortestdistancetoroot)] + return spanhead + + # Heuristic Nr 2: Choose by POS ranking the best head out of the highest nodes + for x in span_nodes: + if len(nx.ancestors(self,x)) == shortestdistancetoroot: + highest_nodes_in_span.append(x) + + best_rank = len(pos_precedence_list) + 1 + candidate_head = - 1 + span_upos = [self.node[x]["cpostag"]for x in highest_nodes_in_span] + for upos, idx in zip(span_upos,highest_nodes_in_span): + if pos_precedence_list.index(upos) < best_rank: + best_rank = pos_precedence_list.index(upos) + candidate_head = idx + return candidate_head + + def _remove_node_properties(self,fields): + for n in sorted(self.nodes()): + for fieldname in self.node[n].keys(): + if fieldname in fields: + self.node[n][fieldname]="_" + + def _remove_deprel_suffixes(self): + for h,d in self.edges(): + if ":" in self[h][d]["deprel"]: + self[h][d]["deprel"]=self[h][d]["deprel"].split(":")[0] + + def _keep_fused_form(self,posPreferenceDicts): + # For a span A,B and external tokens C, such as A > B > C, we have to + # Make A the head of the span + # Attach C-level tokens to A + #Remove B-level tokens, which are the subtokens of the fused form della: de la + + if self.graph["multi_tokens"] == {}: + return + + spanheads = [] + spanhead_fused_token_dict = {} + # This double iteration is overkill, one could skip the spanhead identification + # but in this way we avoid modifying the tree as we read it + for fusedform_idx in sorted(self.graph["multi_tokens"]): + fusedform_start, fusedform_end = self.graph["multi_tokens"][fusedform_idx]["id"] + fuseform_span = list(range(fusedform_start,fusedform_end+1)) + spanhead = self._choose_spanhead_from_heuristics(fuseform_span,posPreferenceDicts) + #if not spanhead: + # spanhead = self._choose_spanhead_from_heuristics(fuseform_span,posPreferenceDicts) + spanheads.append(spanhead) + spanhead_fused_token_dict[spanhead] = fusedform_idx + + # try: + # order = list(nx.topological_sort(self)) + # except nx.NetworkXUnfeasible: + # msg = 'Circular dependency detected between hooks' + # problem_graph = ', '.join(f'{a} -> {b}' + # for a, b in nx.find_cycle(self)) + # print('nx.simple_cycles', list(nx.simple_cycles(self))) + # print(problem_graph) + # exit(0) + # for edge in list(nx.simple_cycles(self)): + # self.remove_edge(edge[0], edge[1]) + self = remove_all_cycle(self) + bottom_up_order = [x for x in nx.topological_sort(self) if x in spanheads] + for spanhead in bottom_up_order: + fusedform_idx = spanhead_fused_token_dict[spanhead] + fusedform = self.graph["multi_tokens"][fusedform_idx]["form"] + fusedform_start, fusedform_end = self.graph["multi_tokens"][fusedform_idx]["id"] + fuseform_span = list(range(fusedform_start,fusedform_end+1)) + + if spanhead: + #Step 1: Replace form of head span (A) with fusedtoken form -- in this way we keep the lemma and features if any + self.node[spanhead]["form"] = fusedform + # 2- Reattach C-level (external dependents) to A + #print(fuseform_span,spanhead) + + internal_dependents = set(fuseform_span) - set([spanhead]) + external_dependents = [nx.bfs_successors(self,x) for x in internal_dependents] + for depdict in external_dependents: + for localhead in depdict: + for ext_dep in depdict[localhead]: + if ext_dep in self[localhead]: + deprel = self[localhead][ext_dep]["deprel"] + self.remove_edge(localhead,ext_dep) + self.add_edge(spanhead,ext_dep,deprel=deprel) + + #3- Remove B-level tokens + for int_dep in internal_dependents: + self.remove_edge(self.head_of(int_dep),int_dep) + self.remove_node(int_dep) + + #4 reconstruct tree at the very end + new_index_dict = {} + for new_node_index, old_node_idex in enumerate(sorted(self.nodes())): + new_index_dict[old_node_idex] = new_node_index + + T = DependencyTree() # Transfer DiGraph, to replace self + + for n in sorted(self.nodes()): + T.add_node(new_index_dict[n],self.node[n]) + + for h, d in self.edges(): + T.add_edge(new_index_dict[h],new_index_dict[d],deprel=self[h][d]["deprel"]) + #4A Quick removal of edges and nodes + self.__init__() + + #4B Rewriting the Deptree in Self + # TODO There must a more elegant way to rewrite self -- self= T for instance? + for n in sorted(T.nodes()): + self.add_node(n,T.node[n]) + + for h,d in T.edges(): + self.add_edge(h,d,T[h][d]) + + # 5. remove all fused forms form the multi_tokens field + self.graph["multi_tokens"] = {} + + # if not nx.is_tree(self): + # print("Not a tree after fused-form heuristics:",self.get_sentence_as_string()) + + def filter_sentence_content(self,replace_subtokens_with_fused_forms=False, lang=None, posPreferenceDict=None,node_properties_to_remove=None,remove_deprel_suffixes=False,remove_arabic_diacritics=False): + if replace_subtokens_with_fused_forms: + self._keep_fused_form(posPreferenceDict) + if remove_deprel_suffixes: + self._remove_deprel_suffixes() + if node_properties_to_remove: + self._remove_node_properties(node_properties_to_remove) + if remove_arabic_diacritics: + self.remove_arabic_diacritics() + +def remove_all_cycle(G): + GC = nx.DiGraph(G.edges()) + edges = list(nx.simple_cycles(GC)) + for edge in edges: + for i in range(len(edge)-1): + for j in range(i+1, len(edge)): + a, b = edge[i], edge[j] + if G.has_edge(a, b): + # print('remove {} - {}'.format(a, b)) + G.remove_edge(a, b) + return G + + +class CoNLLReader(object): + """ + conll input/output + """ + + "" "Static properties""" + CONLL06_COLUMNS = [('id',int), ('form',str), ('lemma',str), ('cpostag',str), ('postag',str), ('feats',str), ('head',int), ('deprel',str), ('phead', str), ('pdeprel',str)] + #CONLL06_COLUMNS = ['id', 'form', 'lemma', 'cpostag', 'postag', 'feats', 'head', 'deprel', 'phead', 'pdeprel'] + CONLL06DENSE_COLUMNS = [('id',int), ('form',str), ('lemma',str), ('cpostag',str), ('postag',str), ('feats',str), ('head',int), ('deprel',str), ('edgew',str)] + CONLL_U_COLUMNS = [('id', parse_id), ('form', str), ('lemma', str), ('cpostag', str), + ('postag', str), ('feats', str), ('head', parse_id), ('deprel', str), + ('deps', parse_deps), ('misc', str)] + #CONLL09_COLUMNS = ['id','form','lemma','plemma','cpostag','pcpostag','feats','pfeats','head','phead','deprel','pdeprel'] + + + + def __init__(self): + pass + + def read_conll_2006(self, filename): + sentences = [] + sent = DependencyTree() + for line_num, conll_line in enumerate(open(filename)): + parts = conll_line.strip().split("\t") + if len(parts) in (8, 10): + token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL06_COLUMNS, parts)} + + sent.add_node(token_dict['id'], token_dict) + sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel']) + elif len(parts) == 0 or (len(parts)==1 and parts[0]==""): + sentences.append(sent) + sent = DependencyTree() + else: + raise Exception("Invalid input format in line nr: ", line_num, conll_line, filename) + + return sentences + + def read_conll_2006_dense(self, filename): + sentences = [] + sent = DependencyTree() + for conll_line in open(filename): + parts = conll_line.strip().split("\t") + if len(parts) == 9: + token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL06DENSE_COLUMNS, parts)} + + sent.add_node(token_dict['id'], token_dict) + sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel']) + elif len(parts) == 0 or (len(parts)==1 and parts[0]==""): + sentences.append(sent) + sent = DependencyTree() + else: + raise Exception("Invalid input format in line: ", conll_line, filename) + + return sentences + + + + def write_conll(self, list_of_graphs, conll_path,conllformat, print_fused_forms=False,print_comments=False): + # TODO add comment writing + if conllformat == "conllu": + columns = [colname for colname, fname in self.CONLL_U_COLUMNS] + else: + columns = [colname for colname, fname in self.CONLL06_COLUMNS] + + with conll_path.open('w') as out: + for sent_i, sent in enumerate(list_of_graphs): + if sent_i > 0: + print("", file=out) + if print_comments: + for c in sent.graph["comment"]: + print(c, file=out) + for token_i in range(1, max(sent.nodes()) + 1): + token_dict = dict(sent.node[token_i]) + head_i = sent.head_of(token_i) + if head_i is None: + token_dict['head'] = 0 + token_dict['deprel'] = '' + else: + token_dict['head'] = head_i + token_dict['deprel'] = sent[head_i][token_i]['deprel'] + token_dict['id'] = token_i + row = [str(token_dict.get(col, '_')) for col in columns] + if print_fused_forms and token_i in sent.graph["multi_tokens"]: + currentmulti = sent.graph["multi_tokens"][token_i] + currentmulti["id"]=str(currentmulti["id"][0])+"-"+str(currentmulti["id"][1]) + currentmulti["feats"]="_" + currentmulti["head"]="_" + rowmulti = [str(currentmulti.get(col, '_')) for col in columns] + print(u"\t".join(rowmulti),file=out) + print(u"\t".join(row), file=out) + + # emtpy line afterwards + print(u"", file=out) + + + def read_conll_u(self,filename,keepFusedForm=False, lang=None, posPreferenceDict=None): + sentences = [] + sent = DependencyTree() + multi_tokens = {} + + for line_no, line in enumerate(open(filename).readlines()): + line = line.strip("\n") + if not line: + # Add extra properties to ROOT node if exists + if 0 in sent: + for key in ('form', 'lemma', 'cpostag', 'postag'): + sent.node[0][key] = 'ROOT' + + # Handle multi-tokens + sent.graph['multi_tokens'] = multi_tokens + multi_tokens = {} + sentences.append(sent) + sent = DependencyTree() + elif line.startswith("#"): + if 'comment' not in sent.graph: + sent.graph['comment'] = [line] + else: + sent.graph['comment'].append(line) + else: + parts = line.split("\t") + if len(parts) != len(self.CONLL_U_COLUMNS): + error_msg = 'Invalid number of columns in line {} (found {}, expected {})'.format(line_no, len(parts), len(CONLL_U_COLUMNS)) + raise Exception(error_msg) + + token_dict = {key: conv_fn(val) for (key, conv_fn), val in zip(self.CONLL_U_COLUMNS, parts)} + if isinstance(token_dict['id'], int): + sent.add_edge(token_dict['head'], token_dict['id'], deprel=token_dict['deprel']) + sent.node[token_dict['id']].update({k: v for (k, v) in token_dict.items() + if k not in ('head', 'id', 'deprel', 'deps')}) + for head, deprel in token_dict['deps']: + sent.add_edge(head, token_dict['id'], deprel=deprel, secondary=True) + elif token_dict['id'] is not None: + #print(token_dict['id']) + first_token_id = int(token_dict['id'][0]) + multi_tokens[first_token_id] = token_dict + return sentences diff --git a/examples/X-STA/third_party/ud-conversion-tools/conllu_to_conll.py b/examples/X-STA/third_party/ud-conversion-tools/conllu_to_conll.py new file mode 100644 index 0000000..689e37d --- /dev/null +++ b/examples/X-STA/third_party/ud-conversion-tools/conllu_to_conll.py @@ -0,0 +1,53 @@ +from collections import defaultdict +from itertools import islice +from pathlib import Path +import argparse +import sys, copy + +from conll import CoNLLReader + +def main(): + parser = argparse.ArgumentParser(description="""Convert conllu to conll format""") + parser.add_argument('input', help="conllu file") + parser.add_argument('output', help="target file", type=Path) + parser.add_argument('--replace_subtokens_with_fused_forms', help="By default removes fused tokens", default=False, action="store_true") + parser.add_argument('--remove_deprel_suffixes', help="Restrict deprels to the common universal subset, e.g. nmod:tmod becomes nmod", default=False, action="store_true") + parser.add_argument('--remove_node_properties', help="space-separated list of node properties to remove: form, lemma, cpostag, postag, feats", choices=['form', 'lemma', 'cpostag','postag','feats'], metavar='prop', type=str, nargs='+') + parser.add_argument('--lang', help="specify a language 2-letter code", default="default") + parser.add_argument('--output_format', choices=['conll2006', 'conll2009', 'conllu'], default="conll2006") + parser.add_argument('--remove_arabic_diacritics', help="remove Arabic short vowels", default=False, action="store_true") + parser.add_argument('--print_comments',default=False,action="store_true") + parser.add_argument('--print_fused_forms',default=False,action="store_true") + + args = parser.parse_args() + + if sys.version_info < (3,0): + print("Sorry, requires Python 3.x.") #suggestion: install anaconda python + sys.exit(1) + + POSRANKPRECEDENCEDICT = defaultdict(list) + POSRANKPRECEDENCEDICT["default"] = "VERB NOUN PROPN PRON ADJ NUM ADV INTJ AUX ADP DET PART CCONJ SCONJ X PUNCT ".split(" ") + # POSRANKPRECEDENCEDICT["de"] = "PROPN ADP DET ".split(" ") + POSRANKPRECEDENCEDICT["es"] = "VERB AUX PRON ADP DET".split(" ") + POSRANKPRECEDENCEDICT["fr"] = "VERB AUX PRON NOUN ADJ ADV ADP DET PART SCONJ CONJ".split(" ") + POSRANKPRECEDENCEDICT["it"] = "VERB AUX ADV PRON ADP DET INTJ".split(" ") + + if args.lang in POSRANKPRECEDENCEDICT: + current_pos_precedence_list = POSRANKPRECEDENCEDICT[args.lang] + else: + current_pos_precedence_list = POSRANKPRECEDENCEDICT["default"] + + cio = CoNLLReader() + orig_treebank = cio.read_conll_u(args.input)#, args.keep_fused_forms, args.lang, POSRANKPRECEDENCEDICT) + modif_treebank = copy.copy(orig_treebank) + + # As per Dec 2015 the args.lang variable is redundant once you have current_pos_precedence_list + # We keep it for future modifications, i.e. any language-specific modules + for s in modif_treebank: + # print('sentence', s.get_sentence_as_string(printid=True)) + s.filter_sentence_content(args.replace_subtokens_with_fused_forms, args.lang, current_pos_precedence_list,args.remove_node_properties,args.remove_deprel_suffixes,args.remove_arabic_diacritics) + + cio.write_conll(modif_treebank,args.output, args.output_format,print_fused_forms=args.print_fused_forms, print_comments=args.print_comments) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/X-STA/third_party/utils_preprocess.py b/examples/X-STA/third_party/utils_preprocess.py new file mode 100644 index 0000000..3da0a5d --- /dev/null +++ b/examples/X-STA/third_party/utils_preprocess.py @@ -0,0 +1,541 @@ +# coding=utf-8 +# Copyright 2020 Google and DeepMind. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function + +import argparse +from transformers import BertTokenizer, XLMTokenizer, XLMRobertaTokenizer +import os +from collections import defaultdict +import csv +import random +import os +import shutil +import json + + +TOKENIZERS = { + 'bert': BertTokenizer, + 'xlm': XLMTokenizer, + 'xlmr': XLMRobertaTokenizer, +} + +def panx_tokenize_preprocess(args): + def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len): + if not os.path.exists(infile): + print(f'{infile} not exists') + return 0 + special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2 + max_seq_len = max_len - special_tokens_count + subword_len_counter = idx = 0 + with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx: + for line in fin: + line = line.strip() + if not line: + fout.write('\n') + fidx.write('\n') + idx += 1 + subword_len_counter = 0 + continue + + items = line.split() + token = items[0].strip() + if len(items) == 2: + label = items[1].strip() + else: + label = 'O' + current_subwords_len = len(tokenizer.tokenize(token)) + + if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0: + token = tokenizer.unk_token + current_subwords_len = 1 + + if (subword_len_counter + current_subwords_len) > max_seq_len: + fout.write(f"\n{token}\t{label}\n") + fidx.write(f"\n{idx}\n") + subword_len_counter = current_subwords_len + else: + fout.write(f"{token}\t{label}\n") + fidx.write(f"{idx}\n") + subword_len_counter += current_subwords_len + return 1 + + model_type = args.model_type + tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path, + do_lower_case=args.do_lower_case, + cache_dir=args.cache_dir if args.cache_dir else None) + for lang in args.languages.split(','): + out_dir = os.path.join(args.output_dir, lang) + if not os.path.exists(out_dir): + os.makedirs(out_dir) + if lang == 'en': + files = ['dev', 'test', 'train'] + else: + files = ['dev', 'test'] + for file in files: + infile = os.path.join(args.data_dir, f'{file}-{lang}.tsv') + outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path)) + idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path)) + if os.path.exists(outfile) and os.path.exists(idxfile): + print(f'{outfile} and {idxfile} exist') + else: + code = _preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len) + if code > 0: + print(f'finish preprocessing {outfile}') + + +def panx_preprocess(args): + def _process_one_file(infile, outfile): + lines = open(infile, 'r').readlines() + if lines[-1].strip() == '': + lines = lines[:-1] + with open(outfile, 'w') as fout: + for l in lines: + items = l.strip().split('\t') + if len(items) == 2: + label = items[1].strip() + idx = items[0].find(':') + if idx != -1: + token = items[0][idx+1:].strip() + # if 'test' in infile: + # fout.write(f'{token}\n') + # else: + fout.write(f'{token}\t{label}\n') + else: + fout.write('\n') + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + langs = 'ar he vi id jv ms tl eu ml ta te af nl en de el bn hi mr ur fa fr it pt es bg ru ja ka ko th sw yo my zh kk tr et fi hu'.split(' ') + for lg in langs: + for split in ['train', 'test', 'dev']: + infile = os.path.join(args.data_dir, f'{lg}-{split}') + outfile = os.path.join(args.output_dir, f'{split}-{lg}.tsv') + _process_one_file(infile, outfile) + +def udpos_tokenize_preprocess(args): + def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len): + if not os.path.exists(infile): + print(f'{infile} does not exist') + return + subword_len_counter = idx = 0 + special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2 + max_seq_len = max_len - special_tokens_count + with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx: + for line in fin: + line = line.strip() + if len(line) == 0 or line == '': + fout.write('\n') + fidx.write('\n') + idx += 1 + subword_len_counter = 0 + continue + + items = line.split() + if len(items) == 2: + label = items[1].strip() + else: + label = "X" + token = items[0].strip() + current_subwords_len = len(tokenizer.tokenize(token)) + + if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0: + token = tokenizer.unk_token + current_subwords_len = 1 + + if (subword_len_counter + current_subwords_len) > max_seq_len: + fout.write(f"\n{token}\t{label}\n") + fidx.write(f"\n{idx}\n") + subword_len_counter = current_subwords_len + else: + fout.write(f"{token}\t{label}\n") + fidx.write(f"{idx}\n") + subword_len_counter += current_subwords_len + + model_type = args.model_type + tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path, + do_lower_case=args.do_lower_case, + cache_dir=args.cache_dir if args.cache_dir else None) + for lang in args.languages.split(','): + out_dir = os.path.join(args.output_dir, lang) + if not os.path.exists(out_dir): + os.makedirs(out_dir) + if lang == 'en': + files = ['dev', 'test', 'train'] + else: + files = ['dev', 'test'] + for file in files: + infile = os.path.join(args.data_dir, "{}-{}.tsv".format(file, lang)) + outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path)) + idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path)) + if os.path.exists(outfile) and os.path.exists(idxfile): + print(f'{outfile} and {idxfile} exist') + else: + _preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len) + print(f'finish preprocessing {outfile}') + +def udpos_preprocess(args): + def _read_one_file(file): + data = [] + sent, tag, lines = [], [], [] + for line in open(file, 'r'): + items = line.strip().split('\t') + if len(items) != 10: + empty = all(w == '_' for w in sent) + num_empty = sum([int(w == '_') for w in sent]) + if num_empty == 0 or num_empty < len(sent) - 1: + data.append((sent, tag, lines)) + sent, tag, lines = [], [], [] + else: + sent.append(items[1].strip()) + tag.append(items[3].strip()) + lines.append(line.strip()) + assert len(sent) == int(items[0]), 'line={}, sent={}, tag={}'.format(line, sent, tag) + return data + + def isfloat(value): + try: + float(value) + return True + except ValueError: + return False + + def remove_empty_space(data): + new_data = {} + for split in data: + new_data[split] = [] + for sent, tag, lines in data[split]: + new_sent = [''.join(w.replace('\u200c', '').split(' ')) for w in sent] + lines = [line.replace('\u200c', '') for line in lines] + assert len(" ".join(new_sent).split(' ')) == len(tag) + new_data[split].append((new_sent, tag, lines)) + return new_data + + def check_file(file): + for i, l in enumerate(open(file)): + items = l.strip().split('\t') + assert len(items[0].split(' ')) == len(items[1].split(' ')), 'idx={}, line={}'.format(i, l) + + def _write_files(data, output_dir, lang, suffix): + for split in data: + if len(data[split]) > 0: + prefix = os.path.join(output_dir, f'{split}-{lang}') + if suffix == 'mt': + with open(prefix + '.mt.tsv', 'w') as fout: + for idx, (sent, tag, _) in enumerate(data[split]): + newline = '\n' if idx != len(data[split]) - 1 else '' + # if split == 'test': + # fout.write('{}{}'.format(' '.join(sent, newline))) + # else: + fout.write('{}\t{}{}'.format(' '.join(sent), ' '.join(tag), newline)) + check_file(prefix + '.mt.tsv') + print(' - finish checking ' + prefix + '.mt.tsv') + elif suffix == 'tsv': + with open(prefix + '.tsv', 'w') as fout: + for sidx, (sent, tag, _) in enumerate(data[split]): + for widx, (w, t) in enumerate(zip(sent, tag)): + newline = '' if (sidx == len(data[split]) - 1) and (widx == len(sent) - 1) else '\n' + # if split == 'test': + # fout.write('{}{}'.format(w, newline)) + # else: + fout.write('{}\t{}{}'.format(w, t, newline)) + fout.write('\n') + elif suffix == 'conll': + with open(prefix + '.conll', 'w') as fout: + for _, _, lines in data[split]: + for l in lines: + fout.write(l.strip() + '\n') + fout.write('\n') + print(f'finish writing file to {prefix}.{suffix}') + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + languages = 'af ar bg de el en es et eu fa fi fr he hi hu id it ja kk ko mr nl pt ru ta te th tl tr ur vi yo zh'.split(' ') + for root, dirs, files in os.walk(args.data_dir): + lg = root.strip().split('/')[-1] + if root == args.data_dir or lg not in languages: + continue + + data = {k: [] for k in ['train', 'dev', 'test']} + for f in sorted(files): + if f.endswith('conll'): + file = os.path.join(root, f) + examples = _read_one_file(file) + if 'train' in f: + data['train'].extend(examples) + elif 'dev' in f: + data['dev'].extend(examples) + elif 'test' in f: + data['test'].extend(examples) + else: + print('split not found: ', file) + print(' - finish reading {}, {}'.format(file, [(k, len(v)) for k,v in data.items()])) + + data = remove_empty_space(data) + for sub in ['tsv']: + _write_files(data, args.output_dir, lg, sub) + +def pawsx_preprocess(args): + def _preprocess_one_file(infile, outfile, remove_label=False): + data = [] + for i, line in enumerate(open(infile, 'r')): + if i == 0: + continue + items = line.strip().split('\t') + sent1 = ' '.join(items[1].strip().split(' ')) + sent2 = ' '.join(items[2].strip().split(' ')) + label = items[3] + data.append([sent1, sent2, label]) + + with open(outfile, 'w') as fout: + writer = csv.writer(fout, delimiter='\t') + for sent1, sent2, label in data: + if remove_label: + writer.writerow([sent1, sent2]) + else: + writer.writerow([sent1, sent2, label]) + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + split2file = {'train': 'train', 'test': 'test_2k', 'dev': 'dev_2k'} + for lang in ['en', 'de', 'es', 'fr', 'ja', 'ko', 'zh']: + for split in ['train', 'test', 'dev']: + if split == 'train' and lang != 'en': + continue + file = split2file[split] + infile = os.path.join(args.data_dir, lang, "{}.tsv".format(file)) + outfile = os.path.join(args.output_dir, "{}-{}.tsv".format(split, lang)) + # _preprocess_one_file(infile, outfile, remove_label=(split == 'test')) + _preprocess_one_file(infile, outfile) + print(f'finish preprocessing {outfile}') + +def xnli_preprocess(args): + def _preprocess_file(infile, output_dir, split): + all_langs = defaultdict(list) + for i, line in enumerate(open(infile, 'r')): + if i == 0: + continue + + items = line.strip().split('\t') + lang = items[0].strip() + label = "contradiction" if items[1].strip() == "contradictory" else items[1].strip() + sent1 = ' '.join(items[6].strip().split(' ')) + sent2 = ' '.join(items[7].strip().split(' ')) + all_langs[lang].append((sent1, sent2, label)) + print(f'# langs={len(all_langs)}') + for lang, pairs in all_langs.items(): + outfile = os.path.join(output_dir, '{}-{}.tsv'.format(split, lang)) + with open(outfile, 'w') as fout: + writer = csv.writer(fout, delimiter='\t') + for (sent1, sent2, label) in pairs: + # if split == 'test': + # writer.writerow([sent1, sent2]) + # else: + writer.writerow([sent1, sent2, label]) + print(f'finish preprocess {outfile}') + + def _preprocess_train_file(infile, outfile): + with open(outfile, 'w') as fout: + writer = csv.writer(fout, delimiter='\t') + for i, line in enumerate(open(infile, 'r')): + if i == 0: + continue + + items = line.strip().split('\t') + sent1 = ' '.join(items[0].strip().split(' ')) + sent2 = ' '.join(items[1].strip().split(' ')) + label = "contradiction" if items[2].strip() == "contradictory" else items[2].strip() + writer.writerow([sent1, sent2, label]) + print(f'finish preprocess {outfile}') + + for lg in "ar,bg,de,el,en,es,fr,hi,ru,sw,th,tr,ur,vi,zh".split(","): + infile = os.path.join(args.data_dir, f'XNLI-MT-1.0/multinli/multinli.train.{lg}.tsv') + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + outfile = os.path.join(args.output_dir, f'train-{lg}.tsv') + _preprocess_train_file(infile, outfile) + + for split in ['test', 'dev']: + infile = os.path.join(args.data_dir, 'XNLI-1.0/xnli.{}.tsv'.format(split)) + print(f'reading file {infile}') + _preprocess_file(infile, args.output_dir, split) + + +def tatoeba_preprocess(args): + lang3_dict = { + 'afr':'af', 'ara':'ar', 'bul':'bg', 'ben':'bn', + 'deu':'de', 'ell':'el', 'spa':'es', 'est':'et', + 'eus':'eu', 'pes':'fa', 'fin':'fi', 'fra':'fr', + 'heb':'he', 'hin':'hi', 'hun':'hu', 'ind':'id', + 'ita':'it', 'jpn':'ja', 'jav':'jv', 'kat':'ka', + 'kaz':'kk', 'kor':'ko', 'mal':'ml', 'mar':'mr', + 'nld':'nl', 'por':'pt', 'rus':'ru', 'swh':'sw', + 'tam':'ta', 'tel':'te', 'tha':'th', 'tgl':'tl', + 'tur':'tr', 'urd':'ur', 'vie':'vi', 'cmn':'zh', + 'eng':'en', + } + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + for sl3, sl2 in lang3_dict.items(): + if sl3 != 'eng': + src_file = f'{args.data_dir}/tatoeba.{sl3}-eng.{sl3}' + tgt_file = f'{args.data_dir}/tatoeba.{sl3}-eng.eng' + src_out = f'{args.output_dir}/{sl2}-en.{sl2}' + tgt_out = f'{args.output_dir}/{sl2}-en.en' + shutil.copy(src_file, src_out) + tgts = [l.strip() for l in open(tgt_file)] + idx = range(len(tgts)) + data = zip(tgts, idx) + with open(tgt_out, 'w') as ftgt: + for t, i in sorted(data, key=lambda x: x[0]): + ftgt.write(f'{t}\n') + + +def xquad_preprocess(args): + pass + # Remove the test annotations to prevent accidental cheating + # remove_qa_test_annotations(args.data_dir) + + +def mlqa_preprocess(args): + pass + # Remove the test annotations to prevent accidental cheating + # remove_qa_test_annotations(args.data_dir) + + +def tydiqa_preprocess(args): + LANG2ISO = {'arabic': 'ar', 'bengali': 'bn', 'english': 'en', 'finnish': 'fi', + 'indonesian': 'id', 'korean': 'ko', 'russian': 'ru', + 'swahili': 'sw', 'telugu': 'te'} + assert os.path.exists(args.data_dir) + train_file = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-train.json') + os.makedirs(args.output_dir, exist_ok=True) + + # Split the training file into language-specific files + lang2data = defaultdict(list) + with open(train_file, 'r') as f_in: + data = json.load(f_in) + version = data['version'] + for doc in data['data']: + for par in doc['paragraphs']: + context = par['context'] + for qa in par['qas']: + question = qa['question'] + question_id = qa['id'] + example_lang = question_id.split('-')[0] + q_id = question_id.split('-')[-1] + for answer in qa['answers']: + a_start, a_text = answer['answer_start'], answer['text'] + a_end = a_start + len(a_text) + assert context[a_start:a_end] == a_text + lang2data[example_lang].append({'paragraphs': [{ + 'context': context, + 'qas': [{'answers': qa['answers'], + 'question': question, + 'id': q_id}]}]}) + + for lang, data in lang2data.items(): + out_file = os.path.join( + args.output_dir, 'tydiqa.%s.train.json' % LANG2ISO[lang]) + with open(out_file, 'w') as f: + json.dump({'data': data, 'version': version}, f) + + # Rename the dev files + dev_dir = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-dev') + assert os.path.exists(dev_dir) + for lang, iso in LANG2ISO.items(): + src_file = os.path.join(dev_dir, 'tydiqa-goldp-dev-%s.json' % lang) + dst_file = os.path.join(dev_dir, 'tydiqa.%s.dev.json' % iso) + os.rename(src_file, dst_file) + + # Remove the test annotations to prevent accidental cheating + # remove_qa_test_annotations(dev_dir) + + +def remove_qa_test_annotations(test_dir): + assert os.path.exists(test_dir) + for file_name in os.listdir(test_dir): + new_data = [] + test_file = os.path.join(test_dir, file_name) + with open(test_file, 'r') as f: + data = json.load(f) + version = data['version'] + for doc in data['data']: + for par in doc['paragraphs']: + context = par['context'] + for qa in par['qas']: + question = qa['question'] + question_id = qa['id'] + for answer in qa['answers']: + a_start, a_text = answer['answer_start'], answer['text'] + a_end = a_start + len(a_text) + assert context[a_start:a_end] == a_text + new_data.append({'paragraphs': [{ + 'context': context, + 'qas': [{'answers': [{'answer_start': 0, 'text': ''}], + 'question': question, + 'id': question_id}]}]}) + with open(test_file, 'w') as f: + json.dump({'data': new_data, 'version': version}, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + ## Required parameters + parser.add_argument("--data_dir", default=None, type=str, required=True, + help="The input data dir. Should contain the .tsv files (or other data files) for the task.") + parser.add_argument("--output_dir", default=None, type=str, required=True, + help="The output data dir where any processed files will be written to.") + parser.add_argument("--task", default="panx", type=str, required=True, + help="The task name") + parser.add_argument("--model_name_or_path", default="bert-base-multilingual-cased", type=str, + help="The pre-trained model") + parser.add_argument("--model_type", default="bert", type=str, + help="model type") + parser.add_argument("--max_len", default=512, type=int, + help="the maximum length of sentences") + parser.add_argument("--do_lower_case", action='store_true', + help="whether to do lower case") + parser.add_argument("--cache_dir", default=None, type=str, + help="cache directory") + parser.add_argument("--languages", default="en", type=str, + help="process language") + parser.add_argument("--remove_last_token", action='store_true', + help="whether to remove the last token") + parser.add_argument("--remove_test_label", action='store_true', + help="whether to remove test set label") + args = parser.parse_args() + + if args.task == 'panx_tokenize': + panx_tokenize_preprocess(args) + if args.task == 'panx': + panx_preprocess(args) + if args.task == 'udpos_tokenize': + udpos_tokenize_preprocess(args) + if args.task == 'udpos': + udpos_preprocess(args) + if args.task == 'pawsx': + pawsx_preprocess(args) + if args.task == 'xnli': + xnli_preprocess(args) + if args.task == 'tatoeba': + tatoeba_preprocess(args) + if args.task == 'xquad': + xquad_preprocess(args) + if args.task == 'mlqa': + mlqa_preprocess(args) + if args.task == 'tydiqa': + tydiqa_preprocess(args)