From 59574e55fbe2938b09f6df0da5c99ad86ab66e2c Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Thu, 2 Feb 2023 16:34:55 -0800 Subject: [PATCH] Add `pod_template` and `pod_template_name` arguments for `PythonAutoContainerTask`, its downstream tasks, and `@task`. (#1425) * Add `pod_template` and `pod_template_name` arguments for `PythonAutoContainerTask`, its downstream tasks, and `@task` Signed-off-by: byhsu * clean Signed-off-by: byhsu * fix test Signed-off-by: byhsu * Fix taskmetadata Signed-off-by: byhsu * add kubernetes in setup.py Signed-off-by: byhsu * address comments Signed-off-by: byhsu * Regenerate requirements using python 3.7 Signed-off-by: Eduardo Apolinario Signed-off-by: byhsu * keep container validation Signed-off-by: byhsu * bump idl version Signed-off-by: byhsu * Regenerate requirements using python 3.7 Signed-off-by: Eduardo Apolinario * Regenerate doc-requirements.txt Signed-off-by: Eduardo Apolinario * fix Signed-off-by: byhsu --------- Signed-off-by: byhsu Signed-off-by: Eduardo Apolinario Co-authored-by: byhsu Co-authored-by: Eduardo Apolinario --- dev-requirements.txt | 156 +++++--- doc-requirements.txt | 378 ++++++++---------- flytekit/__init__.py | 3 +- flytekit/core/base_task.py | 3 + flytekit/core/container_task.py | 1 + flytekit/core/pod_template.py | 20 + flytekit/core/python_auto_container.py | 90 ++++- flytekit/core/task.py | 7 + flytekit/models/task.py | 13 + .../flytekitplugins/awsbatch/task.py | 2 +- .../flytekitplugins/papermill/task.py | 2 +- requirements-spark2.txt | 76 +++- requirements.txt | 87 +++- setup.py | 1 + tests/flytekit/common/parameterizers.py | 4 +- .../workflows/requirements.txt | 67 ++-- .../unit/core/test_python_auto_container.py | 219 +++++++++- .../unit/core/test_python_function_task.py | 83 ++++ tests/flytekit/unit/models/test_tasks.py | 6 +- .../unit/models/test_workflow_closure.py | 1 + 20 files changed, 885 insertions(+), 334 deletions(-) create mode 100644 flytekit/core/pod_template.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 3820870a89..9e4eba39fd 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -8,7 +8,7 @@ # via # -c requirements.txt # pytest-flyte -absl-py==1.3.0 +absl-py==1.4.0 # via # tensorboard # tensorflow @@ -34,11 +34,16 @@ binaryornot==0.4.4 # via # -c requirements.txt # cookiecutter -cachetools==5.2.0 - # via google-auth -certifi==2022.9.24 +cached-property==1.5.2 + # via docker-compose +cachetools==5.3.0 + # via + # -c requirements.txt + # google-auth +certifi==2022.12.7 # via # -c requirements.txt + # kubernetes # requests cffi==1.15.1 # via @@ -51,7 +56,7 @@ chardet==5.0.0 # via # -c requirements.txt # binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via # -c requirements.txt # requests @@ -60,7 +65,7 @@ click==8.1.3 # -c requirements.txt # cookiecutter # flytekit -cloudpickle==2.2.0 +cloudpickle==2.2.1 # via # -c requirements.txt # flytekit @@ -70,7 +75,7 @@ cookiecutter==2.1.1 # via # -c requirements.txt # flytekit -coverage[toml]==6.5.0 +coverage[toml]==7.1.0 # via # -r dev-requirements.in # pytest-cov @@ -78,7 +83,7 @@ croniter==1.3.7 # via # -c requirements.txt # flytekit -cryptography==38.0.3 +cryptography==39.0.0 # via # -c requirements.txt # paramiko @@ -123,11 +128,11 @@ docstring-parser==0.15 # via # -c requirements.txt # flytekit -exceptiongroup==1.0.4 +exceptiongroup==1.1.0 # via pytest -filelock==3.8.0 +filelock==3.9.0 # via virtualenv -flatbuffers==22.12.6 +flatbuffers==23.1.21 # via tensorflow flyteidl==1.2.9 # via @@ -135,37 +140,46 @@ flyteidl==1.2.9 # flytekit gast==0.5.3 # via tensorflow -google-api-core[grpc]==2.10.2 +gitdb==4.0.10 + # via + # -c requirements.txt + # gitpython +gitpython==3.1.30 + # via + # -c requirements.txt + # flytekit +google-api-core[grpc]==2.11.0 # via # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.14.1 +google-auth==2.16.0 # via + # -c requirements.txt # google-api-core # google-auth-oauthlib # google-cloud-core + # kubernetes # tensorboard google-auth-oauthlib==0.4.6 # via tensorboard -google-cloud-bigquery==3.4.0 +google-cloud-bigquery==3.4.2 + # via -r dev-requirements.in +google-cloud-bigquery-storage==2.18.1 # via -r dev-requirements.in -google-cloud-bigquery-storage==2.16.2 - # via - # -r dev-requirements.in - # google-cloud-bigquery google-cloud-core==2.3.2 # via google-cloud-bigquery google-crc32c==1.5.0 # via google-resumable-media google-pasta==0.2.0 # via tensorflow -google-resumable-media==2.4.0 +google-resumable-media==2.4.1 # via google-cloud-bigquery -googleapis-common-protos==1.57.0 +googleapis-common-protos==1.58.0 # via # -c requirements.txt # flyteidl + # flytekit # google-api-core # grpcio-status grpcio==1.48.2 @@ -183,21 +197,29 @@ grpcio-status==1.48.2 # -r dev-requirements.in # flytekit # google-api-core -h5py==3.7.0 +h5py==3.8.0 # via tensorflow -identify==2.5.9 +identify==2.5.17 # via pre-commit idna==3.4 # via # -c requirements.txt # requests -importlib-metadata==5.0.0 +importlib-metadata==6.0.0 # via # -c requirements.txt # flytekit # keyring # markdown -iniconfig==1.1.1 + # pluggy + # pre-commit + # pytest + # virtualenv +importlib-resources==5.10.2 + # via + # -c requirements.txt + # keyring +iniconfig==2.0.0 # via pytest ipython==7.34.0 # via -r dev-requirements.in @@ -231,15 +253,19 @@ keras==2.8.0 # via tensorflow keras-preprocessing==1.1.2 # via tensorflow -keyring==23.11.0 +keyring==23.13.1 + # via + # -c requirements.txt + # flytekit +kubernetes==25.3.0 # via # -c requirements.txt # flytekit -libclang==14.0.6 +libclang==15.0.6.1 # via tensorflow markdown==3.4.1 # via tensorboard -markupsafe==2.1.1 +markupsafe==2.1.2 # via # -c requirements.txt # jinja2 @@ -260,7 +286,7 @@ marshmallow-jsonschema==0.13.0 # flytekit matplotlib-inline==0.1.6 # via ipython -mock==4.0.3 +mock==5.0.1 # via -r dev-requirements.in more-itertools==9.0.0 # via @@ -290,7 +316,9 @@ numpy==1.21.6 # tensorboard # tensorflow oauthlib==3.2.2 - # via requests-oauthlib + # via + # -c requirements.txt + # requests-oauthlib opt-einsum==3.3.0 # via tensorflow # scikit-learn @@ -306,7 +334,7 @@ pandas==1.3.5 # via # -c requirements.txt # flytekit -paramiko==2.12.0 +paramiko==3.0.0 # via docker parso==0.8.3 # via jedi @@ -314,15 +342,15 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -platformdirs==2.5.4 +platformdirs==2.6.2 # via virtualenv pluggy==1.0.0 # via pytest -pre-commit==2.20.0 +pre-commit==2.21.0 # via -r dev-requirements.in prompt-toolkit==3.0.32 # via ipython -proto-plus==1.22.1 +proto-plus==1.22.2 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -354,22 +382,24 @@ pyarrow==10.0.0 # via # -c requirements.txt # flytekit - # google-cloud-bigquery pyasn1==0.4.8 # via + # -c requirements.txt # pyasn1-modules # rsa pyasn1-modules==0.2.8 - # via google-auth + # via + # -c requirements.txt + # google-auth pycparser==2.21 # via # -c requirements.txt # cffi -pygments==2.13.0 +pygments==2.14.0 # via ipython pynacl==1.5.0 # via paramiko -pyopenssl==22.1.0 +pyopenssl==23.0.0 # via # -c requirements.txt # flytekit @@ -377,11 +407,11 @@ pyparsing==3.0.9 # via # -c requirements.txt # packaging -pyrsistent==0.19.2 +pyrsistent==0.19.3 # via # -c requirements.txt # jsonschema -pytest==7.2.0 +pytest==7.2.1 # via # -r dev-requirements.in # pytest-cov @@ -400,14 +430,15 @@ python-dateutil==2.8.2 # croniter # flytekit # google-cloud-bigquery + # kubernetes # pandas -python-dotenv==0.21.0 +python-dotenv==0.21.1 # via docker-compose python-json-logger==2.0.4 # via # -c requirements.txt # flytekit -python-slugify==7.0.0 +python-slugify==8.0.0 # via # -c requirements.txt # cookiecutter @@ -415,7 +446,7 @@ pytimeparse==1.1.8 # via # -c requirements.txt # flytekit -pytz==2022.6 +pytz==2022.7.1 # via # -c requirements.txt # flytekit @@ -426,12 +457,13 @@ pyyaml==5.4.1 # cookiecutter # docker-compose # flytekit + # kubernetes # pre-commit regex==2022.10.31 # via # -c requirements.txt # docker-image-py -requests==2.28.1 +requests==2.28.2 # via # -c requirements.txt # cookiecutter @@ -440,11 +472,15 @@ requests==2.28.1 # flytekit # google-api-core # google-cloud-bigquery + # kubernetes # requests-oauthlib # responses # tensorboard requests-oauthlib==1.3.1 - # via google-auth-oauthlib + # via + # -c requirements.txt + # google-auth-oauthlib + # kubernetes responses==0.22.0 # via # -c requirements.txt @@ -454,7 +490,9 @@ retry==0.9.2 # -c requirements.txt # flytekit rsa==4.9 - # via google-auth + # via + # -c requirements.txt + # google-auth scikit-learn==1.0.2 # via -r dev-requirements.in scipy==1.7.3 @@ -473,10 +511,14 @@ six==1.16.0 # grpcio # jsonschema # keras-preprocessing - # paramiko + # kubernetes # python-dateutil # tensorflow # websocket-client +smmap==5.0.0 + # via + # -c requirements.txt + # gitdb sortedcontainers==2.4.0 # via # -c requirements.txt @@ -495,9 +537,9 @@ tensorflow==2.8.1 # via -r dev-requirements.in tensorflow-estimator==2.8.0 # via tensorflow -tensorflow-io-gcs-filesystem==0.27.0 +tensorflow-io-gcs-filesystem==0.30.0 # via tensorflow -termcolor==2.0.1 +termcolor==2.2.0 # via tensorflow text-unidecode==1.3 # via @@ -510,7 +552,6 @@ threadpoolctl==3.1.0 toml==0.10.2 # via # -c requirements.txt - # pre-commit # responses tomli==2.0.1 # via @@ -519,7 +560,7 @@ tomli==2.0.1 # pytest torch==1.12.1 # via -r dev-requirements.in -traitlets==5.5.0 +traitlets==5.9.0 # via # ipython # matplotlib-inline @@ -533,7 +574,11 @@ typing-extensions==4.4.0 # via # -c requirements.txt # flytekit + # gitpython + # importlib-metadata # mypy + # platformdirs + # responses # tensorflow # torch # typing-inspect @@ -541,22 +586,26 @@ typing-inspect==0.8.0 # via # -c requirements.txt # dataclasses-json -urllib3==1.26.12 +urllib3==1.26.14 # via # -c requirements.txt # docker # flytekit + # kubernetes # requests # responses virtualenv==20.16.7 # via pre-commit -wcwidth==0.2.5 +wcwidth==0.2.6 # via prompt-toolkit websocket-client==0.59.0 # via # -c requirements.txt # docker # docker-compose + # kubernetes +werkzeug==2.2.2 + # via tensorboard wheel==0.38.4 # via # -c requirements.txt @@ -569,10 +618,11 @@ wrapt==1.14.1 # deprecated # flytekit # tensorflow -zipp==3.10.0 +zipp==3.12.0 # via # -c requirements.txt # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/doc-requirements.txt b/doc-requirements.txt index 314c926e39..03ebcfdc5c 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -6,17 +6,17 @@ # -e file:.#egg=flytekit # via -r doc-requirements.in -absl-py==1.3.0 +absl-py==1.4.0 # via # tensorboard # tensorflow aiosignal==1.3.1 # via ray -alabaster==0.7.12 +alabaster==0.7.13 # via sphinx -alembic==1.9.1 +alembic==1.9.2 # via mlflow -altair==4.2.0 +altair==4.2.2 # via great-expectations ansiwrap==0.8.4 # via papermill @@ -27,10 +27,6 @@ anyio==3.6.2 # watchfiles aplus==0.11.0 # via vaex-core -appnope==0.1.3 - # via - # ipykernel - # ipython argon2-cffi==21.3.0 # via # jupyter-server @@ -39,10 +35,12 @@ argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 # via argon2-cffi arrow==1.2.3 - # via jinja2-time -astroid==2.12.12 + # via + # isoduration + # jinja2-time +astroid==2.14.1 # via sphinx-autoapi -astropy==5.2 +astropy==5.2.1 # via vaex-astro asttokens==2.2.1 # via stack-data @@ -57,7 +55,7 @@ babel==2.11.0 # via sphinx backcall==0.2.0 # via ipython -beautifulsoup4==4.11.1 +beautifulsoup4==4.11.2 # via # furo # nbconvert @@ -67,15 +65,17 @@ binaryornot==0.4.4 # via cookiecutter blake3==0.3.3 # via vaex-core -bleach==5.0.1 +bleach==6.0.0 # via nbconvert -botocore==1.29.44 +botocore==1.29.61 # via -r doc-requirements.in bqplot==0.12.36 - # via vaex-jupyter + # via + # ipyvolume + # vaex-jupyter branca==0.6.0 # via ipyleaflet -cachetools==5.2.0 +cachetools==5.3.0 # via # google-auth # vaex-server @@ -91,13 +91,14 @@ cfgv==3.3.1 # via pre-commit chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # dask # databricks-cli + # distributed # flask # flytekit # great-expectations @@ -106,9 +107,10 @@ click==8.1.3 # ray # sphinx-click # uvicorn -cloudpickle==2.2.0 +cloudpickle==2.2.1 # via # dask + # distributed # flytekit # mlflow # shap @@ -117,36 +119,34 @@ colorama==0.4.6 # via great-expectations comm==0.1.2 # via ipykernel -commonmark==0.9.1 - # via rich -contourpy==1.0.6 +contourpy==1.0.7 # via matplotlib cookiecutter==2.1.1 # via flytekit croniter==1.3.7 # via flytekit -<<<<<<< HEAD -cryptography==38.0.3 -======= cryptography==39.0.0 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # -r doc-requirements.in # great-expectations # pyopenssl + # secretstorage css-html-js-minify==2.5.5 # via sphinx-material cycler==0.11.0 # via matplotlib -dask==2022.12.1 - # via vaex-core +dask[distributed]==2023.1.1 + # via + # -r doc-requirements.in + # distributed + # vaex-core databricks-cli==0.17.4 # via mlflow dataclasses-json==0.5.7 # via # dolt-integrations # flytekit -debugpy==1.6.3 +debugpy==1.6.6 # via ipykernel decorator==5.1.1 # via @@ -160,6 +160,8 @@ diskcache==5.4.0 # via flytekit distlib==0.3.6 # via virtualenv +distributed==2023.1.1 + # via dask docker==6.0.1 # via # flytekit @@ -180,33 +182,24 @@ doltcli==0.1.17 entrypoints==0.4 # via # altair - # jupyter-client # mlflow # papermill executing==1.2.0 # via stack-data -fastapi==0.88.0 +fastapi==0.89.1 # via vaex-server fastjsonschema==2.16.2 # via nbformat -<<<<<<< HEAD -filelock==3.8.0 -======= filelock==3.9.0 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # ray # vaex-core # virtualenv flask==2.2.2 # via mlflow -flatbuffers==23.1.4 +flatbuffers==23.1.21 # via tensorflow -<<<<<<< HEAD flyteidl==1.2.9 -======= -flyteidl==1.3.2 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via flytekit fonttools==4.38.0 # via matplotlib @@ -218,31 +211,29 @@ frozenlist==1.3.3 # via # aiosignal # ray -fsspec==2022.11.0 +fsspec==2023.1.0 # via # -r doc-requirements.in # dask # modin furo @ git+https://github.com/flyteorg/furo@main # via -r doc-requirements.in -future==0.18.2 +future==0.18.3 # via vaex-core gast==0.5.3 # via tensorflow -<<<<<<< HEAD -google-api-core[grpc]==2.8.2 -======= gitdb==4.0.10 # via gitpython gitpython==3.1.30 - # via mlflow + # via + # flytekit + # mlflow google-api-core[grpc]==2.11.0 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # -r doc-requirements.in # google-cloud-bigquery # google-cloud-core -google-auth==2.14.1 +google-auth==2.16.0 # via # google-api-core # google-auth-oauthlib @@ -253,11 +244,7 @@ google-auth-oauthlib==0.4.6 # via tensorboard google-cloud==0.34.0 # via -r doc-requirements.in -<<<<<<< HEAD -google-cloud-bigquery==3.1.0 -======= -google-cloud-bigquery==3.4.1 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) +google-cloud-bigquery==3.5.0 # via -r doc-requirements.in google-cloud-core==2.3.2 # via google-cloud-bigquery @@ -265,17 +252,17 @@ google-crc32c==1.5.0 # via google-resumable-media google-pasta==0.2.0 # via tensorflow -google-resumable-media==2.4.0 +google-resumable-media==2.4.1 # via google-cloud-bigquery -googleapis-common-protos==1.57.1 +googleapis-common-protos==1.58.0 # via # flyteidl # flytekit # google-api-core # grpcio-status -great-expectations==0.15.42 +great-expectations==0.15.46 # via -r doc-requirements.in -greenlet==2.0.1 +greenlet==2.0.2 # via sqlalchemy grpcio==1.43.0 # via @@ -296,15 +283,17 @@ gunicorn==20.1.0 # via mlflow h11==0.14.0 # via uvicorn -h5py==3.7.0 +h5py==3.8.0 # via # tensorflow # vaex-hdf5 +heapdict==1.0.1 + # via zict htmlmin==0.1.12 - # via pandas-profiling + # via ydata-profiling httptools==0.5.0 # via uvicorn -identify==2.5.12 +identify==2.5.17 # via pre-commit idna==3.4 # via @@ -314,15 +303,12 @@ imagehash==4.3.1 # via visions imagesize==1.4.1 # via sphinx -<<<<<<< HEAD -importlib-metadata==5.0.0 -======= importlib-metadata==5.2.0 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # flask # flytekit # great-expectations + # jupyter-client # keyring # markdown # mlflow @@ -330,7 +316,7 @@ importlib-metadata==5.2.0 # sphinx ipydatawidgets==4.3.2 # via pythreejs -ipykernel==6.19.4 +ipykernel==6.20.2 # via # ipywidgets # jupyter @@ -342,7 +328,7 @@ ipyleaflet==0.17.2 # via vaex-jupyter ipympl==0.9.2 # via vaex-jupyter -ipython==8.8.0 +ipython==8.9.0 # via # great-expectations # ipykernel @@ -355,12 +341,16 @@ ipython-genutils==0.2.0 # nbclassic # notebook # qtconsole -ipyvolume==0.5.2 +ipyvolume==0.6.0 # via vaex-jupyter ipyvue==1.8.0 - # via ipyvuetify + # via + # ipyvolume + # ipyvuetify ipyvuetify==1.8.4 - # via vaex-jupyter + # via + # ipyvolume + # vaex-jupyter ipywebrtc==0.6.0 # via ipyvolume ipywidgets==8.0.4 @@ -382,11 +372,16 @@ jaraco-classes==3.2.3 # via keyring jedi==0.18.1 # via ipython +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # altair # branca # cookiecutter + # distributed # flask # great-expectations # jinja2-time @@ -395,10 +390,10 @@ jinja2==3.1.2 # nbclassic # nbconvert # notebook - # pandas-profiling # sphinx # sphinx-autoapi # vaex-ml + # ydata-profiling jinja2-time==0.2.0 # via cookiecutter jmespath==1.0.1 @@ -422,7 +417,7 @@ jsonschema[format-nongpl]==4.17.3 # ray jupyter==1.0.0 # via -r doc-requirements.in -jupyter-client==7.4.6 +jupyter-client==8.0.2 # via # ipykernel # jupyter-console @@ -433,11 +428,7 @@ jupyter-client==7.4.6 # qtconsole jupyter-console==6.4.4 # via jupyter -<<<<<<< HEAD -jupyter-core==5.0.0 -======= -jupyter-core==5.1.2 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) +jupyter-core==5.2.0 # via # jupyter-client # jupyter-server @@ -446,18 +437,13 @@ jupyter-core==5.1.2 # nbformat # notebook # qtconsole -jupyter-events==0.5.0 +jupyter-events==0.6.3 # via jupyter-server -jupyter-server==2.0.6 +jupyter-server==2.2.0 # via # nbclassic # notebook-shim -<<<<<<< HEAD -jupyterlab-pygments==0.2.2 - # via nbconvert -jupyterlab-widgets==3.0.3 -======= -jupyter-server-terminals==0.4.3 +jupyter-server-terminals==0.4.4 # via jupyter-server jupyterlab-pygments==0.2.2 # via nbconvert @@ -473,15 +459,19 @@ keyring==23.13.1 kiwisolver==1.4.4 # via matplotlib kubernetes==25.3.0 - # via -r doc-requirements.in + # via + # -r doc-requirements.in + # flytekit lazy-object-proxy==1.9.0 # via astroid -libclang==14.0.6 +libclang==15.0.6.1 # via tensorflow llvmlite==0.39.1 # via numba locket==1.0.0 - # via partd + # via + # distributed + # partd lxml==4.9.2 # via sphinx-material makefun==1.15.0 @@ -493,7 +483,9 @@ markdown==3.4.1 # -r doc-requirements.in # mlflow # tensorboard -markupsafe==2.1.1 +markdown-it-py==2.1.0 + # via rich +markupsafe==2.1.2 # via # jinja2 # mako @@ -509,61 +501,56 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.5.3 +matplotlib==3.6.3 # via # ipympl + # ipyvolume # mlflow - # pandas-profiling # phik # seaborn # vaex-viz + # ydata-profiling matplotlib-inline==0.1.6 # via # ipykernel # ipython -missingno==0.5.1 - # via pandas-profiling +mdurl==0.1.2 + # via markdown-it-py mistune==2.0.4 # via # great-expectations # nbconvert -<<<<<<< HEAD -modin==0.17.0 -======= mlflow==2.1.1 # via -r doc-requirements.in -modin==0.18.0 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) +modin==0.18.1 # via -r doc-requirements.in more-itertools==9.0.0 # via jaraco-classes msgpack==1.0.4 - # via ray + # via + # distributed + # ray multimethod==1.9.1 # via - # pandas-profiling # visions + # ydata-profiling mypy-extensions==0.4.3 # via typing-inspect natsort==8.2.0 # via flytekit -nbclassic==0.4.8 +nbclassic==0.5.1 # via notebook nbclient==0.7.0 # via # nbconvert # papermill -<<<<<<< HEAD -nbconvert==7.2.5 -======= -nbconvert==7.2.7 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) +nbconvert==7.2.9 # via # jupyter # jupyter-server # nbclassic # notebook -nbformat==5.7.1 +nbformat==5.7.3 # via # great-expectations # jupyter-server @@ -575,12 +562,11 @@ nbformat==5.7.1 nest-asyncio==1.5.6 # via # ipykernel - # jupyter-client # nbclassic # nbclient # notebook # vaex-core -networkx==2.8.8 +networkx==3.0 # via visions nodeenv==1.7.0 # via pre-commit @@ -609,16 +595,12 @@ numpy==1.23.5 # ipyvolume # keras-preprocessing # matplotlib -<<<<<<< HEAD # missingno -======= # mlflow ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # modin # numba # opt-einsum # pandas - # pandas-profiling # pandera # patsy # phik @@ -630,19 +612,26 @@ numpy==1.23.5 # scikit-learn # scipy # seaborn -<<<<<<< HEAD # skl2onnx -======= # shap ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # statsmodels # tensorboard # tensorflow # vaex-core # visions # xarray + # ydata-profiling +nvidia-cublas-cu11==11.10.3.66 + # via + # nvidia-cudnn-cu11 + # torch +nvidia-cuda-nvrtc-cu11==11.7.99 + # via torch +nvidia-cuda-runtime-cu11==11.7.99 + # via torch +nvidia-cudnn-cu11==8.5.0.96 + # via torch oauthlib==3.2.2 -<<<<<<< HEAD # via requests-oauthlib onnx==1.12.0 # via @@ -651,17 +640,13 @@ onnx==1.12.0 # tf2onnx onnxconverter-common==1.13.0 # via skl2onnx -======= - # via - # databricks-cli - # requests-oauthlib ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) opt-einsum==3.3.0 # via tensorflow -packaging==21.3 +packaging==22.0 # via # astropy # dask + # distributed # docker # google-cloud-bigquery # great-expectations @@ -679,7 +664,7 @@ packaging==21.3 # sphinx # statsmodels # xarray -pandas==1.5.2 +pandas==1.5.3 # via # altair # bqplot @@ -688,7 +673,6 @@ pandas==1.5.2 # great-expectations # mlflow # modin - # pandas-profiling # pandera # phik # seaborn @@ -697,7 +681,8 @@ pandas==1.5.2 # vaex-core # visions # xarray -pandas-profiling==3.6.2 + # ydata-profiling +pandas-profiling==3.6.6 # via -r doc-requirements.in pandera==0.13.4 # via -r doc-requirements.in @@ -713,8 +698,8 @@ patsy==0.5.3 # via statsmodels pexpect==4.8.0 # via ipython -phik==0.12.2 - # via pandas-profiling +phik==0.12.3 + # via ydata-profiling pickleshare==0.7.5 # via ipython pillow==9.4.0 @@ -725,21 +710,17 @@ pillow==9.4.0 # matplotlib # vaex-viz # visions -<<<<<<< HEAD -platformdirs==2.5.4 -======= platformdirs==2.6.2 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # jupyter-core # virtualenv -plotly==5.11.0 +plotly==5.13.0 # via -r doc-requirements.in -pre-commit==2.21.0 +pre-commit==3.0.2 # via sphinx-tags progressbar2==4.2.0 # via vaex-core -prometheus-client==0.15.0 +prometheus-client==0.16.0 # via # jupyter-server # nbclassic @@ -748,16 +729,11 @@ prompt-toolkit==3.0.32 # via # ipython # jupyter-console -proto-plus==1.22.1 -<<<<<<< HEAD # via # google-cloud-bigquery # google-cloud-bigquery-storage protobuf==3.19.6 -======= - # via google-cloud-bigquery -protobuf==4.21.12 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) +proto-plus==1.22.2 # via # flyteidl # flytekit @@ -765,12 +741,7 @@ protobuf==4.21.12 # google-cloud-bigquery # googleapis-common-protos # grpcio-status -<<<<<<< HEAD - # onnx - # onnxconverter-common -======= # mlflow ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # proto-plus # protoc-gen-swagger # ray @@ -782,6 +753,7 @@ protoc-gen-swagger==0.1.0 # via flyteidl psutil==5.9.4 # via + # distributed # ipykernel # modin ptyprocess==0.7.0 @@ -811,9 +783,9 @@ pydantic==1.10.4 # via # fastapi # great-expectations - # pandas-profiling # pandera # vaex-core + # ydata-profiling pyerfa==2.0.0.1 # via astropy pygments==2.14.0 @@ -834,7 +806,6 @@ pyparsing==3.0.9 # via # great-expectations # matplotlib - # packaging pyrsistent==0.19.3 # via jsonschema pyspark==3.3.1 @@ -852,11 +823,13 @@ python-dateutil==2.8.2 # matplotlib # pandas # whylabs-client -python-dotenv==0.21.0 +python-dotenv==0.21.1 # via uvicorn python-json-logger==2.0.4 - # via flytekit -python-slugify[unidecode]==6.1.2 + # via + # flytekit + # jupyter-events +python-slugify[unidecode]==8.0.0 # via # cookiecutter # sphinx-material @@ -866,7 +839,7 @@ pythreejs==2.4.1 # via ipyvolume pytimeparse==1.1.8 # via flytekit -pytz==2022.7 +pytz==2022.7.1 # via # babel # flytekit @@ -882,17 +855,18 @@ pyyaml==6.0 # astropy # cookiecutter # dask + # distributed # flytekit # kubernetes # mlflow - # pandas-profiling # papermill # pre-commit # ray # sphinx-autoapi # uvicorn # vaex-core -pyzmq==24.0.1 + # ydata-profiling +pyzmq==25.0.0 # via # ipykernel # jupyter-client @@ -910,7 +884,7 @@ ray==2.2.0 # via -r doc-requirements.in regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # databricks-cli @@ -922,7 +896,6 @@ requests==2.28.1 # ipyvolume # kubernetes # mlflow - # pandas-profiling # papermill # ray # requests-oauthlib @@ -930,6 +903,7 @@ requests==2.28.1 # sphinx # tensorboard # vaex-core + # ydata-profiling requests-oauthlib==1.3.1 # via # google-auth-oauthlib @@ -939,10 +913,14 @@ responses==0.22.0 retry==0.9.2 # via flytekit rfc3339-validator==0.1.4 - # via jsonschema + # via + # jsonschema + # jupyter-events rfc3986-validator==0.1.1 - # via jsonschema -rich==13.0.0 + # via + # jsonschema + # jupyter-events +rich==13.3.1 # via vaex-core rsa==4.9 # via google-auth @@ -950,7 +928,7 @@ ruamel-yaml==0.17.17 # via great-expectations ruamel-yaml-clib==0.2.7 # via ruamel-yaml -scikit-learn==1.2.0 +scikit-learn==1.2.1 # via # -r doc-requirements.in # mlflow @@ -959,22 +937,16 @@ scipy==1.9.3 # via # great-expectations # imagehash -<<<<<<< HEAD - # missingno - # pandas-profiling - # phik - # scikit-learn - # skl2onnx -======= # mlflow - # pandas-profiling # phik # scikit-learn # shap ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # statsmodels + # ydata-profiling seaborn==0.12.2 - # via pandas-profiling + # via ydata-profiling +secretstorage==3.3.3 + # via keyring send2trash==1.8.0 # via # jupyter-server @@ -995,11 +967,8 @@ six==1.16.0 # kubernetes # patsy # python-dateutil -<<<<<<< HEAD -======= # querystring-parser # rfc3339-validator ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # sphinx-code-include # tensorflow # vaex-core @@ -1012,7 +981,9 @@ sniffio==1.3.0 snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 - # via flytekit + # via + # distributed + # flytekit soupsieve==2.3.2.post1 # via beautifulsoup4 sphinx==4.5.0 @@ -1031,7 +1002,7 @@ sphinx==4.5.0 # sphinx-prompt # sphinx-tags # sphinxcontrib-yt -sphinx-autoapi==2.0.0 +sphinx-autoapi==2.0.1 # via -r doc-requirements.in sphinx-basic-ng==1.0.0b1 # via furo @@ -1051,13 +1022,13 @@ sphinx-panels==0.6.0 # via -r doc-requirements.in sphinx-prompt==1.5.0 # via -r doc-requirements.in -sphinx-tags==0.1.6 +sphinx-tags==0.2.0 # via -r doc-requirements.in -sphinxcontrib-applehelp==1.0.2 +sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx -sphinxcontrib-htmlhelp==2.0.0 +sphinxcontrib-htmlhelp==2.0.1 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx @@ -1067,11 +1038,6 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-yt==0.2.2 # via -r doc-requirements.in -<<<<<<< HEAD -sqlalchemy==1.4.44 - # via -r doc-requirements.in -stack-data==0.6.1 -======= sqlalchemy==1.4.46 # via # -r doc-requirements.in @@ -1080,20 +1046,21 @@ sqlalchemy==1.4.46 sqlparse==0.4.3 # via mlflow stack-data==0.6.2 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via ipython starlette==0.22.0 # via fastapi statsd==3.3.0 # via flytekit statsmodels==0.13.5 - # via pandas-profiling + # via ydata-profiling tabulate==0.9.0 # via # databricks-cli # vaex-core tangled-up-in-unicode==0.2.0 # via visions +tblib==1.7.0 + # via distributed tenacity==8.1.0 # via # papermill @@ -1108,21 +1075,11 @@ tensorflow==2.9.0 # via -r doc-requirements.in tensorflow-estimator==2.9.0 # via tensorflow -<<<<<<< HEAD -tensorflow-io-gcs-filesystem==0.27.0 - # via tensorflow -termcolor==2.0.1 - # via - # great-expectations - # tensorflow -terminado==0.17.0 -======= -tensorflow-io-gcs-filesystem==0.29.0 +tensorflow-io-gcs-filesystem==0.30.0 # via tensorflow termcolor==2.2.0 # via tensorflow terminado==0.17.1 ->>>>>>> 55b6602a (Add support MLFlow plugin (#1274)) # via # jupyter-server # nbclassic @@ -1143,11 +1100,13 @@ toolz==0.12.0 # via # altair # dask + # distributed # partd torch==1.13.1 # via -r doc-requirements.in tornado==6.2 # via + # distributed # ipykernel # jupyter-client # jupyter-server @@ -1158,10 +1117,10 @@ tornado==6.2 tqdm==4.64.1 # via # great-expectations - # pandas-profiling # papermill # shap -traitlets==5.8.0 + # ydata-profiling +traitlets==5.9.0 # via # bqplot # comm @@ -1190,7 +1149,7 @@ traittypes==0.2.1 # ipyleaflet # ipyvolume typeguard==2.13.3 - # via pandas-profiling + # via ydata-profiling types-toml==0.10.8.1 # via responses typing-extensions==4.4.0 @@ -1217,9 +1176,12 @@ unidecode==1.3.6 # via # python-slugify # sphinx-autoapi -urllib3==1.26.12 +uri-template==1.2.0 + # via jsonschema +urllib3==1.26.14 # via # botocore + # distributed # docker # flytekit # great-expectations @@ -1261,16 +1223,16 @@ virtualenv==20.17.1 # pre-commit # ray visions[type_image_path]==0.7.5 - # via pandas-profiling + # via ydata-profiling watchfiles==0.18.1 # via uvicorn -wcwidth==0.2.5 +wcwidth==0.2.6 # via prompt-toolkit webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.4.2 +websocket-client==1.5.0 # via # docker # jupyter-server @@ -1285,10 +1247,12 @@ wheel==0.38.4 # via # astunparse # flytekit + # nvidia-cublas-cu11 + # nvidia-cuda-runtime-cu11 # tensorboard -whylabs-client==0.4.0 +whylabs-client==0.4.3 # via -r doc-requirements.in -whylogs==1.1.20 +whylogs==1.1.24 # via -r doc-requirements.in whylogs-sketching==3.4.1.dev3 # via whylogs @@ -1301,11 +1265,15 @@ wrapt==1.14.1 # flytekit # pandera # tensorflow -xarray==2022.12.0 +xarray==2023.1.0 # via vaex-jupyter xyzservices==2022.9.0 # via ipyleaflet -zipp==3.11.0 +ydata-profiling==4.0.0 + # via pandas-profiling +zict==2.2.0 + # via distributed +zipp==3.12.0 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/flytekit/__init__.py b/flytekit/__init__.py index e40f060b4d..dfbb6594f7 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -58,7 +58,7 @@ TaskMetadata - Wrapper object that allows users to specify Task Resources - Things like CPUs/Memory, etc. WorkflowFailurePolicy - Customizes what happens when a workflow fails. - + PodTemplate - Custom PodTemplate for a task. Dynamic and Nested Workflows ============================== @@ -176,6 +176,7 @@ from flytekit.core.launch_plan import LaunchPlan from flytekit.core.map_task import map_task from flytekit.core.notification import Email, PagerDuty, Slack +from flytekit.core.pod_template import PodTemplate from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask from flytekit.core.reference import get_reference_entity from flytekit.core.reference_entity import LaunchPlanReference, TaskReference, WorkflowReference diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 491bed4385..2b21f7c40b 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -85,6 +85,7 @@ class TaskMetadata(object): timeout (Optional[Union[datetime.timedelta, int]]): the max amount of time for which one execution of this task should be executed for. The execution will be terminated if the runtime exceeds the given timeout (approximately) + pod_template_name (Optional[str]): the name of existing PodTemplate resource in the cluster which will be used in this task. """ cache: bool = False @@ -94,6 +95,7 @@ class TaskMetadata(object): deprecated: str = "" retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None + pod_template_name: Optional[str] = None def __post_init__(self): if self.timeout: @@ -127,6 +129,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: discovery_version=self.cache_version, deprecated_error_message=self.deprecated, cache_serializable=self.cache_serialize, + pod_template_name=self.pod_template_name, ) diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 848c1d2524..d470fb54fe 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -10,6 +10,7 @@ from flytekit.models.security import Secret, SecurityContext +# TODO: do we need pod_template here? Seems that it is a raw container not running in pods class ContainerTask(PythonTask): """ This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast diff --git a/flytekit/core/pod_template.py b/flytekit/core/pod_template.py new file mode 100644 index 0000000000..be8991db45 --- /dev/null +++ b/flytekit/core/pod_template.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import Dict, Optional + +from kubernetes.client.models import V1PodSpec + +from flytekit.exceptions import user as _user_exceptions + +PRIMARY_CONTAINER_DEFAULT_NAME = "primary" + + +@dataclass +class PodTemplate(object): + pod_spec: V1PodSpec = V1PodSpec(containers=[]) + primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME + labels: Optional[Dict[str, str]] = None + annotations: Optional[Dict[str, str]] = None + + def __post_init__(self): + if not self.primary_container_name: + raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined") diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 06133d9784..2d05df3c3d 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -4,11 +4,16 @@ import re from abc import ABC from types import ModuleType -from typing import Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union + +from flyteidl.core import tasks_pb2 as _core_task +from kubernetes.client import ApiClient +from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements from flytekit.configuration import ImageConfig, SerializationSettings -from flytekit.core.base_task import PythonTask, TaskResolverMixin +from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance, extract_task_module @@ -18,6 +23,11 @@ from flytekit.models.security import Secret, SecurityContext T = TypeVar("T") +_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" + + +def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: + return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC): @@ -40,6 +50,8 @@ def __init__( environment: Optional[Dict[str, str]] = None, task_resolver: Optional[TaskResolverMixin] = None, secret_requests: Optional[List[Secret]] = None, + pod_template: Optional[PodTemplate] = None, + pod_template_name: Optional[str] = None, **kwargs, ): """ @@ -64,6 +76,8 @@ def __init__( - `Confidant `__ - `Kube secrets `__ - `AWS Parameter store `__ + :param pod_template: Custom PodTemplate for this task. + :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ sec_ctx = None if secret_requests: @@ -71,6 +85,11 @@ def __init__( if not isinstance(s, Secret): raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) + + # pod_template_name overwrites the metedata.pod_template_name + kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata() + kwargs["metadata"].pod_template_name = pod_template_name + super().__init__( task_type=task_type, name=name, @@ -98,6 +117,8 @@ def __init__( self._task_resolver = task_resolver or default_task_resolver self._get_command_fn = self.get_default_command + self.pod_template = pod_template + @property def task_resolver(self) -> Optional[TaskResolverMixin]: return self._task_resolver @@ -157,6 +178,13 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return self._get_command_fn(settings) def get_container(self, settings: SerializationSettings) -> _task_model.Container: + # if pod_template is not None, return None here but in get_k8s_pod, return pod_template merged with container + if self.pod_template is not None: + return None + else: + return self._get_container(settings) + + def _get_container(self, settings: SerializationSettings) -> _task_model.Container: env = {} for elem in (settings.env, self.environment): if elem: @@ -179,6 +207,64 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe memory_limit=self.resources.limits.mem, ) + def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: + containers = self.pod_template.pod_spec.containers + primary_exists = False + + for container in containers: + if container.name == self.pod_template.primary_container_name: + primary_exists = True + break + + if not primary_exists: + # insert a placeholder primary container if it is not defined in the pod spec. + containers.append(V1Container(name=self.pod_template.primary_container_name)) + final_containers = [] + for container in containers: + # In the case of the primary container, we overwrite specific container attributes + # with the default values used in the regular Python task. + # The attributes include: image, command, args, resource, and env (env is unioned) + if container.name == self.pod_template.primary_container_name: + sdk_default_container = self._get_container(settings) + container.image = sdk_default_container.image + # clear existing commands + container.command = sdk_default_container.command + # also clear existing args + container.args = sdk_default_container.args + limits, requests = {}, {} + for resource in sdk_default_container.resources.limits: + limits[_sanitize_resource_name(resource)] = resource.value + for resource in sdk_default_container.resources.requests: + requests[_sanitize_resource_name(resource)] = resource.value + resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) + if len(limits) > 0 or len(requests) > 0: + # Important! Only copy over resource requirements if they are non-empty. + container.resources = resource_requirements + container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + ( + container.env or [] + ) + final_containers.append(container) + self.pod_template.pod_spec.containers = final_containers + + return ApiClient().sanitize_for_serialization(self.pod_template.pod_spec) + + def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: + if self.pod_template is None: + return None + return _task_model.K8sPod( + pod_spec=self._serialize_pod_spec(settings), + metadata=_task_model.K8sObjectMetadata( + labels=self.pod_template.labels, + annotations=self.pod_template.annotations, + ), + ) + + # need to call super in all its children tasks + def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: + if self.pod_template is None: + return {} + return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name} + class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): """ diff --git a/flytekit/core/task.py b/flytekit/core/task.py index bb24181338..28c5b5def7 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -4,6 +4,7 @@ from flytekit.core.base_task import TaskMetadata, TaskResolverMixin from flytekit.core.interface import transform_function_to_interface +from flytekit.core.pod_template import PodTemplate from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources @@ -92,6 +93,8 @@ def task( task_resolver: Optional[TaskResolverMixin] = None, docs: Optional[Documentation] = None, disable_deck: bool = True, + pod_template: Optional[PodTemplate] = None, + pod_template_name: Optional[str] = None, ) -> Union[Callable, PythonFunctionTask]: """ This is the core decorator to use for any task type in flytekit. @@ -182,6 +185,8 @@ def foo2(): :param task_resolver: Provide a custom task resolver. :param disable_deck: If true, this task will not output deck html file :param docs: Documentation about this task + :param pod_template: Custom PodTemplate for this task. + :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ def wrapper(fn) -> PythonFunctionTask: @@ -208,6 +213,8 @@ def wrapper(fn) -> PythonFunctionTask: task_resolver=task_resolver, disable_deck=disable_deck, docs=docs, + pod_template=pod_template, + pod_template_name=pod_template_name, ) update_wrapper(task_instance, fn) return task_instance diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 2129cdd88f..fc79c87a2d 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -177,6 +177,7 @@ def __init__( discovery_version, deprecated_error_message, cache_serializable, + pod_template_name, ): """ Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, @@ -196,6 +197,7 @@ def __init__( receive deprecation warnings. :param bool cache_serializable: Whether or not caching operations are executed in serial. This means only a single instance over identical inputs is executed, other concurrent executions wait for the cached results. + :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ self._discoverable = discoverable self._runtime = runtime @@ -205,6 +207,7 @@ def __init__( self._discovery_version = discovery_version self._deprecated_error_message = deprecated_error_message self._cache_serializable = cache_serializable + self._pod_template_name = pod_template_name @property def discoverable(self): @@ -274,6 +277,14 @@ def cache_serializable(self): """ return self._cache_serializable + @property + def pod_template_name(self): + """ + The name of the existing PodTemplate resource which will be used in this task. + :rtype: Text + """ + return self._pod_template_name + def to_flyte_idl(self): """ :rtype: flyteidl.admin.task_pb2.TaskMetadata @@ -286,6 +297,7 @@ def to_flyte_idl(self): discovery_version=self.discovery_version, deprecated_error_message=self.deprecated_error_message, cache_serializable=self.cache_serializable, + pod_template_name=self.pod_template_name, ) if self.timeout: tm.timeout.FromTimedelta(self.timeout) @@ -306,6 +318,7 @@ def from_flyte_idl(cls, pb2_object): discovery_version=pb2_object.discovery_version, deprecated_error_message=pb2_object.deprecated_error_message, cache_serializable=pb2_object.cache_serializable, + pod_template_name=pb2_object.pod_template_name, ) diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py index 8787e70011..0e67b2e50b 100644 --- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -53,7 +53,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: def get_config(self, settings: SerializationSettings) -> Dict[str, str]: # Parameters in taskTemplate config will be used to create aws job definition. # More detail about job definition: https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html - return {"platformCapabilities": self._task_config.platformCapabilities} + return {**super().get_config(settings), "platformCapabilities": self._task_config.platformCapabilities} def get_command(self, settings: SerializationSettings) -> List[str]: container_args = [ diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 04f821ccf3..6c160c2690 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -186,7 +186,7 @@ def fn(settings: SerializationSettings) -> typing.List[str]: return self._config_task_instance.get_k8s_pod(settings) def get_config(self, settings: SerializationSettings) -> typing.Dict[str, str]: - return self._config_task_instance.get_config(settings) + return {**super().get_config(settings), **self._config_task_instance.get_config(settings)} def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: return self._config_task_instance.pre_execute(user_params) diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 4e997568f2..1c7dc65df6 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -16,25 +16,29 @@ attrs==20.3.0 # jsonschema binaryornot==0.4.4 # via cookiecutter -certifi==2022.9.24 - # via requests +cachetools==5.3.0 + # via google-auth +certifi==2022.12.7 + # via + # kubernetes + # requests cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.2.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit croniter==1.3.7 # via flytekit -cryptography==38.0.3 +cryptography==39.0.0 # via # pyopenssl # secretstorage @@ -54,9 +58,16 @@ docstring-parser==0.15 # via flytekit flyteidl==1.2.9 # via flytekit -googleapis-common-protos==1.57.0 +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via flytekit +google-auth==2.16.0 + # via kubernetes +googleapis-common-protos==1.58.0 # via # flyteidl + # flytekit # grpcio-status grpcio==1.48.2 # via @@ -66,12 +77,14 @@ grpcio-status==1.48.2 # via flytekit idna==3.4 # via requests -importlib-metadata==5.0.0 +importlib-metadata==6.0.0 # via # click # flytekit # jsonschema # keyring +importlib-resources==5.10.2 + # via keyring jaraco-classes==3.2.3 # via keyring jeepney==0.8.0 @@ -88,9 +101,11 @@ joblib==1.2.0 # via flytekit jsonschema==3.2.0 # via -r requirements.in -keyring==23.11.0 +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +kubernetes==25.3.0 + # via flytekit +markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 # via @@ -113,8 +128,11 @@ numpy==1.21.6 # flytekit # pandas # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib packaging==21.3 # via + # -r requirements.in # docker # marshmallow pandas==1.3.5 @@ -134,27 +152,34 @@ py==1.11.0 # via retry pyarrow==10.0.0 # via flytekit +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth pycparser==2.21 # via cffi -pyopenssl==22.1.0 +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 # via packaging -pyrsistent==0.19.2 +pyrsistent==0.19.3 # via jsonschema python-dateutil==2.8.2 # via # arrow # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==7.0.0 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.6 +pytz==2022.7.1 # via # flytekit # pandas @@ -163,28 +188,38 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit + # kubernetes regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit + # kubernetes + # requests-oauthlib # responses +requests-oauthlib==1.3.1 + # via kubernetes responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit +rsa==4.9 + # via google-auth secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 # via - # grpcio + # google-auth # jsonschema + # kubernetes # python-dateutil # websocket-client +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 @@ -199,29 +234,34 @@ typing-extensions==4.4.0 # via # arrow # flytekit + # gitpython # importlib-metadata # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.12 +urllib3==1.26.14 # via # docker # flytekit + # kubernetes # requests # responses websocket-client==0.59.0 # via # -r requirements.in # docker + # kubernetes wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.10.0 - # via importlib-metadata +zipp==3.12.0 + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/requirements.txt b/requirements.txt index 5a62372566..15449d437d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # make requirements.txt # @@ -14,25 +14,29 @@ attrs==20.3.0 # jsonschema binaryornot==0.4.4 # via cookiecutter -certifi==2022.9.24 - # via requests +cachetools==5.3.0 + # via google-auth +certifi==2022.12.7 + # via + # kubernetes + # requests cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.2.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit croniter==1.3.7 # via flytekit -cryptography==38.0.3 +cryptography==39.0.0 # via # pyopenssl # secretstorage @@ -52,9 +56,16 @@ docstring-parser==0.15 # via flytekit flyteidl==1.2.9 # via flytekit -googleapis-common-protos==1.57.0 +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via flytekit +google-auth==2.16.0 + # via kubernetes +googleapis-common-protos==1.58.0 # via # flyteidl + # flytekit # grpcio-status grpcio==1.48.2 # via @@ -64,10 +75,14 @@ grpcio-status==1.48.2 # via flytekit idna==3.4 # via requests -importlib-metadata==5.0.0 +importlib-metadata==6.0.0 # via + # click # flytekit + # jsonschema # keyring +importlib-resources==5.10.2 + # via keyring jaraco-classes==3.2.3 # via keyring jinja2==3.1.2 @@ -80,9 +95,11 @@ joblib==1.2.0 # via flytekit jsonschema==3.2.0 # via -r requirements.in -keyring==23.11.0 +keyring==23.13.1 # via flytekit -markupsafe==2.1.1 +kubernetes==25.3.0 + # via flytekit +markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 # via @@ -102,8 +119,11 @@ natsort==8.2.0 numpy==1.21.6 # via # -r requirements.in + # flytekit # pandas # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib packaging==21.3 # via # docker @@ -125,27 +145,34 @@ py==1.11.0 # via retry pyarrow==10.0.0 # via flytekit +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth pycparser==2.21 # via cffi -pyopenssl==22.1.0 +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 # via packaging -pyrsistent==0.19.2 +pyrsistent==0.19.3 # via jsonschema python-dateutil==2.8.2 # via # arrow # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==7.0.0 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.6 +pytz==2022.7.1 # via # flytekit # pandas @@ -154,26 +181,38 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit + # kubernetes regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker # flytekit + # kubernetes + # requests-oauthlib # responses +requests-oauthlib==1.3.1 + # via kubernetes responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit +rsa==4.9 + # via google-auth +secretstorage==3.3.3 + # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 # via - # grpcio + # google-auth # jsonschema + # kubernetes # python-dateutil # websocket-client +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 @@ -186,28 +225,36 @@ types-toml==0.10.8.1 # via responses typing-extensions==4.4.0 # via + # arrow # flytekit + # gitpython + # importlib-metadata + # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.12 +urllib3==1.26.14 # via # docker # flytekit + # kubernetes # requests # responses websocket-client==0.59.0 # via # -r requirements.in # docker + # kubernetes wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.10.0 - # via importlib-metadata +zipp==3.12.0 + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/setup.py b/setup.py index 49434fc684..319079dde3 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,7 @@ # aliases. More details in https://github.com/flyteorg/flyte/issues/3166 "numpy<1.24.0", "gitpython", + "kubernetes>=12.0.1", ], extras_require=extras_require, scripts=[ diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index 4b48fbcfb9..d5b07fe420 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -121,8 +121,9 @@ discovery_version, deprecated, cache_serializable, + pod_template_name, ) - for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable in product( + for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable, pod_template_name in product( [True, False], LIST_OF_RUNTIME_METADATA, [timedelta(days=i) for i in range(3)], @@ -131,6 +132,7 @@ ["1.0"], ["deprecated"], [True, False], + ["A", "B"], ) ] diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index f2a5629f13..adecd293ca 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -14,19 +14,19 @@ cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.1 +charset-normalizer==3.0.1 # via requests click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.2.0 +cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 # via flytekit croniter==1.3.7 # via flytekit -cryptography==38.0.3 +cryptography==39.0.0 # via # pyopenssl # secretstorage @@ -46,29 +46,36 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.2.9 +flyteidl==1.3.5 # via flytekit -flytekit==1.2.4 +flytekit==1.3.1 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in fonttools==4.38.0 # via matplotlib -googleapis-common-protos==1.57.0 +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via flytekit +googleapis-common-protos==1.58.0 # via # flyteidl + # flytekit # grpcio-status -grpcio==1.48.2 +grpcio==1.51.1 # via # flytekit # grpcio-status -grpcio-status==1.48.2 +grpcio-status==1.51.1 # via flytekit idna==3.4 # via requests -importlib-metadata==5.0.0 +importlib-metadata==6.0.0 # via # click # flytekit # keyring +importlib-resources==5.10.2 + # via keyring jaraco-classes==3.2.3 # via keyring jeepney==0.8.0 @@ -85,11 +92,11 @@ joblib==1.2.0 # via # -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in # flytekit -keyring==23.11.0 +keyring==23.13.1 # via flytekit kiwisolver==1.4.4 # via matplotlib -markupsafe==2.1.1 +markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 # via @@ -115,21 +122,20 @@ numpy==1.21.6 # opencv-python # pandas # pyarrow -opencv-python==4.6.0.66 +opencv-python==4.7.0.68 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -packaging==21.3 +packaging==23.0 # via # docker # marshmallow # matplotlib pandas==1.3.5 # via flytekit -pillow==9.3.0 +pillow==9.4.0 # via matplotlib -protobuf==3.20.3 +protobuf==4.21.12 # via # flyteidl - # flytekit # googleapis-common-protos # grpcio-status # protoc-gen-swagger @@ -141,12 +147,10 @@ pyarrow==6.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.1.0 +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 - # via - # matplotlib - # packaging + # via matplotlib python-dateutil==2.8.2 # via # arrow @@ -156,11 +160,11 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==7.0.0 +python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.6 +pytz==2022.7.1 # via # flytekit # pandas @@ -170,7 +174,7 @@ pyyaml==6.0 # flytekit regex==2022.10.31 # via docker-image-py -requests==2.28.1 +requests==2.28.2 # via # cookiecutter # docker @@ -185,9 +189,9 @@ secretstorage==3.3.3 singledispatchmethod==1.0 # via flytekit six==1.16.0 - # via - # grpcio - # python-dateutil + # via python-dateutil +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 @@ -202,19 +206,20 @@ typing-extensions==4.4.0 # via # arrow # flytekit + # gitpython # importlib-metadata # kiwisolver # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.12 +urllib3==1.26.14 # via # docker # flytekit # requests # responses -websocket-client==1.4.2 +websocket-client==1.5.0 # via docker wheel==0.38.4 # via @@ -224,5 +229,7 @@ wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.10.0 - # via importlib-metadata +zipp==3.12.0 + # via + # importlib-metadata + # importlib-resources diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index 108b42ebf7..24856432b4 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -1,9 +1,14 @@ from typing import Any import pytest +from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements, V1Volume from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.base_task import TaskMetadata +from flytekit.core.pod_template import PodTemplate from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image +from flytekit.core.resources import Resources +from flytekit.tools.translator import get_serializable_task @pytest.fixture @@ -36,7 +41,6 @@ def execute(self, **kwargs) -> Any: task = DummyAutoContainerTask(name="x", task_config=None, task_type="t") -task_with_env_vars = DummyAutoContainerTask(name="x", environment={"HAM": "spam"}, task_config=None, task_type="t") def test_default_command(default_serialization_settings): @@ -68,14 +72,227 @@ def test_get_container(default_serialization_settings): assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"FOO": "bar"} + ts = get_serializable_task(default_serialization_settings, task) + assert ts.template.container.image == "docker.io/xyz:some-git-hash" + assert ts.template.container.env == {"FOO": "bar"} + + +task_with_env_vars = DummyAutoContainerTask(name="x", environment={"HAM": "spam"}, task_config=None, task_type="t") + def test_get_container_with_task_envvars(default_serialization_settings): c = task_with_env_vars.get_container(default_serialization_settings) assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"FOO": "bar", "HAM": "spam"} + ts = get_serializable_task(default_serialization_settings, task_with_env_vars) + assert ts.template.container.image == "docker.io/xyz:some-git-hash" + assert ts.template.container.env == {"FOO": "bar", "HAM": "spam"} + def test_get_container_without_serialization_settings_envvars(minimal_serialization_settings): c = task_with_env_vars.get_container(minimal_serialization_settings) assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"HAM": "spam"} + + ts = get_serializable_task(minimal_serialization_settings, task_with_env_vars) + assert ts.template.container.image == "docker.io/xyz:some-git-hash" + assert ts.template.container.env == {"HAM": "spam"} + + +task_with_pod_template = DummyAutoContainerTask( + name="x", + metadata=TaskMetadata( + pod_template_name="podTemplateB", # should be overwritten + retries=3, # ensure other fields still exists + ), + task_config=None, + task_type="t", + container_image="repo/image:0.0.0", + requests=Resources(cpu="3", gpu="1"), + limits=Resources(cpu="6", gpu="2"), + environment={"eKeyA": "eValA", "eKeyB": "vKeyB"}, + pod_template=PodTemplate( + primary_container_name="primary", + labels={"lKeyA": "lValA", "lKeyB": "lValB"}, + annotations={"aKeyA": "aValA", "aKeyB": "aValB"}, + pod_spec=V1PodSpec( + containers=[ + V1Container( + name="notPrimary", + ), + V1Container( + name="primary", + image="repo/placeholderImage:0.0.0", + command="placeholderCommand", + args="placeholderArgs", + resources=V1ResourceRequirements(limits={"cpu": "999", "gpu": "999"}), + env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")], + ), + ], + volumes=[V1Volume(name="volume")], + ), + ), + pod_template_name="podTemplateA", +) + + +def test_pod_template(default_serialization_settings): + ################# + # Test get_k8s_pod + ################# + + container = task_with_pod_template.get_container(default_serialization_settings) + assert container is None + + k8s_pod = task_with_pod_template.get_k8s_pod(default_serialization_settings) + + # labels/annotations should be passed + metadata = k8s_pod.metadata + assert metadata.labels == {"lKeyA": "lValA", "lKeyB": "lValB"} + assert metadata.annotations == {"aKeyA": "aValA", "aKeyB": "aValB"} + + pod_spec = k8s_pod.pod_spec + primary_container = pod_spec["containers"][1] + + # To test overwritten attributes + + # image + assert primary_container["image"] == "repo/image:0.0.0" + # command + assert primary_container["command"] == [] + # args + assert primary_container["args"] == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.flytekit.unit.core.test_python_auto_container", + "task-name", + "task_with_pod_template", + ] + # resource + assert primary_container["resources"]["requests"] == {"cpu": "3", "gpu": "1"} + assert primary_container["resources"]["limits"] == {"cpu": "6", "gpu": "2"} + + # To test union attributes + assert primary_container["env"] == [ + {"name": "FOO", "value": "bar"}, + {"name": "eKeyA", "value": "eValA"}, + {"name": "eKeyB", "value": "vKeyB"}, + {"name": "eKeyC", "value": "eValC"}, + {"name": "eKeyD", "value": "eValD"}, + ] + + # To test not overwritten attributes + assert pod_spec["volumes"][0] == {"name": "volume"} + + ################# + # Test pod_template_name + ################# + assert task_with_pod_template.metadata.pod_template_name == "podTemplateA" + assert task_with_pod_template.metadata.retries == 3 + + config = task_with_minimum_pod_template.get_config(default_serialization_settings) + + ################# + # Test config + ################# + assert config == {"primary_container_name": "primary"} + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, task_with_pod_template) + assert ts.template.container is None + # k8s_pod content is already verified above, so only check the existence here + assert ts.template.k8s_pod is not None + + assert ts.template.metadata.pod_template_name == "podTemplateA" + assert ts.template.metadata.retries.retries == 3 + assert ts.template.config is not None + + +task_with_minimum_pod_template = DummyAutoContainerTask( + name="x", + task_config=None, + task_type="t", + container_image="repo/image:0.0.0", + pod_template=PodTemplate( + primary_container_name="primary", + labels={"lKeyA": "lValA"}, + annotations={"aKeyA": "aValA"}, + ), + pod_template_name="A", +) + + +def test_minimum_pod_template(default_serialization_settings): + + ################# + # Test get_k8s_pod + ################# + + container = task_with_minimum_pod_template.get_container(default_serialization_settings) + assert container is None + + k8s_pod = task_with_minimum_pod_template.get_k8s_pod(default_serialization_settings) + + metadata = k8s_pod.metadata + assert metadata.labels == {"lKeyA": "lValA"} + assert metadata.annotations == {"aKeyA": "aValA"} + + pod_spec = k8s_pod.pod_spec + primary_container = pod_spec["containers"][0] + + assert primary_container["image"] == "repo/image:0.0.0" + assert primary_container["command"] == [] + assert primary_container["args"] == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.flytekit.unit.core.test_python_auto_container", + "task-name", + "task_with_minimum_pod_template", + ] + + config = task_with_minimum_pod_template.get_config(default_serialization_settings) + assert config == {"primary_container_name": "primary"} + + ################# + # Test pod_teamplte_name + ################# + assert task_with_minimum_pod_template.metadata.pod_template_name == "A" + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, task_with_minimum_pod_template) + assert ts.template.container is None + # k8s_pod content is already verified above, so only check the existence here + assert ts.template.k8s_pod is not None + assert ts.template.metadata.pod_template_name == "A" + assert ts.template.config is not None diff --git a/tests/flytekit/unit/core/test_python_function_task.py b/tests/flytekit/unit/core/test_python_function_task.py index 34aaefaeb3..02e04a302f 100644 --- a/tests/flytekit/unit/core/test_python_function_task.py +++ b/tests/flytekit/unit/core/test_python_function_task.py @@ -1,10 +1,13 @@ import pytest +from kubernetes.client.models import V1Container, V1PodSpec from flytekit import task from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.pod_template import PodTemplate from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.tracker import isnested, istestfunction +from flytekit.tools.translator import get_serializable_task from tests.flytekit.unit.core import tasks @@ -122,3 +125,83 @@ def foo_missing_cache_version(i: str): @task(cache_serialize=True) def foo_missing_cache(i: str): print(f"{i}") + + +def test_pod_template(): + @task( + container_image="repo/image:0.0.0", + pod_template=PodTemplate( + primary_container_name="primary", + labels={"lKeyA": "lValA"}, + annotations={"aKeyA": "aValA"}, + pod_spec=V1PodSpec( + containers=[ + V1Container( + name="primary", + ), + ] + ), + ), + pod_template_name="A", + ) + def func_with_pod_template(i: str): + print(i + 3) + + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") + default_image_config = ImageConfig(default_image=default_image) + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + + ################# + # Test get_k8s_pod + ################# + + container = func_with_pod_template.get_container(default_serialization_settings) + assert container is None + + k8s_pod = func_with_pod_template.get_k8s_pod(default_serialization_settings) + + metadata = k8s_pod.metadata + assert metadata.labels == {"lKeyA": "lValA"} + assert metadata.annotations == {"aKeyA": "aValA"} + + pod_spec = k8s_pod.pod_spec + primary_container = pod_spec["containers"][0] + + assert primary_container["image"] == "repo/image:0.0.0" + assert primary_container["command"] == [] + assert primary_container["args"] == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.flytekit.unit.core.test_python_function_task", + "task-name", + "func_with_pod_template", + ] + + ################# + # Test pod_teamplte_name + ################# + assert func_with_pod_template.metadata.pod_template_name == "A" + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, func_with_pod_template) + assert ts.template.container is None + # k8s_pod content is already verified above, so only check the existence here + assert ts.template.k8s_pod is not None + assert ts.template.metadata.pod_template_name == "A" diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index fed32b63aa..a979a39b66 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -71,6 +71,7 @@ def test_task_metadata(): "0.1.1b0", "This is deprecated!", True, + "A", ) assert obj.discoverable is True @@ -82,6 +83,7 @@ def test_task_metadata(): assert obj.runtime.version == "1.0.0" assert obj.deprecated_error_message == "This is deprecated!" assert obj.discovery_version == "0.1.1b0" + assert obj.pod_template_name == "A" assert obj == task.TaskMetadata.from_flyte_idl(obj.to_flyte_idl()) @@ -134,6 +136,7 @@ def test_task_spec(): "0.1.1b0", "This is deprecated!", True, + "A", ) int_type = types.LiteralType(types.SimpleType.INTEGER) @@ -178,7 +181,7 @@ def test_task_spec(): assert obj.template == template -def test_task_template__k8s_pod_target(): +def test_task_template_k8s_pod_target(): int_type = types.LiteralType(types.SimpleType.INTEGER) obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), @@ -192,6 +195,7 @@ def test_task_template__k8s_pod_target(): "1.0", "deprecated", False, + "A", ), interface_models.TypedInterface( # inputs diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index d229d0d5c9..2b5b06696b 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -38,6 +38,7 @@ def test_workflow_closure(): "0.1.1b0", "This is deprecated!", True, + "A", ) cpu_resource = _task.Resources.ResourceEntry(_task.Resources.ResourceName.CPU, "1")