From edfec7e62179aaa259a72c3298eff4b4327f4955 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 13 Jul 2022 12:48:09 +0530 Subject: [PATCH] ONNX Plugin (#804) * fix isort errors Signed-off-by: Samhita Alla * modified the plugin structure Signed-off-by: Samhita Alla * resolve merge conflicts Signed-off-by: Samhita Alla * added more parameters and cleaned up code Signed-off-by: Samhita Alla * add readme Signed-off-by: Samhita Alla * modified package names Signed-off-by: Samhita Alla * update pythonbuild; add ONNXFile Signed-off-by: Samhita Alla * wip Signed-off-by: Samhita Alla * update requirements Signed-off-by: Samhita Alla * exclude tests on python 3.10 Signed-off-by: Samhita Alla --- .github/workflows/pythonbuild.yml | 11 + .gitignore | 1 + flytekit/types/file/__init__.py | 5 + plugins/flytekit-onnx-pytorch/README.md | 9 + .../flytekitplugins/onnxpytorch/__init__.py | 1 + .../flytekitplugins/onnxpytorch/schema.py | 124 ++++++++ plugins/flytekit-onnx-pytorch/requirements.in | 5 + .../flytekit-onnx-pytorch/requirements.txt | 199 ++++++++++++ plugins/flytekit-onnx-pytorch/setup.py | 34 +++ .../flytekit-onnx-pytorch/tests/__init__.py | 0 .../tests/test_onnx_pytorch.py | 128 ++++++++ plugins/flytekit-onnx-scikitlearn/README.md | 9 + .../onnxscikitlearn/__init__.py | 1 + .../flytekitplugins/onnxscikitlearn/schema.py | 141 +++++++++ .../flytekit-onnx-scikitlearn/requirements.in | 3 + .../requirements.txt | 212 +++++++++++++ plugins/flytekit-onnx-scikitlearn/setup.py | 36 +++ .../tests/__init__.py | 0 .../tests/test_onnx_scikitlearn.py | 113 +++++++ plugins/flytekit-onnx-tensorflow/README.md | 9 + .../onnxtensorflow/__init__.py | 1 + .../flytekitplugins/onnxtensorflow/schema.py | 116 +++++++ .../flytekit-onnx-tensorflow/requirements.in | 4 + .../flytekit-onnx-tensorflow/requirements.txt | 285 ++++++++++++++++++ plugins/flytekit-onnx-tensorflow/setup.py | 36 +++ .../tests/__init__.py | 0 .../tests/test_onnx_tf.py | 78 +++++ plugins/setup.py | 3 + 28 files changed, 1564 insertions(+) create mode 100644 plugins/flytekit-onnx-pytorch/README.md create mode 100644 plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/__init__.py create mode 100644 plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py create mode 100644 plugins/flytekit-onnx-pytorch/requirements.in create mode 100644 plugins/flytekit-onnx-pytorch/requirements.txt create mode 100644 plugins/flytekit-onnx-pytorch/setup.py create mode 100644 plugins/flytekit-onnx-pytorch/tests/__init__.py create mode 100644 plugins/flytekit-onnx-pytorch/tests/test_onnx_pytorch.py create mode 100644 plugins/flytekit-onnx-scikitlearn/README.md create mode 100644 plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/__init__.py create mode 100644 plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py create mode 100644 plugins/flytekit-onnx-scikitlearn/requirements.in create mode 100644 plugins/flytekit-onnx-scikitlearn/requirements.txt create mode 100644 plugins/flytekit-onnx-scikitlearn/setup.py create mode 100644 plugins/flytekit-onnx-scikitlearn/tests/__init__.py create mode 100644 plugins/flytekit-onnx-scikitlearn/tests/test_onnx_scikitlearn.py create mode 100644 plugins/flytekit-onnx-tensorflow/README.md create mode 100644 plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/__init__.py create mode 100644 plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py create mode 100644 plugins/flytekit-onnx-tensorflow/requirements.in create mode 100644 plugins/flytekit-onnx-tensorflow/requirements.txt create mode 100644 plugins/flytekit-onnx-tensorflow/setup.py create mode 100644 plugins/flytekit-onnx-tensorflow/tests/__init__.py create mode 100644 plugins/flytekit-onnx-tensorflow/tests/test_onnx_tf.py diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index ca6c85bb3e..88001991f7 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -77,6 +77,9 @@ jobs: - flytekit-kf-pytorch - flytekit-kf-tensorflow - flytekit-modin + - flytekit-onnx-pytorch + - flytekit-onnx-scikitlearn + - flytekit-onnx-tensorflow - flytekit-pandera - flytekit-papermill - flytekit-polars @@ -92,6 +95,14 @@ jobs: # https://github.com/great-expectations/great_expectations/blob/develop/setup.py#L87-L89 - python-version: 3.10 plugin-names: "flytekit-greatexpectations" + # onnxruntime does not support python 3.10 yet + # https://github.com/microsoft/onnxruntime/issues/9782 + - python-version: 3.10 + plugin-names: "flytekit-onnx-pytorch" + - python-version: 3.10 + plugin-names: "flytekit-onnx-scikitlearn" + - python-version: 3.10 + plugin-names: "flytekit-onnx-tensorflow" steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.gitignore b/.gitignore index b3566d8032..2bd9a0161f 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ docs/source/plugins/generated/ .pytest_flyte htmlcov *.ipynb +*dat diff --git a/flytekit/types/file/__init__.py b/flytekit/types/file/__init__.py index 44841d7e35..9e8fca1971 100644 --- a/flytekit/types/file/__init__.py +++ b/flytekit/types/file/__init__.py @@ -75,3 +75,8 @@ #: Can be used to receive or return a CSVFile. The underlying type is a FlyteFile type. This is just a #: decoration and useful for attaching content type information with the file and automatically documenting code. CSVFile = FlyteFile[csv] + +onnx = typing.TypeVar("onnx") +#: Can be used to receive or return an ONNXFile. The underlying type is a FlyteFile type. This is just a +#: decoration and useful for attaching content type information with the file and automatically documenting code. +ONNXFile = FlyteFile[onnx] diff --git a/plugins/flytekit-onnx-pytorch/README.md b/plugins/flytekit-onnx-pytorch/README.md new file mode 100644 index 0000000000..48bc736854 --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/README.md @@ -0,0 +1,9 @@ +# Flytekit ONNX PyTorch Plugin + +This plugin allows you to generate ONNX models from your PyTorch models. + +To install the plugin, run the following command: + +``` +pip install flytekitplugins-onnxpytorch +``` diff --git a/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/__init__.py b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/__init__.py new file mode 100644 index 0000000000..384fb1cab3 --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/__init__.py @@ -0,0 +1 @@ +from .schema import PyTorch2ONNX, PyTorch2ONNXConfig diff --git a/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py new file mode 100644 index 0000000000..7031867c8d --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch +from dataclasses_json import dataclass_json +from torch.onnx import OperatorExportTypes, TrainingMode +from typing_extensions import Annotated, get_args, get_origin + +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core.types import BlobType +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType +from flytekit.types.file import ONNXFile + + +@dataclass_json +@dataclass +class PyTorch2ONNXConfig: + args: Union[Tuple, torch.Tensor] + export_params: bool = True + verbose: bool = False + training: TrainingMode = TrainingMode.EVAL + opset_version: int = 9 + input_names: List[str] = field(default_factory=list) + output_names: List[str] = field(default_factory=list) + operator_export_type: Optional[OperatorExportTypes] = None + do_constant_folding: bool = False + dynamic_axes: Union[Dict[str, Dict[int, str]], Dict[str, List[int]]] = field(default_factory=dict) + keep_initializers_as_inputs: Optional[bool] = None + custom_opsets: Dict[str, int] = field(default_factory=dict) + export_modules_as_functions: Union[bool, set[Type]] = False + + +@dataclass_json +@dataclass +class PyTorch2ONNX: + model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction] = field(default=None) + + +def extract_config(t: Type[PyTorch2ONNX]) -> Tuple[Type[PyTorch2ONNX], PyTorch2ONNXConfig]: + config = None + if get_origin(t) is Annotated: + base_type, config = get_args(t) + if isinstance(config, PyTorch2ONNXConfig): + return base_type, config + else: + raise TypeTransformerFailedError(f"{t}'s config isn't of type PyTorch2ONNXConfig") + return t, config + + +def to_onnx(ctx, model, config): + local_path = ctx.file_access.get_random_local_path() + + torch.onnx.export( + model, + **config, + f=local_path, + ) + + return local_path + + +class PyTorch2ONNXTransformer(TypeTransformer[PyTorch2ONNX]): + ONNX_FORMAT = "onnx" + + def __init__(self): + super().__init__(name="PyTorch ONNX", t=PyTorch2ONNX) + + def get_literal_type(self, t: Type[PyTorch2ONNX]) -> LiteralType: + return LiteralType(blob=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)) + + def to_literal( + self, + ctx: FlyteContext, + python_val: PyTorch2ONNX, + python_type: Type[PyTorch2ONNX], + expected: LiteralType, + ) -> Literal: + python_type, config = extract_config(python_type) + + if config: + local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) + remote_path = ctx.file_access.get_random_remote_path() + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + else: + raise TypeTransformerFailedError(f"{python_type}'s config is None") + + return Literal( + scalar=Scalar( + blob=Blob( + uri=remote_path, + metadata=BlobMetadata( + type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE) + ), + ) + ) + ) + + def to_python_value( + self, + ctx: FlyteContext, + lv: Literal, + expected_python_type: Type[ONNXFile], + ) -> ONNXFile: + if not (lv.scalar.blob.uri and lv.scalar.blob.metadata.format == self.ONNX_FORMAT): + raise TypeTransformerFailedError(f"ONNX format isn't of the expected type {expected_python_type}") + + return ONNXFile(path=lv.scalar.blob.uri) + + def guess_python_type(self, literal_type: LiteralType) -> Type[PyTorch2ONNX]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.ONNX_FORMAT + ): + return PyTorch2ONNX + + raise TypeTransformerFailedError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(PyTorch2ONNXTransformer()) diff --git a/plugins/flytekit-onnx-pytorch/requirements.in b/plugins/flytekit-onnx-pytorch/requirements.in new file mode 100644 index 0000000000..7632919db4 --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/requirements.in @@ -0,0 +1,5 @@ +. +-e file:.#egg=flytekitplugins-onnxpytorch +onnxruntime +pillow +torchvision>=0.12.0 diff --git a/plugins/flytekit-onnx-pytorch/requirements.txt b/plugins/flytekit-onnx-pytorch/requirements.txt new file mode 100644 index 0000000000..fa8ac445af --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/requirements.txt @@ -0,0 +1,199 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-onnxpytorch + # via -r requirements.in +arrow==1.2.2 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.6.15 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.0 + # via requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.1.0 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.4 + # via pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +flatbuffers==2.0 + # via onnxruntime +flyteidl==1.1.8 + # via flytekit +flytekit==1.1.0 + # via flytekitplugins-onnxpytorch +googleapis-common-protos==1.56.3 + # via + # flyteidl + # grpcio-status +grpcio==1.47.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.47.0 + # via flytekit +idna==3.3 + # via requests +importlib-metadata==4.12.0 + # via + # flytekit + # keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +keyring==23.6.0 + # via flytekit +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.17.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +numpy==1.23.0 + # via + # onnxruntime + # pandas + # pyarrow + # torchvision +onnxruntime==1.11.1 + # via -r requirements.in +packaging==21.3 + # via marshmallow +pandas==1.4.3 + # via flytekit +pillow==9.2.0 + # via + # -r requirements.in + # torchvision +protobuf==3.20.1 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # onnxruntime + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit +regex==2022.6.2 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # docker + # flytekit + # responses + # torchvision +responses==0.21.0 + # via flytekit +retry==0.9.2 + # via flytekit +six==1.16.0 + # via + # grpcio + # python-dateutil +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +torch==1.12.0 + # via + # flytekitplugins-onnxpytorch + # torchvision +torchvision==0.13.0 + # via -r requirements.in +typing-extensions==4.3.0 + # via + # flytekit + # torch + # torchvision + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.10 + # via + # flytekit + # requests + # responses +websocket-client==1.3.3 + # via docker +wheel==0.37.1 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zipp==3.8.0 + # via importlib-metadata diff --git a/plugins/flytekit-onnx-pytorch/setup.py b/plugins/flytekit-onnx-pytorch/setup.py new file mode 100644 index 0000000000..74e3b940ec --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/setup.py @@ -0,0 +1,34 @@ +from setuptools import setup + +PLUGIN_NAME = "onnxpytorch" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "torch>=1.11.0"] + +__version__ = "0.0.0+develop" + +setup( + name=f"flytekitplugins-{PLUGIN_NAME}", + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="ONNX PyTorch Plugin for Flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-onnx-pytorch/tests/__init__.py b/plugins/flytekit-onnx-pytorch/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-onnx-pytorch/tests/test_onnx_pytorch.py b/plugins/flytekit-onnx-pytorch/tests/test_onnx_pytorch.py new file mode 100644 index 0000000000..3a704a4780 --- /dev/null +++ b/plugins/flytekit-onnx-pytorch/tests/test_onnx_pytorch.py @@ -0,0 +1,128 @@ +# Some standard imports +from pathlib import Path + +import numpy as np +import onnxruntime +import requests +import torch.nn.init as init +import torch.onnx +import torch.utils.model_zoo as model_zoo +import torchvision.transforms as transforms +from flytekitplugins.onnxpytorch import PyTorch2ONNX, PyTorch2ONNXConfig +from PIL import Image +from torch import nn +from typing_extensions import Annotated + +import flytekit +from flytekit import task, workflow +from flytekit.types.file import JPEGImageFile, ONNXFile + + +class SuperResolutionNet(nn.Module): + def __init__(self, upscale_factor, inplace=False): + super(SuperResolutionNet, self).__init__() + + self.relu = nn.ReLU(inplace=inplace) + self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) + self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) + self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) + self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1)) + self.pixel_shuffle = nn.PixelShuffle(upscale_factor) + + self._initialize_weights() + + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.relu(self.conv3(x)) + x = self.pixel_shuffle(self.conv4(x)) + return x + + def _initialize_weights(self): + init.orthogonal_(self.conv1.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv2.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv3.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv4.weight) + + +def test_onnx_pytorch(): + @task + def train() -> Annotated[ + PyTorch2ONNX, + PyTorch2ONNXConfig( + args=torch.randn(1, 1, 224, 224, requires_grad=True), + export_params=True, # store the trained parameter weights inside + opset_version=10, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=["input"], # the model's input names + output_names=["output"], # the model's output names + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, # variable length axes + ), + ]: + # Create the super-resolution model by using the above model definition. + torch_model = SuperResolutionNet(upscale_factor=3) + + # Load pretrained model weights + model_url = "https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth" + + # Initialize model with the pretrained weights + map_location = lambda storage, loc: storage # noqa: E731 + if torch.cuda.is_available(): + map_location = None + torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location)) + + return PyTorch2ONNX(model=torch_model) + + @task + def onnx_predict(model_file: ONNXFile) -> JPEGImageFile: + ort_session = onnxruntime.InferenceSession(model_file.download()) + + img = Image.open( + requests.get( + "https://raw.githubusercontent.com/flyteorg/static-resources/main/flytekit/onnx/cat.jpg", stream=True + ).raw + ) + + resize = transforms.Resize([224, 224]) + img = resize(img) + + img_ycbcr = img.convert("YCbCr") + img_y, img_cb, img_cr = img_ycbcr.split() + + to_tensor = transforms.ToTensor() + img_y = to_tensor(img_y) + img_y.unsqueeze_(0) + + # compute ONNX Runtime output prediction + ort_inputs = { + ort_session.get_inputs()[0].name: img_y.detach().cpu().numpy() + if img_y.requires_grad + else img_y.cpu().numpy() + } + ort_outs = ort_session.run(None, ort_inputs) + img_out_y = ort_outs[0] + + img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L") + + # get the output image follow post-processing step from PyTorch implementation + final_img = Image.merge( + "YCbCr", + [ + img_out_y, + img_cb.resize(img_out_y.size, Image.BICUBIC), + img_cr.resize(img_out_y.size, Image.BICUBIC), + ], + ).convert("RGB") + + img_path = Path(flytekit.current_context().working_directory) / "cat_superres_with_ort.jpg" + final_img.save(img_path) + + # Save the image, we will compare this with the output image from mobile device + return JPEGImageFile(path=str(img_path)) + + @workflow + def wf() -> JPEGImageFile: + model = train() + return onnx_predict(model_file=model) + + print(wf()) diff --git a/plugins/flytekit-onnx-scikitlearn/README.md b/plugins/flytekit-onnx-scikitlearn/README.md new file mode 100644 index 0000000000..220a157090 --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/README.md @@ -0,0 +1,9 @@ +# Flytekit ONNX ScikitLearn Plugin + +This plugin allows you to generate ONNX models from your ScikitLearn models. + +To install the plugin, run the following command: + +``` +pip install flytekitplugins-onnxscikitlearn +``` diff --git a/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/__init__.py b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/__init__.py new file mode 100644 index 0000000000..d09c317c07 --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/__init__.py @@ -0,0 +1 @@ +from .schema import ScikitLearn2ONNX, ScikitLearn2ONNXConfig diff --git a/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py new file mode 100644 index 0000000000..db50986b5e --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import inspect +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import skl2onnx.common.data_types +from dataclasses_json import dataclass_json +from skl2onnx import convert_sklearn +from sklearn.base import BaseEstimator +from typing_extensions import Annotated, get_args, get_origin + +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core.types import BlobType +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType +from flytekit.types.file import ONNXFile + + +@dataclass_json +@dataclass +class ScikitLearn2ONNXConfig: + initial_types: List[Tuple[str, Type]] + name: Optional[str] = None + doc_string: str = "" + target_opset: Optional[int] = None + custom_conversion_functions: Dict[Callable[..., Any], Callable[..., None]] = field(default_factory=dict) + custom_shape_calculators: Dict[Callable[..., Any], Callable[..., None]] = field(default_factory=dict) + custom_parsers: Dict[Callable[..., Any], Callable[..., None]] = field(default_factory=dict) + options: Dict[Any, Any] = field(default_factory=dict) + intermediate: bool = False + naming: Union[str, Callable[..., Any]] = None + white_op: Optional[Set[str]] = None + black_op: Optional[Set[str]] = None + verbose: int = 0 + final_types: Optional[List[Tuple[str, Type]]] = None + + def __post_init__(self): + validate_initial_types = [ + True for item in self.initial_types if item in inspect.getmembers(skl2onnx.common.data_types) + ] + if not all(validate_initial_types): + raise ValueError("All types in initial_types must be in skl2onnx.common.data_types") + + if self.final_types: + validate_final_types = [ + True for item in self.final_types if item in inspect.getmembers(skl2onnx.common.data_types) + ] + if not all(validate_final_types): + raise ValueError("All types in final_types must be in skl2onnx.common.data_types") + + +@dataclass_json +@dataclass +class ScikitLearn2ONNX: + model: BaseEstimator = field(default=None) + + +def extract_config(t: Type[ScikitLearn2ONNX]) -> Tuple[Type[ScikitLearn2ONNX], ScikitLearn2ONNXConfig]: + config = None + + if get_origin(t) is Annotated: + base_type, config = get_args(t) + if isinstance(config, ScikitLearn2ONNXConfig): + return base_type, config + else: + raise TypeTransformerFailedError(f"{t}'s config isn't of type ScikitLearn2ONNXConfig") + return t, config + + +def to_onnx(ctx, model, config): + local_path = ctx.file_access.get_random_local_path() + + onx = convert_sklearn(model, **config) + + with open(local_path, "wb") as f: + f.write(onx.SerializeToString()) + + return local_path + + +class ScikitLearn2ONNXTransformer(TypeTransformer[ScikitLearn2ONNX]): + ONNX_FORMAT = "onnx" + + def __init__(self): + super().__init__(name="ScikitLearn ONNX", t=ScikitLearn2ONNX) + + def get_literal_type(self, t: Type[ScikitLearn2ONNX]) -> LiteralType: + return LiteralType(blob=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)) + + def to_literal( + self, + ctx: FlyteContext, + python_val: ScikitLearn2ONNX, + python_type: Type[ScikitLearn2ONNX], + expected: LiteralType, + ) -> Literal: + python_type, config = extract_config(python_type) + + if config: + remote_path = ctx.file_access.get_random_remote_path() + local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + else: + raise TypeTransformerFailedError(f"{python_type}'s config is None") + + return Literal( + scalar=Scalar( + blob=Blob( + uri=remote_path, + metadata=BlobMetadata( + type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE) + ), + ) + ) + ) + + def to_python_value( + self, + ctx: FlyteContext, + lv: Literal, + expected_python_type: Type[ONNXFile], + ) -> ONNXFile: + if not lv.scalar.blob.uri: + raise TypeTransformerFailedError(f"ONNX format isn't of the expected type {expected_python_type}") + + return ONNXFile(path=lv.scalar.blob.uri) + + def guess_python_type(self, literal_type: LiteralType) -> Type[ScikitLearn2ONNX]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.ONNX_FORMAT + ): + return ScikitLearn2ONNX + + raise TypeTransformerFailedError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(ScikitLearn2ONNXTransformer()) diff --git a/plugins/flytekit-onnx-scikitlearn/requirements.in b/plugins/flytekit-onnx-scikitlearn/requirements.in new file mode 100644 index 0000000000..bdf7848c4a --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/requirements.in @@ -0,0 +1,3 @@ +. +-e file:.#egg=flytekitplugins-onnxscikitlearn +onnxruntime diff --git a/plugins/flytekit-onnx-scikitlearn/requirements.txt b/plugins/flytekit-onnx-scikitlearn/requirements.txt new file mode 100644 index 0000000000..7ff826a50b --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/requirements.txt @@ -0,0 +1,212 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-onnxscikitlearn + # via -r requirements.in +arrow==1.2.2 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.6.15 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.0 + # via requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.1.0 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.4 + # via pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +flatbuffers==2.0 + # via onnxruntime +flyteidl==1.1.8 + # via flytekit +flytekit==1.1.0 + # via flytekitplugins-onnxscikitlearn +googleapis-common-protos==1.56.3 + # via + # flyteidl + # grpcio-status +grpcio==1.47.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.47.0 + # via flytekit +idna==3.3 + # via requests +importlib-metadata==4.12.0 + # via + # flytekit + # keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.1.0 + # via scikit-learn +keyring==23.6.0 + # via flytekit +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.17.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +numpy==1.23.0 + # via + # onnx + # onnxconverter-common + # onnxruntime + # pandas + # pyarrow + # scikit-learn + # scipy + # skl2onnx +onnx==1.12.0 + # via + # onnxconverter-common + # skl2onnx +onnxconverter-common==1.9.0 + # via skl2onnx +onnxruntime==1.11.1 + # via -r requirements.in +packaging==21.3 + # via marshmallow +pandas==1.4.3 + # via flytekit +protobuf==3.20.1 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # onnx + # onnxconverter-common + # onnxruntime + # protoc-gen-swagger + # skl2onnx +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit +regex==2022.6.2 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # docker + # flytekit + # responses +responses==0.21.0 + # via flytekit +retry==0.9.2 + # via flytekit +scikit-learn==1.1.1 + # via skl2onnx +scipy==1.8.1 + # via + # scikit-learn + # skl2onnx +six==1.16.0 + # via + # grpcio + # python-dateutil +skl2onnx==1.11.2 + # via flytekitplugins-onnxscikitlearn +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +threadpoolctl==3.1.0 + # via scikit-learn +typing-extensions==4.3.0 + # via + # flytekit + # onnx + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.10 + # via + # flytekit + # requests + # responses +websocket-client==1.3.3 + # via docker +wheel==0.37.1 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zipp==3.8.0 + # via importlib-metadata diff --git a/plugins/flytekit-onnx-scikitlearn/setup.py b/plugins/flytekit-onnx-scikitlearn/setup.py new file mode 100644 index 0000000000..9815bedaf2 --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "onnxscikitlearn" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "skl2onnx>=1.10.3"] + +__version__ = "0.0.0+develop" + +setup( + name=f"flytekitplugins-{PLUGIN_NAME}", + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="ONNX ScikitLearn Plugin for Flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-onnx-scikitlearn/tests/__init__.py b/plugins/flytekit-onnx-scikitlearn/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-onnx-scikitlearn/tests/test_onnx_scikitlearn.py b/plugins/flytekit-onnx-scikitlearn/tests/test_onnx_scikitlearn.py new file mode 100644 index 0000000000..d6f1617ece --- /dev/null +++ b/plugins/flytekit-onnx-scikitlearn/tests/test_onnx_scikitlearn.py @@ -0,0 +1,113 @@ +from typing import List, NamedTuple + +import numpy +import onnxruntime as rt +import pandas as pd +from flytekitplugins.onnxscikitlearn import ScikitLearn2ONNX, ScikitLearn2ONNXConfig +from skl2onnx.common._apply_operation import apply_mul +from skl2onnx.common.data_types import FloatTensorType +from skl2onnx.proto import onnx_proto +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split +from typing_extensions import Annotated + +from flytekit import task, workflow +from flytekit.types.file import ONNXFile + + +def test_onnx_scikitlearn_simple(): + TrainOutput = NamedTuple( + "TrainOutput", + [ + ( + "model", + Annotated[ + ScikitLearn2ONNX, + ScikitLearn2ONNXConfig( + initial_types=[("float_input", FloatTensorType([None, 4]))], + target_opset=12, + ), + ], + ), + ("test", pd.DataFrame), + ], + ) + + @task + def train() -> TrainOutput: + iris = load_iris(as_frame=True) + X, y = iris.data, iris.target + X_train, X_test, y_train, _ = train_test_split(X, y) + model = RandomForestClassifier() + model.fit(X_train, y_train) + + return TrainOutput(test=X_test, model=ScikitLearn2ONNX(model)) + + @task + def predict( + model: ONNXFile, + X_test: pd.DataFrame, + ) -> List[int]: + sess = rt.InferenceSession(model.download()) + input_name = sess.get_inputs()[0].name + label_name = sess.get_outputs()[0].name + pred_onx = sess.run([label_name], {input_name: X_test.to_numpy(dtype=numpy.float32)})[0] + return pred_onx.tolist() + + @workflow + def wf() -> List[int]: + train_output = train() + return predict(model=train_output.model, X_test=train_output.test) + + print(wf()) + + +class CustomTransform(BaseEstimator, TransformerMixin): + def __init__(self): + TransformerMixin.__init__(self) + BaseEstimator.__init__(self) + + def fit(self, X, y, sample_weight=None): + pass + + def transform(self, X): + return X * numpy.array([[0.5, 0.1, 10], [0.5, 0.1, 10]]).T + + +def custom_transform_shape_calculator(operator): + operator.outputs[0].type = FloatTensorType([3, 2]) + + +def custom_tranform_converter(scope, operator, container): + input = operator.inputs[0] + output = operator.outputs[0] + + weights_name = scope.get_unique_variable_name("weights") + atype = onnx_proto.TensorProto.FLOAT + weights = [0.5, 0.1, 10] + shape = [len(weights), 1] + container.add_initializer(weights_name, atype, shape, weights) + apply_mul(scope, [input.full_name, weights_name], output.full_name, container) + + +def test_onnx_scikitlearn(): + @task + def get_model() -> Annotated[ + ScikitLearn2ONNX, + ScikitLearn2ONNXConfig( + initial_types=[("input", FloatTensorType([None, numpy.array([[1, 2], [3, 4], [4, 5]]).shape[1]]))], + custom_shape_calculators={CustomTransform: custom_transform_shape_calculator}, + custom_conversion_functions={CustomTransform: custom_tranform_converter}, + target_opset=12, + ), + ]: + model = CustomTransform() + return ScikitLearn2ONNX(model) + + @workflow + def wf() -> ONNXFile: + return get_model() + + print(wf()) diff --git a/plugins/flytekit-onnx-tensorflow/README.md b/plugins/flytekit-onnx-tensorflow/README.md new file mode 100644 index 0000000000..cd29ede0e1 --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/README.md @@ -0,0 +1,9 @@ +# Flytekit ONNX TensorFlow Plugin + +This plugin allows you to generate ONNX models from your TensorFlow Keras models. + +To install the plugin, run the following command: + +``` +pip install flytekitplugins-onnxtensorflow +``` diff --git a/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/__init__.py b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/__init__.py new file mode 100644 index 0000000000..c359a7893d --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/__init__.py @@ -0,0 +1 @@ +from .schema import TensorFlow2ONNX, TensorFlow2ONNXConfig diff --git a/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py new file mode 100644 index 0000000000..184083f90a --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import tensorflow as tf +import tf2onnx +from dataclasses_json import dataclass_json +from typing_extensions import Annotated, get_args, get_origin + +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core.types import BlobType +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType +from flytekit.types.file import ONNXFile + + +@dataclass_json +@dataclass +class TensorFlow2ONNXConfig: + input_signature: Union[tf.TensorSpec, np.ndarray] + custom_ops: Optional[Dict[str, Any]] = None + target: Optional[List[Any]] = None + custom_op_handlers: Optional[Dict[Any, Tuple]] = None + custom_rewriter: Optional[List[Any]] = None + opset: Optional[int] = None + extra_opset: Optional[List[int]] = None + shape_override: Optional[Dict[str, List[Any]]] = None + inputs_as_nchw: Optional[List[str]] = None + large_model: bool = False + + +@dataclass_json +@dataclass +class TensorFlow2ONNX: + model: tf.keras = field(default=None) + + +def extract_config(t: Type[TensorFlow2ONNX]) -> Tuple[Type[TensorFlow2ONNX], TensorFlow2ONNXConfig]: + config = None + if get_origin(t) is Annotated: + base_type, config = get_args(t) + if isinstance(config, TensorFlow2ONNXConfig): + return base_type, config + else: + raise TypeTransformerFailedError(f"{t}'s config isn't of type TensorFlow2ONNX") + return t, config + + +def to_onnx(ctx, model, config): + local_path = ctx.file_access.get_random_local_path() + + tf2onnx.convert.from_keras(model, **config, output_path=local_path) + + return local_path + + +class TensorFlow2ONNXTransformer(TypeTransformer[TensorFlow2ONNX]): + ONNX_FORMAT = "onnx" + + def __init__(self): + super().__init__(name="TensorFlow ONNX", t=TensorFlow2ONNX) + + def get_literal_type(self, t: Type[TensorFlow2ONNX]) -> LiteralType: + return LiteralType(blob=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)) + + def to_literal( + self, + ctx: FlyteContext, + python_val: TensorFlow2ONNX, + python_type: Type[TensorFlow2ONNX], + expected: LiteralType, + ) -> Literal: + python_type, config = extract_config(python_type) + + if config: + remote_path = ctx.file_access.get_random_remote_path() + local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + else: + raise TypeTransformerFailedError(f"{python_type}'s config is None") + + return Literal( + scalar=Scalar( + blob=Blob( + uri=remote_path, + metadata=BlobMetadata( + type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE) + ), + ) + ) + ) + + def to_python_value( + self, + ctx: FlyteContext, + lv: Literal, + expected_python_type: Type[ONNXFile], + ) -> ONNXFile: + if not lv.scalar.blob.uri: + raise TypeTransformerFailedError(f"ONNX format isn't of the expected type {expected_python_type}") + + return ONNXFile(path=lv.scalar.blob.uri) + + def guess_python_type(self, literal_type: LiteralType) -> Type[TensorFlow2ONNX]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.ONNX_FORMAT + ): + return TensorFlow2ONNX + + raise TypeTransformerFailedError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(TensorFlow2ONNXTransformer()) diff --git a/plugins/flytekit-onnx-tensorflow/requirements.in b/plugins/flytekit-onnx-tensorflow/requirements.in new file mode 100644 index 0000000000..0752c85f88 --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/requirements.in @@ -0,0 +1,4 @@ +. +-e file:.#egg=flytekitplugins-onnxtensorflow +onnxruntime +pillow diff --git a/plugins/flytekit-onnx-tensorflow/requirements.txt b/plugins/flytekit-onnx-tensorflow/requirements.txt new file mode 100644 index 0000000000..51f649a4ed --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/requirements.txt @@ -0,0 +1,285 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-onnxtensorflow + # via -r requirements.in +absl-py==1.1.0 + # via + # tensorboard + # tensorflow +arrow==1.2.2 + # via jinja2-time +astunparse==1.6.3 + # via tensorflow +binaryornot==0.4.4 + # via cookiecutter +cachetools==5.2.0 + # via google-auth +certifi==2022.6.15 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.0 + # via requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.1.0 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.4 + # via pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +flatbuffers==1.12 + # via + # onnxruntime + # tensorflow + # tf2onnx +flyteidl==1.1.8 + # via flytekit +flytekit==1.1.0 + # via flytekitplugins-onnxtensorflow +gast==0.4.0 + # via tensorflow +google-auth==2.9.0 + # via + # google-auth-oauthlib + # tensorboard +google-auth-oauthlib==0.4.6 + # via tensorboard +google-pasta==0.2.0 + # via tensorflow +googleapis-common-protos==1.56.3 + # via + # flyteidl + # grpcio-status +grpcio==1.47.0 + # via + # flytekit + # grpcio-status + # tensorboard + # tensorflow +grpcio-status==1.47.0 + # via flytekit +h5py==3.7.0 + # via tensorflow +idna==3.3 + # via requests +importlib-metadata==4.12.0 + # via + # flytekit + # keyring + # markdown +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +keras==2.9.0 + # via tensorflow +keras-preprocessing==1.1.2 + # via tensorflow +keyring==23.6.0 + # via flytekit +libclang==14.0.1 + # via tensorflow +markdown==3.3.7 + # via tensorboard +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.17.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +numpy==1.23.0 + # via + # h5py + # keras-preprocessing + # onnx + # onnxruntime + # opt-einsum + # pandas + # pyarrow + # tensorboard + # tensorflow + # tf2onnx +oauthlib==3.2.0 + # via requests-oauthlib +onnx==1.12.0 + # via tf2onnx +onnxruntime==1.11.1 + # via -r requirements.in +opt-einsum==3.3.0 + # via tensorflow +packaging==21.3 + # via + # marshmallow + # tensorflow +pandas==1.4.3 + # via flytekit +pillow==9.2.0 + # via -r requirements.in +protobuf==3.19.4 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # onnx + # onnxruntime + # protoc-gen-swagger + # tensorboard + # tensorflow +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit +regex==2022.6.2 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # docker + # flytekit + # requests-oauthlib + # responses + # tensorboard + # tf2onnx +requests-oauthlib==1.3.1 + # via google-auth-oauthlib +responses==0.21.0 + # via flytekit +retry==0.9.2 + # via flytekit +rsa==4.8 + # via google-auth +six==1.16.0 + # via + # astunparse + # google-auth + # google-pasta + # grpcio + # keras-preprocessing + # python-dateutil + # tensorflow + # tf2onnx +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +tensorboard==2.9.1 + # via tensorflow +tensorboard-data-server==0.6.1 + # via tensorboard +tensorboard-plugin-wit==1.8.1 + # via tensorboard +tensorflow==2.9.1 + # via flytekitplugins-onnxtensorflow +tensorflow-estimator==2.9.0 + # via tensorflow +tensorflow-io-gcs-filesystem==0.26.0 + # via tensorflow +termcolor==1.1.0 + # via tensorflow +text-unidecode==1.3 + # via python-slugify +tf2onnx==1.11.1 + # via flytekitplugins-onnxtensorflow +typing-extensions==4.3.0 + # via + # flytekit + # onnx + # tensorflow + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.9 + # via + # flytekit + # requests + # responses +websocket-client==1.3.3 + # via docker +werkzeug==2.1.2 + # via tensorboard +wheel==0.37.1 + # via + # astunparse + # flytekit + # tensorboard +wrapt==1.14.1 + # via + # deprecated + # flytekit + # tensorflow +zipp==3.8.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-onnx-tensorflow/setup.py b/plugins/flytekit-onnx-tensorflow/setup.py new file mode 100644 index 0000000000..d2865b083d --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "onnxtensorflow" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "tf2onnx>=1.9.3", "tensorflow>=2.7.0"] + +__version__ = "0.0.0+develop" + +setup( + name=f"flytekitplugins-{PLUGIN_NAME}", + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="ONNX TensorFlow Plugin for Flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-onnx-tensorflow/tests/__init__.py b/plugins/flytekit-onnx-tensorflow/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-onnx-tensorflow/tests/test_onnx_tf.py b/plugins/flytekit-onnx-tensorflow/tests/test_onnx_tf.py new file mode 100644 index 0000000000..259113828a --- /dev/null +++ b/plugins/flytekit-onnx-tensorflow/tests/test_onnx_tf.py @@ -0,0 +1,78 @@ +import urllib +from io import BytesIO +from typing import List, NamedTuple + +import numpy as np +import onnxruntime as rt +import tensorflow as tf +from flytekitplugins.onnxtensorflow import TensorFlow2ONNX, TensorFlow2ONNXConfig +from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input +from tensorflow.keras.preprocessing import image +from typing_extensions import Annotated + +from flytekit import task, workflow +from flytekit.types.file import ONNXFile + + +def test_tf_onnx(): + @task + def load_test_img() -> np.ndarray: + with urllib.request.urlopen( + "https://raw.githubusercontent.com/flyteorg/static-resources/main/flytekit/onnx/ade20k.jpg" + ) as url: + img = image.load_img( + BytesIO(url.read()), + target_size=(224, 224), + ) + + x = image.img_to_array(img) + x = np.expand_dims(x, axis=0) + x = preprocess_input(x) + return x + + TrainPredictOutput = NamedTuple( + "TrainPredictOutput", + [ + ("predictions", np.ndarray), + ( + "model", + Annotated[ + TensorFlow2ONNX, + TensorFlow2ONNXConfig( + input_signature=(tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),), opset=13 + ), + ], + ), + ], + ) + + @task + def train_and_predict(img: np.ndarray) -> TrainPredictOutput: + model = ResNet50(weights="imagenet") + + preds = model.predict(img) + return TrainPredictOutput(predictions=preds, model=TensorFlow2ONNX(model)) + + @task + def onnx_predict( + model: ONNXFile, + img: np.ndarray, + ) -> List[np.ndarray]: + m = rt.InferenceSession(model.download(), providers=["CPUExecutionProvider"]) + onnx_pred = m.run([n.name for n in m.get_outputs()], {"input": img}) + + return onnx_pred + + WorkflowOutput = NamedTuple( + "WorkflowOutput", [("keras_predictions", np.ndarray), ("onnx_predictions", List[np.ndarray])] + ) + + @workflow + def wf() -> WorkflowOutput: + img = load_test_img() + train_predict_output = train_and_predict(img=img) + onnx_preds = onnx_predict(model=train_predict_output.model, img=img) + return WorkflowOutput(keras_predictions=train_predict_output.predictions, onnx_predictions=onnx_preds) + + predictions = wf() + np.testing.assert_allclose(predictions.keras_predictions, predictions.onnx_predictions[0], rtol=1e-5) diff --git a/plugins/setup.py b/plugins/setup.py index 8f3cc5c299..bc15144ee1 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -21,6 +21,9 @@ "flytekitplugins-kfpytorch": "flytekit-kf-pytorch", "flytekitplugins-kftensorflow": "flytekit-kf-tensorflow", "flytekitplugins-modin": "flytekit-modin", + "flytekitplugins-onnxscikitlearn": "flytekit-onnx-scikitlearn", + "flytekitplugins-onnxtensorflow": "flytekit-onnx-tensorflow", + "flytekitplugins-onnxpytorch": "flytekit-onnx-pytorch", "flytekitplugins-pandera": "flytekit-pandera", "flytekitplugins-papermill": "flytekit-papermill", "flytekitplugins-snowflake": "flytekit-snowflake",