diff --git a/.gitignore b/.gitignore index fc76e7d07c..a4fe02503e 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,4 @@ docs/source/plugins/generated/ htmlcov *.ipynb *dat -source/_tags/ +docs/source/_tags/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 39470b7370..1fd6e6b648 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: hooks: - id: black - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black"] diff --git a/.readthedocs.yml b/.readthedocs.yml index 1c0f039d3a..19b1898e94 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -5,12 +5,17 @@ # Required version: 2 +build: + os: ubuntu-20.04 + tools: + python: "3.9" + # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs/source/conf.py # Optionally set the version of Python and requirements required to build your docs python: - version: 3.8 install: - requirements: doc-requirements.txt + - requirements: docs/requirements.txt diff --git a/CODEOWNERS b/CODEOWNERS index 9389524869..a9aab29ffd 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,3 +1,3 @@ # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence. -* @wild-endeavor @kumare3 @eapolinario @pingsutw +* @wild-endeavor @kumare3 @eapolinario @pingsutw @cosmicBboy diff --git a/Dockerfile b/Dockerfile index 82f4fe5366..6c3228ad2f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,6 +4,10 @@ FROM python:${PYTHON_VERSION}-slim-buster MAINTAINER Flyte Team LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit +RUN useradd -u 1000 flytekit +RUN chown flytekit: /root +USER flytekit + WORKDIR /root ENV PYTHONPATH /root diff --git a/dev-requirements.in b/dev-requirements.in index e9726227fb..c2a0a9bdd5 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -12,6 +12,8 @@ codespell google-cloud-bigquery google-cloud-bigquery-storage IPython +tensorflow +grpcio-status<1.49.0 # Newer versions of torch bring in nvidia dependencies that are not present in windows, so # we put this constraint while we do not have per-environment requirements files torch<=1.12.1 diff --git a/dev-requirements.txt b/dev-requirements.txt index 5a73546c0d..9e4eba39fd 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.7 +# This file is autogenerated by pip-compile with python 3.9 # To update, run: # # make dev-requirements.txt @@ -8,12 +8,18 @@ # via # -c requirements.txt # pytest-flyte +absl-py==1.4.0 + # via + # tensorboard + # tensorflow appnope==0.1.3 # via ipython arrow==1.2.3 # via # -c requirements.txt # jinja2-time +astunparse==1.6.3 + # via tensorflow attrs==20.3.0 # via # -c requirements.txt @@ -30,11 +36,14 @@ binaryornot==0.4.4 # cookiecutter cached-property==1.5.2 # via docker-compose -cachetools==5.2.0 - # via google-auth -certifi==2022.9.24 +cachetools==5.3.0 + # via + # -c requirements.txt + # google-auth +certifi==2022.12.7 # via # -c requirements.txt + # kubernetes # requests cffi==1.15.1 # via @@ -47,7 +56,7 @@ chardet==5.0.0 # via # -c requirements.txt # binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via # -c requirements.txt # requests @@ -56,7 +65,7 @@ click==8.1.3 # -c requirements.txt # cookiecutter # flytekit -cloudpickle==2.2.0 +cloudpickle==2.2.1 # via # -c requirements.txt # flytekit @@ -66,7 +75,7 @@ cookiecutter==2.1.1 # via # -c requirements.txt # flytekit -coverage[toml]==6.5.0 +coverage[toml]==7.1.0 # via # -r dev-requirements.in # pytest-cov @@ -74,7 +83,7 @@ croniter==1.3.7 # via # -c requirements.txt # flytekit -cryptography==38.0.3 +cryptography==39.0.0 # via # -c requirements.txt # paramiko @@ -119,39 +128,58 @@ docstring-parser==0.15 # via # -c requirements.txt # flytekit -exceptiongroup==1.0.4 +exceptiongroup==1.1.0 # via pytest -filelock==3.8.0 +filelock==3.9.0 # via virtualenv -flyteidl==1.2.5 +flatbuffers==23.1.21 + # via tensorflow +flyteidl==1.2.9 # via # -c requirements.txt # flytekit -google-api-core[grpc]==2.10.2 +gast==0.5.3 + # via tensorflow +gitdb==4.0.10 + # via + # -c requirements.txt + # gitpython +gitpython==3.1.30 + # via + # -c requirements.txt + # flytekit +google-api-core[grpc]==2.11.0 # via # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.14.1 +google-auth==2.16.0 # via + # -c requirements.txt # google-api-core + # google-auth-oauthlib # google-cloud-core -google-cloud-bigquery==3.4.0 + # kubernetes + # tensorboard +google-auth-oauthlib==0.4.6 + # via tensorboard +google-cloud-bigquery==3.4.2 + # via -r dev-requirements.in +google-cloud-bigquery-storage==2.18.1 # via -r dev-requirements.in -google-cloud-bigquery-storage==2.16.2 - # via - # -r dev-requirements.in - # google-cloud-bigquery google-cloud-core==2.3.2 # via google-cloud-bigquery google-crc32c==1.5.0 # via google-resumable-media -google-resumable-media==2.4.0 +google-pasta==0.2.0 + # via tensorflow +google-resumable-media==2.4.1 # via google-cloud-bigquery -googleapis-common-protos==1.57.0 +googleapis-common-protos==1.58.0 # via # -c requirements.txt # flyteidl + # flytekit # google-api-core # grpcio-status grpcio==1.48.2 @@ -161,29 +189,37 @@ grpcio==1.48.2 # google-api-core # google-cloud-bigquery # grpcio-status + # tensorboard + # tensorflow grpcio-status==1.48.2 # via # -c requirements.txt + # -r dev-requirements.in # flytekit # google-api-core -identify==2.5.9 +h5py==3.8.0 + # via tensorflow +identify==2.5.17 # via pre-commit idna==3.4 # via # -c requirements.txt # requests -importlib-metadata==5.0.0 +importlib-metadata==6.0.0 # via # -c requirements.txt - # click # flytekit - # jsonschema # keyring + # markdown # pluggy # pre-commit # pytest # virtualenv -iniconfig==1.1.1 +importlib-resources==5.10.2 + # via + # -c requirements.txt + # keyring +iniconfig==2.0.0 # via pytest ipython==7.34.0 # via -r dev-requirements.in @@ -213,14 +249,27 @@ jsonschema==3.2.0 # via # -c requirements.txt # docker-compose -keyring==23.11.0 +keras==2.8.0 + # via tensorflow +keras-preprocessing==1.1.2 + # via tensorflow +keyring==23.13.1 + # via + # -c requirements.txt + # flytekit +kubernetes==25.3.0 # via # -c requirements.txt # flytekit -markupsafe==2.1.1 +libclang==15.0.6.1 + # via tensorflow +markdown==3.4.1 + # via tensorboard +markupsafe==2.1.2 # via # -c requirements.txt # jinja2 + # werkzeug marshmallow==3.19.0 # via # -c requirements.txt @@ -237,7 +286,7 @@ marshmallow-jsonschema==0.13.0 # flytekit matplotlib-inline==0.1.6 # via ipython -mock==4.0.3 +mock==5.0.1 # via -r dev-requirements.in more-itertools==9.0.0 # via @@ -259,9 +308,19 @@ nodeenv==1.7.0 numpy==1.21.6 # via # -c requirements.txt - # flytekit + # h5py + # keras-preprocessing + # opt-einsum # pandas # pyarrow + # tensorboard + # tensorflow +oauthlib==3.2.2 + # via + # -c requirements.txt + # requests-oauthlib +opt-einsum==3.3.0 + # via tensorflow # scikit-learn # scipy packaging==21.3 @@ -275,7 +334,7 @@ pandas==1.3.5 # via # -c requirements.txt # flytekit -paramiko==2.12.0 +paramiko==3.0.0 # via docker parso==0.8.3 # via jedi @@ -283,15 +342,15 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -platformdirs==2.5.4 +platformdirs==2.6.2 # via virtualenv pluggy==1.0.0 # via pytest -pre-commit==2.20.0 +pre-commit==2.21.0 # via -r dev-requirements.in prompt-toolkit==3.0.32 # via ipython -proto-plus==1.22.1 +proto-plus==1.22.2 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -307,6 +366,8 @@ protobuf==3.20.3 # grpcio-status # proto-plus # protoc-gen-swagger + # tensorboard + # tensorflow protoc-gen-swagger==0.1.0 # via # -c requirements.txt @@ -321,22 +382,24 @@ pyarrow==10.0.0 # via # -c requirements.txt # flytekit - # google-cloud-bigquery pyasn1==0.4.8 # via + # -c requirements.txt # pyasn1-modules # rsa pyasn1-modules==0.2.8 - # via google-auth + # via + # -c requirements.txt + # google-auth pycparser==2.21 # via # -c requirements.txt # cffi -pygments==2.13.0 +pygments==2.14.0 # via ipython pynacl==1.5.0 # via paramiko -pyopenssl==22.1.0 +pyopenssl==23.0.0 # via # -c requirements.txt # flytekit @@ -344,11 +407,11 @@ pyparsing==3.0.9 # via # -c requirements.txt # packaging -pyrsistent==0.19.2 +pyrsistent==0.19.3 # via # -c requirements.txt # jsonschema -pytest==7.2.0 +pytest==7.2.1 # via # -r dev-requirements.in # pytest-cov @@ -367,14 +430,15 @@ python-dateutil==2.8.2 # croniter # flytekit # google-cloud-bigquery + # kubernetes # pandas -python-dotenv==0.21.0 +python-dotenv==0.21.1 # via docker-compose python-json-logger==2.0.4 # via # -c requirements.txt # flytekit -python-slugify==7.0.0 +python-slugify==8.0.0 # via # -c requirements.txt # cookiecutter @@ -382,7 +446,7 @@ pytimeparse==1.1.8 # via # -c requirements.txt # flytekit -pytz==2022.6 +pytz==2022.7.1 # via # -c requirements.txt # flytekit @@ -393,12 +457,13 @@ pyyaml==5.4.1 # cookiecutter # docker-compose # flytekit + # kubernetes # pre-commit regex==2022.10.31 # via # -c requirements.txt # docker-image-py -requests==2.28.1 +requests==2.28.2 # via # -c requirements.txt # cookiecutter @@ -407,7 +472,15 @@ requests==2.28.1 # flytekit # google-api-core # google-cloud-bigquery + # kubernetes + # requests-oauthlib # responses + # tensorboard +requests-oauthlib==1.3.1 + # via + # -c requirements.txt + # google-auth-oauthlib + # kubernetes responses==0.22.0 # via # -c requirements.txt @@ -417,7 +490,9 @@ retry==0.9.2 # -c requirements.txt # flytekit rsa==4.9 - # via google-auth + # via + # -c requirements.txt + # google-auth scikit-learn==1.0.2 # via -r dev-requirements.in scipy==1.7.3 @@ -429,13 +504,21 @@ singledispatchmethod==1.0 six==1.16.0 # via # -c requirements.txt + # astunparse # dockerpty # google-auth + # google-pasta # grpcio # jsonschema - # paramiko + # keras-preprocessing + # kubernetes # python-dateutil + # tensorflow # websocket-client +smmap==5.0.0 + # via + # -c requirements.txt + # gitdb sortedcontainers==2.4.0 # via # -c requirements.txt @@ -444,6 +527,20 @@ statsd==3.3.0 # via # -c requirements.txt # flytekit +tensorboard==2.8.0 + # via tensorflow +tensorboard-data-server==0.6.1 + # via tensorboard +tensorboard-plugin-wit==1.8.1 + # via tensorboard +tensorflow==2.8.1 + # via -r dev-requirements.in +tensorflow-estimator==2.8.0 + # via tensorflow +tensorflow-io-gcs-filesystem==0.30.0 + # via tensorflow +termcolor==2.2.0 + # via tensorflow text-unidecode==1.3 # via # -c requirements.txt @@ -455,7 +552,6 @@ threadpoolctl==3.1.0 toml==0.10.2 # via # -c requirements.txt - # pre-commit # responses tomli==2.0.1 # via @@ -464,7 +560,7 @@ tomli==2.0.1 # pytest torch==1.12.1 # via -r dev-requirements.in -traitlets==5.5.0 +traitlets==5.9.0 # via # ipython # matplotlib-inline @@ -477,46 +573,56 @@ types-toml==0.10.8.1 typing-extensions==4.4.0 # via # -c requirements.txt - # arrow # flytekit + # gitpython # importlib-metadata # mypy + # platformdirs # responses + # tensorflow # torch # typing-inspect typing-inspect==0.8.0 # via # -c requirements.txt # dataclasses-json -urllib3==1.26.12 +urllib3==1.26.14 # via # -c requirements.txt # docker # flytekit + # kubernetes # requests # responses virtualenv==20.16.7 # via pre-commit -wcwidth==0.2.5 +wcwidth==0.2.6 # via prompt-toolkit websocket-client==0.59.0 # via # -c requirements.txt # docker # docker-compose + # kubernetes +werkzeug==2.2.2 + # via tensorboard wheel==0.38.4 # via # -c requirements.txt + # astunparse # flytekit + # tensorboard wrapt==1.14.1 # via # -c requirements.txt # deprecated # flytekit -zipp==3.10.0 + # tensorflow +zipp==3.12.0 # via # -c requirements.txt # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/doc-requirements.in b/doc-requirements.in index d61955fbff..1ac56c629f 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -12,19 +12,18 @@ sphinx-copybutton sphinx_fontawesome sphinx-panels sphinxcontrib-yt -grpcio==1.43.0 -grpcio-status==1.43.0 cryptography -google-api-core[grpc]==2.8.2 -scikit-learn==1.1.1 +google-api-core[grpc] +scikit-learn sphinx-tags +sphinx-click # Packages for Plugin docs # Package name Plugin needing it botocore # fsspec fsspec # fsspec google-cloud # bigquery -google-cloud-bigquery==3.1.0 # bigquery +google-cloud-bigquery # bigquery markdown # deck plotly # deck pandas_profiling # deck @@ -40,8 +39,10 @@ sqlalchemy # sqlalchemy torch # pytorch skl2onnx # onnxscikitlearn tf2onnx # onnxtensorflow -tensorflow==2.9.0 # onnxtensorflow +tensorflow==2.8.1 # onnxtensorflow whylogs # whylogs whylabs-client # whylogs ray # ray scikit-learn # scikit-learn +vaex # vaex +mlflow # mlflow diff --git a/doc-requirements.txt b/doc-requirements.txt index 4991bd4adc..2616e043df 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -1,29 +1,32 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: # -# make doc-requirements.txt +# pip-compile doc-requirements.in # -e file:.#egg=flytekit # via -r doc-requirements.in -absl-py==1.3.0 +absl-py==1.4.0 # via # tensorboard # tensorflow aiosignal==1.3.1 # via ray -alabaster==0.7.12 +alabaster==0.7.13 # via sphinx -altair==4.2.0 +alembic==1.9.2 + # via mlflow +altair==4.2.2 # via great-expectations ansiwrap==0.8.4 # via papermill anyio==3.6.2 - # via jupyter-server -appnope==0.1.3 # via - # ipykernel - # ipython + # jupyter-server + # starlette + # watchfiles +aplus==0.11.0 + # via vaex-core argon2-cffi==21.3.0 # via # jupyter-server @@ -32,14 +35,18 @@ argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 # via argon2-cffi arrow==1.2.3 - # via jinja2-time -astroid==2.12.12 + # via + # isoduration + # jinja2-time +astroid==2.14.1 # via sphinx-autoapi -asttokens==2.1.0 +astropy==5.2.1 + # via vaex-astro +asttokens==2.2.1 # via stack-data astunparse==1.6.3 # via tensorflow -attrs==22.1.0 +attrs==22.2.0 # via # jsonschema # ray @@ -48,7 +55,7 @@ babel==2.11.0 # via sphinx backcall==0.2.0 # via ipython -beautifulsoup4==4.11.1 +beautifulsoup4==4.11.2 # via # furo # nbconvert @@ -56,13 +63,23 @@ beautifulsoup4==4.11.1 # sphinx-material binaryornot==0.4.4 # via cookiecutter -bleach==5.0.1 +blake3==0.3.3 + # via vaex-core +bleach==6.0.0 # via nbconvert -botocore==1.29.10 +botocore==1.29.61 # via -r doc-requirements.in -cachetools==5.2.0 - # via google-auth -certifi==2022.9.24 +bqplot==0.12.36 + # via + # ipyvolume + # vaex-jupyter +branca==0.6.0 + # via ipyleaflet +cachetools==5.3.0 + # via + # google-auth + # vaex-server +certifi==2022.12.7 # via # kubernetes # requests @@ -74,37 +91,62 @@ cfgv==3.3.1 # via pre-commit chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via requests -click==8.0.4 +click==8.1.3 # via # cookiecutter + # dask + # databricks-cli + # distributed + # flask # flytekit # great-expectations + # mlflow # papermill # ray -cloudpickle==2.2.0 - # via flytekit + # sphinx-click + # uvicorn +cloudpickle==2.2.1 + # via + # dask + # distributed + # flytekit + # mlflow + # shap + # vaex-core colorama==0.4.6 # via great-expectations +comm==0.1.2 + # via ipykernel +contourpy==1.0.7 + # via matplotlib cookiecutter==2.1.1 # via flytekit croniter==1.3.7 # via flytekit -cryptography==38.0.3 +cryptography==39.0.0 # via # -r doc-requirements.in # great-expectations # pyopenssl + # secretstorage css-html-js-minify==2.5.5 # via sphinx-material cycler==0.11.0 # via matplotlib +dask[distributed]==2023.1.1 + # via + # -r doc-requirements.in + # distributed + # vaex-core +databricks-cli==0.17.4 + # via mlflow dataclasses-json==0.5.7 # via # dolt-integrations # flytekit -debugpy==1.6.3 +debugpy==1.6.6 # via ipykernel decorator==5.1.1 # via @@ -118,8 +160,12 @@ diskcache==5.4.0 # via flytekit distlib==0.3.6 # via virtualenv +distributed==2023.1.1 + # via dask docker==6.0.1 - # via flytekit + # via + # flytekit + # mlflow docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 @@ -127,6 +173,7 @@ docstring-parser==0.15 docutils==0.17.1 # via # sphinx + # sphinx-click # sphinx-panels dolt-integrations==0.1.5 # via -r doc-requirements.in @@ -135,43 +182,60 @@ doltcli==0.1.17 entrypoints==0.4 # via # altair - # jupyter-client + # mlflow # papermill executing==1.2.0 # via stack-data +fastapi==0.89.1 + # via vaex-server fastjsonschema==2.16.2 # via nbformat -filelock==3.8.0 +filelock==3.9.0 # via # ray + # vaex-core # virtualenv -flatbuffers==1.12 +flask==2.2.2 + # via mlflow +flatbuffers==2.0.7 # via # tensorflow # tf2onnx -flyteidl==1.2.5 +flyteidl==1.2.9 # via flytekit fonttools==4.38.0 # via matplotlib +fqdn==1.5.1 + # via jsonschema +frozendict==2.3.4 + # via vaex-core frozenlist==1.3.3 # via # aiosignal # ray -fsspec==2022.11.0 +fsspec==2023.1.0 # via # -r doc-requirements.in + # dask # modin furo @ git+https://github.com/flyteorg/furo@main # via -r doc-requirements.in -gast==0.4.0 +future==0.18.3 + # via vaex-core +gast==0.5.3 # via tensorflow -google-api-core[grpc]==2.8.2 +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via + # flytekit + # mlflow +google-api-core[grpc]==2.11.0 # via # -r doc-requirements.in # google-cloud-bigquery - # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.14.1 +google-auth==2.16.0 # via # google-api-core # google-auth-oauthlib @@ -182,30 +246,27 @@ google-auth-oauthlib==0.4.6 # via tensorboard google-cloud==0.34.0 # via -r doc-requirements.in -google-cloud-bigquery==3.1.0 +google-cloud-bigquery==3.5.0 # via -r doc-requirements.in -google-cloud-bigquery-storage==2.16.2 - # via google-cloud-bigquery google-cloud-core==2.3.2 # via google-cloud-bigquery google-crc32c==1.5.0 # via google-resumable-media google-pasta==0.2.0 # via tensorflow -google-resumable-media==2.4.0 +google-resumable-media==2.4.1 # via google-cloud-bigquery -googleapis-common-protos==1.57.0 +googleapis-common-protos==1.58.0 # via # flyteidl # google-api-core # grpcio-status -great-expectations==0.15.32 +great-expectations==0.15.46 # via -r doc-requirements.in -greenlet==2.0.1 +greenlet==2.0.2 # via sqlalchemy -grpcio==1.43.0 +grpcio==1.48.2 # via - # -r doc-requirements.in # flytekit # google-api-core # google-cloud-bigquery @@ -215,32 +276,47 @@ grpcio==1.43.0 # tensorflow grpcio-status==1.43.0 # via - # -r doc-requirements.in # flytekit # google-api-core -h5py==3.7.0 - # via tensorflow +gunicorn==20.1.0 + # via mlflow +h11==0.14.0 + # via uvicorn +h5py==3.8.0 + # via + # tensorflow + # vaex-hdf5 +heapdict==1.0.1 + # via zict htmlmin==0.1.12 - # via pandas-profiling -identify==2.5.8 + # via ydata-profiling +httptools==0.5.0 + # via uvicorn +identify==2.5.17 # via pre-commit idna==3.4 # via # anyio + # jsonschema # requests imagehash==4.3.1 # via visions imagesize==1.4.1 # via sphinx -importlib-metadata==5.0.0 +importlib-metadata==5.2.0 # via + # flask # flytekit # great-expectations + # jupyter-client # keyring # markdown + # mlflow # nbconvert # sphinx -ipykernel==6.17.1 +ipydatawidgets==4.3.2 + # via pythreejs +ipykernel==6.20.2 # via # ipywidgets # jupyter @@ -248,38 +324,76 @@ ipykernel==6.17.1 # nbclassic # notebook # qtconsole -ipython==8.6.0 +ipyleaflet==0.17.2 + # via vaex-jupyter +ipympl==0.9.2 + # via vaex-jupyter +ipython==8.9.0 # via # great-expectations # ipykernel + # ipympl # ipywidgets # jupyter-console ipython-genutils==0.2.0 # via + # ipympl # nbclassic # notebook # qtconsole -ipywidgets==8.0.2 +ipyvolume==0.6.0 + # via vaex-jupyter +ipyvue==1.8.0 # via + # ipyvolume + # ipyvuetify +ipyvuetify==1.8.4 + # via + # ipyvolume + # vaex-jupyter +ipywebrtc==0.6.0 + # via ipyvolume +ipywidgets==8.0.4 + # via + # bqplot # great-expectations + # ipydatawidgets + # ipyleaflet + # ipympl + # ipyvolume + # ipyvue # jupyter + # pythreejs +isoduration==20.11.0 + # via jsonschema +itsdangerous==2.1.2 + # via flask jaraco-classes==3.2.3 # via keyring jedi==0.18.1 # via ipython +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # altair + # branca # cookiecutter + # distributed + # flask # great-expectations # jinja2-time # jupyter-server + # mlflow # nbclassic # nbconvert # notebook - # pandas-profiling # sphinx # sphinx-autoapi + # vaex-ml + # ydata-profiling jinja2-time==0.2.0 # via cookiecutter jmespath==1.0.1 @@ -292,16 +406,19 @@ joblib==1.2.0 jsonpatch==1.32 # via great-expectations jsonpointer==2.3 - # via jsonpatch -jsonschema==4.7.2 + # via + # jsonpatch + # jsonschema +jsonschema[format-nongpl]==4.17.3 # via # altair # great-expectations + # jupyter-events # nbformat # ray jupyter==1.0.0 # via -r doc-requirements.in -jupyter-client==7.4.6 +jupyter-client==8.0.2 # via # ipykernel # jupyter-console @@ -312,7 +429,7 @@ jupyter-client==7.4.6 # qtconsole jupyter-console==6.4.4 # via jupyter -jupyter-core==5.0.0 +jupyter-core==5.2.0 # via # jupyter-client # jupyter-server @@ -321,39 +438,57 @@ jupyter-core==5.0.0 # nbformat # notebook # qtconsole -jupyter-server==1.23.2 +jupyter-events==0.6.3 + # via jupyter-server +jupyter-server==2.2.0 # via # nbclassic # notebook-shim +jupyter-server-terminals==0.4.4 + # via jupyter-server jupyterlab-pygments==0.2.2 # via nbconvert -jupyterlab-widgets==3.0.3 +jupyterlab-widgets==3.0.5 # via ipywidgets -keras==2.9.0 +keras==2.8.0 # via tensorflow keras-preprocessing==1.1.2 # via tensorflow -keyring==23.11.0 +keyring==23.13.1 # via flytekit kiwisolver==1.4.4 # via matplotlib kubernetes==25.3.0 - # via -r doc-requirements.in -lazy-object-proxy==1.8.0 + # via + # -r doc-requirements.in + # flytekit +lazy-object-proxy==1.9.0 # via astroid -libclang==14.0.6 +libclang==15.0.6.1 # via tensorflow -lxml==4.9.1 +llvmlite==0.39.1 + # via numba +locket==1.0.0 + # via + # distributed + # partd +lxml==4.9.2 # via sphinx-material makefun==1.15.0 # via great-expectations +mako==1.2.4 + # via alembic markdown==3.4.1 # via # -r doc-requirements.in + # mlflow # tensorboard -markupsafe==2.1.1 +markdown-it-py==2.1.0 + # via rich +markupsafe==2.1.2 # via # jinja2 + # mako # nbconvert # werkzeug marshmallow==3.19.0 @@ -366,49 +501,56 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.5.3 +matplotlib==3.6.3 # via - # missingno - # pandas-profiling + # ipympl + # ipyvolume + # mlflow # phik # seaborn + # vaex-viz + # ydata-profiling matplotlib-inline==0.1.6 # via # ipykernel # ipython -missingno==0.5.1 - # via pandas-profiling +mdurl==0.1.2 + # via markdown-it-py mistune==2.0.4 # via # great-expectations # nbconvert -modin==0.17.0 +mlflow==2.1.1 + # via -r doc-requirements.in +modin==0.18.1 # via -r doc-requirements.in more-itertools==9.0.0 # via jaraco-classes msgpack==1.0.4 - # via ray -multimethod==1.9 # via - # pandas-profiling + # distributed + # ray +multimethod==1.9.1 + # via # visions + # ydata-profiling mypy-extensions==0.4.3 # via typing-inspect natsort==8.2.0 # via flytekit -nbclassic==0.4.8 +nbclassic==0.5.1 # via notebook nbclient==0.7.0 # via # nbconvert # papermill -nbconvert==7.2.5 +nbconvert==7.2.9 # via # jupyter # jupyter-server # nbclassic # notebook -nbformat==5.7.0 +nbformat==5.7.3 # via # great-expectations # jupyter-server @@ -420,11 +562,11 @@ nbformat==5.7.0 nest-asyncio==1.5.6 # via # ipykernel - # jupyter-client # nbclassic # nbclient # notebook -networkx==2.8.8 + # vaex-core +networkx==3.0 # via visions nodeenv==1.7.0 # via pre-commit @@ -434,38 +576,67 @@ notebook==6.5.2 # jupyter notebook-shim==0.2.2 # via nbclassic -numpy==1.23.4 +numba==0.56.4 + # via + # shap + # vaex-ml +numpy==1.23.5 # via # altair + # astropy + # bqplot + # contourpy + # flytekit # great-expectations # h5py # imagehash + # ipydatawidgets + # ipympl + # ipyvolume # keras-preprocessing # matplotlib - # missingno + # mlflow # modin + # numba # onnx # onnxconverter-common # opt-einsum # pandas - # pandas-profiling # pandera # patsy # phik # pyarrow + # pyerfa + # pythreejs # pywavelets # ray # scikit-learn # scipy # seaborn + # shap # skl2onnx # statsmodels # tensorboard # tensorflow # tf2onnx + # vaex-core # visions + # xarray + # ydata-profiling +nvidia-cublas-cu11==11.10.3.66 + # via + # nvidia-cudnn-cu11 + # torch +nvidia-cuda-nvrtc-cu11==11.7.99 + # via torch +nvidia-cuda-runtime-cu11==11.7.99 + # via torch +nvidia-cudnn-cu11==8.5.0.96 + # via torch oauthlib==3.2.2 - # via requests-oauthlib + # via + # databricks-cli + # requests-oauthlib onnx==1.12.0 # via # onnxconverter-common @@ -475,8 +646,11 @@ onnxconverter-common==1.13.0 # via skl2onnx opt-einsum==3.3.0 # via tensorflow -packaging==21.3 +packaging==22.0 # via + # astropy + # dask + # distributed # docker # google-cloud-bigquery # great-expectations @@ -484,28 +658,35 @@ packaging==21.3 # jupyter-server # marshmallow # matplotlib + # mlflow # modin # nbconvert # onnxconverter-common # pandera # qtpy + # shap # sphinx # statsmodels - # tensorflow -pandas==1.5.1 + # xarray +pandas==1.5.3 # via # altair + # bqplot # dolt-integrations # flytekit # great-expectations + # mlflow # modin - # pandas-profiling # pandera # phik # seaborn + # shap # statsmodels + # vaex-core # visions -pandas-profiling==3.4.0 + # xarray + # ydata-profiling +pandas-profiling==3.6.6 # via -r doc-requirements.in pandera==0.13.4 # via -r doc-requirements.in @@ -515,28 +696,35 @@ papermill==2.4.0 # via -r doc-requirements.in parso==0.8.3 # via jedi +partd==1.3.0 + # via dask patsy==0.5.3 # via statsmodels pexpect==4.8.0 # via ipython -phik==0.12.2 - # via pandas-profiling +phik==0.12.3 + # via ydata-profiling pickleshare==0.7.5 # via ipython -pillow==9.3.0 +pillow==9.4.0 # via # imagehash + # ipympl + # ipyvolume # matplotlib + # vaex-viz # visions -platformdirs==2.5.4 +platformdirs==2.6.2 # via # jupyter-core # virtualenv -plotly==5.11.0 +plotly==5.13.0 # via -r doc-requirements.in -pre-commit==2.20.0 +pre-commit==3.0.2 # via sphinx-tags -prometheus-client==0.15.0 +progressbar2==4.2.0 + # via vaex-core +prometheus-client==0.16.0 # via # jupyter-server # nbclassic @@ -545,19 +733,17 @@ prompt-toolkit==3.0.32 # via # ipython # jupyter-console -proto-plus==1.22.1 - # via - # google-cloud-bigquery - # google-cloud-bigquery-storage +proto-plus==1.22.2 + # via google-cloud-bigquery protobuf==3.19.6 # via # flyteidl # flytekit # google-api-core # google-cloud-bigquery - # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status + # mlflow # onnx # onnxconverter-common # proto-plus @@ -571,6 +757,7 @@ protoc-gen-swagger==0.1.0 # via flyteidl psutil==5.9.4 # via + # distributed # ipykernel # modin ptyprocess==0.7.0 @@ -586,7 +773,8 @@ py4j==0.10.9.5 pyarrow==6.0.1 # via # flytekit - # google-cloud-bigquery + # mlflow + # vaex-core pyasn1==0.4.8 # via # pyasn1-modules @@ -595,27 +783,34 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pydantic==1.10.2 +pydantic==1.10.4 # via - # pandas-profiling + # fastapi + # great-expectations # pandera -pygments==2.13.0 + # vaex-core + # ydata-profiling +pyerfa==2.0.0.1 + # via astropy +pygments==2.14.0 # via # furo # ipython # jupyter-console # nbconvert # qtconsole + # rich # sphinx # sphinx-prompt -pyopenssl==22.1.0 +pyjwt==2.6.0 + # via databricks-cli +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 # via # great-expectations # matplotlib - # packaging -pyrsistent==0.19.2 +pyrsistent==0.19.3 # via jsonschema pyspark==3.3.1 # via -r doc-requirements.in @@ -632,19 +827,28 @@ python-dateutil==2.8.2 # matplotlib # pandas # whylabs-client +python-dotenv==0.21.1 + # via uvicorn python-json-logger==2.0.4 - # via flytekit -python-slugify[unidecode]==6.1.2 + # via + # flytekit + # jupyter-events +python-slugify[unidecode]==8.0.0 # via # cookiecutter # sphinx-material +python-utils==3.4.5 + # via progressbar2 +pythreejs==2.4.1 + # via ipyvolume pytimeparse==1.1.8 # via flytekit -pytz==2022.6 +pytz==2022.7.1 # via # babel # flytekit # great-expectations + # mlflow # pandas pytz-deprecation-shim==0.1.0.post0 # via tzlocal @@ -652,15 +856,22 @@ pywavelets==1.4.1 # via imagehash pyyaml==6.0 # via + # astropy # cookiecutter + # dask + # distributed # flytekit + # jupyter-events # kubernetes - # pandas-profiling + # mlflow # papermill # pre-commit # ray # sphinx-autoapi -pyzmq==24.0.1 + # uvicorn + # vaex-core + # ydata-profiling +pyzmq==25.0.0 # via # ipykernel # jupyter-client @@ -672,20 +883,24 @@ qtconsole==5.4.0 # via jupyter qtpy==2.3.0 # via qtconsole -ray==2.1.0 +querystring-parser==1.2.4 + # via mlflow +ray==2.2.0 # via -r doc-requirements.in regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter + # databricks-cli # docker # flytekit # google-api-core # google-cloud-bigquery # great-expectations + # ipyvolume # kubernetes - # pandas-profiling + # mlflow # papermill # ray # requests-oauthlib @@ -693,6 +908,8 @@ requests==2.28.1 # sphinx # tensorboard # tf2onnx + # vaex-core + # ydata-profiling requests-oauthlib==1.3.1 # via # google-auth-oauthlib @@ -701,6 +918,16 @@ responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit +rfc3339-validator==0.1.4 + # via + # jsonschema + # jupyter-events +rfc3986-validator==0.1.1 + # via + # jsonschema + # jupyter-events +rich==13.3.1 + # via vaex-core rsa==4.9 # via google-auth ruamel-yaml==0.17.17 @@ -710,31 +937,37 @@ ruamel-yaml-clib==0.2.7 scikit-learn==1.1.1 # via # -r doc-requirements.in + # mlflow + # shap # skl2onnx scipy==1.9.3 # via # great-expectations # imagehash - # missingno - # pandas-profiling + # mlflow # phik # scikit-learn + # shap # skl2onnx # statsmodels -seaborn==0.12.1 - # via - # missingno - # pandas-profiling + # ydata-profiling +seaborn==0.12.2 + # via ydata-profiling +secretstorage==3.3.3 + # via keyring send2trash==1.8.0 # via # jupyter-server # nbclassic # notebook +shap==0.41.0 + # via mlflow six==1.16.0 # via # asttokens # astunparse # bleach + # databricks-cli # google-auth # google-pasta # grpcio @@ -742,17 +975,26 @@ six==1.16.0 # kubernetes # patsy # python-dateutil + # querystring-parser + # rfc3339-validator # sphinx-code-include # tensorflow # tf2onnx + # vaex-core skl2onnx==1.13 # via -r doc-requirements.in +slicer==0.0.7 + # via shap +smmap==5.0.0 + # via gitdb sniffio==1.3.0 # via anyio snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 - # via flytekit + # via + # distributed + # flytekit soupsieve==2.3.2.post1 # via beautifulsoup4 sphinx==4.5.0 @@ -761,6 +1003,7 @@ sphinx==4.5.0 # furo # sphinx-autoapi # sphinx-basic-ng + # sphinx-click # sphinx-code-include # sphinx-copybutton # sphinx-fontawesome @@ -770,10 +1013,12 @@ sphinx==4.5.0 # sphinx-prompt # sphinx-tags # sphinxcontrib-yt -sphinx-autoapi==2.0.0 +sphinx-autoapi==2.0.1 # via -r doc-requirements.in sphinx-basic-ng==1.0.0b1 # via furo +sphinx-click==4.4.0 + # via -r doc-requirements.in sphinx-code-include==1.1.1 # via -r doc-requirements.in sphinx-copybutton==0.5.1 @@ -788,13 +1033,13 @@ sphinx-panels==0.6.0 # via -r doc-requirements.in sphinx-prompt==1.5.0 # via -r doc-requirements.in -sphinx-tags==0.1.6 +sphinx-tags==0.2.0 # via -r doc-requirements.in -sphinxcontrib-applehelp==1.0.2 +sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx -sphinxcontrib-htmlhelp==2.0.0 +sphinxcontrib-htmlhelp==2.0.1 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx @@ -804,39 +1049,51 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-yt==0.2.2 # via -r doc-requirements.in -sqlalchemy==1.4.44 - # via -r doc-requirements.in -stack-data==0.6.1 +sqlalchemy==1.4.46 + # via + # -r doc-requirements.in + # alembic + # mlflow +sqlparse==0.4.3 + # via mlflow +stack-data==0.6.2 # via ipython +starlette==0.22.0 + # via fastapi statsd==3.3.0 # via flytekit statsmodels==0.13.5 - # via pandas-profiling + # via ydata-profiling +tabulate==0.9.0 + # via + # databricks-cli + # vaex-core tangled-up-in-unicode==0.2.0 # via visions +tblib==1.7.0 + # via distributed tenacity==8.1.0 # via # papermill # plotly -tensorboard==2.9.1 +tensorboard==2.8.0 # via tensorflow tensorboard-data-server==0.6.1 # via tensorboard tensorboard-plugin-wit==1.8.1 # via tensorboard -tensorflow==2.9.0 +tensorflow==2.8.1 # via -r doc-requirements.in -tensorflow-estimator==2.9.0 +tensorflow-estimator==2.8.0 # via tensorflow -tensorflow-io-gcs-filesystem==0.27.0 +tensorflow-io-gcs-filesystem==0.30.0 # via tensorflow -termcolor==2.0.1 - # via - # great-expectations - # tensorflow -terminado==0.17.0 +termcolor==2.2.0 + # via tensorflow +terminado==0.17.1 # via # jupyter-server + # jupyter-server-terminals # nbclassic # notebook text-unidecode==1.3 @@ -850,33 +1107,43 @@ threadpoolctl==3.1.0 tinycss2==1.2.1 # via nbconvert toml==0.10.2 - # via - # pre-commit - # responses + # via responses toolz==0.12.0 - # via altair -torch==1.13.0 + # via + # altair + # dask + # distributed + # partd +torch==1.13.1 # via -r doc-requirements.in tornado==6.2 # via + # distributed # ipykernel # jupyter-client # jupyter-server # nbclassic # notebook # terminado + # vaex-server tqdm==4.64.1 # via # great-expectations - # pandas-profiling # papermill -traitlets==5.5.0 + # shap + # ydata-profiling +traitlets==5.9.0 # via + # bqplot + # comm # ipykernel + # ipympl # ipython + # ipyvolume # ipywidgets # jupyter-client # jupyter-core + # jupyter-events # jupyter-server # matplotlib-inline # nbclassic @@ -884,7 +1151,18 @@ traitlets==5.5.0 # nbconvert # nbformat # notebook + # pythreejs # qtconsole + # traittypes + # vaex-ml +traittypes==0.2.1 + # via + # bqplot + # ipydatawidgets + # ipyleaflet + # ipyvolume +typeguard==2.13.3 + # via ydata-profiling types-toml==0.10.8.1 # via responses typing-extensions==4.4.0 @@ -894,6 +1172,7 @@ typing-extensions==4.4.0 # great-expectations # onnx # pydantic + # starlette # tensorflow # torch # typing-inspect @@ -910,9 +1189,12 @@ unidecode==1.3.6 # via # python-slugify # sphinx-autoapi -urllib3==1.26.12 +uri-template==1.2.0 + # via jsonschema +urllib3==1.26.14 # via # botocore + # distributed # docker # flytekit # great-expectations @@ -920,37 +1202,76 @@ urllib3==1.26.12 # requests # responses # whylabs-client -virtualenv==20.16.7 +uvicorn[standard]==0.20.0 + # via vaex-server +uvloop==0.17.0 + # via uvicorn +vaex==4.16.0 + # via -r doc-requirements.in +vaex-astro==0.9.3 + # via vaex +vaex-core==4.16.1 + # via + # vaex + # vaex-astro + # vaex-hdf5 + # vaex-jupyter + # vaex-ml + # vaex-server + # vaex-viz +vaex-hdf5==0.14.1 + # via vaex +vaex-jupyter==0.8.1 + # via vaex +vaex-ml==0.18.1 + # via vaex +vaex-server==0.8.1 + # via vaex +vaex-viz==0.5.4 + # via + # vaex + # vaex-jupyter +virtualenv==20.17.1 # via # pre-commit # ray visions[type_image_path]==0.7.5 - # via pandas-profiling -wcwidth==0.2.5 + # via ydata-profiling +watchfiles==0.18.1 + # via uvicorn +wcwidth==0.2.6 # via prompt-toolkit +webcolors==1.12 + # via jsonschema webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.4.2 +websocket-client==1.5.0 # via # docker # jupyter-server # kubernetes +websockets==10.4 + # via uvicorn werkzeug==2.2.2 - # via tensorboard + # via + # flask + # tensorboard wheel==0.38.4 # via # astunparse # flytekit + # nvidia-cublas-cu11 + # nvidia-cuda-runtime-cu11 # tensorboard -whylabs-client==0.4.0 +whylabs-client==0.4.3 # via -r doc-requirements.in -whylogs==1.1.13 +whylogs==1.1.24 # via -r doc-requirements.in whylogs-sketching==3.4.1.dev3 # via whylogs -widgetsnbextension==4.0.3 +widgetsnbextension==4.0.5 # via ipywidgets wrapt==1.14.1 # via @@ -959,7 +1280,15 @@ wrapt==1.14.1 # flytekit # pandera # tensorflow -zipp==3.10.0 +xarray==2023.1.0 + # via vaex-jupyter +xyzservices==2022.9.0 + # via ipyleaflet +ydata-profiling==4.0.0 + # via pandas-profiling +zict==2.2.0 + # via distributed +zipp==3.12.0 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/docs/Makefile b/docs/Makefile index e61723ad76..afa73807cb 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -18,3 +18,7 @@ help: # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + + +clean: + rm -rf ./build ./source/generated diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000000..1fb1b91359 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,7 @@ +# TODO: Remove after buf migration is done and packages updated, see doc-requirements.in +# skl2onnx and tf2onnx added here so that the plugin API reference is rendered, +# with the caveat that the docs build environment has the environment variable +# PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set so that protobuf can be parsed +# using Python, which is acceptable for docs building. +skl2onnx +tf2onnx diff --git a/docs/source/clients.rst b/docs/source/clients.rst new file mode 100644 index 0000000000..f67ebf6a3a --- /dev/null +++ b/docs/source/clients.rst @@ -0,0 +1,4 @@ +.. automodule:: flytekit.clients + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/conf.py b/docs/source/conf.py index 6aba967ae9..1745e56efe 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,8 +28,6 @@ sys.path.insert(0, flytekit_src_dir) sys.path.insert(0, flytekit_dir) -print(sys.path) - # -- Project information ----------------------------------------------------- project = "Flytekit" @@ -62,6 +60,7 @@ "sphinx_panels", "sphinxcontrib.yt", "sphinx_tags", + "sphinx_click", ] # build the templated autosummary files diff --git a/docs/source/extras.tensorflow.rst b/docs/source/extras.tensorflow.rst new file mode 100644 index 0000000000..699dd44da0 --- /dev/null +++ b/docs/source/extras.tensorflow.rst @@ -0,0 +1,7 @@ +############ +TensorFlow Type +############ +.. automodule:: flytekit.extras.tensorflow + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/index.rst b/docs/source/index.rst index 13630e9e58..db5902391b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -76,6 +76,7 @@ Expected output: flytekit configuration remote + clients testing extend deck @@ -83,4 +84,5 @@ Expected output: tasks.extend types.extend data.extend + pyflyte contributing diff --git a/docs/source/plugins/index.rst b/docs/source/plugins/index.rst index dd1f59238d..008f2b4bbe 100644 --- a/docs/source/plugins/index.rst +++ b/docs/source/plugins/index.rst @@ -26,8 +26,10 @@ Plugin API reference * :ref:`ONNX PyTorch ` - ONNX PyTorch API reference * :ref:`ONNX TensorFlow ` - ONNX TensorFlow API reference * :ref:`ONNX ScikitLearn ` - ONNX ScikitLearn API reference -* :ref:`Ray ` - Ray +* :ref:`Ray ` - Ray API reference * :ref:`DBT ` - DBT API reference +* :ref:`Vaex ` - Vaex API reference +* :ref:`MLflow ` - MLflow API reference .. toctree:: :maxdepth: 2 @@ -57,3 +59,5 @@ Plugin API reference ONNX ScikitLearn Ray DBT + Vaex + MLflow diff --git a/docs/source/plugins/mlflow.rst b/docs/source/plugins/mlflow.rst new file mode 100644 index 0000000000..60d1a7c66b --- /dev/null +++ b/docs/source/plugins/mlflow.rst @@ -0,0 +1,9 @@ +.. _mlflow: + +################################################### +MLflow API reference +################################################### + +.. tags:: Integration, MachineLearning, Tracking + +.. automodule:: flytekitplugins.mlflow diff --git a/docs/source/pyflyte.rst b/docs/source/pyflyte.rst new file mode 100644 index 0000000000..cbbf657bc3 --- /dev/null +++ b/docs/source/pyflyte.rst @@ -0,0 +1,27 @@ +########### +Pyflyte CLI +########### + +.. click:: flytekit.clis.sdk_in_container.init:init + :prog: pyflyte init + :nested: full + +.. click:: flytekit.clis.sdk_in_container.local_cache:local_cache + :prog: pyflyte local-cache + :nested: full + +.. click:: flytekit.clis.sdk_in_container.package:package + :prog: pyflyte package + :nested: full + +.. click:: flytekit.clis.sdk_in_container.register:register + :prog: pyflyte register + :nested: full + +.. click:: flytekit.clis.sdk_in_container.run:run + :prog: pyflyte run + :nested: none + +.. click:: flytekit.clis.sdk_in_container.serialize:serialize + :prog: pyflyte serialize + :nested: full diff --git a/docs/source/types.extend.rst b/docs/source/types.extend.rst index b7382c5993..db1cb8dfff 100644 --- a/docs/source/types.extend.rst +++ b/docs/source/types.extend.rst @@ -15,3 +15,5 @@ Refer to the :ref:`extensibility contribution guide str: """ diff --git a/flytekit/clients/__init__.py b/flytekit/clients/__init__.py index e69de29bb2..1b08e1c567 100644 --- a/flytekit/clients/__init__.py +++ b/flytekit/clients/__init__.py @@ -0,0 +1,19 @@ +""" +===================== +Clients +===================== + +.. currentmodule:: flytekit.clients + +This module provides lower level access to a Flyte backend. + +.. _clients_module: + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + :nosignatures: + + ~friendly.SynchronousFlyteClient + ~raw.RawSynchronousFlyteClient +""" diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index bf4c1e2d0c..6c8f54e9ce 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -10,11 +10,13 @@ import grpc import requests as _requests from flyteidl.admin.project_pb2 import ProjectListRequest +from flyteidl.admin.signal_pb2 import SignalList, SignalListRequest, SignalSetRequest, SignalSetResponse from flyteidl.service import admin_pb2_grpc as _admin_service from flyteidl.service import auth_pb2 from flyteidl.service import auth_pb2_grpc as auth_service from flyteidl.service import dataproxy_pb2 as _dataproxy_pb2 from flyteidl.service import dataproxy_pb2_grpc as dataproxy_service +from flyteidl.service import signal_pb2_grpc as signal_service from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub from google.protobuf.json_format import MessageToJson as _MessageToJson @@ -79,7 +81,6 @@ def handler(self, create_request): cli_logger.error(_MessageToJson(create_request)) cli_logger.error("Details returned from the flyte admin: ") cli_logger.error(e.details) - e.details += "create_request: " + _MessageToJson(create_request) # Re-raise since we're not handling the error here and add the create_request details raise e @@ -146,6 +147,7 @@ def __init__(self, cfg: PlatformConfig, **kwargs): ) self._stub = _admin_service.AdminServiceStub(self._channel) self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel) + self._signal = signal_service.SignalServiceStub(self._channel) try: resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest()) self._public_client_config = resp @@ -260,7 +262,6 @@ def _refresh_credentials_from_command(self): :param self: RawSynchronousFlyteClient :return: """ - command = self._cfg.command if not command: raise FlyteAuthenticationException("No command specified in configuration for command authentication") @@ -408,6 +409,20 @@ def get_task(self, get_object_request): """ return self._stub.GetTask(get_object_request, metadata=self._metadata) + @_handle_rpc_error(retry=True) + def set_signal(self, signal_set_request: SignalSetRequest) -> SignalSetResponse: + """ + This sets a signal + """ + return self._signal.SetSignal(signal_set_request, metadata=self._metadata) + + @_handle_rpc_error(retry=True) + def list_signals(self, signal_list_request: SignalListRequest) -> SignalList: + """ + This lists signals + """ + return self._signal.ListSignals(signal_list_request, metadata=self._metadata) + #################################################################################################################### # # Workflow Endpoints diff --git a/flytekit/clis/sdk_in_container/backfill.py b/flytekit/clis/sdk_in_container/backfill.py new file mode 100644 index 0000000000..80a799b600 --- /dev/null +++ b/flytekit/clis/sdk_in_container/backfill.py @@ -0,0 +1,178 @@ +import typing +from datetime import datetime, timedelta + +import click + +from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.clis.sdk_in_container.run import DateTimeType, DurationParamType + +_backfill_help = """ +The backfill command generates and registers a new workflow based on the input launchplan to run an +automated backfill. The workflow can be managed using the Flyte UI and can be canceled, relaunched, and recovered. + +- launchplan refers to the name of the launchplan +- launchplan_version is optional and should be a valid version for a launchplan version. +""" + + +def resolve_backfill_window( + from_date: datetime = None, + to_date: datetime = None, + backfill_window: timedelta = None, +) -> typing.Tuple[datetime, datetime]: + """ + Resolves the from_date -> to_date + """ + if from_date and to_date and backfill_window: + raise click.BadParameter("Setting from-date, to-date and backfill_window at the same time is not allowed.") + if not (from_date or to_date): + raise click.BadParameter( + "One of following pairs are required -> (from-date, to-date) | (from-date, backfill_window) |" + " (to-date, backfill_window)" + ) + if from_date and to_date: + pass + elif not backfill_window: + raise click.BadParameter("One of start-date and end-date are needed with duration") + elif from_date: + to_date = from_date + backfill_window + else: + from_date = to_date - backfill_window + return from_date, to_date + + +@click.command("backfill", help=_backfill_help) +@click.option( + "-p", + "--project", + required=False, + type=str, + default="flytesnacks", + help="Project to register and run this workflow in", +) +@click.option( + "-d", + "--domain", + required=False, + type=str, + default="development", + help="Domain to register and run this workflow in", +) +@click.option( + "-v", + "--version", + required=False, + type=str, + default=None, + help="Version for the registered workflow. If not specified it is auto-derived using the start and end date", +) +@click.option( + "-n", + "--execution-name", + required=False, + type=str, + default=None, + help="Create a named execution for the backfill. This can prevent launching multiple executions.", +) +@click.option( + "--dry-run", + required=False, + type=bool, + is_flag=True, + default=False, + show_default=True, + help="Just generate the workflow - do not register or execute", +) +@click.option( + "--parallel/--serial", + required=False, + type=bool, + is_flag=True, + default=False, + show_default=True, + help="All backfill steps can be run in parallel (limited by max-parallelism), if using --parallel." + " Else all steps will be run sequentially [--serial].", +) +@click.option( + "--execute/--do-not-execute", + required=False, + type=bool, + is_flag=True, + default=True, + show_default=True, + help="Generate the workflow and register, do not execute", +) +@click.option( + "--from-date", + required=False, + type=DateTimeType(), + default=None, + help="Date from which the backfill should begin. Start date is inclusive.", +) +@click.option( + "--to-date", + required=False, + type=DateTimeType(), + default=None, + help="Date to which the backfill should run_until. End date is inclusive", +) +@click.option( + "--backfill-window", + required=False, + type=DurationParamType(), + default=None, + help="Timedelta for number of days, minutes hours after the from-date or before the to-date to compute the " + "backfills between. This is needed with from-date / to-date. Optional if both from-date and to-date are " + "provided", +) +@click.argument( + "launchplan", + required=True, + type=str, +) +@click.argument( + "launchplan-version", + required=False, + type=str, + default=None, +) +@click.pass_context +def backfill( + ctx: click.Context, + project: str, + domain: str, + from_date: datetime, + to_date: datetime, + backfill_window: timedelta, + launchplan: str, + launchplan_version: str, + dry_run: bool, + execute: bool, + parallel: bool, + execution_name: str, + version: str, +): + from_date, to_date = resolve_backfill_window(from_date, to_date, backfill_window) + remote = get_and_save_remote_with_click_context(ctx, project, domain) + try: + entity = remote.launch_backfill( + project=project, + domain=domain, + from_date=from_date, + to_date=to_date, + launchplan=launchplan, + launchplan_version=launchplan_version, + execution_name=execution_name, + version=version, + dry_run=dry_run, + execute=execute, + parallel=parallel, + ) + if entity: + console_url = remote.generate_console_url(entity) + if execute: + click.secho(f"\n Execution launched {console_url} to see execution in the console.", fg="green") + return + click.secho(f"\n Workflow registered at {console_url}", fg="green") + except StopIteration as e: + click.secho(f"{e.value}", fg="red") diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index 72246bcba4..6ac451be92 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -4,7 +4,7 @@ import click from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE -from flytekit.configuration import Config, ImageConfig +from flytekit.configuration import Config, ImageConfig, get_config_file from flytekit.loggers import cli_logger from flytekit.remote.remote import FlyteRemote @@ -25,10 +25,15 @@ def get_and_save_remote_with_click_context( :return: FlyteRemote instance """ cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) - cfg_obj = Config.auto(cfg_file_location) - cli_logger.info( - f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "") - ) + cfg_file = get_config_file(cfg_file_location) + if cfg_file is None: + cfg_obj = Config.for_sandbox() + cli_logger.info("No config files found, creating remote with sandbox config") + else: + cfg_obj = Config.auto(cfg_file_location) + cli_logger.info( + f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "") + ) r = FlyteRemote(cfg_obj, default_project=project, default_domain=domain) if save: ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index 76777c5663..1f843450ed 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -1,6 +1,7 @@ import click from flytekit import configuration +from flytekit.clis.sdk_in_container.backfill import backfill from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_PACKAGES from flytekit.clis.sdk_in_container.init import init from flytekit.clis.sdk_in_container.local_cache import local_cache @@ -70,6 +71,7 @@ def main(ctx, pkgs=None, config=None): main.add_command(init) main.add_command(run) main.add_command(register) +main.add_command(backfill) if __name__ == "__main__": main() diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 1556e343bf..2a167e9d0e 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -17,7 +17,7 @@ and the flytectl register step in one command. This is why you see switches you'd normally use with flytectl like service account here. -Note: This command runs "fast" register by default. Future work to come to add a non-fast version. +Note: This command runs "fast" register by default. This means that a zip is created from the detected root of the packages given, and uploaded. Just like with pyflyte run, tasks registered from this command will download and unzip that code package before running. @@ -67,7 +67,7 @@ help="Directory to write the output zip file containing the protobuf definitions", ) @click.option( - "-d", + "-D", "--destination-dir", required=False, type=str, @@ -107,6 +107,12 @@ is_flag=True, help="Enables to skip zipping and uploading the package", ) +@click.option( + "--dry-run", + default=False, + is_flag=True, + help="Execute registration in dry-run mode. Skips actual registration to remote", +) @click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) @click.pass_context def register( @@ -122,6 +128,7 @@ def register( deref_symlinks: bool, non_fast: bool, package_or_module: typing.Tuple[str], + dry_run: bool, ): """ see help @@ -156,6 +163,7 @@ def register( # Create and save FlyteRemote, remote = get_and_save_remote_with_click_context(ctx, project, domain) + click.secho(f"Registering against {remote.config.platform.endpoint}") try: repo.register( project, @@ -170,6 +178,7 @@ def register( fast=not non_fast, package_or_module=package_or_module, remote=remote, + dry_run=dry_run, ) except Exception as e: raise e diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 0121094805..793c15c911 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -105,12 +105,32 @@ def convert( raise click.BadParameter(f"parameter should be a valid file path, {value}") +class DateTimeType(click.DateTime): + + _NOW_FMT = "now" + _ADDITONAL_FORMATS = [_NOW_FMT] + + def __init__(self): + super().__init__() + self.formats.extend(self._ADDITONAL_FORMATS) + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + if value in self._ADDITONAL_FORMATS: + if value == self._NOW_FMT: + return datetime.datetime.now() + return super().convert(value, param, ctx) + + class DurationParamType(click.ParamType): - name = "timedelta" + name = "[1:24 | :22 | 1 minute | 10 days | ...]" def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: + if value is None: + raise click.BadParameter("None value cannot be converted to a Duration type.") return datetime.timedelta(seconds=parse(value)) @@ -303,6 +323,9 @@ def convert_to_literal( if self._literal_type.simple or self._literal_type.enum_type: if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT: if self._python_type == dict: + if type(value) != str: + # The type of default value is dict, so we have to convert it to json string + value = json.dumps(value) o = json.loads(value) elif type(value) != self._python_type: o = cast(DataClassJsonMixin, self._python_type).from_json(value) @@ -644,7 +667,8 @@ def list_commands(self, ctx): return [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"] def get_command(self, ctx, filename): - ctx.obj[RUN_LEVEL_PARAMS_KEY] = ctx.params + if ctx.obj: + ctx.obj[RUN_LEVEL_PARAMS_KEY] = ctx.params return WorkflowCommand(filename, name=filename, help="Run a [workflow|task] in a file using script mode") diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index c9e15c2eba..220f9209ea 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -53,6 +53,7 @@ ~Image ~ImageConfig ~SerializationSettings + ~FastSerializationSettings .. _configuration-execution-time-settings: @@ -299,7 +300,7 @@ class PlatformConfig(object): :param endpoint: DNS for Flyte backend :param insecure: Whether or not to use SSL :param insecure_skip_verify: Wether to skip SSL certificate verification - :param console_endpoint: endpoint for console if differenet than Flyte backend + :param console_endpoint: endpoint for console if different than Flyte backend :param command: This command is executed to return a token using an external process. :param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. @@ -310,7 +311,7 @@ class PlatformConfig(object): :param auth_mode: The OAuth mode to use. Defaults to pkce flow. """ - endpoint: str = "localhost:30081" + endpoint: str = "localhost:30080" insecure: bool = False insecure_skip_verify: bool = False console_endpoint: typing.Optional[str] = None @@ -462,7 +463,7 @@ class GCSConfig(object): gsutil_parallelism: bool = False @classmethod - def auto(self, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: config_file = get_config_file(config_file) kwargs = {} kwargs = set_if_exists(kwargs, "gsutil_parallelism", _internal.GCP.GSUTIL_PARALLELISM.read(config_file)) @@ -528,7 +529,7 @@ def with_params( ) @classmethod - def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> Config: + def auto(cls, config_file: typing.Union[str, ConfigFile, None] = None) -> Config: """ Automatically constructs the Config Object. The order of precedence is as follows 1. first try to find any env vars that match the config vars specified in the FLYTE_CONFIG format. @@ -557,9 +558,9 @@ def for_sandbox(cls) -> Config: :return: Config """ return Config( - platform=PlatformConfig(endpoint="localhost:30081", auth_mode="Pkce", insecure=True), + platform=PlatformConfig(endpoint="localhost:30080", auth_mode="Pkce", insecure=True), data_config=DataConfig( - s3=S3Config(endpoint="http://localhost:30084", access_key_id="minio", secret_access_key="miniostorage") + s3=S3Config(endpoint="http://localhost:30002", access_key_id="minio", secret_access_key="miniostorage") ), ) @@ -646,6 +647,7 @@ class SerializationSettings(object): domain: typing.Optional[str] = None version: typing.Optional[str] = None env: Optional[Dict[str, str]] = None + git_repo: Optional[str] = None python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER flytekit_virtualenv_root: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None @@ -718,6 +720,7 @@ def new_builder(self) -> Builder: version=self.version, image_config=self.image_config, env=self.env.copy() if self.env else None, + git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, @@ -767,6 +770,7 @@ class Builder(object): version: str image_config: ImageConfig env: Optional[Dict[str, str]] = None + git_repo: Optional[str] = None flytekit_virtualenv_root: Optional[str] = None python_interpreter: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None @@ -782,6 +786,7 @@ def build(self) -> SerializationSettings: version=self.version, image_config=self.image_config, env=self.env, + git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, diff --git a/flytekit/configuration/default_images.py b/flytekit/configuration/default_images.py index 33520e544f..8c01041eed 100644 --- a/flytekit/configuration/default_images.py +++ b/flytekit/configuration/default_images.py @@ -16,10 +16,10 @@ class DefaultImages(object): """ _DEFAULT_IMAGE_PREFIXES = { - PythonVersion.PYTHON_3_7: "ghcr.io/flyteorg/flytekit:py3.7-", - PythonVersion.PYTHON_3_8: "ghcr.io/flyteorg/flytekit:py3.8-", - PythonVersion.PYTHON_3_9: "ghcr.io/flyteorg/flytekit:py3.9-", - PythonVersion.PYTHON_3_10: "ghcr.io/flyteorg/flytekit:py3.10-", + PythonVersion.PYTHON_3_7: "cr.flyte.org/flyteorg/flytekit:py3.7-", + PythonVersion.PYTHON_3_8: "cr.flyte.org/flyteorg/flytekit:py3.8-", + PythonVersion.PYTHON_3_9: "cr.flyte.org/flyteorg/flytekit:py3.9-", + PythonVersion.PYTHON_3_10: "cr.flyte.org/flyteorg/flytekit:py3.10-", } @classmethod diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py index 467f660d42..23210e95f1 100644 --- a/flytekit/configuration/file.py +++ b/flytekit/configuration/file.py @@ -18,6 +18,11 @@ FLYTECTL_CONFIG_ENV_VAR = "FLYTECTL_CONFIG" +def _exists(val: typing.Any) -> bool: + """Check if a value is defined.""" + return isinstance(val, bool) or bool(val is not None and val) + + @dataclass class LegacyConfigEntry(object): """ @@ -63,7 +68,7 @@ def read_from_file( @dataclass class YamlConfigEntry(object): """ - Creates a record for the config entry. contains + Creates a record for the config entry. Args: switch: dot-delimited string that should match flytectl args. Leaving it as dot-delimited instead of a list of strings because it's easier to maintain alignment with flytectl. @@ -80,10 +85,11 @@ def read_from_file( return None try: v = cfg.get(self) - if v: + if _exists(v): return transform(v) if transform else v except Exception: ... + return None @@ -224,7 +230,7 @@ def legacy_config(self) -> _configparser.ConfigParser: return self._legacy_config @property - def yaml_config(self) -> typing.Dict[str, Any]: + def yaml_config(self) -> typing.Dict[str, typing.Any]: return self._yaml_config @@ -273,7 +279,7 @@ def set_if_exists(d: dict, k: str, v: typing.Any) -> dict: The input dictionary ``d`` will be mutated. """ - if v: + if _exists(v): d[k] = v return d diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index d2e4838ed8..7fcdc15a50 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -41,7 +41,7 @@ def __init__( task_config=task_config, **kwargs, ) - self._query_template = query_template.replace("\n", "\\n").replace("\t", "\\t") + self._query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip() @property def query_template(self) -> str: diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index dccbaec803..2cf8032a6f 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -45,6 +45,7 @@ from flytekit.models import literals as _literal_models from flytekit.models import task as _task_model from flytekit.models.core import workflow as _workflow_model +from flytekit.models.documentation import Description, Documentation from flytekit.models.interface import Variable from flytekit.models.security import SecurityContext @@ -84,6 +85,7 @@ class TaskMetadata(object): timeout (Optional[Union[datetime.timedelta, int]]): the max amount of time for which one execution of this task should be executed for. The execution will be terminated if the runtime exceeds the given timeout (approximately) + pod_template_name (Optional[str]): the name of existing PodTemplate resource in the cluster which will be used in this task. """ cache: bool = False @@ -93,6 +95,7 @@ class TaskMetadata(object): deprecated: str = "" retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None + pod_template_name: Optional[str] = None def __post_init__(self): if self.timeout: @@ -126,6 +129,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: discovery_version=self.cache_version, deprecated_error_message=self.deprecated, cache_serializable=self.cache_serialize, + pod_template_name=self.pod_template_name, ) @@ -156,6 +160,7 @@ def __init__( metadata: Optional[TaskMetadata] = None, task_type_version=0, security_ctx: Optional[SecurityContext] = None, + docs: Optional[Documentation] = None, **kwargs, ): self._task_type = task_type @@ -164,6 +169,7 @@ def __init__( self._metadata = metadata if metadata else TaskMetadata() self._task_type_version = task_type_version self._security_ctx = security_ctx + self._docs = docs FlyteEntities.entities.append(self) @@ -195,6 +201,10 @@ def task_type_version(self) -> int: def security_context(self) -> SecurityContext: return self._security_ctx + @property + def docs(self) -> Documentation: + return self._docs + def get_type_for_input_var(self, k: str, v: Any) -> type: """ Returns the python native type for the given input variable @@ -248,7 +258,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # The cache returns None iff the key does not exist in the cache if outputs_literal_map is None: logger.info("Cache miss, task will be executed now") - outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) + outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) # TODO: need `native_inputs` LocalTaskCache.set(self.name, self.metadata.cache_version, input_literal_map, outputs_literal_map) logger.info( @@ -258,10 +268,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr else: logger.info("Cache hit") else: - es = ctx.execution_state - b = es.user_space_params.with_task_sandbox() - ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() - outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) + # This code should mirror the call to `sandbox_execute` in the above cache case. + # Code is simpler with duplication and less metaprogramming, but introduces regressions + # if one is changed and not the other. + outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) outputs_literals = outputs_literal_map.literals # TODO maybe this is the part that should be done for local execution, we pass the outputs to some special @@ -316,6 +326,19 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] """ return None + def sandbox_execute( + self, + ctx: FlyteContext, + input_literal_map: _literal_models.LiteralMap, + ) -> _literal_models.LiteralMap: + """ + Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime. + """ + es = ctx.execution_state + b = es.user_space_params.with_task_sandbox() + ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() + return self.dispatch_execute(ctx, input_literal_map) + @abstractmethod def dispatch_execute( self, @@ -390,6 +413,17 @@ def __init__( self._environment = environment if environment else {} self._task_config = task_config self._disable_deck = disable_deck + if self._python_interface.docstring: + if self.docs is None: + self._docs = Documentation( + short_description=self._python_interface.docstring.short_description, + long_description=Description(value=self._python_interface.docstring.long_description), + ) + else: + if self._python_interface.docstring.short_description: + self._docs.short_description = self._python_interface.docstring.short_description + if self._python_interface.docstring.long_description: + self._docs.long_description = Description(value=self._python_interface.docstring.long_description) # TODO lets call this interface and the other as flyte_interface? @property diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 848c1d2524..d470fb54fe 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -10,6 +10,7 @@ from flytekit.models.security import Secret, SecurityContext +# TODO: do we need pod_template here? Seems that it is a raw container not running in pods class ContainerTask(PythonTask): """ This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast diff --git a/flytekit/core/gate.py b/flytekit/core/gate.py new file mode 100644 index 0000000000..b6cb7ca2b6 --- /dev/null +++ b/flytekit/core/gate.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import datetime +import typing +from typing import Tuple, Union + +import click + +from flytekit.core import interface as flyte_interface +from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.promise import Promise, VoidPromise, flyte_entity_call_handler +from flytekit.core.type_engine import TypeEngine +from flytekit.exceptions.user import FlyteDisapprovalException +from flytekit.interaction.parse_stdin import parse_stdin_to_literal +from flytekit.models.core import workflow as _workflow_model +from flytekit.models.types import LiteralType + +DEFAULT_TIMEOUT = datetime.timedelta(hours=1) + + +class Gate(object): + """ + A node type that waits for user input before proceeding with a workflow. + A gate is a type of node that behaves like a task, but instead of running code, it either needs to wait + for user input to proceed or wait for a timer to complete running. + """ + + def __init__( + self, + name: str, + input_type: typing.Optional[typing.Type] = None, + upstream_item: typing.Optional[typing.Any] = None, + sleep_duration: typing.Optional[datetime.timedelta] = None, + timeout: typing.Optional[datetime.timedelta] = None, + ): + self._name = name + self._input_type = input_type + self._sleep_duration = sleep_duration + self._timeout = timeout or DEFAULT_TIMEOUT + self._upstream_item = upstream_item + self._literal_type = TypeEngine.to_literal_type(input_type) if input_type else None + + # Determine the python interface if we can + if self._sleep_duration: + # Just a sleep so there is no interface + self._python_interface = flyte_interface.Interface() + elif input_type: + # Waiting for user input, so the output of the node is whatever input the user provides. + self._python_interface = flyte_interface.Interface( + outputs={ + "o0": self.input_type, + } + ) + else: + # We don't know how to find the python interface here, approve() sets it below, See the code. + self._python_interface = None + + @property + def name(self) -> str: + # Part of SupportsNodeCreation interface + return self._name + + @property + def input_type(self) -> typing.Optional[typing.Type]: + return self._input_type + + @property + def literal_type(self) -> typing.Optional[LiteralType]: + return self._literal_type + + @property + def sleep_duration(self) -> typing.Optional[datetime.timedelta]: + return self._sleep_duration + + @property + def python_interface(self) -> flyte_interface.Interface: + """ + This will not be valid during local execution + Part of SupportsNodeCreation interface + """ + # If this is just a sleep node, or user input node, then it will have a Python interface upon construction. + if self._python_interface: + return self._python_interface + + raise ValueError("You can't check for a Python interface for an approval node outside of compilation") + + def construct_node_metadata(self) -> _workflow_model.NodeMetadata: + # Part of SupportsNodeCreation interface + return _workflow_model.NodeMetadata( + name=self.name, + timeout=self._timeout, + ) + + # This is to satisfy the LocallyExecutable protocol + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: + if self.sleep_duration: + print(f"Mock sleeping for {self.sleep_duration}") + return VoidPromise(self.name) + + # Trigger stdin + if self.input_type: + msg = f"Execution stopped for gate {self.name}...\n" + literal = parse_stdin_to_literal(ctx, self.input_type, msg) + p = Promise(var="o0", val=literal) + return p + + # Assume this is an approval operation since that's the only remaining option. + msg = f"Pausing execution for {self.name}, literal value is:\n{self._upstream_item.val}\nContinue?" + proceed = click.confirm(msg, default=True) + if proceed: + # We need to return a promise here, and a promise is what should've been passed in by the call in approve() + # Only one element should be in this map. Rely on kwargs instead of the stored _upstream_item even though + # they should be the same to be cleaner + output_name = list(kwargs.keys())[0] + return kwargs[output_name] + else: + raise FlyteDisapprovalException(f"User did not approve the transaction for gate node {self.name}") + + +def wait_for_input(name: str, timeout: datetime.timedelta, expected_type: typing.Type): + """Create a Gate object that waits for user input of the specified type. + + Create a Gate object. This object will function like a task. Note that unlike a task, + each time this function is called, a new Python object is created. If a workflow + calls a subworkflow twice, and the subworkflow has a signal, then two Gate + objects are created. This shouldn't be a problem as long as the objects are identical. + + :param name: The name of the gate node. + :param timeout: How long to wait for before Flyte fails the workflow. + :param expected_type: What is the type that the user will be inputting? + :return: + """ + + g = Gate(name, input_type=expected_type, timeout=timeout) + + return flyte_entity_call_handler(g) + + +def sleep(duration: datetime.timedelta): + """Create a sleep Gate object. + + :param duration: How long to sleep for + :return: + """ + g = Gate("sleep-gate", sleep_duration=duration) + + return flyte_entity_call_handler(g) + + +def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: str, timeout: datetime.timedelta): + """Create a Gate object for binary approval. + + Create a Gate object. This object will function like a task. Note that unlike a task, + each time this function is called, a new Python object is created. If a workflow + calls a subworkflow twice, and the subworkflow has a signal, then two Gate + objects are created. This shouldn't be a problem as long as the objects are identical. + + :param upstream_item: This should be the output, one output, of a previous task, that you want to gate execution + on. This is the value that you want a human to check before moving on. + :param name: The name of the gate node. + :param timeout: How long to wait before Flyte fails the workflow. + :return: + """ + g = Gate(name, upstream_item=upstream_item, timeout=timeout) + + if upstream_item is None or isinstance(upstream_item, VoidPromise): + raise ValueError("You can't use approval on a task that doesn't return anything.") + + ctx = FlyteContextManager.current_context() + if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: + if not upstream_item.ref.node.flyte_entity.python_interface: + raise ValueError( + f"Upstream node doesn't have a Python interface. Node entity is: " + f"{upstream_item.ref.node.flyte_entity}" + ) + + # We have reach back up to the entity that this promise came from, to get the python type, since + # the approve function itself doesn't have a python interface. + io_type = upstream_item.ref.node.flyte_entity.python_interface.outputs[upstream_item.var] + io_var_name = upstream_item.var + else: + # We don't know the python type here. in local execution, downstream doesn't really use the type + # so we should be okay. But use None instead of type() so that errors are more obvious hopefully. + io_type = None + io_var_name = "o0" + + # In either case, we need a python interface + g._python_interface = flyte_interface.Interface( + inputs={ + io_var_name: io_type, + }, + outputs={ + io_var_name: io_type, + }, + ) + kwargs = {io_var_name: upstream_item} + + return flyte_entity_call_handler(g, **kwargs) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 63d7c8106f..954c1ae409 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -207,7 +207,6 @@ def transform_interface_to_typed_interface( """ if interface is None: return None - if interface.docstring is None: input_descriptions = output_descriptions = {} else: diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 0d143e5fe8..550dc1919e 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -455,8 +455,15 @@ def reference_launch_plan( ) -> Callable[[Callable[..., Any]], ReferenceLaunchPlan]: """ A reference launch plan is a pointer to a launch plan that already exists on your Flyte installation. This - object will not initiate a network call to Admin, which is why the user is asked to provide the expected interface. + object will not initiate a network call to Admin, which is why the user is asked to provide the expected interface + via the function definition. + If at registration time the interface provided causes an issue with compilation, an error will be returned. + + :param project: Flyte project name of the launch plan + :param domain: Flyte domain name of the launch plan + :param name: launch plan name + :param version: specific version of the launch plan to use """ def wrapper(fn) -> ReferenceLaunchPlan: diff --git a/flytekit/core/node.py b/flytekit/core/node.py index d849ef5397..617790746f 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -6,6 +6,7 @@ from flytekit.core.resources import Resources from flytekit.core.utils import _dnsify +from flytekit.loggers import logger from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.task import Resources as _resources_model @@ -51,6 +52,10 @@ def __rshift__(self, other: Node): self.runs_before(other) return other + @property + def name(self) -> str: + return self._id + @property def outputs(self): if self._outputs is None: @@ -110,6 +115,12 @@ def with_overrides(self, *args, **kwargs): self._metadata._interruptible = kwargs["interruptible"] if "name" in kwargs: self._metadata._name = kwargs["name"] + if "task_config" in kwargs: + logger.warning("This override is beta. We may want to revisit this in the future.") + new_task_config = kwargs["task_config"] + if not isinstance(new_task_config, type(self.flyte_entity._task_config)): + raise ValueError("can't change the type of the task config") + self.flyte_entity._task_config = new_task_config return self diff --git a/flytekit/core/pod_template.py b/flytekit/core/pod_template.py new file mode 100644 index 0000000000..5e9c746911 --- /dev/null +++ b/flytekit/core/pod_template.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Dict, Optional + +from kubernetes.client.models import V1PodSpec + +from flytekit.exceptions import user as _user_exceptions + +PRIMARY_CONTAINER_DEFAULT_NAME = "primary" + + +@dataclass +class PodTemplate(object): + """Custom PodTemplate specification for a Task.""" + + pod_spec: V1PodSpec = V1PodSpec(containers=[]) + primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME + labels: Optional[Dict[str, str]] = None + annotations: Optional[Dict[str, str]] = None + + def __post_init__(self): + if not self.primary_container_name: + raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 0fb404fb9e..3a851a50ea 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1,9 +1,8 @@ from __future__ import annotations import collections -import typing from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast from typing_extensions import Protocol, get_args @@ -70,7 +69,6 @@ def extract_value( val_type: type, flyte_literal_type: _type_models.LiteralType, ) -> _literal_models.Literal: - if isinstance(input_val, list): lt = flyte_literal_type python_type = val_type @@ -143,17 +141,16 @@ def extract_value( def get_primitive_val(prim: Primitive) -> Any: - if prim.integer: - return prim.integer - if prim.datetime: - return prim.datetime - if prim.boolean: - return prim.boolean - if prim.duration: - return prim.duration - if prim.string_value: - return prim.string_value - return prim.float_value + for value in [ + prim.integer, + prim.float_value, + prim.string_value, + prim.boolean, + prim.datetime, + prim.duration, + ]: + if value is not None: + return value class ConjunctionOps(Enum): @@ -350,7 +347,7 @@ def __init__(self, var: str, val: Union[NodeOutput, _literal_models.Literal]): def __hash__(self): return hash(id(self)) - def __rshift__(self, other: typing.Union[Promise, VoidPromise]): + def __rshift__(self, other: Union[Promise, VoidPromise]): if not self.is_ready: self.ref.node.runs_before(other.ref.node) return other @@ -458,7 +455,7 @@ def __str__(self): def create_native_named_tuple( ctx: FlyteContext, - promises: Optional[Union[Promise, typing.List[Promise]]], + promises: Optional[Union[Promise, List[Promise]]], entity_interface: Interface, ) -> Optional[Tuple]: """ @@ -578,7 +575,7 @@ def binding_from_flyte_std( ctx: _flyte_context.FlyteContext, var_name: str, expected_literal_type: _type_models.LiteralType, - t_value: typing.Any, + t_value: Any, ) -> _literals_models.Binding: binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type=None) return _literals_models.Binding(var=var_name, binding=binding_data) @@ -587,7 +584,7 @@ def binding_from_flyte_std( def binding_data_from_python_std( ctx: _flyte_context.FlyteContext, expected_literal_type: _type_models.LiteralType, - t_value: typing.Any, + t_value: Any, t_value_type: Optional[type] = None, ) -> _literals_models.BindingData: # This handles the case where the given value is the output of another task @@ -654,7 +651,7 @@ def binding_from_python_std( ctx: _flyte_context.FlyteContext, var_name: str, expected_literal_type: _type_models.LiteralType, - t_value: typing.Any, + t_value: Any, t_value_type: type, ) -> _literals_models.Binding: binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type) @@ -671,7 +668,7 @@ class VoidPromise(object): VoidPromise cannot be interacted with and does not allow comparisons or any operations """ - def __init__(self, task_name: str, ref: typing.Optional[NodeOutput] = None): + def __init__(self, task_name: str, ref: Optional[NodeOutput] = None): self._task_name = task_name self._ref = ref @@ -682,10 +679,10 @@ def runs_before(self, *args, **kwargs): """ @property - def ref(self) -> typing.Optional[NodeOutput]: + def ref(self) -> Optional[NodeOutput]: return self._ref - def __rshift__(self, other: typing.Union[Promise, VoidPromise]): + def __rshift__(self, other: Union[Promise, VoidPromise]): if self.ref: self.ref.node.runs_before(other.ref.node) return other @@ -811,10 +808,26 @@ def extract_obj_name(name: str) -> str: def create_and_link_node_from_remote( ctx: FlyteContext, entity: HasFlyteInterface, + _inputs_not_allowed: Optional[Set[str]] = None, + _ignorable_inputs: Optional[Set[str]] = None, **kwargs, -): +) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: """ - This method is used to generate a node with bindings. This is not used in the execution path. + This method is used to generate a node with bindings especially when using remote entities, like FlyteWorkflow, + FlyteTask and FlyteLaunchplan. + + This method is kept separate from the similar named method `create_and_link_node` as remote entities have to be + handled differently. The major difference arises from the fact that the remote entities do not have a python + interface, so all comparisons need to happen using the Literals. + + :param ctx: FlyteContext + :param entity: RemoteEntity + :param _inputs_not_allowed: Set of all variable names that should not be provided when using this entity. + Useful for Launchplans with `fixed` inputs + :param _ignorable_inputs: Set of all variable names that are optional, but if provided will be overriden. Useful + for launchplans with `default` inputs + :param kwargs: Dict[str, Any] default inputs passed from the user to this entity. Can be promises. + :return: Optional[Union[Tuple[Promise], Promise, VoidPromise]] """ if ctx.compilation_state is None: raise _user_exceptions.FlyteAssertion("Cannot create node when not compiling...") @@ -824,9 +837,19 @@ def create_and_link_node_from_remote( typed_interface = entity.interface + if _inputs_not_allowed: + inputs_not_allowed_specified = _inputs_not_allowed.intersection(kwargs.keys()) + if inputs_not_allowed_specified: + raise _user_exceptions.FlyteAssertion( + f"Fixed inputs cannot be specified. Please remove the following inputs - {inputs_not_allowed_specified}" + ) + for k in sorted(typed_interface.inputs): var = typed_interface.inputs[k] if k not in kwargs: + if _inputs_not_allowed and _ignorable_inputs: + if k in _ignorable_inputs or k in _inputs_not_allowed: + continue # TODO to improve the error message, should we show python equivalent types for var.type? raise _user_exceptions.FlyteAssertion("Missing input `{}` type `{}`".format(k, var.type)) v = kwargs[k] @@ -854,7 +877,8 @@ def create_and_link_node_from_remote( extra_inputs = used_inputs ^ set(kwargs.keys()) if len(extra_inputs) > 0: raise _user_exceptions.FlyteAssertion( - "Too many inputs were specified for the interface. Extra inputs were: {}".format(extra_inputs) + f"Too many inputs for [{entity.name}] Expected inputs: {typed_interface.inputs.keys()} " + f"- extra inputs: {extra_inputs}" ) # Detect upstream nodes @@ -870,7 +894,6 @@ def create_and_link_node_from_remote( ) flytekit_node = Node( - # TODO: Better naming, probably a derivative of the function name. id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}", metadata=entity.construct_node_metadata(), bindings=sorted(bindings, key=lambda b: b.var), @@ -896,7 +919,13 @@ def create_and_link_node( **kwargs, ) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: """ - This method is used to generate a node with bindings. This is not used in the execution path. + This method is used to generate a node with bindings within a flytekit workflow. this is useful to traverse the + workflow using regular python interpreter and generate nodes and promises whenever an execution is encountered + + :param ctx: FlyteContext + :param entity: RemoteEntity + :param kwargs: Dict[str, Any] default inputs passed from the user to this entity. Can be promises. + :return: Optional[Union[Tuple[Promise], Promise, VoidPromise]] """ if ctx.compilation_state is None: raise _user_exceptions.FlyteAssertion("Cannot create node when not compiling...") @@ -994,7 +1023,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr ... -def flyte_entity_call_handler(entity: Union[SupportsNodeCreation], *args, **kwargs): +def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): """ This function is the call handler for tasks, workflows, and launch plans (which redirects to the underlying workflow). The logic is the same for all three, but we did not want to create base class, hence this separate @@ -1024,7 +1053,6 @@ def flyte_entity_call_handler(entity: Union[SupportsNodeCreation], *args, **kwar ) ctx = FlyteContextManager.current_context() - if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: return create_and_link_node(ctx, entity=entity, **kwargs) elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 06133d9784..2d05df3c3d 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -4,11 +4,16 @@ import re from abc import ABC from types import ModuleType -from typing import Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union + +from flyteidl.core import tasks_pb2 as _core_task +from kubernetes.client import ApiClient +from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements from flytekit.configuration import ImageConfig, SerializationSettings -from flytekit.core.base_task import PythonTask, TaskResolverMixin +from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance, extract_task_module @@ -18,6 +23,11 @@ from flytekit.models.security import Secret, SecurityContext T = TypeVar("T") +_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" + + +def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: + return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC): @@ -40,6 +50,8 @@ def __init__( environment: Optional[Dict[str, str]] = None, task_resolver: Optional[TaskResolverMixin] = None, secret_requests: Optional[List[Secret]] = None, + pod_template: Optional[PodTemplate] = None, + pod_template_name: Optional[str] = None, **kwargs, ): """ @@ -64,6 +76,8 @@ def __init__( - `Confidant `__ - `Kube secrets `__ - `AWS Parameter store `__ + :param pod_template: Custom PodTemplate for this task. + :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ sec_ctx = None if secret_requests: @@ -71,6 +85,11 @@ def __init__( if not isinstance(s, Secret): raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) + + # pod_template_name overwrites the metedata.pod_template_name + kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata() + kwargs["metadata"].pod_template_name = pod_template_name + super().__init__( task_type=task_type, name=name, @@ -98,6 +117,8 @@ def __init__( self._task_resolver = task_resolver or default_task_resolver self._get_command_fn = self.get_default_command + self.pod_template = pod_template + @property def task_resolver(self) -> Optional[TaskResolverMixin]: return self._task_resolver @@ -157,6 +178,13 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return self._get_command_fn(settings) def get_container(self, settings: SerializationSettings) -> _task_model.Container: + # if pod_template is not None, return None here but in get_k8s_pod, return pod_template merged with container + if self.pod_template is not None: + return None + else: + return self._get_container(settings) + + def _get_container(self, settings: SerializationSettings) -> _task_model.Container: env = {} for elem in (settings.env, self.environment): if elem: @@ -179,6 +207,64 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe memory_limit=self.resources.limits.mem, ) + def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: + containers = self.pod_template.pod_spec.containers + primary_exists = False + + for container in containers: + if container.name == self.pod_template.primary_container_name: + primary_exists = True + break + + if not primary_exists: + # insert a placeholder primary container if it is not defined in the pod spec. + containers.append(V1Container(name=self.pod_template.primary_container_name)) + final_containers = [] + for container in containers: + # In the case of the primary container, we overwrite specific container attributes + # with the default values used in the regular Python task. + # The attributes include: image, command, args, resource, and env (env is unioned) + if container.name == self.pod_template.primary_container_name: + sdk_default_container = self._get_container(settings) + container.image = sdk_default_container.image + # clear existing commands + container.command = sdk_default_container.command + # also clear existing args + container.args = sdk_default_container.args + limits, requests = {}, {} + for resource in sdk_default_container.resources.limits: + limits[_sanitize_resource_name(resource)] = resource.value + for resource in sdk_default_container.resources.requests: + requests[_sanitize_resource_name(resource)] = resource.value + resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) + if len(limits) > 0 or len(requests) > 0: + # Important! Only copy over resource requirements if they are non-empty. + container.resources = resource_requirements + container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + ( + container.env or [] + ) + final_containers.append(container) + self.pod_template.pod_spec.containers = final_containers + + return ApiClient().sanitize_for_serialization(self.pod_template.pod_spec) + + def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: + if self.pod_template is None: + return None + return _task_model.K8sPod( + pod_spec=self._serialize_pod_spec(settings), + metadata=_task_model.K8sObjectMetadata( + labels=self.pod_template.labels, + annotations=self.pod_template.annotations, + ), + ) + + # need to call super in all its children tasks + def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: + if self.pod_template is None: + return {} + return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name} + class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): """ diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index bcb80f34ca..81f6739a39 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -217,12 +217,7 @@ def compile_into_workflow( for entity, model in model_entities.items(): # We only care about gathering tasks here. Launch plans are handled by # propeller. Subworkflows should already be in the workflow spec. - if not isinstance(entity, Task) and not isinstance(entity, task_models.TaskTemplate): - continue - - # Handle FlyteTask - if isinstance(entity, task_models.TaskTemplate): - tts.append(entity) + if not isinstance(entity, Task) and not isinstance(entity, task_models.TaskSpec): continue # We are currently not supporting reference tasks since these will diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index 77b96e6892..7247457d86 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -43,6 +43,8 @@ def resource_type(self) -> int: @dataclass class TaskReference(Reference): + """A reference object containing metadata that points to a remote task.""" + @property def resource_type(self) -> int: return _identifier_model.ResourceType.TASK @@ -50,6 +52,8 @@ def resource_type(self) -> int: @dataclass class LaunchPlanReference(Reference): + """A reference object containing metadata that points to a remote launch plan.""" + @property def resource_type(self) -> int: return _identifier_model.ResourceType.LAUNCH_PLAN @@ -57,6 +61,8 @@ def resource_type(self) -> int: @dataclass class WorkflowReference(Reference): + """A reference object containing metadata that points to a remote workflow.""" + @property def resource_type(self) -> int: return _identifier_model.ResourceType.WORKFLOW diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 6e5b0a6b6a..28c5b5def7 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -4,9 +4,11 @@ from flytekit.core.base_task import TaskMetadata, TaskResolverMixin from flytekit.core.interface import transform_function_to_interface +from flytekit.core.pod_template import PodTemplate from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources +from flytekit.models.documentation import Documentation from flytekit.models.security import Secret @@ -89,7 +91,10 @@ def task( secret_requests: Optional[List[Secret]] = None, execution_mode: Optional[PythonFunctionTask.ExecutionBehavior] = PythonFunctionTask.ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, + docs: Optional[Documentation] = None, disable_deck: bool = True, + pod_template: Optional[PodTemplate] = None, + pod_template_name: Optional[str] = None, ) -> Union[Callable, PythonFunctionTask]: """ This is the core decorator to use for any task type in flytekit. @@ -179,6 +184,9 @@ def foo2(): :param execution_mode: This is mainly for internal use. Please ignore. It is filled in automatically. :param task_resolver: Provide a custom task resolver. :param disable_deck: If true, this task will not output deck html file + :param docs: Documentation about this task + :param pod_template: Custom PodTemplate for this task. + :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ def wrapper(fn) -> PythonFunctionTask: @@ -204,6 +212,9 @@ def wrapper(fn) -> PythonFunctionTask: execution_mode=execution_mode, task_resolver=task_resolver, disable_deck=disable_deck, + docs=docs, + pod_template=pod_template, + pod_template_name=pod_template_name, ) update_wrapper(task_instance, fn) return task_instance diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 98969e41b3..264b6e3860 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections +import copy import dataclasses import datetime as _datetime import enum @@ -320,7 +321,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: ) schema = None try: - s = cast(DataClassJsonMixin, t).schema() + s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema() for _, v in s.fields.items(): # marshmallow-jsonschema only supports enums loaded by name. # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 @@ -352,6 +353,46 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct())) ) + def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: + # dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is + # not hashable, such as Annotated[StructuredDataset, kwtypes(...)]. Therefore, we should just extract the origin + # type from annotated. + if get_origin(python_type) is list: + return typing.List[self._get_origin_type_in_annotation(get_args(python_type)[0])] # type: ignore + elif get_origin(python_type) is dict: + return typing.Dict[ # type: ignore + self._get_origin_type_in_annotation(get_args(python_type)[0]), + self._get_origin_type_in_annotation(get_args(python_type)[1]), + ] + elif get_origin(python_type) is Annotated: + return get_args(python_type)[0] + elif dataclasses.is_dataclass(python_type): + for field in dataclasses.fields(copy.deepcopy(python_type)): + field.type = self._get_origin_type_in_annotation(field.type) + return python_type + + def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T: + # In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict, + # so here we convert it back to the Structured Dataset. + from flytekit import StructuredDataset + + if python_type == StructuredDataset and type(python_val) == dict: + return StructuredDataset(**python_val) + elif get_origin(python_type) is list: + return [self._fix_structured_dataset_type(get_args(python_type)[0], v) for v in python_val] # type: ignore + elif get_origin(python_type) is dict: + return { # type: ignore + self._fix_structured_dataset_type(get_args(python_type)[0], k): self._fix_structured_dataset_type( + get_args(python_type)[1], v + ) + for k, v in python_val.items() + } + elif dataclasses.is_dataclass(python_type): + for field in dataclasses.fields(python_type): + val = python_val.__getattribute__(field.name) + python_val.__setattr__(field.name, self._fix_structured_dataset_type(field.type, val)) + return python_val + def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.Any: """ If any field inside the dataclass is flyte type, we should use flyte type transformer for that field. @@ -361,6 +402,12 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A from flytekit.types.schema.types import FlyteSchema from flytekit.types.structured.structured_dataset import StructuredDataset + # Handle Optional + if get_origin(python_type) is typing.Union and type(None) in get_args(python_type): + if python_val is None: + return None + return self._serialize_flyte_type(python_val, get_args(python_type)[0]) + if hasattr(python_type, "__origin__") and python_type.__origin__ is list: return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val] @@ -388,7 +435,9 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A if issubclass(python_type, FlyteFile) or issubclass(python_type, FlyteDirectory): return python_type(path=lv.scalar.blob.uri) elif issubclass(python_type, StructuredDataset): - return python_type(uri=lv.scalar.structured_dataset.uri) + sd = python_type(uri=lv.scalar.structured_dataset.uri) + sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format + return sd else: return python_val else: @@ -398,12 +447,18 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A python_val.__setattr__(v.name, self._serialize_flyte_type(val, field_type)) return python_val - def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T: + def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> Optional[T]: from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine + # Handle Optional + if get_origin(expected_python_type) is typing.Union and type(None) in get_args(expected_python_type): + if python_val is None: + return None + return self._deserialize_flyte_type(python_val, get_args(expected_python_type)[0]) + if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is list: return [self._deserialize_flyte_type(v, expected_python_type.__args__[0]) for v in python_val] # type: ignore @@ -533,8 +588,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Dataclass {expected_python_type} should be decorated with @dataclass_json to be " f"serialized correctly" ) - - dc = cast(DataClassJsonMixin, expected_python_type).from_json(_json_format.MessageToJson(lv.scalar.generic)) + json_str = _json_format.MessageToJson(lv.scalar.generic) + dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str) + dc = self._fix_structured_dataset_type(expected_python_type, dc) return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type)) # This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run`` @@ -1123,7 +1179,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: def guess_python_type(self, literal_type: LiteralType) -> type: if literal_type.union_type is not None: - return typing.Union[tuple(TypeEngine.guess_python_type(v.type) for v in literal_type.union_type.variants)] # type: ignore + return typing.Union[tuple(TypeEngine.guess_python_type(v) for v in literal_type.union_type.variants)] raise ValueError(f"Union transformer cannot reverse {literal_type}") diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 468a5aa7ea..8ba307b767 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -39,6 +39,7 @@ from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model +from flytekit.models.documentation import Description, Documentation GLOBAL_START_NODE = Node( id=_common_constants.GLOBAL_INPUT_NODE_ID, @@ -168,6 +169,7 @@ def __init__( workflow_metadata: WorkflowMetadata, workflow_metadata_defaults: WorkflowMetadataDefaults, python_interface: Interface, + docs: Optional[Documentation] = None, **kwargs, ): self._name = name @@ -179,6 +181,20 @@ def __init__( self._unbound_inputs = set() self._nodes = [] self._output_bindings: List[_literal_models.Binding] = [] + self._docs = docs + + if self._python_interface.docstring: + if self.docs is None: + self._docs = Documentation( + short_description=self._python_interface.docstring.short_description, + long_description=Description(value=self._python_interface.docstring.long_description), + ) + else: + if self._python_interface.docstring.short_description: + self._docs.short_description = self._python_interface.docstring.short_description + if self._python_interface.docstring.long_description: + self._docs = Description(value=self._python_interface.docstring.long_description) + FlyteEntities.entities.append(self) super().__init__(**kwargs) @@ -186,6 +202,10 @@ def __init__( def name(self) -> str: return self._name + @property + def docs(self): + return self._docs + @property def short_name(self) -> str: return extract_obj_name(self._name) @@ -208,10 +228,12 @@ def interface(self) -> _interface_models.TypedInterface: @property def output_bindings(self) -> List[_literal_models.Binding]: + self.compile() return self._output_bindings @property def nodes(self) -> List[Node]: + self.compile() return self._nodes def __repr__(self): @@ -235,11 +257,15 @@ def __call__(self, *args, **kwargs): # Get default arguments and override with kwargs passed in input_kwargs = self.python_interface.default_inputs_as_kwargs input_kwargs.update(kwargs) + self.compile() return flyte_entity_call_handler(self, *args, **input_kwargs) def execute(self, **kwargs): raise Exception("Should not be called") + def compile(self, **kwargs): + pass + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. @@ -250,6 +276,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # The output of this will always be a combination of Python native values and Promises containing Flyte # Literals. + self.compile() function_outputs = self.execute(**kwargs) # First handle the empty return case. @@ -571,7 +598,8 @@ def __init__( workflow_function: Callable, metadata: Optional[WorkflowMetadata], default_metadata: Optional[WorkflowMetadataDefaults], - docstring: Docstring = None, + docstring: Optional[Docstring] = None, + docs: Optional[Documentation] = None, ): name, _, _, _ = extract_task_module(workflow_function) self._workflow_function = workflow_function @@ -586,7 +614,9 @@ def __init__( workflow_metadata=metadata, workflow_metadata_defaults=default_metadata, python_interface=native_interface, + docs=docs, ) + self.compiled = False @property def function(self): @@ -600,6 +630,9 @@ def compile(self, **kwargs): Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics a 'closure' in the traditional sense of the word. """ + if self.compiled: + return + self.compiled = True ctx = FlyteContextManager.current_context() self._input_parameters = transform_inputs_to_parameters(ctx, self.python_interface) all_nodes = [] @@ -690,7 +723,8 @@ def workflow( _workflow_function=None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, -): + docs: Optional[Documentation] = None, +) -> WorkflowBase: """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG of tasks using the data flow between tasks. @@ -718,6 +752,7 @@ def workflow( :param _workflow_function: This argument is implicitly passed and represents the decorated function. :param failure_policy: Use the options in flytekit.WorkflowFailurePolicy :param interruptible: Whether or not tasks launched from this workflow are by default interruptible + :param docs: Description entity for the workflow """ def wrapper(fn): @@ -730,8 +765,8 @@ def wrapper(fn): metadata=workflow_metadata, default_metadata=workflow_metadata_defaults, docstring=Docstring(callable_=fn), + docs=docs, ) - workflow_instance.compile() update_wrapper(workflow_instance, fn) return workflow_instance diff --git a/flytekit/exceptions/user.py b/flytekit/exceptions/user.py index a93510cae1..6575218666 100644 --- a/flytekit/exceptions/user.py +++ b/flytekit/exceptions/user.py @@ -62,6 +62,10 @@ class FlyteValidationException(FlyteAssertion): _ERROR_CODE = "USER:ValidationError" +class FlyteDisapprovalException(FlyteAssertion): + _ERROR_CODE = "USER:ResultNotApproved" + + class FlyteEntityAlreadyExistsException(FlyteAssertion): _ERROR_CODE = "USER:EntityAlreadyExists" diff --git a/flytekit/extras/pytorch/__init__.py b/flytekit/extras/pytorch/__init__.py index 770fe11b73..a29d8e89e6 100644 --- a/flytekit/extras/pytorch/__init__.py +++ b/flytekit/extras/pytorch/__init__.py @@ -1,6 +1,4 @@ """ -Flytekit PyTorch -========================================= .. currentmodule:: flytekit.extras.pytorch .. autosummary:: @@ -8,6 +6,9 @@ :toctree: generated/ PyTorchCheckpoint + PyTorchCheckpointTransformer + PyTorchModuleTransformer + PyTorchTensorTransformer """ from flytekit.loggers import logger diff --git a/flytekit/extras/pytorch/native.py b/flytekit/extras/pytorch/native.py index 4cf37871fb..cbaa0c80f0 100644 --- a/flytekit/extras/pytorch/native.py +++ b/flytekit/extras/pytorch/native.py @@ -1,5 +1,5 @@ import pathlib -from typing import Generic, Type, TypeVar +from typing import Type, TypeVar import torch @@ -12,7 +12,7 @@ T = TypeVar("T") -class PyTorchTypeTransformer(TypeTransformer, Generic[T]): +class PyTorchTypeTransformer(TypeTransformer[T]): def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( @@ -63,30 +63,40 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: # load pytorch tensor/module from a file return torch.load(local_path, map_location=map_location) - def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + +class PyTorchTensorTransformer(PyTorchTypeTransformer[torch.Tensor]): + PYTORCH_FORMAT = "PyTorchTensor" + + def __init__(self): + super().__init__(name="PyTorch Tensor", t=torch.Tensor) + + def guess_python_type(self, literal_type: LiteralType) -> Type[torch.Tensor]: if ( literal_type.blob is not None and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE and literal_type.blob.format == self.PYTORCH_FORMAT ): - return T + return torch.Tensor raise ValueError(f"Transformer {self} cannot reverse {literal_type}") -class PyTorchTensorTransformer(PyTorchTypeTransformer[torch.Tensor]): - PYTORCH_FORMAT = "PyTorchTensor" - - def __init__(self): - super().__init__(name="PyTorch Tensor", t=torch.Tensor) - - class PyTorchModuleTransformer(PyTorchTypeTransformer[torch.nn.Module]): PYTORCH_FORMAT = "PyTorchModule" def __init__(self): super().__init__(name="PyTorch Module", t=torch.nn.Module) + def guess_python_type(self, literal_type: LiteralType) -> Type[torch.nn.Module]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.PYTORCH_FORMAT + ): + return torch.nn.Module + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + TypeEngine.register(PyTorchTensorTransformer()) TypeEngine.register(PyTorchModuleTransformer()) diff --git a/flytekit/extras/sklearn/__init__.py b/flytekit/extras/sklearn/__init__.py index 0a1bf2dda5..1d16f6080f 100644 --- a/flytekit/extras/sklearn/__init__.py +++ b/flytekit/extras/sklearn/__init__.py @@ -1,11 +1,11 @@ """ -Flytekit Sklearn -========================================= .. currentmodule:: flytekit.extras.sklearn .. autosummary:: :template: custom.rst :toctree: generated/ + + SklearnEstimatorTransformer """ from flytekit.loggers import logger diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 45c23da0ae..8e7d8b3b29 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -92,6 +92,8 @@ def __init__( container_image=container_image or DefaultImages.default_image(), executor_type=SQLite3TaskExecutor, task_type=self._SQLITE_TASK_TYPE, + # Sanitize query by removing the newlines at the end of the query. Keep in mind + # that the query can be a multiline string. query_template=query_template, inputs=inputs, outputs=outputs, diff --git a/flytekit/extras/tensorflow/__init__.py b/flytekit/extras/tensorflow/__init__.py new file mode 100644 index 0000000000..b5699906fb --- /dev/null +++ b/flytekit/extras/tensorflow/__init__.py @@ -0,0 +1,31 @@ +""" +.. currentmodule:: flytekit.extras.tensorflow + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + TensorFlowRecordFileTransformer + TensorFlowRecordsDirTransformer +""" + +from flytekit.loggers import logger + +# TODO: abstract this out so that there's an established pattern for registering plugins +# that have soft dependencies +try: + # isolate the exception to the tensorflow import + import tensorflow + + _tensorflow_installed = True +except (ImportError, OSError): + _tensorflow_installed = False + + +if _tensorflow_installed: + from .record import TensorFlowRecordFileTransformer, TensorFlowRecordsDirTransformer +else: + logger.info( + "We won't register TensorFlowRecordFileTransformer and TensorFlowRecordsDirTransformer " + "because tensorflow is not installed." + ) diff --git a/flytekit/extras/tensorflow/record.py b/flytekit/extras/tensorflow/record.py new file mode 100644 index 0000000000..d5d750b521 --- /dev/null +++ b/flytekit/extras/tensorflow/record.py @@ -0,0 +1,188 @@ +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Type, Union + +import tensorflow as tf +from dataclasses_json import dataclass_json +from tensorflow.python.data.ops.readers import TFRecordDatasetV2 +from typing_extensions import Annotated, get_args, get_origin + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType +from flytekit.types.directory import TFRecordsDirectory +from flytekit.types.file import TFRecordFile + + +@dataclass_json +@dataclass +class TFRecordDatasetConfig: + """ + TFRecordDatasetConfig can be used while creating tf.data.TFRecordDataset comprising + record of one or more TFRecord files. + + Args: + compression_type: A scalar evaluating to one of "" (no compression), "ZLIB", or "GZIP". + buffer_size: The number of bytes in the read buffer. If None, a sensible default for both local and remote file systems is used. + num_parallel_reads: The number of files to read in parallel. If greater than one, the record of files read in parallel are outputted in an interleaved order. + name: A name for the operation. + """ + + compression_type: Optional[str] = None + buffer_size: Optional[int] = None + num_parallel_reads: Optional[int] = None + name: Optional[str] = None + + +def extract_metadata_and_uri( + lv: Literal, t: Type[Union[TFRecordFile, TFRecordsDirectory]] +) -> Tuple[Union[TFRecordFile, TFRecordsDirectory], TFRecordDatasetConfig]: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {t}") + metadata = TFRecordDatasetConfig() + if get_origin(t) is Annotated: + _, metadata = get_args(t) + if isinstance(metadata, TFRecordDatasetConfig): + return uri, metadata + else: + raise TypeTransformerFailedError(f"{t}'s metadata needs to be of type TFRecordDatasetConfig") + return uri, metadata + + +class TensorFlowRecordFileTransformer(TypeTransformer[TFRecordFile]): + """ + TypeTransformer that supports serialising and deserialising to and from TFRecord file. + https://www.tensorflow.org/tutorials/load_data/tfrecord + """ + + TENSORFLOW_FORMAT = "TensorFlowRecord" + + def __init__(self): + super().__init__(name="TensorFlow Record File", t=TFRecordFile) + + def get_literal_type(self, t: Type[TFRecordFile]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + def to_literal( + self, ctx: FlyteContext, python_val: TFRecordFile, python_type: Type[TFRecordFile], expected: LiteralType + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + local_dir = ctx.file_access.get_random_local_directory() + remote_path = ctx.file_access.get_random_remote_path() + local_path = os.path.join(local_dir, "0000.tfrecord") + with tf.io.TFRecordWriter(local_path) as writer: + writer.write(python_val.SerializeToString()) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[TFRecordFile] + ) -> TFRecordDatasetV2: + uri, metadata = extract_metadata_and_uri(lv, expected_python_type) + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=False) + filenames = [local_path] + return tf.data.TFRecordDataset( + filenames=filenames, + compression_type=metadata.compression_type, + buffer_size=metadata.buffer_size, + num_parallel_reads=metadata.num_parallel_reads, + name=metadata.name, + ) + + def guess_python_type(self, literal_type: LiteralType) -> Type[TFRecordFile]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.TENSORFLOW_FORMAT + ): + return TFRecordFile + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +class TensorFlowRecordsDirTransformer(TypeTransformer[TFRecordsDirectory]): + """ + TypeTransformer that supports serialising and deserialising to and from TFRecord directory. + https://www.tensorflow.org/tutorials/load_data/tfrecord + """ + + TENSORFLOW_FORMAT = "TensorFlowRecord" + + def __init__(self): + super().__init__(name="TensorFlow Record Directory", t=TFRecordsDirectory) + + def get_literal_type(self, t: Type[TFRecordsDirectory]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: TFRecordsDirectory, + python_type: Type[TFRecordsDirectory], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) + ) + local_dir = ctx.file_access.get_random_local_directory() + remote_path = ctx.file_access.get_random_remote_directory() + for i, val in enumerate(python_val): + local_path = f"{local_dir}/part_{i}.tfrecord" + with tf.io.TFRecordWriter(local_path) as writer: + writer.write(val.SerializeToString()) + ctx.file_access.upload_directory(local_dir, remote_path) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[TFRecordsDirectory] + ) -> TFRecordDatasetV2: + + uri, metadata = extract_metadata_and_uri(lv, expected_python_type) + local_dir = ctx.file_access.get_random_local_directory() + ctx.file_access.get_data(uri, local_dir, is_multipart=True) + files = os.scandir(local_dir) + filenames = [os.path.join(local_dir, f.name) for f in files] + return tf.data.TFRecordDataset( + filenames=filenames, + compression_type=metadata.compression_type, + buffer_size=metadata.buffer_size, + num_parallel_reads=metadata.num_parallel_reads, + name=metadata.name, + ) + + def guess_python_type(self, literal_type: LiteralType) -> Type[TFRecordsDirectory]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART + and literal_type.blob.format == self.TENSORFLOW_FORMAT + ): + return TFRecordsDirectory + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(TensorFlowRecordsDirTransformer()) +TypeEngine.register(TensorFlowRecordFileTransformer()) diff --git a/flytekit/interaction/__init__.py b/flytekit/interaction/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/interaction/parse_stdin.py b/flytekit/interaction/parse_stdin.py new file mode 100644 index 0000000000..ec051d73ce --- /dev/null +++ b/flytekit/interaction/parse_stdin.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import typing + +import click + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine +from flytekit.loggers import logger +from flytekit.models.literals import Literal + + +# TODO: Move the improved click parsing here. https://github.com/flyteorg/flyte/issues/3124 +def parse_stdin_to_literal(ctx: FlyteContext, t: typing.Type, message_prefix: typing.Optional[str]) -> Literal: + + message = message_prefix or "" + message += f"Please enter value for type {t} to continue" + if issubclass(t, bool): + user_input = click.prompt(message, type=bool) + l = TypeEngine.to_literal(ctx, user_input, bool, TypeEngine.to_literal_type(bool)) # noqa + elif issubclass(t, int): + user_input = click.prompt(message, type=int) + l = TypeEngine.to_literal(ctx, user_input, int, TypeEngine.to_literal_type(int)) # noqa + elif issubclass(t, float): + user_input = click.prompt(message, type=float) + l = TypeEngine.to_literal(ctx, user_input, float, TypeEngine.to_literal_type(float)) # noqa + elif issubclass(t, str): + user_input = click.prompt(message, type=str) + l = TypeEngine.to_literal(ctx, user_input, str, TypeEngine.to_literal_type(str)) # noqa + else: + # Todo: We should implement the rest by way of importing the code in pyflyte run + # that parses text from the command line + raise Exception("Only bool, int/float, or strings are accepted for now.") + + logger.debug(f"Parsed literal {l} from user input {user_input}") + return l diff --git a/flytekit/loggers.py b/flytekit/loggers.py index 0c8c2e035a..f047348de0 100644 --- a/flytekit/loggers.py +++ b/flytekit/loggers.py @@ -13,12 +13,6 @@ # By default, the root flytekit logger to debug so everything is logged, but enable fine-tuning logger = logging.getLogger("flytekit") -# Root logger control -flytekit_root_env_var = f"{LOGGING_ENV_VAR}_ROOT" -if os.getenv(flytekit_root_env_var) is not None: - logger.setLevel(int(os.getenv(flytekit_root_env_var))) -else: - logger.setLevel(logging.DEBUG) # Stop propagation so that configuration is isolated to this file (so that it doesn't matter what the # global Python root logger is set to). @@ -40,22 +34,33 @@ # create console handler ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) +# Root logger control # Don't want to import the configuration library since that will cause all sorts of circular imports, let's # just use the environment variable if it's defined. Decide in the future when we implement better controls # if we should control with the channel or with the logger level. # The handler log level controls whether log statements will actually print to the screen +flytekit_root_env_var = f"{LOGGING_ENV_VAR}_ROOT" level_from_env = os.getenv(LOGGING_ENV_VAR) -if level_from_env is not None: - ch.setLevel(int(level_from_env)) +root_level_from_env = os.getenv(flytekit_root_env_var) +if root_level_from_env is not None: + logger.setLevel(int(root_level_from_env)) +elif level_from_env is not None: + logger.setLevel(int(level_from_env)) else: - ch.setLevel(logging.WARNING) + logger.setLevel(logging.WARNING) for log_name, child_logger in child_loggers.items(): env_var = f"{LOGGING_ENV_VAR}_{log_name.upper()}" level_from_env = os.getenv(env_var) if level_from_env is not None: child_logger.setLevel(int(level_from_env)) + else: + if child_logger is user_space_logger: + child_logger.setLevel(logging.INFO) + else: + child_logger.setLevel(logging.WARNING) # create formatter formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s") diff --git a/flytekit/models/admin/workflow.py b/flytekit/models/admin/workflow.py index f34e692123..e40307b6ba 100644 --- a/flytekit/models/admin/workflow.py +++ b/flytekit/models/admin/workflow.py @@ -1,13 +1,21 @@ +import typing + from flyteidl.admin import workflow_pb2 as _admin_workflow from flytekit.models import common as _common from flytekit.models.core import compiler as _compiler_models from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _core_workflow +from flytekit.models.documentation import Documentation class WorkflowSpec(_common.FlyteIdlEntity): - def __init__(self, template, sub_workflows): + def __init__( + self, + template: _core_workflow.WorkflowTemplate, + sub_workflows: typing.List[_core_workflow.WorkflowTemplate], + docs: typing.Optional[Documentation] = None, + ): """ This object fully encapsulates the specification of a workflow :param flytekit.models.core.workflow.WorkflowTemplate template: @@ -15,6 +23,7 @@ def __init__(self, template, sub_workflows): """ self._template = template self._sub_workflows = sub_workflows + self._docs = docs @property def template(self): @@ -30,6 +39,13 @@ def sub_workflows(self): """ return self._sub_workflows + @property + def docs(self): + """ + :rtype: Description entity for the workflow + """ + return self._docs + def to_flyte_idl(self): """ :rtype: flyteidl.admin.workflow_pb2.WorkflowSpec @@ -37,6 +53,7 @@ def to_flyte_idl(self): return _admin_workflow.WorkflowSpec( template=self._template.to_flyte_idl(), sub_workflows=[s.to_flyte_idl() for s in self._sub_workflows], + description=self._docs.to_flyte_idl() if self._docs else None, ) @classmethod @@ -48,6 +65,7 @@ def from_flyte_idl(cls, pb2_object): return cls( _core_workflow.WorkflowTemplate.from_flyte_idl(pb2_object.template), [_core_workflow.WorkflowTemplate.from_flyte_idl(s) for s in pb2_object.sub_workflows], + Documentation.from_flyte_idl(pb2_object.description) if pb2_object.description else None, ) diff --git a/flytekit/models/common.py b/flytekit/models/common.py index 7236dd15ce..62018c1eef 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -414,8 +414,10 @@ def from_flyte_idl(cls, pb): class AuthRole(FlyteIdlEntity): def __init__(self, assumable_iam_role=None, kubernetes_service_account=None): - """ + """Auth configuration for IAM or K8s service account. + Either one or both of the assumable IAM role and/or the K8s service account can be set. + :param Text assumable_iam_role: IAM identity with set permissions policies. :param Text kubernetes_service_account: Provides an identity for workflow execution resources. Flyte deployment administrators are responsible for handling permissions as they diff --git a/flytekit/models/core/identifier.py b/flytekit/models/core/identifier.py index bf46ace349..8a45232e38 100644 --- a/flytekit/models/core/identifier.py +++ b/flytekit/models/core/identifier.py @@ -1,13 +1,13 @@ -from flyteidl.core import identifier_pb2 as _identifier_pb2 +from flyteidl.core import identifier_pb2 as identifier_pb2 from flytekit.models import common as _common_models class ResourceType(object): - UNSPECIFIED = _identifier_pb2.UNSPECIFIED - TASK = _identifier_pb2.TASK - WORKFLOW = _identifier_pb2.WORKFLOW - LAUNCH_PLAN = _identifier_pb2.LAUNCH_PLAN + UNSPECIFIED = identifier_pb2.UNSPECIFIED + TASK = identifier_pb2.TASK + WORKFLOW = identifier_pb2.WORKFLOW + LAUNCH_PLAN = identifier_pb2.LAUNCH_PLAN class Identifier(_common_models.FlyteIdlEntity): @@ -34,7 +34,7 @@ def resource_type(self): return self._resource_type def resource_type_name(self) -> str: - return _identifier_pb2.ResourceType.Name(self.resource_type) + return identifier_pb2.ResourceType.Name(self.resource_type) @property def project(self): @@ -68,7 +68,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.identifier_pb2.Identifier """ - return _identifier_pb2.Identifier( + return identifier_pb2.Identifier( resource_type=self.resource_type, project=self.project, domain=self.domain, @@ -133,7 +133,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.identifier_pb2.WorkflowExecutionIdentifier """ - return _identifier_pb2.WorkflowExecutionIdentifier( + return identifier_pb2.WorkflowExecutionIdentifier( project=self.project, domain=self.domain, name=self.name, @@ -179,7 +179,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.identifier_pb2.NodeExecutionIdentifier """ - return _identifier_pb2.NodeExecutionIdentifier( + return identifier_pb2.NodeExecutionIdentifier( node_id=self.node_id, execution_id=self.execution_id.to_flyte_idl(), ) @@ -232,7 +232,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.identifier_pb2.TaskExecutionIdentifier """ - return _identifier_pb2.TaskExecutionIdentifier( + return identifier_pb2.TaskExecutionIdentifier( task_id=self.task_id.to_flyte_idl(), node_execution_id=self.node_execution_id.to_flyte_idl(), retry_attempt=self.retry_attempt, @@ -249,3 +249,34 @@ def from_flyte_idl(cls, proto): node_execution_id=NodeExecutionIdentifier.from_flyte_idl(proto.node_execution_id), retry_attempt=proto.retry_attempt, ) + + +class SignalIdentifier(_common_models.FlyteIdlEntity): + def __init__(self, signal_id: str, execution_id: WorkflowExecutionIdentifier): + """ + :param signal_id: User provided name for the gate node. + :param execution_id: The workflow execution id this signal is for. + """ + self._signal_id = signal_id + self._execution_id = execution_id + + @property + def signal_id(self) -> str: + return self._signal_id + + @property + def execution_id(self) -> WorkflowExecutionIdentifier: + return self._execution_id + + def to_flyte_idl(self) -> identifier_pb2.SignalIdentifier: + return identifier_pb2.SignalIdentifier( + signal_id=self.signal_id, + execution_id=self.execution_id.to_flyte_idl(), + ) + + @classmethod + def from_flyte_idl(cls, proto: identifier_pb2.SignalIdentifier) -> "SignalIdentifier": + return cls( + signal_id=proto.signal_id, + execution_id=WorkflowExecutionIdentifier.from_flyte_idl(proto.execution_id), + ) diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index e9e68ce485..1af53b3a53 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -5,7 +5,7 @@ from flytekit.models import common as _common from flytekit.models import interface as _interface -from flytekit.models import types as _types +from flytekit.models import types as type_models from flytekit.models.core import condition as _condition from flytekit.models.core import identifier as _identifier from flytekit.models.literals import Binding as _Binding @@ -61,7 +61,7 @@ def __init__(self, case, other=None, else_node=None, error=None): :param IfBlock case: :param list[IfBlock] other: :param Node else_node: - :param _types.Error error: + :param type_models.Error error: """ self._case = case self._other = other @@ -121,7 +121,7 @@ def from_flyte_idl(cls, pb2_object): case=IfBlock.from_flyte_idl(pb2_object.case), other=[IfBlock.from_flyte_idl(a) for a in pb2_object.other], else_node=Node.from_flyte_idl(pb2_object.else_node) if pb2_object.HasField("else_node") else None, - error=_types.Error.from_flyte_idl(pb2_object.error) if pb2_object.HasField("error") else None, + error=type_models.Error.from_flyte_idl(pb2_object.error) if pb2_object.HasField("error") else None, ) @@ -219,6 +219,127 @@ def from_flyte_idl(cls, pb2_object): ) +class SignalCondition(_common.FlyteIdlEntity): + def __init__(self, signal_id: str, type: type_models.LiteralType, output_variable_name: str): + """ + Represents a dependency on an signal from a user. + :param signal_id: The node id of the signal, also the signal name. + :param type: + """ + self._signal_id = signal_id + self._type = type + self._output_variable_name = output_variable_name + + @property + def signal_id(self) -> str: + return self._signal_id + + @property + def type(self) -> type_models.LiteralType: + return self._type + + @property + def output_variable_name(self) -> str: + return self._output_variable_name + + def to_flyte_idl(self) -> _core_workflow.SignalCondition: + return _core_workflow.SignalCondition( + signal_id=self.signal_id, type=self.type.to_flyte_idl(), output_variable_name=self.output_variable_name + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _core_workflow.SignalCondition): + return cls( + signal_id=pb2_object.signal_id, + type=type_models.LiteralType.from_flyte_idl(pb2_object.type), + output_variable_name=pb2_object.output_variable_name, + ) + + +class ApproveCondition(_common.FlyteIdlEntity): + def __init__(self, signal_id: str): + """ + Represents a dependency on an signal from a user. + :param signal_id: The node id of the signal, also the signal name. + """ + self._signal_id = signal_id + + @property + def signal_id(self) -> str: + return self._signal_id + + def to_flyte_idl(self) -> _core_workflow.ApproveCondition: + return _core_workflow.ApproveCondition(signal_id=self.signal_id) + + @classmethod + def from_flyte_idl(cls, pb2_object: _core_workflow.ApproveCondition): + return cls(signal_id=pb2_object.signal_id) + + +class SleepCondition(_common.FlyteIdlEntity): + def __init__(self, duration: datetime.timedelta): + """ + A sleep condition. + """ + self._duration = duration + + @property + def duration(self) -> datetime.timedelta: + return self._duration + + def to_flyte_idl(self) -> _core_workflow.SleepCondition: + sc = _core_workflow.SleepCondition() + sc.duration.FromTimedelta(self.duration) + return sc + + @classmethod + def from_flyte_idl(cls, pb2_object: _core_workflow.SignalCondition) -> "SleepCondition": + return cls(duration=pb2_object.duration.ToTimedelta()) + + +class GateNode(_common.FlyteIdlEntity): + def __init__( + self, + signal: typing.Optional[SignalCondition] = None, + sleep: typing.Optional[SleepCondition] = None, + approve: typing.Optional[ApproveCondition] = None, + ): + self._signal = signal + self._sleep = sleep + self._approve = approve + + @property + def signal(self) -> typing.Optional[SignalCondition]: + return self._signal + + @property + def sleep(self) -> typing.Optional[SignalCondition]: + return self._sleep + + @property + def approve(self) -> typing.Optional[ApproveCondition]: + return self._approve + + @property + def condition(self) -> typing.Union[SignalCondition, SleepCondition, ApproveCondition]: + return self.signal or self.sleep or self.approve + + def to_flyte_idl(self) -> _core_workflow.GateNode: + return _core_workflow.GateNode( + signal=self.signal.to_flyte_idl() if self.signal else None, + sleep=self.sleep.to_flyte_idl() if self.sleep else None, + approve=self.approve.to_flyte_idl() if self.approve else None, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _core_workflow.GateNode) -> "GateNode": + return cls( + signal=SignalCondition.from_flyte_idl(pb2_object.signal) if pb2_object.HasField("signal") else None, + sleep=SleepCondition.from_flyte_idl(pb2_object.sleep) if pb2_object.HasField("sleep") else None, + approve=ApproveCondition.from_flyte_idl(pb2_object.approve) if pb2_object.HasField("approve") else None, + ) + + class Node(_common.FlyteIdlEntity): def __init__( self, @@ -230,6 +351,7 @@ def __init__( task_node=None, workflow_node=None, branch_node=None, + gate_node: typing.Optional[GateNode] = None, ): """ A Workflow graph Node. One unit of execution in the graph. Each node can be linked to a Task, @@ -260,6 +382,7 @@ def __init__( self._task_node = task_node self._workflow_node = workflow_node self._branch_node = branch_node + self._gate_node = gate_node @property def id(self): @@ -331,6 +454,10 @@ def branch_node(self): """ return self._branch_node + @property + def gate_node(self) -> typing.Optional[GateNode]: + return self._gate_node + @property def target(self): """ @@ -351,6 +478,7 @@ def to_flyte_idl(self): task_node=self.task_node.to_flyte_idl() if self.task_node is not None else None, workflow_node=self.workflow_node.to_flyte_idl() if self.workflow_node is not None else None, branch_node=self.branch_node.to_flyte_idl() if self.branch_node is not None else None, + gate_node=self.gate_node.to_flyte_idl() if self.gate_node else None, ) @classmethod @@ -372,6 +500,7 @@ def from_flyte_idl(cls, pb2_object): branch_node=BranchNode.from_flyte_idl(pb2_object.branch_node) if pb2_object.HasField("branch_node") else None, + gate_node=GateNode.from_flyte_idl(pb2_object.gate_node) if pb2_object.HasField("gate_node") else None, ) diff --git a/flytekit/models/documentation.py b/flytekit/models/documentation.py new file mode 100644 index 0000000000..e1bae8122e --- /dev/null +++ b/flytekit/models/documentation.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from flyteidl.admin import description_entity_pb2 + +from flytekit.models import common as _common_models + + +@dataclass +class Description(_common_models.FlyteIdlEntity): + """ + Full user description with formatting preserved. This can be rendered + by clients, such as the console or command line tools with in-tact + formatting. + """ + + class DescriptionFormat(Enum): + UNKNOWN = 0 + MARKDOWN = 1 + HTML = 2 + RST = 3 + + value: Optional[str] = None + uri: Optional[str] = None + icon_link: Optional[str] = None + format: DescriptionFormat = DescriptionFormat.RST + + def to_flyte_idl(self): + return description_entity_pb2.Description( + value=self.value if self.value else None, + uri=self.uri if self.uri else None, + format=self.format.value, + icon_link=self.icon_link, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.Description) -> "Description": + return cls( + value=pb2_object.value if pb2_object.value else None, + uri=pb2_object.uri if pb2_object.uri else None, + format=Description.DescriptionFormat(pb2_object.format), + icon_link=pb2_object.icon_link if pb2_object.icon_link else None, + ) + + +@dataclass +class SourceCode(_common_models.FlyteIdlEntity): + """ + Link to source code used to define this task or workflow. + """ + + link: Optional[str] = None + + def to_flyte_idl(self): + return description_entity_pb2.SourceCode(link=self.link) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.SourceCode) -> "SourceCode": + return cls(link=pb2_object.link) if pb2_object.link else None + + +@dataclass +class Documentation(_common_models.FlyteIdlEntity): + """ + DescriptionEntity contains detailed description for the task/workflow/launch plan. + Documentation could provide insight into the algorithms, business use case, etc. + Args: + short_description (str): One-liner overview of the entity. + long_description (Optional[Description]): Full user description with formatting preserved. + source_code (Optional[SourceCode]): link to source code used to define this entity + """ + + short_description: Optional[str] = None + long_description: Optional[Description] = None + source_code: Optional[SourceCode] = None + + def to_flyte_idl(self): + return description_entity_pb2.DescriptionEntity( + short_description=self.short_description, + long_description=self.long_description.to_flyte_idl() if self.long_description else None, + source_code=self.source_code.to_flyte_idl() if self.source_code else None, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.DescriptionEntity) -> "Documentation": + return cls( + short_description=pb2_object.short_description, + long_description=Description.from_flyte_idl(pb2_object.long_description) + if pb2_object.long_description + else None, + source_code=SourceCode.from_flyte_idl(pb2_object.source_code) if pb2_object.source_code else None, + ) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 75e040891b..9c2d5ba2ec 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -175,6 +175,7 @@ def __init__( raw_output_data_config=None, max_parallelism=None, security_context: typing.Optional[security.SecurityContext] = None, + overwrite_cache: bool = None, ): """ :param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute @@ -200,6 +201,7 @@ def __init__( self._raw_output_data_config = raw_output_data_config self._max_parallelism = max_parallelism self._security_context = security_context + self.overwrite_cache = overwrite_cache @property def launch_plan(self): @@ -283,6 +285,7 @@ def to_flyte_idl(self): else None, max_parallelism=self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, + overwrite_cache=self.overwrite_cache, ) @classmethod @@ -306,6 +309,7 @@ def from_flyte_idl(cls, p): security_context=security.SecurityContext.from_flyte_idl(p.security_context) if p.security_context else None, + overwrite_cache=p.overwrite_cache, ) @@ -441,6 +445,8 @@ def __init__( error: typing.Optional[flytekit.models.core.execution.ExecutionError] = None, outputs: typing.Optional[LiteralMapBlob] = None, abort_metadata: typing.Optional[AbortMetadata] = None, + created_at: typing.Optional[datetime.datetime] = None, + updated_at: typing.Optional[datetime.datetime] = None, ): """ :param phase: From the flytekit.models.core.execution.WorkflowExecutionPhase enum @@ -456,6 +462,8 @@ def __init__( self._error = error self._outputs = outputs self._abort_metadata = abort_metadata + self._created_at = created_at + self._updated_at = updated_at @property def error(self) -> flytekit.models.core.execution.ExecutionError: @@ -476,6 +484,14 @@ def started_at(self) -> datetime.datetime: def duration(self) -> datetime.timedelta: return self._duration + @property + def created_at(self) -> typing.Optional[datetime.datetime]: + return self._created_at + + @property + def updated_at(self) -> typing.Optional[datetime.datetime]: + return self._updated_at + @property def outputs(self) -> LiteralMapBlob: return self._outputs @@ -496,6 +512,10 @@ def to_flyte_idl(self): ) obj.started_at.FromDatetime(self.started_at.astimezone(_pytz.UTC).replace(tzinfo=None)) obj.duration.FromTimedelta(self.duration) + if self.created_at: + obj.created_at.FromDatetime(self.created_at.astimezone(_pytz.UTC).replace(tzinfo=None)) + if self.updated_at: + obj.updated_at.FromDatetime(self.updated_at.astimezone(_pytz.UTC).replace(tzinfo=None)) return obj @classmethod @@ -520,6 +540,12 @@ def from_flyte_idl(cls, pb2_object): started_at=pb2_object.started_at.ToDatetime().replace(tzinfo=_pytz.UTC), duration=pb2_object.duration.ToTimedelta(), abort_metadata=abort_metadata, + created_at=pb2_object.created_at.ToDatetime().replace(tzinfo=_pytz.UTC) + if pb2_object.HasField("created_at") + else None, + updated_at=pb2_object.updated_at.ToDatetime().replace(tzinfo=_pytz.UTC) + if pb2_object.HasField("updated_at") + else None, ) diff --git a/flytekit/models/node_execution.py b/flytekit/models/node_execution.py index 220db5cc5f..335a793db6 100644 --- a/flytekit/models/node_execution.py +++ b/flytekit/models/node_execution.py @@ -1,3 +1,4 @@ +import datetime import typing import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 @@ -96,6 +97,8 @@ def __init__( error=None, workflow_node_metadata: typing.Optional[WorkflowNodeMetadata] = None, task_node_metadata: typing.Optional[TaskNodeMetadata] = None, + created_at: typing.Optional[datetime.datetime] = None, + updated_at: typing.Optional[datetime.datetime] = None, ): """ :param int phase: @@ -113,6 +116,8 @@ def __init__( self._workflow_node_metadata = workflow_node_metadata self._task_node_metadata = task_node_metadata # TODO: Add output_data field as well. + self._created_at = created_at + self._updated_at = updated_at @property def phase(self): @@ -135,6 +140,14 @@ def duration(self): """ return self._duration + @property + def created_at(self) -> typing.Optional[datetime.datetime]: + return self._created_at + + @property + def updated_at(self) -> typing.Optional[datetime.datetime]: + return self._updated_at + @property def output_uri(self): """ @@ -184,6 +197,10 @@ def to_flyte_idl(self): ) obj.started_at.FromDatetime(self.started_at.astimezone(_pytz.UTC).replace(tzinfo=None)) obj.duration.FromTimedelta(self.duration) + if self.created_at: + obj.created_at.FromDatetime(self.created_at.astimezone(_pytz.UTC).replace(tzinfo=None)) + if self.updated_at: + obj.updated_at.FromDatetime(self.updated_at.astimezone(_pytz.UTC).replace(tzinfo=None)) return obj @classmethod @@ -205,6 +222,8 @@ def from_flyte_idl(cls, p): task_node_metadata=TaskNodeMetadata.from_flyte_idl(p.task_node_metadata) if p.HasField("task_node_metadata") else None, + created_at=p.created_at.ToDatetime().replace(tzinfo=_pytz.UTC) if p.HasField("created_at") else None, + updated_at=p.updated_at.ToDatetime().replace(tzinfo=_pytz.UTC) if p.HasField("updated_at") else None, ) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index f2ff5efd89..fc79c87a2d 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -13,6 +13,7 @@ from flytekit.models import literals as _literals from flytekit.models import security as _sec from flytekit.models.core import identifier as _identifier +from flytekit.models.documentation import Documentation class Resources(_common.FlyteIdlEntity): @@ -176,6 +177,7 @@ def __init__( discovery_version, deprecated_error_message, cache_serializable, + pod_template_name, ): """ Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, @@ -195,6 +197,7 @@ def __init__( receive deprecation warnings. :param bool cache_serializable: Whether or not caching operations are executed in serial. This means only a single instance over identical inputs is executed, other concurrent executions wait for the cached results. + :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ self._discoverable = discoverable self._runtime = runtime @@ -204,6 +207,7 @@ def __init__( self._discovery_version = discovery_version self._deprecated_error_message = deprecated_error_message self._cache_serializable = cache_serializable + self._pod_template_name = pod_template_name @property def discoverable(self): @@ -273,6 +277,14 @@ def cache_serializable(self): """ return self._cache_serializable + @property + def pod_template_name(self): + """ + The name of the existing PodTemplate resource which will be used in this task. + :rtype: Text + """ + return self._pod_template_name + def to_flyte_idl(self): """ :rtype: flyteidl.admin.task_pb2.TaskMetadata @@ -285,6 +297,7 @@ def to_flyte_idl(self): discovery_version=self.discovery_version, deprecated_error_message=self.deprecated_error_message, cache_serializable=self.cache_serializable, + pod_template_name=self.pod_template_name, ) if self.timeout: tm.timeout.FromTimedelta(self.timeout) @@ -305,6 +318,7 @@ def from_flyte_idl(cls, pb2_object): discovery_version=pb2_object.discovery_version, deprecated_error_message=pb2_object.deprecated_error_message, cache_serializable=pb2_object.cache_serializable, + pod_template_name=pb2_object.pod_template_name, ) @@ -480,11 +494,13 @@ def from_flyte_idl(cls, pb2_object): class TaskSpec(_common.FlyteIdlEntity): - def __init__(self, template): + def __init__(self, template: TaskTemplate, docs: typing.Optional[Documentation] = None): """ :param TaskTemplate template: + :param Documentation docs: """ self._template = template + self._docs = docs @property def template(self): @@ -493,11 +509,20 @@ def template(self): """ return self._template + @property + def docs(self): + """ + :rtype: Description entity for the task + """ + return self._docs + def to_flyte_idl(self): """ :rtype: flyteidl.admin.tasks_pb2.TaskSpec """ - return _admin_task.TaskSpec(template=self.template.to_flyte_idl()) + return _admin_task.TaskSpec( + template=self.template.to_flyte_idl(), description=self.docs.to_flyte_idl() if self.docs else None + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -505,7 +530,10 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.tasks_pb2.TaskSpec pb2_object: :rtype: TaskSpec """ - return cls(TaskTemplate.from_flyte_idl(pb2_object.template)) + return cls( + TaskTemplate.from_flyte_idl(pb2_object.template), + Documentation.from_flyte_idl(pb2_object.description) if pb2_object.description else None, + ) class Task(_common.FlyteIdlEntity): diff --git a/flytekit/remote/__init__.py b/flytekit/remote/__init__.py index 643d613231..4d6f172586 100644 --- a/flytekit/remote/__init__.py +++ b/flytekit/remote/__init__.py @@ -51,9 +51,9 @@ :toctree: generated/ :nosignatures: - ~task.FlyteTask - ~workflow.FlyteWorkflow - ~launch_plan.FlyteLaunchPlan + ~entities.FlyteTask + ~entities.FlyteWorkflow + ~entities.FlyteLaunchPlan .. _remote-flyte-entity-components: @@ -65,9 +65,9 @@ :toctree: generated/ :nosignatures: - ~nodes.FlyteNode - ~component_nodes.FlyteTaskNode - ~component_nodes.FlyteWorkflowNode + ~entities.FlyteNode + ~entities.FlyteTaskNode + ~entities.FlyteWorkflowNode .. _remote-flyte-execution-objects: @@ -85,10 +85,14 @@ """ -from flytekit.remote.component_nodes import FlyteTaskNode, FlyteWorkflowNode +from flytekit.remote.entities import ( + FlyteBranchNode, + FlyteLaunchPlan, + FlyteNode, + FlyteTask, + FlyteTaskNode, + FlyteWorkflow, + FlyteWorkflowNode, +) from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution -from flytekit.remote.launch_plan import FlyteLaunchPlan -from flytekit.remote.nodes import FlyteNode from flytekit.remote.remote import FlyteRemote -from flytekit.remote.task import FlyteTask -from flytekit.remote.workflow import FlyteWorkflow diff --git a/flytekit/remote/backfill.py b/flytekit/remote/backfill.py new file mode 100644 index 0000000000..154bf4d1b4 --- /dev/null +++ b/flytekit/remote/backfill.py @@ -0,0 +1,97 @@ +import logging +import typing +from datetime import datetime, timedelta + +from croniter import croniter + +from flytekit import LaunchPlan +from flytekit.core.workflow import ImperativeWorkflow, WorkflowBase +from flytekit.remote.entities import FlyteLaunchPlan + + +def create_backfill_workflow( + start_date: datetime, + end_date: datetime, + for_lp: typing.Union[LaunchPlan, FlyteLaunchPlan], + parallel: bool = False, + per_node_timeout: timedelta = None, + per_node_retries: int = 0, +) -> typing.Tuple[WorkflowBase, datetime, datetime]: + """ + Generates a new imperative workflow for the launchplan that can be used to backfill the given launchplan. + This can only be used to generate backfilling workflow only for schedulable launchplans + + the Backfill plan is generated as (start_date - exclusive, end_date inclusive) + + .. code-block:: python + :caption: Correct usage for dates example + + lp = Launchplan.get_or_create(...) + start_date = datetime.datetime(2023, 1, 1) + end_date = start_date + datetime.timedelta(days=10) + wf = create_backfill_workflow(start_date, end_date, for_lp=lp) + + + .. code-block:: python + :caption: Incorrect date example + + wf = create_backfill_workflow(end_date, start_date, for_lp=lp) # end_date is before start_date + # OR + wf = create_backfill_workflow(start_date, start_date, for_lp=lp) # start and end date are same + + + :param start_date: datetime generate a backfill starting at this datetime (exclusive) + :param end_date: datetime generate a backfill ending at this datetime (inclusive) + :param for_lp: typing.Union[LaunchPlan, FlyteLaunchPlan] the backfill is generatd for this launchplan + :param parallel: if the backfill should be run in parallel. False (default) will run each bacfill sequentially + :param per_node_timeout: timedelta Timeout to use per node + :param per_node_retries: int Retries to user per node + :return: WorkflowBase, datetime datetime -> New generated workflow, datetime for first instance of backfill, datetime for last instance of backfill + """ + if not for_lp: + raise ValueError("Launch plan is required!") + + if start_date >= end_date: + raise ValueError( + f"for a backfill start date should be earlier than end date. Received {start_date} -> {end_date}" + ) + + schedule = for_lp.entity_metadata.schedule if isinstance(for_lp, FlyteLaunchPlan) else for_lp.schedule + + if schedule is None: + raise ValueError("Backfill can only be created for scheduled launch plans") + + if schedule.cron_schedule is not None: + cron_schedule = schedule.cron_schedule + else: + raise NotImplementedError("Currently backfilling only supports cron schedules.") + + logging.info(f"Generating backfill from {start_date} -> {end_date}. Parallel?[{parallel}]") + wf = ImperativeWorkflow(name=f"backfill-{for_lp.name}") + date_iter = croniter(cron_schedule.schedule, start_time=start_date, ret_type=datetime) + prev_node = None + actual_start = None + actual_end = None + while True: + next_start_date = date_iter.get_next() + if not actual_start: + actual_start = next_start_date + if next_start_date >= end_date: + break + actual_end = next_start_date + next_node = wf.add_launch_plan(for_lp, t=next_start_date) + next_node = next_node.with_overrides( + name=f"b-{next_start_date}", retries=per_node_retries, timeout=per_node_timeout + ) + if not parallel: + if prev_node: + prev_node.runs_before(next_node) + prev_node = next_node + + if actual_end is None: + raise StopIteration( + f"The time window is too small for any backfill instances, first instance after start" + f" date is {actual_start}" + ) + + return wf, actual_start, actual_end diff --git a/flytekit/remote/component_nodes.py b/flytekit/remote/component_nodes.py deleted file mode 100644 index bdf5fab38a..0000000000 --- a/flytekit/remote/component_nodes.py +++ /dev/null @@ -1,163 +0,0 @@ -from typing import Dict - -from flytekit.exceptions import system as _system_exceptions -from flytekit.loggers import remote_logger -from flytekit.models import launch_plan as _launch_plan_model -from flytekit.models import task as _task_model -from flytekit.models.core import identifier as id_models -from flytekit.models.core import workflow as _workflow_model - - -class FlyteTaskNode(_workflow_model.TaskNode): - """ - A class encapsulating a task that a Flyte node needs to execute. - """ - - def __init__(self, flyte_task: "flytekit.remote.task.FlyteTask"): - self._flyte_task = flyte_task - super(FlyteTaskNode, self).__init__(None) - - @property - def reference_id(self) -> id_models.Identifier: - """ - A globally unique identifier for the task. - """ - return self._flyte_task.id - - @property - def flyte_task(self) -> "flytekit.remote.tasks.task.FlyteTask": - return self._flyte_task - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_model.TaskNode, - tasks: Dict[id_models.Identifier, _task_model.TaskTemplate], - ) -> "FlyteTaskNode": - """ - Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the - FlyteTask control plane. - - :param base_model: - :param tasks: - """ - from flytekit.remote.task import FlyteTask - - if base_model.reference_id in tasks: - task = tasks[base_model.reference_id] - remote_logger.debug(f"Found existing task template for {task.id}, will not retrieve from Admin") - flyte_task = FlyteTask.promote_from_model(task) - return cls(flyte_task) - - raise _system_exceptions.FlyteSystemException(f"Task template {base_model.reference_id} not found.") - - -class FlyteWorkflowNode(_workflow_model.WorkflowNode): - """A class encapsulating a workflow that a Flyte node needs to execute.""" - - def __init__( - self, - flyte_workflow: "flytekit.remote.workflow.FlyteWorkflow" = None, - flyte_launch_plan: "flytekit.remote.launch_plan.FlyteLaunchPlan" = None, - ): - if flyte_workflow and flyte_launch_plan: - raise _system_exceptions.FlyteSystemException( - "FlyteWorkflowNode cannot be called with both a workflow and a launchplan specified, please pick " - f"one. workflow: {flyte_workflow} launchPlan: {flyte_launch_plan}", - ) - - self._flyte_workflow = flyte_workflow - self._flyte_launch_plan = flyte_launch_plan - super(FlyteWorkflowNode, self).__init__( - launchplan_ref=self._flyte_launch_plan.id if self._flyte_launch_plan else None, - sub_workflow_ref=self._flyte_workflow.id if self._flyte_workflow else None, - ) - - def __repr__(self) -> str: - if self.flyte_workflow is not None: - return f"FlyteWorkflowNode with workflow: {self.flyte_workflow}" - return f"FlyteWorkflowNode with launch plan: {self.flyte_launch_plan}" - - @property - def launchplan_ref(self) -> id_models.Identifier: - """A globally unique identifier for the launch plan, which should map to Admin.""" - return self._flyte_launch_plan.id if self._flyte_launch_plan else None - - @property - def sub_workflow_ref(self): - return self._flyte_workflow.id if self._flyte_workflow else None - - @property - def flyte_launch_plan(self) -> "flytekit.remote.launch_plan.FlyteLaunchPlan": - return self._flyte_launch_plan - - @property - def flyte_workflow(self) -> "flytekit.remote.workflow.FlyteWorkflow": - return self._flyte_workflow - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_model.WorkflowNode, - sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], - node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], - tasks: Dict[id_models.Identifier, _task_model.TaskTemplate], - ) -> "FlyteWorkflowNode": - from flytekit.remote import launch_plan as _launch_plan - from flytekit.remote import workflow as _workflow - - if base_model.launchplan_ref is not None: - return cls( - flyte_launch_plan=_launch_plan.FlyteLaunchPlan.promote_from_model( - base_model.launchplan_ref, node_launch_plans[base_model.launchplan_ref] - ) - ) - elif base_model.sub_workflow_ref is not None: - # the workflow templates for sub-workflows should have been included in the original response - if base_model.reference in sub_workflows: - return cls( - flyte_workflow=_workflow.FlyteWorkflow.promote_from_model( - sub_workflows[base_model.reference], - sub_workflows=sub_workflows, - node_launch_plans=node_launch_plans, - tasks=tasks, - ) - ) - raise _system_exceptions.FlyteSystemException(f"Subworkflow {base_model.reference} not found.") - - raise _system_exceptions.FlyteSystemException( - "Bad workflow node model, neither subworkflow nor launchplan specified." - ) - - -class FlyteBranchNode(_workflow_model.BranchNode): - def __init__(self, if_else: _workflow_model.IfElseBlock): - super().__init__(if_else) - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_model.BranchNode, - sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], - node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], - tasks: Dict[id_models.Identifier, _task_model.TaskTemplate], - ) -> "FlyteBranchNode": - - from flytekit.remote.nodes import FlyteNode - - block = base_model.if_else - - else_node = None - if block.else_node: - else_node = FlyteNode.promote_from_model(block.else_node, sub_workflows, node_launch_plans, tasks) - - block.case._then_node = FlyteNode.promote_from_model( - block.case.then_node, sub_workflows, node_launch_plans, tasks - ) - - for o in block.other: - o._then_node = FlyteNode.promote_from_model(o.then_node, sub_workflows, node_launch_plans, tasks) - - new_if_else_block = _workflow_model.IfElseBlock(block.case, block.other, else_node, block.error) - - return cls(new_if_else_block) diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py new file mode 100644 index 0000000000..624e01661d --- /dev/null +++ b/flytekit/remote/entities.py @@ -0,0 +1,815 @@ +"""This module contains shadow entities for all Flyte entities as represented in Flyte Admin / Control Plane. +The goal is to enable easy access, manipulation of these entities. """ +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple, Union + +from flytekit import FlyteContext +from flytekit.core import constants as _constants +from flytekit.core import hash as _hash_mixin +from flytekit.core import hash as hash_mixin +from flytekit.core.promise import create_and_link_node_from_remote +from flytekit.exceptions import system as _system_exceptions +from flytekit.exceptions import user as _user_exceptions +from flytekit.loggers import remote_logger +from flytekit.models import interface as _interface_models +from flytekit.models import launch_plan as _launch_plan_model +from flytekit.models import launch_plan as _launch_plan_models +from flytekit.models import launch_plan as launch_plan_models +from flytekit.models import task as _task_model +from flytekit.models import task as _task_models +from flytekit.models.admin.workflow import WorkflowSpec +from flytekit.models.core import compiler as compiler_models +from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import identifier as id_models +from flytekit.models.core import workflow as _workflow_model +from flytekit.models.core import workflow as _workflow_models +from flytekit.models.core.identifier import Identifier +from flytekit.models.core.workflow import Node, WorkflowMetadata, WorkflowMetadataDefaults +from flytekit.models.interface import TypedInterface +from flytekit.models.literals import Binding +from flytekit.models.task import TaskSpec +from flytekit.remote import interface as _interface +from flytekit.remote import interface as _interfaces +from flytekit.remote.remote_callable import RemoteEntity + + +class FlyteTask(hash_mixin.HashOnReferenceMixin, RemoteEntity, TaskSpec): + """A class encapsulating a remote Flyte task.""" + + def __init__( + self, + id, + type, + metadata, + interface, + custom, + container=None, + task_type_version: int = 0, + config=None, + should_register: bool = False, + ): + super(FlyteTask, self).__init__( + template=_task_model.TaskTemplate( + id, + type, + metadata, + interface, + custom, + container=container, + task_type_version=task_type_version, + config=config, + ) + ) + self._should_register = should_register + + @property + def id(self): + """ + This is generated by the system and uniquely identifies the task. + :rtype: flytekit.models.core.identifier.Identifier + """ + return self.template.id + + @property + def type(self): + """ + This is used to identify additional extensions for use by Propeller or SDK. + :rtype: Text + """ + return self.template.type + + @property + def metadata(self): + """ + This contains information needed at runtime to determine behavior such as whether or not outputs are + discoverable, timeouts, and retries. + :rtype: TaskMetadata + """ + return self.template.metadata + + @property + def interface(self): + """ + The interface definition for this task. + :rtype: flytekit.models.interface.TypedInterface + """ + return self.template.interface + + @property + def custom(self): + """ + Arbitrary dictionary containing metadata for custom plugins. + :rtype: dict[Text, T] + """ + return self.template.custom + + @property + def task_type_version(self): + return self.template.task_type_version + + @property + def container(self): + """ + If not None, the target of execution should be a container. + :rtype: Container + """ + return self.template.container + + @property + def config(self): + """ + Arbitrary dictionary containing metadata for parsing and handling custom plugins. + :rtype: dict[Text, T] + """ + return self.template.config + + @property + def security_context(self): + return self.template.security_context + + @property + def k8s_pod(self): + return self.template.k8s_pod + + @property + def sql(self): + return self.template.sql + + @property + def should_register(self) -> bool: + return self._should_register + + @property + def name(self) -> str: + return self.template.id.name + + @property + def resource_type(self) -> _identifier_model.ResourceType: + return _identifier_model.ResourceType.TASK + + @property + def entity_type_text(self) -> str: + return "Task" + + @classmethod + def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> FlyteTask: + t = cls( + id=base_model.id, + type=base_model.type, + metadata=base_model.metadata, + interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), + custom=base_model.custom, + container=base_model.container, + task_type_version=base_model.task_type_version, + ) + # Override the newly generated name if one exists in the base model + if not base_model.id.is_empty: + t._id = base_model.id + + return t + + +class FlyteTaskNode(_workflow_model.TaskNode): + """ + A class encapsulating a task that a Flyte node needs to execute. + """ + + def __init__(self, flyte_task: FlyteTask): + super(FlyteTaskNode, self).__init__(None) + self._flyte_task = flyte_task + + @property + def reference_id(self) -> id_models.Identifier: + """ + A globally unique identifier for the task. + """ + return self._flyte_task.id + + @property + def flyte_task(self) -> FlyteTask: + return self._flyte_task + + @classmethod + def promote_from_model(cls, task: FlyteTask) -> FlyteTaskNode: + """ + Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the + FlyteTask control plane. + """ + return cls(flyte_task=task) + + +class FlyteWorkflowNode(_workflow_model.WorkflowNode): + """A class encapsulating a workflow that a Flyte node needs to execute.""" + + def __init__( + self, + flyte_workflow: FlyteWorkflow = None, + flyte_launch_plan: FlyteLaunchPlan = None, + ): + if flyte_workflow and flyte_launch_plan: + raise _system_exceptions.FlyteSystemException( + "FlyteWorkflowNode cannot be called with both a workflow and a launchplan specified, please pick " + f"one. workflow: {flyte_workflow} launchPlan: {flyte_launch_plan}", + ) + + self._flyte_workflow = flyte_workflow + self._flyte_launch_plan = flyte_launch_plan + super(FlyteWorkflowNode, self).__init__( + launchplan_ref=self._flyte_launch_plan.id if self._flyte_launch_plan else None, + sub_workflow_ref=self._flyte_workflow.id if self._flyte_workflow else None, + ) + + def __repr__(self) -> str: + if self.flyte_workflow is not None: + return f"FlyteWorkflowNode with workflow: {self.flyte_workflow}" + return f"FlyteWorkflowNode with launch plan: {self.flyte_launch_plan}" + + @property + def launchplan_ref(self) -> id_models.Identifier: + """A globally unique identifier for the launch plan, which should map to Admin.""" + return self._flyte_launch_plan.id if self._flyte_launch_plan else None + + @property + def sub_workflow_ref(self): + return self._flyte_workflow.id if self._flyte_workflow else None + + @property + def flyte_launch_plan(self) -> FlyteLaunchPlan: + return self._flyte_launch_plan + + @property + def flyte_workflow(self) -> FlyteWorkflow: + return self._flyte_workflow + + @classmethod + def _promote_workflow( + cls, + wf: _workflow_models.WorkflowTemplate, + sub_workflows: Optional[Dict[Identifier, _workflow_models.WorkflowTemplate]] = None, + tasks: Optional[Dict[Identifier, FlyteTask]] = None, + node_launch_plans: Optional[Dict[Identifier, launch_plan_models.LaunchPlanSpec]] = None, + ) -> FlyteWorkflow: + return FlyteWorkflow.promote_from_model( + wf, + sub_workflows=sub_workflows, + node_launch_plans=node_launch_plans, + tasks=tasks, + ) + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_model.WorkflowNode, + sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], + node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], + tasks: Dict[Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[FlyteWorkflowNode, Dict[id_models.Identifier, FlyteWorkflow]]: + if base_model.launchplan_ref is not None: + return ( + cls( + flyte_launch_plan=FlyteLaunchPlan.promote_from_model( + base_model.launchplan_ref, node_launch_plans[base_model.launchplan_ref] + ) + ), + converted_sub_workflows, + ) + elif base_model.sub_workflow_ref is not None: + # the workflow templates for sub-workflows should have been included in the original response + if base_model.reference in sub_workflows: + wf = None + if base_model.reference not in converted_sub_workflows: + wf = cls._promote_workflow( + sub_workflows[base_model.reference], + sub_workflows=sub_workflows, + node_launch_plans=node_launch_plans, + tasks=tasks, + ) + converted_sub_workflows[base_model.reference] = wf + else: + wf = converted_sub_workflows[base_model.reference] + return cls(flyte_workflow=wf), converted_sub_workflows + raise _system_exceptions.FlyteSystemException(f"Subworkflow {base_model.reference} not found.") + + raise _system_exceptions.FlyteSystemException( + "Bad workflow node model, neither subworkflow nor launchplan specified." + ) + + +class FlyteBranchNode(_workflow_model.BranchNode): + def __init__(self, if_else: _workflow_model.IfElseBlock): + super().__init__(if_else) + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_model.BranchNode, + sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], + node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], + tasks: Dict[id_models.Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[FlyteBranchNode, Dict[id_models.Identifier, FlyteWorkflow]]: + + block = base_model.if_else + block.case._then_node, converted_sub_workflows = FlyteNode.promote_from_model( + block.case.then_node, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + + for o in block.other: + o._then_node, converted_sub_workflows = FlyteNode.promote_from_model( + o.then_node, sub_workflows, node_launch_plans, tasks, converted_sub_workflows + ) + + else_node = None + if block.else_node: + else_node, converted_sub_workflows = FlyteNode.promote_from_model( + block.else_node, sub_workflows, node_launch_plans, tasks, converted_sub_workflows + ) + + new_if_else_block = _workflow_model.IfElseBlock(block.case, block.other, else_node, block.error) + + return cls(new_if_else_block), converted_sub_workflows + + +class FlyteGateNode(_workflow_model.GateNode): + @classmethod + def promote_from_model(cls, model: _workflow_model.GateNode): + return cls(model.signal, model.sleep, model.approve) + + +class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): + """A class encapsulating a remote Flyte node.""" + + def __init__( + self, + id, + upstream_nodes, + bindings, + metadata, + task_node: Optional[FlyteTaskNode] = None, + workflow_node: Optional[FlyteWorkflowNode] = None, + branch_node: Optional[FlyteBranchNode] = None, + gate_node: Optional[FlyteGateNode] = None, + ): + if not task_node and not workflow_node and not branch_node and not gate_node: + raise _user_exceptions.FlyteAssertion( + "An Flyte node must have one of task|workflow|branch|gate entity specified at once" + ) + # TODO: Revisit flyte_branch_node and flyte_gate_node, should they be another type like Condition instead + # of a node? + if task_node: + self._flyte_entity = task_node.flyte_task + elif workflow_node: + self._flyte_entity = workflow_node.flyte_workflow or workflow_node.flyte_launch_plan + else: + self._flyte_entity = branch_node or gate_node + + super(FlyteNode, self).__init__( + id=id, + metadata=metadata, + inputs=bindings, + upstream_node_ids=[n.id for n in upstream_nodes], + output_aliases=[], + task_node=task_node, + workflow_node=workflow_node, + branch_node=branch_node, + gate_node=gate_node, + ) + self._upstream = upstream_nodes + + @property + def flyte_entity(self) -> Union[FlyteTask, FlyteWorkflow, FlyteLaunchPlan, FlyteBranchNode]: + return self._flyte_entity + + @classmethod + def _promote_task_node(cls, t: FlyteTask) -> FlyteTaskNode: + return FlyteTaskNode.promote_from_model(t) + + @classmethod + def _promote_workflow_node( + cls, + wn: _workflow_model.WorkflowNode, + sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], + node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], + tasks: Dict[Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[FlyteWorkflowNode, Dict[id_models.Identifier, FlyteWorkflow]]: + return FlyteWorkflowNode.promote_from_model( + wn, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + + @classmethod + def promote_from_model( + cls, + model: _workflow_model.Node, + sub_workflows: Optional[Dict[id_models.Identifier, _workflow_model.WorkflowTemplate]], + node_launch_plans: Optional[Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec]], + tasks: Dict[id_models.Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[Optional[FlyteNode], Dict[id_models.Identifier, FlyteWorkflow]]: + node_model_id = model.id + # TODO: Consider removing + if id in {_constants.START_NODE_ID, _constants.END_NODE_ID}: + remote_logger.warning(f"Should not call promote from model on a start node or end node {model}") + return None, converted_sub_workflows + + flyte_task_node, flyte_workflow_node, flyte_branch_node, flyte_gate_node = None, None, None, None + if model.task_node is not None: + if model.task_node.reference_id not in tasks: + raise RuntimeError( + f"Remote Workflow closure does not have task with id {model.task_node.reference_id}." + ) + flyte_task_node = cls._promote_task_node(tasks[model.task_node.reference_id]) + elif model.workflow_node is not None: + flyte_workflow_node, converted_sub_workflows = cls._promote_workflow_node( + model.workflow_node, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + elif model.branch_node is not None: + flyte_branch_node, converted_sub_workflows = FlyteBranchNode.promote_from_model( + model.branch_node, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + elif model.gate_node is not None: + flyte_gate_node = FlyteGateNode.promote_from_model(model.gate_node) + else: + raise _system_exceptions.FlyteSystemException( + f"Bad Node model, neither task nor workflow detected, node: {model}" + ) + + # When WorkflowTemplate models (containing node models) are returned by Admin, they've been compiled with a + # start node. In order to make the promoted FlyteWorkflow look the same, we strip the start-node text back out. + # TODO: Consider removing + for model_input in model.inputs: + if ( + model_input.binding.promise is not None + and model_input.binding.promise.node_id == _constants.START_NODE_ID + ): + model_input.binding.promise._node_id = _constants.GLOBAL_INPUT_NODE_ID + + return ( + cls( + id=node_model_id, + upstream_nodes=[], # set downstream, model doesn't contain this information + bindings=model.inputs, + metadata=model.metadata, + task_node=flyte_task_node, + workflow_node=flyte_workflow_node, + branch_node=flyte_branch_node, + gate_node=flyte_gate_node, + ), + converted_sub_workflows, + ) + + @property + def upstream_nodes(self) -> List[FlyteNode]: + return self._upstream + + @property + def upstream_node_ids(self) -> List[str]: + return list(sorted(n.id for n in self.upstream_nodes)) + + def __repr__(self) -> str: + return f"Node(ID: {self.id})" + + +class FlyteWorkflow(_hash_mixin.HashOnReferenceMixin, RemoteEntity, WorkflowSpec): + """A class encapsulating a remote Flyte workflow.""" + + def __init__( + self, + id: id_models.Identifier, + nodes: List[FlyteNode], + interface, + output_bindings, + metadata, + metadata_defaults, + subworkflows: Optional[List[FlyteWorkflow]] = None, + tasks: Optional[List[FlyteTask]] = None, + launch_plans: Optional[Dict[id_models.Identifier, launch_plan_models.LaunchPlanSpec]] = None, + compiled_closure: Optional[compiler_models.CompiledWorkflowClosure] = None, + should_register: bool = False, + ): + # TODO: Remove check + for node in nodes: + for upstream in node.upstream_nodes: + if upstream.id is None: + raise _user_exceptions.FlyteAssertion( + "Some nodes contained in the workflow were not found in the workflow description. Please " + "ensure all nodes are either assigned to attributes within the class or an element in a " + "list, dict, or tuple which is stored as an attribute in the class." + ) + + self._flyte_sub_workflows = subworkflows + template_subworkflows = [] + if subworkflows: + template_subworkflows = [swf.template for swf in subworkflows] + + super(FlyteWorkflow, self).__init__( + template=_workflow_models.WorkflowTemplate( + id=id, + metadata=metadata, + metadata_defaults=metadata_defaults, + interface=interface, + nodes=nodes, + outputs=output_bindings, + ), + sub_workflows=template_subworkflows, + ) + self._flyte_nodes = nodes + + # Optional things that we save for ease of access when promoting from a model or CompiledWorkflowClosure + self._tasks = tasks + self._launch_plans = launch_plans + self._compiled_closure = compiled_closure + self._node_map = None + self._name = id.name + self._should_register = should_register + + @property + def name(self) -> str: + return self._name + + @property + def flyte_tasks(self) -> Optional[List[FlyteTask]]: + return self._tasks + + @property + def should_register(self) -> bool: + return self._should_register + + @property + def flyte_sub_workflows(self) -> List[FlyteWorkflow]: + return self._flyte_sub_workflows + + @property + def entity_type_text(self) -> str: + return "Workflow" + + @property + def resource_type(self): + return id_models.ResourceType.WORKFLOW + + @property + def flyte_nodes(self) -> List[FlyteNode]: + return self._flyte_nodes + + @property + def id(self) -> Identifier: + """ + This is an autogenerated id by the system. The id is globally unique across Flyte. + """ + return self.template.id + + @property + def metadata(self) -> WorkflowMetadata: + """ + This contains information on how to run the workflow. + """ + return self.template.metadata + + @property + def metadata_defaults(self) -> WorkflowMetadataDefaults: + """ + This contains information on how to run the workflow. + :rtype: WorkflowMetadataDefaults + """ + return self.template.metadata_defaults + + @property + def interface(self) -> TypedInterface: + """ + Defines a strongly typed interface for the Workflow (inputs, outputs). This can include some optional + parameters. + """ + return self.template.interface + + @property + def nodes(self) -> List[Node]: + """ + A list of nodes. In addition, "globals" is a special reserved node id that can be used to consume + workflow inputs + """ + return self.template.nodes + + @property + def outputs(self) -> List[Binding]: + """ + A list of output bindings that specify how to construct workflow outputs. Bindings can + pull node outputs or specify literals. All workflow outputs specified in the interface field must be bound + in order for the workflow to be validated. A workflow has an implicit dependency on all of its nodes + to execute successfully in order to bind final outputs. + """ + return self.template.outputs + + @property + def failure_node(self) -> Node: + """ + Node failure_node: A catch-all node. This node is executed whenever the execution engine determines the + workflow has failed. The interface of this node must match the Workflow interface with an additional input + named "error" of type pb.lyft.flyte.core.Error. + """ + return self.template.failure_node + + @classmethod + def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workflow_models.Node]: + return [n for n in nodes if n.id not in {_constants.START_NODE_ID, _constants.END_NODE_ID}] + + @classmethod + def _promote_node( + cls, + model: _workflow_model.Node, + sub_workflows: Optional[Dict[id_models.Identifier, _workflow_model.WorkflowTemplate]], + node_launch_plans: Optional[Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec]], + tasks: Dict[id_models.Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[Optional[FlyteNode], Dict[id_models.Identifier, FlyteWorkflow]]: + return FlyteNode.promote_from_model(model, sub_workflows, node_launch_plans, tasks, converted_sub_workflows) + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_models.WorkflowTemplate, + sub_workflows: Optional[Dict[Identifier, _workflow_models.WorkflowTemplate]] = None, + tasks: Optional[Dict[Identifier, FlyteTask]] = None, + node_launch_plans: Optional[Dict[Identifier, launch_plan_models.LaunchPlanSpec]] = None, + ) -> FlyteWorkflow: + + base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) + + node_map = {} + converted_sub_workflows = {} + for node in base_model_non_system_nodes: + flyte_node, converted_sub_workflows = cls._promote_node( + node, sub_workflows, node_launch_plans, tasks, converted_sub_workflows + ) + node_map[node.id] = flyte_node + + # Set upstream nodes for each node + for n in base_model_non_system_nodes: + current = node_map[n.id] + for upstream_id in n.upstream_node_ids: + upstream_node = node_map[upstream_id] + current._upstream.append(upstream_node) + + subworkflow_list = [] + if converted_sub_workflows: + subworkflow_list = [v for _, v in converted_sub_workflows.items()] + + task_list = [] + if tasks: + task_list = [t for _, t in tasks.items()] + + # No inputs/outputs specified, see the constructor for more information on the overrides. + wf = cls( + id=base_model.id, + nodes=list(node_map.values()), + metadata=base_model.metadata, + metadata_defaults=base_model.metadata_defaults, + interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), + output_bindings=base_model.outputs, + subworkflows=subworkflow_list, + tasks=task_list, + launch_plans=node_launch_plans, + ) + + wf._node_map = node_map + + return wf + + @classmethod + def _promote_task(cls, t: _task_models.TaskTemplate) -> FlyteTask: + return FlyteTask.promote_from_model(t) + + @classmethod + def promote_from_closure( + cls, + closure: compiler_models.CompiledWorkflowClosure, + node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None, + ): + """ + Extracts out the relevant portions of a FlyteWorkflow from a closure from the control plane. + + :param closure: This is the closure returned by Admin + :param node_launch_plans: The reason this exists is because the compiled closure doesn't have launch plans. + It only has subworkflows and tasks. Why this is is unclear. If supplied, this map of launch plans will be + :return: + """ + sub_workflows = {sw.template.id: sw.template for sw in closure.sub_workflows} + tasks = {} + if closure.tasks: + tasks = {t.template.id: cls._promote_task(t.template) for t in closure.tasks} + + flyte_wf = cls.promote_from_model( + base_model=closure.primary.template, + sub_workflows=sub_workflows, + node_launch_plans=node_launch_plans, + tasks=tasks, + ) + flyte_wf._compiled_closure = closure + return flyte_wf + + +class FlyteLaunchPlan(hash_mixin.HashOnReferenceMixin, RemoteEntity, _launch_plan_models.LaunchPlanSpec): + """A class encapsulating a remote Flyte launch plan.""" + + def __init__(self, id, *args, **kwargs): + super(FlyteLaunchPlan, self).__init__(*args, **kwargs) + # Set all the attributes we expect this class to have + self._id = id + self._name = id.name + + # The interface is not set explicitly unless fetched in an engine context + self._interface = None + # If fetched when creating this object, can store it here. + self._flyte_workflow = None + + @property + def name(self) -> str: + return self._name + + @property + def flyte_workflow(self) -> Optional[FlyteWorkflow]: + return self._flyte_workflow + + @classmethod + def promote_from_model(cls, id: id_models.Identifier, model: _launch_plan_models.LaunchPlanSpec) -> FlyteLaunchPlan: + lp = cls( + id=id, + workflow_id=model.workflow_id, + default_inputs=_interface_models.ParameterMap(model.default_inputs.parameters), + fixed_inputs=model.fixed_inputs, + entity_metadata=model.entity_metadata, + labels=model.labels, + annotations=model.annotations, + auth_role=model.auth_role, + raw_output_data_config=model.raw_output_data_config, + max_parallelism=model.max_parallelism, + security_context=model.security_context, + ) + return lp + + @property + def id(self) -> id_models.Identifier: + return self._id + + @property + def is_scheduled(self) -> bool: + if self.entity_metadata.schedule.cron_expression: + return True + elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value: + return True + elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule: + return True + else: + return False + + @property + def workflow_id(self) -> id_models.Identifier: + return self._workflow_id + + @property + def interface(self) -> Optional[_interface.TypedInterface]: + """ + The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and + from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the= + object and get a node. + """ + return self._interface + + @property + def resource_type(self) -> id_models.ResourceType: + return id_models.ResourceType.LAUNCH_PLAN + + @property + def entity_type_text(self) -> str: + return "Launch Plan" + + def compile(self, ctx: FlyteContext, *args, **kwargs): + fixed_input_lits = self.fixed_inputs.literals or {} + default_input_params = self.default_inputs.parameters or {} + return create_and_link_node_from_remote( + ctx, + entity=self, + _inputs_not_allowed=set(fixed_input_lits.keys()), + _ignorable_inputs=set(default_input_params.keys()), + **kwargs, + ) # noqa + + def __repr__(self) -> str: + return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface}) - Spec {super().__repr__()})" diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index 607b15c889..292b6f0218 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -9,8 +9,7 @@ from flytekit.models import node_execution as node_execution_models from flytekit.models.admin import task_execution as admin_task_execution_models from flytekit.models.core import execution as core_execution_models -from flytekit.remote.task import FlyteTask -from flytekit.remote.workflow import FlyteWorkflow +from flytekit.remote.entities import FlyteTask, FlyteWorkflow class RemoteExecutionBase(object): diff --git a/flytekit/remote/launch_plan.py b/flytekit/remote/launch_plan.py deleted file mode 100644 index b6c8e1f9e6..0000000000 --- a/flytekit/remote/launch_plan.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from flytekit.core import hash as hash_mixin -from flytekit.models import interface as _interface_models -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models.core import identifier as id_models -from flytekit.remote import interface as _interface -from flytekit.remote.remote_callable import RemoteEntity - - -class FlyteLaunchPlan(hash_mixin.HashOnReferenceMixin, RemoteEntity, _launch_plan_models.LaunchPlanSpec): - """A class encapsulating a remote Flyte launch plan.""" - - def __init__(self, id, *args, **kwargs): - super(FlyteLaunchPlan, self).__init__(*args, **kwargs) - # Set all the attributes we expect this class to have - self._id = id - self._name = id.name - - # The interface is not set explicitly unless fetched in an engine context - self._interface = None - - @property - def name(self) -> str: - return self._name - - # If fetched when creating this object, can store it here. - self._flyte_workflow = None - - @property - def flyte_workflow(self) -> Optional["FlyteWorkflow"]: - return self._flyte_workflow - - @classmethod - def promote_from_model( - cls, id: id_models.Identifier, model: _launch_plan_models.LaunchPlanSpec - ) -> "FlyteLaunchPlan": - lp = cls( - id=id, - workflow_id=model.workflow_id, - default_inputs=_interface_models.ParameterMap(model.default_inputs.parameters), - fixed_inputs=model.fixed_inputs, - entity_metadata=model.entity_metadata, - labels=model.labels, - annotations=model.annotations, - auth_role=model.auth_role, - raw_output_data_config=model.raw_output_data_config, - max_parallelism=model.max_parallelism, - security_context=model.security_context, - ) - return lp - - @property - def id(self) -> id_models.Identifier: - return self._id - - @property - def is_scheduled(self) -> bool: - if self.entity_metadata.schedule.cron_expression: - return True - elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value: - return True - elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule: - return True - else: - return False - - @property - def workflow_id(self) -> id_models.Identifier: - return self._workflow_id - - @property - def interface(self) -> Optional[_interface.TypedInterface]: - """ - The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and - from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the= - object and get a node. - """ - return self._interface - - @property - def resource_type(self) -> id_models.ResourceType: - return id_models.ResourceType.LAUNCH_PLAN - - @property - def entity_type_text(self) -> str: - return "Launch Plan" - - def __repr__(self) -> str: - return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface}) - Spec {super().__repr__()})" diff --git a/flytekit/remote/lazy_entity.py b/flytekit/remote/lazy_entity.py new file mode 100644 index 0000000000..4755aad99d --- /dev/null +++ b/flytekit/remote/lazy_entity.py @@ -0,0 +1,67 @@ +import typing +from threading import Lock + +from flytekit import FlyteContext +from flytekit.remote.remote_callable import RemoteEntity + +T = typing.TypeVar("T", bound=RemoteEntity) + + +class LazyEntity(RemoteEntity, typing.Generic[T]): + """ + Fetches the entity when the entity is called or when the entity is retrieved. + The entity is derived from RemoteEntity so that it behaves exactly like the mimiced entity. + """ + + def __init__(self, name: str, getter: typing.Callable[[], T], *args, **kwargs): + super().__init__(*args, **kwargs) + self._entity = None + self._getter = getter + self._name = name + if not self._getter: + raise ValueError("getter method is required to create a Lazy loadable Remote Entity.") + self._mutex = Lock() + + @property + def name(self) -> str: + return self._name + + def entity_fetched(self) -> bool: + with self._mutex: + return self._entity is not None + + @property + def entity(self) -> T: + """ + If not already fetched / available, then the entity will be force fetched. + """ + with self._mutex: + if self._entity is None: + try: + self._entity = self._getter() + except AttributeError as e: + raise RuntimeError( + f"Error downloading the entity {self._name}, (check original exception...)" + ) from e + return self._entity + + def __getattr__(self, item: str) -> typing.Any: + """ + Forwards all other attributes to entity, causing the entity to be fetched! + """ + return getattr(self.entity, item) + + def compile(self, ctx: FlyteContext, *args, **kwargs): + return self.entity.compile(ctx, *args, **kwargs) + + def __call__(self, *args, **kwargs): + """ + Forwards the call to the underlying entity. The entity will be fetched if not already present + """ + return self.entity(*args, **kwargs) + + def __repr__(self) -> str: + return str(self) + + def __str__(self) -> str: + return f"Promise for entity [{self._name}]" diff --git a/flytekit/remote/nodes.py b/flytekit/remote/nodes.py deleted file mode 100644 index 0d73678b7e..0000000000 --- a/flytekit/remote/nodes.py +++ /dev/null @@ -1,164 +0,0 @@ -from __future__ import annotations - -from typing import Dict, List, Optional, Union - -from flytekit.core import constants as _constants -from flytekit.core import hash as _hash_mixin -from flytekit.core.promise import NodeOutput -from flytekit.exceptions import system as _system_exceptions -from flytekit.exceptions import user as _user_exceptions -from flytekit.loggers import remote_logger -from flytekit.models import launch_plan as _launch_plan_model -from flytekit.models import task as _task_model -from flytekit.models.core import identifier as id_models -from flytekit.models.core import workflow as _workflow_model -from flytekit.remote import component_nodes as _component_nodes - - -class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): - """A class encapsulating a remote Flyte node.""" - - def __init__( - self, - id, - upstream_nodes, - bindings, - metadata, - flyte_task: Optional["FlyteTask"] = None, - flyte_workflow: Optional["FlyteWorkflow"] = None, - flyte_launch_plan: Optional["FlyteLaunchPlan"] = None, - flyte_branch_node: Optional["FlyteBranchNode"] = None, - ): - # todo: flyte_branch_node is the only non-entity here, feels wrong, it should probably be a Condition - # or the other ones changed. - non_none_entities = list(filter(None, [flyte_task, flyte_workflow, flyte_launch_plan, flyte_branch_node])) - if len(non_none_entities) != 1: - raise _user_exceptions.FlyteAssertion( - "An Flyte node must have one underlying entity specified at once. Received the following " - "entities: {}".format(non_none_entities) - ) - # todo: wip - flyte_branch_node is a hack, it should be a Condition, but backing out a Condition object from - # the compiled IfElseBlock is cumbersome, shouldn't do it if we can get away with it. - self._flyte_entity = flyte_task or flyte_workflow or flyte_launch_plan or flyte_branch_node - - workflow_node = None - if flyte_workflow is not None: - workflow_node = _component_nodes.FlyteWorkflowNode(flyte_workflow=flyte_workflow) - elif flyte_launch_plan is not None: - workflow_node = _component_nodes.FlyteWorkflowNode(flyte_launch_plan=flyte_launch_plan) - - task_node = None - if flyte_task: - task_node = _component_nodes.FlyteTaskNode(flyte_task) - - super(FlyteNode, self).__init__( - id=id, - metadata=metadata, - inputs=bindings, - upstream_node_ids=[n.id for n in upstream_nodes], - output_aliases=[], - task_node=task_node, - workflow_node=workflow_node, - branch_node=flyte_branch_node, - ) - self._upstream = upstream_nodes - - @property - def flyte_entity(self) -> Union["FlyteTask", "FlyteWorkflow", "FlyteLaunchPlan"]: - return self._flyte_entity - - @classmethod - def promote_from_model( - cls, - model: _workflow_model.Node, - sub_workflows: Optional[Dict[id_models.Identifier, _workflow_model.WorkflowTemplate]], - node_launch_plans: Optional[Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec]], - tasks: Optional[Dict[id_models.Identifier, _task_model.TaskTemplate]], - ) -> FlyteNode: - node_model_id = model.id - # TODO: Consider removing - if id in {_constants.START_NODE_ID, _constants.END_NODE_ID}: - remote_logger.warning(f"Should not call promote from model on a start node or end node {model}") - return None - - flyte_task_node, flyte_workflow_node, flyte_branch_node = None, None, None - if model.task_node is not None: - flyte_task_node = _component_nodes.FlyteTaskNode.promote_from_model(model.task_node, tasks) - elif model.workflow_node is not None: - flyte_workflow_node = _component_nodes.FlyteWorkflowNode.promote_from_model( - model.workflow_node, - sub_workflows, - node_launch_plans, - tasks, - ) - elif model.branch_node is not None: - flyte_branch_node = _component_nodes.FlyteBranchNode.promote_from_model( - model.branch_node, sub_workflows, node_launch_plans, tasks - ) - else: - raise _system_exceptions.FlyteSystemException( - f"Bad Node model, neither task nor workflow detected, node: {model}" - ) - - # When WorkflowTemplate models (containing node models) are returned by Admin, they've been compiled with a - # start node. In order to make the promoted FlyteWorkflow look the same, we strip the start-node text back out. - # TODO: Consider removing - for model_input in model.inputs: - if ( - model_input.binding.promise is not None - and model_input.binding.promise.node_id == _constants.START_NODE_ID - ): - model_input.binding.promise._node_id = _constants.GLOBAL_INPUT_NODE_ID - - if flyte_task_node is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_task=flyte_task_node.flyte_task, - ) - elif flyte_workflow_node is not None: - if flyte_workflow_node.flyte_workflow is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_workflow=flyte_workflow_node.flyte_workflow, - ) - elif flyte_workflow_node.flyte_launch_plan is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_launch_plan=flyte_workflow_node.flyte_launch_plan, - ) - raise _system_exceptions.FlyteSystemException( - "Bad FlyteWorkflowNode model, both launch plan and workflow are None" - ) - elif flyte_branch_node is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_branch_node=flyte_branch_node, - ) - raise _system_exceptions.FlyteSystemException("Bad FlyteNode model, both task and workflow nodes are empty") - - @property - def upstream_nodes(self) -> List[FlyteNode]: - return self._upstream - - @property - def upstream_node_ids(self) -> List[str]: - return list(sorted(n.id for n in self.upstream_nodes)) - - @property - def outputs(self) -> Dict[str, NodeOutput]: - return self._outputs - - def __repr__(self) -> str: - return f"Node(ID: {self.id})" diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 00227f88f0..93badd5374 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -17,6 +17,7 @@ from dataclasses import asdict, dataclass from datetime import datetime, timedelta +from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest from flyteidl.core import literals_pb2 as literals_pb2 from flytekit import Literal @@ -40,11 +41,12 @@ from flytekit.models import launch_plan as launch_plan_models from flytekit.models import literals as literal_models from flytekit.models import task as task_models +from flytekit.models import types as type_models from flytekit.models.admin import common as admin_common_models from flytekit.models.admin import workflow as admin_workflow_models from flytekit.models.admin.common import Sort from flytekit.models.core import workflow as workflow_model -from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier +from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier from flytekit.models.core.workflow import NodeMetadata from flytekit.models.execution import ( ExecutionMetadata, @@ -53,13 +55,12 @@ NotificationList, WorkflowExecutionGetDataResponse, ) +from flytekit.remote.backfill import create_backfill_workflow +from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution from flytekit.remote.interface import TypedInterface -from flytekit.remote.launch_plan import FlyteLaunchPlan -from flytekit.remote.nodes import FlyteNode +from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote_callable import RemoteEntity -from flytekit.remote.task import FlyteTask -from flytekit.remote.workflow import FlyteWorkflow from flytekit.tools.fast_registration import fast_package from flytekit.tools.script_mode import fast_register_single_script, hash_file from flytekit.tools.translator import ( @@ -75,6 +76,14 @@ MOST_RECENT_FIRST = admin_common_models.Sort("created_at", admin_common_models.Sort.Direction.DESCENDING) +class RegistrationSkipped(Exception): + """ + RegistrationSkipped error is raised when trying to register an entity that is not registrable. + """ + + pass + + @dataclass class ResolvedIdentifiers: project: str @@ -113,6 +122,22 @@ def _get_entity_identifier( ) +def _get_git_repo_url(source_path): + """ + Get git repo URL from remote.origin.url + """ + try: + from git import Repo + + return "github.com/" + Repo(source_path).remotes.origin.url.split(".git")[0].split(":")[-1] + except ImportError: + remote_logger.warning("Could not import git. is the git executable installed?") + except Exception: + # If the file isn't in the git repo, we can't get the url from git config + remote_logger.debug(f"{source_path} is not a git repo.") + return "" + + class FlyteRemote(object): """Main entrypoint for programmatically accessing a Flyte remote backend. @@ -140,7 +165,8 @@ def __init__( if config is None or config.platform is None or config.platform.endpoint is None: raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.") - self._client = SynchronousFlyteClient(config.platform, **kwargs) + self._kwargs = kwargs + self._client_initialized = False self._config = config # read config files, env vars, host, ssl options for admin client self._default_project = default_project @@ -162,6 +188,9 @@ def context(self) -> FlyteContext: @property def client(self) -> SynchronousFlyteClient: """Return a SynchronousFlyteClient for additional operations.""" + if not self._client_initialized: + self._client = SynchronousFlyteClient(self.config.platform, **self._kwargs) + self._client_initialized = True return self._client @property @@ -190,6 +219,20 @@ def remote_context(self): FlyteContextManager.current_context().with_file_access(self.file_access) ) + def fetch_task_lazy( + self, project: str = None, domain: str = None, name: str = None, version: str = None + ) -> LazyEntity: + """ + Similar to fetch_task, just that it returns a LazyEntity, which will fetch the workflow lazily. + """ + if name is None: + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") + + def _fetch(): + return self.fetch_task(project=project, domain=domain, name=name, version=version) + + return LazyEntity(name=name, getter=_fetch) + def fetch_task(self, project: str = None, domain: str = None, name: str = None, version: str = None) -> FlyteTask: """Fetch a task entity from flyte admin. @@ -213,14 +256,28 @@ def fetch_task(self, project: str = None, domain: str = None, name: str = None, ) admin_task = self.client.get_task(task_id) flyte_task = FlyteTask.promote_from_model(admin_task.closure.compiled_task.template) - flyte_task._id = task_id + flyte_task.template._id = task_id return flyte_task + def fetch_workflow_lazy( + self, project: str = None, domain: str = None, name: str = None, version: str = None + ) -> LazyEntity[FlyteWorkflow]: + """ + Similar to fetch_workflow, just that it returns a LazyEntity, which will fetch the workflow lazily. + """ + if name is None: + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") + + def _fetch(): + return self.fetch_workflow(project, domain, name, version) + + return LazyEntity(name=name, getter=_fetch) + def fetch_workflow( self, project: str = None, domain: str = None, name: str = None, version: str = None ) -> FlyteWorkflow: - """Fetch a workflow entity from flyte admin. - + """ + Fetch a workflow entity from flyte admin. :param project: fetch entity from this project. If None, uses the default_project attribute. :param domain: fetch entity from this domain. If None, uses the default_domain attribute. :param name: fetch entity with matching name. @@ -237,6 +294,7 @@ def fetch_workflow( name, version, ) + admin_workflow = self.client.get_workflow(workflow_id) compiled_wf = admin_workflow.closure.compiled_workflow @@ -315,6 +373,69 @@ def fetch_execution(self, project: str = None, domain: str = None, name: str = N # Listing Entities # ###################### + def list_signals( + self, + execution_name: str, + project: typing.Optional[str] = None, + domain: typing.Optional[str] = None, + limit: int = 100, + filters: typing.Optional[typing.List[filter_models.Filter]] = None, + ) -> typing.List[Signal]: + """ + :param execution_name: The name of the execution. This is the tailend of the URL when looking at the workflow execution. + :param project: The execution project, will default to the Remote's default project. + :param domain: The execution domain, will default to the Remote's default domain. + :param limit: The number of signals to fetch + :param filters: Optional list of filters + """ + wf_exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) + req = SignalListRequest(workflow_execution_id=wf_exec_id.to_flyte_idl(), limit=limit, filters=filters) + resp = self.client.list_signals(req) + s = resp.signals + return s + + def set_signal( + self, + signal_id: str, + execution_name: str, + value: typing.Union[literal_models.Literal, typing.Any], + project: typing.Optional[str] = None, + domain: typing.Optional[str] = None, + python_type: typing.Optional[typing.Type] = None, + literal_type: typing.Optional[type_models.LiteralType] = None, + ): + """ + :param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call. + :param execution_name: The name of the execution. This is the tail-end of the URL when looking + at the workflow execution. + :param value: This is either a Literal or a Python value which FlyteRemote will invoke the TypeEngine to + convert into a Literal. This argument is only value for wait_for_input type signals. + :param project: The execution project, will default to the Remote's default project. + :param domain: The execution domain, will default to the Remote's default domain. + :param python_type: Provide a python type to help with conversion if the value you provided is not a Literal. + :param literal_type: Provide a Flyte literal type to help with conversion if the value you provided + is not a Literal + """ + wf_exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) + if isinstance(value, Literal): + remote_logger.debug(f"Using provided {value} as existing Literal value") + lit = value + else: + lt = literal_type or ( + TypeEngine.to_literal_type(python_type) if python_type else TypeEngine.to_literal_type(type(value)) + ) + lit = TypeEngine.to_literal(self.context, value, python_type or type(value), lt) + remote_logger.debug(f"Converted {value} to literal {lit} using literal type {lt}") + + req = SignalSetRequest(id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=lit.to_flyte_idl()) + + # Response is empty currently, nothing to give back to the user. + self.client.set_signal(req) + def recent_executions( self, project: typing.Optional[str] = None, @@ -359,8 +480,8 @@ def list_tasks_by_version( def _resolve_identifier(self, t: int, name: str, version: str, ss: SerializationSettings) -> Identifier: ident = Identifier( resource_type=t, - project=ss.project or self.default_project if ss else self.default_project, - domain=ss.domain or self.default_domain if ss else self.default_domain, + project=ss.project if ss and ss.project else self.default_project, + domain=ss.domain if ss and ss.domain else self.default_domain, name=name, version=version or ss.version, ) @@ -374,7 +495,7 @@ def _resolve_identifier(self, t: int, name: str, version: str, ss: Serialization def raw_register( self, cp_entity: FlyteControlPlaneEntity, - settings: typing.Optional[SerializationSettings], + settings: SerializationSettings, version: str, create_default_launchplan: bool = True, options: Options = None, @@ -393,6 +514,15 @@ def raw_register( :param og_entity: Pass in the original workflow (flytekit type) if create_default_launchplan is true :return: Identifier of the created entity """ + if isinstance(cp_entity, RemoteEntity): + if isinstance(cp_entity, (FlyteWorkflow, FlyteTask)): + if not cp_entity.should_register: + remote_logger.debug(f"Skipping registration of remote entity: {cp_entity.name}") + raise RegistrationSkipped(f"Remote task/Workflow {cp_entity.name} is not registrable.") + else: + remote_logger.debug(f"Skipping registration of remote entity: {cp_entity.name}") + raise RegistrationSkipped(f"Remote task/Workflow {cp_entity.name} is not registrable.") + if isinstance( cp_entity, ( @@ -410,6 +540,8 @@ def raw_register( return None if isinstance(cp_entity, task_models.TaskSpec): + if isinstance(cp_entity, FlyteTask): + version = cp_entity.id.version ident = self._resolve_identifier(ResourceType.TASK, cp_entity.template.id.name, version, settings) try: self.client.create_task(task_identifer=ident, task_spec=cp_entity) @@ -418,6 +550,8 @@ def raw_register( return ident if isinstance(cp_entity, admin_workflow_models.WorkflowSpec): + if isinstance(cp_entity, FlyteWorkflow): + version = cp_entity.id.version ident = self._resolve_identifier(ResourceType.WORKFLOW, cp_entity.template.id.name, version, settings) try: self.client.create_workflow(workflow_identifier=ident, workflow_spec=cp_entity) @@ -484,10 +618,6 @@ def _serialize_and_register( ident = None for entity, cp_entity in m.items(): - if isinstance(entity, RemoteEntity): - remote_logger.debug(f"Skipping registration of remote entity: {entity.name}") - continue - if not isinstance(cp_entity, admin_workflow_models.WorkflowSpec) and is_dummy_serialization_setting: # Only in the case of workflows can we use the dummy serialization settings. raise user_exceptions.FlyteValueException( @@ -495,14 +625,17 @@ def _serialize_and_register( f"No serialization settings set, but workflow contains entities that need to be registered. {cp_entity.id.name}", ) - ident = self.raw_register( - cp_entity, - settings=settings, - version=version, - create_default_launchplan=True, - options=options, - og_entity=entity, - ) + try: + ident = self.raw_register( + cp_entity, + settings=settings, + version=version, + create_default_launchplan=True, + options=options, + og_entity=entity, + ) + except RegistrationSkipped: + pass return ident @@ -602,7 +735,7 @@ def _upload_file( filename=to_upload.name, ) self._ctx.file_access.put_data(str(to_upload), upload_location.signed_url) - remote_logger.warning( + remote_logger.debug( f"Uploading {to_upload} to {upload_location.signed_url} native url {upload_location.native_url}" ) @@ -678,11 +811,11 @@ def register_script( filename="scriptmode.tar.gz", ), ) - serialization_settings = SerializationSettings( project=project, domain=domain, image_config=image_config, + git_repo=_get_git_repo_url(source_path), fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, @@ -744,6 +877,7 @@ def _execute( options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """Common method for execution across all entities. @@ -755,6 +889,9 @@ def _execute( :param wait: if True, waits for execution to complete :param type_hints: map of python types to inputs so that the TypeEngine knows how to convert the input values into Flyte Literals. + :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten + for a single execution. If enabled, all calculations are performed even if cached results would + be available, overwriting the stored data once execution finishes successfully. :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` """ execution_name = execution_name or "f" + uuid.uuid4().hex[:19] @@ -810,6 +947,7 @@ def _execute( "placeholder", # Admin replaces this from oidc token if auth is enabled. 0, ), + overwrite_cache=overwrite_cache, notifications=notifications, disable_all=options.disable_notifications, labels=options.labels, @@ -873,6 +1011,7 @@ def execute( options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """ Execute a task, workflow, or launchplan, either something that's been declared locally, or a fetched entity. @@ -906,6 +1045,9 @@ def execute( using the type engine, and then to ``type(v)``. Providing the correct Python types is particularly important if the inputs are containers like lists or maps, or if the Python type is one of the more complex Flyte provided classes (like a StructuredDataset that's annotated with columns). + :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten + for a single execution. If enabled, all calculations are performed even if cached results would + be available, overwriting the stored data once execution finishes successfully. .. note: @@ -924,6 +1066,7 @@ def execute( options=options, wait=wait, type_hints=type_hints, + overwrite_cache=overwrite_cache, ) if isinstance(entity, FlyteWorkflow): return self.execute_remote_wf( @@ -935,6 +1078,7 @@ def execute( options=options, wait=wait, type_hints=type_hints, + overwrite_cache=overwrite_cache, ) if isinstance(entity, PythonTask): return self.execute_local_task( @@ -947,6 +1091,7 @@ def execute( execution_name=execution_name, image_config=image_config, wait=wait, + overwrite_cache=overwrite_cache, ) if isinstance(entity, WorkflowBase): return self.execute_local_workflow( @@ -960,6 +1105,7 @@ def execute( image_config=image_config, options=options, wait=wait, + overwrite_cache=overwrite_cache, ) if isinstance(entity, LaunchPlan): return self.execute_local_launch_plan( @@ -971,6 +1117,7 @@ def execute( execution_name=execution_name, options=options, wait=wait, + overwrite_cache=overwrite_cache, ) raise NotImplementedError(f"entity type {type(entity)} not recognized for execution") @@ -987,6 +1134,7 @@ def execute_remote_task_lp( options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """Execute a FlyteTask, or FlyteLaunchplan. @@ -1001,6 +1149,7 @@ def execute_remote_task_lp( wait=wait, options=options, type_hints=type_hints, + overwrite_cache=overwrite_cache, ) def execute_remote_wf( @@ -1013,6 +1162,7 @@ def execute_remote_wf( options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """Execute a FlyteWorkflow. @@ -1028,6 +1178,7 @@ def execute_remote_wf( options=options, wait=wait, type_hints=type_hints, + overwrite_cache=overwrite_cache, ) # Flytekit Entities @@ -1044,6 +1195,7 @@ def execute_local_task( execution_name: str = None, image_config: typing.Optional[ImageConfig] = None, wait: bool = False, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """ Execute an @task-decorated function or TaskTemplate task. @@ -1058,6 +1210,7 @@ def execute_local_task( :param execution_name: :param image_config: :param wait: + :param overwrite_cache: :return: """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) @@ -1084,6 +1237,7 @@ def execute_local_task( execution_name=execution_name, wait=wait, type_hints=entity.python_interface.inputs, + overwrite_cache=overwrite_cache, ) def execute_local_workflow( @@ -1098,6 +1252,7 @@ def execute_local_workflow( image_config: typing.Optional[ImageConfig] = None, options: typing.Optional[Options] = None, wait: bool = False, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """ Execute an @workflow decorated function. @@ -1111,6 +1266,7 @@ def execute_local_workflow( :param image_config: :param options: :param wait: + :param overwrite_cache: :return: """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) @@ -1155,6 +1311,7 @@ def execute_local_workflow( wait=wait, options=options, type_hints=entity.python_interface.inputs, + overwrite_cache=overwrite_cache, ) def execute_local_launch_plan( @@ -1167,6 +1324,7 @@ def execute_local_launch_plan( execution_name: typing.Optional[str] = None, options: typing.Optional[Options] = None, wait: bool = False, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """ @@ -1178,6 +1336,7 @@ def execute_local_launch_plan( :param execution_name: If specified, will be used as the execution name instead of randomly generating. :param options: :param wait: + :param overwrite_cache: :return: """ try: @@ -1203,6 +1362,7 @@ def execute_local_launch_plan( options=options, wait=wait, type_hints=entity.python_interface.inputs, + overwrite_cache=overwrite_cache, ) ################################### @@ -1305,7 +1465,7 @@ def sync_execution( upstream_nodes=[], bindings=[], metadata=NodeMetadata(name=""), - flyte_task=flyte_entity, + task_node=FlyteTaskNode(flyte_entity), ) } if len(task_node_exec) >= 1 @@ -1562,6 +1722,88 @@ def generate_console_http_domain(self) -> str: return protocol + f"://{endpoint}" def generate_console_url( - self, execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution] + self, + entity: typing.Union[ + FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflow, FlyteTask, FlyteLaunchPlan + ], ): - return f"{self.generate_console_http_domain()}/console/projects/{execution.id.project}/domains/{execution.id.domain}/executions/{execution.id.name}" + """ + Generate a Flyteconsole URL for the given Flyte remote endpoint. + This will automatically determine if this is an execution or an entity and change the type automatically + """ + if isinstance(entity, (FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution)): + return f"{self.generate_console_http_domain()}/console/projects/{entity.id.project}/domains/{entity.id.domain}/executions/{entity.id.name}" # noqa + + if not isinstance(entity, (FlyteWorkflow, FlyteTask, FlyteLaunchPlan)): + raise ValueError(f"Only remote entities can be looked at in the console, got type {type(entity)}") + rt = "workflow" + if entity.id.resource_type == ResourceType.TASK: + rt = "task" + elif entity.id.resource_type == ResourceType.LAUNCH_PLAN: + rt = "launch_plan" + return f"{self.generate_console_http_domain()}/console/projects/{entity.id.project}/domains/{entity.id.domain}/{rt}/{entity.name}/version/{entity.id.version}" # noqa + + def launch_backfill( + self, + project: str, + domain: str, + from_date: datetime, + to_date: datetime, + launchplan: str, + launchplan_version: str = None, + execution_name: str = None, + version: str = None, + dry_run: bool = False, + execute: bool = True, + parallel: bool = False, + ) -> typing.Optional[FlyteWorkflowExecution, FlyteWorkflow, WorkflowBase]: + """ + Creates and launches a backfill workflow for the given launchplan. If launchplan version is not specified, + then the latest launchplan is retrieved. + The from_date is exclusive and end_date is inclusive and backfill run for all instances in between. + -> (start_date - exclusive, end_date inclusive) + If dry_run is specified, the workflow is created and returned + if execute==False is specified then the workflow is created and registered + in the last case, the workflow is created, registered and executed. + + The `parallel` flag can be used to generate a workflow where all launchplans can be run in parallel. Default + is that execute backfill is run sequentially + + :param project: str project name + :param domain: str domain name + :param from_date: datetime generate a backfill starting at this datetime (exclusive) + :param to_date: datetime generate a backfill ending at this datetime (inclusive) + :param launchplan: str launchplan name in the flyte backend + :param launchplan_version: str (optional) version for the launchplan. If not specified the most recent will be retrieved + :param execution_name: str (optional) the generated execution will be named so. this can help in ensuring idempotency + :param version: str (optional) version to be used for the newly created workflow. + :param dry_run: bool do not register or execute the workflow + :param execute: bool Register and execute the wwkflow. + :param parallel: if the backfill should be run in parallel. False (default) will run each bacfill sequentially + :return: In case of dry-run, return WorkflowBase, else if no_execute return FlyteWorkflow else in the default + case return a FlyteWorkflowExecution + """ + lp = self.fetch_launch_plan(project=project, domain=domain, name=launchplan, version=launchplan_version) + wf, start, end = create_backfill_workflow(start_date=from_date, end_date=to_date, for_lp=lp, parallel=parallel) + if dry_run: + remote_logger.warning("Dry Run enabled. Workflow will not be registered and or executed.") + return wf + + unique_fingerprint = f"{start}-{end}-{launchplan}-{launchplan_version}" + h = hashlib.md5() + h.update(unique_fingerprint.encode("utf-8")) + unique_fingerprint_encoded = base64.urlsafe_b64encode(h.digest()).decode("ascii") + if not version: + version = unique_fingerprint_encoded + ss = SerializationSettings( + image_config=ImageConfig.auto(), + project=project, + domain=domain, + version=version, + ) + remote_wf = self.register_workflow(wf, serialization_settings=ss) + + if not execute: + return remote_wf + + return self.execute(remote_wf, inputs={}, project=project, domain=domain, execution_name=execution_name) diff --git a/flytekit/remote/remote_callable.py b/flytekit/remote/remote_callable.py index c04ec75f66..9adfd4846f 100644 --- a/flytekit/remote/remote_callable.py +++ b/flytekit/remote/remote_callable.py @@ -63,10 +63,10 @@ def __call__(self, *args, **kwargs): return self.execute(**kwargs) def local_execute(self, ctx: FlyteContext, **kwargs) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: - raise Exception("Remotely fetched entities cannot be run locally. You have to mock this out.") + return self.execute(**kwargs) def execute(self, **kwargs) -> Any: - raise Exception("Remotely fetched entities cannot be run locally. You have to mock this out.") + raise AssertionError(f"Remotely fetched entities cannot be run locally. Please mock the {self.name}.execute.") @property def python_interface(self) -> Optional[Dict[str, Type]]: diff --git a/flytekit/remote/task.py b/flytekit/remote/task.py deleted file mode 100644 index 3c2c8f8d92..0000000000 --- a/flytekit/remote/task.py +++ /dev/null @@ -1,51 +0,0 @@ -from flytekit.core import hash as hash_mixin -from flytekit.models import task as _task_model -from flytekit.models.core import identifier as _identifier_model -from flytekit.remote import interface as _interfaces -from flytekit.remote.remote_callable import RemoteEntity - - -class FlyteTask(hash_mixin.HashOnReferenceMixin, RemoteEntity, _task_model.TaskTemplate): - """A class encapsulating a remote Flyte task.""" - - def __init__(self, id, type, metadata, interface, custom, container=None, task_type_version=0, config=None): - super(FlyteTask, self).__init__( - id, - type, - metadata, - interface, - custom, - container=container, - task_type_version=task_type_version, - config=config, - ) - self._name = id.name - - @property - def name(self) -> str: - return self._name - - @property - def resource_type(self) -> _identifier_model.ResourceType: - return _identifier_model.ResourceType.TASK - - @property - def entity_type_text(self) -> str: - return "Task" - - @classmethod - def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> "FlyteTask": - t = cls( - id=base_model.id, - type=base_model.type, - metadata=base_model.metadata, - interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), - custom=base_model.custom, - container=base_model.container, - task_type_version=base_model.task_type_version, - ) - # Override the newly generated name if one exists in the base model - if not base_model.id.is_empty: - t._id = base_model.id - - return t diff --git a/flytekit/remote/workflow.py b/flytekit/remote/workflow.py deleted file mode 100644 index 3133f8a1fe..0000000000 --- a/flytekit/remote/workflow.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import annotations - -from typing import Dict, List, Optional - -from flytekit.core import constants as _constants -from flytekit.core import hash as _hash_mixin -from flytekit.exceptions import user as _user_exceptions -from flytekit.models import launch_plan as launch_plan_models -from flytekit.models import task as _task_models -from flytekit.models.core import compiler as compiler_models -from flytekit.models.core import identifier as id_models -from flytekit.models.core import workflow as _workflow_models -from flytekit.remote import interface as _interfaces -from flytekit.remote import nodes as _nodes -from flytekit.remote.remote_callable import RemoteEntity - - -class FlyteWorkflow(_hash_mixin.HashOnReferenceMixin, RemoteEntity, _workflow_models.WorkflowTemplate): - """A class encapsulating a remote Flyte workflow.""" - - def __init__( - self, - id: id_models.Identifier, - nodes: List[_nodes.FlyteNode], - interface, - output_bindings, - metadata, - metadata_defaults, - subworkflows: Optional[Dict[id_models.Identifier, _workflow_models.WorkflowTemplate]] = None, - tasks: Optional[Dict[id_models.Identifier, _task_models.TaskTemplate]] = None, - launch_plans: Optional[Dict[id_models.Identifier, launch_plan_models.LaunchPlanSpec]] = None, - compiled_closure: Optional[compiler_models.CompiledWorkflowClosure] = None, - ): - # TODO: Remove check - for node in nodes: - for upstream in node.upstream_nodes: - if upstream.id is None: - raise _user_exceptions.FlyteAssertion( - "Some nodes contained in the workflow were not found in the workflow description. Please " - "ensure all nodes are either assigned to attributes within the class or an element in a " - "list, dict, or tuple which is stored as an attribute in the class." - ) - super(FlyteWorkflow, self).__init__( - id=id, - metadata=metadata, - metadata_defaults=metadata_defaults, - interface=interface, - nodes=nodes, - outputs=output_bindings, - ) - self._flyte_nodes = nodes - - # Optional things that we save for ease of access when promoting from a model or CompiledWorkflowClosure - self._subworkflows = subworkflows - self._tasks = tasks - self._launch_plans = launch_plans - self._compiled_closure = compiled_closure - self._node_map = None - self._name = id.name - - @property - def name(self) -> str: - return self._name - - @property - def sub_workflows(self) -> Optional[Dict[id_models.Identifier, _workflow_models.WorkflowTemplate]]: - return self._subworkflows - - @property - def entity_type_text(self) -> str: - return "Workflow" - - @property - def resource_type(self): - return id_models.ResourceType.WORKFLOW - - @property - def flyte_nodes(self) -> List[_nodes.FlyteNode]: - return self._flyte_nodes - - @classmethod - def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workflow_models.Node]: - return [n for n in nodes if n.id not in {_constants.START_NODE_ID, _constants.END_NODE_ID}] - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_models.WorkflowTemplate, - sub_workflows: Optional[Dict[id_models, _workflow_models.WorkflowTemplate]] = None, - node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None, - tasks: Optional[Dict[id_models, _task_models.TaskTemplate]] = None, - ) -> FlyteWorkflow: - base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) - sub_workflows = sub_workflows or {} - tasks = tasks or {} - node_map = { - node.id: _nodes.FlyteNode.promote_from_model(node, sub_workflows, node_launch_plans, tasks) - for node in base_model_non_system_nodes - } - - # Set upstream nodes for each node - for n in base_model_non_system_nodes: - current = node_map[n.id] - for upstream_id in n.upstream_node_ids: - upstream_node = node_map[upstream_id] - current._upstream.append(upstream_node) - - # No inputs/outputs specified, see the constructor for more information on the overrides. - wf = cls( - id=base_model.id, - nodes=list(node_map.values()), - metadata=base_model.metadata, - metadata_defaults=base_model.metadata_defaults, - interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), - output_bindings=base_model.outputs, - subworkflows=sub_workflows, - tasks=tasks, - launch_plans=node_launch_plans, - ) - - wf._node_map = node_map - - return wf - - @classmethod - def promote_from_closure( - cls, - closure: compiler_models.CompiledWorkflowClosure, - node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None, - ): - """ - Extracts out the relevant portions of a FlyteWorkflow from a closure from the control plane. - - :param closure: This is the closure returned by Admin - :param node_launch_plans: The reason this exists is because the compiled closure doesn't have launch plans. - It only has subworkflows and tasks. Why this is is unclear. If supplied, this map of launch plans will be - :return: - """ - sub_workflows = {sw.template.id: sw.template for sw in closure.sub_workflows} - tasks = {t.template.id: t.template for t in closure.tasks} - - flyte_wf = FlyteWorkflow.promote_from_model( - base_model=closure.primary.template, - sub_workflows=sub_workflows, - node_launch_plans=node_launch_plans, - tasks=tasks, - ) - flyte_wf._compiled_closure = closure - return flyte_wf diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 870299e5ad..3c9fe64068 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -10,7 +10,9 @@ from flytekit.core.context_manager import FlyteContextManager from flytekit.loggers import logger from flytekit.models import launch_plan +from flytekit.models.core.identifier import Identifier from flytekit.remote import FlyteRemote +from flytekit.remote.remote import RegistrationSkipped, _get_git_repo_url from flytekit.tools import fast_registration, module_loader from flytekit.tools.script_mode import _find_project_root from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities @@ -160,7 +162,7 @@ def load_packages_and_modules( :param options: :return: The common detected root path, the output of _find_project_root """ - + ss.git_repo = _get_git_repo_url(project_root) pkgs_and_modules = [] for pm in pkgs_or_mods: p = Path(pm).resolve() @@ -179,6 +181,27 @@ def load_packages_and_modules( return registrable_entities +def secho(i: Identifier, state: str = "success", reason: str = None): + state_ind = "[ ]" + fg = "white" + nl = False + if state == "success": + state_ind = "\r[âś”]" + fg = "green" + nl = True + reason = f"successful with version {i.version}" if not reason else reason + elif state == "failed": + state_ind = "\r[x]" + fg = "red" + nl = True + reason = "skipped!" + click.secho( + click.style(f"{state_ind}", fg=fg) + f" Registration {i.name} type {i.resource_type_name()} {reason}", + dim=True, + nl=nl, + ) + + def register( project: str, domain: str, @@ -192,6 +215,7 @@ def register( fast: bool, package_or_module: typing.Tuple[str], remote: FlyteRemote, + dry_run: bool = False, ): detected_root = find_common_root(package_or_module) click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow") @@ -234,11 +258,18 @@ def register( if len(serializable_entities) == 0: click.secho("No Flyte entities were detected. Aborting!", fg="red") return - click.secho(f"Found and serialized {len(serializable_entities)} entities") for cp_entity in serializable_entities: - name = cp_entity.id.name if isinstance(cp_entity, launch_plan.LaunchPlan) else cp_entity.template.id.name - click.secho(f" Registering {name}....", dim=True, nl=False) - i = remote.raw_register(cp_entity, serialization_settings, version=version, create_default_launchplan=False) - click.secho(f"done, {i.resource_type_name()} with version {i.version}.", dim=True) + og_id = cp_entity.id if isinstance(cp_entity, launch_plan.LaunchPlan) else cp_entity.template.id + secho(og_id, "") + try: + if not dry_run: + i = remote.raw_register( + cp_entity, serialization_settings, version=version, create_default_launchplan=False + ) + secho(i) + else: + secho(og_id, reason="Dry run Mode!") + except RegistrationSkipped: + secho(og_id, "failed") click.secho(f"Successfully registered {len(serializable_entities)} entities", fg="green") diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 57f863f055..5ec249fa4b 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -1,13 +1,15 @@ +import sys import typing from collections import OrderedDict from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union -from flytekit import PythonFunctionTask +from flytekit import PythonFunctionTask, SourceCode from flytekit.configuration import SerializationSettings from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask from flytekit.core.condition import BranchNode +from flytekit.core.gate import Gate from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan from flytekit.core.map_task import MapPythonTask from flytekit.core.node import Node @@ -21,13 +23,15 @@ from flytekit.models import interface as interface_models from flytekit.models import launch_plan as _launch_plan_models from flytekit.models import security -from flytekit.models import task as task_models from flytekit.models.admin import workflow as admin_workflow_models +from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core import identifier as _identifier_model from flytekit.models.core import workflow as _core_wf from flytekit.models.core import workflow as workflow_model +from flytekit.models.core.workflow import ApproveCondition from flytekit.models.core.workflow import BranchNode as BranchNodeModel -from flytekit.models.core.workflow import TaskNodeOverrides +from flytekit.models.core.workflow import GateNode, SignalCondition, SleepCondition, TaskNodeOverrides +from flytekit.models.task import TaskSpec, TaskTemplate FlyteLocalEntity = Union[ PythonTask, @@ -41,7 +45,7 @@ ReferenceEntity, ] FlyteControlPlaneEntity = Union[ - task_models.TaskSpec, + TaskSpec, _launch_plan_models.LaunchPlan, admin_workflow_models.WorkflowSpec, workflow_model.Node, @@ -152,10 +156,9 @@ def fn(settings: SerializationSettings) -> List[str]: def get_serializable_task( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, -) -> task_models.TaskSpec: +) -> TaskSpec: task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, settings.project, @@ -195,7 +198,7 @@ def get_serializable_task( pod = entity.get_k8s_pod(settings) entity.reset_command_fn() - tt = task_models.TaskTemplate( + tt = TaskTemplate( id=task_id, type=entity.task_type, metadata=entity.metadata.to_taskmetadata_model(), @@ -210,7 +213,8 @@ def get_serializable_task( ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): entity.reset_command_fn() - return task_models.TaskSpec(template=tt) + + return TaskSpec(template=tt, docs=entity.docs) def get_serializable_workflow( @@ -219,18 +223,19 @@ def get_serializable_workflow( entity: WorkflowBase, options: Optional[Options] = None, ) -> admin_workflow_models.WorkflowSpec: - # TODO: Try to move up following config refactor - https://github.com/flyteorg/flyte/issues/2214 - from flytekit.remote.workflow import FlyteWorkflow - - # Get node models - upstream_node_models = [ - get_serializable(entity_mapping, settings, n, options) - for n in entity.nodes - if n.id != _common_constants.GLOBAL_INPUT_NODE_ID - ] - + # Serialize all nodes + serialized_nodes = [] sub_wfs = [] for n in entity.nodes: + # Ignore start nodes + if n.id == _common_constants.GLOBAL_INPUT_NODE_ID: + continue + + # Recursively serialize the node + serialized_nodes.append(get_serializable(entity_mapping, settings, n, options)) + + # If the node is workflow Node or Branch node, we need to handle it specially, to extract all subworkflows, + # so that they can be added to the workflow being serialized if isinstance(n.flyte_entity, WorkflowBase): # We are currently not supporting reference workflows since these will # require a network call to flyteadmin to populate the WorkflowTemplate @@ -247,10 +252,14 @@ def get_serializable_workflow( sub_wfs.append(sub_wf_spec.template) sub_wfs.extend(sub_wf_spec.sub_workflows) + from flytekit.remote import FlyteWorkflow + if isinstance(n.flyte_entity, FlyteWorkflow): - get_serializable(entity_mapping, settings, n.flyte_entity, options) - sub_wfs.append(n.flyte_entity) - sub_wfs.extend([s for s in n.flyte_entity.sub_workflows.values()]) + for swf in n.flyte_entity.flyte_sub_workflows: + sub_wf = get_serializable(entity_mapping, settings, swf, options) + sub_wfs.append(sub_wf.template) + main_wf = get_serializable(entity_mapping, settings, n.flyte_entity, options) + sub_wfs.append(main_wf.template) if isinstance(n.flyte_entity, BranchNode): if_else: workflow_model.IfElseBlock = n.flyte_entity._ifelse_block @@ -286,11 +295,12 @@ def get_serializable_workflow( metadata=entity.workflow_metadata.to_flyte_model(), metadata_defaults=entity.workflow_metadata_defaults.to_flyte_model(), interface=entity.interface, - nodes=upstream_node_models, + nodes=serialized_nodes, outputs=entity.output_bindings, ) + return admin_workflow_models.WorkflowSpec( - template=wf_t, sub_workflows=sorted(set(sub_wfs), key=lambda x: x.short_string()) + template=wf_t, sub_workflows=sorted(set(sub_wfs), key=lambda x: x.short_string()), docs=entity.docs ) @@ -374,12 +384,7 @@ def get_serializable_node( if entity.flyte_entity is None: raise Exception(f"Node {entity.id} has no flyte entity") - # TODO: Try to move back up following config refactor - https://github.com/flyteorg/flyte/issues/2214 - from flytekit.remote.launch_plan import FlyteLaunchPlan - from flytekit.remote.task import FlyteTask - from flytekit.remote.workflow import FlyteWorkflow - - upstream_sdk_nodes = [ + upstream_nodes = [ get_serializable(entity_mapping, settings, n, options=options) for n in entity.upstream_nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID @@ -393,7 +398,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], ) if ref_template.resource_type == _identifier_model.ResourceType.TASK: @@ -408,13 +413,15 @@ def get_serializable_node( ) return node_model + from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow + if isinstance(entity.flyte_entity, PythonTask): task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources) @@ -429,7 +436,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.template.id), ) @@ -439,7 +446,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], branch_node=get_serializable(entity_mapping, settings, entity.flyte_entity, options=options), ) @@ -457,11 +464,32 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(launchplan_ref=lp_spec.id), ) + elif isinstance(entity.flyte_entity, Gate): + if entity.flyte_entity.sleep_duration: + gn = GateNode(sleep=SleepCondition(duration=entity.flyte_entity.sleep_duration)) + elif entity.flyte_entity.input_type: + output_name = list(entity.flyte_entity.python_interface.outputs.keys())[0] # should be o0 + gn = GateNode( + signal=SignalCondition( + entity.flyte_entity.name, type=entity.flyte_entity.literal_type, output_variable_name=output_name + ) + ) + else: + gn = GateNode(approve=ApproveCondition(entity.flyte_entity.name)) + node_model = workflow_model.Node( + id=_dnsify(entity.id), + metadata=entity.metadata, + inputs=entity.bindings, + upstream_node_ids=[n.id for n in upstream_nodes], + output_aliases=[], + gate_node=gn, + ) + elif isinstance(entity.flyte_entity, FlyteTask): # Recursive call doesn't do anything except put the entity on the map. get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) @@ -469,23 +497,23 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides(resources=entity._resources) ), ) elif isinstance(entity.flyte_entity, FlyteWorkflow): - wf_template = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) - for _, sub_wf in entity.flyte_entity.sub_workflows.items(): + wf_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) + for sub_wf in entity.flyte_entity.flyte_sub_workflows: get_serializable(entity_mapping, settings, sub_wf, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], - workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_template.id), + workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.id), ) elif isinstance(entity.flyte_entity, FlyteLaunchPlan): # Recursive call doesn't do anything except put the entity on the map. @@ -500,7 +528,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(launchplan_ref=entity.flyte_entity.id), ) @@ -540,6 +568,54 @@ def get_reference_spec( return ReferenceSpec(template) +def get_serializable_flyte_workflow( + entity: "FlyteWorkflow", settings: SerializationSettings +) -> FlyteControlPlaneEntity: + """ + TODO replace with deep copy + """ + + def _mutate_task_node(tn: workflow_model.TaskNode): + tn.reference_id._project = settings.project + tn.reference_id._domain = settings.domain + + def _mutate_branch_node_task_ids(bn: workflow_model.BranchNode): + _mutate_node(bn.if_else.case.then_node) + for c in bn.if_else.other: + _mutate_node(c.then_node) + if bn.if_else.else_node: + _mutate_node(bn.if_else.else_node) + + def _mutate_workflow_node(wn: workflow_model.WorkflowNode): + wn.sub_workflow_ref._project = settings.project + wn.sub_workflow_ref._domain = settings.domain + + def _mutate_node(n: workflow_model.Node): + if n.task_node: + _mutate_task_node(n.task_node) + elif n.branch_node: + _mutate_branch_node_task_ids(n.branch_node) + elif n.workflow_node: + _mutate_workflow_node(n.workflow_node) + + for n in entity.flyte_nodes: + _mutate_node(n) + + entity.id._project = settings.project + entity.id._domain = settings.domain + + return entity + + +def get_serializable_flyte_task(entity: "FlyteTask", settings: SerializationSettings) -> FlyteControlPlaneEntity: + """ + TODO replace with deep copy + """ + entity.id._project = settings.project + entity.id._domain = settings.domain + return entity + + def get_serializable( entity_mapping: OrderedDict, settings: SerializationSettings, @@ -563,19 +639,16 @@ def get_serializable( :return: The resulting control plane entity, in addition to being added to the mutable entity_mapping parameter is also returned. """ - # TODO: Try to replace following config refactor - https://github.com/flyteorg/flyte/issues/2214 - from flytekit.remote.launch_plan import FlyteLaunchPlan - from flytekit.remote.task import FlyteTask - from flytekit.remote.workflow import FlyteWorkflow - if entity in entity_mapping: return entity_mapping[entity] + from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow + if isinstance(entity, ReferenceEntity): cp_entity = get_reference_spec(entity_mapping, settings, entity) elif isinstance(entity, PythonTask): - cp_entity = get_serializable_task(entity_mapping, settings, entity) + cp_entity = get_serializable_task(settings, entity) elif isinstance(entity, WorkflowBase): cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options) @@ -589,12 +662,41 @@ def get_serializable( elif isinstance(entity, BranchNode): cp_entity = get_serializable_branch_node(entity_mapping, settings, entity, options) - elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow) or isinstance(entity, FlyteLaunchPlan): + elif isinstance(entity, GateNode): + import ipdb + + ipdb.set_trace() + + elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow): + if entity.should_register: + if isinstance(entity, FlyteTask): + cp_entity = get_serializable_flyte_task(entity, settings) + else: + if entity.should_register: + # We only add the tasks if the should register flag is set. This is to avoid adding + # unnecessary tasks to the registrable list. + for t in entity.flyte_tasks: + get_serializable(entity_mapping, settings, t, options) + cp_entity = get_serializable_flyte_workflow(entity, settings) + else: + cp_entity = entity + + elif isinstance(entity, FlyteLaunchPlan): cp_entity = entity else: raise Exception(f"Non serializable type found {type(entity)} Entity {entity}") + if isinstance(entity, TaskSpec) or isinstance(entity, WorkflowSpec): + # 1. Check if the size of long description exceeds 16KB + # 2. Extract the repo URL from the git config, and assign it to the link of the source code of the description entity + if entity.docs and entity.docs.long_description: + if entity.docs.long_description.value: + if sys.getsizeof(entity.docs.long_description.value) > 16 * 1024 * 1024: + raise ValueError( + "Long Description of the flyte entity exceeds the 16KB size limit. Please specify the uri in the long description instead." + ) + entity.docs.source_code = SourceCode(link=settings.git_repo) # This needs to be at the bottom not the top - i.e. dependent tasks get added before the workflow containing it entity_mapping[entity] = cp_entity return cp_entity @@ -603,7 +705,7 @@ def get_serializable( def gather_dependent_entities( serialized: OrderedDict, ) -> Tuple[ - Dict[_identifier_model.Identifier, task_models.TaskTemplate], + Dict[_identifier_model.Identifier, TaskTemplate], Dict[_identifier_model.Identifier, admin_workflow_models.WorkflowSpec], Dict[_identifier_model.Identifier, _launch_plan_models.LaunchPlanSpec], ]: @@ -616,12 +718,12 @@ def gather_dependent_entities( :param serialized: This should be the filled in OrderedDict used in the get_serializable function above. :return: """ - task_templates: Dict[_identifier_model.Identifier, task_models.TaskTemplate] = {} + task_templates: Dict[_identifier_model.Identifier, TaskTemplate] = {} workflow_specs: Dict[_identifier_model.Identifier, admin_workflow_models.WorkflowSpec] = {} launch_plan_specs: Dict[_identifier_model.Identifier, _launch_plan_models.LaunchPlanSpec] = {} for cp_entity in serialized.values(): - if isinstance(cp_entity, task_models.TaskSpec): + if isinstance(cp_entity, TaskSpec): task_templates[cp_entity.template.id] = cp_entity.template elif isinstance(cp_entity, _launch_plan_models.LaunchPlan): launch_plan_specs[cp_entity.id] = cp_entity.spec diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index 6aa193d7ce..c2ab8fd438 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -11,6 +11,7 @@ FlyteDirectory TensorboardLogs + TFRecordsDirectory """ import typing @@ -26,3 +27,11 @@ This is usually the SummaryWriter output in PyTorch or Keras callbacks which record the history readable by TensorBoard. """ + +tfrecords_dir = typing.TypeVar("tfrecord") +TFRecordsDirectory = FlyteDirectory[tfrecords_dir] +""" + This type can be used to denote that the output is a folder that contains tensorflow record files. + This is usually the TFRecordWriter output in Tensorflow which writes serialised tf.train.Example + message (or protobuf) to tfrecord files +""" diff --git a/flytekit/types/file/__init__.py b/flytekit/types/file/__init__.py index 34bf834a4a..871c48d4c6 100644 --- a/flytekit/types/file/__init__.py +++ b/flytekit/types/file/__init__.py @@ -109,3 +109,8 @@ def check_and_convert_to_str(item: typing.Union[typing.Type, str]) -> str: #: Can be used to receive or return an ONNXFile. The underlying type is a FlyteFile type. This is just a #: decoration and useful for attaching content type information with the file and automatically documenting code. ONNXFile = FlyteFile[onnx] + +tfrecords_file = Annotated[str, FileExt("tfrecord")] +#: Can be used to receive or return an TFRecordFile. The underlying type is a FlyteFile type. This is just a +#: decoration and useful for attaching content type information with the file and automatically documenting code. +TFRecordFile = FlyteFile[tfrecords_file] diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index f486d1012e..8a8d832b58 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -180,6 +180,7 @@ class FlyteSchema(object): """ This is the main schema class that users should use. """ + logger.warning("FlyteSchema is deprecated, use Structured Dataset instead.") @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: @@ -233,7 +234,6 @@ def __init__( supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE, downloader: typing.Callable[[str, os.PathLike], None] = None, ): - if supported_mode == SchemaOpenMode.READ and remote_path is None: raise ValueError("To create a FlyteSchema in read mode, remote_path is required") if ( diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 71dff61c5e..39f8d11e24 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -101,10 +101,10 @@ def decode( return pq.read_table(local_dir) -StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler()) -StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler()) -StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler()) -StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler()) +StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True) +StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True) +StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True) +StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True) StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 99fcb49d7b..0e4649203a 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -34,6 +34,7 @@ # Storage formats PARQUET: StructuredDatasetFormat = "parquet" +GENERIC_FORMAT: StructuredDatasetFormat = "" @dataclass_json @@ -45,9 +46,7 @@ class (that is just a model, a Python class representation of the protobuf). """ uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) - file_format: typing.Optional[str] = field(default=PARQUET, metadata=config(mm_field=fields.String())) - - DEFAULT_FILE_FORMAT = PARQUET + file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String())) @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: @@ -68,6 +67,8 @@ def __init__( # Make these fields public, so that the dataclass transformer can set a value for it # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298 self.uri = uri + # When dataclass_json runs from_json, we need to set it here, otherwise the format will be empty string + self.file_format = kwargs["file_format"] if "file_format" in kwargs else GENERIC_FORMAT # This is a special attribute that indicates if the data was either downloaded or uploaded self._metadata = metadata # This is not for users to set, the transformer will set this. @@ -128,14 +129,14 @@ def extract_cols_and_format( optional str for the format, optional pyarrow Schema """ - fmt = None + fmt = "" ordered_dict_cols = None pa_schema = None if get_origin(t) is Annotated: base_type, *annotate_args = get_args(t) for aa in annotate_args: if isinstance(aa, StructuredDatasetFormat): - if fmt is not None: + if fmt != "": raise ValueError(f"A format was already specified {fmt}, cannot use {aa}") fmt = aa elif isinstance(aa, collections.OrderedDict): @@ -334,21 +335,44 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): Handlers = Union[StructuredDatasetEncoder, StructuredDatasetDecoder] Renderers: Dict[Type, Renderable] = {} - @staticmethod - def _finder(handler_map, df_type: Type, protocol: str, format: str): - try: - return handler_map[df_type][protocol][format] - except KeyError: + @classmethod + def _finder(cls, handler_map, df_type: Type, protocol: str, format: str): + # If the incoming format requested is a specific format (e.g. "avro"), then look for that specific handler + # if missing, see if there's a generic format handler. Error if missing. + # If the incoming format requested is the generic format (""), then see if it's present, + # if not, look to see if there is a default format for the df_type and a handler for that format. + # if still missing, look to see if there's only _one_ handler for that type, if so then use that. + if format != GENERIC_FORMAT: try: - hh = handler_map[df_type][protocol][""] - logger.info( - f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}" - f" format {format}, using default instead." - ) - return hh + return handler_map[df_type][protocol][format] + except KeyError: + try: + return handler_map[df_type][protocol][GENERIC_FORMAT] + except KeyError: + ... + else: + try: + return handler_map[df_type][protocol][GENERIC_FORMAT] except KeyError: - ... - raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt {format}") + if df_type in cls.DEFAULT_FORMATS and cls.DEFAULT_FORMATS[df_type] in handler_map[df_type][protocol]: + hh = handler_map[df_type][protocol][cls.DEFAULT_FORMATS[df_type]] + logger.debug( + f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}" + f" using the generic handler {hh} instead." + ) + return hh + if len(handler_map[df_type][protocol]) == 1: + hh = list(handler_map[df_type][protocol].values())[0] + logger.debug( + f"Using {hh} with format {hh.supported_format} as it's the only one available for {df_type}" + ) + return hh + else: + logger.warning( + f"Did not automatically pick a handler for {df_type}," + f" more than one detected {handler_map[df_type][protocol].keys()}" + ) + raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt |{format}|") @classmethod def get_encoder(cls, df_type: Type, protocol: str, format: str): @@ -381,7 +405,14 @@ def register_renderer(cls, python_type: Type, renderer: Renderable): cls.Renderers[python_type] = renderer @classmethod - def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False): + def register( + cls, + h: Handlers, + default_for_type: bool = False, + override: bool = False, + default_format_for_type: bool = False, + default_storage_for_type: bool = False, + ): """ Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not specify a protocol (e.g. s3, gs, etc.) field, then @@ -395,6 +426,10 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = False, overrid In these cases, the protocol is determined by the raw output data prefix set in the active context. :param override: Override any previous registrations. If default_for_type is also set, this will also override the default. + :param default_format_for_type: Unlike the default_for_type arg that will set this handler's format and storage + as the default, this will only set the format. Error if already set, unless override is specified. + :param default_storage_for_type: Same as above but only for the storage format. Error if already set, + unless override is specified. """ if not (isinstance(h, StructuredDatasetEncoder) or isinstance(h, StructuredDatasetDecoder)): raise TypeError(f"We don't support this type of handler {h}") @@ -409,17 +444,29 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = False, overrid stripped = DataPersistencePlugins.get_protocol(persistence_protocol) logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}") try: - cls.register_for_protocol(h, stripped, False, override) + cls.register_for_protocol( + h, stripped, False, override, default_format_for_type, default_storage_for_type + ) except DuplicateHandlerError: logger.debug(f"Skipping {persistence_protocol}/{stripped} for {h} because duplicate") elif h.protocol == "": raise ValueError(f"Use None instead of empty string for registering handler {h}") else: - cls.register_for_protocol(h, h.protocol, default_for_type, override) + cls.register_for_protocol( + h, h.protocol, default_for_type, override, default_format_for_type, default_storage_for_type + ) @classmethod - def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: bool, override: bool): + def register_for_protocol( + cls, + h: Handlers, + protocol: str, + default_for_type: bool, + override: bool, + default_format_for_type: bool, + default_storage_for_type: bool, + ): """ See the main register function instead. """ @@ -434,12 +481,25 @@ def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: boo lowest_level[h.supported_format] = h logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}") - if default_for_type: - logger.debug( - f"Using storage {protocol} and format {h.supported_format} for dataframes of type {h.python_type} from handler {h}" - ) - cls.DEFAULT_FORMATS[h.python_type] = h.supported_format - cls.DEFAULT_PROTOCOLS[h.python_type] = protocol + if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT: + if h.python_type in cls.DEFAULT_FORMATS and not override: + if cls.DEFAULT_FORMATS[h.python_type] != h.supported_format: + logger.info( + f"Not using handler {h} with format {h.supported_format} as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified." + ) + else: + logger.debug( + f"Setting format {h.supported_format} for dataframes of type {h.python_type} from handler {h}" + ) + cls.DEFAULT_FORMATS[h.python_type] = h.supported_format + if default_storage_for_type or default_for_type: + if h.protocol in cls.DEFAULT_PROTOCOLS and not override: + logger.debug( + f"Not using handler {h} with storage protocol {h.protocol} as default for {h.python_type}, {cls.DEFAULT_PROTOCOLS[h.python_type]} already specified." + ) + else: + logger.debug(f"Using storage {protocol} for dataframes of type {h.python_type} from handler {h}") + cls.DEFAULT_PROTOCOLS[h.python_type] = protocol # Register with the type engine as well # The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as @@ -461,7 +521,7 @@ def to_literal( # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema python_type, *attrs = extract_cols_and_format(python_type) # In case it's a FlyteSchema - sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, None)) + sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT)) if expected and expected.structured_dataset_type: sdt = StructuredDatasetType( @@ -514,16 +574,12 @@ def to_literal( python_val, df_type, protocol, - sdt.format or typing.cast(StructuredDataset, python_val).DEFAULT_FILE_FORMAT, + sdt.format, sdt, ) # Otherwise assume it's a dataframe instance. Wrap it with some defaults - if python_type in self.DEFAULT_FORMATS: - fmt = self.DEFAULT_FORMATS[python_type] - else: - logger.debug(f"No default format for type {python_type}, using system default.") - fmt = StructuredDataset.DEFAULT_FILE_FORMAT + fmt = self.DEFAULT_FORMATS.get(python_type, "") protocol = self._protocol_from_type_or_prefix(ctx, python_type) meta = StructuredDatasetMetadata(structured_dataset_type=expected.structured_dataset_type if expected else None) @@ -760,18 +816,9 @@ def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any] # Get the column information converted_cols = self._convert_ordered_dict_of_columns_to_list(column_map) - - # Get the format - default_format = ( - original_python_type.DEFAULT_FILE_FORMAT - if issubclass(original_python_type, StructuredDataset) - else self.DEFAULT_FORMATS.get(original_python_type, PARQUET) - ) - fmt = storage_format or default_format - return StructuredDatasetType( columns=converted_cols, - format=fmt, + format=storage_format, external_schema_type="arrow" if pa_schema else None, external_schema_bytes=typing.cast(pa.lib.Schema, pa_schema).to_string().encode() if pa_schema else None, ) diff --git a/plugins/flytekit-aws-athena/requirements.txt b/plugins/flytekit-aws-athena/requirements.txt index 0659434eb4..30900d53ff 100644 --- a/plugins/flytekit-aws-athena/requirements.txt +++ b/plugins/flytekit-aws-athena/requirements.txt @@ -1,34 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-athena # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -40,33 +40,37 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-athena -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -77,11 +81,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 # via flytekit -markupsafe==2.1.1 +keyring==23.13.1 + # via flytekit +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -90,20 +96,24 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -114,27 +124,25 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -142,19 +150,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -168,27 +176,34 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py index 8787e70011..0e67b2e50b 100644 --- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -53,7 +53,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: def get_config(self, settings: SerializationSettings) -> Dict[str, str]: # Parameters in taskTemplate config will be used to create aws job definition. # More detail about job definition: https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html - return {"platformCapabilities": self._task_config.platformCapabilities} + return {**super().get_config(settings), "platformCapabilities": self._task_config.platformCapabilities} def get_command(self, settings: SerializationSettings) -> List[str]: container_args = [ diff --git a/plugins/flytekit-aws-batch/requirements.txt b/plugins/flytekit-aws-batch/requirements.txt index b73cef90da..7fa8195e20 100644 --- a/plugins/flytekit-aws-batch/requirements.txt +++ b/plugins/flytekit-aws-batch/requirements.txt @@ -1,34 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-awsbatch # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -40,33 +40,37 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-awsbatch -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -77,11 +81,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 # via flytekit -markupsafe==2.1.1 +keyring==23.13.1 + # via flytekit +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -90,20 +96,24 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -114,27 +124,25 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -142,19 +150,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -168,27 +176,34 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-aws-sagemaker/requirements.txt b/plugins/flytekit-aws-sagemaker/requirements.txt index 82be0fa00b..64d03c18f2 100644 --- a/plugins/flytekit-aws-sagemaker/requirements.txt +++ b/plugins/flytekit-aws-sagemaker/requirements.txt @@ -1,45 +1,44 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-awssagemaker # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time -bcrypt==3.2.2 +bcrypt==4.0.1 # via paramiko binaryornot==0.4.4 # via cookiecutter -boto3==1.24.22 +boto3==1.26.77 # via sagemaker-training -botocore==1.27.22 +botocore==1.29.77 # via # boto3 # s3transfer -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via - # bcrypt # cryptography # pynacl -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # paramiko # pyopenssl @@ -52,39 +51,43 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-awssagemaker -gevent==21.12.0 +gevent==22.10.2 # via sagemaker-training -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -greenlet==1.1.2 +greenlet==2.0.2 # via gevent -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring inotify-simple==1.2.1 # via sagemaker-training +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -99,11 +102,15 @@ jmespath==1.0.1 # via # boto3 # botocore -keyring==23.6.0 +joblib==1.2.0 + # via flytekit +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 - # via jinja2 -marshmallow==3.17.0 +markupsafe==2.1.2 + # via + # jinja2 + # werkzeug +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -112,9 +119,11 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via @@ -123,13 +132,15 @@ numpy==1.21.6 # pyarrow # sagemaker-training # scipy -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via flytekit -paramiko==2.11.0 +paramiko==3.0.0 # via sagemaker-training -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -139,20 +150,18 @@ protobuf==3.20.2 # sagemaker-training protoc-gen-swagger==0.1.0 # via flyteidl -psutil==5.9.1 +psutil==5.9.4 # via sagemaker-training py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi pynacl==1.5.0 # via paramiko -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow @@ -160,13 +169,13 @@ python-dateutil==2.8.2 # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -174,19 +183,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -retrying==1.3.3 +retrying==1.3.4 # via sagemaker-training s3transfer==0.6.0 # via boto3 @@ -194,14 +203,13 @@ sagemaker-training==3.9.2 # via flytekitplugins-awssagemaker scipy==1.7.3 # via sagemaker-training -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 # via # grpcio - # paramiko # python-dateutil # retrying # sagemaker-training @@ -211,36 +219,43 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via # botocore + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -werkzeug==2.1.2 +werkzeug==2.2.3 # via sagemaker-training -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata -zope-event==4.5.0 +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources +zope-event==4.6 # via gevent -zope-interface==5.4.0 +zope-interface==5.5.2 # via gevent # The following packages are considered to be unsafe in a requirements file: diff --git a/plugins/flytekit-bigquery/requirements.txt b/plugins/flytekit-bigquery/requirements.txt index 6ae47dbecf..a9bacca60d 100644 --- a/plugins/flytekit-bigquery/requirements.txt +++ b/plugins/flytekit-bigquery/requirements.txt @@ -1,36 +1,36 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-bigquery # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -cachetools==5.2.0 +cachetools==5.3.0 # via google-auth -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -42,57 +42,58 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-bigquery -google-api-core[grpc]==2.8.2 +google-api-core[grpc]==2.11.0 # via # google-cloud-bigquery - # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.9.0 +google-auth==2.16.1 # via # google-api-core # google-cloud-core -google-cloud-bigquery==3.2.0 +google-cloud-bigquery==3.6.0 # via flytekitplugins-bigquery -google-cloud-bigquery-storage==2.13.2 +google-cloud-core==2.3.2 # via google-cloud-bigquery -google-cloud-core==2.3.1 - # via google-cloud-bigquery -google-crc32c==1.3.0 +google-crc32c==1.5.0 # via google-resumable-media -google-resumable-media==2.3.3 +google-resumable-media==2.4.1 # via google-cloud-bigquery -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # google-api-core # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # google-api-core # google-cloud-bigquery # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via # flytekit # google-api-core -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -103,11 +104,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 + # via flytekit +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -116,32 +119,32 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 +packaging==23.0 # via + # docker # google-cloud-bigquery # marshmallow pandas==1.3.5 # via flytekit -proto-plus==1.20.6 - # via - # google-cloud-bigquery - # google-cloud-bigquery-storage -protobuf==3.20.2 +proto-plus==1.22.2 + # via google-cloud-bigquery +protobuf==3.20.3 # via # flyteidl # flytekit # google-api-core # google-cloud-bigquery - # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status # proto-plus @@ -150,10 +153,8 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 - # via - # flytekit - # google-cloud-bigquery +pyarrow==10.0.1 + # via flytekit pyasn1==0.4.8 # via # pyasn1-modules @@ -162,10 +163,8 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow @@ -173,13 +172,13 @@ python-dateutil==2.8.2 # flytekit # google-cloud-bigquery # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -187,9 +186,9 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker @@ -197,13 +196,13 @@ requests==2.28.1 # google-api-core # google-cloud-bigquery # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -rsa==4.8 +rsa==4.9 # via google-auth -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -218,27 +217,34 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-data-fsspec/requirements.txt b/plugins/flytekit-data-fsspec/requirements.txt index 86d5df921f..74ab73760d 100644 --- a/plugins/flytekit-data-fsspec/requirements.txt +++ b/plugins/flytekit-data-fsspec/requirements.txt @@ -1,36 +1,36 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-data-fsspec # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -botocore==1.27.22 +botocore==1.29.77 # via flytekitplugins-data-fsspec -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -42,35 +42,39 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-data-fsspec -fsspec==2022.5.0 +fsspec==2023.1.0 # via flytekitplugins-data-fsspec -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -83,11 +87,13 @@ jinja2-time==0.2.0 # via cookiecutter jmespath==1.0.1 # via botocore -keyring==23.6.0 +joblib==1.2.0 + # via flytekit +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -96,20 +102,26 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 - # via flytekit -protobuf==3.20.2 + # via + # flytekit + # flytekitplugins-data-fsspec +protobuf==3.20.3 # via # flyteidl # flytekit @@ -120,14 +132,12 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow @@ -135,13 +145,13 @@ python-dateutil==2.8.2 # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -149,19 +159,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -175,28 +185,35 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via # botocore + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-dbt/requirements.txt b/plugins/flytekit-dbt/requirements.txt index ea09d84eb1..6001a4f957 100644 --- a/plugins/flytekit-dbt/requirements.txt +++ b/plugins/flytekit-dbt/requirements.txt @@ -66,9 +66,9 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.13 # via flytekit -flyteidl==1.1.12 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.1 +flytekit==1.2.7 # via flytekitplugins-dbt future==0.18.2 # via parsedatetime diff --git a/plugins/flytekit-deck-standard/requirements.txt b/plugins/flytekit-deck-standard/requirements.txt index 7ec0d09f73..fd8205c8ad 100644 --- a/plugins/flytekit-deck-standard/requirements.txt +++ b/plugins/flytekit-deck-standard/requirements.txt @@ -1,52 +1,46 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.7 # by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-deck-standard # via -r requirements.in -appnope==0.1.3 - # via - # ipykernel - # ipython -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time -asttokens==2.2.1 - # via stack-data -attrs==21.4.0 +attrs==22.2.0 # via visions backcall==0.2.0 # via ipython binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit -comm==0.1.2 - # via ipykernel cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 - # via pyopenssl +cryptography==39.0.1 + # via + # pyopenssl + # secretstorage cycler==0.11.0 # via matplotlib dataclasses-json==0.5.7 # via flytekit -debugpy==1.6.4 +debugpy==1.6.6 # via ipykernel decorator==5.1.1 # via @@ -56,81 +50,86 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit entrypoints==0.4 # via jupyter-client -executing==1.2.0 - # via stack-data -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-deck-standard -fonttools==4.33.3 +fonttools==4.38.0 # via matplotlib -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit htmlmin==0.1.12 - # via pandas-profiling -idna==3.3 + # via ydata-profiling +idna==3.4 # via requests -imagehash==4.2.1 +imagehash==4.3.1 # via visions -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via + # click # flytekit # keyring # markdown -ipykernel==6.19.4 +importlib-resources==5.12.0 + # via keyring +ipykernel==6.16.2 # via ipywidgets -ipython==8.7.0 +ipython==7.34.0 # via # ipykernel # ipywidgets ipywidgets==8.0.4 # via flytekitplugins-deck-standard +jaraco-classes==3.2.3 + # via keyring jedi==0.18.2 # via ipython +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time - # pandas-profiling + # ydata-profiling jinja2-time==0.2.0 # via cookiecutter -joblib==1.1.0 +joblib==1.2.0 # via - # pandas-profiling + # flytekit # phik -jupyter-client==7.4.8 +jupyter-client==7.4.9 # via ipykernel -jupyter-core==5.1.1 +jupyter-core==4.12.0 # via jupyter-client jupyterlab-widgets==3.0.5 # via ipywidgets -keyring==23.6.0 +keyring==23.13.1 # via flytekit -kiwisolver==1.4.3 +kiwisolver==1.4.4 # via matplotlib -markdown==3.3.7 +markdown==3.4.1 # via flytekitplugins-deck-standard -markupsafe==2.1.1 - # via - # jinja2 - # pandas-profiling -marshmallow==3.17.0 +markupsafe==2.1.2 + # via jinja2 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -139,25 +138,24 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.5.2 +matplotlib==3.5.3 # via - # missingno - # pandas-profiling # phik # seaborn + # ydata-profiling matplotlib-inline==0.1.6 # via # ipykernel # ipython -missingno==0.5.1 - # via pandas-profiling -multimethod==1.8 +more-itertools==9.0.0 + # via jaraco-classes +multimethod==1.9.1 # via - # pandas-profiling # visions -mypy-extensions==0.4.3 + # ydata-profiling +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit nest-asyncio==1.5.6 # via @@ -167,51 +165,55 @@ networkx==2.6.3 # via visions numpy==1.21.6 # via + # flytekit # imagehash # matplotlib - # missingno # pandas - # pandas-profiling + # patsy # phik # pyarrow - # pywavelets # scipy # seaborn + # statsmodels # visions -packaging==21.3 + # ydata-profiling +packaging==23.0 # via + # docker # ipykernel # marshmallow # matplotlib + # statsmodels pandas==1.3.5 # via # flytekit - # pandas-profiling # phik # seaborn + # statsmodels # visions -pandas-profiling==3.2.0 + # ydata-profiling +pandas-profiling==3.6.6 # via flytekitplugins-deck-standard parso==0.8.3 # via jedi +patsy==0.5.3 + # via statsmodels pexpect==4.8.0 # via ipython -phik==0.12.2 - # via pandas-profiling +phik==0.12.3 + # via ydata-profiling pickleshare==0.7.5 # via ipython -pillow==9.2.0 +pillow==9.4.0 # via # imagehash # matplotlib # visions -platformdirs==2.6.0 - # via jupyter-core -plotly==5.9.0 +plotly==5.13.0 # via flytekitplugins-deck-standard -prompt-toolkit==3.0.36 +prompt-toolkit==3.0.37 # via ipython -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -224,24 +226,20 @@ psutil==5.9.4 # via ipykernel ptyprocess==0.7.0 # via pexpect -pure-eval==0.2.2 - # via stack-data py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pydantic==1.9.1 - # via pandas-profiling -pygments==2.13.0 +pydantic==1.10.5 + # via ydata-profiling +pygments==2.14.0 # via ipython -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 - # via - # matplotlib - # packaging + # via matplotlib python-dateutil==2.8.2 # via # arrow @@ -250,13 +248,13 @@ python-dateutil==2.8.2 # jupyter-client # matplotlib # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -266,89 +264,98 @@ pyyaml==6.0 # via # cookiecutter # flytekit - # pandas-profiling -pyzmq==24.0.1 + # ydata-profiling +pyzmq==25.0.0 # via # ipykernel # jupyter-client -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit - # pandas-profiling # responses -responses==0.21.0 + # ydata-profiling +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit scipy==1.7.3 # via # imagehash - # missingno - # pandas-profiling # phik - # seaborn -seaborn==0.11.2 - # via - # missingno - # pandas-profiling + # statsmodels + # ydata-profiling +seaborn==0.12.2 + # via ydata-profiling +secretstorage==3.3.3 + # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via - # asttokens # grpcio - # imagehash + # patsy # python-dateutil sortedcontainers==2.4.0 # via flytekit -stack-data==0.6.2 - # via ipython statsd==3.3.0 # via flytekit +statsmodels==0.13.5 + # via ydata-profiling tangled-up-in-unicode==0.2.0 - # via - # pandas-profiling - # visions -tenacity==8.0.1 + # via visions +tenacity==8.2.1 # via plotly text-unidecode==1.3 # via python-slugify +toml==0.10.2 + # via responses tornado==6.2 # via # ipykernel # jupyter-client -tqdm==4.64.0 - # via pandas-profiling -traitlets==5.8.0 +tqdm==4.64.1 + # via ydata-profiling +traitlets==5.9.0 # via - # comm # ipykernel # ipython # ipywidgets # jupyter-client # jupyter-core # matplotlib-inline -typing-extensions==4.3.0 +typeguard==2.13.3 + # via ydata-profiling +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via + # arrow # flytekit + # importlib-metadata + # kiwisolver # pydantic + # responses + # seaborn # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -visions[type_image_path]==0.7.4 - # via pandas-profiling -wcwidth==0.2.5 +visions[type_image_path]==0.7.5 + # via ydata-profiling +wcwidth==0.2.6 # via prompt-toolkit -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit widgetsnbextension==4.0.5 # via ipywidgets @@ -356,8 +363,12 @@ wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +ydata-profiling==4.0.0 + # via pandas-profiling +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/plugins/flytekit-deck-standard/tests/test_renderer.py b/plugins/flytekit-deck-standard/tests/test_renderer.py index 270d53b12e..79eb7e877d 100644 --- a/plugins/flytekit-deck-standard/tests/test_renderer.py +++ b/plugins/flytekit-deck-standard/tests/test_renderer.py @@ -7,7 +7,7 @@ def test_frame_profiling_renderer(): renderer = FrameProfilingRenderer() - assert "Profile Report Generated With The `Pandas-Profiling`" in renderer.to_html(df).title() + assert "Profile Report Generated By Ydata!" in renderer.to_html(df).title() def test_markdown_renderer(): diff --git a/plugins/flytekit-dolt/requirements.txt b/plugins/flytekit-dolt/requirements.txt index b619ca9520..e5f0595edc 100644 --- a/plugins/flytekit-dolt/requirements.txt +++ b/plugins/flytekit-dolt/requirements.txt @@ -1,34 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-dolt # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -42,37 +42,41 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit dolt-integrations==0.1.5 # via flytekitplugins-dolt -doltcli==0.1.17 +doltcli==0.1.18 # via dolt-integrations -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-dolt -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -83,11 +87,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 # via flytekit -markupsafe==2.1.1 +keyring==23.13.1 + # via flytekit +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -96,22 +102,26 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via # dolt-integrations # flytekit -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -122,27 +132,25 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -150,19 +158,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -176,27 +184,36 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +typed-ast==1.5.4 + # via doltcli +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-greatexpectations/requirements.txt b/plugins/flytekit-greatexpectations/requirements.txt index 57fd7b20d3..1e4051a41c 100644 --- a/plugins/flytekit-greatexpectations/requirements.txt +++ b/plugins/flytekit-greatexpectations/requirements.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # @@ -8,13 +8,18 @@ # via -r requirements.in altair==4.2.0 # via great-expectations +anyio==3.6.2 + # via jupyter-server argon2-cffi==21.3.0 - # via notebook + # via + # jupyter-server + # nbclassic + # notebook argon2-cffi-bindings==21.2.0 # via argon2-cffi -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time -attrs==21.4.0 +attrs==22.2.0 # via jsonschema backcall==0.2.0 # via ipython @@ -22,43 +27,43 @@ backports-zoneinfo==0.2.1 # via # pytz-deprecation-shim # tzlocal -beautifulsoup4==4.11.1 +beautifulsoup4==4.11.2 # via nbconvert binaryornot==0.4.4 # via cookiecutter -bleach==5.0.1 +bleach==6.0.0 # via nbconvert -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via # argon2-cffi-bindings # cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit # great-expectations -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit -colorama==0.4.5 +colorama==0.4.6 # via great-expectations cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # great-expectations # pyopenssl # secretstorage dataclasses-json==0.5.7 # via flytekit -debugpy==1.6.0 +debugpy==1.6.6 # via ipykernel decorator==5.1.1 # via @@ -70,58 +75,71 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit entrypoints==0.4 # via # altair # jupyter-client - # nbconvert -fastjsonschema==2.15.3 +fastjsonschema==2.16.2 # via nbformat -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-great-expectations -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status great-expectations==0.15.12 - # via flytekitplugins-great-expectations -greenlet==1.1.2 + # via + # -r requirements.in + # flytekitplugins-great-expectations +greenlet==2.0.2 # via sqlalchemy -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 - # via requests -importlib-metadata==4.12.0 +idna==3.4 + # via + # anyio + # requests +importlib-metadata==6.0.0 # via # click # flytekit # great-expectations # jsonschema # keyring + # nbconvert + # nbformat # sqlalchemy -importlib-resources==5.8.0 - # via jsonschema -ipykernel==6.15.0 - # via notebook +importlib-resources==5.12.0 + # via + # jsonschema + # keyring +ipykernel==6.16.2 + # via + # nbclassic + # notebook ipython==7.34.0 # via # great-expectations # ipykernel ipython-genutils==0.2.0 - # via notebook -jedi==0.18.1 + # via + # nbclassic + # notebook +jaraco-classes==3.2.3 + # via keyring +jedi==0.18.2 # via ipython jeepney==0.8.0 # via @@ -133,39 +151,52 @@ jinja2==3.1.2 # cookiecutter # great-expectations # jinja2-time + # jupyter-server + # nbclassic # nbconvert # notebook jinja2-time==0.2.0 # via cookiecutter +joblib==1.2.0 + # via flytekit jsonpatch==1.32 # via great-expectations jsonpointer==2.3 # via jsonpatch -jsonschema==4.6.1 +jsonschema==4.17.3 # via # altair # great-expectations # nbformat -jupyter-client==7.3.4 +jupyter-client==7.4.9 # via # ipykernel + # jupyter-server + # nbclassic # nbclient # notebook -jupyter-core==4.11.2 +jupyter-core==4.12.0 # via # jupyter-client + # jupyter-server + # nbclassic + # nbclient # nbconvert # nbformat # notebook +jupyter-server==1.23.6 + # via + # nbclassic + # notebook-shim jupyterlab-pygments==0.2.2 # via nbconvert -keyring==23.6.0 +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +markupsafe==2.1.2 # via # jinja2 # nbconvert -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -174,36 +205,47 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib-inline==0.1.3 +matplotlib-inline==0.1.6 # via # ipykernel # ipython -mistune==2.0.3 +mistune==2.0.5 # via # great-expectations # nbconvert -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit -nbclient==0.6.6 - # via nbconvert -nbconvert==7.0.0rc2 +nbclassic==0.5.2 # via notebook -nbformat==5.4.0 +nbclient==0.7.2 + # via nbconvert +nbconvert==7.2.9 + # via + # jupyter-server + # nbclassic + # notebook +nbformat==5.7.3 # via # great-expectations + # jupyter-server + # nbclassic # nbclient # nbconvert # notebook -nest-asyncio==1.5.5 +nest-asyncio==1.5.6 # via # ipykernel # jupyter-client - # nbclient + # nbclassic # notebook -notebook==6.4.12 +notebook==6.5.2 # via great-expectations +notebook-shim==0.2.2 + # via nbclassic numpy==1.21.6 # via # altair @@ -212,10 +254,12 @@ numpy==1.21.6 # pandas # pyarrow # scipy -packaging==21.3 +packaging==23.0 # via + # docker # great-expectations # ipykernel + # jupyter-server # marshmallow # nbconvert pandas==1.3.5 @@ -231,11 +275,16 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -prometheus-client==0.14.1 - # via notebook -prompt-toolkit==3.0.30 +pkgutil-resolve-name==1.3.10 + # via jsonschema +prometheus-client==0.16.0 + # via + # jupyter-server + # nbclassic + # notebook +prompt-toolkit==3.0.37 # via ipython -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -244,7 +293,7 @@ protobuf==3.20.2 # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -psutil==5.9.1 +psutil==5.9.4 # via ipykernel ptyprocess==0.7.0 # via @@ -252,21 +301,19 @@ ptyprocess==0.7.0 # terminado py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pygments==2.12.0 +pygments==2.14.0 # via # ipython # nbconvert -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit pyparsing==2.4.7 - # via - # great-expectations - # packaging -pyrsistent==0.18.1 + # via great-expectations +pyrsistent==0.19.3 # via jsonschema python-dateutil==2.8.2 # via @@ -276,13 +323,13 @@ python-dateutil==2.8.2 # great-expectations # jupyter-client # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # great-expectations @@ -293,34 +340,39 @@ pyyaml==6.0 # via # cookiecutter # flytekit -pyzmq==23.2.0 +pyzmq==25.0.0 # via # ipykernel # jupyter-client + # jupyter-server + # nbclassic # notebook -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # great-expectations # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit ruamel-yaml==0.17.17 # via great-expectations -ruamel-yaml-clib==0.2.6 +ruamel-yaml-clib==0.2.7 # via ruamel-yaml scipy==1.7.3 # via great-expectations -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring send2trash==1.8.0 - # via notebook + # via + # jupyter-server + # nbclassic + # notebook singledispatchmethod==1.0 # via flytekit six==1.16.0 @@ -328,9 +380,11 @@ six==1.16.0 # bleach # grpcio # python-dateutil +sniffio==1.3.0 + # via anyio sortedcontainers==2.4.0 # via flytekit -soupsieve==2.3.2.post1 +soupsieve==2.4 # via beautifulsoup4 sqlalchemy==1.4.39 # via @@ -338,37 +392,49 @@ sqlalchemy==1.4.39 # flytekitplugins-great-expectations statsd==3.3.0 # via flytekit -termcolor==1.1.0 +termcolor==2.2.0 # via great-expectations -terminado==0.15.0 - # via notebook +terminado==0.17.1 + # via + # jupyter-server + # nbclassic + # notebook text-unidecode==1.3 # via python-slugify -tinycss2==1.1.1 +tinycss2==1.2.1 # via nbconvert -toolz==0.11.2 +toml==0.10.2 + # via responses +toolz==0.12.0 # via altair tornado==6.2 # via # ipykernel # jupyter-client + # jupyter-server + # nbclassic # notebook # terminado -tqdm==4.64.0 +tqdm==4.64.1 # via great-expectations -traitlets==5.3.0 +traitlets==5.9.0 # via # ipykernel # ipython # jupyter-client # jupyter-core + # jupyter-server # matplotlib-inline + # nbclassic # nbclient # nbconvert # nbformat # notebook -typing-extensions==4.3.0 +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via + # anyio # argon2-cffi # arrow # flytekit @@ -377,33 +443,36 @@ typing-extensions==4.3.0 # jsonschema # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -tzdata==2022.1 +tzdata==2022.7 # via pytz-deprecation-shim tzlocal==4.2 # via great-expectations -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # great-expectations # requests # responses -wcwidth==0.2.5 +wcwidth==0.2.6 # via prompt-toolkit webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.3.3 - # via docker -wheel==0.38.0 +websocket-client==1.5.1 + # via + # docker + # jupyter-server +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 +zipp==3.14.0 # via # importlib-metadata # importlib-resources diff --git a/plugins/flytekit-hive/requirements.txt b/plugins/flytekit-hive/requirements.txt index b6a0b8b490..50dfef8807 100644 --- a/plugins/flytekit-hive/requirements.txt +++ b/plugins/flytekit-hive/requirements.txt @@ -1,34 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-hive # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -40,33 +40,37 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-hive -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -77,11 +81,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 # via flytekit -markupsafe==2.1.1 +keyring==23.13.1 + # via flytekit +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -90,20 +96,24 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -114,27 +124,25 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -142,19 +150,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -168,27 +176,34 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py index 0690179bb1..579efd366c 100644 --- a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py +++ b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py @@ -1,3 +1,4 @@ +import os import typing import datasets @@ -59,12 +60,11 @@ def decode( ) -> datasets.Dataset: local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) - path = f"{local_dir}/00000" - + files = [item.path for item in os.scandir(local_dir)] if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return datasets.Dataset.from_parquet(path, columns=columns) - return datasets.Dataset.from_parquet(path) + return datasets.Dataset.from_parquet(files, columns=columns) + return datasets.Dataset.from_parquet(files) StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler()) diff --git a/plugins/flytekit-huggingface/requirements.txt b/plugins/flytekit-huggingface/requirements.txt index 8cabda2e96..07cefbae2c 100644 --- a/plugins/flytekit-huggingface/requirements.txt +++ b/plugins/flytekit-huggingface/requirements.txt @@ -1,32 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-huggingface # via -r requirements.in -aiohttp==3.8.1 +aiohttp==3.8.4 # via # datasets # fsspec -aiosignal==1.2.0 +aiosignal==1.3.1 # via aiohttp -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time async-timeout==4.0.2 # via aiohttp -attrs==21.4.0 +asynctest==0.13.0 + # via aiohttp +attrs==22.2.0 # via aiohttp binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via # aiohttp # requests @@ -34,77 +36,92 @@ click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 - # via pyopenssl +cryptography==39.0.1 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit -datasets==2.4.0 +datasets==2.10.0 # via flytekitplugins-huggingface decorator==5.1.1 # via retry deprecated==1.2.13 # via flytekit -dill==0.3.5.1 +dill==0.3.6 # via # datasets # multiprocess diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -filelock==3.7.1 +filelock==3.9.0 # via huggingface-hub -flyteidl==1.1.9 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-huggingface -frozenlist==1.3.0 +frozenlist==1.3.3 # via # aiohttp # aiosignal -fsspec[http]==2022.7.0 +fsspec[http]==2023.1.0 # via datasets -googleapis-common-protos==1.56.4 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -huggingface-hub==0.8.1 +huggingface-hub==0.12.1 # via datasets -idna==3.3 +idna==3.4 # via # requests # yarl -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via + # click + # datasets # flytekit + # huggingface-hub # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.7.0 +joblib==1.2.0 # via flytekit -markupsafe==2.1.1 +keyring==23.13.1 + # via flytekit +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -113,31 +130,35 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -multidict==6.0.2 +more-itertools==9.0.0 + # via jaraco-classes +multidict==6.0.4 # via # aiohttp # yarl -multiprocess==0.70.13 +multiprocess==0.70.14 # via datasets -mypy-extensions==0.4.3 +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit -numpy==1.23.1 +numpy==1.21.6 # via # datasets + # flytekit # pandas # pyarrow -packaging==21.3 +packaging==23.0 # via # datasets + # docker # huggingface-hub # marshmallow -pandas==1.4.3 +pandas==1.3.5 # via # datasets # flytekit -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -148,40 +169,39 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via # datasets # flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.4 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas pyyaml==6.0 # via # cookiecutter + # datasets # flytekit # huggingface-hub -regex==2022.7.25 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # datasets @@ -196,6 +216,10 @@ responses==0.18.0 # flytekit retry==0.9.2 # via flytekit +secretstorage==3.3.3 + # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via # grpcio @@ -206,33 +230,41 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -tqdm==4.64.0 +tqdm==4.64.1 # via # datasets # huggingface-hub -typing-extensions==4.3.0 +typing-extensions==4.5.0 # via + # aiohttp + # arrow + # async-timeout # flytekit # huggingface-hub + # importlib-metadata # typing-inspect -typing-inspect==0.7.1 + # yarl +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.11 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -xxhash==3.0.0 +xxhash==3.2.0 # via datasets -yarl==1.7.2 +yarl==1.8.2 # via aiohttp -zipp==3.8.1 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-huggingface/setup.py b/plugins/flytekit-huggingface/setup.py index 477c7a1a7c..acdbc20810 100644 --- a/plugins/flytekit-huggingface/setup.py +++ b/plugins/flytekit-huggingface/setup.py @@ -38,4 +38,5 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py index 170fdc3789..5b65b2511c 100644 --- a/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py +++ b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py @@ -68,3 +68,15 @@ def test_datasets_renderer(): df = pd.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) dataset = datasets.Dataset.from_pandas(df) assert HuggingFaceDatasetRenderer().to_html(dataset) == str(dataset).replace("\n", "
") + + +def test_parquet_to_datasets(): + df = pd.DataFrame({"name": ["Alice"], "age": [10]}) + + @task + def create_sd() -> StructuredDataset: + return StructuredDataset(dataframe=df) + + sd = create_sd() + dataset = sd.open(datasets.Dataset).all() + assert dataset.data == datasets.Dataset.from_pandas(df).data diff --git a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py index e81728ddb4..9e8e5ef937 100644 --- a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py +++ b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Tuple, Union from flyteidl.core import tasks_pb2 as _core_task @@ -12,56 +13,38 @@ from flytekit.models import task as _task_models _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" +PRIMARY_CONTAINER_DEFAULT_NAME = "primary" def _sanitize_resource_name(resource: _task_models.Resources.ResourceEntry) -> str: return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") +@dataclass class Pod(object): """ Pod is a platform-wide configuration that uses pod templates. By default, every task is launched as a container in a pod. This plugin helps expose a fully modifiable Kubernetes pod spec to customize the task execution runtime. To use pod tasks: (1) Define a pod spec, and (2) Specify the primary container name. :param V1PodSpec pod_spec: Kubernetes pod spec. https://kubernetes.io/docs/concepts/workloads/pods - :param str primary_container_name: the primary container name + :param str primary_container_name: the primary container name. If provided the pod-spec can contain a container whose name matches the primary_container_name. This will force Flyte to give up control of the primary + container and will expect users to control setting up the container. If you expect your python function to run as is, simply create containers that do not match the default primary-container-name and Flyte will auto-inject a + container for the python function based on the default image provided during serialization. :param Optional[Dict[str, str]] labels: Labels are key/value pairs that are attached to pod spec :param Optional[Dict[str, str]] annotations: Annotations are key/value pairs that are attached to arbitrary non-identifying metadata to pod spec. """ - def __init__( - self, - pod_spec: V1PodSpec, - primary_container_name: str, - labels: Optional[Dict[str, str]] = None, - annotations: Optional[Dict[str, str]] = None, - ): - if not pod_spec: + pod_spec: V1PodSpec + primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME + labels: Optional[Dict[str, str]] = None + annotations: Optional[Dict[str, str]] = None + + def __post_init__(self): + if not self.pod_spec: raise _user_exceptions.FlyteValidationException("A pod spec cannot be undefined") - if not primary_container_name: + if not self.primary_container_name: raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined") - self._pod_spec = pod_spec - self._primary_container_name = primary_container_name - self._labels = labels - self._annotations = annotations - - @property - def pod_spec(self) -> V1PodSpec: - return self._pod_spec - - @property - def primary_container_name(self) -> str: - return self._primary_container_name - - @property - def labels(self) -> Optional[Dict[str, str]]: - return self._labels - - @property - def annotations(self) -> Optional[Dict[str, str]]: - return self._annotations - class PodFunctionTask(PythonFunctionTask[Pod]): def __init__(self, task_config: Pod, task_function: Callable, **kwargs): @@ -114,7 +97,7 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any] final_containers.append(container) - self.task_config._pod_spec.containers = final_containers + self.task_config.pod_spec.containers = final_containers return ApiClient().sanitize_for_serialization(self.task_config.pod_spec) diff --git a/plugins/flytekit-k8s-pod/requirements.txt b/plugins/flytekit-k8s-pod/requirements.txt index 75036fceb9..9c911b1263 100644 --- a/plugins/flytekit-k8s-pod/requirements.txt +++ b/plugins/flytekit-k8s-pod/requirements.txt @@ -1,38 +1,38 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-pod # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -cachetools==5.2.0 +cachetools==5.3.0 # via google-auth -certifi==2022.6.15 +certifi==2022.12.7 # via # kubernetes # requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -44,35 +44,39 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-pod -google-auth==2.9.0 +google-auth==2.16.1 # via kubernetes -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -83,13 +87,15 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 # via flytekit -kubernetes==24.2.0 +keyring==23.13.1 + # via flytekit +kubernetes==26.1.0 # via flytekitplugins-pod -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -98,22 +104,26 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -oauthlib==3.2.1 +oauthlib==3.2.2 # via requests-oauthlib -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -124,7 +134,7 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pyasn1==0.4.8 # via @@ -134,10 +144,8 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow @@ -145,13 +153,13 @@ python-dateutil==2.8.2 # flytekit # kubernetes # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -160,9 +168,9 @@ pyyaml==6.0 # cookiecutter # flytekit # kubernetes -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker @@ -172,13 +180,13 @@ requests==2.28.1 # responses requests-oauthlib==1.3.1 # via kubernetes -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -rsa==4.8 +rsa==4.9 # via google-auth -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -194,33 +202,40 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # kubernetes # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via # docker # kubernetes -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/plugins/flytekit-k8s-pod/tests/test_pod.py b/plugins/flytekit-k8s-pod/tests/test_pod.py index 716190b4df..0d6788ac92 100644 --- a/plugins/flytekit-k8s-pod/tests/test_pod.py +++ b/plugins/flytekit-k8s-pod/tests/test_pod.py @@ -12,6 +12,7 @@ from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.type_engine import TypeEngine +from flytekit.exceptions import user from flytekit.extend import ExecutionState from flytekit.tools.translator import get_serializable @@ -473,3 +474,32 @@ def dynamic_task_with_pod_subtask(dummy_input: str) -> str: assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0]["resources"]["requests"]["gpu"] == "1" assert context_manager.FlyteContextManager.size() == 1 + + +def test_pod_config(): + with pytest.raises(user.FlyteValidationException): + Pod(pod_spec=None) + + with pytest.raises(user.FlyteValidationException): + Pod(pod_spec=V1PodSpec(containers=[]), primary_container_name=None) + + selector = {"node_group": "memory"} + + @task( + task_config=Pod( + pod_spec=V1PodSpec( + containers=[], + node_selector=selector, + ), + ), + requests=Resources( + mem="1G", + ), + ) + def my_pod_task(): + print("hello world") + time.sleep(30000) + + assert my_pod_task.task_config + assert isinstance(my_pod_task.task_config, Pod) + assert my_pod_task.task_config.pod_spec.node_selector == selector diff --git a/plugins/flytekit-kf-mpi/requirements.txt b/plugins/flytekit-kf-mpi/requirements.txt index 5c3a8a8efa..bbf6b17f0d 100644 --- a/plugins/flytekit-kf-mpi/requirements.txt +++ b/plugins/flytekit-kf-mpi/requirements.txt @@ -46,11 +46,11 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.14.1 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via # flytekit # flytekitplugins-kfmpi -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-kfmpi googleapis-common-protos==1.56.3 # via diff --git a/plugins/flytekit-kf-pytorch/requirements.txt b/plugins/flytekit-kf-pytorch/requirements.txt index fc354d83bd..96fa577a3e 100644 --- a/plugins/flytekit-kf-pytorch/requirements.txt +++ b/plugins/flytekit-kf-pytorch/requirements.txt @@ -1,34 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kfpytorch # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -40,33 +40,37 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-kfpytorch -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -77,11 +81,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 # via flytekit -markupsafe==2.1.1 +keyring==23.13.1 + # via flytekit +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -90,20 +96,24 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -114,27 +124,25 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -142,19 +150,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -168,27 +176,34 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-kf-tensorflow/requirements.txt b/plugins/flytekit-kf-tensorflow/requirements.txt index d1dc578fdf..60f87a8ac1 100644 --- a/plugins/flytekit-kf-tensorflow/requirements.txt +++ b/plugins/flytekit-kf-tensorflow/requirements.txt @@ -1,34 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kftensorflow # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -40,33 +40,37 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-kftensorflow -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -77,11 +81,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 # via flytekit -markupsafe==2.1.1 +keyring==23.13.1 + # via flytekit +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -90,20 +96,24 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -114,27 +124,25 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -142,19 +150,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -168,27 +176,34 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-mlflow/README.md b/plugins/flytekit-mlflow/README.md new file mode 100644 index 0000000000..6cbee9cf59 --- /dev/null +++ b/plugins/flytekit-mlflow/README.md @@ -0,0 +1,22 @@ +# Flytekit MLflow Plugin + +MLflow enables us to log parameters, code, and results in machine learning experiments and compare them using an interactive UI. +This MLflow plugin enables seamless use of MLFlow within Flyte, and render the metrics and parameters on Flyte Deck. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-mlflow +``` + +Example +```python +from flytekit import task, workflow +from flytekitplugins.mlflow import mlflow_autolog +import mlflow + +@task(disable_deck=False) +@mlflow_autolog(framework=mlflow.keras) +def train_model(): + ... +``` diff --git a/plugins/flytekit-mlflow/dev-requirements.in b/plugins/flytekit-mlflow/dev-requirements.in new file mode 100644 index 0000000000..0f57144081 --- /dev/null +++ b/plugins/flytekit-mlflow/dev-requirements.in @@ -0,0 +1 @@ +tensorflow diff --git a/plugins/flytekit-mlflow/dev-requirements.txt b/plugins/flytekit-mlflow/dev-requirements.txt new file mode 100644 index 0000000000..6ad9be49bb --- /dev/null +++ b/plugins/flytekit-mlflow/dev-requirements.txt @@ -0,0 +1,122 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile dev-requirements.in +# +absl-py==1.3.0 + # via + # tensorboard + # tensorflow +astunparse==1.6.3 + # via tensorflow +cachetools==5.2.0 + # via google-auth +certifi==2022.9.24 + # via requests +charset-normalizer==2.1.1 + # via requests +flatbuffers==22.10.26 + # via tensorflow +gast==0.4.0 + # via tensorflow +google-auth==2.14.1 + # via + # google-auth-oauthlib + # tensorboard +google-auth-oauthlib==0.4.6 + # via tensorboard +google-pasta==0.2.0 + # via tensorflow +grpcio==1.50.0 + # via + # tensorboard + # tensorflow +h5py==3.7.0 + # via tensorflow +idna==3.4 + # via requests +importlib-metadata==5.0.0 + # via markdown +keras==2.10.0 + # via tensorflow +keras-preprocessing==1.1.2 + # via tensorflow +libclang==14.0.6 + # via tensorflow +markdown==3.4.1 + # via tensorboard +markupsafe==2.1.1 + # via werkzeug +numpy==1.23.4 + # via + # h5py + # keras-preprocessing + # opt-einsum + # tensorboard + # tensorflow +oauthlib==3.2.2 + # via requests-oauthlib +opt-einsum==3.3.0 + # via tensorflow +packaging==21.3 + # via tensorflow +protobuf==3.19.6 + # via + # tensorboard + # tensorflow +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth +pyparsing==3.0.9 + # via packaging +requests==2.28.1 + # via + # requests-oauthlib + # tensorboard +requests-oauthlib==1.3.1 + # via google-auth-oauthlib +rsa==4.9 + # via google-auth +six==1.16.0 + # via + # astunparse + # google-auth + # google-pasta + # grpcio + # keras-preprocessing + # tensorflow +tensorboard==2.10.1 + # via tensorflow +tensorboard-data-server==0.6.1 + # via tensorboard +tensorboard-plugin-wit==1.8.1 + # via tensorboard +tensorflow==2.10.0 + # via -r dev-requirements.in +tensorflow-estimator==2.10.0 + # via tensorflow +tensorflow-io-gcs-filesystem==0.27.0 + # via tensorflow +termcolor==2.1.0 + # via tensorflow +typing-extensions==4.4.0 + # via tensorflow +urllib3==1.26.12 + # via requests +werkzeug==2.2.2 + # via tensorboard +wheel==0.38.3 + # via + # astunparse + # tensorboard +wrapt==1.14.1 + # via tensorflow +zipp==3.10.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py b/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py new file mode 100644 index 0000000000..98e84547e0 --- /dev/null +++ b/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py @@ -0,0 +1,13 @@ +""" +.. currentmodule:: flytekitplugins.mlflow + +This plugin enables seamless integration between Flyte and mlflow. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + mlflow_autolog +""" + +from .tracking import mlflow_autolog diff --git a/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py new file mode 100644 index 0000000000..b58aa4a120 --- /dev/null +++ b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py @@ -0,0 +1,140 @@ +import typing +from functools import partial, wraps + +import mlflow +import pandas +import pandas as pd +import plotly.graph_objects as go +from mlflow import MlflowClient +from mlflow.entities.metric import Metric +from plotly.subplots import make_subplots + +import flytekit +from flytekit import FlyteContextManager +from flytekit.bin.entrypoint import get_one_of +from flytekit.core.context_manager import ExecutionState +from flytekit.deck import TopFrameRenderer + + +def metric_to_df(metrics: typing.List[Metric]) -> pd.DataFrame: + """ + Converts mlflow Metric object to a dataframe of 2 columns ['timestamp', 'value'] + """ + t = [] + v = [] + for m in metrics: + t.append(m.timestamp) + v.append(m.value) + return pd.DataFrame(list(zip(t, v)), columns=["timestamp", "value"]) + + +def get_run_metrics(c: MlflowClient, run_id: str) -> typing.Dict[str, pandas.DataFrame]: + """ + Extracts all metrics and returns a dictionary of metric name to the list of metric for the given run_id + """ + r = c.get_run(run_id) + metrics = {} + for k in r.data.metrics.keys(): + metrics[k] = metric_to_df(metrics=c.get_metric_history(run_id=run_id, key=k)) + return metrics + + +def get_run_params(c: MlflowClient, run_id: str) -> typing.Optional[pd.DataFrame]: + """ + Extracts all parameters and returns a dictionary of metric name to the list of metric for the given run_id + """ + r = c.get_run(run_id) + name = [] + value = [] + if r.data.params == {}: + return None + for k, v in r.data.params.items(): + name.append(k) + value.append(v) + return pd.DataFrame(list(zip(name, value)), columns=["name", "value"]) + + +def plot_metrics(metrics: typing.Dict[str, pandas.DataFrame]) -> typing.Optional[go.Figure]: + v = len(metrics) + if v == 0: + return None + + # Initialize figure with subplots + fig = make_subplots(rows=v, cols=1, subplot_titles=list(metrics.keys())) + + # Add traces + row = 1 + for k, v in metrics.items(): + v["timestamp"] = (v["timestamp"] - v["timestamp"][0]) / 1000 + fig.add_trace(go.Scatter(x=v["timestamp"], y=v["value"], name=k), row=row, col=1) + row = row + 1 + + fig.update_xaxes(title_text="Time (s)") + fig.update_layout(height=700, width=900) + return fig + + +def mlflow_autolog(fn=None, *, framework=mlflow.sklearn, experiment_name: typing.Optional[str] = None): + """MLFlow decorator to enable autologging of training metrics. + + This decorator can be used as a nested decorator for a ``@task`` and it will automatically enable mlflow autologging, + for the given ``framework``. By default autologging is enabled for ``sklearn``. + + .. code-block:: python + + @task + @mlflow_autolog(framework=mlflow.tensorflow) + def my_tensorflow_trainer(): + ... + + One benefit of doing so is that the mlflow metrics are then rendered inline using FlyteDecks and can be viewed + in jupyter notebook, as well as in hosted Flyte environment: + + .. code-block:: python + + # jupyter notebook cell + with flytekit.new_context() as ctx: + my_tensorflow_trainer() + ctx.get_deck() # IPython.display + + When the task is called in a Flyte backend, the decorator starts a new MLFlow run using the Flyte execution name + by default, or a user-provided ``experiment_name`` in the decorator. + + :param fn: Function to generate autologs for. + :param framework: The mlflow module to use for autologging + :param experiment_name: The MLFlow experiment name. If not provided, uses the Flyte execution name. + """ + + @wraps(fn) + def wrapper(*args, **kwargs): + framework.autolog() + params = FlyteContextManager.current_context().user_space_params + ctx = FlyteContextManager.current_context() + + experiment = experiment_name or "local workflow" + run_name = None # MLflow will generate random name if value is None + + if ctx.execution_state.mode != ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + experiment = experiment_name and f"{get_one_of('FLYTE_INTERNAL_EXECUTION_WORKFLOW', '_F_WF')}" + run_name = f"{params.execution_id.name}.{params.task_id.name.split('.')[-1]}" + + mlflow.set_experiment(experiment) + with mlflow.start_run(run_name=run_name): + out = fn(*args, **kwargs) + run = mlflow.active_run() + if run is not None: + client = MlflowClient() + run_id = run.info.run_id + metrics = get_run_metrics(client, run_id) + figure = plot_metrics(metrics) + if figure: + flytekit.Deck("mlflow metrics", figure.to_html()) + params = get_run_params(client, run_id) + if params is not None: + flytekit.Deck("mlflow params", TopFrameRenderer(max_rows=10).to_html(params)) + return out + + if fn is None: + return partial(mlflow_autolog, framework=framework, experiment_name=experiment_name) + + return wrapper diff --git a/plugins/flytekit-mlflow/requirements.in b/plugins/flytekit-mlflow/requirements.in new file mode 100644 index 0000000000..cbe58e3885 --- /dev/null +++ b/plugins/flytekit-mlflow/requirements.in @@ -0,0 +1,3 @@ +. +-e file:.#egg=flytekitplugins-mlflow +grpcio-status<1.49.0 diff --git a/plugins/flytekit-mlflow/requirements.txt b/plugins/flytekit-mlflow/requirements.txt new file mode 100644 index 0000000000..03873c05f5 --- /dev/null +++ b/plugins/flytekit-mlflow/requirements.txt @@ -0,0 +1,274 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-mlflow + # via -r requirements.in +alembic==1.8.1 + # via mlflow +arrow==1.2.3 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.9.24 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.1 + # via requests +click==8.1.3 + # via + # cookiecutter + # databricks-cli + # flask + # flytekit + # mlflow +cloudpickle==2.2.0 + # via + # flytekit + # mlflow +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.7 + # via flytekit +cryptography==38.0.3 + # via pyopenssl +databricks-cli==0.17.3 + # via mlflow +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==6.0.1 + # via + # flytekit + # mlflow +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.15 + # via flytekit +entrypoints==0.4 + # via mlflow +flask==2.2.2 + # via + # mlflow + # prometheus-flask-exporter +flyteidl==1.1.22 + # via flytekit +flytekit==1.2.3 + # via flytekitplugins-mlflow +gitdb==4.0.9 + # via gitpython +gitpython==3.1.29 + # via mlflow +googleapis-common-protos==1.56.4 + # via + # flyteidl + # grpcio-status +greenlet==2.0.1 + # via sqlalchemy +grpcio==1.50.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.48.2 + # via + # -r requirements.in + # flytekit +gunicorn==20.1.0 + # via mlflow +idna==3.4 + # via requests +importlib-metadata==5.0.0 + # via + # flask + # flytekit + # keyring + # mlflow +itsdangerous==2.1.2 + # via flask +jaraco-classes==3.2.3 + # via keyring +jinja2==3.1.2 + # via + # cookiecutter + # flask + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.2.0 + # via flytekit +keyring==23.11.0 + # via flytekit +mako==1.2.3 + # via alembic +markupsafe==2.1.1 + # via + # jinja2 + # mako + # werkzeug +marshmallow==3.18.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mlflow==1.30.0 + # via flytekitplugins-mlflow +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.2.0 + # via flytekit +numpy==1.23.4 + # via + # mlflow + # pandas + # pyarrow + # scipy +oauthlib==3.2.2 + # via databricks-cli +packaging==21.3 + # via + # docker + # marshmallow + # mlflow +pandas==1.5.1 + # via + # flytekit + # mlflow +plotly==5.11.0 + # via flytekitplugins-mlflow +prometheus-client==0.15.0 + # via prometheus-flask-exporter +prometheus-flask-exporter==0.20.3 + # via mlflow +protobuf==3.20.3 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # mlflow + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyjwt==2.6.0 + # via databricks-cli +pyopenssl==22.1.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.4 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.6 + # via + # flytekit + # mlflow + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit + # mlflow +querystring-parser==1.2.4 + # via mlflow +regex==2022.10.31 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # databricks-cli + # docker + # flytekit + # mlflow + # responses +responses==0.22.0 + # via flytekit +retry==0.9.2 + # via flytekit +scipy==1.9.3 + # via mlflow +six==1.16.0 + # via + # databricks-cli + # grpcio + # python-dateutil + # querystring-parser +smmap==5.0.0 + # via gitdb +sortedcontainers==2.4.0 + # via flytekit +sqlalchemy==1.4.43 + # via + # alembic + # mlflow +sqlparse==0.4.3 + # via mlflow +statsd==3.3.0 + # via flytekit +tabulate==0.9.0 + # via databricks-cli +tenacity==8.1.0 + # via plotly +text-unidecode==1.3 + # via python-slugify +toml==0.10.2 + # via responses +types-toml==0.10.8 + # via responses +typing-extensions==4.4.0 + # via + # flytekit + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +urllib3==1.26.12 + # via + # docker + # flytekit + # requests + # responses +websocket-client==1.4.2 + # via docker +werkzeug==2.2.2 + # via flask +wheel==0.38.3 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zipp==3.10.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-mlflow/setup.py b/plugins/flytekit-mlflow/setup.py new file mode 100644 index 0000000000..2033ce5d27 --- /dev/null +++ b/plugins/flytekit-mlflow/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "mlflow" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.1.0,<2.0.0", "plotly", "mlflow"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package enables seamless use of MLFlow within Flyte", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-mlflow/tests/__init__.py b/plugins/flytekit-mlflow/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py new file mode 100644 index 0000000000..b196327d8d --- /dev/null +++ b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py @@ -0,0 +1,32 @@ +import mlflow +import tensorflow as tf +from flytekitplugins.mlflow import mlflow_autolog + +import flytekit +from flytekit import task + + +@task(disable_deck=False) +@mlflow_autolog(framework=mlflow.keras) +def train_model(epochs: int): + fashion_mnist = tf.keras.datasets.fashion_mnist + (train_images, train_labels), (_, _) = fashion_mnist.load_data() + train_images = train_images / 255.0 + + model = tf.keras.Sequential( + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(128, activation="relu"), + tf.keras.layers.Dense(10), + ] + ) + + model.compile( + optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"] + ) + model.fit(train_images, train_labels, epochs=epochs) + + +def test_local_exec(): + train_model(epochs=1) + assert len(flytekit.current_context().decks) == 4 # mlflow metrics, params, input, and output diff --git a/plugins/flytekit-modin/requirements.txt b/plugins/flytekit-modin/requirements.txt index ffec3bce45..c96c995db2 100644 --- a/plugins/flytekit-modin/requirements.txt +++ b/plugins/flytekit-modin/requirements.txt @@ -1,42 +1,44 @@ # -# This file is autogenerated by pip-compile with python 3.8 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-modin # via -r requirements.in -aiosignal==1.2.0 +aiosignal==1.3.1 # via ray -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time -attrs==21.4.0 +attrs==22.2.0 # via # jsonschema # ray binaryornot==0.4.4 # via cookiecutter -certifi==2021.10.8 +certifi==2022.12.7 # via requests -cffi==1.15.0 +cffi==1.15.1 # via cryptography -chardet==4.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==3.0.1 # via requests -click==8.1.2 +click==8.1.3 # via # cookiecutter # flytekit # ray -cloudpickle==2.0.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.4 +croniter==1.3.8 # via flytekit -cryptography==36.0.2 - # via secretstorage +cryptography==39.0.1 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -45,31 +47,31 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -distlib==0.3.4 +distlib==0.3.6 # via virtualenv -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.13 +docstring-parser==0.15 # via flytekit -filelock==3.6.0 +filelock==3.9.0 # via # ray # virtualenv -flyteidl==1.0.1 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0b2 +flytekit==1.2.7 # via flytekitplugins-modin -frozenlist==1.3.0 +frozenlist==1.3.3 # via # aiosignal # ray -fsspec==2022.3.0 +fsspec==2023.1.0 # via # flytekitplugins-modin # modin -googleapis-common-protos==1.56.0 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status @@ -83,29 +85,33 @@ grpcio-status==1.43.0 # via # flytekit # flytekitplugins-modin -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.11.3 +importlib-metadata==6.0.0 + # via + # flytekit + # keyring +jaraco-classes==3.2.3 # via keyring -importlib-resources==5.7.1 - # via jsonschema jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.1.1 +jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -jsonschema==4.4.0 +joblib==1.2.0 + # via flytekit +jsonschema==4.17.3 # via ray -keyring==23.5.0 +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 -marshmallow==3.15.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -114,33 +120,36 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -modin==0.14.0 +modin==0.18.1 # via flytekitplugins-modin -msgpack==1.0.3 +more-itertools==9.0.0 + # via jaraco-classes +msgpack==1.0.4 # via ray -mypy-extensions==0.4.3 +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit -numpy==1.22.3 +numpy==1.23.5 # via + # flytekit # modin # pandas # pyarrow # ray -packaging==21.3 +packaging==23.0 # via + # docker # marshmallow # modin -pandas==1.4.1 + # ray +pandas==1.5.3 # via # flytekit # modin -platformdirs==2.5.2 +platformdirs==3.0.0 # via virtualenv -poyo==0.5.0 - # via cookiecutter -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -150,15 +159,17 @@ protobuf==3.20.2 # ray protoc-gen-swagger==0.1.0 # via flyteidl +psutil==5.9.4 + # via modin py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyparsing==3.0.8 - # via packaging -pyrsistent==0.18.1 +pyopenssl==23.0.0 + # via flytekit +pyrsistent==0.19.3 # via jsonschema python-dateutil==2.8.2 # via @@ -166,71 +177,73 @@ python-dateutil==2.8.2 # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.1 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas pyyaml==6.0 # via + # cookiecutter # flytekit # ray -ray==1.12.0 +ray==2.2.0 # via flytekitplugins-modin -regex==2022.3.15 +regex==2022.10.31 # via docker-image-py -requests==2.27.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # ray # responses -responses==0.20.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring six==1.16.0 # via - # cookiecutter # grpcio # python-dateutil - # virtualenv sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.2.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # flytekit # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -virtualenv==20.14.1 +virtualenv==20.19.0 # via ray -websocket-client==1.3.2 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit -wrapt==1.14.0 +wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via - # importlib-metadata - # importlib-resources +zipp==3.14.0 + # via importlib-metadata diff --git a/plugins/flytekit-onnx-pytorch/requirements.txt b/plugins/flytekit-onnx-pytorch/requirements.txt index 7a4bdea79b..0660386414 100644 --- a/plugins/flytekit-onnx-pytorch/requirements.txt +++ b/plugins/flytekit-onnx-pytorch/requirements.txt @@ -1,35 +1,39 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-onnxpytorch # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit +coloredlogs==15.0.1 + # via onnxruntime cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 - # via pyopenssl +cryptography==39.0.1 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -38,45 +42,58 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flatbuffers==2.0 +flatbuffers==23.1.21 # via onnxruntime -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-onnxpytorch -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +humanfriendly==10.0 + # via coloredlogs +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via + # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 + # via flytekit +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -85,27 +102,45 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mpmath==1.2.1 + # via sympy +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit -numpy==1.23.0 +numpy==1.21.6 # via + # flytekit # onnxruntime # pandas # pyarrow # torchvision -onnxruntime==1.11.1 +nvidia-cublas-cu11==11.10.3.66 + # via + # nvidia-cudnn-cu11 + # torch +nvidia-cuda-nvrtc-cu11==11.7.99 + # via torch +nvidia-cuda-runtime-cu11==11.7.99 + # via torch +nvidia-cudnn-cu11==8.5.0.96 + # via torch +onnxruntime==1.14.0 # via -r requirements.in -packaging==21.3 - # via marshmallow -pandas==1.4.3 +packaging==23.0 + # via + # docker + # marshmallow + # onnxruntime +pandas==1.3.5 # via flytekit -pillow==9.2.0 +pillow==9.4.0 # via # -r requirements.in # torchvision -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -117,27 +152,25 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -145,19 +178,23 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses # torchvision -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit +secretstorage==3.3.3 + # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via # grpcio @@ -166,34 +203,52 @@ sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 # via flytekit +sympy==1.10.1 + # via onnxruntime text-unidecode==1.3 # via python-slugify -torch==1.12.0 +toml==0.10.2 + # via responses +torch==1.13.1 # via # flytekitplugins-onnxpytorch # torchvision -torchvision==0.13.0 +torchvision==0.14.1 # via -r requirements.in -typing-extensions==4.3.0 +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via + # arrow # flytekit + # importlib-metadata + # responses # torch # torchvision # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.10 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 - # via flytekit +wheel==0.38.4 + # via + # flytekit + # nvidia-cublas-cu11 + # nvidia-cuda-runtime-cu11 wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-onnx-scikitlearn/requirements.txt b/plugins/flytekit-onnx-scikitlearn/requirements.txt index 86c474c8b4..b2171351b6 100644 --- a/plugins/flytekit-onnx-scikitlearn/requirements.txt +++ b/plugins/flytekit-onnx-scikitlearn/requirements.txt @@ -1,35 +1,39 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-onnxscikitlearn # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit +coloredlogs==15.0.1 + # via onnxruntime cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 - # via pyopenssl +cryptography==39.0.1 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -38,47 +42,60 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flatbuffers==2.0 +flatbuffers==23.1.21 # via onnxruntime -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-onnxscikitlearn -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +humanfriendly==10.0 + # via coloredlogs +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via + # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -joblib==1.1.0 - # via scikit-learn -keyring==23.6.0 +joblib==1.2.0 + # via + # flytekit + # scikit-learn +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -87,12 +104,17 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mpmath==1.2.1 + # via sympy +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit -numpy==1.23.0 +numpy==1.21.6 # via + # flytekit # onnx # onnxconverter-common # onnxruntime @@ -101,19 +123,23 @@ numpy==1.23.0 # scikit-learn # scipy # skl2onnx -onnx==1.12.0 +onnx==1.13.1 # via # onnxconverter-common # skl2onnx -onnxconverter-common==1.9.0 +onnxconverter-common==1.13.0 # via skl2onnx -onnxruntime==1.11.1 +onnxruntime==1.14.0 # via -r requirements.in -packaging==21.3 - # via marshmallow -pandas==1.4.3 +packaging==23.0 + # via + # docker + # marshmallow + # onnxconverter-common + # onnxruntime +pandas==1.3.5 # via flytekit -protobuf==3.19.5 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -128,27 +154,25 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -156,57 +180,73 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -scikit-learn==1.1.1 +scikit-learn==1.0.2 # via skl2onnx -scipy==1.8.1 +scipy==1.7.3 # via # scikit-learn # skl2onnx +secretstorage==3.3.3 + # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via # grpcio # python-dateutil -skl2onnx==1.11.2 +skl2onnx==1.13 # via flytekitplugins-onnxscikitlearn sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 # via flytekit +sympy==1.10.1 + # via onnxruntime text-unidecode==1.3 # via python-slugify threadpoolctl==3.1.0 # via scikit-learn -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via + # arrow # flytekit + # importlib-metadata # onnx + # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.10 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-onnx-tensorflow/requirements.txt b/plugins/flytekit-onnx-tensorflow/requirements.txt index b25b24fe1e..33dddb5b25 100644 --- a/plugins/flytekit-onnx-tensorflow/requirements.txt +++ b/plugins/flytekit-onnx-tensorflow/requirements.txt @@ -57,9 +57,9 @@ flatbuffers==1.12 # onnxruntime # tensorflow # tf2onnx -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-onnxtensorflow gast==0.4.0 # via tensorflow diff --git a/plugins/flytekit-pandera/requirements.txt b/plugins/flytekit-pandera/requirements.txt index fe8ed5c840..f6c56bf14f 100644 --- a/plugins/flytekit-pandera/requirements.txt +++ b/plugins/flytekit-pandera/requirements.txt @@ -1,34 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-pandera # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -40,33 +40,37 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-pandera -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -77,11 +81,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 + # via flytekit +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -90,9 +96,11 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via @@ -100,17 +108,18 @@ numpy==1.21.6 # pandas # pandera # pyarrow -packaging==21.3 +packaging==23.0 # via + # docker # marshmallow # pandera pandas==1.3.5 # via # flytekit # pandera -pandera==0.9.0 +pandera==0.13.4 # via flytekitplugins-pandera -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -121,31 +130,27 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 - # via - # flytekit - # pandera +pyarrow==10.0.1 + # via flytekit pycparser==2.21 # via cffi -pydantic==1.9.1 +pydantic==1.10.5 # via pandera -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -153,19 +158,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -179,7 +184,11 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit @@ -188,23 +197,26 @@ typing-extensions==4.3.0 # pydantic # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via # dataclasses-json # pandera -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit # pandera -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-pandera/tests/test_plugin.py b/plugins/flytekit-pandera/tests/test_plugin.py index a16d80d781..cc9b26c4fa 100644 --- a/plugins/flytekit-pandera/tests/test_plugin.py +++ b/plugins/flytekit-pandera/tests/test_plugin.py @@ -48,6 +48,8 @@ def my_wf() -> pandera.typing.DataFrame[OutSchema]: def invalid_wf() -> pandera.typing.DataFrame[OutSchema]: return transform2(df=transform1(df=invalid_df)) + invalid_wf() + # raise error when executing workflow with invalid input @workflow def wf_with_df_input(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[OutSchema]: diff --git a/plugins/flytekit-papermill/dev-requirements.in b/plugins/flytekit-papermill/dev-requirements.in index ed26284a17..5285174ac3 100644 --- a/plugins/flytekit-papermill/dev-requirements.in +++ b/plugins/flytekit-papermill/dev-requirements.in @@ -1,3 +1,3 @@ -flyteidl>=1.0.0 +flyteidl==1.2.9 -e file:../../.#egg=flytekitplugins-pod&subdirectory=plugins/flytekit-k8s-pod -e file:../../.#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark diff --git a/plugins/flytekit-papermill/dev-requirements.txt b/plugins/flytekit-papermill/dev-requirements.txt index 2a42c36a96..0dfc2d57bb 100644 --- a/plugins/flytekit-papermill/dev-requirements.txt +++ b/plugins/flytekit-papermill/dev-requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.7 +# This file is autogenerated by pip-compile with Python 3.8 # by the following command: # # pip-compile dev-requirements.in @@ -34,8 +34,10 @@ cookiecutter==2.1.1 # via flytekit croniter==1.2.0 # via flytekit -cryptography==37.0.4 - # via secretstorage +cryptography==39.0.1 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.6 # via flytekit decorator==5.1.1 @@ -50,11 +52,11 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.13 # via flytekit -flyteidl==1.0.0.post1 +flyteidl==1.2.9 # via # -r dev-requirements.in # flytekit -flytekit==1.1.0b0 +flytekit==1.2.7 # via # flytekitplugins-pod # flytekitplugins-spark @@ -73,7 +75,9 @@ grpcio-status==1.47.0 idna==3.3 # via requests importlib-metadata==4.10.1 - # via keyring + # via + # flytekit + # keyring jeepney==0.8.0 # via # keyring @@ -84,6 +88,8 @@ jinja2==3.0.3 # jinja2-time jinja2-time==0.2.0 # via cookiecutter +joblib==1.2.0 + # via flytekit keyring==23.5.0 # via flytekit kubernetes==24.2.0 @@ -135,6 +141,8 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi +pyopenssl==23.0.0 + # via flytekit pyspark==3.3.0 # via flytekitplugins-spark python-dateutil==2.8.1 @@ -179,8 +187,6 @@ rsa==4.9 # via google-auth secretstorage==3.3.3 # via keyring -singledispatchmethod==1.0 - # via flytekit six==1.16.0 # via # google-auth @@ -196,9 +202,7 @@ text-unidecode==1.3 # via python-slugify typing-extensions==4.0.1 # via - # arrow # flytekit - # importlib-metadata # typing-inspect typing-inspect==0.7.1 # via dataclasses-json diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 04f821ccf3..6c160c2690 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -186,7 +186,7 @@ def fn(settings: SerializationSettings) -> typing.List[str]: return self._config_task_instance.get_k8s_pod(settings) def get_config(self, settings: SerializationSettings) -> typing.Dict[str, str]: - return self._config_task_instance.get_config(settings) + return {**super().get_config(settings), **self._config_task_instance.get_config(settings)} def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: return self._config_task_instance.pre_execute(user_params) diff --git a/plugins/flytekit-papermill/requirements.txt b/plugins/flytekit-papermill/requirements.txt index 125caca902..3964df44fa 100644 --- a/plugins/flytekit-papermill/requirements.txt +++ b/plugins/flytekit-papermill/requirements.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: # # pip-compile requirements.in # @@ -8,44 +8,44 @@ # via -r requirements.in ansiwrap==0.8.4 # via papermill -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time -attrs==21.4.0 +attrs==22.2.0 # via jsonschema backcall==0.2.0 # via ipython -beautifulsoup4==4.11.1 +beautifulsoup4==4.11.2 # via nbconvert binaryornot==0.4.4 # via cookiecutter -bleach==5.0.1 +bleach==6.0.0 # via nbconvert -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit # papermill -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage dataclasses-json==0.5.7 # via flytekit -debugpy==1.6.0 +debugpy==1.6.6 # via ipykernel decorator==5.1.1 # via @@ -57,48 +57,50 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit entrypoints==0.4 # via # jupyter-client - # nbconvert # papermill -fastjsonschema==2.15.3 +fastjsonschema==2.16.2 # via nbformat -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-papermill -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via - # click # flytekit + # keyring + # nbconvert +importlib-resources==5.12.0 + # via # jsonschema # keyring -importlib-resources==5.8.0 - # via jsonschema -ipykernel==6.15.0 +ipykernel==6.16.2 # via flytekitplugins-papermill ipython==7.34.0 # via ipykernel -jedi==0.18.1 +jaraco-classes==3.2.3 + # via keyring +jedi==0.18.2 # via ipython jeepney==0.8.0 # via @@ -111,26 +113,29 @@ jinja2==3.1.2 # nbconvert jinja2-time==0.2.0 # via cookiecutter -jsonschema==4.6.1 +joblib==1.2.0 + # via flytekit +jsonschema==4.17.3 # via nbformat -jupyter-client==7.3.4 +jupyter-client==7.4.9 # via # ipykernel # nbclient -jupyter-core==4.11.2 +jupyter-core==4.12.0 # via # jupyter-client + # nbclient # nbconvert # nbformat jupyterlab-pygments==0.2.2 # via nbconvert -keyring==23.6.0 +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +markupsafe==2.1.2 # via # jinja2 # nbconvert -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -139,39 +144,41 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib-inline==0.1.3 +matplotlib-inline==0.1.6 # via # ipykernel # ipython -mistune==2.0.3 +mistune==2.0.5 # via nbconvert -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit -nbclient==0.6.6 +nbclient==0.7.2 # via # nbconvert # papermill -nbconvert==7.0.0rc2 +nbconvert==7.2.9 # via flytekitplugins-papermill -nbformat==5.4.0 +nbformat==5.7.3 # via # nbclient # nbconvert # papermill -nest-asyncio==1.5.5 +nest-asyncio==1.5.6 # via # ipykernel # jupyter-client - # nbclient numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 +packaging==23.0 # via + # docker # ipykernel # marshmallow # nbconvert @@ -179,7 +186,7 @@ pandas==1.3.5 # via flytekit pandocfilters==1.5.0 # via nbconvert -papermill==2.3.4 +papermill==2.4.0 # via flytekitplugins-papermill parso==0.8.3 # via jedi @@ -187,9 +194,11 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -prompt-toolkit==3.0.30 +pkgutil-resolve-name==1.3.10 + # via jsonschema +prompt-toolkit==3.0.37 # via ipython -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -198,25 +207,23 @@ protobuf==3.20.2 # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -psutil==5.9.1 +psutil==5.9.4 # via ipykernel ptyprocess==0.7.0 # via pexpect py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pygments==2.12.0 +pygments==2.14.0 # via # ipython # nbconvert -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging -pyrsistent==0.18.1 +pyrsistent==0.19.3 # via jsonschema python-dateutil==2.8.2 # via @@ -225,13 +232,13 @@ python-dateutil==2.8.2 # flytekit # jupyter-client # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -240,27 +247,25 @@ pyyaml==6.0 # cookiecutter # flytekit # papermill -pyzmq==23.2.0 +pyzmq==25.0.0 # via # ipykernel # jupyter-client -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # papermill # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring -singledispatchmethod==1.0 - # via flytekit six==1.16.0 # via # bleach @@ -268,25 +273,27 @@ six==1.16.0 # python-dateutil sortedcontainers==2.4.0 # via flytekit -soupsieve==2.3.2.post1 +soupsieve==2.4 # via beautifulsoup4 statsd==3.3.0 # via flytekit -tenacity==8.0.1 +tenacity==8.2.1 # via papermill text-unidecode==1.3 # via python-slugify textwrap3==0.9.2 # via ansiwrap -tinycss2==1.1.1 +tinycss2==1.2.1 # via nbconvert +toml==0.10.2 + # via responses tornado==6.2 # via # ipykernel # jupyter-client -tqdm==4.64.0 +tqdm==4.64.1 # via papermill -traitlets==5.3.0 +traitlets==5.9.0 # via # ipykernel # ipython @@ -296,36 +303,35 @@ traitlets==5.3.0 # nbclient # nbconvert # nbformat -typing-extensions==4.3.0 +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via - # arrow # flytekit - # importlib-metadata - # jsonschema - # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -wcwidth==0.2.5 +wcwidth==0.2.6 # via prompt-toolkit webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 +zipp==3.14.0 # via # importlib-metadata # importlib-resources diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 0dfd0c6516..0b5bf8e577 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -64,11 +64,10 @@ def decode( ) -> pl.DataFrame: local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) - path = f"{local_dir}/00000" if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pl.read_parquet(path, columns=columns) - return pl.read_parquet(path) + return pl.read_parquet(local_dir, columns=columns, use_pyarrow=True) + return pl.read_parquet(local_dir, use_pyarrow=True) StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) diff --git a/plugins/flytekit-polars/requirements.txt b/plugins/flytekit-polars/requirements.txt index 7595ab781c..9deb7d3482 100644 --- a/plugins/flytekit-polars/requirements.txt +++ b/plugins/flytekit-polars/requirements.txt @@ -1,34 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-polars # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -40,33 +40,37 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-polars -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -77,11 +81,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 + # via flytekit +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -90,23 +96,26 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas - # polars # pyarrow -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via flytekit -polars==0.13.51 +polars==0.16.8 # via flytekitplugins-polars -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -117,27 +126,25 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -145,19 +152,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -171,7 +178,11 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit @@ -179,20 +190,23 @@ typing-extensions==4.3.0 # polars # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index b991cd5d13..15a195e5d5 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -66,3 +66,16 @@ def test_polars_renderer(): assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame( df.describe().transpose(), columns=df.describe().columns ).to_html(index=False) + + +def test_parquet_to_polars(): + data = {"name": ["Alice"], "age": [5]} + + @task + def create_sd() -> StructuredDataset: + df = pd.DataFrame(data=data) + return StructuredDataset(dataframe=df) + + sd = create_sd() + polars_df = sd.open(pl.DataFrame).all() + assert pl.DataFrame(data).frame_equal(polars_df) diff --git a/plugins/flytekit-ray/requirements.txt b/plugins/flytekit-ray/requirements.txt index 3162769907..5f81a1059e 100644 --- a/plugins/flytekit-ray/requirements.txt +++ b/plugins/flytekit-ray/requirements.txt @@ -6,16 +6,31 @@ # -e file:.#egg=flytekitplugins-ray # via -r requirements.in -aiosignal==1.2.0 +aiohttp==3.8.3 + # via + # aiohttp-cors + # ray +aiohttp-cors==0.7.0 # via ray +aiosignal==1.2.0 + # via + # aiohttp + # ray arrow==1.2.2 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp attrs==21.4.0 # via + # aiohttp # jsonschema # ray binaryornot==0.4.4 # via cookiecutter +blessed==1.19.1 + # via gpustat +cachetools==5.2.0 + # via google-auth certifi==2022.5.18.1 # via requests cffi==1.15.0 @@ -23,7 +38,9 @@ cffi==1.15.0 chardet==4.0.0 # via binaryornot charset-normalizer==2.0.12 - # via requests + # via + # aiohttp + # requests click==8.0.4 # via # cookiecutter @@ -31,6 +48,8 @@ click==8.0.4 # ray cloudpickle==2.1.0 # via flytekit +colorful==0.5.5 + # via ray cookiecutter==1.7.3 # via flytekit croniter==1.3.5 @@ -57,20 +76,28 @@ filelock==3.7.1 # via # ray # virtualenv -flyteidl==1.1.10 +flyteidl==1.2.9 # via # flytekit # flytekitplugins-ray -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-ray frozenlist==1.3.0 # via + # aiohttp # aiosignal # ray -googleapis-common-protos==1.56.1 +google-api-core==2.11.0 + # via opencensus +google-auth==2.15.0 + # via google-api-core +googleapis-common-protos==1.57.0 # via # flyteidl + # google-api-core # grpcio-status +gpustat==1.0.0 + # via ray grpcio==1.43.0 # via # flytekit @@ -79,7 +106,9 @@ grpcio==1.43.0 grpcio-status==1.43.0 # via flytekit idna==3.3 - # via requests + # via + # requests + # yarl importlib-metadata==4.11.3 # via # flytekit @@ -107,6 +136,10 @@ marshmallow-jsonschema==0.13.0 # via flytekit msgpack==1.0.4 # via ray +multidict==6.0.3 + # via + # aiohttp + # yarl mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 @@ -116,6 +149,12 @@ numpy==1.21.6 # pandas # pyarrow # ray +nvidia-ml-py==11.495.46 + # via gpustat +opencensus==0.11.0 + # via ray +opencensus-context==0.1.3 + # via opencensus packaging==21.3 # via marshmallow pandas==1.3.5 @@ -124,20 +163,33 @@ platformdirs==2.5.2 # via virtualenv poyo==0.5.0 # via cookiecutter +prometheus-client==0.13.1 + # via ray protobuf==3.20.2 # via # flyteidl # flytekit + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger # ray protoc-gen-swagger==0.1.0 # via flyteidl +psutil==5.9.4 + # via gpustat py==1.11.0 # via retry +py-spy==0.3.14 + # via ray pyarrow==6.0.1 # via flytekit +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth pycparser==2.21 # via cffi pyopenssl==22.0.0 @@ -166,7 +218,7 @@ pyyaml==6.0 # via # flytekit # ray -ray==1.13.0 +ray[default]==1.13.0 # via flytekitplugins-ray regex==2022.4.24 # via docker-image-py @@ -175,18 +227,26 @@ requests==2.27.1 # cookiecutter # docker # flytekit + # google-api-core # ray # responses responses==0.20.0 # via flytekit retry==0.9.2 # via flytekit +rsa==4.9 + # via google-auth six==1.16.0 # via + # blessed # cookiecutter + # google-auth + # gpustat # grpcio # python-dateutil # virtualenv +smart-open==6.2.0 + # via ray sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 @@ -206,6 +266,8 @@ urllib3==1.26.9 # responses virtualenv==20.15.1 # via ray +wcwidth==0.2.5 + # via blessed websocket-client==1.3.2 # via docker wheel==0.38.0 @@ -214,5 +276,10 @@ wrapt==1.14.1 # via # deprecated # flytekit +yarl==1.8.2 + # via aiohttp zipp==3.8.0 # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-ray/setup.py b/plugins/flytekit-ray/setup.py index 84fdd15646..a83b0822fc 100644 --- a/plugins/flytekit-ray/setup.py +++ b/plugins/flytekit-ray/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["ray", "flytekit>=1.1.0b0,<1.3.0,<2.0.0", "flyteidl>=1.1.10"] +plugin_requires = ["ray[default]", "flytekit>=1.1.0b0,<1.3.0,<2.0.0", "flyteidl>=1.1.10"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-snowflake/requirements.txt b/plugins/flytekit-snowflake/requirements.txt index da4bb43c2a..9e44535ee0 100644 --- a/plugins/flytekit-snowflake/requirements.txt +++ b/plugins/flytekit-snowflake/requirements.txt @@ -1,34 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-snowflake # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -40,33 +40,37 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-snowflake -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -77,11 +81,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 # via flytekit -markupsafe==2.1.1 +keyring==23.13.1 + # via flytekit +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -90,20 +96,24 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -114,27 +124,25 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -142,19 +150,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -168,27 +176,34 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-snowflake/tests/test_snowflake.py b/plugins/flytekit-snowflake/tests/test_snowflake.py index a012e38d99..672f4a19ad 100644 --- a/plugins/flytekit-snowflake/tests/test_snowflake.py +++ b/plugins/flytekit-snowflake/tests/test_snowflake.py @@ -70,7 +70,7 @@ def test_local_exec(): ) assert len(snowflake_task.interface.inputs) == 1 - assert snowflake_task.query_template == "select 1\\n" + assert snowflake_task.query_template == "select 1" assert len(snowflake_task.interface.outputs) == 1 # will not run locally @@ -86,4 +86,4 @@ def test_sql_template(): custom where column = 1""", output_schema_type=FlyteSchema, ) - assert snowflake_task.query_template == "select 1 from\\t\\n custom where column = 1" + assert snowflake_task.query_template == "select 1 from custom where column = 1" diff --git a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py index 7e0c0b77e7..e769540aea 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py @@ -20,4 +20,4 @@ from .pyspark_transformers import PySparkPipelineModelTransformer from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer # noqa from .sd_transformers import ParquetToSparkDecodingHandler, SparkToParquetEncodingHandler -from .task import Spark, new_spark_session # noqa +from .task import Databricks, Spark, new_spark_session # noqa diff --git a/plugins/flytekit-spark/flytekitplugins/spark/models.py b/plugins/flytekit-spark/flytekitplugins/spark/models.py index 53b1620331..28e67ac631 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/models.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/models.py @@ -1,7 +1,9 @@ import enum -import typing +from typing import Dict, Optional from flyteidl.plugins import spark_pb2 as _spark_task +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Struct from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common @@ -17,12 +19,15 @@ class SparkType(enum.Enum): class SparkJob(_common.FlyteIdlEntity): def __init__( self, - spark_type, - application_file, - main_class, - spark_conf, - hadoop_conf, - executor_path, + spark_type: SparkType, + application_file: str, + main_class: str, + spark_conf: Dict[str, str], + hadoop_conf: Dict[str, str], + executor_path: str, + databricks_conf: Dict[str, Dict[str, Dict]] = {}, + databricks_token: Optional[str] = None, + databricks_instance: Optional[str] = None, ): """ This defines a SparkJob target. It will execute the appropriate SparkJob. @@ -30,6 +35,9 @@ def __init__( :param application_file: The main application file to execute. :param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job. :param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job. + :param Optional[dict[Text, dict]] databricks_conf: A definition of key-value pairs for databricks config for the job. Refer to https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit. + :param Optional[str] databricks_token: databricks access token. + :param Optional[str] databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. """ self._application_file = application_file self._spark_type = spark_type @@ -37,9 +45,15 @@ def __init__( self._executor_path = executor_path self._spark_conf = spark_conf self._hadoop_conf = hadoop_conf + self._databricks_conf = databricks_conf + self._databricks_token = databricks_token + self._databricks_instance = databricks_instance def with_overrides( - self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None + self, + new_spark_conf: Optional[Dict[str, str]] = None, + new_hadoop_conf: Optional[Dict[str, str]] = None, + new_databricks_conf: Optional[Dict[str, Dict]] = None, ) -> "SparkJob": if not new_spark_conf: new_spark_conf = self.spark_conf @@ -47,12 +61,18 @@ def with_overrides( if not new_hadoop_conf: new_hadoop_conf = self.hadoop_conf + if not new_databricks_conf: + new_databricks_conf = self.databricks_conf + return SparkJob( spark_type=self.spark_type, application_file=self.application_file, main_class=self.main_class, spark_conf=new_spark_conf, hadoop_conf=new_hadoop_conf, + databricks_conf=new_databricks_conf, + databricks_token=self.databricks_token, + databricks_instance=self.databricks_instance, executor_path=self.executor_path, ) @@ -104,6 +124,31 @@ def hadoop_conf(self): """ return self._hadoop_conf + @property + def databricks_conf(self) -> Dict[str, Dict]: + """ + databricks_conf: Databricks job configuration. + Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + :rtype: dict[Text, dict[Text, Text]] + """ + return self._databricks_conf + + @property + def databricks_token(self) -> str: + """ + Databricks access token + :rtype: str + """ + return self._databricks_token + + @property + def databricks_instance(self) -> str: + """ + Domain name of your deployment. Use the form .cloud.databricks.com. + :rtype: str + """ + return self._databricks_instance + def to_flyte_idl(self): """ :rtype: flyteidl.plugins.spark_pb2.SparkJob @@ -120,6 +165,9 @@ def to_flyte_idl(self): else: raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified") + databricks_conf = Struct() + databricks_conf.update(self.databricks_conf) + return _spark_task.SparkJob( applicationType=application_type, mainApplicationFile=self.application_file, @@ -127,6 +175,9 @@ def to_flyte_idl(self): executorPath=self.executor_path, sparkConf=self.spark_conf, hadoopConf=self.hadoop_conf, + databricksConf=databricks_conf, + databricksToken=self.databricks_token, + databricksInstance=self.databricks_instance, ) @classmethod @@ -151,4 +202,7 @@ def from_flyte_idl(cls, pb2_object): main_class=pb2_object.mainClass, hadoop_conf=pb2_object.hadoopConf, executor_path=pb2_object.executorPath, + databricks_conf=json_format.MessageToDict(pb2_object.databricksConf), + databricks_token=pb2_object.databricksToken, + databricks_instance=pb2_object.databricksInstance, ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 46079f40dd..386570be5c 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -1,6 +1,7 @@ import typing import pandas as pd +import pyspark from pyspark.sql.dataframe import DataFrame from flytekit import FlyteContext @@ -38,7 +39,10 @@ def encode( ) -> literals.StructuredDataset: path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() df = typing.cast(DataFrame, structured_dataset.dataframe) - df.write.mode("overwrite").parquet(path) + ss = pyspark.sql.SparkSession.builder.getOrCreate() + # Avoid generating SUCCESS files + ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") + df.write.mode("overwrite").parquet(path=path) return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 8428b492ce..7b32e9f28b 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -36,6 +36,23 @@ def __post_init__(self): self.hadoop_conf = {} +@dataclass +class Databricks(Spark): + """ + Use this to configure a Databricks task. Task's marked with this will automatically execute + natively onto databricks platform as a distributed execution of spark + + Args: + databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + databricks_token: Databricks access token. https://docs.databricks.com/dev-tools/api/latest/authentication.html. + databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. + """ + + databricks_conf: typing.Optional[Dict[str, typing.Union[str, dict]]] = None + databricks_token: Optional[str] = None + databricks_instance: Optional[str] = None + + # This method does not reset the SparkSession since it's a bit hard to handle multiple # Spark sessions in a single application as it's described in: # https://stackoverflow.com/questions/41491972/how-can-i-tear-down-a-sparksession-and-create-a-new-one-within-one-application. @@ -100,6 +117,12 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: main_class="", spark_type=SparkType.PYTHON, ) + if isinstance(self.task_config, Databricks): + cfg = typing.cast(Databricks, self.task_config) + job._databricks_conf = cfg.databricks_conf + job._databricks_token = cfg.databricks_token + job._databricks_instance = cfg.databricks_instance + return MessageToDict(job.to_flyte_idl()) def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: @@ -127,3 +150,4 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: # Inject the Spark plugin into flytekits dynamic plugin loading system TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask) +TaskPlugins.register_pythontask_plugin(Databricks, PysparkFunctionTask) diff --git a/plugins/flytekit-spark/requirements.txt b/plugins/flytekit-spark/requirements.txt index ebb767cf0c..c30b61fe55 100644 --- a/plugins/flytekit-spark/requirements.txt +++ b/plugins/flytekit-spark/requirements.txt @@ -1,34 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-spark # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -40,33 +40,37 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-spark -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -77,11 +81,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 # via flytekit -markupsafe==2.1.1 +keyring==23.13.1 + # via flytekit +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -90,20 +96,24 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -116,15 +126,13 @@ py==1.11.0 # via retry py4j==0.10.9.5 # via pyspark -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging -pyspark==3.3.0 +pyspark==3.3.2 # via flytekitplugins-spark python-dateutil==2.8.2 # via @@ -132,13 +140,13 @@ python-dateutil==2.8.2 # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -146,19 +154,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -172,27 +180,34 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index 96081bd789..11935a30af 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -34,4 +34,5 @@ "Topic :: Software Development :: Libraries :: Python Modules", ], scripts=["scripts/flytekit_install_spark3.sh"], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-spark/tests/test_remote_register.py b/plugins/flytekit-spark/tests/test_remote_register.py index 8eaf8a0794..3bb65d09bc 100644 --- a/plugins/flytekit-spark/tests/test_remote_register.py +++ b/plugins/flytekit-spark/tests/test_remote_register.py @@ -21,6 +21,7 @@ def my_python_task(a: str) -> int: mock_client = MagicMock() remote._client = mock_client + remote._client_initialized = True remote.register_task( my_spark, diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 38555fa9b8..7049684a2d 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -2,7 +2,7 @@ import pyspark import pytest from flytekitplugins.spark import Spark -from flytekitplugins.spark.task import new_spark_session +from flytekitplugins.spark.task import Databricks, new_spark_session from pyspark.sql import SparkSession import flytekit @@ -19,6 +19,23 @@ def reset_spark_session() -> None: def test_spark_task(reset_spark_session): + databricks_conf = { + "name": "flytekit databricks plugin example", + "new_cluster": { + "spark_version": "11.0.x-scala2.12", + "node_type_id": "r3.xlarge", + "aws_attributes": {"availability": "ON_DEMAND"}, + "num_workers": 4, + "docker_image": {"url": "pingsutw/databricks:latest"}, + }, + "timeout_seconds": 3600, + "max_retries": 1, + "spark_python_task": { + "python_file": "dbfs:///FileStore/tables/entrypoint-1.py", + "parameters": "ls", + }, + } + @task(task_config=Spark(spark_conf={"spark": "1"})) def my_spark(a: str) -> int: session = flytekit.current_context().spark_session @@ -53,6 +70,28 @@ def my_spark(a: str) -> int: assert ("spark", "1") in configs assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs + databricks_token = "token" + databricks_instance = "account.cloud.databricks.com" + + @task( + task_config=Databricks( + spark_conf={"spark": "2"}, + databricks_conf=databricks_conf, + databricks_instance="account.cloud.databricks.com", + databricks_token="token", + ) + ) + def my_databricks(a: str) -> int: + session = flytekit.current_context().spark_session + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" + return 10 + + assert my_databricks.task_config is not None + assert my_databricks.task_config.spark_conf == {"spark": "2"} + assert my_databricks.task_config.databricks_conf == databricks_conf + assert my_databricks.task_config.databricks_instance == databricks_instance + assert my_databricks.task_config.databricks_token == databricks_token + def test_new_spark_session(): name = "SessionName" diff --git a/plugins/flytekit-sqlalchemy/requirements.txt b/plugins/flytekit-sqlalchemy/requirements.txt index 4879071049..d79a8ff148 100644 --- a/plugins/flytekit-sqlalchemy/requirements.txt +++ b/plugins/flytekit-sqlalchemy/requirements.txt @@ -1,34 +1,34 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-sqlalchemy # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==39.0.1 # via # pyopenssl # secretstorage @@ -40,36 +40,40 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.8 +flyteidl==1.2.9 # via flytekit -flytekit==1.1.0 +flytekit==1.2.7 # via flytekitplugins-sqlalchemy -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -greenlet==1.1.2 +greenlet==2.0.2 # via sqlalchemy -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring # sqlalchemy +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring jeepney==0.8.0 # via # keyring @@ -80,11 +84,13 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter -keyring==23.6.0 +joblib==1.2.0 + # via flytekit +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -93,20 +99,24 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via # flytekit # pandas # pyarrow -packaging==21.3 - # via marshmallow +packaging==23.0 + # via + # docker + # marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.2 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -117,27 +127,25 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.7.1 # via # flytekit # pandas @@ -145,19 +153,19 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -168,32 +176,41 @@ six==1.16.0 sortedcontainers==2.4.0 # via flytekit sqlalchemy==1.4.39 - # via flytekitplugins-sqlalchemy + # via + # -r requirements.in + # flytekitplugins-sqlalchemy statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via # arrow # flytekit # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.14 # via + # docker # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-sqlalchemy/tests/test_task.py b/plugins/flytekit-sqlalchemy/tests/test_task.py index 6d20027b2a..7537a3a1de 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_task.py +++ b/plugins/flytekit-sqlalchemy/tests/test_task.py @@ -70,7 +70,23 @@ def test_task_schema(sql_server): assert df is not None -def test_workflow(sql_server): +@pytest.mark.parametrize( + "query_template", + [ + "select * from tracks limit {{.inputs.limit}}", + """ + select * from tracks + limit {{.inputs.limit}} + """, + """select * from tracks + limit {{.inputs.limit}} + """, + """ + select * from tracks + limit {{.inputs.limit}}""", + ], +) +def test_workflow(sql_server, query_template): @task def my_task(df: pandas.DataFrame) -> int: return len(df[df.columns[0]]) @@ -84,7 +100,7 @@ def my_task(df: pandas.DataFrame) -> int: sql_task = SQLAlchemyTask( "test", - query_template="select * from tracks limit {{.inputs.limit}}", + query_template=query_template, inputs=kwtypes(limit=int), task_config=SQLAlchemyConfig(uri=sql_server), ) diff --git a/plugins/flytekit-vaex/requirements.txt b/plugins/flytekit-vaex/requirements.txt index 4f8266cc16..9c8a1c7009 100644 --- a/plugins/flytekit-vaex/requirements.txt +++ b/plugins/flytekit-vaex/requirements.txt @@ -1,45 +1,45 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # -# pip-compile --output-file=requirements.txt requirements.in +# pip-compile requirements.in # -e file:.#egg=flytekitplugins-vaex # via -r requirements.in aplus==0.11.0 # via vaex-core -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -blake3==0.3.1 +blake3==0.3.3 # via vaex-core -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.1 # via # dask # flytekit # vaex-core -commonmark==0.9.1 - # via rich cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 - # via pyopenssl -dask==2022.10.0 +cryptography==39.0.1 + # via + # pyopenssl + # secretstorage +dask==2022.2.0 # via vaex-core dataclasses-json==0.5.7 # via flytekit @@ -49,55 +49,66 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==6.0.0 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -filelock==3.8.0 +filelock==3.9.0 # via vaex-core -flyteidl==1.1.12 +flyteidl==1.2.9 # via flytekit -flytekit==1.2.1 +flytekit==1.2.7 # via flytekitplugins-vaex -frozendict==2.3.4 +frozendict==2.3.5 # via vaex-core -fsspec==2022.10.0 +fsspec==2023.1.0 # via dask -future==0.18.2 +future==0.18.3 # via vaex-core -googleapis-common-protos==1.56.4 +googleapis-common-protos==1.58.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==6.0.0 # via + # click # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +jaraco-classes==3.2.3 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -joblib==1.1.0 +joblib==1.2.0 # via flytekit -keyring==23.8.2 +keyring==23.13.1 # via flytekit locket==1.0.0 # via partd -markupsafe==2.1.1 +markdown-it-py==2.2.0 + # via rich +markupsafe==2.1.2 # via jinja2 -marshmallow==3.17.1 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -106,18 +117,23 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -mypy-extensions==0.4.3 +mdurl==0.1.2 + # via markdown-it-py +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit nest-asyncio==1.5.6 # via vaex-core numpy==1.21.6 # via + # flytekit # pandas # pyarrow # vaex-core -packaging==21.3 +packaging==23.0 # via # dask # docker @@ -128,9 +144,9 @@ pandas==1.3.5 # vaex-core partd==1.3.0 # via dask -progressbar2==4.1.1 +progressbar2==4.2.0 # via vaex-core -protobuf==3.20.1 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -141,65 +157,66 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via # flytekit # vaex-core pycparser==2.21 # via cffi -pydantic==1.10.2 +pydantic==1.10.5 # via vaex-core -pygments==2.13.0 +pygments==2.14.0 # via rich -pyopenssl==22.0.0 +pyopenssl==23.0.0 # via flytekit -pyparsing==3.0.9 - # via packaging python-dateutil==2.8.2 # via # arrow # croniter # flytekit # pandas -python-json-logger==2.0.4 +python-json-logger==2.0.7 # via flytekit -python-slugify==6.1.2 +python-slugify==8.0.0 # via cookiecutter -python-utils==3.3.3 +python-utils==3.5.2 # via progressbar2 pytimeparse==1.1.8 # via flytekit -pytz==2022.2.1 +pytz==2022.7.1 # via # flytekit # pandas -pyyaml==5.4.1 +pyyaml==6.0 # via # cookiecutter # dask # flytekit # vaex-core -regex==2022.8.17 +regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit # responses # vaex-core -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -rich==12.6.0 +rich==13.3.1 # via vaex-core +secretstorage==3.3.3 + # via keyring +singledispatchmethod==1.0 + # via flytekit six==1.16.0 # via # grpcio # python-dateutil # vaex-core - # websocket-client sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 @@ -208,18 +225,28 @@ tabulate==0.9.0 # via vaex-core text-unidecode==1.3 # via python-slugify +toml==0.10.2 + # via responses toolz==0.12.0 # via # dask # partd -typing-extensions==4.3.0 +types-toml==0.10.8.5 + # via responses +typing-extensions==4.5.0 # via + # arrow # flytekit + # importlib-metadata + # markdown-it-py # pydantic + # python-utils + # responses + # rich # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.12 +urllib3==1.26.14 # via # docker # flytekit @@ -227,13 +254,15 @@ urllib3==1.26.12 # responses vaex-core==4.13.0 # via flytekitplugins-vaex -websocket-client==0.59.0 +websocket-client==1.5.1 # via docker -wheel==0.38.0 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.1 - # via importlib-metadata +zipp==3.14.0 + # via + # importlib-metadata + # importlib-resources diff --git a/plugins/flytekit-whylogs/requirements.txt b/plugins/flytekit-whylogs/requirements.txt index 7596afc017..8dc1500e09 100644 --- a/plugins/flytekit-whylogs/requirements.txt +++ b/plugins/flytekit-whylogs/requirements.txt @@ -1,66 +1,85 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-whylogs # via -r requirements.in -appnope==0.1.3 - # via ipython -asttokens==2.0.8 - # via stack-data backcall==0.2.0 # via ipython +certifi==2022.12.7 + # via requests +charset-normalizer==3.0.1 + # via requests decorator==5.1.1 # via ipython -executing==1.0.0 - # via stack-data -ipython==8.5.0 +idna==3.4 + # via requests +importlib-metadata==4.2.0 + # via whylogs +ipython==7.34.0 # via whylogs -jedi==0.18.1 +jedi==0.18.2 # via ipython matplotlib-inline==0.1.6 # via ipython -numpy==1.23.3 - # via scipy +numpy==1.21.6 + # via + # scipy + # whylogs parso==0.8.3 # via jedi pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -prompt-toolkit==3.0.31 +pillow==9.4.0 + # via whylogs +prompt-toolkit==3.0.37 # via ipython -protobuf==3.20.2 +protobuf==3.20.3 # via # flytekitplugins-whylogs # whylogs ptyprocess==0.7.0 # via pexpect -pure-eval==0.2.2 - # via stack-data pybars3==0.9.7 # via whylogs -pygments==2.13.0 +pygments==2.14.0 # via ipython pymeta3==0.5.1 # via pybars3 -scipy==1.9.1 +python-dateutil==2.8.2 + # via whylabs-client +requests==2.28.2 + # via whylogs +scipy==1.7.3 # via whylogs six==1.16.0 - # via asttokens -stack-data==0.5.0 - # via ipython -traitlets==5.4.0 + # via python-dateutil +traitlets==5.9.0 # via # ipython # matplotlib-inline -typing-extensions==4.3.0 - # via whylogs -wcwidth==0.2.5 +typing-extensions==4.5.0 + # via + # importlib-metadata + # whylogs +urllib3==1.26.14 + # via + # requests + # whylabs-client +wcwidth==0.2.6 # via prompt-toolkit -whylogs[viz]==1.1.0 +whylabs-client==0.4.2 + # via whylogs +whylogs[viz]==1.1.27 # via flytekitplugins-whylogs whylogs-sketching==3.4.1.dev3 # via whylogs +zipp==3.14.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 783282f409..1c7dc65df6 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -16,25 +16,29 @@ attrs==20.3.0 # jsonschema binaryornot==0.4.4 # via cookiecutter -certifi==2022.9.24 - # via requests +cachetools==5.3.0 + # via google-auth +certifi==2022.12.7 + # via + # kubernetes + # requests cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.2.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit croniter==1.3.7 # via flytekit -cryptography==38.0.3 +cryptography==39.0.0 # via # pyopenssl # secretstorage @@ -52,11 +56,18 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.2.5 +flyteidl==1.2.9 # via flytekit -googleapis-common-protos==1.57.0 +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via flytekit +google-auth==2.16.0 + # via kubernetes +googleapis-common-protos==1.58.0 # via # flyteidl + # flytekit # grpcio-status grpcio==1.48.2 # via @@ -66,12 +77,14 @@ grpcio-status==1.48.2 # via flytekit idna==3.4 # via requests -importlib-metadata==5.0.0 +importlib-metadata==6.0.0 # via # click # flytekit # jsonschema # keyring +importlib-resources==5.10.2 + # via keyring jaraco-classes==3.2.3 # via keyring jeepney==0.8.0 @@ -88,9 +101,11 @@ joblib==1.2.0 # via flytekit jsonschema==3.2.0 # via -r requirements.in -keyring==23.11.0 +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +kubernetes==25.3.0 + # via flytekit +markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 # via @@ -113,8 +128,11 @@ numpy==1.21.6 # flytekit # pandas # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib packaging==21.3 # via + # -r requirements.in # docker # marshmallow pandas==1.3.5 @@ -134,27 +152,34 @@ py==1.11.0 # via retry pyarrow==10.0.0 # via flytekit +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth pycparser==2.21 # via cffi -pyopenssl==22.1.0 +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 # via packaging -pyrsistent==0.19.2 +pyrsistent==0.19.3 # via jsonschema python-dateutil==2.8.2 # via # arrow # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==7.0.0 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.6 +pytz==2022.7.1 # via # flytekit # pandas @@ -163,28 +188,38 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit + # kubernetes regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit + # kubernetes + # requests-oauthlib # responses +requests-oauthlib==1.3.1 + # via kubernetes responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit +rsa==4.9 + # via google-auth secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 # via - # grpcio + # google-auth # jsonschema + # kubernetes # python-dateutil # websocket-client +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 @@ -199,29 +234,34 @@ typing-extensions==4.4.0 # via # arrow # flytekit + # gitpython # importlib-metadata # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.12 +urllib3==1.26.14 # via # docker # flytekit + # kubernetes # requests # responses websocket-client==0.59.0 # via # -r requirements.in # docker + # kubernetes wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.10.0 - # via importlib-metadata +zipp==3.12.0 + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/requirements.txt b/requirements.txt index e4d3fbcdc0..15449d437d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # make requirements.txt # @@ -14,25 +14,29 @@ attrs==20.3.0 # jsonschema binaryornot==0.4.4 # via cookiecutter -certifi==2022.9.24 - # via requests +cachetools==5.3.0 + # via google-auth +certifi==2022.12.7 + # via + # kubernetes + # requests cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.2.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit croniter==1.3.7 # via flytekit -cryptography==38.0.3 +cryptography==39.0.0 # via # pyopenssl # secretstorage @@ -50,11 +54,18 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.2.5 +flyteidl==1.2.9 # via flytekit -googleapis-common-protos==1.57.0 +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via flytekit +google-auth==2.16.0 + # via kubernetes +googleapis-common-protos==1.58.0 # via # flyteidl + # flytekit # grpcio-status grpcio==1.48.2 # via @@ -64,12 +75,14 @@ grpcio-status==1.48.2 # via flytekit idna==3.4 # via requests -importlib-metadata==5.0.0 +importlib-metadata==6.0.0 # via # click # flytekit # jsonschema # keyring +importlib-resources==5.10.2 + # via keyring jaraco-classes==3.2.3 # via keyring jinja2==3.1.2 @@ -82,9 +95,11 @@ joblib==1.2.0 # via flytekit jsonschema==3.2.0 # via -r requirements.in -keyring==23.11.0 +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +kubernetes==25.3.0 + # via flytekit +markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 # via @@ -107,6 +122,8 @@ numpy==1.21.6 # flytekit # pandas # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib packaging==21.3 # via # docker @@ -128,27 +145,34 @@ py==1.11.0 # via retry pyarrow==10.0.0 # via flytekit +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth pycparser==2.21 # via cffi -pyopenssl==22.1.0 +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 # via packaging -pyrsistent==0.19.2 +pyrsistent==0.19.3 # via jsonschema python-dateutil==2.8.2 # via # arrow # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==7.0.0 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.6 +pytz==2022.7.1 # via # flytekit # pandas @@ -157,26 +181,38 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit + # kubernetes regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit + # kubernetes + # requests-oauthlib # responses +requests-oauthlib==1.3.1 + # via kubernetes responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit +rsa==4.9 + # via google-auth +secretstorage==3.3.3 + # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 # via - # grpcio + # google-auth # jsonschema + # kubernetes # python-dateutil # websocket-client +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 @@ -191,29 +227,34 @@ typing-extensions==4.4.0 # via # arrow # flytekit + # gitpython # importlib-metadata # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.12 +urllib3==1.26.14 # via # docker # flytekit + # kubernetes # requests # responses websocket-client==0.59.0 # via # -r requirements.in # docker + # kubernetes wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.10.0 - # via importlib-metadata +zipp==3.12.0 + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/setup.py b/setup.py index 0ac5dd5fdc..9de6a9f359 100644 --- a/setup.py +++ b/setup.py @@ -1,16 +1,5 @@ -import sys - from setuptools import find_packages, setup # noqa -MIN_PYTHON_VERSION = (3, 7) -CURRENT_PYTHON = sys.version_info[:2] -if CURRENT_PYTHON < MIN_PYTHON_VERSION: - print( - f"Flytekit API is only supported for Python version is {MIN_PYTHON_VERSION}+. Detected you are on" - f" version {CURRENT_PYTHON}, installation will not proceed!" - ) - sys.exit(-1) - extras_require = {} __version__ = "0.0.0+develop" @@ -39,7 +28,7 @@ ] }, install_requires=[ - "flyteidl>=1.2.0,<1.3.0", + "flyteidl>=1.2.9,<1.3.0", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", @@ -82,6 +71,8 @@ # TODO: We should remove mentions to the deprecated numpy # aliases. More details in https://github.com/flyteorg/flyte/issues/3166 "numpy<1.24.0", + "gitpython", + "kubernetes>=12.0.1", ], extras_require=extras_require, scripts=[ @@ -90,7 +81,7 @@ "flytekit/bin/entrypoint.py", ], license="apache2", - python_requires=">=3.7", + python_requires=">=3.7,<3.11", classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index 4b48fbcfb9..d5b07fe420 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -121,8 +121,9 @@ discovery_version, deprecated, cache_serializable, + pod_template_name, ) - for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable in product( + for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable, pod_template_name in product( [True, False], LIST_OF_RUNTIME_METADATA, [timedelta(days=i) for i in range(3)], @@ -131,6 +132,7 @@ ["1.0"], ["deprecated"], [True, False], + ["A", "B"], ) ] diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index 71225babcf..adecd293ca 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -14,19 +14,19 @@ cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.2.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit croniter==1.3.7 # via flytekit -cryptography==38.0.3 +cryptography==39.0.0 # via # pyopenssl # secretstorage @@ -46,29 +46,36 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.2.5 +flyteidl==1.3.5 # via flytekit -flytekit==1.2.4 +flytekit==1.3.1 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in fonttools==4.38.0 # via matplotlib -googleapis-common-protos==1.57.0 +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via flytekit +googleapis-common-protos==1.58.0 # via # flyteidl + # flytekit # grpcio-status -grpcio==1.48.2 +grpcio==1.51.1 # via # flytekit # grpcio-status -grpcio-status==1.48.2 +grpcio-status==1.51.1 # via flytekit idna==3.4 # via requests -importlib-metadata==5.0.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.10.2 + # via keyring jaraco-classes==3.2.3 # via keyring jeepney==0.8.0 @@ -85,11 +92,11 @@ joblib==1.2.0 # via # -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in # flytekit -keyring==23.11.0 +keyring==23.13.1 # via flytekit kiwisolver==1.4.4 # via matplotlib -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 # via @@ -115,21 +122,20 @@ numpy==1.21.6 # opencv-python # pandas # pyarrow -opencv-python==4.6.0.66 +opencv-python==4.7.0.68 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -packaging==21.3 +packaging==23.0 # via # docker # marshmallow # matplotlib pandas==1.3.5 # via flytekit -pillow==9.3.0 +pillow==9.4.0 # via matplotlib -protobuf==3.20.3 +protobuf==4.21.12 # via # flyteidl - # flytekit # googleapis-common-protos # grpcio-status # protoc-gen-swagger @@ -141,12 +147,10 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.1.0 +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 - # via - # matplotlib - # packaging + # via matplotlib python-dateutil==2.8.2 # via # arrow @@ -156,11 +160,11 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==7.0.0 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.6 +pytz==2022.7.1 # via # flytekit # pandas @@ -170,7 +174,7 @@ pyyaml==6.0 # flytekit regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker @@ -185,9 +189,9 @@ secretstorage==3.3.3 singledispatchmethod==1.0 # via flytekit six==1.16.0 - # via - # grpcio - # python-dateutil + # via python-dateutil +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 @@ -202,19 +206,20 @@ typing-extensions==4.4.0 # via # arrow # flytekit + # gitpython # importlib-metadata # kiwisolver # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.12 +urllib3==1.26.14 # via # docker # flytekit # requests # responses -websocket-client==1.4.2 +websocket-client==1.5.0 # via docker wheel==0.38.4 # via @@ -224,5 +229,7 @@ wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.10.0 - # via importlib-metadata +zipp==3.12.0 + # via + # importlib-metadata + # importlib-resources diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index dd021eb3be..78a828a507 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -8,9 +8,8 @@ import joblib import pytest -from flytekit import kwtypes +from flytekit import LaunchPlan, kwtypes from flytekit.configuration import Config -from flytekit.core.launch_plan import LaunchPlan from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task from flytekit.remote.remote import FlyteRemote @@ -221,6 +220,7 @@ def test_fetch_execute_task_convert_dict(flyteclient, flyte_workflows_register): flyte_task = remote.fetch_task(name="workflows.basic.dict_str_wf.convert_to_string", version=f"v{VERSION}") d: typing.Dict[str, str] = {"key1": "value1", "key2": "value2"} execution = remote.execute(flyte_task, {"d": d}, wait=True) + remote.sync_execution(execution, sync_nodes=True) assert json.loads(execution.outputs["o0"]) == {"key1": "value1", "key2": "value2"} diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index c8d1d93cbd..6dd1785585 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -1,7 +1,9 @@ +import os import typing from collections import OrderedDict import mock +import pytest from flyteidl.core.errors_pb2 import ErrorDocument from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution @@ -108,6 +110,37 @@ def verify_output(*args, **kwargs): assert mock_write_to_file.call_count == 1 +@mock.patch.dict(os.environ, {"FLYTE_FAIL_ON_ERROR": "True"}) +@mock.patch("flytekit.core.utils.load_proto_from_file") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.utils.write_proto_to_file") +def test_dispatch_execute_return_error_code(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): + mock_get_data.return_value = True + mock_upload_dir.return_value = True + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + python_task = mock.MagicMock() + python_task.dispatch_execute.side_effect = Exception("random") + + empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl() + mock_load_proto.return_value = empty_literal_map + + def verify_output(*args, **kwargs): + assert isinstance(args[0], ErrorDocument) + + mock_write_to_file.side_effect = verify_output + + with pytest.raises(SystemExit) as cm: + _dispatch_execute(ctx, python_task, "inputs path", "outputs prefix") + pytest.assertEqual(cm.value.code, 1) + + # This function collects outputs instead of writing them to a file. # See flytekit.core.utils.write_proto_to_file for the original def get_output_collector(results: OrderedDict): diff --git a/tests/flytekit/unit/cli/pyflyte/test_backfill.py b/tests/flytekit/unit/cli/pyflyte/test_backfill.py new file mode 100644 index 0000000000..8389295af2 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_backfill.py @@ -0,0 +1,46 @@ +from datetime import datetime, timedelta + +import click +import pytest +from click.testing import CliRunner +from mock import mock + +from flytekit.clis.sdk_in_container import pyflyte +from flytekit.clis.sdk_in_container.backfill import resolve_backfill_window +from flytekit.remote import FlyteRemote + + +def test_resolve_backfill_window(): + dt = datetime(2022, 12, 1, 8) + window = timedelta(days=10) + assert resolve_backfill_window(None, dt + window, window) == (dt, dt + window) + assert resolve_backfill_window(dt, None, window) == (dt, dt + window) + assert resolve_backfill_window(dt, dt + window) == (dt, dt + window) + with pytest.raises(click.BadParameter): + resolve_backfill_window() + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +def test_pyflyte_backfill(mock_remote): + mock_remote.generate_console_url.return_value = "ex" + runner = CliRunner() + with runner.isolated_filesystem(): + result = runner.invoke( + pyflyte.main, + [ + "backfill", + "--parallel", + "-p", + "flytesnacks", + "-d", + "development", + "--from-date", + "now", + "--backfill-window", + "5 day", + "daily", + "--dry-run", + ], + ) + assert result.exit_code == 0 + assert "Execution launched" in result.output diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index 4951d4be46..a6c0bb91d8 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -8,6 +8,7 @@ from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clis.sdk_in_container import pyflyte from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.configuration import Config from flytekit.core import context_manager from flytekit.remote.remote import FlyteRemote @@ -34,6 +35,7 @@ def test_saving_remote(mock_remote): mock_context.obj = {} get_and_save_remote_with_click_context(mock_context, "p", "d") assert mock_context.obj["flyte_remote"] is not None + mock_remote.assert_called_once_with(Config.for_sandbox(), default_project="p", default_domain="d") def test_register_with_no_package_or_module_argument(): diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index d5db7296b9..b211153f44 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -2,6 +2,7 @@ import os import pathlib import typing +from datetime import datetime, timedelta from enum import Enum import click @@ -16,6 +17,8 @@ from flytekit.clis.sdk_in_container.run import ( REMOTE_FLAG_KEY, RUN_LEVEL_PARAMS_KEY, + DateTimeType, + DurationParamType, FileParamType, FlyteLiteralConverter, get_entities_in_file, @@ -32,12 +35,21 @@ DIR_NAME = os.path.dirname(os.path.realpath(__file__)) -def test_pyflyte_run_wf(): - runner = CliRunner() - module_path = WORKFLOW_FILE - result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False) +@pytest.fixture +def remote(): + with mock.patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client: + flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + flyte_remote._client = mock_client + return flyte_remote - assert result.exit_code == 0 + +def test_pyflyte_run_wf(remote): + with mock.patch("flytekit.clis.sdk_in_container.helpers.get_and_save_remote_with_click_context"): + runner = CliRunner() + module_path = WORKFLOW_FILE + result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False) + + assert result.exit_code == 0 def test_imperative_wf(): @@ -330,3 +342,22 @@ def test_enum_converter(): assert union_lt.stored_type.simple is None assert union_lt.stored_type.enum_type.values == ["red", "green", "blue"] + + +def test_duration_type(): + t = DurationParamType() + assert t.convert(value="1 day", param=None, ctx=None) == timedelta(days=1) + + with pytest.raises(click.BadParameter): + t.convert(None, None, None) + + +def test_datetime_type(): + t = DateTimeType() + + assert t.convert("2020-01-01", None, None) == datetime(2020, 1, 1) + + now = datetime.now() + v = t.convert("now", None, None) + assert v.day == now.day + assert v.month == now.month diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index 71d4e29a7c..85438eb00d 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -55,8 +55,9 @@ def print_all( j: datetime.timedelta, k: Color, l: dict, + m: dict, ): - print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}") + print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}") @task @@ -85,9 +86,10 @@ def my_wf( l: dict, remote: pd.DataFrame, image: StructuredDataset, + m: dict = {"hello": "world"}, ) -> Annotated[StructuredDataset, subset_cols]: x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks show_sd(in_sd=x) show_sd(in_sd=image) - print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l) + print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m) return x diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index b3f1807b96..10a7e09333 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -40,12 +40,13 @@ def get_admin_stub_mock() -> mock.MagicMock: return auth_stub_mock +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.auth_service") @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") @mock.patch("flytekit.clients.raw.grpc.secure_channel") -def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy): +def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal): mock_secure_channel.return_value = True mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True @@ -73,6 +74,7 @@ def test_refresh_credentials_from_command(mock_call_to_external_process, mock_ad mock_set_access_token.assert_called_with(token, client.public_client_config.authorization_metadata_key) +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.get_basic_authorization_header") @mock.patch("flytekit.clients.raw.get_token") @@ -88,6 +90,7 @@ def test_refresh_client_credentials_aka_basic( mock_get_token, mock_get_basic_header, mock_dataproxy, + mock_signal, ): mock_secure_channel.return_value = True mock_channel.return_value = True @@ -112,12 +115,13 @@ def test_refresh_client_credentials_aka_basic( assert client._metadata[0][0] == "authorization" +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.auth_service") @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") @mock.patch("flytekit.clients.raw.grpc.secure_channel") -def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy): +def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal): mock_secure_channel.return_value = True mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True diff --git a/tests/flytekit/unit/configuration/configs/good.config b/tests/flytekit/unit/configuration/configs/good.config index 56bb837b00..06c2579d42 100644 --- a/tests/flytekit/unit/configuration/configs/good.config +++ b/tests/flytekit/unit/configuration/configs/good.config @@ -7,8 +7,8 @@ assumable_iam_role=some_role [platform] - url=fakeflyte.com +insecure=false [madeup] diff --git a/tests/flytekit/unit/configuration/configs/nossl.yaml b/tests/flytekit/unit/configuration/configs/nossl.yaml new file mode 100644 index 0000000000..f7acdde5a5 --- /dev/null +++ b/tests/flytekit/unit/configuration/configs/nossl.yaml @@ -0,0 +1,4 @@ +admin: + endpoint: dns:///flyte.mycorp.io + authType: Pkce + insecure: false diff --git a/tests/flytekit/unit/configuration/test_file.py b/tests/flytekit/unit/configuration/test_file.py index cb10bf42c0..3ce03f9c50 100644 --- a/tests/flytekit/unit/configuration/test_file.py +++ b/tests/flytekit/unit/configuration/test_file.py @@ -7,7 +7,8 @@ from pytimeparse.timeparse import timeparse from flytekit.configuration import ConfigEntry, get_config_file, set_if_exists -from flytekit.configuration.file import LegacyConfigEntry +from flytekit.configuration.file import LegacyConfigEntry, _exists +from flytekit.configuration.internal import Platform def test_set_if_exists(): @@ -21,6 +22,25 @@ def test_set_if_exists(): assert d["k"] == "x" +@pytest.mark.parametrize( + "data, expected", + [ + [1, True], + [1.0, True], + ["foo", True], + [True, True], + [False, True], + [[1], True], + [{"k": "v"}, True], + [None, False], + [[], False], + [{}, False], + ], +) +def test_exists(data, expected): + assert _exists(data) is expected + + def test_get_config_file(): c = get_config_file(None) assert c is None @@ -118,3 +138,9 @@ def test_env_var_bool_transformer(mock_file_read): # The last read should've triggered the file read since now the env var is no longer set. assert mock_file_read.call_count == 1 + + +def test_use_ssl(): + config_file = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) + res = Platform.INSECURE.read(config_file) + assert res is False diff --git a/tests/flytekit/unit/configuration/test_image_config.py b/tests/flytekit/unit/configuration/test_image_config.py index be59b883af..84c767f8fb 100644 --- a/tests/flytekit/unit/configuration/test_image_config.py +++ b/tests/flytekit/unit/configuration/test_image_config.py @@ -11,10 +11,10 @@ @pytest.mark.parametrize( "python_version_enum, expected_image_string", [ - (PythonVersion.PYTHON_3_7, "ghcr.io/flyteorg/flytekit:py3.7-latest"), - (PythonVersion.PYTHON_3_8, "ghcr.io/flyteorg/flytekit:py3.8-latest"), - (PythonVersion.PYTHON_3_9, "ghcr.io/flyteorg/flytekit:py3.9-latest"), - (PythonVersion.PYTHON_3_10, "ghcr.io/flyteorg/flytekit:py3.10-latest"), + (PythonVersion.PYTHON_3_7, "cr.flyte.org/flyteorg/flytekit:py3.7-latest"), + (PythonVersion.PYTHON_3_8, "cr.flyte.org/flyteorg/flytekit:py3.8-latest"), + (PythonVersion.PYTHON_3_9, "cr.flyte.org/flyteorg/flytekit:py3.9-latest"), + (PythonVersion.PYTHON_3_10, "cr.flyte.org/flyteorg/flytekit:py3.10-latest"), ], ) def test_defaults(python_version_enum, expected_image_string): @@ -24,8 +24,8 @@ def test_defaults(python_version_enum, expected_image_string): @pytest.mark.parametrize( "python_version_enum, flytekit_version, expected_image_string", [ - (PythonVersion.PYTHON_3_7, "v0.32.0", "ghcr.io/flyteorg/flytekit:py3.7-0.32.0"), - (PythonVersion.PYTHON_3_8, "1.31.3", "ghcr.io/flyteorg/flytekit:py3.8-1.31.3"), + (PythonVersion.PYTHON_3_7, "v0.32.0", "cr.flyte.org/flyteorg/flytekit:py3.7-0.32.0"), + (PythonVersion.PYTHON_3_8, "1.31.3", "cr.flyte.org/flyteorg/flytekit:py3.8-1.31.3"), ], ) def test_set_both(python_version_enum, flytekit_version, expected_image_string): @@ -36,7 +36,7 @@ def test_image_config_auto(): x = ImageConfig.auto_default_image() assert x.images[0].name == "default" version_str = f"{sys.version_info.major}.{sys.version_info.minor}" - assert x.images[0].full == f"ghcr.io/flyteorg/flytekit:py{version_str}-latest" + assert x.images[0].full == f"cr.flyte.org/flyteorg/flytekit:py{version_str}-latest" def test_image_from_flytectl_config(): @@ -56,7 +56,7 @@ def test_not_version(mock_sys): def test_image_create(): with pytest.raises(ValueError): - ImageConfig.create_from("ghcr.io/im/g:latest") + ImageConfig.create_from("cr.flyte.org/im/g:latest") - ic = ImageConfig.from_images("ghcr.io/im/g:latest") - assert ic.default_image.fqn == "ghcr.io/im/g" + ic = ImageConfig.from_images("cr.flyte.org/im/g:latest") + assert ic.default_image.fqn == "cr.flyte.org/im/g" diff --git a/tests/flytekit/unit/configuration/test_internal.py b/tests/flytekit/unit/configuration/test_internal.py index 7f6be53a55..97e30b5612 100644 --- a/tests/flytekit/unit/configuration/test_internal.py +++ b/tests/flytekit/unit/configuration/test_internal.py @@ -77,3 +77,9 @@ def test_some_int(mocked): res = AWS.RETRIES.read(cfg) assert type(res) is int assert res == 5 + + +def test_default_platform_config_endpoint_insecure(): + platform_config = PlatformConfig() + assert platform_config.endpoint == "localhost:30080" + assert platform_config.insecure is False diff --git a/tests/flytekit/unit/configuration/test_yaml_file.py b/tests/flytekit/unit/configuration/test_yaml_file.py index 7e1c3eee98..ba2c61e158 100644 --- a/tests/flytekit/unit/configuration/test_yaml_file.py +++ b/tests/flytekit/unit/configuration/test_yaml_file.py @@ -14,6 +14,7 @@ def test_config_entry_file(): assert c.read() is None cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/sample.yaml")) + assert cfg.yaml_config is not None assert c.read(cfg) == "flyte.mycorp.io" c = ConfigEntry(LegacyConfigEntry("platform", "url2", str)) # Does not exist @@ -26,6 +27,7 @@ def test_config_entry_file_normal(): cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/no_images.yaml")) images_dict = Images.get_specified_images(cfg) assert images_dict == {} + assert cfg.yaml_config is not None @mock.patch("flytekit.configuration.file.getenv") @@ -43,6 +45,7 @@ def test_config_entry_file_2(mock_get): cfg = get_config_file(sample_yaml_file_name) assert c.read(cfg) == "flyte.mycorp.io" + assert cfg.yaml_config is not None c = ConfigEntry(LegacyConfigEntry("platform", "url2", str)) # Does not exist assert c.read(cfg) is None @@ -67,3 +70,9 @@ def test_real_config(): res = Credentials.SCOPES.read(config_file) assert res == ["all"] + + +def test_use_ssl(): + config_file = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/nossl.yaml")) + res = Platform.INSECURE.read(config_file) + assert res is False diff --git a/tests/flytekit/unit/core/test_checkpoint.py b/tests/flytekit/unit/core/test_checkpoint.py index f1dbbbd5ef..2add1b9e7d 100644 --- a/tests/flytekit/unit/core/test_checkpoint.py +++ b/tests/flytekit/unit/core/test_checkpoint.py @@ -4,6 +4,7 @@ import flytekit from flytekit.core.checkpointer import SyncCheckpoint +from flytekit.core.local_cache import LocalTaskCache def test_sync_checkpoint_write(tmpdir): @@ -123,5 +124,23 @@ def t1(n: int) -> int: return n + 1 +@flytekit.task(cache=True, cache_version="v0") +def t2(n: int) -> int: + ctx = flytekit.current_context() + cp = ctx.checkpoint + cp.write(bytes(n + 1)) + return n + 1 + + +@pytest.fixture(scope="function", autouse=True) +def setup(): + LocalTaskCache.initialize() + LocalTaskCache.clear() + + def test_checkpoint_task(): assert t1(n=5) == 6 + + +def test_checkpoint_cached_task(): + assert t2(n=5) == 6 diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 3963c77c8d..09140d8cb7 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -198,3 +198,5 @@ def t3(c: Optional[int] = 3) -> Optional[int]: @workflow def wf(): return t3() + + wf() diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index be85918b74..24f051fbf7 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -71,6 +71,22 @@ def multiplier_2(my_input: float) -> float: multiplier_2(my_input=10.0) +def test_condition_else_int(): + @workflow + def multiplier_3(my_input: int) -> float: + return ( + conditional("fractions") + .if_((my_input >= 0) & (my_input < 1.0)) + .then(double(n=my_input)) + .elif_((my_input > 1.0) & (my_input < 10.0)) + .then(square(n=my_input)) + .else_() + .fail("The input must be between 0 and 10") + ) + + assert multiplier_3(my_input=0) == 0 + + def test_condition_sub_workflows(): @task def sum_div_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, div=int, sub=int): @@ -151,12 +167,16 @@ def decompose_unary() -> int: result = return_true() return conditional("test").if_(result).then(success()).else_().then(failed()) + decompose_unary() + with pytest.raises(AssertionError): @workflow def decompose_none() -> int: return conditional("test").if_(None).then(success()).else_().then(failed()) + decompose_none() + with pytest.raises(AssertionError): @workflow @@ -164,6 +184,8 @@ def decompose_is() -> int: result = return_true() return conditional("test").if_(result is True).then(success()).else_().then(failed()) + decompose_is() + @workflow def decompose() -> int: result = return_true() diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 98af80638a..6e68c9d4be 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -207,7 +207,7 @@ def test_serialization_settings_transport(): ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings - assert len(tp) == 376 + assert len(tp) == 388 def test_exec_params(): diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index cccf406c71..b9b0ebd3fa 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -34,11 +34,16 @@ def t1(a: int) -> str: a = a + 2 return "fast-" + str(a) + @workflow + def subwf(a: int): + t1(a=a) + @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] for i in range(a): s.append(t1(a=i)) + subwf(a=a) return s @workflow @@ -58,7 +63,7 @@ def my_wf(a: int) -> typing.List[str]: ) as ctx: input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5}) dynamic_job_spec = my_subwf.dispatch_execute(ctx, input_literal_map) - assert len(dynamic_job_spec._nodes) == 5 + assert len(dynamic_job_spec._nodes) == 6 assert len(dynamic_job_spec.tasks) == 1 args = " ".join(dynamic_job_spec.tasks[0].container.args) assert args.startswith( diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py index 318a6b76f3..7ceec809b1 100644 --- a/tests/flytekit/unit/core/test_flyte_pickle.py +++ b/tests/flytekit/unit/core/test_flyte_pickle.py @@ -95,5 +95,5 @@ def t1(data: Annotated[Union[np.ndarray, pd.DataFrame, Sequence], "some annotati task_spec = get_serializable(OrderedDict(), serialization_settings, t1) variants = task_spec.template.interface.inputs["data"].type.union_type.variants assert variants[0].blob.format == "NumpyArray" - assert variants[1].structured_dataset_type.format == "parquet" + assert variants[1].structured_dataset_type.format == "" assert variants[2].blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT diff --git a/tests/flytekit/unit/core/test_gate.py b/tests/flytekit/unit/core/test_gate.py new file mode 100644 index 0000000000..c92e1c9e19 --- /dev/null +++ b/tests/flytekit/unit/core/test_gate.py @@ -0,0 +1,325 @@ +import typing +from collections import OrderedDict +from datetime import timedelta +from io import StringIO + +from mock import patch + +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.condition import conditional +from flytekit.core.context_manager import ExecutionState, FlyteContextManager +from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.gate import approve, sleep, wait_for_input +from flytekit.core.task import task +from flytekit.core.type_engine import TypeEngine +from flytekit.core.workflow import workflow +from flytekit.remote.entities import FlyteWorkflow +from flytekit.tools.translator import gather_dependent_entities, get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +def test_basic_sleep(): + @task + def t1(a: int) -> int: + return a + 5 + + @workflow + def wf_sleep() -> int: + x = sleep(timedelta(seconds=10)) + b = t1(a=5) + x >> b + return b + + wf_spec = get_serializable(OrderedDict(), serialization_settings, wf_sleep) + assert len(wf_spec.template.nodes) == 2 + wf_spec.template.nodes[0].gate_node is not None + wf_spec.template.nodes[0].gate_node.sleep.duration == timedelta(seconds=10) + wf_spec.template.nodes[1].upstream_node_ids == ["n0"] + + +def test_basic_signal(): + @task + def t1(a: int) -> int: + return a + 5 + + @task + def t2(a: int) -> int: + return a + 6 + + @workflow + def wf(a: int) -> typing.Tuple[int, int, int]: + x = t1(a=a) + s1 = wait_for_input("my-signal-name", timeout=timedelta(hours=1), expected_type=bool) + s2 = wait_for_input("my-signal-name-2", timeout=timedelta(hours=2), expected_type=int) + z = t1(a=5) + y = t2(a=s2) + q = t2(a=approve(y, "approvalfory", timeout=timedelta(hours=2))) + x >> s1 + s1 >> z + + return y, z, q + + with patch("sys.stdin", StringIO("y\n3\ny\n")) as stdin, patch("sys.stdout", new_callable=StringIO): + res = wf(a=5) + assert res == (9, 10, 15) + assert stdin.read() == "" # all input consumed + + wf_spec = get_serializable(OrderedDict(), serialization_settings, wf) + assert len(wf_spec.template.nodes) == 7 + # The first t1 call + assert wf_spec.template.nodes[0].task_node is not None + + # The first signal s1, dependent on the first t1 call + assert wf_spec.template.nodes[1].upstream_node_ids == ["n0"] + assert wf_spec.template.nodes[1].gate_node is not None + assert wf_spec.template.nodes[1].gate_node.signal.signal_id == "my-signal-name" + assert wf_spec.template.nodes[1].gate_node.signal.type.simple == 4 + assert wf_spec.template.nodes[1].gate_node.signal.output_variable_name == "o0" + + # The second signal + assert wf_spec.template.nodes[2].upstream_node_ids == [] + assert wf_spec.template.nodes[2].gate_node is not None + assert wf_spec.template.nodes[2].gate_node.signal.signal_id == "my-signal-name-2" + assert wf_spec.template.nodes[2].gate_node.signal.type.simple == 1 + assert wf_spec.template.nodes[2].gate_node.signal.output_variable_name == "o0" + + # The second call to t1, dependent on the first signal + assert wf_spec.template.nodes[3].upstream_node_ids == ["n1"] + assert wf_spec.template.nodes[3].task_node is not None + + # The call to t2, dependent on the second signal + assert wf_spec.template.nodes[4].upstream_node_ids == ["n2"] + assert wf_spec.template.nodes[4].task_node is not None + + # Approval node + assert wf_spec.template.nodes[5].gate_node is not None + assert wf_spec.template.nodes[5].gate_node.approve is not None + assert wf_spec.template.nodes[5].upstream_node_ids == ["n4"] + assert len(wf_spec.template.nodes[5].inputs) == 1 + assert wf_spec.template.nodes[5].inputs[0].binding.promise.node_id == "n4" + assert wf_spec.template.nodes[5].inputs[0].binding.promise.var == "o0" + assert wf_spec.template.nodes[6].inputs[0].binding.promise.node_id == "n5" + assert wf_spec.template.nodes[6].inputs[0].binding.promise.var == "o0" + + assert wf_spec.template.outputs[0].binding.promise.node_id == "n4" + assert wf_spec.template.outputs[1].binding.promise.node_id == "n3" + assert wf_spec.template.outputs[2].binding.promise.node_id == "n6" + + +def test_dyn_signal(): + @task + def t1(a: int) -> int: + return a + 5 + + @task + def t2(a: int) -> int: + return a + 6 + + @dynamic + def dyn(a: int) -> typing.Tuple[int, int, int]: + x = t1(a=a) + s1 = wait_for_input("my-signal-name", timeout=timedelta(hours=1), expected_type=bool) + s2 = wait_for_input("my-signal-name-2", timeout=timedelta(hours=2), expected_type=int) + z = t1(a=5) + y = t2(a=s2) + q = t2(a=approve(y, "approvalfory", timeout=timedelta(hours=2))) + x >> s1 + s1 >> z + + return y, z, q + + @workflow + def wf_dyn(a: int) -> typing.Tuple[int, int, int]: + y, z, q = dyn(a=a) + return y, z, q + + with patch("sys.stdin", StringIO("y\n3\ny\n")) as stdin, patch("sys.stdout", new_callable=StringIO): + res = wf_dyn(a=5) + assert res == (9, 10, 15) + assert stdin.read() == "" # all input consumed + + wf_spec = get_serializable(OrderedDict(), serialization_settings, wf_dyn) + assert len(wf_spec.template.nodes) == 1 + # The first t1 call + assert wf_spec.template.nodes[0].task_node is not None + + with FlyteContextManager.with_context( + FlyteContextManager.current_context().with_serialization_settings(serialization_settings) + ) as ctx: + with FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=ExecutionState.Mode.TASK_EXECUTION, + ) + ) + ) as ctx: + input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 50}) + dynamic_job_spec = dyn.dispatch_execute(ctx, input_literal_map) + print(dynamic_job_spec) + + assert dynamic_job_spec.nodes[1].upstream_node_ids == ["dn0"] + assert dynamic_job_spec.nodes[1].gate_node is not None + assert dynamic_job_spec.nodes[1].gate_node.signal.signal_id == "my-signal-name" + assert dynamic_job_spec.nodes[1].gate_node.signal.type.simple == 4 + assert dynamic_job_spec.nodes[1].gate_node.signal.output_variable_name == "o0" + + assert dynamic_job_spec.nodes[2].upstream_node_ids == [] + assert dynamic_job_spec.nodes[2].gate_node is not None + assert dynamic_job_spec.nodes[2].gate_node.signal.signal_id == "my-signal-name-2" + assert dynamic_job_spec.nodes[2].gate_node.signal.type.simple == 1 + assert dynamic_job_spec.nodes[2].gate_node.signal.output_variable_name == "o0" + + assert dynamic_job_spec.nodes[5].gate_node is not None + assert dynamic_job_spec.nodes[5].gate_node.approve is not None + assert dynamic_job_spec.nodes[5].upstream_node_ids == ["dn4"] + assert len(dynamic_job_spec.nodes[5].inputs) == 1 + assert dynamic_job_spec.nodes[5].inputs[0].binding.promise.node_id == "dn4" + assert dynamic_job_spec.nodes[5].inputs[0].binding.promise.var == "o0" + assert dynamic_job_spec.nodes[6].inputs[0].binding.promise.node_id == "dn5" + assert dynamic_job_spec.nodes[6].inputs[0].binding.promise.var == "o0" + + +def test_dyn_signal_no_approve(): + @task + def t1(a: int) -> int: + return a + 5 + + @task + def t2(a: int) -> int: + return a + 6 + + @dynamic + def dyn(a: int) -> typing.Tuple[int, int]: + x = t1(a=a) + s1 = wait_for_input("my-signal-name", timeout=timedelta(hours=1), expected_type=bool) + s2 = wait_for_input("my-signal-name-2", timeout=timedelta(hours=2), expected_type=int) + z = t1(a=5) + y = t2(a=s2) + x >> s1 + s1 >> z + + return y, z + + @workflow + def wf_dyn(a: int) -> typing.Tuple[int, int]: + y, z = dyn(a=a) + return y, z + + with patch("sys.stdin", StringIO("y\n3\n")) as stdin, patch("sys.stdout", new_callable=StringIO): + wf_dyn(a=5) + assert stdin.read() == "" # all input consumed + + +def test_subwf(): + nt = typing.NamedTuple("Multi", named1=int, named2=int) + + @task + def nt1(a: int) -> nt: + a = a + 2 + return nt(a, a) + + @workflow + def subwf(a: int) -> nt: + return nt1(a=a) + + @workflow + def parent_wf(b: int) -> nt: + out = subwf(a=b) + return nt1(a=approve(out.named1, "subwf approve", timeout=timedelta(hours=2))) + + with patch("sys.stdin", StringIO("y\n")) as stdin, patch("sys.stdout", new_callable=StringIO): + x = parent_wf(b=3) + assert stdin.read() == "" # all input consumed + assert x == (7, 7) + + +def test_cond(): + @task + def five() -> int: + return 5 + + @task + def square(n: float) -> float: + return n * n + + @task + def double(n: float) -> float: + return 2 * n + + @workflow + def cond_wf() -> float: + f = five() + # Because approve itself produces a node, call approve outside of the conditional. + app = approve(f, "jfdkl", timeout=timedelta(hours=2)) + return conditional("fractions").if_(app == 5).then(double(n=f)).else_().then(square(n=f)) + + with patch("sys.stdin", StringIO("y\n")) as stdin, patch("sys.stdout", new_callable=StringIO): + x = cond_wf() + assert x == 10.0 + assert stdin.read() == "" + + +def test_cond_wait(): + @task + def square(n: float) -> float: + return n * n + + @task + def double(n: float) -> float: + return 2 * n + + @workflow + def cond_wf(a: int) -> float: + # Because approve itself produces a node, call approve outside of the conditional. + input_1 = wait_for_input("top-input", timeout=timedelta(hours=1), expected_type=int) + return conditional("fractions").if_(input_1 >= 5).then(double(n=a)).else_().then(square(n=a)) + + with patch("sys.stdin", StringIO("3\n")) as stdin, patch("sys.stdout", new_callable=StringIO): + x = cond_wf(a=3) + assert x == 9 + assert stdin.read() == "" + + with patch("sys.stdin", StringIO("8\n")) as stdin, patch("sys.stdout", new_callable=StringIO): + x = cond_wf(a=3) + assert x == 6 + assert stdin.read() == "" + + +def test_promote(): + @task + def t1(a: int) -> int: + return a + 5 + + @task + def t2(a: int) -> int: + return a + 6 + + @workflow + def wf(a: int) -> typing.Tuple[int, int, int]: + zzz = sleep(timedelta(seconds=10)) + x = t1(a=a) + s1 = wait_for_input("my-signal-name", timeout=timedelta(hours=1), expected_type=bool) + s2 = wait_for_input("my-signal-name-2", timeout=timedelta(hours=2), expected_type=int) + z = t1(a=5) + y = t2(a=s2) + q = t2(a=approve(y, "approvalfory", timeout=timedelta(hours=2))) + zzz >> x + x >> s1 + s1 >> z + + return y, z, q + + entries = OrderedDict() + wf_spec = get_serializable(entries, serialization_settings, wf) + tts, wf_specs, lp_specs = gather_dependent_entities(entries) + + fwf = FlyteWorkflow.promote_from_model(wf_spec.template, tasks=tts) + assert fwf.template.nodes[2].gate_node is not None diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index ab90d991b1..db4b32f6a9 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -16,7 +16,6 @@ from flytekit.tools.translator import get_serializable from flytekit.types.file import FlyteFile from flytekit.types.schema import FlyteSchema -from flytekit.types.structured.structured_dataset import StructuredDatasetType default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -373,6 +372,4 @@ def ref_t2( assert len(wf_spec.template.interface.outputs) == 1 assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type is not None - assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type == StructuredDatasetType( - format="parquet" - ) + assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type.format == "" diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 442851a8a2..db05de0ddb 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -4,6 +4,7 @@ from typing_extensions import Annotated # type: ignore +from flytekit import task from flytekit.core import context_manager from flytekit.core.docstring import Docstring from flytekit.core.interface import ( @@ -320,3 +321,20 @@ def z(a: Foo) -> Foo: assert params.parameters["a"].default is None assert our_interface.outputs["o0"].__origin__ == FlytePickle assert our_interface.inputs["a"].__origin__ == FlytePickle + + +def test_doc_string(): + @task + def t1(a: int) -> int: + """Set the temperature value. + + The value of the temp parameter is stored as a value in + the class variable temperature. + """ + return a + + assert t1.docs.short_description == "Set the temperature value." + assert ( + t1.docs.long_description.value + == "The value of the temp parameter is stored as a value in\nthe class variable temperature." + ) diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 14c9620ae6..95927873d0 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -159,6 +159,8 @@ def wf1(a: int): def wf2(a: typing.List[int]): return map_task(wf1)(a=a) + wf2() + lp = LaunchPlan.create("test", wf1) with pytest.raises(ValueError): @@ -167,6 +169,8 @@ def wf2(a: typing.List[int]): def wf3(a: typing.List[int]): return map_task(lp)(a=a) + wf3() + def test_inputs_outputs_length(): @task diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 47c8af9830..48d3020e88 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -1,6 +1,7 @@ import datetime import typing from collections import OrderedDict +from dataclasses import dataclass import pytest @@ -95,6 +96,8 @@ def empty_wf2(): def empty_wf2(): create_node(t2, "foo") + empty_wf2() + def test_more_normal_task(): nt = typing.NamedTuple("OneOutput", t1_str_output=str) @@ -144,6 +147,8 @@ def my_wf(a: int) -> str: t1_node = create_node(t1, a=a) return t1_node.outputs + my_wf() + def test_runs_before(): @task @@ -333,6 +338,8 @@ def t1(a: str) -> str: def my_wf(a: str) -> str: return t1(a=a).with_overrides(timeout="foo") + my_wf() + @pytest.mark.parametrize( "retries,expected", @@ -424,3 +431,27 @@ def my_wf(a: str) -> str: wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.name == "foo" + + +def test_config_override(): + @dataclass + class DummyConfig: + name: str + + @task(task_config=DummyConfig(name="hello")) + def t1(a: str) -> str: + return f"*~*~*~{a}*~*~*~" + + @workflow + def my_wf(a: str) -> str: + return t1(a=a).with_overrides(task_config=DummyConfig("flyte")) + + assert my_wf.nodes[0].flyte_entity.task_config.name == "flyte" + + with pytest.raises(ValueError): + + @workflow + def my_wf(a: str) -> str: + return t1(a=a).with_overrides(task_config=None) + + my_wf() diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 23b3de4573..d8b043116e 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -4,7 +4,7 @@ import pytest from dataclasses_json import dataclass_json -from flytekit import task +from flytekit import LaunchPlan, task, workflow from flytekit.core import context_manager from flytekit.core.context_manager import CompilationState from flytekit.core.promise import ( @@ -64,6 +64,32 @@ def t2(a: int) -> int: assert len(p.ref.node.bindings) == 1 +def test_create_and_link_node_from_remote_ignore(): + @workflow + def wf(i: int, j: int): + ... + + lp = LaunchPlan.get_or_create(wf, name="promise-test", fixed_inputs={"i": 1}, default_inputs={"j": 10}) + ctx = context_manager.FlyteContext.current_context().with_compilation_state(CompilationState(prefix="")) + + # without providing the _inputs_not_allowed or _ignorable_inputs, all inputs to lp become required, + # which is incorrect + with pytest.raises(FlyteAssertion, match="Missing input `i` type `simple: INTEGER"): + create_and_link_node_from_remote(ctx, lp) + + # Even if j is not provided it will default + create_and_link_node_from_remote(ctx, lp, _inputs_not_allowed={"i"}, _ignorable_inputs={"j"}) + + # value of `i` cannot be overriden + with pytest.raises( + FlyteAssertion, match="ixed inputs cannot be specified. Please remove the following inputs - {'i'}" + ): + create_and_link_node_from_remote(ctx, lp, _inputs_not_allowed={"i"}, _ignorable_inputs={"j"}, i=15) + + # It is ok to override `j` which is a default input + create_and_link_node_from_remote(ctx, lp, _inputs_not_allowed={"i"}, _ignorable_inputs={"j"}, j=15) + + @pytest.mark.parametrize( "input", [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3]], diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index 108b42ebf7..24856432b4 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -1,9 +1,14 @@ from typing import Any import pytest +from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements, V1Volume from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.base_task import TaskMetadata +from flytekit.core.pod_template import PodTemplate from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image +from flytekit.core.resources import Resources +from flytekit.tools.translator import get_serializable_task @pytest.fixture @@ -36,7 +41,6 @@ def execute(self, **kwargs) -> Any: task = DummyAutoContainerTask(name="x", task_config=None, task_type="t") -task_with_env_vars = DummyAutoContainerTask(name="x", environment={"HAM": "spam"}, task_config=None, task_type="t") def test_default_command(default_serialization_settings): @@ -68,14 +72,227 @@ def test_get_container(default_serialization_settings): assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"FOO": "bar"} + ts = get_serializable_task(default_serialization_settings, task) + assert ts.template.container.image == "docker.io/xyz:some-git-hash" + assert ts.template.container.env == {"FOO": "bar"} + + +task_with_env_vars = DummyAutoContainerTask(name="x", environment={"HAM": "spam"}, task_config=None, task_type="t") + def test_get_container_with_task_envvars(default_serialization_settings): c = task_with_env_vars.get_container(default_serialization_settings) assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"FOO": "bar", "HAM": "spam"} + ts = get_serializable_task(default_serialization_settings, task_with_env_vars) + assert ts.template.container.image == "docker.io/xyz:some-git-hash" + assert ts.template.container.env == {"FOO": "bar", "HAM": "spam"} + def test_get_container_without_serialization_settings_envvars(minimal_serialization_settings): c = task_with_env_vars.get_container(minimal_serialization_settings) assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"HAM": "spam"} + + ts = get_serializable_task(minimal_serialization_settings, task_with_env_vars) + assert ts.template.container.image == "docker.io/xyz:some-git-hash" + assert ts.template.container.env == {"HAM": "spam"} + + +task_with_pod_template = DummyAutoContainerTask( + name="x", + metadata=TaskMetadata( + pod_template_name="podTemplateB", # should be overwritten + retries=3, # ensure other fields still exists + ), + task_config=None, + task_type="t", + container_image="repo/image:0.0.0", + requests=Resources(cpu="3", gpu="1"), + limits=Resources(cpu="6", gpu="2"), + environment={"eKeyA": "eValA", "eKeyB": "vKeyB"}, + pod_template=PodTemplate( + primary_container_name="primary", + labels={"lKeyA": "lValA", "lKeyB": "lValB"}, + annotations={"aKeyA": "aValA", "aKeyB": "aValB"}, + pod_spec=V1PodSpec( + containers=[ + V1Container( + name="notPrimary", + ), + V1Container( + name="primary", + image="repo/placeholderImage:0.0.0", + command="placeholderCommand", + args="placeholderArgs", + resources=V1ResourceRequirements(limits={"cpu": "999", "gpu": "999"}), + env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")], + ), + ], + volumes=[V1Volume(name="volume")], + ), + ), + pod_template_name="podTemplateA", +) + + +def test_pod_template(default_serialization_settings): + ################# + # Test get_k8s_pod + ################# + + container = task_with_pod_template.get_container(default_serialization_settings) + assert container is None + + k8s_pod = task_with_pod_template.get_k8s_pod(default_serialization_settings) + + # labels/annotations should be passed + metadata = k8s_pod.metadata + assert metadata.labels == {"lKeyA": "lValA", "lKeyB": "lValB"} + assert metadata.annotations == {"aKeyA": "aValA", "aKeyB": "aValB"} + + pod_spec = k8s_pod.pod_spec + primary_container = pod_spec["containers"][1] + + # To test overwritten attributes + + # image + assert primary_container["image"] == "repo/image:0.0.0" + # command + assert primary_container["command"] == [] + # args + assert primary_container["args"] == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.flytekit.unit.core.test_python_auto_container", + "task-name", + "task_with_pod_template", + ] + # resource + assert primary_container["resources"]["requests"] == {"cpu": "3", "gpu": "1"} + assert primary_container["resources"]["limits"] == {"cpu": "6", "gpu": "2"} + + # To test union attributes + assert primary_container["env"] == [ + {"name": "FOO", "value": "bar"}, + {"name": "eKeyA", "value": "eValA"}, + {"name": "eKeyB", "value": "vKeyB"}, + {"name": "eKeyC", "value": "eValC"}, + {"name": "eKeyD", "value": "eValD"}, + ] + + # To test not overwritten attributes + assert pod_spec["volumes"][0] == {"name": "volume"} + + ################# + # Test pod_template_name + ################# + assert task_with_pod_template.metadata.pod_template_name == "podTemplateA" + assert task_with_pod_template.metadata.retries == 3 + + config = task_with_minimum_pod_template.get_config(default_serialization_settings) + + ################# + # Test config + ################# + assert config == {"primary_container_name": "primary"} + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, task_with_pod_template) + assert ts.template.container is None + # k8s_pod content is already verified above, so only check the existence here + assert ts.template.k8s_pod is not None + + assert ts.template.metadata.pod_template_name == "podTemplateA" + assert ts.template.metadata.retries.retries == 3 + assert ts.template.config is not None + + +task_with_minimum_pod_template = DummyAutoContainerTask( + name="x", + task_config=None, + task_type="t", + container_image="repo/image:0.0.0", + pod_template=PodTemplate( + primary_container_name="primary", + labels={"lKeyA": "lValA"}, + annotations={"aKeyA": "aValA"}, + ), + pod_template_name="A", +) + + +def test_minimum_pod_template(default_serialization_settings): + + ################# + # Test get_k8s_pod + ################# + + container = task_with_minimum_pod_template.get_container(default_serialization_settings) + assert container is None + + k8s_pod = task_with_minimum_pod_template.get_k8s_pod(default_serialization_settings) + + metadata = k8s_pod.metadata + assert metadata.labels == {"lKeyA": "lValA"} + assert metadata.annotations == {"aKeyA": "aValA"} + + pod_spec = k8s_pod.pod_spec + primary_container = pod_spec["containers"][0] + + assert primary_container["image"] == "repo/image:0.0.0" + assert primary_container["command"] == [] + assert primary_container["args"] == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.flytekit.unit.core.test_python_auto_container", + "task-name", + "task_with_minimum_pod_template", + ] + + config = task_with_minimum_pod_template.get_config(default_serialization_settings) + assert config == {"primary_container_name": "primary"} + + ################# + # Test pod_teamplte_name + ################# + assert task_with_minimum_pod_template.metadata.pod_template_name == "A" + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, task_with_minimum_pod_template) + assert ts.template.container is None + # k8s_pod content is already verified above, so only check the existence here + assert ts.template.k8s_pod is not None + assert ts.template.metadata.pod_template_name == "A" + assert ts.template.config is not None diff --git a/tests/flytekit/unit/core/test_python_function_task.py b/tests/flytekit/unit/core/test_python_function_task.py index 34aaefaeb3..02e04a302f 100644 --- a/tests/flytekit/unit/core/test_python_function_task.py +++ b/tests/flytekit/unit/core/test_python_function_task.py @@ -1,10 +1,13 @@ import pytest +from kubernetes.client.models import V1Container, V1PodSpec from flytekit import task from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.pod_template import PodTemplate from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.tracker import isnested, istestfunction +from flytekit.tools.translator import get_serializable_task from tests.flytekit.unit.core import tasks @@ -122,3 +125,83 @@ def foo_missing_cache_version(i: str): @task(cache_serialize=True) def foo_missing_cache(i: str): print(f"{i}") + + +def test_pod_template(): + @task( + container_image="repo/image:0.0.0", + pod_template=PodTemplate( + primary_container_name="primary", + labels={"lKeyA": "lValA"}, + annotations={"aKeyA": "aValA"}, + pod_spec=V1PodSpec( + containers=[ + V1Container( + name="primary", + ), + ] + ), + ), + pod_template_name="A", + ) + def func_with_pod_template(i: str): + print(i + 3) + + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") + default_image_config = ImageConfig(default_image=default_image) + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + + ################# + # Test get_k8s_pod + ################# + + container = func_with_pod_template.get_container(default_serialization_settings) + assert container is None + + k8s_pod = func_with_pod_template.get_k8s_pod(default_serialization_settings) + + metadata = k8s_pod.metadata + assert metadata.labels == {"lKeyA": "lValA"} + assert metadata.annotations == {"aKeyA": "aValA"} + + pod_spec = k8s_pod.pod_spec + primary_container = pod_spec["containers"][0] + + assert primary_container["image"] == "repo/image:0.0.0" + assert primary_container["command"] == [] + assert primary_container["args"] == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.flytekit.unit.core.test_python_function_task", + "task-name", + "func_with_pod_template", + ] + + ################# + # Test pod_teamplte_name + ################# + assert func_with_pod_template.metadata.pod_template_name == "A" + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, func_with_pod_template) + assert ts.template.container is None + # k8s_pod content is already verified above, so only check the existence here + assert ts.template.k8s_pod is not None + assert ts.template.metadata.pod_template_name == "A" diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index a96a94843b..862c469460 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -440,6 +440,8 @@ def my_wf() -> wf_outputs: # Note only Namedtuples can be created like this return wf_outputs(say_hello(), say_hello()) + my_wf() + def test_serialized_docstrings(): @task diff --git a/tests/flytekit/unit/core/test_signal.py b/tests/flytekit/unit/core/test_signal.py new file mode 100644 index 0000000000..a3bee2e4c7 --- /dev/null +++ b/tests/flytekit/unit/core/test_signal.py @@ -0,0 +1,48 @@ +import pytest +from flyteidl.admin.signal_pb2 import Signal, SignalList +from mock import MagicMock + +from flytekit.configuration import Config +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.models.core.identifier import SignalIdentifier, WorkflowExecutionIdentifier +from flytekit.remote.remote import FlyteRemote + + +@pytest.fixture +def remote(): + flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + flyte_remote._client_initialized = True + return flyte_remote + + +def test_remote_list_signals(remote): + ctx = FlyteContextManager.current_context() + wfeid = WorkflowExecutionIdentifier("p", "d", "execid") + signal_id = SignalIdentifier(signal_id="sigid", execution_id=wfeid).to_flyte_idl() + lt = TypeEngine.to_literal_type(int) + signal = Signal( + id=signal_id, + type=lt.to_flyte_idl(), + value=TypeEngine.to_literal(ctx, 3, int, lt).to_flyte_idl(), + ) + + mock_client = MagicMock() + mock_client.list_signals.return_value = SignalList(signals=[signal], token="") + + remote._client = mock_client + res = remote.list_signals("execid", "p", "d", limit=10) + assert len(res) == 1 + + +def test_remote_set_signal(remote): + mock_client = MagicMock() + + def checker(request): + assert request.id.signal_id == "sigid" + assert request.value.scalar.primitive.integer == 3 + + mock_client.set_signal.side_effect = checker + + remote._client = mock_client + remote.set_signal("sigid", "execid", 3) diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index 7793df430f..bfb41d0fef 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -7,11 +7,13 @@ from typing_extensions import Annotated import flytekit.configuration -from flytekit import kwtypes, task from flytekit.configuration import Image, ImageConfig -from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.base_task import kwtypes +from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine +from flytekit.core.workflow import workflow from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import SchemaType, SimpleType, StructuredDatasetType @@ -38,6 +40,7 @@ image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) +df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) def test_protocol(): @@ -49,13 +52,67 @@ def generate_pandas() -> pd.DataFrame: return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]}) +def test_formats_make_sense(): + @task + def t1(a: pd.DataFrame) -> pd.DataFrame: + print(a) + return generate_pandas() + + # this should be an empty string format + assert t1.interface.outputs["o0"].type.structured_dataset_type.format == "" + assert t1.interface.inputs["a"].type.structured_dataset_type.format == "" + + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) + ) + ): + result = t1(a=generate_pandas()) + val = result.val.scalar.value + assert val.metadata.structured_dataset_type.format == "parquet" + + +def test_setting_of_unset_formats(): + + custom = Annotated[StructuredDataset, "parquet"] + example = custom(dataframe=df, uri="/path") + # It's okay that the annotation is not used here yet. + assert example.file_format == "" + + @task + def t2(path: str) -> StructuredDataset: + sd = StructuredDataset(dataframe=df, uri=path) + return sd + + @workflow + def wf(path: str) -> StructuredDataset: + return t2(path=path) + + res = wf(path="/tmp/somewhere") + # Now that it's passed through an encoder however, it should be set. + assert res.file_format == "parquet" + + +def test_json(): + sd = StructuredDataset(dataframe=df, uri="/some/path") + sd.file_format = "myformat" + json_str = sd.to_json() + new_sd = StructuredDataset.from_json(json_str) + assert new_sd.file_format == "myformat" + + def test_types_pandas(): pt = pd.DataFrame lt = TypeEngine.to_literal_type(pt) assert lt.structured_dataset_type is not None - assert lt.structured_dataset_type.format == PARQUET + assert lt.structured_dataset_type.format == "" assert lt.structured_dataset_type.columns == [] + pt = Annotated[pd.DataFrame, "csv"] + lt = TypeEngine.to_literal_type(pt) + assert lt.structured_dataset_type.format == "csv" + def test_annotate_extraction(): xyz = Annotated[pd.DataFrame, "myformat"] @@ -68,7 +125,7 @@ def test_annotate_extraction(): a, b, c, d = extract_cols_and_format(pd.DataFrame) assert a is pd.DataFrame assert b is None - assert c is None + assert c == "" assert d is None @@ -115,9 +172,10 @@ def test_types_sd(): def test_retrieving(): assert StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET) is not None - with pytest.raises(ValueError): - # We don't have a default "" format encoder - StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", "") + # Asking for a generic means you're okay with any one registered for that type assuming there's just one. + assert StructuredDatasetTransformerEngine.get_encoder( + pd.DataFrame, "file", "" + ) is StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET) class TempEncoder(StructuredDatasetEncoder): def __init__(self, protocol): @@ -188,9 +246,10 @@ def encode( ) -> literals.StructuredDataset: return literals.StructuredDataset(uri="") - StructuredDatasetTransformerEngine.register(TempEncoder("myavro"), default_for_type=True) + default_encoder = TempEncoder("myavro") + StructuredDatasetTransformerEngine.register(default_encoder, default_for_type=True) lt = TypeEngine.to_literal_type(MyDF) - assert lt.structured_dataset_type.format == "myavro" + assert lt.structured_dataset_type.format == "" ctx = FlyteContextManager.current_context() fdt = StructuredDatasetTransformerEngine() @@ -228,7 +287,7 @@ def encode( def test_sd(): sd = StructuredDataset(dataframe="hi") sd.uri = "my uri" - assert sd.file_format == PARQUET + assert sd.file_format == "" with pytest.raises(ValueError, match="No dataframe type set"): sd.all() @@ -383,7 +442,7 @@ def encode( assert df_literal_type.structured_dataset_type.format == "avro" sd = annotated_sd_type(df) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Failed to find a handler"): TypeEngine.to_literal(ctx, sd, python_type=annotated_sd_type, expected=df_literal_type) StructuredDatasetTransformerEngine.register(TempEncoder(), default_for_type=False) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 3e813c0fb7..bd270fd360 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -6,6 +6,7 @@ from datetime import timedelta from enum import Enum +import mock import pandas as pd import pyarrow as pa import pytest @@ -44,7 +45,7 @@ from flytekit.models.annotation import TypeAnnotation from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Void -from flytekit.models.types import LiteralType, SimpleType, TypeStructure +from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType from flytekit.types.directory import TensorboardLogs from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import FileExt, JPEGImageFile @@ -569,6 +570,90 @@ def test_dataclass_int_preserving(): assert ot == o +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +def test_optional_flytefile_in_dataclass(mock_upload_dir): + mock_upload_dir.return_value = True + + @dataclass_json + @dataclass + class A(object): + a: int + + @dataclass_json + @dataclass + class TestFileStruct(object): + a: FlyteFile + b: typing.Optional[FlyteFile] + b_prime: typing.Optional[FlyteFile] + c: typing.Union[FlyteFile, None] + d: typing.List[FlyteFile] + e: typing.List[typing.Optional[FlyteFile]] + e_prime: typing.List[typing.Optional[FlyteFile]] + f: typing.Dict[str, FlyteFile] + g: typing.Dict[str, typing.Optional[FlyteFile]] + g_prime: typing.Dict[str, typing.Optional[FlyteFile]] + h: typing.Optional[FlyteFile] = None + h_prime: typing.Optional[FlyteFile] = None + i: typing.Optional[A] = None + i_prime: typing.Optional[A] = A(a=99) + + remote_path = "s3://tmp/file" + with tempfile.TemporaryFile() as f: + f.write(b"abc") + f1 = FlyteFile("f1", remote_path=remote_path) + o = TestFileStruct( + a=f1, + b=f1, + b_prime=None, + c=f1, + d=[f1], + e=[f1], + e_prime=[None], + f={"a": f1}, + g={"a": f1}, + g_prime={"a": None}, + h=f1, + i=A(a=42), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct) + lv = tf.to_literal(ctx, o, TestFileStruct, lt) + + assert lv.scalar.generic["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b_prime"] is None + assert lv.scalar.generic["c"].fields["path"].string_value == remote_path + assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" + assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g_prime"]["a"] is None + assert lv.scalar.generic["h"].fields["path"].string_value == remote_path + assert lv.scalar.generic["h_prime"] is None + assert lv.scalar.generic["i"].fields["a"].number_value == 42 + assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99 + + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct) + + assert o.a.path == ot.a.remote_source + assert o.b.path == ot.b.remote_source + assert ot.b_prime is None + assert o.c.path == ot.c.remote_source + assert o.d[0].path == ot.d[0].remote_source + assert o.e[0].path == ot.e[0].remote_source + assert o.e_prime == [None] + assert o.f["a"].path == ot.f["a"].remote_source + assert o.g["a"].path == ot.g["a"].remote_source + assert o.g_prime == {"a": None} + assert o.h.path == ot.h.remote_source + assert ot.h_prime is None + assert o.i == ot.i + assert o.i_prime == A(a=99) + + def test_flyte_file_in_dataclass(): @dataclass_json @dataclass @@ -664,18 +749,19 @@ class TestFileStruct(object): def test_structured_dataset_in_dataclass(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + People = Annotated[StructuredDataset, "parquet", kwtypes(Name=str, Age=int)] @dataclass_json @dataclass class InnerDatasetStruct(object): a: StructuredDataset - b: typing.List[StructuredDataset] - c: typing.Dict[str, StructuredDataset] + b: typing.List[Annotated[StructuredDataset, "parquet"]] + c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]] @dataclass_json @dataclass class DatasetStruct(object): - a: StructuredDataset + a: People b: InnerDatasetStruct sd = StructuredDataset(dataframe=df, file_format="parquet") @@ -856,6 +942,18 @@ def test_union_transformer(): assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int +def test_union_guess_type(): + ut = UnionTransformer() + t = ut.guess_python_type( + LiteralType( + union_type=UnionType( + variants=[LiteralType(simple=SimpleType.STRING), LiteralType(simple=SimpleType.INTEGER)] + ) + ) + ) + assert t == typing.Union[str, int] + + def test_union_type_with_annotated(): pt = typing.Union[ Annotated[str, FlyteAnnotation({"hello": "world"})], Annotated[int, FlyteAnnotation({"test": 123})] diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index b6d2d77ae5..9da416c1e8 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -176,11 +176,11 @@ def my_wf(a: int, b: str) -> (int, str): d = t2(a=y, b=b) return x, d - assert len(my_wf._nodes) == 2 + assert len(my_wf.nodes) == 2 assert my_wf._nodes[0].id == "n0" assert my_wf._nodes[1]._upstream_nodes[0] is my_wf._nodes[0] - assert len(my_wf._output_bindings) == 2 + assert len(my_wf.output_bindings) == 2 assert my_wf._output_bindings[0].var == "o0" assert my_wf._output_bindings[0].binding.promise.var == "t1_int_output" @@ -282,18 +282,24 @@ def test_wf_output_mismatch(): def my_wf(a: int, b: str) -> (int, str): return a + my_wf() + with pytest.raises(AssertionError): @workflow def my_wf2(a: int, b: str) -> int: return a, b # type: ignore + my_wf2() + with pytest.raises(AssertionError): @workflow def my_wf3(a: int, b: str) -> int: return (a,) # type: ignore + my_wf3() + assert context_manager.FlyteContextManager.size() == 1 @@ -678,7 +684,7 @@ def lister() -> typing.List[str]: return s assert len(lister.interface.outputs) == 1 - binding_data = lister._output_bindings[0].binding # the property should be named binding_data + binding_data = lister.output_bindings[0].binding # the property should be named binding_data assert binding_data.collection is not None assert len(binding_data.collection.bindings) == 10 @@ -802,6 +808,8 @@ def my_wf(a: int, b: str) -> (int, str): conditional("test2").if_(x == 4).then(t2(a=b)).elif_(x >= 5).then(t2(a=y)).else_().fail("blah") return x, d + my_wf() + assert context_manager.FlyteContextManager.size() == 1 diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 46389daed2..4f1082df63 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -7,8 +7,9 @@ from typing_extensions import Annotated # type: ignore import flytekit.configuration -from flytekit import StructuredDataset, kwtypes +from flytekit import FlyteContextManager, StructuredDataset, kwtypes from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager from flytekit.core.condition import conditional from flytekit.core.task import task from flytekit.core.workflow import WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, workflow @@ -156,6 +157,8 @@ def no_outputs_wf(): def one_output_wf() -> int: # noqa t1(a=3) + one_output_wf() + def test_wf_no_output(): @task @@ -320,3 +323,18 @@ def test_structured_dataset_wf(): assert_frame_equal(sd_to_schema_wf(), superset_df) assert_frame_equal(schema_to_sd_wf()[0], subset_df) assert_frame_equal(schema_to_sd_wf()[1], subset_df) + + +def test_compile_wf_at_compile_time(): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.new_execution_state().with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ): + + @workflow + def wf(): + t4() + + assert ctx.compilation_state is None diff --git a/tests/flytekit/unit/extras/pytorch/test_transformations.py b/tests/flytekit/unit/extras/pytorch/test_transformations.py index 9724a01182..a470b646d4 100644 --- a/tests/flytekit/unit/extras/pytorch/test_transformations.py +++ b/tests/flytekit/unit/extras/pytorch/test_transformations.py @@ -40,6 +40,7 @@ def test_get_literal_type(transformer, python_type, format): tf = transformer lt = tf.get_literal_type(python_type) assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)) + assert tf.guess_python_type(lt) == python_type @pytest.mark.parametrize( diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index cc0f2deee1..40fc94a3d2 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -1,4 +1,5 @@ import pandas +import pytest from flytekit import kwtypes, task, workflow from flytekit.configuration import DefaultImages @@ -108,3 +109,36 @@ def test_task_serialization(): sql_task._container_image = image tt = sql_task.serialize_to_model(sql_task.SERIALIZE_SETTINGS) assert tt.container.image == image + + +@pytest.mark.parametrize( + "query_template, expected_query", + [ + ( + """ +select * +from tracks +limit {{.inputs.limit}}""", + "select * from tracks limit {{.inputs.limit}}", + ), + ( + """ \ +select * \ +from tracks \ +limit {{.inputs.limit}}""", + "select * from tracks limit {{.inputs.limit}}", + ), + ("select * from abc", "select * from abc"), + ], +) +def test_query_sanitization(query_template, expected_query): + sql_task = SQLite3Task( + "test", + query_template=query_template, + inputs=kwtypes(limit=int), + task_config=SQLite3Config( + uri=EXAMPLE_DB, + compressed=True, + ), + ) + assert sql_task.query_template == expected_query diff --git a/tests/flytekit/unit/extras/tensorflow/__init__.py b/tests/flytekit/unit/extras/tensorflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/tensorflow/record/__init__.py b/tests/flytekit/unit/extras/tensorflow/record/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/tensorflow/record/test_record.py b/tests/flytekit/unit/extras/tensorflow/record/test_record.py new file mode 100644 index 0000000000..c523475562 --- /dev/null +++ b/tests/flytekit/unit/extras/tensorflow/record/test_record.py @@ -0,0 +1,117 @@ +from typing import Dict, Tuple + +import numpy as np +import tensorflow as tf +from tensorflow.python.data.ops.readers import TFRecordDatasetV2 +from typing_extensions import Annotated + +from flytekit import task, workflow +from flytekit.extras.tensorflow.record import TFRecordDatasetConfig +from flytekit.types.directory import TFRecordsDirectory +from flytekit.types.file import TFRecordFile + +a = tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"foo", b"bar"])) +b = tf.train.Feature(float_list=tf.train.FloatList(value=[1.0, 2.0])) +c = tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])) +d = tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"ham", b"spam"])) +e = tf.train.Feature(float_list=tf.train.FloatList(value=[8.0, 9.0])) +f = tf.train.Feature(int64_list=tf.train.Int64List(value=[22, 23])) +features1 = tf.train.Features(feature=dict(a=a, b=b, c=c)) +features2 = tf.train.Features(feature=dict(a=d, b=e, c=f)) + + +def decode_fn(dataset: TFRecordDatasetV2) -> Dict[str, np.ndarray]: + examples_list = [] + # parse serialised tensors https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file_2 + for batch in list(dataset.as_numpy_iterator()): + example = tf.train.Example() + example.ParseFromString(batch) + examples_list.append(example) + result = {} + for a in examples_list: + # convert example to dict of numpy arrays + # https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file_2 + for key, feature in a.features.feature.items(): + kind = feature.WhichOneof("kind") + val = np.array(getattr(feature, kind).value) + if key not in result.keys(): + result[key] = val + else: + result.update({key: np.concatenate((result[key], val))}) + return result + + +@task +def generate_tf_record_file() -> TFRecordFile: + return tf.train.Example(features=features1) + + +@task +def generate_tf_record_dir() -> TFRecordsDirectory: + return [tf.train.Example(features=features1), tf.train.Example(features=features2)] + + +@task +def t1( + dataset: Annotated[ + TFRecordFile, + TFRecordDatasetConfig(buffer_size=1024, num_parallel_reads=3, compression_type="GZIP"), + ] +): + assert isinstance(dataset, TFRecordDatasetV2) + assert dataset._compression_type == "GZIP" + assert dataset._buffer_size == 1024 + assert dataset._num_parallel_reads == 3 + + +@task +def t2(dataset: TFRecordFile): + + # if not annotated with TFRecordDatasetConfig, all attributes should default to None + assert isinstance(dataset, TFRecordDatasetV2) + assert dataset._compression_type is None + assert dataset._buffer_size is None + assert dataset._num_parallel_reads is None + + +@task +def t3(dataset: TFRecordsDirectory): + + # if not annotated with TFRecordDatasetConfig, all attributes should default to None + assert isinstance(dataset, TFRecordDatasetV2) + assert dataset._compression_type is None + assert dataset._buffer_size is None + assert dataset._num_parallel_reads is None + + +@task +def t4(dataset: Annotated[TFRecordFile, TFRecordDatasetConfig(buffer_size=1024)]) -> Dict[str, np.ndarray]: + return decode_fn(dataset) + + +@task +def t5(dataset: Annotated[TFRecordsDirectory, TFRecordDatasetConfig(buffer_size=1024)]) -> Dict[str, np.ndarray]: + return decode_fn(dataset) + + +@workflow +def wf() -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + file = generate_tf_record_file() + files = generate_tf_record_dir() + t1(dataset=file) + t2(dataset=file) + t3(dataset=files) + files_res = t4(dataset=file) + dir_res = t5(dataset=files) + return files_res, dir_res + + +def test_wf(): + file_res, dir_res = wf() + assert np.array_equal(file_res["a"], np.array([b"foo", b"bar"])) + assert np.array_equal(file_res["b"], np.array([1.0, 2.0])) + assert np.array_equal(file_res["c"], np.array([3, 4])) + + assert np.array_equal(np.sort(dir_res["a"]), np.array([b"bar", b"foo", b"ham", b"spam"])) + assert np.array_equal(np.sort(dir_res["b"]), np.array([1.0, 2.0, 8.0, 9.0])) + assert np.array_equal(np.sort(dir_res["c"]), np.array([3, 4, 22, 23])) diff --git a/tests/flytekit/unit/extras/tensorflow/record/test_transformations.py b/tests/flytekit/unit/extras/tensorflow/record/test_transformations.py new file mode 100644 index 0000000000..f236696a0c --- /dev/null +++ b/tests/flytekit/unit/extras/tensorflow/record/test_transformations.py @@ -0,0 +1,89 @@ +import pytest +import tensorflow +import tensorflow as tf +from tensorflow.core.example.example_pb2 import Example +from tensorflow.python.data.ops.readers import TFRecordDatasetV2 +from typing_extensions import Annotated + +import flytekit +from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.extras.tensorflow.record import ( + TensorFlowRecordFileTransformer, + TensorFlowRecordsDirTransformer, + TFRecordDatasetConfig, +) +from flytekit.models.core.types import BlobType +from flytekit.models.literals import BlobMetadata +from flytekit.models.types import LiteralType +from flytekit.types.directory import TFRecordsDirectory +from flytekit.types.file import TFRecordFile + +from .test_record import features1, features2 + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +@pytest.mark.parametrize( + "transformer,python_type,format,dimensionality", + [ + (TensorFlowRecordFileTransformer(), TFRecordFile, TensorFlowRecordFileTransformer.TENSORFLOW_FORMAT, 0), + (TensorFlowRecordsDirTransformer(), TFRecordsDirectory, TensorFlowRecordsDirTransformer.TENSORFLOW_FORMAT, 1), + ], +) +def test_get_literal_type(transformer, python_type, format, dimensionality): + tf = transformer + lt = tf.get_literal_type(python_type) + assert lt == LiteralType(blob=BlobType(format=format, dimensionality=dimensionality)) + + +@pytest.mark.parametrize( + "transformer,python_type,format,python_val,dimension", + [ + ( + TensorFlowRecordFileTransformer(), + TFRecordFile, + TensorFlowRecordFileTransformer.TENSORFLOW_FORMAT, + tf.train.Example(features=features1), + BlobType.BlobDimensionality.SINGLE, + ), + ( + TensorFlowRecordsDirTransformer(), + TFRecordsDirectory, + TensorFlowRecordsDirTransformer.TENSORFLOW_FORMAT, + [tf.train.Example(features=features1), tf.train.Example(features=features2)], + BlobType.BlobDimensionality.MULTIPART, + ), + ], +) +def test_to_python_value_and_literal(transformer, python_type, format, python_val, dimension): + ctx = context_manager.FlyteContext.current_context() + tf = transformer + lt = tf.get_literal_type(python_type) + lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=format, + dimensionality=dimension, + ) + ) + assert lv.scalar.blob.uri is not None + output = tf.to_python_value(ctx, lv, Annotated[python_type, TFRecordDatasetConfig(buffer_size=1024)]) + assert isinstance(output, TFRecordDatasetV2) + results = [] + example = tensorflow.train.Example() + for raw_record in output: + example.ParseFromString(raw_record.numpy()) + results.append(example) + if isinstance(python_val, list): + assert len(results) == 2 + assert all(list(map(lambda x: isinstance(x, Example), python_val))) + else: + assert results == [python_val] diff --git a/tests/flytekit/unit/models/test_documentation.py b/tests/flytekit/unit/models/test_documentation.py new file mode 100644 index 0000000000..7702df0452 --- /dev/null +++ b/tests/flytekit/unit/models/test_documentation.py @@ -0,0 +1,29 @@ +from flytekit.models.documentation import Description, Documentation, SourceCode + + +def test_long_description(): + value = "long" + icon_link = "http://icon" + obj = Description(value=value, icon_link=icon_link) + assert Description.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.value == value + assert obj.icon_link == icon_link + assert obj.format == Description.DescriptionFormat.RST + + +def test_source_code(): + link = "https://github.com/flyteorg/flytekit" + obj = SourceCode(link=link) + assert SourceCode.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.link == link + + +def test_documentation(): + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + obj = Documentation(short_description=short_description, long_description=long_description, source_code=source_code) + assert Documentation.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.short_description == short_description + assert obj.long_description == long_description + assert obj.source_code == source_code diff --git a/tests/flytekit/unit/models/test_execution.py b/tests/flytekit/unit/models/test_execution.py index b327d5e9d6..fac5604543 100644 --- a/tests/flytekit/unit/models/test_execution.py +++ b/tests/flytekit/unit/models/test_execution.py @@ -28,6 +28,8 @@ def test_execution_closure_with_output(): started_at=test_datetime, duration=test_timedelta, outputs=test_outputs, + created_at=None, + updated_at=test_datetime, ) assert obj.phase == _core_exec.WorkflowExecutionPhase.SUCCEEDED assert obj.started_at == test_datetime @@ -39,6 +41,8 @@ def test_execution_closure_with_output(): assert obj2.started_at == test_datetime assert obj2.duration == test_timedelta assert obj2.outputs == test_outputs + assert obj2.created_at is None + assert obj2.updated_at == test_datetime def test_execution_closure_with_error(): @@ -53,6 +57,8 @@ def test_execution_closure_with_error(): started_at=test_datetime, duration=test_timedelta, error=test_error, + created_at=test_datetime, + updated_at=None, ) assert obj.phase == _core_exec.WorkflowExecutionPhase.SUCCEEDED assert obj.started_at == test_datetime @@ -62,6 +68,8 @@ def test_execution_closure_with_error(): assert obj2 == obj assert obj2.phase == _core_exec.WorkflowExecutionPhase.SUCCEEDED assert obj2.started_at == test_datetime + assert obj2.created_at == test_datetime + assert obj2.updated_at is None assert obj2.duration == test_timedelta assert obj2.error == test_error diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index fcebf465f9..a979a39b66 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -7,6 +7,7 @@ import flytekit.models.interface as interface_models import flytekit.models.literals as literal_models +from flytekit import Description, Documentation, SourceCode from flytekit.models import literals, task, types from flytekit.models.core import identifier from tests.flytekit.common import parameterizers @@ -70,6 +71,7 @@ def test_task_metadata(): "0.1.1b0", "This is deprecated!", True, + "A", ) assert obj.discoverable is True @@ -81,6 +83,7 @@ def test_task_metadata(): assert obj.runtime.version == "1.0.0" assert obj.deprecated_error_message == "This is deprecated!" assert obj.discovery_version == "0.1.1b0" + assert obj.pod_template_name == "A" assert obj == task.TaskMetadata.from_flyte_idl(obj.to_flyte_idl()) @@ -123,7 +126,62 @@ def test_task_template(in_tuple): assert obj.config == {"a": "b"} -def test_task_template__k8s_pod_target(): +def test_task_spec(): + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + {"a": interface_models.Variable(int_type, "description1")}, + { + "b": interface_models.Variable(int_type, "description2"), + "c": interface_models.Variable(int_type, "description3"), + }, + ) + + resource = [task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1")] + resources = task.Resources(resource, resource) + + template = task.TaskTemplate( + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), + "python", + task_metadata, + interfaces, + {"a": 1, "b": {"c": 2, "d": 3}}, + container=task.Container( + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, + ), + config={"a": "b"}, + ) + + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + docs = Documentation( + short_description=short_description, long_description=long_description, source_code=source_code + ) + + obj = task.TaskSpec(template, docs) + assert task.TaskSpec.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.docs == docs + assert obj.template == template + + +def test_task_template_k8s_pod_target(): int_type = types.LiteralType(types.SimpleType.INTEGER) obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), @@ -137,6 +195,7 @@ def test_task_template__k8s_pod_target(): "1.0", "deprecated", False, + "A", ), interface_models.TypedInterface( # inputs diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index 3a42f5af81..2b5b06696b 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -5,8 +5,10 @@ from flytekit.models import task as _task from flytekit.models import types as _types from flytekit.models import workflow_closure as _workflow_closure +from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _workflow +from flytekit.models.documentation import Description, Documentation, SourceCode def test_workflow_closure(): @@ -36,6 +38,7 @@ def test_workflow_closure(): "0.1.1b0", "This is deprecated!", True, + "A", ) cpu_resource = _task.Resources.ResourceEntry(_task.Resources.ResourceName.CPU, "1") @@ -81,3 +84,16 @@ def test_workflow_closure(): obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 + + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + docs = Documentation( + short_description=short_description, long_description=long_description, source_code=source_code + ) + + workflow_spec = WorkflowSpec(template=template, sub_workflows=[], docs=docs) + assert WorkflowSpec.from_flyte_idl(workflow_spec.to_flyte_idl()) == workflow_spec + assert workflow_spec.docs.short_description == short_description + assert workflow_spec.docs.long_description == long_description + assert workflow_spec.docs.source_code == source_code diff --git a/tests/flytekit/unit/remote/responses/CompiledWorkflowClosure.pb b/tests/flytekit/unit/remote/responses/CompiledWorkflowClosure.pb new file mode 100644 index 0000000000..1f3ce5c79a Binary files /dev/null and b/tests/flytekit/unit/remote/responses/CompiledWorkflowClosure.pb differ diff --git a/tests/flytekit/unit/remote/test_backfill.py b/tests/flytekit/unit/remote/test_backfill.py new file mode 100644 index 0000000000..1d4884115d --- /dev/null +++ b/tests/flytekit/unit/remote/test_backfill.py @@ -0,0 +1,95 @@ +from datetime import datetime, timedelta + +import pytest + +from flytekit import CronSchedule, FixedRate, LaunchPlan, task, workflow +from flytekit.remote.backfill import create_backfill_workflow + + +@task +def tk(t: datetime, v: int): + print(f"Invoked at {t} with v {v}") + + +@workflow +def example_wf(t: datetime, v: int): + tk(t=t, v=v) + + +def test_create_backfiller_error(): + no_schedule = LaunchPlan.get_or_create( + workflow=example_wf, + name="nos", + fixed_inputs={"v": 10}, + ) + rate_schedule = LaunchPlan.get_or_create( + workflow=example_wf, + name="rate", + fixed_inputs={"v": 10}, + schedule=FixedRate(duration=timedelta(days=1)), + ) + start_date = datetime(2022, 12, 1, 8) + end_date = start_date + timedelta(days=10) + + with pytest.raises(ValueError): + create_backfill_workflow(start_date, end_date, no_schedule) + + with pytest.raises(ValueError): + create_backfill_workflow(end_date, start_date, no_schedule) + + with pytest.raises(ValueError): + create_backfill_workflow(end_date, start_date, None) + + with pytest.raises(NotImplementedError): + create_backfill_workflow(start_date, end_date, rate_schedule) + + +def test_create_backfiller(): + daily_lp = LaunchPlan.get_or_create( + workflow=example_wf, + name="daily", + fixed_inputs={"v": 10}, + schedule=CronSchedule(schedule="0 8 * * *", kickoff_time_input_arg="t"), + ) + + start_date = datetime(2022, 12, 1, 8) + end_date = start_date + timedelta(days=10) + + wf, start, end = create_backfill_workflow(start_date, end_date, daily_lp) + assert isinstance(wf.nodes[0].flyte_entity, LaunchPlan) + b0, b1 = wf.nodes[0].bindings[0], wf.nodes[0].bindings[1] + assert b0.var == "t" + assert b0.binding.scalar.primitive.datetime.day == 2 + assert b1.var == "v" + assert b1.binding.scalar.primitive.integer == 10 + assert len(wf.nodes) == 9 + assert len(wf.nodes[0].upstream_nodes) == 0 + assert len(wf.nodes[1].upstream_nodes) == 1 + assert wf.nodes[1].upstream_nodes[0] == wf.nodes[0] + assert start + assert end + + +def test_create_backfiller_parallel(): + daily_lp = LaunchPlan.get_or_create( + workflow=example_wf, + name="daily", + fixed_inputs={"v": 10}, + schedule=CronSchedule(schedule="0 8 * * *", kickoff_time_input_arg="t"), + ) + + start_date = datetime(2022, 12, 1, 8) + end_date = start_date + timedelta(days=10) + + wf, start, end = create_backfill_workflow(start_date, end_date, daily_lp, parallel=True) + assert isinstance(wf.nodes[0].flyte_entity, LaunchPlan) + b0, b1 = wf.nodes[0].bindings[0], wf.nodes[0].bindings[1] + assert b0.var == "t" + assert b0.binding.scalar.primitive.datetime.day == 2 + assert b1.var == "v" + assert b1.binding.scalar.primitive.integer == 10 + assert len(wf.nodes) == 9 + assert len(wf.nodes[0].upstream_nodes) == 0 + assert len(wf.nodes[1].upstream_nodes) == 0 + assert start + assert end diff --git a/tests/flytekit/unit/remote/test_calling.py b/tests/flytekit/unit/remote/test_calling.py index 00d80464c3..289fba37d7 100644 --- a/tests/flytekit/unit/remote/test_calling.py +++ b/tests/flytekit/unit/remote/test_calling.py @@ -12,11 +12,10 @@ from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion -from flytekit.models.core.workflow import WorkflowTemplate -from flytekit.models.task import TaskTemplate -from flytekit.remote import FlyteLaunchPlan, FlyteTask +from flytekit.models.admin.workflow import WorkflowSpec +from flytekit.models.task import TaskSpec +from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow from flytekit.remote.interface import TypedInterface -from flytekit.remote.workflow import FlyteWorkflow from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") @@ -63,7 +62,7 @@ def wf(a: int) -> int: serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, wf) vals = [v for v in serialized.values()] - tts = [f for f in filter(lambda x: isinstance(x, TaskTemplate), vals)] + tts = [f for f in filter(lambda x: isinstance(x, TaskSpec), vals)] assert len(tts) == 1 assert wf_spec.template.nodes[0].id == "foobar" assert wf_spec.template.outputs[0].binding.promise.node_id == "foobar" @@ -76,6 +75,8 @@ def test_misnamed(): def wf(a: int) -> int: return ft(b=a) + wf() + def test_calling_lp(): sub_wf_lp = LaunchPlan.get_or_create(sub_wf) @@ -143,9 +144,11 @@ def my_subwf(a: int) -> typing.List[int]: def test_calling_wf(): # No way to fetch from Admin in unit tests so we serialize and then promote back serialized = OrderedDict() - wf_spec = get_serializable(serialized, serialization_settings, sub_wf) + wf_spec: WorkflowSpec = get_serializable(serialized, serialization_settings, sub_wf) task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) - fwf = FlyteWorkflow.promote_from_model(wf_spec.template, tasks=task_templates) + fwf = FlyteWorkflow.promote_from_model( + wf_spec.template, tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()} + ) @workflow def parent_1(a: int, b: str) -> typing.Tuple[int, str]: @@ -162,8 +165,14 @@ def parent_1(a: int, b: str) -> typing.Tuple[int, str]: # Pick out the subworkflow templates from the ordereddict. We can't use the output of the gather_dependent_entities # function because that only looks for WorkflowSpecs - subwf_templates = {x.id: x for x in list(filter(lambda x: isinstance(x, WorkflowTemplate), serialized.values()))} - fwf_p1 = FlyteWorkflow.promote_from_model(wf_spec.template, sub_workflows=subwf_templates, tasks=task_templates_p1) + subwf_templates = { + x.template.id: x.template for x in list(filter(lambda x: isinstance(x, WorkflowSpec), serialized.values())) + } + fwf_p1 = FlyteWorkflow.promote_from_model( + wf_spec.template, + sub_workflows=subwf_templates, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates_p1.items()}, + ) @workflow def parent_2(a: int, b: str) -> typing.Tuple[int, str]: diff --git a/tests/flytekit/unit/remote/test_lazy_entity.py b/tests/flytekit/unit/remote/test_lazy_entity.py new file mode 100644 index 0000000000..5328a2caf0 --- /dev/null +++ b/tests/flytekit/unit/remote/test_lazy_entity.py @@ -0,0 +1,78 @@ +import pytest +from mock import patch + +from flytekit import TaskMetadata +from flytekit.core import context_manager +from flytekit.models.core.identifier import Identifier, ResourceType +from flytekit.models.interface import TypedInterface +from flytekit.remote import FlyteTask +from flytekit.remote.lazy_entity import LazyEntity + + +def test_missing_getter(): + with pytest.raises(ValueError): + LazyEntity("x", None) + + +dummy_task = FlyteTask( + id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), + type="t", + metadata=TaskMetadata().to_taskmetadata_model(), + interface=TypedInterface(inputs={}, outputs={}), + custom=None, +) + + +def test_lazy_loading(): + once = True + + def _getter(): + nonlocal once + if not once: + raise ValueError("Should be called once only") + once = False + return dummy_task + + e = LazyEntity("x", _getter) + assert e.__repr__() == "Promise for entity [x]" + assert e.name == "x" + assert e._entity is None + assert not e.entity_fetched() + v = e.entity + assert e._entity is not None + assert v == dummy_task + assert e.entity == dummy_task + assert e.entity_fetched() + + +@patch("flytekit.remote.remote_callable.create_and_link_node_from_remote") +def test_lazy_loading_compile(create_and_link_node_from_remote_mock): + once = True + + def _getter(): + nonlocal once + if not once: + raise ValueError("Should be called once only") + once = False + return dummy_task + + e = LazyEntity("x", _getter) + assert e.name == "x" + assert e._entity is None + ctx = context_manager.FlyteContext.current_context() + e.compile(ctx) + assert e._entity is not None + assert e.entity == dummy_task + + +def test_lazy_loading_exception(): + def _getter(): + raise AttributeError("Error") + + e = LazyEntity("x", _getter) + assert e.name == "x" + assert e._entity is None + with pytest.raises(RuntimeError) as exc: + assert e.blah + + assert isinstance(exc.value.__cause__, AttributeError) diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index dd37b97f87..4b8f82fb7e 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -1,19 +1,29 @@ import os import pathlib import tempfile +from collections import OrderedDict +from datetime import datetime, timedelta import pytest +from flyteidl.core import compiler_pb2 as _compiler_pb2 from mock import MagicMock, patch import flytekit.configuration +from flytekit import CronSchedule, LaunchPlan, task, workflow from flytekit.configuration import Config, DefaultImages, ImageConfig +from flytekit.core.base_task import PythonTask from flytekit.exceptions import user as user_exceptions from flytekit.models import common as common_models from flytekit.models import security -from flytekit.models.core.identifier import ResourceType, WorkflowExecutionIdentifier +from flytekit.models.admin.workflow import Workflow, WorkflowClosure +from flytekit.models.core.compiler import CompiledWorkflowClosure +from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier from flytekit.models.execution import Execution +from flytekit.models.task import Task +from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote import FlyteRemote -from flytekit.tools.translator import Options +from flytekit.tools.translator import Options, get_serializable, get_serializable_launch_plan +from tests.flytekit.common.parameterizers import LIST_OF_TASK_CLOSURES CLIENT_METHODS = { ResourceType.WORKFLOW: "list_workflows_paginated", @@ -34,29 +44,36 @@ } -@patch("flytekit.clients.friendly.SynchronousFlyteClient") -def test_remote_fetch_execution(mock_client_manager): +@pytest.fixture +def remote(): + with patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client: + flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + flyte_remote._client_initialized = True + flyte_remote._client = mock_client + return flyte_remote + + +def test_remote_fetch_execution(remote): admin_workflow_execution = Execution( id=WorkflowExecutionIdentifier("p1", "d1", "n1"), spec=MagicMock(), closure=MagicMock(), ) - mock_client = MagicMock() mock_client.get_execution.return_value = admin_workflow_execution - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client flyte_workflow_execution = remote.fetch_execution(name="n1") assert flyte_workflow_execution.id == admin_workflow_execution.id -@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -def test_underscore_execute_uses_launch_plan_attributes(mock_wf_exec): +@pytest.fixture +def mock_wf_exec(): + return patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") + + +def test_underscore_execute_uses_launch_plan_attributes(remote, mock_wf_exec): mock_wf_exec.return_value = True mock_client = MagicMock() - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client def local_assertions(*args, **kwargs): @@ -83,12 +100,9 @@ def local_assertions(*args, **kwargs): ) -@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -def test_underscore_execute_fall_back_remote_attributes(mock_wf_exec): +def test_underscore_execute_fall_back_remote_attributes(remote, mock_wf_exec): mock_wf_exec.return_value = True mock_client = MagicMock() - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client options = Options( @@ -114,14 +128,11 @@ def local_assertions(*args, **kwargs): ) -@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -def test_execute_with_wrong_input_key(mock_wf_exec): +def test_execute_with_wrong_input_key(remote, mock_wf_exec): # mock_url.get.return_value = "localhost" # mock_insecure.get.return_value = True mock_wf_exec.return_value = True mock_client = MagicMock() - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client mock_entity = MagicMock() @@ -152,7 +163,7 @@ def test_passing_of_kwargs(mock_client): "root_certificates": 5, "certificate_chain": 6, } - FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args) + FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args).client assert mock_client.called assert mock_client.call_args[1] == additional_args @@ -225,7 +236,7 @@ def test_generate_console_http_domain_sandbox_rewrite(mock_client): remote = FlyteRemote( config=Config.auto(config_file=temp_filename), default_project="project", default_domain="domain" ) - assert remote.generate_console_http_domain() == "http://localhost:30080" + assert remote.generate_console_http_domain() == "http://localhost:30081" with open(temp_filename, "w") as f: # This string is similar to the relevant configuration emitted by flytectl in the cases of both demo and sandbox. @@ -247,3 +258,91 @@ def test_generate_console_http_domain_sandbox_rewrite(mock_client): os.remove(temp_filename) except OSError: pass + + +def get_compiled_workflow_closure(): + """ + :rtype: flytekit.models.core.compiler.CompiledWorkflowClosure + """ + cwc_pb = _compiler_pb2.CompiledWorkflowClosure() + # So that tests that use this work when run from any directory + basepath = os.path.dirname(__file__) + filepath = os.path.abspath(os.path.join(basepath, "responses", "CompiledWorkflowClosure.pb")) + with open(filepath, "rb") as fh: + cwc_pb.ParseFromString(fh.read()) + + return CompiledWorkflowClosure.from_flyte_idl(cwc_pb) + + +def test_fetch_lazy(remote): + mock_client = remote._client + mock_client.get_task.return_value = Task( + id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), closure=LIST_OF_TASK_CLOSURES[0] + ) + + mock_client.get_workflow.return_value = Workflow( + id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), + closure=WorkflowClosure(compiled_workflow=get_compiled_workflow_closure()), + ) + + lw = remote.fetch_workflow_lazy(name="wn", version="v") + assert isinstance(lw, LazyEntity) + assert lw._getter + assert lw._entity is None + assert lw.entity + + lt = remote.fetch_task_lazy(name="n", version="v") + assert isinstance(lw, LazyEntity) + assert lt._getter + assert lt._entity is None + tk = lt.entity + assert tk.name == "n" + + +@task +def tk(t: datetime, v: int): + print(f"Invoked at {t} with v {v}") + + +@workflow +def example_wf(t: datetime, v: int): + tk(t=t, v=v) + + +def test_launch_backfill(remote): + daily_lp = LaunchPlan.get_or_create( + workflow=example_wf, + name="daily2", + fixed_inputs={"v": 10}, + schedule=CronSchedule(schedule="0 8 * * *", kickoff_time_input_arg="t"), + ) + + serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig.auto(img_name=DefaultImages.default_image()), + ) + + start_date = datetime(2022, 12, 1, 8) + end_date = start_date + timedelta(days=10) + + ser_lp = get_serializable_launch_plan(OrderedDict(), serialization_settings, daily_lp, recurse_downstream=False) + m = OrderedDict() + ser_wf = get_serializable(m, serialization_settings, example_wf) + tasks = [] + for k, v in m.items(): + if isinstance(k, PythonTask): + tasks.append(v) + mock_client = remote._client + mock_client.get_launch_plan.return_value = ser_lp + mock_client.get_workflow.return_value = Workflow( + id=Identifier(ResourceType.WORKFLOW, "p", "d", "daily2", "v"), + closure=WorkflowClosure( + compiled_workflow=CompiledWorkflowClosure(primary=ser_wf, sub_workflows=[], tasks=tasks) + ), + ) + + wf = remote.launch_backfill("p", "d", start_date, end_date, "daily2", "v1", dry_run=True) + assert wf diff --git a/tests/flytekit/unit/remote/test_with_responses.py b/tests/flytekit/unit/remote/test_with_responses.py index ee3fbb4d8a..7dd7b97910 100644 --- a/tests/flytekit/unit/remote/test_with_responses.py +++ b/tests/flytekit/unit/remote/test_with_responses.py @@ -66,11 +66,11 @@ def test_normal_task(mock_client): ) admin_task = task_models.Task.from_flyte_idl(merge_sort_remotely) mock_client.get_task.return_value = admin_task - ft = rr.fetch_task(name="merge_sort_remotely", version="tst") + remote_task = rr.fetch_task(name="merge_sort_remotely", version="tst") @workflow def my_wf(numbers: typing.List[int], run_local_at_count: int) -> typing.List[int]: - t1_node = create_node(ft, numbers=numbers, run_local_at_count=run_local_at_count) + t1_node = create_node(remote_task, numbers=numbers, run_local_at_count=run_local_at_count) return t1_node.o0 serialization_settings = flytekit.configuration.SerializationSettings( diff --git a/tests/flytekit/unit/remote/test_wrapper_classes.py b/tests/flytekit/unit/remote/test_wrapper_classes.py index 4a08cb7724..82ba538883 100644 --- a/tests/flytekit/unit/remote/test_wrapper_classes.py +++ b/tests/flytekit/unit/remote/test_wrapper_classes.py @@ -9,7 +9,7 @@ from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.remote import FlyteWorkflow +from flytekit.remote import FlyteTask, FlyteWorkflow from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") @@ -58,11 +58,14 @@ def wf(b: int) -> int: serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, wf) - sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) + sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows=sub_wf_dict, node_launch_plans=lp_specs, tasks=task_templates + wf_spec.template, + sub_workflows=sub_wf_dict, + node_launch_plans=lp_specs, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()}, ) assert len(fwf.outputs) == 1 assert list(fwf.interface.inputs.keys()) == ["b"] @@ -79,7 +82,10 @@ def wf2(b: int) -> int: task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows={}, node_launch_plans=lp_specs, tasks=task_templates + wf_spec.template, + sub_workflows={}, + node_launch_plans=lp_specs, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()}, ) assert len(fwf.outputs) == 1 assert list(fwf.interface.inputs.keys()) == ["b"] @@ -111,7 +117,10 @@ def my_wf(a: int) -> str: task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows={}, node_launch_plans={}, tasks=task_templates + wf_spec.template, + sub_workflows={}, + node_launch_plans={}, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()}, ) assert len(fwf.flyte_nodes[0].upstream_nodes) == 0 @@ -125,11 +134,14 @@ def parent(a: int) -> (str, str): serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, parent) - sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) + sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows=sub_wf_dict, node_launch_plans={}, tasks=task_templates + wf_spec.template, + sub_workflows=sub_wf_dict, + node_launch_plans={}, + tasks={k: FlyteTask.promote_from_model(v) for k, v in task_templates.items()}, ) # Test upstream nodes don't get confused by subworkflows assert len(fwf.flyte_nodes[0].upstream_nodes) == 0