Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeTransformers for PyTorch Tensor, Module, and Checkpoint #1032

Merged
merged 19 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ codespell
google-cloud-bigquery
google-cloud-bigquery-storage
IPython
torch
3 changes: 3 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ tomli==2.0.1
# coverage
# mypy
# pytest
torch==1.11.0
# via -r dev-requirements.in
traitlets==5.3.0
# via
# ipython
Expand All @@ -450,6 +452,7 @@ typing-extensions==4.3.0
# importlib-metadata
# mypy
# responses
# torch
# typing-inspect
typing-inspect==0.7.1
# via
Expand Down
1 change: 1 addition & 0 deletions doc-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ papermill # papermill
jupyter # papermill
pyspark # spark
sqlalchemy # sqlalchemy
torch # pytorch
3 changes: 3 additions & 0 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ tinycss2==1.1.1
# via nbconvert
toolz==0.11.2
# via altair
torch==1.11.0
# via -r doc-requirements.in
tornado==6.1
# via
# ipykernel
Expand Down Expand Up @@ -700,6 +702,7 @@ typing-extensions==4.3.0
# pandera
# pydantic
# responses
# torch
# typing-inspect
typing-inspect==0.7.1
# via
Expand Down
7 changes: 7 additions & 0 deletions docs/source/extras.pytorch.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
############
PyTorch Type
############
.. automodule:: flytekit.extras.pytorch
:no-members:
:no-inherited-members:
:no-special-members:
1 change: 1 addition & 0 deletions docs/source/types.extend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Feel free to follow the pattern of the built-in types.
types.builtins.structured
types.builtins.file
types.builtins.directory
extras.pytorch
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
from flytekit.extras import pytorch
from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence
from flytekit.loggers import logger
from flytekit.models.common import Annotations, AuthRole, Labels
Expand Down
20 changes: 20 additions & 0 deletions flytekit/extras/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Flytekit PyTorch
=========================================
.. currentmodule:: flytekit.extras.pytorch

.. autosummary::
:template: custom.rst
:toctree: generated/

PyTorchCheckpoint
"""
from flytekit.loggers import logger

try:
from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should stick with the full import in the future, just for consistency. merge as is, i'll update it in the future.

from .native import PyTorchModuleTransformer, PyTorchTensorTransformer
except ImportError:
logger.info(
"We won't register PyTorchCheckpointTransformer, PyTorchTensorTransformer, and PyTorchModuleTransformer because torch is not installed."
)
141 changes: 141 additions & 0 deletions flytekit/extras/pytorch/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import pathlib
import typing
from dataclasses import asdict, dataclass, fields, is_dataclass
from typing import Any, Callable, Dict, NamedTuple, Optional, Type, Union

import torch
from dataclasses_json import dataclass_json

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType

try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can always use typing_extensions right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! I'll merge this now but will make sure to modify the import to use typing_extensions in a different PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified the import — I had to resolve a merge conflict.



class IsDataclass(Protocol):
__dataclass_fields__: Dict
__dataclass_params__: Dict
__post_init__: Optional[Callable]
eapolinario marked this conversation as resolved.
Show resolved Hide resolved


@dataclass_json
@dataclass
class PyTorchCheckpoint:
"""
This class is helpful to save a checkpoint.
"""

module: Optional[torch.nn.Module] = None
hyperparameters: Optional[Union[Dict[str, Any], NamedTuple, IsDataclass]] = None
optimizer: Optional[torch.optim.Optimizer] = None

def __post_init__(self):
if not (
isinstance(self.hyperparameters, dict)
or (is_dataclass(self.hyperparameters) and not isinstance(self.hyperparameters, type))
or (isinstance(self.hyperparameters, tuple) and hasattr(self.hyperparameters, "_fields"))
or (self.hyperparameters is None)
):
raise TypeTransformerFailedError(
f"hyperparameters must be a dict, dataclass, or NamedTuple. Got {type(self.hyperparameters)}"
)

if not (self.module or self.hyperparameters or self.optimizer):
raise TypeTransformerFailedError("Must have at least one of module, hyperparameters, or optimizer")


class PyTorchCheckpointTransformer(TypeTransformer[PyTorchCheckpoint]):
"""
TypeTransformer that supports serializing and deserializing checkpoint.
"""

PYTORCH_CHECKPOINT_FORMAT = "PyTorchCheckpoint"

def __init__(self):
super().__init__(name="PyTorch Checkpoint", t=PyTorchCheckpoint)

def get_literal_type(self, t: Type[PyTorchCheckpoint]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)

def to_literal(
self,
ctx: FlyteContext,
python_val: PyTorchCheckpoint,
python_type: Type[PyTorchCheckpoint],
expected: LiteralType,
) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)

local_path = ctx.file_access.get_random_local_path() + ".pt"
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

to_save = {}
for field in fields(python_val):
value = getattr(python_val, field.name)

if value and field.name in ["module", "optimizer"]:
to_save[field.name + "_state_dict"] = getattr(value, "state_dict")()
elif value and field.name == "hyperparameters":
if isinstance(value, dict):
to_save.update(value)
elif isinstance(value, tuple):
to_save.update(value._asdict())
elif is_dataclass(value):
to_save.update(asdict(value))

if not to_save:
raise TypeTransformerFailedError(f"Cannot save empty {python_val}")

# save checkpoint to a file
torch.save(to_save, local_path)

remote_path = ctx.file_access.get_random_remote_path(local_path)
ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[PyTorchCheckpoint]
) -> PyTorchCheckpoint:
try:
uri = lv.scalar.blob.uri
except AttributeError:
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=False)

# cpu <-> gpu conversion
if torch.cuda.is_available():
map_location = "cuda:0"
else:
map_location = torch.device("cpu")

# load checkpoint from a file
return typing.cast(PyTorchCheckpoint, torch.load(local_path, map_location=map_location))

def guess_python_type(self, literal_type: LiteralType) -> Type[PyTorchCheckpoint]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == self.PYTORCH_CHECKPOINT_FORMAT
):
return PyTorchCheckpoint

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


TypeEngine.register(PyTorchCheckpointTransformer())
92 changes: 92 additions & 0 deletions flytekit/extras/pytorch/native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pathlib
from typing import Generic, Type, TypeVar

import torch

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType

T = TypeVar("T")


class PyTorchTypeTransformer(TypeTransformer, Generic[T]):
def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.PYTORCH_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)

def to_literal(
self,
ctx: FlyteContext,
python_val: T,
python_type: Type[T],
expected: LiteralType,
) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTORCH_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)

local_path = ctx.file_access.get_random_local_path() + ".pt"
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

# save pytorch tensor/module to a file
torch.save(python_val, local_path)

remote_path = ctx.file_access.get_random_remote_path(local_path)
ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
try:
uri = lv.scalar.blob.uri
except AttributeError:
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=False)

# cpu <-> gpu conversion
if torch.cuda.is_available():
map_location = "cuda:0"
else:
map_location = torch.device("cpu")

# load pytorch tensor/module from a file
return torch.load(local_path, map_location=map_location)

def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == self.PYTORCH_FORMAT
):
return T

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


class PyTorchTensorTransformer(PyTorchTypeTransformer[torch.Tensor]):
PYTORCH_FORMAT = "PyTorchTensor"

def __init__(self):
super().__init__(name="PyTorch Tensor", t=torch.Tensor)


class PyTorchModuleTransformer(PyTorchTypeTransformer[torch.nn.Module]):
PYTORCH_FORMAT = "PyTorchModule"

def __init__(self):
super().__init__(name="PyTorch Module", t=torch.nn.Module)


TypeEngine.register(PyTorchTensorTransformer())
TypeEngine.register(PyTorchModuleTransformer())
12 changes: 0 additions & 12 deletions flytekit/types/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1 @@
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is removing this from docs intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I don't think we'd want to have Transformer in the API reference cause the methods within the TypeTransformer class remain the same.

Flytekit Numpy
==============
.. currentmodule:: flytekit.types.numpy

.. autosummary::
:template: custom.rst
:toctree: generated/

NumpyArrayTransformer
"""

from .ndarray import NumpyArrayTransformer
1 change: 0 additions & 1 deletion flytekit/types/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
StructuredDataset
StructuredDatasetEncoder
StructuredDatasetDecoder
StructuredDatasetTransformerEngine
"""


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ pyparsing==3.0.9
# packaging
python-dateutil==2.8.2
# via
# arrow
# croniter
# flytekit
# matplotlib
Expand Down
Empty file.
Loading