diff --git a/dev-requirements.in b/dev-requirements.in index 09f3b90c46..b2e6cf74a1 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -11,3 +11,4 @@ codespell google-cloud-bigquery google-cloud-bigquery-storage IPython +torch diff --git a/dev-requirements.txt b/dev-requirements.txt index be495a35e2..3705a578c0 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -444,6 +444,8 @@ tomli==2.0.1 # coverage # mypy # pytest +torch==1.11.0 + # via -r dev-requirements.in traitlets==5.3.0 # via # ipython @@ -458,6 +460,7 @@ typing-extensions==4.3.0 # importlib-metadata # mypy # responses + # torch # typing-inspect typing-inspect==0.7.1 # via diff --git a/doc-requirements.in b/doc-requirements.in index 4d60a6919b..760d3903dc 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -33,3 +33,4 @@ papermill # papermill jupyter # papermill pyspark # spark sqlalchemy # sqlalchemy +torch # pytorch diff --git a/doc-requirements.txt b/doc-requirements.txt index c3c934c80f..e9f16c89a0 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -664,6 +664,8 @@ tinycss2==1.1.1 # via nbconvert toolz==0.11.2 # via altair +torch==1.11.0 + # via -r doc-requirements.in tornado==6.2 # via # ipykernel @@ -703,6 +705,7 @@ typing-extensions==4.3.0 # pandera # pydantic # responses + # torch # typing-inspect typing-inspect==0.7.1 # via diff --git a/docs/source/extras.pytorch.rst b/docs/source/extras.pytorch.rst new file mode 100644 index 0000000000..12fd3d62d9 --- /dev/null +++ b/docs/source/extras.pytorch.rst @@ -0,0 +1,7 @@ +############ +PyTorch Type +############ +.. automodule:: flytekit.extras.pytorch + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/types.extend.rst b/docs/source/types.extend.rst index f1b15455dd..f0cdff28dc 100644 --- a/docs/source/types.extend.rst +++ b/docs/source/types.extend.rst @@ -11,3 +11,4 @@ Feel free to follow the pattern of the built-in types. types.builtins.structured types.builtins.file types.builtins.directory + extras.pytorch diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 3856bbb806..c67a8a04b4 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -182,6 +182,7 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck +from flytekit.extras import pytorch from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels diff --git a/flytekit/extras/pytorch/__init__.py b/flytekit/extras/pytorch/__init__.py new file mode 100644 index 0000000000..ae077d9755 --- /dev/null +++ b/flytekit/extras/pytorch/__init__.py @@ -0,0 +1,20 @@ +""" +Flytekit PyTorch +========================================= +.. currentmodule:: flytekit.extras.pytorch + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + PyTorchCheckpoint +""" +from flytekit.loggers import logger + +try: + from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer + from .native import PyTorchModuleTransformer, PyTorchTensorTransformer +except ImportError: + logger.info( + "We won't register PyTorchCheckpointTransformer, PyTorchTensorTransformer, and PyTorchModuleTransformer because torch is not installed." + ) diff --git a/flytekit/extras/pytorch/checkpoint.py b/flytekit/extras/pytorch/checkpoint.py new file mode 100644 index 0000000000..c7561f13f4 --- /dev/null +++ b/flytekit/extras/pytorch/checkpoint.py @@ -0,0 +1,137 @@ +import pathlib +import typing +from dataclasses import asdict, dataclass, fields, is_dataclass +from typing import Any, Callable, Dict, NamedTuple, Optional, Type, Union + +import torch +from dataclasses_json import dataclass_json +from typing_extensions import Protocol + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + + +class IsDataclass(Protocol): + __dataclass_fields__: Dict + __dataclass_params__: Dict + __post_init__: Optional[Callable] + + +@dataclass_json +@dataclass +class PyTorchCheckpoint: + """ + This class is helpful to save a checkpoint. + """ + + module: Optional[torch.nn.Module] = None + hyperparameters: Optional[Union[Dict[str, Any], NamedTuple, IsDataclass]] = None + optimizer: Optional[torch.optim.Optimizer] = None + + def __post_init__(self): + if not ( + isinstance(self.hyperparameters, dict) + or (is_dataclass(self.hyperparameters) and not isinstance(self.hyperparameters, type)) + or (isinstance(self.hyperparameters, tuple) and hasattr(self.hyperparameters, "_fields")) + or (self.hyperparameters is None) + ): + raise TypeTransformerFailedError( + f"hyperparameters must be a dict, dataclass, or NamedTuple. Got {type(self.hyperparameters)}" + ) + + if not (self.module or self.hyperparameters or self.optimizer): + raise TypeTransformerFailedError("Must have at least one of module, hyperparameters, or optimizer") + + +class PyTorchCheckpointTransformer(TypeTransformer[PyTorchCheckpoint]): + """ + TypeTransformer that supports serializing and deserializing checkpoint. + """ + + PYTORCH_CHECKPOINT_FORMAT = "PyTorchCheckpoint" + + def __init__(self): + super().__init__(name="PyTorch Checkpoint", t=PyTorchCheckpoint) + + def get_literal_type(self, t: Type[PyTorchCheckpoint]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: PyTorchCheckpoint, + python_type: Type[PyTorchCheckpoint], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ) + + local_path = ctx.file_access.get_random_local_path() + ".pt" + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + to_save = {} + for field in fields(python_val): + value = getattr(python_val, field.name) + + if value and field.name in ["module", "optimizer"]: + to_save[field.name + "_state_dict"] = getattr(value, "state_dict")() + elif value and field.name == "hyperparameters": + if isinstance(value, dict): + to_save.update(value) + elif isinstance(value, tuple): + to_save.update(value._asdict()) + elif is_dataclass(value): + to_save.update(asdict(value)) + + if not to_save: + raise TypeTransformerFailedError(f"Cannot save empty {python_val}") + + # save checkpoint to a file + torch.save(to_save, local_path) + + remote_path = ctx.file_access.get_random_remote_path(local_path) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[PyTorchCheckpoint] + ) -> PyTorchCheckpoint: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=False) + + # cpu <-> gpu conversion + if torch.cuda.is_available(): + map_location = "cuda:0" + else: + map_location = torch.device("cpu") + + # load checkpoint from a file + return typing.cast(PyTorchCheckpoint, torch.load(local_path, map_location=map_location)) + + def guess_python_type(self, literal_type: LiteralType) -> Type[PyTorchCheckpoint]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.PYTORCH_CHECKPOINT_FORMAT + ): + return PyTorchCheckpoint + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(PyTorchCheckpointTransformer()) diff --git a/flytekit/extras/pytorch/native.py b/flytekit/extras/pytorch/native.py new file mode 100644 index 0000000000..4cf37871fb --- /dev/null +++ b/flytekit/extras/pytorch/native.py @@ -0,0 +1,92 @@ +import pathlib +from typing import Generic, Type, TypeVar + +import torch + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + +T = TypeVar("T") + + +class PyTorchTypeTransformer(TypeTransformer, Generic[T]): + def get_literal_type(self, t: Type[T]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.PYTORCH_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: T, + python_type: Type[T], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.PYTORCH_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + local_path = ctx.file_access.get_random_local_path() + ".pt" + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + # save pytorch tensor/module to a file + torch.save(python_val, local_path) + + remote_path = ctx.file_access.get_random_remote_path(local_path) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=False) + + # cpu <-> gpu conversion + if torch.cuda.is_available(): + map_location = "cuda:0" + else: + map_location = torch.device("cpu") + + # load pytorch tensor/module from a file + return torch.load(local_path, map_location=map_location) + + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.PYTORCH_FORMAT + ): + return T + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +class PyTorchTensorTransformer(PyTorchTypeTransformer[torch.Tensor]): + PYTORCH_FORMAT = "PyTorchTensor" + + def __init__(self): + super().__init__(name="PyTorch Tensor", t=torch.Tensor) + + +class PyTorchModuleTransformer(PyTorchTypeTransformer[torch.nn.Module]): + PYTORCH_FORMAT = "PyTorchModule" + + def __init__(self): + super().__init__(name="PyTorch Module", t=torch.nn.Module) + + +TypeEngine.register(PyTorchTensorTransformer()) +TypeEngine.register(PyTorchModuleTransformer()) diff --git a/flytekit/types/numpy/__init__.py b/flytekit/types/numpy/__init__.py index 83771c1152..ec20e87970 100644 --- a/flytekit/types/numpy/__init__.py +++ b/flytekit/types/numpy/__init__.py @@ -1,13 +1 @@ -""" -Flytekit Numpy -============== -.. currentmodule:: flytekit.types.numpy - -.. autosummary:: - :template: custom.rst - :toctree: generated/ - - NumpyArrayTransformer -""" - from .ndarray import NumpyArrayTransformer diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 3dd8b06235..52577a650d 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -10,7 +10,6 @@ StructuredDataset StructuredDatasetEncoder StructuredDatasetDecoder - StructuredDatasetTransformerEngine """ diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index 0f6985b1ed..c93c56435c 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -142,7 +142,6 @@ pyparsing==3.0.9 # packaging python-dateutil==2.8.2 # via - # arrow # croniter # flytekit # matplotlib diff --git a/tests/flytekit/unit/extras/pytorch/__init__.py b/tests/flytekit/unit/extras/pytorch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/pytorch/test_checkpoint.py b/tests/flytekit/unit/extras/pytorch/test_checkpoint.py new file mode 100644 index 0000000000..49ad083285 --- /dev/null +++ b/tests/flytekit/unit/extras/pytorch/test_checkpoint.py @@ -0,0 +1,105 @@ +from dataclasses import asdict, dataclass +from typing import NamedTuple + +import pytest +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from dataclasses_json import dataclass_json + +from flytekit import task, workflow +from flytekit.core.type_engine import TypeTransformerFailedError +from flytekit.extras.pytorch import PyTorchCheckpoint + + +@dataclass_json +@dataclass +class Hyperparameters: + epochs: int + loss: float + + +class TupleHyperparameters(NamedTuple): + epochs: int + loss: float + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +@task +def generate_model_dict(hyperparameters: Hyperparameters) -> PyTorchCheckpoint: + bn = Net() + optimizer = optim.SGD(bn.parameters(), lr=0.001, momentum=0.9) + return PyTorchCheckpoint(module=bn, hyperparameters=asdict(hyperparameters), optimizer=optimizer) + + +@task +def generate_model_tuple() -> PyTorchCheckpoint: + bn = Net() + optimizer = optim.SGD(bn.parameters(), lr=0.001, momentum=0.9) + return PyTorchCheckpoint(module=bn, hyperparameters=TupleHyperparameters(epochs=5, loss=0.4), optimizer=optimizer) + + +@task +def generate_model_dataclass(hyperparameters: Hyperparameters) -> PyTorchCheckpoint: + bn = Net() + optimizer = optim.SGD(bn.parameters(), lr=0.001, momentum=0.9) + return PyTorchCheckpoint(module=bn, hyperparameters=hyperparameters, optimizer=optimizer) + + +@task +def generate_model_only_module() -> PyTorchCheckpoint: + bn = Net() + return PyTorchCheckpoint(module=bn) + + +@task +def empty_checkpoint(): + with pytest.raises(TypeTransformerFailedError): + return PyTorchCheckpoint() + + +@task +def t1(checkpoint: PyTorchCheckpoint): + new_bn = Net() + new_bn.load_state_dict(checkpoint["module_state_dict"]) + optimizer = optim.SGD(new_bn.parameters(), lr=0.001, momentum=0.9) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + + assert checkpoint["epochs"] == 5 + assert checkpoint["loss"] == 0.4 + + +@workflow +def wf(): + checkpoint_dict = generate_model_dict(hyperparameters=Hyperparameters(epochs=5, loss=0.4)) + checkpoint_tuple = generate_model_tuple() + checkpoint_dataclass = generate_model_dataclass(hyperparameters=Hyperparameters(epochs=5, loss=0.4)) + t1(checkpoint=checkpoint_dict) + t1(checkpoint=checkpoint_tuple) + t1(checkpoint=checkpoint_dataclass) + generate_model_only_module() + empty_checkpoint() + + +@workflow +def test_wf(): + wf() diff --git a/tests/flytekit/unit/extras/pytorch/test_native.py b/tests/flytekit/unit/extras/pytorch/test_native.py new file mode 100644 index 0000000000..9d44ed1c1f --- /dev/null +++ b/tests/flytekit/unit/extras/pytorch/test_native.py @@ -0,0 +1,73 @@ +import torch + +from flytekit import task, workflow + + +@task +def generate_tensor_1d() -> torch.Tensor: + return torch.zeros(5, dtype=torch.int32) + + +@task +def generate_tensor_2d() -> torch.Tensor: + return torch.tensor([[1.0, -1.0, 2], [1.0, -1.0, 9], [0, 7.0, 3]]) + + +@task +def generate_module() -> torch.nn.Module: + bn = torch.nn.BatchNorm1d(3, track_running_stats=True) + return bn + + +class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.l0 = torch.nn.Linear(4, 2) + self.l1 = torch.nn.Linear(2, 1) + + def forward(self, input): + out0 = self.l0(input) + out0_relu = torch.nn.functional.relu(out0) + return self.l1(out0_relu) + + +@task +def generate_model() -> torch.nn.Module: + return MyModel() + + +@task +def t1(tensor: torch.Tensor) -> torch.Tensor: + assert tensor.dtype == torch.int32 + tensor[0] = 1 + return tensor + + +@task +def t2(tensor: torch.Tensor) -> torch.Tensor: + # convert 2D to 3D + tensor.unsqueeze_(-1) + return tensor.expand(3, 3, 2) + + +@task +def t3(model: torch.nn.Module) -> torch.Tensor: + return model.weight + + +@task +def t4(model: torch.nn.Module) -> torch.nn.Module: + return model.l1 + + +@workflow +def wf(): + t1(tensor=generate_tensor_1d()) + t2(tensor=generate_tensor_2d()) + t3(model=generate_module()) + t4(model=MyModel()) + + +@workflow +def test_wf(): + wf() diff --git a/tests/flytekit/unit/extras/pytorch/test_transformations.py b/tests/flytekit/unit/extras/pytorch/test_transformations.py new file mode 100644 index 0000000000..1a3a83ab93 --- /dev/null +++ b/tests/flytekit/unit/extras/pytorch/test_transformations.py @@ -0,0 +1,130 @@ +from collections import OrderedDict + +import pytest +import torch + +import flytekit +from flytekit import task +from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.extras.pytorch import ( + PyTorchCheckpoint, + PyTorchCheckpointTransformer, + PyTorchModuleTransformer, + PyTorchTensorTransformer, +) +from flytekit.models.core.types import BlobType +from flytekit.models.literals import BlobMetadata +from flytekit.models.types import LiteralType +from flytekit.tools.translator import get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +@pytest.mark.parametrize( + "transformer,python_type,format", + [ + (PyTorchTensorTransformer(), torch.Tensor, PyTorchTensorTransformer.PYTORCH_FORMAT), + (PyTorchModuleTransformer(), torch.nn.Module, PyTorchModuleTransformer.PYTORCH_FORMAT), + (PyTorchCheckpointTransformer(), PyTorchCheckpoint, PyTorchCheckpointTransformer.PYTORCH_CHECKPOINT_FORMAT), + ], +) +def test_get_literal_type(transformer, python_type, format): + tf = transformer + lt = tf.get_literal_type(python_type) + assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)) + + +@pytest.mark.parametrize( + "transformer,python_type,format,python_val", + [ + ( + PyTorchTensorTransformer(), + torch.Tensor, + PyTorchTensorTransformer.PYTORCH_FORMAT, + torch.tensor([[1, 2], [3, 4]]), + ), + ( + PyTorchModuleTransformer(), + torch.nn.Module, + PyTorchModuleTransformer.PYTORCH_FORMAT, + torch.nn.Linear(2, 2), + ), + ( + PyTorchCheckpointTransformer(), + PyTorchCheckpoint, + PyTorchCheckpointTransformer.PYTORCH_CHECKPOINT_FORMAT, + PyTorchCheckpoint( + module=torch.nn.Linear(2, 2), + hyperparameters={"epochs": 10, "batch_size": 32}, + optimizer=torch.optim.Adam(torch.nn.Linear(2, 2).parameters()), + ), + ), + ], +) +def test_to_python_value_and_literal(transformer, python_type, format, python_val): + ctx = context_manager.FlyteContext.current_context() + tf = transformer + python_val = python_val + lt = tf.get_literal_type(python_type) + + lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=format, + dimensionality=BlobType.BlobDimensionality.SINGLE, + ) + ) + assert lv.scalar.blob.uri is not None + + output = tf.to_python_value(ctx, lv, python_type) + if isinstance(python_val, torch.Tensor): + assert torch.equal(output, python_val) + elif isinstance(python_val, torch.nn.Module): + for p1, p2 in zip(output.parameters(), python_val.parameters()): + if p1.data.ne(p2.data).sum() > 0: + assert False + assert True + else: + assert isinstance(output, dict) + + +def test_example_tensor(): + @task + def t1(array: torch.Tensor) -> torch.Tensor: + return torch.flatten(array) + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.interface.outputs["o0"].type.blob.format is PyTorchTensorTransformer.PYTORCH_FORMAT + + +def test_example_module(): + @task + def t1() -> torch.nn.Module: + return torch.nn.BatchNorm1d(3, track_running_stats=True) + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.interface.outputs["o0"].type.blob.format is PyTorchModuleTransformer.PYTORCH_FORMAT + + +def test_example_checkpoint(): + @task + def t1() -> PyTorchCheckpoint: + return PyTorchCheckpoint( + module=torch.nn.Linear(2, 2), + hyperparameters={"epochs": 10, "batch_size": 32}, + optimizer=torch.optim.Adam(torch.nn.Linear(2, 2).parameters()), + ) + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert ( + task_spec.template.interface.outputs["o0"].type.blob.format + is PyTorchCheckpointTransformer.PYTORCH_CHECKPOINT_FORMAT + )