Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

Fix LightningCLI compatibility #288

Merged
merged 2 commits into from
Sep 24, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions lightning_transformers/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
from typing import IO, Any, Callable, Dict, Optional, Tuple, Type, Union

import pytorch_lightning as pl
import torch
import transformers
from pytorch_lightning.utilities import rank_zero_warn
from transformers import AutoConfig, PreTrainedTokenizerBase
from transformers import AutoConfig, Pipeline, PreTrainedTokenizerBase
from transformers import pipeline as hf_transformers_pipeline
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.utilities.deepspeed import enable_transformers_pretrained_deepspeed_sharding
from lightning_transformers.utilities.imports import _ACCELERATE_AVAILABLE

if _ACCELERATE_AVAILABLE:
from accelerate import load_checkpoint_and_dispatch

if TYPE_CHECKING:
from transformers import AutoModel, Pipeline


class TaskTransformer(pl.LightningModule):
"""Base class for task specific transformers, wrapping pre-trained language models for downstream tasks. The
Expand All @@ -47,7 +45,7 @@ class TaskTransformer(pl.LightningModule):

def __init__(
self,
downstream_model_type: Type["AutoModel"],
downstream_model_type: Type[_BaseAutoModelClass],
pretrained_model_name_or_path: Optional[str] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
pipeline_kwargs: Optional[dict] = None,
Expand Down Expand Up @@ -127,7 +125,7 @@ def configure_metrics(self, stage: str) -> Optional[Any]:
pass

@property
def tokenizer(self) -> Optional["PreTrainedTokenizerBase"]:
def tokenizer(self) -> Optional[PreTrainedTokenizerBase]:
if (
self._tokenizer is None
and hasattr(self, "trainer") # noqa: W503
Expand All @@ -138,7 +136,7 @@ def tokenizer(self) -> Optional["PreTrainedTokenizerBase"]:
return self._tokenizer

@tokenizer.setter
def tokenizer(self, tokenizer: "PreTrainedTokenizerBase") -> None:
def tokenizer(self, tokenizer: PreTrainedTokenizerBase) -> None:
self._tokenizer = tokenizer

@property
Expand All @@ -150,7 +148,7 @@ def hf_pipeline_task(self) -> Optional[str]:
return None

@property
def hf_pipeline(self) -> "Pipeline":
def hf_pipeline(self) -> Pipeline:
if self._hf_pipeline is None:
if self.hf_pipeline_task is not None:
self._hf_pipeline = hf_transformers_pipeline(
Expand All @@ -161,7 +159,7 @@ def hf_pipeline(self) -> "Pipeline":
return self._hf_pipeline

@hf_pipeline.setter
def hf_pipeline(self, pipeline: "Pipeline") -> None:
def hf_pipeline(self, pipeline: Pipeline) -> None:
self._hf_pipeline = pipeline

def hf_predict(self, *args, **kwargs) -> Any:
Expand Down
8 changes: 3 additions & 5 deletions lightning_transformers/task/nlp/language_modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Type
from typing import Any, Type

import torch
import transformers
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class LanguageModelingTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Language Modeling Task.
Expand All @@ -33,7 +31,7 @@ class LanguageModelingTransformer(TaskTransformer):
"""

def __init__(
self, *args, downstream_model_type: Type["AutoModel"] = transformers.AutoModelForCausalLM, **kwargs
self, *args, downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForCausalLM, **kwargs
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Type
from typing import Type

import transformers
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class MaskedLanguageModelingTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Masked Language Modeling Task.
Expand All @@ -31,7 +29,7 @@ class MaskedLanguageModelingTransformer(TaskTransformer):
"""

def __init__(
self, *args, downstream_model_type: Type["AutoModel"] = transformers.AutoModelForMaskedLM, **kwargs
self, *args, downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForMaskedLM, **kwargs
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)

Expand Down
11 changes: 6 additions & 5 deletions lightning_transformers/task/nlp/multiple_choice/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Type
from typing import Type

import torch
import transformers
from torchmetrics import Accuracy, Precision, Recall
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class MultipleChoiceTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Multiple Choice Task.
Expand All @@ -34,7 +32,10 @@ class MultipleChoiceTransformer(TaskTransformer):
"""

def __init__(
self, *args, downstream_model_type: Type["AutoModel"] = transformers.AutoModelForMultipleChoice, **kwargs
self,
*args,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForMultipleChoice,
**kwargs,
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)

Expand Down
11 changes: 6 additions & 5 deletions lightning_transformers/task/nlp/question_answering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import TYPE_CHECKING, Any, Type
from typing import Any, Type

import torch
import transformers
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer
from lightning_transformers.task.nlp.question_answering import QuestionAnsweringDataModule
from lightning_transformers.task.nlp.question_answering.datasets.squad.metric import SquadMetric

if TYPE_CHECKING:
from transformers import AutoModel


class QuestionAnsweringTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Question Answering Task.
Expand All @@ -36,7 +34,10 @@ class QuestionAnsweringTransformer(TaskTransformer):
"""

def __init__(
self, *args, downstream_model_type: Type["AutoModel"] = transformers.AutoModelForQuestionAnswering, **kwargs
self,
*args,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForQuestionAnswering,
**kwargs,
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)

Expand Down
8 changes: 3 additions & 5 deletions lightning_transformers/task/nlp/summarization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Type
from typing import Type

import transformers
from torchmetrics.text.rouge import ROUGEScore
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core.seq2seq.model import Seq2SeqTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class SummarizationTransformer(Seq2SeqTransformer):
"""Defines ``LightningModule`` for the Summarization Task.
Expand All @@ -36,7 +34,7 @@ class SummarizationTransformer(Seq2SeqTransformer):
def __init__(
self,
*args,
downstream_model_type: Type["AutoModel"] = transformers.AutoModelForSeq2SeqLM,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForSeq2SeqLM,
use_stemmer: bool = True,
**kwargs
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Type
from typing import Any, Dict, Type

import torch
import transformers
from torchmetrics import Accuracy, Precision, Recall
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class TextClassificationTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Text Classification Task.
Expand All @@ -36,7 +34,7 @@ class TextClassificationTransformer(TaskTransformer):
def __init__(
self,
*args,
downstream_model_type: Type["AutoModel"] = transformers.AutoModelForSequenceClassification,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForSequenceClassification,
**kwargs,
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Type, Union
from typing import Any, Dict, List, Type, Union

import torch
import transformers
from torchmetrics import Accuracy, F1Score, Precision, Recall
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class TokenClassificationTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Text Classification Task.
Expand All @@ -37,7 +35,7 @@ def __init__(
self,
*args,
labels: Union[int, List[str]],
downstream_model_type: Type["AutoModel"] = transformers.AutoModelForTokenClassification,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForTokenClassification,
**kwargs,
) -> None:
num_labels = labels if isinstance(labels, int) else len(labels)
Expand Down
8 changes: 3 additions & 5 deletions lightning_transformers/task/nlp/translation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Type
from typing import Type

import transformers
from torchmetrics.text.bleu import BLEUScore
from transformers import MBartTokenizer
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core.seq2seq.model import Seq2SeqTransformer
from lightning_transformers.task.nlp.translation import TranslationDataModule

if TYPE_CHECKING:
from transformers import AutoModel


class TranslationTransformer(Seq2SeqTransformer):
"""Defines ``LightningModule`` for the Translation Task.
Expand All @@ -39,7 +37,7 @@ class TranslationTransformer(Seq2SeqTransformer):
def __init__(
self,
*args,
downstream_model_type: Type["AutoModel"] = transformers.AutoModelForSeq2SeqLM,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForSeq2SeqLM,
n_gram: int = 4,
smooth: bool = False,
**kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Type
from typing import Any, Dict, Type

import torch
import transformers
from torchmetrics import Accuracy, Precision, Recall
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel, Pipeline


class ImageClassificationTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Text Classification Task.
Expand All @@ -34,7 +32,10 @@ class ImageClassificationTransformer(TaskTransformer):
"""

def __init__(
self, *args, downstream_model_type: Type["AutoModel"] = transformers.AutoModelForImageClassification, **kwargs
self,
*args,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForImageClassification,
**kwargs,
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)
self.metrics = {}
Expand Down Expand Up @@ -83,6 +84,6 @@ def hf_pipeline_task(self) -> str:
return "image-classification"

@property
def hf_pipeline(self) -> "Pipeline":
def hf_pipeline(self) -> transformers.Pipeline:
self._hf_pipeline_kwargs["feature_extractor"] = self.tokenizer
return super().hf_pipeline
32 changes: 32 additions & 0 deletions tests/core/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from importlib.util import find_spec
from unittest import mock

import pytest
from pytorch_lightning.cli import LightningCLI

from lightning_transformers.task.vision.image_classification import (
ImageClassificationDataModule,
ImageClassificationTransformer,
)


@pytest.mark.skipif(find_spec("jsonargparse") is None, reason="jsonargparse is required")
def test_lightning_cli_image_classification():
config = {
"data": {
"dataset_name": "beans", # Resolve from TransformerDataModule.__init__
},
"model": {
"pretrained_model_name_or_path": "nateraw/tiny-vit-random", # Resolve from TaskTransformer.__init__
},
}
with mock.patch("sys.argv", ["any.py", f"--config={config}"]):
cli = LightningCLI(
ImageClassificationTransformer,
ImageClassificationDataModule,
run=False,
)
assert cli.config.data.dataset_name == "beans"
assert cli.config.model.pretrained_model_name_or_path == "nateraw/tiny-vit-random"
assert isinstance(cli.config_init.data, ImageClassificationDataModule)
assert isinstance(cli.config_init.model, ImageClassificationTransformer)