This repository has been archived by the owner on Nov 21, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sean Naren
authored
Jun 22, 2022
1 parent
e64b9f6
commit 9026a1e
Showing
13 changed files
with
265 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
.. _hf_save: | ||
|
||
HuggingFace Hub Checkpoints | ||
=========================== | ||
|
||
Lightning Transformers default behaviour means we save PyTorch based checkpoints. | ||
|
||
HuggingFace Transformers provides a separate API for saving checkpoints. Below we describe two ways to save HuggingFace checkpoints manually or during training. | ||
|
||
To manually save checkpoints from your model: | ||
|
||
.. code-block:: python | ||
from lightning_transformers.task.nlp.text_classification import TextClassificationTransformer | ||
model = TextClassificationTransformer(pretrained_model_name_or_path="prajjwal1/bert-tiny") | ||
# saves a HF checkpoint to this path. | ||
model.save_hf_checkpoint("checkpoint") | ||
To save an additional HF Checkpoint everytime the checkpoint callback saves, pass in the ``HFSaveCheckpoint`` plugin: | ||
|
||
.. code-block:: python | ||
import pytorch_lightning as pl | ||
from transformers import AutoTokenizer | ||
from lightning_transformers.plugins.checkpoint import HFSaveCheckpoint | ||
from lightning_transformers.task.nlp.text_classification import ( | ||
TextClassificationDataConfig, | ||
TextClassificationDataModule, | ||
TextClassificationTransformer, | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="prajjwal1/bert-tiny") | ||
dm = TextClassificationDataModule( | ||
cfg=TextClassificationDataConfig( | ||
batch_size=1, | ||
dataset_name="glue", | ||
dataset_config_name="sst2", | ||
max_length=512, | ||
), | ||
tokenizer=tokenizer, | ||
) | ||
model = TextClassificationTransformer(pretrained_model_name_or_path="prajjwal1/bert-tiny") | ||
trainer = pl.Trainer(plugins=HFSaveCheckpoint(model=model)) | ||
trainer.fit(model, dm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
.. _large_model: | ||
|
||
Inference for Big Transformers | ||
============================== | ||
|
||
Lightning Transformers provides out of the box support for running inference with very large billion parameter models. Under-the-hood we use HF Accelerates' Transformer support to auto-select devices for optimal throughput and memory usage. | ||
|
||
This will allow the model to be split onto GPUs/CPUs and even kept onto Disk to optimize memory space. | ||
|
||
Below is an example of how you can run generation with a large 6B parameter transformer model using Lightning Transformers. | ||
|
||
|
||
.. code-block:: bash | ||
pip install accelerate | ||
Download the sharded checkpoint weights that we'll be using: | ||
|
||
.. code-block:: bash | ||
git clone https://huggingface.co/sgugger/sharded-gpt-j-6B | ||
cd sharded-gpt-j-6B | ||
git-lfs install | ||
git pull | ||
.. code-block:: python | ||
import torch | ||
from accelerate import init_empty_weights | ||
from transformers import AutoTokenizer | ||
from lightning_transformers.task.nlp.language_modeling import LanguageModelingTransformer | ||
# initializes empty model for us to the load the checkpoint. | ||
with init_empty_weights(): | ||
model = LanguageModelingTransformer( | ||
pretrained_model_name_or_path="EleutherAI/gpt-j-6B", | ||
tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") | ||
) | ||
# automatically selects the best devices (cpu/gpu) to load model layers based on available memory | ||
model.load_checkpoint_and_dispatch("sharded-gpt-j-6B", device_map="auto", no_split_module_classes=["GPTJBlock"]) | ||
output = model.generate("Hello, my name is", device=torch.device("cuda")) | ||
print(model.tokenizer.decode(output[0].tolist())) | ||
To see more details about the API, see `here <https://huggingface.co/docs/accelerate/big_modeling>`__. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""Example to use a large language model to generate text. Accelerate will automatically place weights on | ||
appropriate devices to fit into CPU and GPU memory. | ||
Please download the sharded weights manually before running: | ||
git clone https://huggingface.co/sgugger/sharded-gpt-j-6B | ||
cd sharded-gpt-j-6B | ||
git-lfs install | ||
git pull | ||
""" | ||
|
||
import torch | ||
from accelerate import init_empty_weights | ||
from transformers import AutoTokenizer | ||
|
||
from lightning_transformers.task.nlp.language_modeling import LanguageModelingTransformer | ||
|
||
with init_empty_weights(): | ||
model = LanguageModelingTransformer( | ||
pretrained_model_name_or_path="EleutherAI/gpt-j-6B", | ||
tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"), | ||
load_weights=False, | ||
) | ||
|
||
model.load_checkpoint_and_dispatch("sharded-gpt-j-6B", device_map="auto", no_split_module_classes=["GPTJBlock"]) | ||
|
||
output = model.generate("Hello, my name is", device=torch.device("cuda")) | ||
print(model.tokenizer.decode(output[0].tolist())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from lightning_transformers.plugins.checkpoint import HFSaveCheckpoint # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import os | ||
from pathlib import Path | ||
from typing import Any, Dict, Optional, Union | ||
|
||
from pytorch_lightning.plugins import TorchCheckpointIO | ||
from pytorch_lightning.utilities.types import _PATH | ||
|
||
from lightning_transformers.core import TaskTransformer | ||
|
||
|
||
class HFSaveCheckpoint(TorchCheckpointIO): | ||
"""Allows you to save an additional HuggingFace Hub compatible checkpoint.""" | ||
|
||
def __init__(self, model: TaskTransformer, suffix: Union[str, Path] = "_huggingface"): | ||
self._model = model | ||
self._suffix = suffix | ||
|
||
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: | ||
super().save_checkpoint(checkpoint, path, storage_options) | ||
base_path = os.path.splitext(path)[0] + self._suffix | ||
self._model.save_hf_checkpoint(base_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# extensions | ||
lightning-bolts>=0.5.0 | ||
deepspeed | ||
accelerate>=0.8.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters