Skip to content

Commit

Permalink
ONNX Plugin (#804)
Browse files Browse the repository at this point in the history
* fix isort errors

Signed-off-by: Samhita Alla <[email protected]>

* modified the plugin structure

Signed-off-by: Samhita Alla <[email protected]>

* resolve merge conflicts

Signed-off-by: Samhita Alla <[email protected]>

* added more parameters and cleaned up code

Signed-off-by: Samhita Alla <[email protected]>

* add readme

Signed-off-by: Samhita Alla <[email protected]>

* modified package names

Signed-off-by: Samhita Alla <[email protected]>

* update pythonbuild; add ONNXFile

Signed-off-by: Samhita Alla <[email protected]>

* wip

Signed-off-by: Samhita Alla <[email protected]>

* update requirements

Signed-off-by: Samhita Alla <[email protected]>

* exclude tests on python 3.10

Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
samhita-alla authored and wild-endeavor committed Aug 2, 2022
1 parent fc133a0 commit 8eac743
Show file tree
Hide file tree
Showing 28 changed files with 1,564 additions and 0 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ jobs:
- flytekit-kf-pytorch
- flytekit-kf-tensorflow
- flytekit-modin
- flytekit-onnx-pytorch
- flytekit-onnx-scikitlearn
- flytekit-onnx-tensorflow
- flytekit-pandera
- flytekit-papermill
- flytekit-polars
Expand All @@ -92,6 +95,14 @@ jobs:
# https://github.com/great-expectations/great_expectations/blob/develop/setup.py#L87-L89
- python-version: 3.10
plugin-names: "flytekit-greatexpectations"
# onnxruntime does not support python 3.10 yet
# https://github.com/microsoft/onnxruntime/issues/9782
- python-version: 3.10
plugin-names: "flytekit-onnx-pytorch"
- python-version: 3.10
plugin-names: "flytekit-onnx-scikitlearn"
- python-version: 3.10
plugin-names: "flytekit-onnx-tensorflow"
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ docs/source/plugins/generated/
.pytest_flyte
htmlcov
*.ipynb
*dat
5 changes: 5 additions & 0 deletions flytekit/types/file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@
#: Can be used to receive or return a CSVFile. The underlying type is a FlyteFile type. This is just a
#: decoration and useful for attaching content type information with the file and automatically documenting code.
CSVFile = FlyteFile[csv]

onnx = typing.TypeVar("onnx")
#: Can be used to receive or return an ONNXFile. The underlying type is a FlyteFile type. This is just a
#: decoration and useful for attaching content type information with the file and automatically documenting code.
ONNXFile = FlyteFile[onnx]
9 changes: 9 additions & 0 deletions plugins/flytekit-onnx-pytorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Flytekit ONNX PyTorch Plugin

This plugin allows you to generate ONNX models from your PyTorch models.

To install the plugin, run the following command:

```
pip install flytekitplugins-onnxpytorch
```
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .schema import PyTorch2ONNX, PyTorch2ONNXConfig
124 changes: 124 additions & 0 deletions plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union

import torch
from dataclasses_json import dataclass_json
from torch.onnx import OperatorExportTypes, TrainingMode
from typing_extensions import Annotated, get_args, get_origin

from flytekit import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.file import ONNXFile


@dataclass_json
@dataclass
class PyTorch2ONNXConfig:
args: Union[Tuple, torch.Tensor]
export_params: bool = True
verbose: bool = False
training: TrainingMode = TrainingMode.EVAL
opset_version: int = 9
input_names: List[str] = field(default_factory=list)
output_names: List[str] = field(default_factory=list)
operator_export_type: Optional[OperatorExportTypes] = None
do_constant_folding: bool = False
dynamic_axes: Union[Dict[str, Dict[int, str]], Dict[str, List[int]]] = field(default_factory=dict)
keep_initializers_as_inputs: Optional[bool] = None
custom_opsets: Dict[str, int] = field(default_factory=dict)
export_modules_as_functions: Union[bool, set[Type]] = False


@dataclass_json
@dataclass
class PyTorch2ONNX:
model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction] = field(default=None)


def extract_config(t: Type[PyTorch2ONNX]) -> Tuple[Type[PyTorch2ONNX], PyTorch2ONNXConfig]:
config = None
if get_origin(t) is Annotated:
base_type, config = get_args(t)
if isinstance(config, PyTorch2ONNXConfig):
return base_type, config
else:
raise TypeTransformerFailedError(f"{t}'s config isn't of type PyTorch2ONNXConfig")
return t, config


def to_onnx(ctx, model, config):
local_path = ctx.file_access.get_random_local_path()

torch.onnx.export(
model,
**config,
f=local_path,
)

return local_path


class PyTorch2ONNXTransformer(TypeTransformer[PyTorch2ONNX]):
ONNX_FORMAT = "onnx"

def __init__(self):
super().__init__(name="PyTorch ONNX", t=PyTorch2ONNX)

def get_literal_type(self, t: Type[PyTorch2ONNX]) -> LiteralType:
return LiteralType(blob=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE))

def to_literal(
self,
ctx: FlyteContext,
python_val: PyTorch2ONNX,
python_type: Type[PyTorch2ONNX],
expected: LiteralType,
) -> Literal:
python_type, config = extract_config(python_type)

if config:
local_path = to_onnx(ctx, python_val.model, config.__dict__.copy())
remote_path = ctx.file_access.get_random_remote_path()
ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
else:
raise TypeTransformerFailedError(f"{python_type}'s config is None")

return Literal(
scalar=Scalar(
blob=Blob(
uri=remote_path,
metadata=BlobMetadata(
type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)
),
)
)
)

def to_python_value(
self,
ctx: FlyteContext,
lv: Literal,
expected_python_type: Type[ONNXFile],
) -> ONNXFile:
if not (lv.scalar.blob.uri and lv.scalar.blob.metadata.format == self.ONNX_FORMAT):
raise TypeTransformerFailedError(f"ONNX format isn't of the expected type {expected_python_type}")

return ONNXFile(path=lv.scalar.blob.uri)

def guess_python_type(self, literal_type: LiteralType) -> Type[PyTorch2ONNX]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == self.ONNX_FORMAT
):
return PyTorch2ONNX

raise TypeTransformerFailedError(f"Transformer {self} cannot reverse {literal_type}")


TypeEngine.register(PyTorch2ONNXTransformer())
5 changes: 5 additions & 0 deletions plugins/flytekit-onnx-pytorch/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.
-e file:.#egg=flytekitplugins-onnxpytorch
onnxruntime
pillow
torchvision>=0.12.0
199 changes: 199 additions & 0 deletions plugins/flytekit-onnx-pytorch/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
#
# This file is autogenerated by pip-compile with python 3.9
# To update, run:
#
# pip-compile requirements.in
#
-e file:.#egg=flytekitplugins-onnxpytorch
# via -r requirements.in
arrow==1.2.2
# via jinja2-time
binaryornot==0.4.4
# via cookiecutter
certifi==2022.6.15
# via requests
cffi==1.15.1
# via cryptography
chardet==5.0.0
# via binaryornot
charset-normalizer==2.1.0
# via requests
click==8.1.3
# via
# cookiecutter
# flytekit
cloudpickle==2.1.0
# via flytekit
cookiecutter==2.1.1
# via flytekit
croniter==1.3.5
# via flytekit
cryptography==37.0.4
# via pyopenssl
dataclasses-json==0.5.7
# via flytekit
decorator==5.1.1
# via retry
deprecated==1.2.13
# via flytekit
diskcache==5.4.0
# via flytekit
docker==5.0.3
# via flytekit
docker-image-py==0.1.12
# via flytekit
docstring-parser==0.14.1
# via flytekit
flatbuffers==2.0
# via onnxruntime
flyteidl==1.1.8
# via flytekit
flytekit==1.1.0
# via flytekitplugins-onnxpytorch
googleapis-common-protos==1.56.3
# via
# flyteidl
# grpcio-status
grpcio==1.47.0
# via
# flytekit
# grpcio-status
grpcio-status==1.47.0
# via flytekit
idna==3.3
# via requests
importlib-metadata==4.12.0
# via
# flytekit
# keyring
jinja2==3.1.2
# via
# cookiecutter
# jinja2-time
jinja2-time==0.2.0
# via cookiecutter
keyring==23.6.0
# via flytekit
markupsafe==2.1.1
# via jinja2
marshmallow==3.17.0
# via
# dataclasses-json
# marshmallow-enum
# marshmallow-jsonschema
marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow-jsonschema==0.13.0
# via flytekit
mypy-extensions==0.4.3
# via typing-inspect
natsort==8.1.0
# via flytekit
numpy==1.23.0
# via
# onnxruntime
# pandas
# pyarrow
# torchvision
onnxruntime==1.11.1
# via -r requirements.in
packaging==21.3
# via marshmallow
pandas==1.4.3
# via flytekit
pillow==9.2.0
# via
# -r requirements.in
# torchvision
protobuf==3.20.1
# via
# flyteidl
# flytekit
# googleapis-common-protos
# grpcio-status
# onnxruntime
# protoc-gen-swagger
protoc-gen-swagger==0.1.0
# via flyteidl
py==1.11.0
# via retry
pyarrow==6.0.1
# via flytekit
pycparser==2.21
# via cffi
pyopenssl==22.0.0
# via flytekit
pyparsing==3.0.9
# via packaging
python-dateutil==2.8.2
# via
# arrow
# croniter
# flytekit
# pandas
python-json-logger==2.0.2
# via flytekit
python-slugify==6.1.2
# via cookiecutter
pytimeparse==1.1.8
# via flytekit
pytz==2022.1
# via
# flytekit
# pandas
pyyaml==6.0
# via
# cookiecutter
# flytekit
regex==2022.6.2
# via docker-image-py
requests==2.28.1
# via
# cookiecutter
# docker
# flytekit
# responses
# torchvision
responses==0.21.0
# via flytekit
retry==0.9.2
# via flytekit
six==1.16.0
# via
# grpcio
# python-dateutil
sortedcontainers==2.4.0
# via flytekit
statsd==3.3.0
# via flytekit
text-unidecode==1.3
# via python-slugify
torch==1.12.0
# via
# flytekitplugins-onnxpytorch
# torchvision
torchvision==0.13.0
# via -r requirements.in
typing-extensions==4.3.0
# via
# flytekit
# torch
# torchvision
# typing-inspect
typing-inspect==0.7.1
# via dataclasses-json
urllib3==1.26.10
# via
# flytekit
# requests
# responses
websocket-client==1.3.3
# via docker
wheel==0.37.1
# via flytekit
wrapt==1.14.1
# via
# deprecated
# flytekit
zipp==3.8.0
# via importlib-metadata
Loading

0 comments on commit 8eac743

Please sign in to comment.