Skip to content

Commit

Permalink
TypeTransformers for PyTorch Tensor, Module, and Checkpoint (#1032)
Browse files Browse the repository at this point in the history
* TypeTransformers for PyTorch Tensor and Module

Signed-off-by: Samhita Alla <[email protected]>

* add torch to requirements

Signed-off-by: Samhita Alla <[email protected]>

* add module as a native type and PyTorchCheckpoint

Signed-off-by: Samhita Alla <[email protected]>

* update requirements

Signed-off-by: Samhita Alla <[email protected]>

* procedural to OOP approach

Signed-off-by: Samhita Alla <[email protected]>

* nit

Signed-off-by: Samhita Alla <[email protected]>

* verify device conversion

Signed-off-by: Samhita Alla <[email protected]>

* verify device conversion

Signed-off-by: Samhita Alla <[email protected]>

* hyperparameters can be None

Signed-off-by: Samhita Alla <[email protected]>

* device conversion

Signed-off-by: Samhita Alla <[email protected]>

* device conversion

Signed-off-by: Samhita Alla <[email protected]>

* checkpoint code cleanup

Signed-off-by: Samhita Alla <[email protected]>

* resolve merge conflict

Signed-off-by: Samhita Alla <[email protected]>

* fix pytorch api reference; resolve merge conflict

Signed-off-by: Samhita Alla <[email protected]>

* fix pytorch import

Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
samhita-alla authored and wild-endeavor committed Aug 2, 2022
1 parent edbd900 commit e5f9d88
Show file tree
Hide file tree
Showing 17 changed files with 574 additions and 14 deletions.
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ codespell
google-cloud-bigquery
google-cloud-bigquery-storage
IPython
torch
3 changes: 3 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,8 @@ tomli==2.0.1
# coverage
# mypy
# pytest
torch==1.11.0
# via -r dev-requirements.in
traitlets==5.3.0
# via
# ipython
Expand All @@ -458,6 +460,7 @@ typing-extensions==4.3.0
# importlib-metadata
# mypy
# responses
# torch
# typing-inspect
typing-inspect==0.7.1
# via
Expand Down
1 change: 1 addition & 0 deletions doc-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ papermill # papermill
jupyter # papermill
pyspark # spark
sqlalchemy # sqlalchemy
torch # pytorch
3 changes: 3 additions & 0 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,8 @@ tinycss2==1.1.1
# via nbconvert
toolz==0.11.2
# via altair
torch==1.11.0
# via -r doc-requirements.in
tornado==6.2
# via
# ipykernel
Expand Down Expand Up @@ -703,6 +705,7 @@ typing-extensions==4.3.0
# pandera
# pydantic
# responses
# torch
# typing-inspect
typing-inspect==0.7.1
# via
Expand Down
7 changes: 7 additions & 0 deletions docs/source/extras.pytorch.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
############
PyTorch Type
############
.. automodule:: flytekit.extras.pytorch
:no-members:
:no-inherited-members:
:no-special-members:
1 change: 1 addition & 0 deletions docs/source/types.extend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Feel free to follow the pattern of the built-in types.
types.builtins.structured
types.builtins.file
types.builtins.directory
extras.pytorch
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
from flytekit.extras import pytorch
from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence
from flytekit.loggers import logger
from flytekit.models.common import Annotations, AuthRole, Labels
Expand Down
20 changes: 20 additions & 0 deletions flytekit/extras/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Flytekit PyTorch
=========================================
.. currentmodule:: flytekit.extras.pytorch
.. autosummary::
:template: custom.rst
:toctree: generated/
PyTorchCheckpoint
"""
from flytekit.loggers import logger

try:
from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer
from .native import PyTorchModuleTransformer, PyTorchTensorTransformer
except ImportError:
logger.info(
"We won't register PyTorchCheckpointTransformer, PyTorchTensorTransformer, and PyTorchModuleTransformer because torch is not installed."
)
137 changes: 137 additions & 0 deletions flytekit/extras/pytorch/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import pathlib
import typing
from dataclasses import asdict, dataclass, fields, is_dataclass
from typing import Any, Callable, Dict, NamedTuple, Optional, Type, Union

import torch
from dataclasses_json import dataclass_json
from typing_extensions import Protocol

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType


class IsDataclass(Protocol):
__dataclass_fields__: Dict
__dataclass_params__: Dict
__post_init__: Optional[Callable]


@dataclass_json
@dataclass
class PyTorchCheckpoint:
"""
This class is helpful to save a checkpoint.
"""

module: Optional[torch.nn.Module] = None
hyperparameters: Optional[Union[Dict[str, Any], NamedTuple, IsDataclass]] = None
optimizer: Optional[torch.optim.Optimizer] = None

def __post_init__(self):
if not (
isinstance(self.hyperparameters, dict)
or (is_dataclass(self.hyperparameters) and not isinstance(self.hyperparameters, type))
or (isinstance(self.hyperparameters, tuple) and hasattr(self.hyperparameters, "_fields"))
or (self.hyperparameters is None)
):
raise TypeTransformerFailedError(
f"hyperparameters must be a dict, dataclass, or NamedTuple. Got {type(self.hyperparameters)}"
)

if not (self.module or self.hyperparameters or self.optimizer):
raise TypeTransformerFailedError("Must have at least one of module, hyperparameters, or optimizer")


class PyTorchCheckpointTransformer(TypeTransformer[PyTorchCheckpoint]):
"""
TypeTransformer that supports serializing and deserializing checkpoint.
"""

PYTORCH_CHECKPOINT_FORMAT = "PyTorchCheckpoint"

def __init__(self):
super().__init__(name="PyTorch Checkpoint", t=PyTorchCheckpoint)

def get_literal_type(self, t: Type[PyTorchCheckpoint]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)

def to_literal(
self,
ctx: FlyteContext,
python_val: PyTorchCheckpoint,
python_type: Type[PyTorchCheckpoint],
expected: LiteralType,
) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)

local_path = ctx.file_access.get_random_local_path() + ".pt"
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

to_save = {}
for field in fields(python_val):
value = getattr(python_val, field.name)

if value and field.name in ["module", "optimizer"]:
to_save[field.name + "_state_dict"] = getattr(value, "state_dict")()
elif value and field.name == "hyperparameters":
if isinstance(value, dict):
to_save.update(value)
elif isinstance(value, tuple):
to_save.update(value._asdict())
elif is_dataclass(value):
to_save.update(asdict(value))

if not to_save:
raise TypeTransformerFailedError(f"Cannot save empty {python_val}")

# save checkpoint to a file
torch.save(to_save, local_path)

remote_path = ctx.file_access.get_random_remote_path(local_path)
ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[PyTorchCheckpoint]
) -> PyTorchCheckpoint:
try:
uri = lv.scalar.blob.uri
except AttributeError:
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=False)

# cpu <-> gpu conversion
if torch.cuda.is_available():
map_location = "cuda:0"
else:
map_location = torch.device("cpu")

# load checkpoint from a file
return typing.cast(PyTorchCheckpoint, torch.load(local_path, map_location=map_location))

def guess_python_type(self, literal_type: LiteralType) -> Type[PyTorchCheckpoint]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == self.PYTORCH_CHECKPOINT_FORMAT
):
return PyTorchCheckpoint

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


TypeEngine.register(PyTorchCheckpointTransformer())
92 changes: 92 additions & 0 deletions flytekit/extras/pytorch/native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pathlib
from typing import Generic, Type, TypeVar

import torch

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType

T = TypeVar("T")


class PyTorchTypeTransformer(TypeTransformer, Generic[T]):
def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.PYTORCH_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)

def to_literal(
self,
ctx: FlyteContext,
python_val: T,
python_type: Type[T],
expected: LiteralType,
) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTORCH_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)

local_path = ctx.file_access.get_random_local_path() + ".pt"
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

# save pytorch tensor/module to a file
torch.save(python_val, local_path)

remote_path = ctx.file_access.get_random_remote_path(local_path)
ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
try:
uri = lv.scalar.blob.uri
except AttributeError:
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=False)

# cpu <-> gpu conversion
if torch.cuda.is_available():
map_location = "cuda:0"
else:
map_location = torch.device("cpu")

# load pytorch tensor/module from a file
return torch.load(local_path, map_location=map_location)

def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == self.PYTORCH_FORMAT
):
return T

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


class PyTorchTensorTransformer(PyTorchTypeTransformer[torch.Tensor]):
PYTORCH_FORMAT = "PyTorchTensor"

def __init__(self):
super().__init__(name="PyTorch Tensor", t=torch.Tensor)


class PyTorchModuleTransformer(PyTorchTypeTransformer[torch.nn.Module]):
PYTORCH_FORMAT = "PyTorchModule"

def __init__(self):
super().__init__(name="PyTorch Module", t=torch.nn.Module)


TypeEngine.register(PyTorchTensorTransformer())
TypeEngine.register(PyTorchModuleTransformer())
12 changes: 0 additions & 12 deletions flytekit/types/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1 @@
"""
Flytekit Numpy
==============
.. currentmodule:: flytekit.types.numpy
.. autosummary::
:template: custom.rst
:toctree: generated/
NumpyArrayTransformer
"""

from .ndarray import NumpyArrayTransformer
1 change: 0 additions & 1 deletion flytekit/types/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
StructuredDataset
StructuredDatasetEncoder
StructuredDatasetDecoder
StructuredDatasetTransformerEngine
"""


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ pyparsing==3.0.9
# packaging
python-dateutil==2.8.2
# via
# arrow
# croniter
# flytekit
# matplotlib
Expand Down
Empty file.
Loading

0 comments on commit e5f9d88

Please sign in to comment.