Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

Commit

Permalink
Fix LightningCLI compatibility (#288)
Browse files Browse the repository at this point in the history
Change AutoModel type hint to fix LightningCLI compatibility.
  • Loading branch information
mauvilsa authored Sep 24, 2022
1 parent 5cb1d5f commit 461ad92
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 57 deletions.
18 changes: 8 additions & 10 deletions lightning_transformers/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
from typing import IO, Any, Callable, Dict, Optional, Tuple, Type, Union

import pytorch_lightning as pl
import torch
import transformers
from pytorch_lightning.utilities import rank_zero_warn
from transformers import AutoConfig, PreTrainedTokenizerBase
from transformers import AutoConfig, Pipeline, PreTrainedTokenizerBase
from transformers import pipeline as hf_transformers_pipeline
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.utilities.deepspeed import enable_transformers_pretrained_deepspeed_sharding
from lightning_transformers.utilities.imports import _ACCELERATE_AVAILABLE

if _ACCELERATE_AVAILABLE:
from accelerate import load_checkpoint_and_dispatch

if TYPE_CHECKING:
from transformers import AutoModel, Pipeline


class TaskTransformer(pl.LightningModule):
"""Base class for task specific transformers, wrapping pre-trained language models for downstream tasks. The
Expand All @@ -47,7 +45,7 @@ class TaskTransformer(pl.LightningModule):

def __init__(
self,
downstream_model_type: Type["AutoModel"],
downstream_model_type: Type[_BaseAutoModelClass],
pretrained_model_name_or_path: Optional[str] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
pipeline_kwargs: Optional[dict] = None,
Expand Down Expand Up @@ -127,7 +125,7 @@ def configure_metrics(self, stage: str) -> Optional[Any]:
pass

@property
def tokenizer(self) -> Optional["PreTrainedTokenizerBase"]:
def tokenizer(self) -> Optional[PreTrainedTokenizerBase]:
if (
self._tokenizer is None
and hasattr(self, "trainer") # noqa: W503
Expand All @@ -138,7 +136,7 @@ def tokenizer(self) -> Optional["PreTrainedTokenizerBase"]:
return self._tokenizer

@tokenizer.setter
def tokenizer(self, tokenizer: "PreTrainedTokenizerBase") -> None:
def tokenizer(self, tokenizer: PreTrainedTokenizerBase) -> None:
self._tokenizer = tokenizer

@property
Expand All @@ -150,7 +148,7 @@ def hf_pipeline_task(self) -> Optional[str]:
return None

@property
def hf_pipeline(self) -> "Pipeline":
def hf_pipeline(self) -> Pipeline:
if self._hf_pipeline is None:
if self.hf_pipeline_task is not None:
self._hf_pipeline = hf_transformers_pipeline(
Expand All @@ -161,7 +159,7 @@ def hf_pipeline(self) -> "Pipeline":
return self._hf_pipeline

@hf_pipeline.setter
def hf_pipeline(self, pipeline: "Pipeline") -> None:
def hf_pipeline(self, pipeline: Pipeline) -> None:
self._hf_pipeline = pipeline

def hf_predict(self, *args, **kwargs) -> Any:
Expand Down
8 changes: 3 additions & 5 deletions lightning_transformers/task/nlp/language_modeling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Type
from typing import Any, Type

import torch
import transformers
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class LanguageModelingTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Language Modeling Task.
Expand All @@ -33,7 +31,7 @@ class LanguageModelingTransformer(TaskTransformer):
"""

def __init__(
self, *args, downstream_model_type: Type["AutoModel"] = transformers.AutoModelForCausalLM, **kwargs
self, *args, downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForCausalLM, **kwargs
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Type
from typing import Type

import transformers
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class MaskedLanguageModelingTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Masked Language Modeling Task.
Expand All @@ -31,7 +29,7 @@ class MaskedLanguageModelingTransformer(TaskTransformer):
"""

def __init__(
self, *args, downstream_model_type: Type["AutoModel"] = transformers.AutoModelForMaskedLM, **kwargs
self, *args, downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForMaskedLM, **kwargs
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)

Expand Down
11 changes: 6 additions & 5 deletions lightning_transformers/task/nlp/multiple_choice/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Type
from typing import Type

import torch
import transformers
from torchmetrics import Accuracy, Precision, Recall
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class MultipleChoiceTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Multiple Choice Task.
Expand All @@ -34,7 +32,10 @@ class MultipleChoiceTransformer(TaskTransformer):
"""

def __init__(
self, *args, downstream_model_type: Type["AutoModel"] = transformers.AutoModelForMultipleChoice, **kwargs
self,
*args,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForMultipleChoice,
**kwargs,
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)

Expand Down
11 changes: 6 additions & 5 deletions lightning_transformers/task/nlp/question_answering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import TYPE_CHECKING, Any, Type
from typing import Any, Type

import torch
import transformers
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer
from lightning_transformers.task.nlp.question_answering import QuestionAnsweringDataModule
from lightning_transformers.task.nlp.question_answering.datasets.squad.metric import SquadMetric

if TYPE_CHECKING:
from transformers import AutoModel


class QuestionAnsweringTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Question Answering Task.
Expand All @@ -36,7 +34,10 @@ class QuestionAnsweringTransformer(TaskTransformer):
"""

def __init__(
self, *args, downstream_model_type: Type["AutoModel"] = transformers.AutoModelForQuestionAnswering, **kwargs
self,
*args,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForQuestionAnswering,
**kwargs,
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)

Expand Down
8 changes: 3 additions & 5 deletions lightning_transformers/task/nlp/summarization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Type
from typing import Type

import transformers
from torchmetrics.text.rouge import ROUGEScore
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core.seq2seq.model import Seq2SeqTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class SummarizationTransformer(Seq2SeqTransformer):
"""Defines ``LightningModule`` for the Summarization Task.
Expand All @@ -36,7 +34,7 @@ class SummarizationTransformer(Seq2SeqTransformer):
def __init__(
self,
*args,
downstream_model_type: Type["AutoModel"] = transformers.AutoModelForSeq2SeqLM,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForSeq2SeqLM,
use_stemmer: bool = True,
**kwargs
) -> None:
Expand Down
8 changes: 3 additions & 5 deletions lightning_transformers/task/nlp/text_classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Type
from typing import Any, Dict, Type

import torch
import transformers
from torchmetrics import Accuracy, Precision, Recall
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class TextClassificationTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Text Classification Task.
Expand All @@ -36,7 +34,7 @@ class TextClassificationTransformer(TaskTransformer):
def __init__(
self,
*args,
downstream_model_type: Type["AutoModel"] = transformers.AutoModelForSequenceClassification,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForSequenceClassification,
**kwargs,
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)
Expand Down
8 changes: 3 additions & 5 deletions lightning_transformers/task/nlp/token_classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Type, Union
from typing import Any, Dict, List, Type, Union

import torch
import transformers
from torchmetrics import Accuracy, F1Score, Precision, Recall
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel


class TokenClassificationTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Text Classification Task.
Expand All @@ -37,7 +35,7 @@ def __init__(
self,
*args,
labels: Union[int, List[str]],
downstream_model_type: Type["AutoModel"] = transformers.AutoModelForTokenClassification,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForTokenClassification,
**kwargs,
) -> None:
num_labels = labels if isinstance(labels, int) else len(labels)
Expand Down
8 changes: 3 additions & 5 deletions lightning_transformers/task/nlp/translation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Type
from typing import Type

import transformers
from torchmetrics.text.bleu import BLEUScore
from transformers import MBartTokenizer
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core.seq2seq.model import Seq2SeqTransformer
from lightning_transformers.task.nlp.translation import TranslationDataModule

if TYPE_CHECKING:
from transformers import AutoModel


class TranslationTransformer(Seq2SeqTransformer):
"""Defines ``LightningModule`` for the Translation Task.
Expand All @@ -39,7 +37,7 @@ class TranslationTransformer(Seq2SeqTransformer):
def __init__(
self,
*args,
downstream_model_type: Type["AutoModel"] = transformers.AutoModelForSeq2SeqLM,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForSeq2SeqLM,
n_gram: int = 4,
smooth: bool = False,
**kwargs,
Expand Down
13 changes: 7 additions & 6 deletions lightning_transformers/task/vision/image_classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Type
from typing import Any, Dict, Type

import torch
import transformers
from torchmetrics import Accuracy, Precision, Recall
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from lightning_transformers.core import TaskTransformer

if TYPE_CHECKING:
from transformers import AutoModel, Pipeline


class ImageClassificationTransformer(TaskTransformer):
"""Defines ``LightningModule`` for the Text Classification Task.
Expand All @@ -34,7 +32,10 @@ class ImageClassificationTransformer(TaskTransformer):
"""

def __init__(
self, *args, downstream_model_type: Type["AutoModel"] = transformers.AutoModelForImageClassification, **kwargs
self,
*args,
downstream_model_type: Type[_BaseAutoModelClass] = transformers.AutoModelForImageClassification,
**kwargs,
) -> None:
super().__init__(downstream_model_type, *args, **kwargs)
self.metrics = {}
Expand Down Expand Up @@ -83,6 +84,6 @@ def hf_pipeline_task(self) -> str:
return "image-classification"

@property
def hf_pipeline(self) -> "Pipeline":
def hf_pipeline(self) -> transformers.Pipeline:
self._hf_pipeline_kwargs["feature_extractor"] = self.tokenizer
return super().hf_pipeline
32 changes: 32 additions & 0 deletions tests/core/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from importlib.util import find_spec
from unittest import mock

import pytest
from pytorch_lightning.cli import LightningCLI

from lightning_transformers.task.vision.image_classification import (
ImageClassificationDataModule,
ImageClassificationTransformer,
)


@pytest.mark.skipif(find_spec("jsonargparse") is None, reason="jsonargparse is required")
def test_lightning_cli_image_classification():
config = {
"data": {
"dataset_name": "beans", # Resolve from TransformerDataModule.__init__
},
"model": {
"pretrained_model_name_or_path": "nateraw/tiny-vit-random", # Resolve from TaskTransformer.__init__
},
}
with mock.patch("sys.argv", ["any.py", f"--config={config}"]):
cli = LightningCLI(
ImageClassificationTransformer,
ImageClassificationDataModule,
run=False,
)
assert cli.config.data.dataset_name == "beans"
assert cli.config.model.pretrained_model_name_or_path == "nateraw/tiny-vit-random"
assert isinstance(cli.config_init.data, ImageClassificationDataModule)
assert isinstance(cli.config_init.model, ImageClassificationTransformer)
2 changes: 1 addition & 1 deletion tests/task/nlp/test_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_smoke_train(hf_cache_path):
dataset_name="conll2003",
preprocessing_num_workers=1,
label_all_tokens=False,
revision="master",
revision="main",
limit_test_samples=64,
limit_val_samples=64,
limit_train_samples=64,
Expand Down

0 comments on commit 461ad92

Please sign in to comment.