diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 024dc447f9..86957d139a 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -106,6 +106,11 @@ jobs: pip list shell: bash + - name: Install serve test dependencies + if: matrix.topic == 'serve' + run: | + pip install '.[all]' --pre --upgrade + - name: Cache datasets uses: actions/cache@v2 with: @@ -115,7 +120,8 @@ jobs: - name: Tests env: - FIFTYONE_DO_NOT_TRACK: true + FLASH_TEST_TOPIC: ${{ matrix.topic }} + FIFTYONE_DO_NOT_TRACK: true run: | # tox --sitepackages coverage run --source flash -m pytest flash tests -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml @@ -143,3 +149,8 @@ jobs: env_vars: OS,PYTHON name: codecov-umbrella fail_ci_if_error: false + + - name: Uninstall + run: | + pip uninstall lightning-flash -y + shell: bash diff --git a/CHANGELOG.md b/CHANGELOG.md index 4dae6cbb71..c5173d18c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389)) - Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390)) - Added support for multi-label text classification and toxic comments example ([#401](https://github.com/PyTorchLightning/lightning-flash/pull/401)) +- Added a sanity checking feature to flash.serve ([#423](https://github.com/PyTorchLightning/lightning-flash/pull/423)) ### Changed diff --git a/flash/__init__.py b/flash/__init__.py index a10576ff1f..7a13f9d20b 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -27,6 +27,7 @@ from flash.core.trainer import Trainer # noqa: E402 _PACKAGE_ROOT = os.path.dirname(__file__) + ASSETS_ROOT = os.path.join(_PACKAGE_ROOT, "assets") PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) _IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1" diff --git a/flash/assets/fish.jpg b/flash/assets/fish.jpg new file mode 100644 index 0000000000..76be7af0d7 Binary files /dev/null and b/flash/assets/fish.jpg differ diff --git a/flash/core/classification.py b/flash/core/classification.py index 350a6a0118..31459b14ce 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union import torch import torch.nn.functional as F @@ -24,7 +24,6 @@ from flash.core.utilities.imports import _FIFTYONE_AVAILABLE if _FIFTYONE_AVAILABLE: - import fiftyone as fo from fiftyone.core.labels import Classification, Classifications else: Classification, Classifications = None, None @@ -83,34 +82,43 @@ def multi_label(self) -> bool: return self._mutli_label -class Logits(ClassificationSerializer): +class PredsClassificationSerializer(ClassificationSerializer): + """A :class:`~flash.core.classification.ClassificationSerializer` which gets the + :attr:`~flash.core.data.data_source.DefaultDataKeys.PREDS` from the sample. + """ + + def serialize(self, sample: Any) -> Any: + if isinstance(sample, Mapping) and DefaultDataKeys.PREDS in sample: + sample = sample[DefaultDataKeys.PREDS] + if not isinstance(sample, torch.Tensor): + sample = torch.tensor(sample) + return sample + + +class Logits(PredsClassificationSerializer): """A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list.""" def serialize(self, sample: Any) -> Any: - sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample - sample = torch.tensor(sample) - return sample.tolist() + return super().serialize(sample).tolist() -class Probabilities(ClassificationSerializer): +class Probabilities(PredsClassificationSerializer): """A :class:`.Serializer` which applies a softmax to the model outputs (assumed to be logits) and converts to a list.""" def serialize(self, sample: Any) -> Any: - sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample - sample = torch.tensor(sample) + sample = super().serialize(sample) if self.multi_label: return torch.sigmoid(sample).tolist() return torch.softmax(sample, -1).tolist() -class Classes(ClassificationSerializer): +class Classes(PredsClassificationSerializer): """A :class:`.Serializer` which applies an argmax to the model outputs (either logits or probabilities) and converts to a list. Args: multi_label: If true, treats outputs as multi label logits. - threshold: The threshold to use for multi_label classification. """ @@ -120,8 +128,7 @@ def __init__(self, multi_label: bool = False, threshold: float = 0.5): self.threshold = threshold def serialize(self, sample: Any) -> Union[int, List[int]]: - sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample - sample = torch.tensor(sample) + sample = super().serialize(sample) if self.multi_label: one_hot = (sample.sigmoid() > self.threshold).int().tolist() result = [] @@ -139,9 +146,7 @@ class Labels(Classes): Args: labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not provided, will attempt to get them from the :class:`.LabelsState`. - multi_label: If true, treats outputs as multi label logits. - threshold: The threshold to use for multi_label classification. """ @@ -153,8 +158,6 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False self.set_state(LabelsState(labels)) def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: - sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample - sample = torch.tensor(sample) labels = None if self._labels is not None: diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index ba422b0d22..f22bf2892c 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -119,6 +119,10 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> data_pipeline_state._initialized = True # TODO: Not sure we need this return data_pipeline_state + @property + def example_input(self) -> str: + return self._deserializer.example_input + @staticmethod def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool: """ diff --git a/flash/core/data/process.py b/flash/core/data/process.py index c461768a36..f1bde6e2b5 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -13,7 +13,7 @@ # limitations under the License. import os from abc import ABC, abstractclassmethod, abstractmethod -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence import torch from pytorch_lightning.trainer.states import RunningStage @@ -569,6 +569,11 @@ class Deserializer(Properties): def deserialize(self, sample: Any) -> Any: # TODO: Output must be a tensor??? raise NotImplementedError + @property + @abstractmethod + def example_input(self) -> str: + pass + def __call__(self, sample: Any) -> Any: return self.deserialize(sample) diff --git a/flash/core/model.py b/flash/core/model.py index a52de73dfb..f08658fe2c 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -41,8 +41,9 @@ ) from flash.core.registry import FlashRegistry from flash.core.schedulers import _SCHEDULERS_REGISTRY -from flash.core.serve import Composition, expose, ModelComponent +from flash.core.serve import Composition from flash.core.utilities.apply_func import get_callable_dict +from flash.core.utilities.imports import _SERVE_AVAILABLE class BenchmarkConvergenceCI(Callback): @@ -390,12 +391,18 @@ def build_data_pipeline( else: data_source = preprocess.data_source_of_name(data_source) - deserializer = deserializer or getattr(preprocess, "deserializer", None) + if deserializer is None or type(deserializer) == Deserializer: + deserializer = getattr(preprocess, "deserializer", deserializer) data_pipeline = DataPipeline(data_source, preprocess, postprocess, deserializer, serializer) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) return data_pipeline + @torch.jit.unused + @property + def is_servable(self) -> bool: + return type(self.build_data_pipeline()._deserializer) != Deserializer + @torch.jit.unused @property def data_pipeline(self) -> DataPipeline: @@ -592,41 +599,39 @@ def configure_callbacks(self): if flash._IS_TESTING and torch.cuda.is_available(): return [BenchmarkConvergenceCI()] - def serve(self, host: str = "127.0.0.1", port: int = 8000) -> 'Composition': - from flash.core.serve.flash_components import FlashInputs, FlashOutputs + def run_serve_sanity_check(self): + if not _SERVE_AVAILABLE: + raise ModuleNotFoundError("Please, pip install 'lightning-flash[serve]'") + if not self.is_servable: + raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.") - class FlashServeModelComponent(ModelComponent): + from fastapi.testclient import TestClient - def __init__(self, model): - self.model = model - self.model.eval() - self.data_pipeline = self.model.build_data_pipeline() - self.worker_preprocessor = self.data_pipeline.worker_preprocessor( - RunningStage.PREDICTING, is_serving=True - ) - self.device_preprocessor = self.data_pipeline.device_preprocessor(RunningStage.PREDICTING) - self.postprocessor = self.data_pipeline.postprocessor(RunningStage.PREDICTING, is_serving=True) - # todo (tchaton) Remove this hack - self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3 - self.device = self.model.device - - @expose( - inputs={"inputs": FlashInputs(self.data_pipeline.deserialize_processor())}, - outputs={"outputs": FlashOutputs(self.data_pipeline.serialize_processor())}, - ) - def predict(self, inputs): - with torch.no_grad(): - inputs = self.worker_preprocessor(inputs) - if self.extra_arguments: - inputs = self.model.transfer_batch_to_device(inputs, self.device, 0) - else: - inputs = self.model.transfer_batch_to_device(inputs, self.device) - inputs = self.device_preprocessor(inputs) - preds = self.model.predict_step(inputs, 0) - preds = self.postprocessor(preds) - return preds - - comp = FlashServeModelComponent(self) - composition = Composition(predict=comp) + from flash.core.serve.flash_components import build_flash_serve_model_component + + print("Running serve sanity check") + comp = build_flash_serve_model_component(self) + composition = Composition(predict=comp, TESTING=True, DEBUG=True) + app = composition.serve(host="0.0.0.0", port=8000) + + with TestClient(app) as tc: + input_str = self.data_pipeline._deserializer.example_input + body = {"session": "UUID", "payload": {"inputs": {"data": input_str}}} + resp = tc.post("http://0.0.0.0:8000/predict", json=body) + print(f"Sanity check response: {resp.json()}") + + def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> 'Composition': + if not _SERVE_AVAILABLE: + raise ModuleNotFoundError("Please, pip install 'lightning-flash[serve]'") + if not self.is_servable: + raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.") + + from flash.core.serve.flash_components import build_flash_serve_model_component + + if sanity_check: + self.run_serve_sanity_check() + + comp = build_flash_serve_model_component(self) + composition = Composition(predict=comp, TESTING=flash._IS_TESTING) composition.serve(host=host, port=port) return composition diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index 1f549be029..5e1f745c2a 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -1,9 +1,12 @@ +import inspect from typing import Any, Callable, Mapping, Optional import torch +from pytorch_lightning.trainer.states import RunningStage from flash import Task from flash.core.data.data_source import DefaultDataKeys +from flash.core.serve import expose, ModelComponent from flash.core.serve.core import FilePath, GridserveScriptLoader from flash.core.serve.types.base import BaseType @@ -54,3 +57,39 @@ class FlashServeScriptLoader(GridserveScriptLoader): def __init__(self, location: FilePath): self.location = location self.instance = self.model_cls.load_from_checkpoint(location) + + +def build_flash_serve_model_component(model): + + data_pipeline = model.build_data_pipeline() + + class FlashServeModelComponent(ModelComponent): + + def __init__(self, model): + self.model = model + self.model.eval() + self.data_pipeline = model.build_data_pipeline() + self.worker_preprocessor = self.data_pipeline.worker_preprocessor(RunningStage.PREDICTING, is_serving=True) + self.device_preprocessor = self.data_pipeline.device_preprocessor(RunningStage.PREDICTING) + self.postprocessor = self.data_pipeline.postprocessor(RunningStage.PREDICTING, is_serving=True) + # todo (tchaton) Remove this hack + self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3 + self.device = self.model.device + + @expose( + inputs={"inputs": FlashInputs(data_pipeline.deserialize_processor())}, + outputs={"outputs": FlashOutputs(data_pipeline.serialize_processor())}, + ) + def predict(self, inputs): + with torch.no_grad(): + inputs = self.worker_preprocessor(inputs) + if self.extra_arguments: + inputs = self.model.transfer_batch_to_device(inputs, self.device, 0) + else: + inputs = self.model.transfer_batch_to_device(inputs, self.device) + inputs = self.device_preprocessor(inputs) + preds = self.model.predict_step(inputs, 0) + preds = self.postprocessor(preds) + return preds + + return FlashServeModelComponent(model) diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 646a6f7581..08884d4940 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -import os import warnings from argparse import ArgumentParser, Namespace from functools import wraps @@ -29,6 +28,7 @@ import flash from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks +from flash.core.utilities.imports import _SERVE_AVAILABLE def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): @@ -72,7 +72,7 @@ def insert_env_defaults(self, *args, **kwargs): class Trainer(PlTrainer): @_defaults_from_env_vars - def __init__(self, *args, **kwargs): + def __init__(self, *args, serve_sanity_check: bool = True, **kwargs): if flash._IS_TESTING: if torch.cuda.is_available(): kwargs["gpus"] = 1 @@ -85,6 +85,14 @@ def __init__(self, *args, **kwargs): kwargs["fast_dev_run"] = True super().__init__(*args, **kwargs) + self.serve_sanity_check = serve_sanity_check + + def run_sanity_check(self, ref_model): + super().run_sanity_check(ref_model) + + if self.serve_sanity_check and ref_model.is_servable and _SERVE_AVAILABLE: + ref_model.run_serve_sanity_check() + def fit( self, model: LightningModule, diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 75dc93e605..6b29f233fe 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" - import importlib import operator from importlib.util import find_spec diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 474037e176..0c1c7f1ddf 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import base64 -from io import BytesIO from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -25,43 +23,29 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources from flash.core.data.process import Deserializer, Preprocess -from flash.core.utilities.imports import _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE from flash.image.classification.transforms import default_transforms, train_default_transforms -from flash.image.data import ImageFiftyOneDataSource, ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource +from flash.image.data import ( + ImageDeserializer, + ImageFiftyOneDataSource, + ImageNumpyDataSource, + ImagePathsDataSource, + ImageTensorDataSource, +) if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: plt = None -if _TORCHVISION_AVAILABLE: - import torchvision - if _IMAGE_AVAILABLE: from PIL import Image - from PIL import Image as PILImage else: class Image: Image = None -class ImageClassificationDeserializer(Deserializer): - - def __init__(self): - - self.to_tensor = torchvision.transforms.ToTensor() - - def deserialize(self, data: str) -> Dict: - encoded_with_padding = (data + "===").encode("ascii") - img = base64.b64decode(encoded_with_padding) - buffer = BytesIO(img) - img = PILImage.open(buffer, mode="r") - return { - DefaultDataKeys.INPUT: img, - } - - class ImageClassificationPreprocess(Preprocess): def __init__( @@ -88,7 +72,7 @@ def __init__( DefaultDataSources.NUMPY: ImageNumpyDataSource(), DefaultDataSources.TENSORS: ImageTensorDataSource(), }, - deserializer=deserializer or ImageClassificationDeserializer(), + deserializer=deserializer or ImageDeserializer(), default_data_source=DefaultDataSources.FILES, ) diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 75a2dd0a49..618fb0f275 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -19,7 +19,7 @@ from torch import nn from torch.optim.lr_scheduler import _LRScheduler -from flash.core.classification import ClassificationTask +from flash.core.classification import ClassificationTask, Labels from flash.core.data.data_source import DefaultDataKeys from flash.core.data.process import Serializer from flash.core.registry import FlashRegistry @@ -91,7 +91,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, - serializer=serializer, + serializer=serializer or Labels(), ) self.save_hyperparameters() diff --git a/flash/image/data.py b/flash/image/data.py index 69fd25e657..4fbeb2cb1c 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 +from io import BytesIO +from pathlib import Path from typing import Any, Dict, Optional import torch +import flash from flash.core.data.data_source import ( DefaultDataKeys, FiftyOneDataSource, @@ -22,14 +26,44 @@ PathsDataSource, TensorDataSource, ) -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from flash.core.data.process import Deserializer +from flash.core.utilities.imports import _IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: + import torchvision from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS from torchvision.transforms.functional import to_pil_image else: IMG_EXTENSIONS = [] +if _IMAGE_AVAILABLE: + from PIL import Image as PILImage +else: + + class Image: + Image = None + + +class ImageDeserializer(Deserializer): + + def __init__(self): + super().__init__() + self.to_tensor = torchvision.transforms.ToTensor() + + def deserialize(self, data: str) -> Dict: + encoded_with_padding = (data + "===").encode("ascii") + img = base64.b64decode(encoded_with_padding) + buffer = BytesIO(img) + img = PILImage.open(buffer, mode="r") + return { + DefaultDataKeys.INPUT: img, + } + + @property + def example_input(self) -> str: + with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f: + return base64.b64encode(f.read()).decode("UTF-8") + class ImagePathsDataSource(PathsDataSource): diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index a21517704d..3edd64398b 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -38,6 +38,7 @@ ) from flash.core.data.process import Deserializer, Preprocess from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE +from flash.image.data import ImageDeserializer from flash.image.segmentation.serialization import SegmentationLabels from flash.image.segmentation.transforms import default_transforms, train_default_transforms @@ -215,19 +216,13 @@ def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: return sample -class SemanticSegmentationDeserializer(Deserializer): - - def __init__(self): - - self.to_tensor = torchvision.transforms.ToTensor() +class SemanticSegmentationDeserializer(ImageDeserializer): def deserialize(self, data: str) -> torch.Tensor: - encoded_with_padding = (data + "===").encode("ascii") - img = base64.b64decode(encoded_with_padding) - buffer = BytesIO(img) - img = PILImage.open(buffer, mode="r") - img = self.to_tensor(img) - return {DefaultDataKeys.INPUT: img, DefaultDataKeys.METADATA: {"size": img.shape}} + result = super().deserialize(data) + result[DefaultDataKeys.INPUT] = self.to_tensor(result[DefaultDataKeys.INPUT]) + result[DefaultDataKeys.METADATA] = {"size": result[DefaultDataKeys.INPUT].shape} + return result class SemanticSegmentationPreprocess(Preprocess): diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index bcf770f694..c2a60e24da 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from io import StringIO from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -123,7 +124,7 @@ def __init__( classes: Optional[List[str]] = None, is_regression: bool = True ): - + super().__init__() self.cat_cols = cat_cols self.num_cols = num_cols self.target_col = target_col @@ -134,20 +135,8 @@ def __init__( self.classes = classes self.is_regression = is_regression - @staticmethod - def _convert_row(row): - _row = [] - for c in row: - try: - _row.append(float(c)) - except Exception: - _row.append(c) - return _row - def deserialize(self, data: str) -> Any: - columns = data.split("\n")[0].split(',') - df = pd.DataFrame([TabularDeserializer._convert_row(x.split(',')[1:]) for x in data.split('\n')[1:-1]], - columns=columns) + df = pd.read_csv(StringIO(data)) df = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, self.target_codes)[0] @@ -159,6 +148,15 @@ def deserialize(self, data: str) -> Any: return [{DefaultDataKeys.INPUT: [c, n]} for c, n in zip(cat_vars, num_vars)] + @property + def example_input(self) -> str: + row = {} + for cat_col in self.cat_cols: + row[cat_col] = ["test"] + for num_col in self.num_cols: + row[num_col] = [0] + return str(DataFrame.from_dict(row).to_csv()) + class TabularPreprocess(Preprocess): diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 543ee24b2b..5aad7a05d1 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -33,12 +33,17 @@ class TextDeserializer(Deserializer): def __init__(self, backbone: str, max_length: int, use_fast: bool = True): + super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=use_fast) self.max_length = max_length def deserialize(self, text: str) -> Tensor: return self.tokenizer(text, max_length=self.max_length, truncation=True, padding="max_length") + @property + def example_input(self) -> str: + return "An example input" + class TextDataSource(DataSource): diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 13b9210412..db4028770e 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -48,10 +48,10 @@ def fn_resnet(pretrained: bool = True): print(ImageClassifier.available_backbones()) # 4. Build the model -model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, serializer=Labels()) +model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) # 5. Create the trainer -trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) +trainer = flash.Trainer(max_epochs=3) # 6. Train the model trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) diff --git a/flash_examples/finetuning/image_classification_multi_label.py b/flash_examples/finetuning/image_classification_multi_label.py index 5b55e4fec2..64d2ea6cde 100644 --- a/flash_examples/finetuning/image_classification_multi_label.py +++ b/flash_examples/finetuning/image_classification_multi_label.py @@ -59,7 +59,7 @@ def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], L ) # 4. Create the trainer -trainer = flash.Trainer(fast_dev_run=True) +trainer = flash.Trainer(max_epochs=1) # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index d7f75d1e46..7c5233287c 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -46,7 +46,7 @@ backbone="mobilenet_v3_large", head="fcn", num_classes=datamodule.num_classes, - serializer=SegmentationLabels(visualize=True), + serializer=SegmentationLabels(visualize=False), ) # 4. Create the trainer. diff --git a/tests/conftest.py b/tests/conftest.py index 868cd06df4..20827b48b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ from flash.core.serve.decorators import uuid4 # noqa (used in mocker.patch) from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from tests.helpers.utils import _SERVE_TESTING if _TORCHVISION_AVAILABLE: import torchvision @@ -60,7 +61,7 @@ def global_datadir(tmp_path_factory, original_global_datadir): return prep_global_datadir(tmp_path_factory, original_global_datadir) -if _TORCHVISION_AVAILABLE: +if _SERVE_TESTING: @pytest.fixture(scope="session") def squeezenet1_1_model(): diff --git a/tests/core/data/test_base_viz.py b/tests/core/data/test_base_viz.py index eaaf819ccc..9df0fbb290 100644 --- a/tests/core/data/test_base_viz.py +++ b/tests/core/data/test_base_viz.py @@ -26,6 +26,7 @@ from flash.core.data.utils import _CALLBACK_FUNCS, _STAGES_PREFIX from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image import ImageClassificationData +from tests.helpers.utils import _IMAGE_TESTING if _IMAGE_AVAILABLE: from PIL import Image @@ -74,7 +75,7 @@ def check_reset(self): self.per_batch_transform_called = False -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") class TestBaseViz: def test_base_viz(self, tmpdir): diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index 2333167e5b..d8398b60ac 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -34,6 +34,7 @@ from flash.core.data.properties import ProcessState from flash.core.model import Task from flash.core.utilities.imports import _IMAGE_AVAILABLE +from tests.helpers.utils import _IMAGE_TESTING if _IMAGE_AVAILABLE: import torchvision.transforms as T @@ -764,7 +765,7 @@ def val_collate(self, *_): assert not DataPipeline._is_overriden_recursive("chocolate", preprocess, Preprocess) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @patch("torch.save") # need to mock torch.save or we get pickle error def test_dummy_example(tmpdir): diff --git a/tests/core/serve/test_components.py b/tests/core/serve/test_components.py index 43e166d4e3..cd1b8a87ad 100644 --- a/tests/core/serve/test_components.py +++ b/tests/core/serve/test_components.py @@ -2,11 +2,11 @@ import torch from flash.core.serve.types import Label -from flash.core.utilities.imports import _SERVE_AVAILABLE, _TORCHVISION_AVAILABLE from tests.core.serve.models import ClassificationInferenceComposable, LightningSqueezenet +from tests.helpers.utils import _SERVE_TESTING -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_model_compute_call_method(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) img = torch.arange(195075).reshape((1, 255, 255, 3)) @@ -15,7 +15,7 @@ def test_model_compute_call_method(lightning_squeezenet1_1_obj): assert out_res.item() == 753 -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_model_compute_dependencies(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) @@ -31,7 +31,7 @@ def test_model_compute_dependencies(lightning_squeezenet1_1_obj): assert list(comp2._gridserve_meta_.connections) == [] -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_inverse_model_compute_component_dependencies(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) @@ -48,7 +48,7 @@ def test_inverse_model_compute_component_dependencies(lightning_squeezenet1_1_ob assert list(comp1._gridserve_meta_.connections) == [] -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_two_component_invalid_dependencies_fail(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) @@ -83,7 +83,7 @@ def __init__(self): comp1.inputs["tag"] >> foo -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_component_initialization(lightning_squeezenet1_1_obj): with pytest.raises(TypeError): ClassificationInferenceComposable(wrongname=lightning_squeezenet1_1_obj) @@ -98,7 +98,7 @@ def test_component_initialization(lightning_squeezenet1_1_obj): assert "predicted_tag" in comp.outputs -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_component_parameters(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) @@ -118,7 +118,7 @@ def test_component_parameters(lightning_squeezenet1_1_obj): assert first_tag.connections == comp1._gridserve_meta_.connections -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_invalid_expose_inputs(): from flash.core.serve import expose, ModelComponent from flash.core.serve.types import Number @@ -178,7 +178,7 @@ def predict(self, param): _ = ComposeClassEmptyExposeInputsType(lr) -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_connection_invalid_raises(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) @@ -194,7 +194,7 @@ class FakeParam: comp1.outputs.predicted_tag >> fake_param -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_invalid_name(lightning_squeezenet1_1_obj): from flash.core.serve import expose, ModelComponent from flash.core.serve.types import Number @@ -211,7 +211,7 @@ def predict(self, param): return param -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_invalid_config_args(lightning_squeezenet1_1_obj): from flash.core.serve import expose, ModelComponent from flash.core.serve.types import Number @@ -239,7 +239,7 @@ def predict(self, param): _ = SomeComponent(lightning_squeezenet1_1_obj, config={"key": lambda x: x}) -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_invalid_model_args(lightning_squeezenet1_1_obj): from flash.core.serve import expose, ModelComponent from flash.core.serve.types import Number @@ -270,7 +270,7 @@ def predict(self, param): _ = SomeComponent({"first": lightning_squeezenet1_1_obj, "second": 233}) -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_create_invalid_endpoint(lightning_squeezenet1_1_obj): from flash.core.serve import Endpoint diff --git a/tests/core/serve/test_composition.py b/tests/core/serve/test_composition.py index 803448b511..53ab457667 100644 --- a/tests/core/serve/test_composition.py +++ b/tests/core/serve/test_composition.py @@ -4,13 +4,14 @@ import pytest from flash.core.serve import Composition, Endpoint -from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _SERVE_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _FASTAPI_AVAILABLE +from tests.helpers.utils import _SERVE_TESTING if _FASTAPI_AVAILABLE: from fastapi.testclient import TestClient -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_composit_endpoint_data(lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceComposable @@ -62,7 +63,7 @@ def test_composit_endpoint_data(lightning_squeezenet1_1_obj): } -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_endpoint_errors_on_wrong_key_name(lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceComposable @@ -142,7 +143,7 @@ def test_endpoint_errors_on_wrong_key_name(lightning_squeezenet1_1_obj): _ = Composition(comp1=comp1, predict_ep=ep) -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_composition_recieve_wrong_arg_type(lightning_squeezenet1_1_obj): # no endpoints or components with pytest.raises(TypeError): @@ -158,7 +159,7 @@ def test_composition_recieve_wrong_arg_type(lightning_squeezenet1_1_obj): _ = Composition(c1=comp1, c2=comp2) -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_gridmodel_sequence(tmp_path, lightning_squeezenet1_1_obj, squeezenet_gridmodel): from tests.core.serve.models import ClassificationInferenceModelSequence @@ -172,7 +173,7 @@ def test_gridmodel_sequence(tmp_path, lightning_squeezenet1_1_obj, squeezenet_gr assert composit.components["callnum_1"].model2 == model_seq[1] -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_gridmodel_mapping(tmp_path, lightning_squeezenet1_1_obj, squeezenet_gridmodel): from tests.core.serve.models import ClassificationInferenceModelMapping @@ -186,7 +187,7 @@ def test_gridmodel_mapping(tmp_path, lightning_squeezenet1_1_obj, squeezenet_gri assert composit.components["callnum_1"].model2 == model_map["model_two"] -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_invalid_gridmodel_composition(tmp_path, lightning_squeezenet1_1_obj, squeezenet_gridmodel): from tests.core.serve.models import ClassificationInferenceModelMapping @@ -200,7 +201,7 @@ def test_invalid_gridmodel_composition(tmp_path, lightning_squeezenet1_1_obj, sq _ = ClassificationInferenceModelMapping(lambda x: x + 1) -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_complex_spec_single_endpoint(tmp_path, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceComposable @@ -251,7 +252,7 @@ def test_complex_spec_single_endpoint(tmp_path, lightning_squeezenet1_1_obj): } -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_complex_spec_multiple_endpoints(tmp_path, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceComposable @@ -327,7 +328,7 @@ def test_complex_spec_multiple_endpoints(tmp_path, lightning_squeezenet1_1_obj): } -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_start_server_from_composition(tmp_path, squeezenet_gridmodel, session_global_datadir): from tests.core.serve.models import ClassificationInferenceComposable diff --git a/tests/core/serve/test_gridbase_validations.py b/tests/core/serve/test_gridbase_validations.py index c0601eaf01..7d41873012 100644 --- a/tests/core/serve/test_gridbase_validations.py +++ b/tests/core/serve/test_gridbase_validations.py @@ -2,7 +2,8 @@ from flash.core.serve import expose, ModelComponent from flash.core.serve.types import Number -from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE +from tests.helpers.utils import _SERVE_TESTING @pytest.mark.skipif(not _CYTOOLZ_AVAILABLE, reason="the library cytoolz is not installed.") @@ -173,7 +174,7 @@ def predict(self, param): return param -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_ModelComponent_raises_if_exposed_input_keys_differ_from_decorated_method_parameters( lightning_squeezenet1_1_obj, ): @@ -199,9 +200,7 @@ def predict(self, param): _ = FailedExposedDecorator(comp) -@pytest.mark.skipif( - not (_SERVE_AVAILABLE and _CYTOOLZ_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve is not installed." -) +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve is not installed.") def test_ModelComponent_raises_if_config_is_empty_dict(lightning_squeezenet1_1_obj): """This occurs when the instance is being initialized. diff --git a/tests/core/serve/test_integration.py b/tests/core/serve/test_integration.py index c846dd1842..36595a0f0d 100644 --- a/tests/core/serve/test_integration.py +++ b/tests/core/serve/test_integration.py @@ -3,13 +3,14 @@ import pytest from flash.core.serve import Composition, Endpoint -from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _SERVE_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _FASTAPI_AVAILABLE +from tests.helpers.utils import _SERVE_TESTING if _FASTAPI_AVAILABLE: from fastapi.testclient import TestClient -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_resnet_18_inference_class(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference @@ -37,7 +38,7 @@ def test_resnet_18_inference_class(session_global_datadir, lightning_squeezenet1 assert expected == resp.json() -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_start_server_with_repeated_exposed(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceRepeated @@ -63,7 +64,7 @@ def test_start_server_with_repeated_exposed(session_global_datadir, lightning_sq assert resp.json() == expected -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_serving_single_component_and_endpoint_no_composition(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference @@ -170,7 +171,7 @@ def test_serving_single_component_and_endpoint_no_composition(session_global_dat } -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference, SeatClassifier @@ -235,7 +236,7 @@ def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj): assert resp.template.name == "dag.html" -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_composed_does_not_eliminate_endpoint_serialization(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference, SeatClassifier @@ -321,7 +322,7 @@ def test_composed_does_not_eliminate_endpoint_serialization(session_global_datad assert resp.template.name == "dag.html" -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference, SeatClassifier @@ -461,7 +462,7 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ } -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_cycle_in_connection_fails(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceComposable @@ -471,7 +472,7 @@ def test_cycle_in_connection_fails(session_global_datadir, lightning_squeezenet1 c1.outputs.cropped_img >> c1.inputs.img -@pytest.mark.skipif(not (_SERVE_AVAILABLE and _TORCHVISION_AVAILABLE), reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_composition_from_url_torchscript_gridmodel(tmp_path): from flash.core.serve import expose, GridModel, ModelComponent from flash.core.serve.types import Number diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 59953d295d..356ec27410 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -31,17 +31,13 @@ from flash.core.data.process import DefaultPreprocess, Postprocess from flash.core.utilities.imports import _IMAGE_AVAILABLE, _TABULAR_AVAILABLE, _TEXT_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier +from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING if _TABULAR_AVAILABLE: from flash.tabular import TabularClassifier else: TabularClassifier = None -if _TEXT_AVAILABLE: - from flash.text import TextClassifier -else: - TextClassifier = None - if _IMAGE_AVAILABLE: from PIL import Image else: @@ -104,8 +100,8 @@ def test_classificationtask_task_predict(): assert pred0[0] == pred1[0] -@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@mock.patch("flash._IS_TESTING", True) +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_classification_task_predict_folder_path(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() @@ -168,7 +164,7 @@ def test_task_datapipeline_save(tmpdir): ImageClassifier, "image_classification_model.pt", marks=pytest.mark.skipif( - not _IMAGE_AVAILABLE, + not _IMAGE_TESTING, reason="image packages aren't installed", ) ), @@ -176,7 +172,7 @@ def test_task_datapipeline_save(tmpdir): TabularClassifier, "tabular_classification_model.pt", marks=pytest.mark.skipif( - not _TABULAR_AVAILABLE, + not _TABULAR_TESTING, reason="tabular packages aren't installed", ) ), @@ -188,7 +184,7 @@ def test_model_download(tmpdir, cls, filename): assert isinstance(task, cls) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_available_backbones(): backbones = ImageClassifier.available_backbones() assert "resnet152" in backbones @@ -232,11 +228,7 @@ def test_optimization(tmpdir): if _TEXT_AVAILABLE: from transformers.optimization import get_linear_schedule_with_warmup - assert task.available_schedulers() == [ - 'constant_schedule', 'constant_schedule_with_warmup', 'cosine_schedule_with_warmup', - 'cosine_with_hard_restarts_schedule_with_warmup', 'linear_schedule_with_warmup', - 'polynomial_decay_schedule_with_warmup' - ] + assert isinstance(task.available_schedulers(), list) optim = torch.optim.Adadelta(model.parameters()) with pytest.raises(MisconfigurationException, match="The LightningModule isn't attached to the trainer yet."): diff --git a/tests/core/test_registry.py b/tests/core/test_registry.py index 9c6669c3b8..061c6f4504 100644 --- a/tests/core/test_registry.py +++ b/tests/core/test_registry.py @@ -14,7 +14,6 @@ import logging import pytest -import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 9b440e6f6a..896178e7d0 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -12,26 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import subprocess -import sys from pathlib import Path -from typing import List, Optional, Tuple from unittest import mock import pytest -from flash.core.utilities.imports import ( - _IMAGE_AVAILABLE, - _PYSTICHE_GREATER_EQUAL_0_7_2, - _SKLEARN_AVAILABLE, - _TABULAR_AVAILABLE, - _TEXT_AVAILABLE, - _TORCHVISION_GREATER_EQUAL_0_9, - _VIDEO_AVAILABLE, -) +from flash.core.utilities.imports import _SKLEARN_AVAILABLE from tests.examples.utils import run_test - -_IMAGE_AVAILABLE = _IMAGE_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_9 +from tests.helpers.utils import ( + _IMAGE_STLYE_TRANSFER_TESTING, + _IMAGE_TESTING, + _TABULAR_TESTING, + _TEXT_TESTING, + _VIDEO_TESTING, +) root = Path(__file__).parent.parent.parent @@ -43,28 +37,28 @@ pytest.param( "finetuning", "image_classification.py", - marks=pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") ), pytest.param( "finetuning", "image_classification_multi_label.py", - marks=pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") ), # pytest.param("finetuning", "object_detection.py"), # TODO: takes too long. pytest.param( "finetuning", "semantic_segmentation.py", - marks=pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") ), pytest.param( "finetuning", "summarization.py", - marks=pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed") + marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") ), pytest.param( "finetuning", "tabular_classification.py", - marks=pytest.mark.skipif(not _TABULAR_AVAILABLE, reason="tabular libraries aren't installed") + marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed") ), # pytest.param("finetuning", "video_classification.py"), # pytest.param("finetuning", "text_classification.py"), # TODO: takes too long @@ -76,48 +70,48 @@ pytest.param( "finetuning", "translation.py", - marks=pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed") + marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") ), pytest.param( "finetuning", "style_transfer.py", - marks=pytest.mark.skipif(not _PYSTICHE_GREATER_EQUAL_0_7_2, reason="pystiche is not installed") + marks=pytest.mark.skipif(not _IMAGE_STLYE_TRANSFER_TESTING, reason="pystiche is not installed") ), pytest.param( "predict", "image_classification.py", - marks=pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") ), pytest.param( "predict", "image_classification_multi_label.py", - marks=pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") ), pytest.param( "predict", "semantic_segmentation.py", - marks=pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") ), pytest.param( "predict", "tabular_classification.py", - marks=pytest.mark.skipif(not _TABULAR_AVAILABLE, reason="tabular libraries aren't installed") + marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed") ), # pytest.param("predict", "text_classification.py"), pytest.param( "predict", "image_embedder.py", - marks=pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") ), pytest.param( "predict", "video_classification.py", - marks=pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="video libraries aren't installed") + marks=pytest.mark.skipif(not _VIDEO_TESTING, reason="video libraries aren't installed") ), pytest.param( "predict", "summarization.py", - marks=pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed") + marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") ), pytest.param( "predict", @@ -127,7 +121,7 @@ pytest.param( "predict", "translation.py", - marks=pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed") + marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") ), ] ) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py new file mode 100644 index 0000000000..241cb4d10d --- /dev/null +++ b/tests/helpers/utils.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from flash.core.utilities.imports import ( + _IMAGE_AVAILABLE, + _IMAGE_STLYE_TRANSFER, + _SERVE_AVAILABLE, + _TABULAR_AVAILABLE, + _TEXT_AVAILABLE, + _VIDEO_AVAILABLE, +) + +_IMAGE_TESTING = _IMAGE_AVAILABLE +_VIDEO_TESTING = _VIDEO_AVAILABLE +_TABULAR_TESTING = _TABULAR_AVAILABLE +_TEXT_TESTING = _TEXT_AVAILABLE +_IMAGE_STLYE_TRANSFER_TESTING = _IMAGE_STLYE_TRANSFER +_SERVE_TESTING = _SERVE_AVAILABLE + +if "FLASH_TEST_TOPIC" in os.environ: + topic = os.environ["FLASH_TEST_TOPIC"] + _IMAGE_TESTING = topic == "image" + _VIDEO_TESTING = topic == "video" + _TABULAR_TESTING = topic == "tabular" + _TEXT_TESTING = topic == "text" + _IMAGE_STLYE_TRANSFER_TESTING = topic == "image_style_transfer" + _SERVE_TESTING = topic == "serve" diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index d16a9c0d0a..3d6b52a0ce 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -23,9 +23,9 @@ from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ImageClassificationData +from tests.helpers.utils import _IMAGE_TESTING if _IMAGE_AVAILABLE: - import kornia as K import torchvision from PIL import Image @@ -44,7 +44,7 @@ def _rand_image(size: Tuple[int, int] = None): return Image.fromarray(np.random.randint(0, 255, (*size, 3), dtype="uint8")) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_filepaths_smoke(tmpdir): tmpdir = Path(tmpdir) @@ -75,7 +75,7 @@ def test_from_filepaths_smoke(tmpdir): assert sorted(list(labels.numpy())) == [1, 2] -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_filepaths_list_image_paths(tmpdir): tmpdir = Path(tmpdir) @@ -122,7 +122,7 @@ def test_from_filepaths_list_image_paths(tmpdir): assert list(labels.numpy()) == [2, 5] -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_filepaths_visualise(tmpdir): tmpdir = Path(tmpdir) @@ -157,7 +157,7 @@ def test_from_filepaths_visualise(tmpdir): dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_filepaths_visualise_multilabel(tmpdir): tmpdir = Path(tmpdir) @@ -193,7 +193,7 @@ def test_from_filepaths_visualise_multilabel(tmpdir): dm.show_val_batch("per_batch_transform") -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_filepaths_splits(tmpdir): tmpdir = Path(tmpdir) @@ -238,7 +238,7 @@ def run(transform: Any = None): run(_to_tensor) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_folders_only_train(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() @@ -262,7 +262,7 @@ def test_from_folders_only_train(tmpdir): assert img_data.test_dataloader() is None -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_folders_train_val(tmpdir): train_dir = Path(tmpdir / "train") @@ -301,7 +301,7 @@ def test_from_folders_train_val(tmpdir): assert list(labels.numpy()) == [0, 0] -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_filepaths_multilabel(tmpdir): tmpdir = Path(tmpdir) @@ -343,7 +343,7 @@ def test_from_filepaths_multilabel(tmpdir): torch.testing.assert_allclose(labels, torch.tensor(test_labels)) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize( "data,from_function", [ @@ -386,7 +386,7 @@ def test_from_data(data, from_function): assert list(labels.numpy()) == [2, 5] -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") def test_from_fiftyone(tmpdir): tmpdir = Path(tmpdir) diff --git a/tests/image/classification/test_data_model_integration.py b/tests/image/classification/test_data_model_integration.py index 711bcc329f..befbf4f4e8 100644 --- a/tests/image/classification/test_data_model_integration.py +++ b/tests/image/classification/test_data_model_integration.py @@ -20,6 +20,7 @@ from flash import Trainer from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier +from tests.helpers.utils import _IMAGE_TESTING if _IMAGE_AVAILABLE: from PIL import Image @@ -36,7 +37,7 @@ def _rand_image(): return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_classification(tmpdir): tmpdir = Path(tmpdir) @@ -61,7 +62,7 @@ def test_classification(tmpdir): trainer.finetune(model, datamodule=data, strategy="freeze") -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") def test_classification_fiftyone(tmpdir): tmpdir = Path(tmpdir) diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index a03bd16a54..2becff18d0 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest import mock import pytest import torch @@ -19,8 +20,9 @@ from flash import Trainer from flash.core.classification import Probabilities from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image import ImageClassifier +from flash.image.classification.data import ImageClassificationPreprocess +from tests.helpers.utils import _IMAGE_TESTING, _SERVE_TESTING # ======== Mock functions ======== @@ -55,7 +57,7 @@ def __len__(self) -> int: # ============================== -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize( "backbone", [ @@ -73,13 +75,13 @@ def test_init_train(tmpdir, backbone): trainer.finetune(model, train_dl, strategy="freeze_unfreeze") -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_non_existent_backbone(): with pytest.raises(KeyError): ImageClassifier(2, "i am never going to implement this lol") -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_freeze(): model = ImageClassifier(2) model.freeze() @@ -87,7 +89,7 @@ def test_freeze(): assert p.requires_grad is False -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_unfreeze(): model = ImageClassifier(2) model.unfreeze() @@ -95,7 +97,7 @@ def test_unfreeze(): assert p.requires_grad is True -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_multilabel(tmpdir): num_classes = 4 @@ -112,7 +114,7 @@ def test_multilabel(tmpdir): assert len(torch.unique(label)) <= 2 -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") @@ -128,3 +130,13 @@ def test_jit(tmpdir, jitter, args): out = model(torch.rand(1, 3, 32, 32)) assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([1, 2]) + + +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@mock.patch("flash._IS_TESTING", True) +def test_serve(): + model = ImageClassifier(2) + # TODO: Currently only servable once a preprocess has been attached + model._preprocess = ImageClassificationPreprocess() + model.eval() + model.serve() diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index b87ba8dec5..2cd29b61f2 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -5,8 +5,9 @@ import pytest from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image.detection.data import ObjectDetectionData +from tests.helpers.utils import _IMAGE_TESTING if _IMAGE_AVAILABLE: from PIL import Image @@ -120,7 +121,7 @@ def _create_synth_fiftyone_dataset(tmpdir): return dataset -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="pycocotools is not installed for testing") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="pycocotools is not installed for testing") def test_image_detector_data_from_coco(tmpdir): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) @@ -166,7 +167,7 @@ def test_image_detector_data_from_coco(tmpdir): assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") def test_image_detector_data_from_fiftyone(tmpdir): diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py index a20a4c06d3..2c54a4d0f0 100644 --- a/tests/image/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -19,6 +19,7 @@ from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ObjectDetector from flash.image.detection import ObjectDetectionData +from tests.helpers.utils import _IMAGE_TESTING if _IMAGE_AVAILABLE: from PIL import Image @@ -32,7 +33,7 @@ from tests.image.detection.test_data import _create_synth_fiftyone_dataset -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="pycocotools is not installed for testing") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="pycocotools is not installed for testing") @pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") @pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "resnet18")]) def test_detection(tmpdir, model, backbone): @@ -56,7 +57,7 @@ def test_detection(tmpdir, model, backbone): model.predict(test_images) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed for testing") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed for testing") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") @pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "resnet18")]) def test_detection_fiftyone(tmpdir, model, backbone): diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index 925d37b6a2..1b9a0bf67c 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -19,8 +19,8 @@ from torch.utils.data import DataLoader, Dataset from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image import ObjectDetector +from tests.helpers.utils import _IMAGE_TESTING def collate_fn(samples): @@ -52,7 +52,7 @@ def __getitem__(self, idx): return {DefaultDataKeys.INPUT: img, DefaultDataKeys.TARGET: {"boxes": boxes, "labels": labels}} -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_init(): model = ObjectDetector(num_classes=2) model.eval() @@ -70,7 +70,7 @@ def test_init(): @pytest.mark.parametrize("model", ["fasterrcnn", "retinanet"]) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_training(tmpdir, model): model = ObjectDetector(num_classes=2, model=model, pretrained=False, pretrained_backbone=False) ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) @@ -79,7 +79,7 @@ def test_training(tmpdir, model): trainer.fit(model, dl) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_jit(tmpdir): path = os.path.join(tmpdir, "test.pt") diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index 0c43035451..72951d5e7b 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -16,11 +16,11 @@ import pytest import torch -from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image import ImageEmbedder +from tests.helpers.utils import _IMAGE_TESTING -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") diff --git a/tests/image/segmentation/test_backbones.py b/tests/image/segmentation/test_backbones.py index a4c850b0bc..0b2b452e17 100644 --- a/tests/image/segmentation/test_backbones.py +++ b/tests/image/segmentation/test_backbones.py @@ -15,7 +15,6 @@ import torch from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE -from flash.core.utilities.imports import _BOLTS_AVAILABLE from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES diff --git a/tests/image/segmentation/test_data.py b/tests/image/segmentation/test_data.py index b464c35ad5..8d536df2dd 100644 --- a/tests/image/segmentation/test_data.py +++ b/tests/image/segmentation/test_data.py @@ -11,6 +11,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess +from tests.helpers.utils import _IMAGE_TESTING if _IMAGE_AVAILABLE: from PIL import Image @@ -54,7 +55,7 @@ def test_smoke(self): assert prep is not None -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") class TestSemanticSegmentationData: def test_smoke(self): diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index 302cb99cfc..da8e390e67 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -13,6 +13,7 @@ # limitations under the License. import os from typing import Tuple +from unittest import mock import numpy as np import pytest @@ -21,9 +22,9 @@ from flash import Trainer from flash.core.data.data_pipeline import DataPipeline from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image import SemanticSegmentation from flash.image.segmentation.data import SemanticSegmentationPreprocess +from tests.helpers.utils import _IMAGE_TESTING, _SERVE_TESTING # ======== Mock functions ======== @@ -45,13 +46,13 @@ def __len__(self) -> int: # ============================== -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_smoke(): model = SemanticSegmentation(num_classes=1) assert model is not None -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize("num_classes", [8, 256]) @pytest.mark.parametrize("img_shape", [(1, 3, 224, 192), (2, 3, 127, 212)]) def test_forward(num_classes, img_shape): @@ -68,7 +69,7 @@ def test_forward(num_classes, img_shape): assert out.shape == (B, num_classes, H, W) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_init_train(tmpdir): model = SemanticSegmentation(num_classes=10) train_dl = torch.utils.data.DataLoader(DummyDataset()) @@ -76,13 +77,13 @@ def test_init_train(tmpdir): trainer.finetune(model, train_dl, strategy="freeze_unfreeze") -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_non_existent_backbone(): with pytest.raises(KeyError): SemanticSegmentation(2, "i am never going to implement this lol") -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_freeze(): model = SemanticSegmentation(2) model.freeze() @@ -90,7 +91,7 @@ def test_freeze(): assert p.requires_grad is False -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_unfreeze(): model = SemanticSegmentation(2) model.unfreeze() @@ -98,7 +99,7 @@ def test_unfreeze(): assert p.requires_grad is True -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_predict_tensor(): img = torch.rand(1, 3, 10, 20) model = SemanticSegmentation(2) @@ -109,7 +110,7 @@ def test_predict_tensor(): assert len(out[0][0]) == 20 -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_predict_numpy(): img = np.ones((1, 3, 10, 20)) model = SemanticSegmentation(2) @@ -120,7 +121,7 @@ def test_predict_numpy(): assert len(out[0][0]) == 20 -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") @@ -136,3 +137,13 @@ def test_jit(tmpdir, jitter, args): out = model(torch.rand(1, 3, 32, 32)) assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([1, 2, 32, 32]) + + +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@mock.patch("flash._IS_TESTING", True) +def test_serve(): + model = SemanticSegmentation(2) + # TODO: Currently only servable once a preprocess has been attached + model._preprocess = SemanticSegmentationPreprocess() + model.eval() + model.serve() diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py index 4cd1b05c9d..6c4345fb87 100644 --- a/tests/image/style_transfer/test_model.py +++ b/tests/image/style_transfer/test_model.py @@ -3,11 +3,12 @@ import pytest import torch -from flash.core.utilities.imports import _IMAGE_STLYE_TRANSFER, _PYSTICHE_GREATER_EQUAL_0_7_2 +from flash.core.utilities.imports import _IMAGE_STLYE_TRANSFER from flash.image.style_transfer import StyleTransfer +from tests.helpers.utils import _IMAGE_STLYE_TRANSFER_TESTING -@pytest.mark.skipif(not _PYSTICHE_GREATER_EQUAL_0_7_2, reason="image style transfer libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_STLYE_TRANSFER_TESTING, reason="image style transfer libraries aren't installed.") def test_style_transfer_task(): model = StyleTransfer( @@ -25,7 +26,7 @@ def test_style_transfer_task_import(): StyleTransfer() -@pytest.mark.skipif(not _PYSTICHE_GREATER_EQUAL_0_7_2, reason="image style transfer libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_STLYE_TRANSFER_TESTING, reason="image style transfer libraries aren't installed.") def test_jit(tmpdir): path = os.path.join(tmpdir, "test.pt") diff --git a/tests/tabular/classification/test_data_model_integration.py b/tests/tabular/classification/test_data_model_integration.py index 2072dd40f5..349aeeaaba 100644 --- a/tests/tabular/classification/test_data_model_integration.py +++ b/tests/tabular/classification/test_data_model_integration.py @@ -16,6 +16,7 @@ from flash.core.utilities.imports import _TABULAR_AVAILABLE from flash.tabular import TabularClassifier, TabularData +from tests.helpers.utils import _TABULAR_TESTING if _TABULAR_AVAILABLE: import pandas as pd @@ -30,7 +31,7 @@ ) -@pytest.mark.skipif(not _TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") +@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") def test_classification(tmpdir): train_data_frame = TEST_DF_1.copy() diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index e2fbf11edb..6e2108162e 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest import mock +import pandas as pd import pytest import torch from pytorch_lightning import Trainer @@ -20,6 +22,8 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _TABULAR_AVAILABLE from flash.tabular import TabularClassifier +from flash.tabular.classification.data import TabularData +from tests.helpers.utils import _SERVE_TESTING, _TABULAR_TESTING # ======== Mock functions ======== @@ -44,7 +48,7 @@ def __len__(self) -> int: # ============================== -@pytest.mark.skipif(not _TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") +@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") def test_init_train(tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=16) model = TabularClassifier(num_classes=10, num_features=16 + 16, embedding_sizes=16 * [(10, 32)]) @@ -52,7 +56,7 @@ def test_init_train(tmpdir): trainer.fit(model, train_dl) -@pytest.mark.skipif(not _TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") +@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") def test_init_train_no_num(tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset(num_num=0), batch_size=16) model = TabularClassifier(num_classes=10, num_features=16, embedding_sizes=16 * [(10, 32)]) @@ -60,7 +64,7 @@ def test_init_train_no_num(tmpdir): trainer.fit(model, train_dl) -@pytest.mark.skipif(not _TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") +@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") def test_init_train_no_cat(tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset(num_cat=0), batch_size=16) model = TabularClassifier(num_classes=10, num_features=16, embedding_sizes=[]) @@ -74,7 +78,7 @@ def test_module_import_error(tmpdir): TabularClassifier(num_classes=10, num_features=16, embedding_sizes=[]) -@pytest.mark.skipif(not _TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") +@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") def test_jit(tmpdir): model = TabularClassifier(num_classes=10, num_features=8, embedding_sizes=4 * [(10, 32)]) model.eval() @@ -90,3 +94,20 @@ def test_jit(tmpdir): out = model((torch.randint(0, 10, size=(1, 4)), torch.rand(1, 4))) assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([1, 10]) + + +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@mock.patch("flash._IS_TESTING", True) +def test_serve(): + train_data = {"num_col": [1.4, 2.5], "cat_col": ["positive", "negative"], "target": [1, 2]} + datamodule = TabularData.from_data_frame( + "cat_col", + "num_col", + "target", + pd.DataFrame.from_dict(train_data), + ) + model = TabularClassifier.from_data(datamodule) + # TODO: Currently only servable once a preprocess has been attached + model._preprocess = datamodule.preprocess + model.eval() + model.serve() diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index 0564355cd6..d5a3b680f9 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -25,6 +25,7 @@ TextJSONDataSource, TextSentencesDataSource, ) +from tests.helpers.utils import _TEXT_TESTING if _TEXT_AVAILABLE: from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -57,7 +58,7 @@ def json_data(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = TextClassificationData.from_csv("sentence", "label", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) @@ -67,7 +68,7 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_test_valid(tmpdir): csv_path = csv_data(tmpdir) dm = TextClassificationData.from_csv( @@ -89,7 +90,7 @@ def test_test_valid(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_json(tmpdir): json_path = json_data(tmpdir) dm = TextClassificationData.from_json("sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1) @@ -105,7 +106,7 @@ def test_text_module_not_found_error(): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") @pytest.mark.parametrize( "cls, kwargs", [ diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index b628086657..18863b9637 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest import mock import pytest import torch from flash import Trainer -from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TextClassifier +from flash.text.classification.data import TextClassificationPostprocess, TextClassificationPreprocess +from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING # ======== Mock functions ======== @@ -41,7 +43,7 @@ def __len__(self) -> int: @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_init_train(tmpdir): model = TextClassifier(2, TEST_BACKBONE) train_dl = torch.utils.data.DataLoader(DummyDataset()) @@ -49,7 +51,7 @@ def test_init_train(tmpdir): trainer.fit(model, train_dl) -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_jit(tmpdir): sample_input = {"input_ids": torch.randint(1000, size=(1, 100))} path = os.path.join(tmpdir, "test.pt") @@ -66,3 +68,14 @@ def test_jit(tmpdir): out = model(sample_input)["logits"] assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([1, 2]) + + +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@mock.patch("flash._IS_TESTING", True) +def test_serve(): + model = TextClassifier(2, TEST_BACKBONE) + # TODO: Currently only servable once a preprocess and postprocess have been attached + model._preprocess = TextClassificationPreprocess(backbone=TEST_BACKBONE) + model._postprocess = TextClassificationPostprocess() + model.eval() + model.serve() diff --git a/tests/text/seq2seq/core/test_data.py b/tests/text/seq2seq/core/test_data.py index bf63bb86fb..4f2144aa90 100644 --- a/tests/text/seq2seq/core/test_data.py +++ b/tests/text/seq2seq/core/test_data.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from pathlib import Path import pytest from flash.core.utilities.imports import _TEXT_AVAILABLE -from flash.text import TextClassificationData from flash.text.seq2seq.core.data import ( Seq2SeqBackboneState, Seq2SeqCSVDataSource, @@ -27,13 +25,14 @@ Seq2SeqPostprocess, Seq2SeqSentencesDataSource, ) +from tests.helpers.utils import _TEXT_TESTING if _TEXT_AVAILABLE: from transformers.tokenization_utils_base import PreTrainedTokenizerBase @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") @pytest.mark.parametrize( "cls, kwargs", [ diff --git a/tests/text/seq2seq/summarization/test_data.py b/tests/text/seq2seq/summarization/test_data.py index 5abc9afe11..2ab09f3636 100644 --- a/tests/text/seq2seq/summarization/test_data.py +++ b/tests/text/seq2seq/summarization/test_data.py @@ -16,8 +16,8 @@ import pytest -from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import SummarizationData +from tests.helpers.utils import _TEXT_TESTING TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing @@ -47,7 +47,7 @@ def json_data(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = SummarizationData.from_csv("input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) @@ -57,7 +57,7 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_files(tmpdir): csv_path = csv_data(tmpdir) dm = SummarizationData.from_csv( @@ -78,7 +78,7 @@ def test_from_files(tmpdir): assert "input_ids" in batch -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_postprocess_tokenizer(tmpdir): """Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different backbone is used. @@ -99,7 +99,7 @@ def test_postprocess_tokenizer(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_json(tmpdir): json_path = json_data(tmpdir) dm = SummarizationData.from_json("input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1) diff --git a/tests/text/seq2seq/summarization/test_metric.py b/tests/text/seq2seq/summarization/test_metric.py index 7b829333a3..9f17397b02 100644 --- a/tests/text/seq2seq/summarization/test_metric.py +++ b/tests/text/seq2seq/summarization/test_metric.py @@ -14,11 +14,11 @@ import pytest import torch -from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text.seq2seq.summarization.metric import RougeMetric +from tests.helpers.utils import _TEXT_TESTING -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_rouge(): preds = "My name is John".split() target = "Is your name John".split() diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index 7e380e331e..69b030e402 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -17,8 +17,8 @@ import torch from flash import Trainer -from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import SummarizationTask +from tests.helpers.utils import _TEXT_TESTING # ======== Mock functions ======== @@ -41,7 +41,7 @@ def __len__(self) -> int: @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_init_train(tmpdir): model = SummarizationTask(TEST_BACKBONE) train_dl = torch.utils.data.DataLoader(DummyDataset()) @@ -49,7 +49,7 @@ def test_init_train(tmpdir): trainer.fit(model, train_dl) -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_jit(tmpdir): sample_input = { "input_ids": torch.randint(1000, size=(1, 32)), diff --git a/tests/text/seq2seq/translation/test_data.py b/tests/text/seq2seq/translation/test_data.py index 64e613d6a6..244cb27d4a 100644 --- a/tests/text/seq2seq/translation/test_data.py +++ b/tests/text/seq2seq/translation/test_data.py @@ -16,8 +16,8 @@ import pytest -from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TranslationData +from tests.helpers.utils import _TEXT_TESTING TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing @@ -47,7 +47,7 @@ def json_data(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = TranslationData.from_csv("input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) @@ -57,7 +57,7 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_files(tmpdir): csv_path = csv_data(tmpdir) dm = TranslationData.from_csv( @@ -79,7 +79,7 @@ def test_from_files(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_json(tmpdir): json_path = json_data(tmpdir) dm = TranslationData.from_json("input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1) diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index 808d0a0ada..dd0ac0979a 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -17,8 +17,8 @@ import torch from flash import Trainer -from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TranslationTask +from tests.helpers.utils import _TEXT_TESTING # ======== Mock functions ======== @@ -41,7 +41,7 @@ def __len__(self) -> int: @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_init_train(tmpdir): model = TranslationTask(TEST_BACKBONE) train_dl = torch.utils.data.DataLoader(DummyDataset()) @@ -49,7 +49,7 @@ def test_init_train(tmpdir): trainer.fit(model, train_dl) -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_jit(tmpdir): sample_input = { "input_ids": torch.randint(128, size=(1, 4)), diff --git a/tests/text/test_data_model_integration.py b/tests/text/test_data_model_integration.py index f6c5137252..9f65b02639 100644 --- a/tests/text/test_data_model_integration.py +++ b/tests/text/test_data_model_integration.py @@ -17,8 +17,8 @@ import pytest from pytorch_lightning import Trainer -from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TextClassificationData, TextClassifier +from tests.helpers.utils import _TEXT_TESTING TEST_BACKBONE = "prajjwal1/bert-tiny" # super small model for testing @@ -36,7 +36,7 @@ def csv_data(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_classification(tmpdir): csv_path = csv_data(tmpdir) diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index fcb1bc68bd..a15c0d0b52 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -22,6 +22,7 @@ import flash from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _VIDEO_AVAILABLE +from tests.helpers.utils import _VIDEO_TESTING if _FIFTYONE_AVAILABLE: import fiftyone as fo @@ -123,7 +124,7 @@ def mock_encoded_video_dataset_folder(tmpdir): yield str(tmp_dir), video_duration -@pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") def test_video_classifier_finetune(tmpdir): with mock_encoded_video_dataset_file() as ( @@ -189,7 +190,7 @@ def test_video_classifier_finetune(tmpdir): trainer.finetune(model, datamodule=datamodule) -@pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") def test_video_classifier_finetune_fiftyone(tmpdir): @@ -259,7 +260,7 @@ def test_video_classifier_finetune_fiftyone(tmpdir): trainer.finetune(model, datamodule=datamodule) -@pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") def test_jit(tmpdir): sample_input = torch.rand(1, 3, 32, 256, 256) path = os.path.join(tmpdir, "test.pt")