Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

Commit

Permalink
Add big model support (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Jun 22, 2022
1 parent e64b9f6 commit 9026a1e
Show file tree
Hide file tree
Showing 13 changed files with 265 additions and 5 deletions.
4 changes: 3 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ Lightning Transformers
.. toctree::
:maxdepth: 1
:name: optimization
:caption: Transformer Optimizations
:caption: Optimizations

optimizations/sparseml
optimizations/hf_save
optimizations/large_model

.. toctree::
:maxdepth: 1
Expand Down
47 changes: 47 additions & 0 deletions docs/source/optimizations/hf_save.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
.. _hf_save:

HuggingFace Hub Checkpoints
===========================

Lightning Transformers default behaviour means we save PyTorch based checkpoints.

HuggingFace Transformers provides a separate API for saving checkpoints. Below we describe two ways to save HuggingFace checkpoints manually or during training.

To manually save checkpoints from your model:

.. code-block:: python
from lightning_transformers.task.nlp.text_classification import TextClassificationTransformer
model = TextClassificationTransformer(pretrained_model_name_or_path="prajjwal1/bert-tiny")
# saves a HF checkpoint to this path.
model.save_hf_checkpoint("checkpoint")
To save an additional HF Checkpoint everytime the checkpoint callback saves, pass in the ``HFSaveCheckpoint`` plugin:

.. code-block:: python
import pytorch_lightning as pl
from transformers import AutoTokenizer
from lightning_transformers.plugins.checkpoint import HFSaveCheckpoint
from lightning_transformers.task.nlp.text_classification import (
TextClassificationDataConfig,
TextClassificationDataModule,
TextClassificationTransformer,
)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="prajjwal1/bert-tiny")
dm = TextClassificationDataModule(
cfg=TextClassificationDataConfig(
batch_size=1,
dataset_name="glue",
dataset_config_name="sst2",
max_length=512,
),
tokenizer=tokenizer,
)
model = TextClassificationTransformer(pretrained_model_name_or_path="prajjwal1/bert-tiny")
trainer = pl.Trainer(plugins=HFSaveCheckpoint(model=model))
trainer.fit(model, dm)
49 changes: 49 additions & 0 deletions docs/source/optimizations/large_model.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
.. _large_model:

Inference for Big Transformers
==============================

Lightning Transformers provides out of the box support for running inference with very large billion parameter models. Under-the-hood we use HF Accelerates' Transformer support to auto-select devices for optimal throughput and memory usage.

This will allow the model to be split onto GPUs/CPUs and even kept onto Disk to optimize memory space.

Below is an example of how you can run generation with a large 6B parameter transformer model using Lightning Transformers.


.. code-block:: bash
pip install accelerate
Download the sharded checkpoint weights that we'll be using:

.. code-block:: bash
git clone https://huggingface.co/sgugger/sharded-gpt-j-6B
cd sharded-gpt-j-6B
git-lfs install
git pull
.. code-block:: python
import torch
from accelerate import init_empty_weights
from transformers import AutoTokenizer
from lightning_transformers.task.nlp.language_modeling import LanguageModelingTransformer
# initializes empty model for us to the load the checkpoint.
with init_empty_weights():
model = LanguageModelingTransformer(
pretrained_model_name_or_path="EleutherAI/gpt-j-6B",
tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
)
# automatically selects the best devices (cpu/gpu) to load model layers based on available memory
model.load_checkpoint_and_dispatch("sharded-gpt-j-6B", device_map="auto", no_split_module_classes=["GPTJBlock"])
output = model.generate("Hello, my name is", device=torch.device("cuda"))
print(model.tokenizer.decode(output[0].tolist()))
To see more details about the API, see `here <https://huggingface.co/docs/accelerate/big_modeling>`__.
28 changes: 28 additions & 0 deletions examples/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Example to use a large language model to generate text. Accelerate will automatically place weights on
appropriate devices to fit into CPU and GPU memory.
Please download the sharded weights manually before running:
git clone https://huggingface.co/sgugger/sharded-gpt-j-6B
cd sharded-gpt-j-6B
git-lfs install
git pull
"""

import torch
from accelerate import init_empty_weights
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.language_modeling import LanguageModelingTransformer

with init_empty_weights():
model = LanguageModelingTransformer(
pretrained_model_name_or_path="EleutherAI/gpt-j-6B",
tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"),
load_weights=False,
)

model.load_checkpoint_and_dispatch("sharded-gpt-j-6B", device_map="auto", no_split_module_classes=["GPTJBlock"])

output = model.generate("Hello, my name is", device=torch.device("cuda"))
print(model.tokenizer.decode(output[0].tolist()))
38 changes: 36 additions & 2 deletions lightning_transformers/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union

import pytorch_lightning as pl
import torch
import transformers
from pytorch_lightning.utilities import rank_zero_warn
from transformers import PreTrainedTokenizerBase
from transformers import AutoConfig, PreTrainedTokenizerBase
from transformers import pipeline as hf_transformers_pipeline

from lightning_transformers.utilities.imports import _ACCELERATE_AVAILABLE

if _ACCELERATE_AVAILABLE:
from accelerate import load_checkpoint_and_dispatch

if TYPE_CHECKING:
from transformers import AutoModel, Pipeline

Expand All @@ -45,11 +51,16 @@ def __init__(
pretrained_model_name_or_path: Optional[str] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
pipeline_kwargs: Optional[dict] = None,
load_weights: bool = True,
**model_data_kwargs,
) -> None:
super().__init__()
self.save_hyperparameters()
self.model = downstream_model_type.from_pretrained(pretrained_model_name_or_path, **model_data_kwargs)
if load_weights:
self.model = downstream_model_type.from_pretrained(pretrained_model_name_or_path, **model_data_kwargs)
else:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **model_data_kwargs)
self.model = downstream_model_type.from_config(config)
self._tokenizer = tokenizer # necessary for hf_pipeline
self._hf_pipeline = None
self._hf_pipeline_kwargs = pipeline_kwargs or {}
Expand Down Expand Up @@ -152,3 +163,26 @@ def load_from_checkpoint(
if hf_pipeline_kwargs is not None:
model._hf_pipeline_kwargs.update(hf_pipeline_kwargs)
return model

def load_checkpoint_and_dispatch(self, *args, **kwargs) -> None:
"""Use when loading checkpoint via accelerate for large model support.
Useful for when loading sharded checkpoints.
"""
self.model = load_checkpoint_and_dispatch(self.model, *args, **kwargs)

@property
def hf_device_map(self) -> Dict:
"""
Returns: Device Map as defined when using `load_checkpoint_and_dispatch`.
"""
return self.model.hf_device_map

def save_hf_checkpoint(self, path: Union[str, Path]) -> None:
"""Save the model using the original HF AutoModel.
This is useful for when you'd like to export the model to the hub.
Args:
path: Path to save the model to.
"""
self.model.save_pretrained(path)
1 change: 1 addition & 0 deletions lightning_transformers/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from lightning_transformers.plugins.checkpoint import HFSaveCheckpoint # noqa: F401
21 changes: 21 additions & 0 deletions lightning_transformers/plugins/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union

from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.utilities.types import _PATH

from lightning_transformers.core import TaskTransformer


class HFSaveCheckpoint(TorchCheckpointIO):
"""Allows you to save an additional HuggingFace Hub compatible checkpoint."""

def __init__(self, model: TaskTransformer, suffix: Union[str, Path] = "_huggingface"):
self._model = model
self._suffix = suffix

def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
super().save_checkpoint(checkpoint, path, storage_options)
base_path = os.path.splitext(path)[0] + self._suffix
self._model.save_hf_checkpoint(base_path)
14 changes: 13 additions & 1 deletion lightning_transformers/task/nlp/language_modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Type
from typing import TYPE_CHECKING, Any, Type

import torch
import transformers
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from lightning_transformers.core import TaskTransformer

Expand Down Expand Up @@ -60,3 +62,13 @@ def test_step(self, batch, batch_idx, dataloader_idx=0):
@property
def hf_pipeline_task(self) -> str:
return "text-generation"

def generate(self, text: str, device: torch.device = torch.device("cpu")) -> Any:
if self.tokenizer is None:
raise MisconfigurationException(
"A tokenizer is required to use the `generate` function. "
"Please pass a tokenizer `LanguageModelingTransformer(tokenizer=...)`."
)
inputs = self.tokenizer(text, return_tensors="pt")
inputs = inputs.to(device)
return self.model.generate(inputs["input_ids"])
1 change: 1 addition & 0 deletions lightning_transformers/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
_BOLTS_AVAILABLE = _module_available("pl_bolts") and _compare_version("pl_bolts", operator.ge, "0.4.0")
_BOLTS_GREATER_EQUAL_0_5_0 = _module_available("pl_bolts") and _compare_version("pl_bolts", operator.ge, "0.5.0")
_WANDB_AVAILABLE = _module_available("wandb")
_ACCELERATE_AVAILABLE = _module_available("accelerate")
1 change: 1 addition & 0 deletions requirements/extra.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# extensions
lightning-bolts>=0.5.0
deepspeed
accelerate>=0.8.0
30 changes: 30 additions & 0 deletions tests/task/nlp/test_language_modeling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import sys
from unittest.mock import MagicMock

import pytest
import pytorch_lightning as pl
import torch
import transformers
from transformers import AutoTokenizer

Expand All @@ -11,6 +13,10 @@
LanguageModelingDataModule,
LanguageModelingTransformer,
)
from lightning_transformers.utilities.imports import _ACCELERATE_AVAILABLE

if _ACCELERATE_AVAILABLE:
from accelerate import init_empty_weights


@pytest.mark.skipif(sys.platform == "win32", reason="Currently Windows is not supported")
Expand Down Expand Up @@ -54,3 +60,27 @@ def test_datamodule_has_correct_cfg():
dm = LanguageModelingDataModule(tokenizer)
assert isinstance(dm.cfg, LanguageModelingDataConfig)
assert dm.tokenizer is tokenizer


@pytest.mark.skipif(not _ACCELERATE_AVAILABLE, reason="Accelerate not installed.")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a GPU to run.")
def test_generate_inference(tmpdir):
model = LanguageModelingTransformer(
pretrained_model_name_or_path="sshleifer/tiny-gpt2",
tokenizer=AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2"),
)
ckpt_path = os.path.join(tmpdir, "checkpoint.ckpt")
torch.save(model.model.state_dict(), ckpt_path)

with init_empty_weights():
model = LanguageModelingTransformer(
pretrained_model_name_or_path="sshleifer/tiny-gpt2",
tokenizer=AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2"),
load_weights=False,
)

model.load_checkpoint_and_dispatch(ckpt_path, device_map="auto")

output = model.generate("Hello, my name is", device=torch.device("cuda"))
output = model.tokenizer.decode(output[0].tolist())
assert "Hello, my name is" in output
33 changes: 33 additions & 0 deletions tests/task/nlp/test_text_classification.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
import sys
from unittest.mock import MagicMock

import pytest
import pytorch_lightning as pl
import transformers
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import AutoTokenizer

from lightning_transformers.plugins.checkpoint import HFSaveCheckpoint
from lightning_transformers.task.nlp.text_classification import (
TextClassificationDataConfig,
TextClassificationDataModule,
Expand Down Expand Up @@ -78,3 +81,33 @@ def test_datamodule_has_correct_cfg():
dm = TextClassificationDataModule(tokenizer)
assert isinstance(dm.cfg, TextClassificationDataConfig)
assert dm.tokenizer is tokenizer


@pytest.mark.skipif(sys.platform == "win32", reason="Currently Windows is not supported")
def test_huggingface_checkpoint_train(hf_cache_path, tmpdir):
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="prajjwal1/bert-tiny")
dm = TextClassificationDataModule(
cfg=TextClassificationDataConfig(
batch_size=1,
dataset_name="glue",
dataset_config_name="sst2",
max_length=512,
limit_test_samples=64,
limit_val_samples=64,
limit_train_samples=64,
cache_dir=hf_cache_path,
),
tokenizer=tokenizer,
)
ckpt_path = os.path.join(tmpdir, "checkpoints")
model = TextClassificationTransformer(pretrained_model_name_or_path="prajjwal1/bert-tiny")
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=1,
plugins=HFSaveCheckpoint(model=model),
callbacks=ModelCheckpoint(save_last=True, dirpath=ckpt_path),
)
trainer.fit(model, dm)
assert os.path.exists(os.path.join(ckpt_path, "last_huggingface"))
3 changes: 2 additions & 1 deletion tests/task/vision/test_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@


@pytest.mark.skipif(sys.platform == "win32", reason="Currently Windows is not supported")
@pytest.mark.skipif(sys.platform == "darwin", reason="Currently darwin is not working")
def test_smoke_train(hf_cache_path):
feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path="nateraw/tiny-vit-random")
dm = ImageClassificationDataModule(
cfg=ImageClassificationDataConfig(batch_size=2, dataset_name="beans"),
cfg=ImageClassificationDataConfig(batch_size=1, dataset_name="beans"),
feature_extractor=feature_extractor,
)
model = ImageClassificationTransformer(
Expand Down

0 comments on commit 9026a1e

Please sign in to comment.