From cf12de172f1408b92adeceddae28f8a255b5ba86 Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Fri, 6 Jan 2023 14:52:16 -0800 Subject: [PATCH] Add support MLFlow plugin (#1274) * MLFlow plugin in progress Signed-off-by: Ketan Umare * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * update test Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * update readme Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * dwip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * change experiment name Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * Add mlflow to index.rst Signed-off-by: Kevin Su * use experiment name that user provided Signed-off-by: Kevin Su * update doc-requirements.txt Signed-off-by: Kevin Su * Add backend plugin deployment Signed-off-by: Kevin Su * generate doc for method Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su * update docstring Signed-off-by: Niels Bantilan * update docstring Signed-off-by: Niels Bantilan * Update tracking.py Signed-off-by: Niels Bantilan Signed-off-by: Ketan Umare Signed-off-by: Kevin Su Signed-off-by: Niels Bantilan Co-authored-by: Kevin Su Co-authored-by: Niels Bantilan --- doc-requirements.in | 1 + doc-requirements.txt | 255 +++++++++++++--- docs/source/plugins/index.rst | 2 + docs/source/plugins/mlflow.rst | 9 + plugins/flytekit-mlflow/README.md | 22 ++ plugins/flytekit-mlflow/dev-requirements.in | 1 + plugins/flytekit-mlflow/dev-requirements.txt | 122 ++++++++ .../flytekitplugins/mlflow/__init__.py | 13 + .../flytekitplugins/mlflow/tracking.py | 140 +++++++++ plugins/flytekit-mlflow/requirements.in | 3 + plugins/flytekit-mlflow/requirements.txt | 274 ++++++++++++++++++ plugins/flytekit-mlflow/setup.py | 36 +++ plugins/flytekit-mlflow/tests/__init__.py | 0 .../tests/test_mlflow_tracking.py | 32 ++ 14 files changed, 863 insertions(+), 47 deletions(-) create mode 100644 docs/source/plugins/mlflow.rst create mode 100644 plugins/flytekit-mlflow/README.md create mode 100644 plugins/flytekit-mlflow/dev-requirements.in create mode 100644 plugins/flytekit-mlflow/dev-requirements.txt create mode 100644 plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py create mode 100644 plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py create mode 100644 plugins/flytekit-mlflow/requirements.in create mode 100644 plugins/flytekit-mlflow/requirements.txt create mode 100644 plugins/flytekit-mlflow/setup.py create mode 100644 plugins/flytekit-mlflow/tests/__init__.py create mode 100644 plugins/flytekit-mlflow/tests/test_mlflow_tracking.py diff --git a/doc-requirements.in b/doc-requirements.in index 20cdf76ff1..9fa7c50a1b 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -47,3 +47,4 @@ whylabs-client # whylogs ray # ray scikit-learn # scikit-learn vaex # vaex +mlflow # mlflow diff --git a/doc-requirements.txt b/doc-requirements.txt index da61aa525b..314c926e39 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -14,6 +14,8 @@ aiosignal==1.3.1 # via ray alabaster==0.7.12 # via sphinx +alembic==1.9.1 + # via mlflow altair==4.2.0 # via great-expectations ansiwrap==0.8.4 @@ -40,13 +42,13 @@ arrow==1.2.3 # via jinja2-time astroid==2.12.12 # via sphinx-autoapi -astropy==5.1.1 +astropy==5.2 # via vaex-astro asttokens==2.2.1 # via stack-data astunparse==1.6.3 # via tensorflow -attrs==22.1.0 +attrs==22.2.0 # via # jsonschema # ray @@ -63,11 +65,11 @@ beautifulsoup4==4.11.1 # sphinx-material binaryornot==0.4.4 # via cookiecutter -blake3==0.3.1 +blake3==0.3.3 # via vaex-core bleach==5.0.1 # via nbconvert -botocore==1.29.26 +botocore==1.29.44 # via -r doc-requirements.in bqplot==0.12.36 # via vaex-jupyter @@ -91,12 +93,15 @@ chardet==5.0.0 # via binaryornot charset-normalizer==2.1.1 # via requests -click==8.0.4 +click==8.1.3 # via # cookiecutter # dask + # databricks-cli + # flask # flytekit # great-expectations + # mlflow # papermill # ray # sphinx-click @@ -105,6 +110,8 @@ cloudpickle==2.2.0 # via # dask # flytekit + # mlflow + # shap # vaex-core colorama==0.4.6 # via great-expectations @@ -118,7 +125,11 @@ cookiecutter==2.1.1 # via flytekit croniter==1.3.7 # via flytekit +<<<<<<< HEAD cryptography==38.0.3 +======= +cryptography==39.0.0 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # -r doc-requirements.in # great-expectations @@ -127,8 +138,10 @@ css-html-js-minify==2.5.5 # via sphinx-material cycler==0.11.0 # via matplotlib -dask==2022.12.0 +dask==2022.12.1 # via vaex-core +databricks-cli==0.17.4 + # via mlflow dataclasses-json==0.5.7 # via # dolt-integrations @@ -148,7 +161,9 @@ diskcache==5.4.0 distlib==0.3.6 # via virtualenv docker==6.0.1 - # via flytekit + # via + # flytekit + # mlflow docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 @@ -166,6 +181,7 @@ entrypoints==0.4 # via # altair # jupyter-client + # mlflow # papermill executing==1.2.0 # via stack-data @@ -173,14 +189,24 @@ fastapi==0.88.0 # via vaex-server fastjsonschema==2.16.2 # via nbformat +<<<<<<< HEAD filelock==3.8.0 +======= +filelock==3.9.0 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # ray # vaex-core # virtualenv -flatbuffers==22.12.6 +flask==2.2.2 + # via mlflow +flatbuffers==23.1.4 # via tensorflow +<<<<<<< HEAD flyteidl==1.2.9 +======= +flyteidl==1.3.2 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via flytekit fonttools==4.38.0 # via matplotlib @@ -203,11 +229,18 @@ future==0.18.2 # via vaex-core gast==0.5.3 # via tensorflow +<<<<<<< HEAD google-api-core[grpc]==2.8.2 +======= +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via mlflow +google-api-core[grpc]==2.11.0 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # -r doc-requirements.in # google-cloud-bigquery - # google-cloud-bigquery-storage # google-cloud-core google-auth==2.14.1 # via @@ -220,10 +253,12 @@ google-auth-oauthlib==0.4.6 # via tensorboard google-cloud==0.34.0 # via -r doc-requirements.in +<<<<<<< HEAD google-cloud-bigquery==3.1.0 +======= +google-cloud-bigquery==3.4.1 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via -r doc-requirements.in -google-cloud-bigquery-storage==2.16.2 - # via google-cloud-bigquery google-cloud-core==2.3.2 # via google-cloud-bigquery google-crc32c==1.5.0 @@ -232,12 +267,13 @@ google-pasta==0.2.0 # via tensorflow google-resumable-media==2.4.0 # via google-cloud-bigquery -googleapis-common-protos==1.57.0 +googleapis-common-protos==1.57.1 # via # flyteidl + # flytekit # google-api-core # grpcio-status -great-expectations==0.15.37 +great-expectations==0.15.42 # via -r doc-requirements.in greenlet==2.0.1 # via sqlalchemy @@ -256,6 +292,8 @@ grpcio-status==1.43.0 # -r doc-requirements.in # flytekit # google-api-core +gunicorn==20.1.0 + # via mlflow h11==0.14.0 # via uvicorn h5py==3.7.0 @@ -266,7 +304,7 @@ htmlmin==0.1.12 # via pandas-profiling httptools==0.5.0 # via uvicorn -identify==2.5.9 +identify==2.5.12 # via pre-commit idna==3.4 # via @@ -276,17 +314,23 @@ imagehash==4.3.1 # via visions imagesize==1.4.1 # via sphinx +<<<<<<< HEAD importlib-metadata==5.0.0 +======= +importlib-metadata==5.2.0 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via + # flask # flytekit # great-expectations # keyring # markdown + # mlflow # nbconvert # sphinx ipydatawidgets==4.3.2 # via pythreejs -ipykernel==6.19.2 +ipykernel==6.19.4 # via # ipywidgets # jupyter @@ -298,7 +342,7 @@ ipyleaflet==0.17.2 # via vaex-jupyter ipympl==0.9.2 # via vaex-jupyter -ipython==8.7.0 +ipython==8.8.0 # via # great-expectations # ipykernel @@ -319,7 +363,7 @@ ipyvuetify==1.8.4 # via vaex-jupyter ipywebrtc==0.6.0 # via ipyvolume -ipywidgets==8.0.3 +ipywidgets==8.0.4 # via # bqplot # great-expectations @@ -332,6 +376,8 @@ ipywidgets==8.0.3 # pythreejs isoduration==20.11.0 # via jsonschema +itsdangerous==2.1.2 + # via flask jaraco-classes==3.2.3 # via keyring jedi==0.18.1 @@ -341,9 +387,11 @@ jinja2==3.1.2 # altair # branca # cookiecutter + # flask # great-expectations # jinja2-time # jupyter-server + # mlflow # nbclassic # nbconvert # notebook @@ -385,7 +433,11 @@ jupyter-client==7.4.6 # qtconsole jupyter-console==6.4.4 # via jupyter +<<<<<<< HEAD jupyter-core==5.0.0 +======= +jupyter-core==5.1.2 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # jupyter-client # jupyter-server @@ -396,25 +448,33 @@ jupyter-core==5.0.0 # qtconsole jupyter-events==0.5.0 # via jupyter-server -jupyter-server==2.0.1 +jupyter-server==2.0.6 # via # nbclassic # notebook-shim +<<<<<<< HEAD jupyterlab-pygments==0.2.2 # via nbconvert jupyterlab-widgets==3.0.3 +======= +jupyter-server-terminals==0.4.3 + # via jupyter-server +jupyterlab-pygments==0.2.2 + # via nbconvert +jupyterlab-widgets==3.0.5 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via ipywidgets keras==2.9.0 # via tensorflow keras-preprocessing==1.1.2 # via tensorflow -keyring==23.11.0 +keyring==23.13.1 # via flytekit kiwisolver==1.4.4 # via matplotlib kubernetes==25.3.0 # via -r doc-requirements.in -lazy-object-proxy==1.8.0 +lazy-object-proxy==1.9.0 # via astroid libclang==14.0.6 # via tensorflow @@ -422,17 +482,21 @@ llvmlite==0.39.1 # via numba locket==1.0.0 # via partd -lxml==4.9.1 +lxml==4.9.2 # via sphinx-material makefun==1.15.0 # via great-expectations +mako==1.2.4 + # via alembic markdown==3.4.1 # via # -r doc-requirements.in + # mlflow # tensorboard markupsafe==2.1.1 # via # jinja2 + # mako # nbconvert # werkzeug marshmallow==3.19.0 @@ -448,6 +512,7 @@ marshmallow-jsonschema==0.13.0 matplotlib==3.5.3 # via # ipympl + # mlflow # pandas-profiling # phik # seaborn @@ -462,13 +527,19 @@ mistune==2.0.4 # via # great-expectations # nbconvert +<<<<<<< HEAD modin==0.17.0 +======= +mlflow==2.1.1 + # via -r doc-requirements.in +modin==0.18.0 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via -r doc-requirements.in more-itertools==9.0.0 # via jaraco-classes msgpack==1.0.4 # via ray -multimethod==1.9 +multimethod==1.9.1 # via # pandas-profiling # visions @@ -482,13 +553,17 @@ nbclient==0.7.0 # via # nbconvert # papermill +<<<<<<< HEAD nbconvert==7.2.5 +======= +nbconvert==7.2.7 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # jupyter # jupyter-server # nbclassic # notebook -nbformat==5.7.0 +nbformat==5.7.1 # via # great-expectations # jupyter-server @@ -516,13 +591,16 @@ notebook==6.5.2 notebook-shim==0.2.2 # via nbclassic numba==0.56.4 - # via vaex-ml + # via + # shap + # vaex-ml numpy==1.23.5 # via # altair # astropy # bqplot # contourpy + # flytekit # great-expectations # h5py # imagehash @@ -531,7 +609,11 @@ numpy==1.23.5 # ipyvolume # keras-preprocessing # matplotlib +<<<<<<< HEAD # missingno +======= + # mlflow +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # modin # numba # opt-einsum @@ -548,7 +630,11 @@ numpy==1.23.5 # scikit-learn # scipy # seaborn +<<<<<<< HEAD # skl2onnx +======= + # shap +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # statsmodels # tensorboard # tensorflow @@ -556,6 +642,7 @@ numpy==1.23.5 # visions # xarray oauthlib==3.2.2 +<<<<<<< HEAD # via requests-oauthlib onnx==1.12.0 # via @@ -564,6 +651,11 @@ onnx==1.12.0 # tf2onnx onnxconverter-common==1.13.0 # via skl2onnx +======= + # via + # databricks-cli + # requests-oauthlib +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) opt-einsum==3.3.0 # via tensorflow packaging==21.3 @@ -577,11 +669,13 @@ packaging==21.3 # jupyter-server # marshmallow # matplotlib + # mlflow # modin # nbconvert # onnxconverter-common # pandera # qtpy + # shap # sphinx # statsmodels # xarray @@ -592,16 +686,18 @@ pandas==1.5.2 # dolt-integrations # flytekit # great-expectations + # mlflow # modin # pandas-profiling # pandera # phik # seaborn + # shap # statsmodels # vaex-core # visions # xarray -pandas-profiling==3.5.0 +pandas-profiling==3.6.2 # via -r doc-requirements.in pandera==0.13.4 # via -r doc-requirements.in @@ -621,7 +717,7 @@ phik==0.12.2 # via pandas-profiling pickleshare==0.7.5 # via ipython -pillow==9.3.0 +pillow==9.4.0 # via # imagehash # ipympl @@ -629,13 +725,17 @@ pillow==9.3.0 # matplotlib # vaex-viz # visions +<<<<<<< HEAD platformdirs==2.5.4 +======= +platformdirs==2.6.2 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # jupyter-core # virtualenv plotly==5.11.0 # via -r doc-requirements.in -pre-commit==2.20.0 +pre-commit==2.21.0 # via sphinx-tags progressbar2==4.2.0 # via vaex-core @@ -649,20 +749,28 @@ prompt-toolkit==3.0.32 # ipython # jupyter-console proto-plus==1.22.1 +<<<<<<< HEAD # via # google-cloud-bigquery # google-cloud-bigquery-storage protobuf==3.19.6 +======= + # via google-cloud-bigquery +protobuf==4.21.12 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # flyteidl # flytekit # google-api-core # google-cloud-bigquery - # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status +<<<<<<< HEAD # onnx # onnxconverter-common +======= + # mlflow +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # proto-plus # protoc-gen-swagger # ray @@ -689,7 +797,7 @@ py4j==0.10.9.5 pyarrow==6.0.1 # via # flytekit - # google-cloud-bigquery + # mlflow # vaex-core pyasn1==0.4.8 # via @@ -699,7 +807,7 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pydantic==1.10.2 +pydantic==1.10.4 # via # fastapi # great-expectations @@ -708,7 +816,7 @@ pydantic==1.10.2 # vaex-core pyerfa==2.0.0.1 # via astropy -pygments==2.13.0 +pygments==2.14.0 # via # furo # ipython @@ -718,14 +826,16 @@ pygments==2.13.0 # rich # sphinx # sphinx-prompt -pyopenssl==22.1.0 +pyjwt==2.6.0 + # via databricks-cli +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 # via # great-expectations # matplotlib # packaging -pyrsistent==0.19.2 +pyrsistent==0.19.3 # via jsonschema pyspark==3.3.1 # via -r doc-requirements.in @@ -756,11 +866,12 @@ pythreejs==2.4.1 # via ipyvolume pytimeparse==1.1.8 # via flytekit -pytz==2022.6 +pytz==2022.7 # via # babel # flytekit # great-expectations + # mlflow # pandas pytz-deprecation-shim==0.1.0.post0 # via tzlocal @@ -773,6 +884,7 @@ pyyaml==6.0 # dask # flytekit # kubernetes + # mlflow # pandas-profiling # papermill # pre-commit @@ -792,13 +904,16 @@ qtconsole==5.4.0 # via jupyter qtpy==2.3.0 # via qtconsole -ray==2.1.0 +querystring-parser==1.2.4 + # via mlflow +ray==2.2.0 # via -r doc-requirements.in regex==2022.10.31 # via docker-image-py requests==2.28.1 # via # cookiecutter + # databricks-cli # docker # flytekit # google-api-core @@ -806,6 +921,7 @@ requests==2.28.1 # great-expectations # ipyvolume # kubernetes + # mlflow # pandas-profiling # papermill # ray @@ -826,7 +942,7 @@ rfc3339-validator==0.1.4 # via jsonschema rfc3986-validator==0.1.1 # via jsonschema -rich==12.6.0 +rich==13.0.0 # via vaex-core rsa==4.9 # via google-auth @@ -835,29 +951,43 @@ ruamel-yaml==0.17.17 ruamel-yaml-clib==0.2.7 # via ruamel-yaml scikit-learn==1.2.0 - # via -r doc-requirements.in + # via + # -r doc-requirements.in + # mlflow + # shap scipy==1.9.3 # via # great-expectations # imagehash +<<<<<<< HEAD # missingno # pandas-profiling # phik # scikit-learn # skl2onnx +======= + # mlflow + # pandas-profiling + # phik + # scikit-learn + # shap +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # statsmodels -seaborn==0.12.1 +seaborn==0.12.2 # via pandas-profiling send2trash==1.8.0 # via # jupyter-server # nbclassic # notebook +shap==0.41.0 + # via mlflow six==1.16.0 # via # asttokens # astunparse # bleach + # databricks-cli # google-auth # google-pasta # grpcio @@ -865,9 +995,18 @@ six==1.16.0 # kubernetes # patsy # python-dateutil +<<<<<<< HEAD +======= + # querystring-parser + # rfc3339-validator +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # sphinx-code-include # tensorflow # vaex-core +slicer==0.0.7 + # via shap +smmap==5.0.0 + # via gitdb sniffio==1.3.0 # via anyio snowballstemmer==2.2.0 @@ -928,9 +1067,20 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-yt==0.2.2 # via -r doc-requirements.in +<<<<<<< HEAD sqlalchemy==1.4.44 # via -r doc-requirements.in stack-data==0.6.1 +======= +sqlalchemy==1.4.46 + # via + # -r doc-requirements.in + # alembic + # mlflow +sqlparse==0.4.3 + # via mlflow +stack-data==0.6.2 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via ipython starlette==0.22.0 # via fastapi @@ -939,7 +1089,9 @@ statsd==3.3.0 statsmodels==0.13.5 # via pandas-profiling tabulate==0.9.0 - # via vaex-core + # via + # databricks-cli + # vaex-core tangled-up-in-unicode==0.2.0 # via visions tenacity==8.1.0 @@ -956,6 +1108,7 @@ tensorflow==2.9.0 # via -r doc-requirements.in tensorflow-estimator==2.9.0 # via tensorflow +<<<<<<< HEAD tensorflow-io-gcs-filesystem==0.27.0 # via tensorflow termcolor==2.0.1 @@ -963,6 +1116,13 @@ termcolor==2.0.1 # great-expectations # tensorflow terminado==0.17.0 +======= +tensorflow-io-gcs-filesystem==0.29.0 + # via tensorflow +termcolor==2.2.0 + # via tensorflow +terminado==0.17.1 +>>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # jupyter-server # nbclassic @@ -978,15 +1138,13 @@ threadpoolctl==3.1.0 tinycss2==1.2.1 # via nbconvert toml==0.10.2 - # via - # pre-commit - # responses + # via responses toolz==0.12.0 # via # altair # dask # partd -torch==1.13.0 +torch==1.13.1 # via -r doc-requirements.in tornado==6.2 # via @@ -1002,7 +1160,8 @@ tqdm==4.64.1 # great-expectations # pandas-profiling # papermill -traitlets==5.7.0 + # shap +traitlets==5.8.0 # via # bqplot # comm @@ -1119,7 +1278,9 @@ websocket-client==1.4.2 websockets==10.4 # via uvicorn werkzeug==2.2.2 - # via tensorboard + # via + # flask + # tensorboard wheel==0.38.4 # via # astunparse @@ -1127,11 +1288,11 @@ wheel==0.38.4 # tensorboard whylabs-client==0.4.0 # via -r doc-requirements.in -whylogs==1.1.13 +whylogs==1.1.20 # via -r doc-requirements.in whylogs-sketching==3.4.1.dev3 # via whylogs -widgetsnbextension==4.0.3 +widgetsnbextension==4.0.5 # via ipywidgets wrapt==1.14.1 # via diff --git a/docs/source/plugins/index.rst b/docs/source/plugins/index.rst index bf0b03fb95..008f2b4bbe 100644 --- a/docs/source/plugins/index.rst +++ b/docs/source/plugins/index.rst @@ -29,6 +29,7 @@ Plugin API reference * :ref:`Ray ` - Ray API reference * :ref:`DBT ` - DBT API reference * :ref:`Vaex ` - Vaex API reference +* :ref:`MLflow ` - MLflow API reference .. toctree:: :maxdepth: 2 @@ -59,3 +60,4 @@ Plugin API reference Ray DBT Vaex + MLflow diff --git a/docs/source/plugins/mlflow.rst b/docs/source/plugins/mlflow.rst new file mode 100644 index 0000000000..60d1a7c66b --- /dev/null +++ b/docs/source/plugins/mlflow.rst @@ -0,0 +1,9 @@ +.. _mlflow: + +################################################### +MLflow API reference +################################################### + +.. tags:: Integration, MachineLearning, Tracking + +.. automodule:: flytekitplugins.mlflow diff --git a/plugins/flytekit-mlflow/README.md b/plugins/flytekit-mlflow/README.md new file mode 100644 index 0000000000..6cbee9cf59 --- /dev/null +++ b/plugins/flytekit-mlflow/README.md @@ -0,0 +1,22 @@ +# Flytekit MLflow Plugin + +MLflow enables us to log parameters, code, and results in machine learning experiments and compare them using an interactive UI. +This MLflow plugin enables seamless use of MLFlow within Flyte, and render the metrics and parameters on Flyte Deck. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-mlflow +``` + +Example +```python +from flytekit import task, workflow +from flytekitplugins.mlflow import mlflow_autolog +import mlflow + +@task(disable_deck=False) +@mlflow_autolog(framework=mlflow.keras) +def train_model(): + ... +``` diff --git a/plugins/flytekit-mlflow/dev-requirements.in b/plugins/flytekit-mlflow/dev-requirements.in new file mode 100644 index 0000000000..0f57144081 --- /dev/null +++ b/plugins/flytekit-mlflow/dev-requirements.in @@ -0,0 +1 @@ +tensorflow diff --git a/plugins/flytekit-mlflow/dev-requirements.txt b/plugins/flytekit-mlflow/dev-requirements.txt new file mode 100644 index 0000000000..6ad9be49bb --- /dev/null +++ b/plugins/flytekit-mlflow/dev-requirements.txt @@ -0,0 +1,122 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile dev-requirements.in +# +absl-py==1.3.0 + # via + # tensorboard + # tensorflow +astunparse==1.6.3 + # via tensorflow +cachetools==5.2.0 + # via google-auth +certifi==2022.9.24 + # via requests +charset-normalizer==2.1.1 + # via requests +flatbuffers==22.10.26 + # via tensorflow +gast==0.4.0 + # via tensorflow +google-auth==2.14.1 + # via + # google-auth-oauthlib + # tensorboard +google-auth-oauthlib==0.4.6 + # via tensorboard +google-pasta==0.2.0 + # via tensorflow +grpcio==1.50.0 + # via + # tensorboard + # tensorflow +h5py==3.7.0 + # via tensorflow +idna==3.4 + # via requests +importlib-metadata==5.0.0 + # via markdown +keras==2.10.0 + # via tensorflow +keras-preprocessing==1.1.2 + # via tensorflow +libclang==14.0.6 + # via tensorflow +markdown==3.4.1 + # via tensorboard +markupsafe==2.1.1 + # via werkzeug +numpy==1.23.4 + # via + # h5py + # keras-preprocessing + # opt-einsum + # tensorboard + # tensorflow +oauthlib==3.2.2 + # via requests-oauthlib +opt-einsum==3.3.0 + # via tensorflow +packaging==21.3 + # via tensorflow +protobuf==3.19.6 + # via + # tensorboard + # tensorflow +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth +pyparsing==3.0.9 + # via packaging +requests==2.28.1 + # via + # requests-oauthlib + # tensorboard +requests-oauthlib==1.3.1 + # via google-auth-oauthlib +rsa==4.9 + # via google-auth +six==1.16.0 + # via + # astunparse + # google-auth + # google-pasta + # grpcio + # keras-preprocessing + # tensorflow +tensorboard==2.10.1 + # via tensorflow +tensorboard-data-server==0.6.1 + # via tensorboard +tensorboard-plugin-wit==1.8.1 + # via tensorboard +tensorflow==2.10.0 + # via -r dev-requirements.in +tensorflow-estimator==2.10.0 + # via tensorflow +tensorflow-io-gcs-filesystem==0.27.0 + # via tensorflow +termcolor==2.1.0 + # via tensorflow +typing-extensions==4.4.0 + # via tensorflow +urllib3==1.26.12 + # via requests +werkzeug==2.2.2 + # via tensorboard +wheel==0.38.3 + # via + # astunparse + # tensorboard +wrapt==1.14.1 + # via tensorflow +zipp==3.10.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py b/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py new file mode 100644 index 0000000000..98e84547e0 --- /dev/null +++ b/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py @@ -0,0 +1,13 @@ +""" +.. currentmodule:: flytekitplugins.mlflow + +This plugin enables seamless integration between Flyte and mlflow. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + mlflow_autolog +""" + +from .tracking import mlflow_autolog diff --git a/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py new file mode 100644 index 0000000000..b58aa4a120 --- /dev/null +++ b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py @@ -0,0 +1,140 @@ +import typing +from functools import partial, wraps + +import mlflow +import pandas +import pandas as pd +import plotly.graph_objects as go +from mlflow import MlflowClient +from mlflow.entities.metric import Metric +from plotly.subplots import make_subplots + +import flytekit +from flytekit import FlyteContextManager +from flytekit.bin.entrypoint import get_one_of +from flytekit.core.context_manager import ExecutionState +from flytekit.deck import TopFrameRenderer + + +def metric_to_df(metrics: typing.List[Metric]) -> pd.DataFrame: + """ + Converts mlflow Metric object to a dataframe of 2 columns ['timestamp', 'value'] + """ + t = [] + v = [] + for m in metrics: + t.append(m.timestamp) + v.append(m.value) + return pd.DataFrame(list(zip(t, v)), columns=["timestamp", "value"]) + + +def get_run_metrics(c: MlflowClient, run_id: str) -> typing.Dict[str, pandas.DataFrame]: + """ + Extracts all metrics and returns a dictionary of metric name to the list of metric for the given run_id + """ + r = c.get_run(run_id) + metrics = {} + for k in r.data.metrics.keys(): + metrics[k] = metric_to_df(metrics=c.get_metric_history(run_id=run_id, key=k)) + return metrics + + +def get_run_params(c: MlflowClient, run_id: str) -> typing.Optional[pd.DataFrame]: + """ + Extracts all parameters and returns a dictionary of metric name to the list of metric for the given run_id + """ + r = c.get_run(run_id) + name = [] + value = [] + if r.data.params == {}: + return None + for k, v in r.data.params.items(): + name.append(k) + value.append(v) + return pd.DataFrame(list(zip(name, value)), columns=["name", "value"]) + + +def plot_metrics(metrics: typing.Dict[str, pandas.DataFrame]) -> typing.Optional[go.Figure]: + v = len(metrics) + if v == 0: + return None + + # Initialize figure with subplots + fig = make_subplots(rows=v, cols=1, subplot_titles=list(metrics.keys())) + + # Add traces + row = 1 + for k, v in metrics.items(): + v["timestamp"] = (v["timestamp"] - v["timestamp"][0]) / 1000 + fig.add_trace(go.Scatter(x=v["timestamp"], y=v["value"], name=k), row=row, col=1) + row = row + 1 + + fig.update_xaxes(title_text="Time (s)") + fig.update_layout(height=700, width=900) + return fig + + +def mlflow_autolog(fn=None, *, framework=mlflow.sklearn, experiment_name: typing.Optional[str] = None): + """MLFlow decorator to enable autologging of training metrics. + + This decorator can be used as a nested decorator for a ``@task`` and it will automatically enable mlflow autologging, + for the given ``framework``. By default autologging is enabled for ``sklearn``. + + .. code-block:: python + + @task + @mlflow_autolog(framework=mlflow.tensorflow) + def my_tensorflow_trainer(): + ... + + One benefit of doing so is that the mlflow metrics are then rendered inline using FlyteDecks and can be viewed + in jupyter notebook, as well as in hosted Flyte environment: + + .. code-block:: python + + # jupyter notebook cell + with flytekit.new_context() as ctx: + my_tensorflow_trainer() + ctx.get_deck() # IPython.display + + When the task is called in a Flyte backend, the decorator starts a new MLFlow run using the Flyte execution name + by default, or a user-provided ``experiment_name`` in the decorator. + + :param fn: Function to generate autologs for. + :param framework: The mlflow module to use for autologging + :param experiment_name: The MLFlow experiment name. If not provided, uses the Flyte execution name. + """ + + @wraps(fn) + def wrapper(*args, **kwargs): + framework.autolog() + params = FlyteContextManager.current_context().user_space_params + ctx = FlyteContextManager.current_context() + + experiment = experiment_name or "local workflow" + run_name = None # MLflow will generate random name if value is None + + if ctx.execution_state.mode != ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + experiment = experiment_name and f"{get_one_of('FLYTE_INTERNAL_EXECUTION_WORKFLOW', '_F_WF')}" + run_name = f"{params.execution_id.name}.{params.task_id.name.split('.')[-1]}" + + mlflow.set_experiment(experiment) + with mlflow.start_run(run_name=run_name): + out = fn(*args, **kwargs) + run = mlflow.active_run() + if run is not None: + client = MlflowClient() + run_id = run.info.run_id + metrics = get_run_metrics(client, run_id) + figure = plot_metrics(metrics) + if figure: + flytekit.Deck("mlflow metrics", figure.to_html()) + params = get_run_params(client, run_id) + if params is not None: + flytekit.Deck("mlflow params", TopFrameRenderer(max_rows=10).to_html(params)) + return out + + if fn is None: + return partial(mlflow_autolog, framework=framework, experiment_name=experiment_name) + + return wrapper diff --git a/plugins/flytekit-mlflow/requirements.in b/plugins/flytekit-mlflow/requirements.in new file mode 100644 index 0000000000..cbe58e3885 --- /dev/null +++ b/plugins/flytekit-mlflow/requirements.in @@ -0,0 +1,3 @@ +. +-e file:.#egg=flytekitplugins-mlflow +grpcio-status<1.49.0 diff --git a/plugins/flytekit-mlflow/requirements.txt b/plugins/flytekit-mlflow/requirements.txt new file mode 100644 index 0000000000..03873c05f5 --- /dev/null +++ b/plugins/flytekit-mlflow/requirements.txt @@ -0,0 +1,274 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-mlflow + # via -r requirements.in +alembic==1.8.1 + # via mlflow +arrow==1.2.3 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.9.24 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.1 + # via requests +click==8.1.3 + # via + # cookiecutter + # databricks-cli + # flask + # flytekit + # mlflow +cloudpickle==2.2.0 + # via + # flytekit + # mlflow +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.7 + # via flytekit +cryptography==38.0.3 + # via pyopenssl +databricks-cli==0.17.3 + # via mlflow +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==6.0.1 + # via + # flytekit + # mlflow +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.15 + # via flytekit +entrypoints==0.4 + # via mlflow +flask==2.2.2 + # via + # mlflow + # prometheus-flask-exporter +flyteidl==1.1.22 + # via flytekit +flytekit==1.2.3 + # via flytekitplugins-mlflow +gitdb==4.0.9 + # via gitpython +gitpython==3.1.29 + # via mlflow +googleapis-common-protos==1.56.4 + # via + # flyteidl + # grpcio-status +greenlet==2.0.1 + # via sqlalchemy +grpcio==1.50.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.48.2 + # via + # -r requirements.in + # flytekit +gunicorn==20.1.0 + # via mlflow +idna==3.4 + # via requests +importlib-metadata==5.0.0 + # via + # flask + # flytekit + # keyring + # mlflow +itsdangerous==2.1.2 + # via flask +jaraco-classes==3.2.3 + # via keyring +jinja2==3.1.2 + # via + # cookiecutter + # flask + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.2.0 + # via flytekit +keyring==23.11.0 + # via flytekit +mako==1.2.3 + # via alembic +markupsafe==2.1.1 + # via + # jinja2 + # mako + # werkzeug +marshmallow==3.18.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mlflow==1.30.0 + # via flytekitplugins-mlflow +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.2.0 + # via flytekit +numpy==1.23.4 + # via + # mlflow + # pandas + # pyarrow + # scipy +oauthlib==3.2.2 + # via databricks-cli +packaging==21.3 + # via + # docker + # marshmallow + # mlflow +pandas==1.5.1 + # via + # flytekit + # mlflow +plotly==5.11.0 + # via flytekitplugins-mlflow +prometheus-client==0.15.0 + # via prometheus-flask-exporter +prometheus-flask-exporter==0.20.3 + # via mlflow +protobuf==3.20.3 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # mlflow + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyjwt==2.6.0 + # via databricks-cli +pyopenssl==22.1.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.4 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.6 + # via + # flytekit + # mlflow + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit + # mlflow +querystring-parser==1.2.4 + # via mlflow +regex==2022.10.31 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # databricks-cli + # docker + # flytekit + # mlflow + # responses +responses==0.22.0 + # via flytekit +retry==0.9.2 + # via flytekit +scipy==1.9.3 + # via mlflow +six==1.16.0 + # via + # databricks-cli + # grpcio + # python-dateutil + # querystring-parser +smmap==5.0.0 + # via gitdb +sortedcontainers==2.4.0 + # via flytekit +sqlalchemy==1.4.43 + # via + # alembic + # mlflow +sqlparse==0.4.3 + # via mlflow +statsd==3.3.0 + # via flytekit +tabulate==0.9.0 + # via databricks-cli +tenacity==8.1.0 + # via plotly +text-unidecode==1.3 + # via python-slugify +toml==0.10.2 + # via responses +types-toml==0.10.8 + # via responses +typing-extensions==4.4.0 + # via + # flytekit + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +urllib3==1.26.12 + # via + # docker + # flytekit + # requests + # responses +websocket-client==1.4.2 + # via docker +werkzeug==2.2.2 + # via flask +wheel==0.38.3 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zipp==3.10.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-mlflow/setup.py b/plugins/flytekit-mlflow/setup.py new file mode 100644 index 0000000000..2033ce5d27 --- /dev/null +++ b/plugins/flytekit-mlflow/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "mlflow" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.1.0,<2.0.0", "plotly", "mlflow"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package enables seamless use of MLFlow within Flyte", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-mlflow/tests/__init__.py b/plugins/flytekit-mlflow/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py new file mode 100644 index 0000000000..b196327d8d --- /dev/null +++ b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py @@ -0,0 +1,32 @@ +import mlflow +import tensorflow as tf +from flytekitplugins.mlflow import mlflow_autolog + +import flytekit +from flytekit import task + + +@task(disable_deck=False) +@mlflow_autolog(framework=mlflow.keras) +def train_model(epochs: int): + fashion_mnist = tf.keras.datasets.fashion_mnist + (train_images, train_labels), (_, _) = fashion_mnist.load_data() + train_images = train_images / 255.0 + + model = tf.keras.Sequential( + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(128, activation="relu"), + tf.keras.layers.Dense(10), + ] + ) + + model.compile( + optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"] + ) + model.fit(train_images, train_labels, epochs=epochs) + + +def test_local_exec(): + train_model(epochs=1) + assert len(flytekit.current_context().decks) == 4 # mlflow metrics, params, input, and output