diff --git a/dev-requirements.txt b/dev-requirements.txt
index ea0e19354a..eeedd85fe8 100644
--- a/dev-requirements.txt
+++ b/dev-requirements.txt
@@ -8,7 +8,7 @@
# via
# -c requirements.txt
# pytest-flyte
-absl-py==1.3.0
+absl-py==1.4.0
# via
# tensorboard
# tensorflow
@@ -34,11 +34,14 @@ binaryornot==0.4.4
# cookiecutter
cached-property==1.5.2
# via docker-compose
-cachetools==5.2.0
- # via google-auth
+cachetools==5.3.0
+ # via
+ # -c requirements.txt
+ # google-auth
certifi==2022.12.7
# via
# -c requirements.txt
+ # kubernetes
# requests
cffi==1.15.1
# via
@@ -51,7 +54,7 @@ chardet==5.1.0
# via
# -c requirements.txt
# binaryornot
-charset-normalizer==2.1.1
+charset-normalizer==3.0.1
# via
# -c requirements.txt
# requests
@@ -60,7 +63,7 @@ click==8.1.3
# -c requirements.txt
# cookiecutter
# flytekit
-cloudpickle==2.2.0
+cloudpickle==2.2.1
# via
# -c requirements.txt
# flytekit
@@ -70,7 +73,7 @@ cookiecutter==2.1.1
# via
# -c requirements.txt
# flytekit
-coverage[toml]==6.5.0
+coverage[toml]==7.1.0
# via
# -r dev-requirements.in
# pytest-cov
@@ -78,7 +81,7 @@ croniter==1.3.8
# via
# -c requirements.txt
# flytekit
-cryptography==38.0.4
+cryptography==39.0.0
# via
# -c requirements.txt
# paramiko
@@ -124,49 +127,58 @@ docstring-parser==0.15
# via
# -c requirements.txt
# flytekit
-exceptiongroup==1.0.4
+exceptiongroup==1.1.0
# via pytest
-filelock==3.8.2
+filelock==3.9.0
# via virtualenv
-flatbuffers==22.12.6
+flatbuffers==23.1.21
# via tensorflow
-flyteidl==1.3.1
+flyteidl==1.3.5
# via
# -c requirements.txt
# flytekit
gast==0.5.3
# via tensorflow
+gitdb==4.0.10
+ # via
+ # -c requirements.txt
+ # gitpython
+gitpython==3.1.30
+ # via
+ # -c requirements.txt
+ # flytekit
google-api-core[grpc]==2.11.0
# via
# google-cloud-bigquery
# google-cloud-bigquery-storage
# google-cloud-core
-google-auth==2.15.0
+google-auth==2.16.0
# via
+ # -c requirements.txt
# google-api-core
# google-auth-oauthlib
# google-cloud-core
+ # kubernetes
# tensorboard
google-auth-oauthlib==0.4.6
# via tensorboard
-google-cloud-bigquery==3.4.0
+google-cloud-bigquery==3.4.2
+ # via -r dev-requirements.in
+google-cloud-bigquery-storage==2.18.1
# via -r dev-requirements.in
-google-cloud-bigquery-storage==2.16.2
- # via
- # -r dev-requirements.in
- # google-cloud-bigquery
google-cloud-core==2.3.2
# via google-cloud-bigquery
google-crc32c==1.5.0
# via google-resumable-media
google-pasta==0.2.0
# via tensorflow
-google-resumable-media==2.4.0
+google-resumable-media==2.4.1
# via google-cloud-bigquery
-googleapis-common-protos==1.57.0
+googleapis-common-protos==1.58.0
# via
# -c requirements.txt
# flyteidl
+ # flytekit
# google-api-core
# grpcio-status
grpcio==1.51.1
@@ -183,15 +195,15 @@ grpcio-status==1.51.1
# -c requirements.txt
# flytekit
# google-api-core
-h5py==3.7.0
+h5py==3.8.0
# via tensorflow
-identify==2.5.9
+identify==2.5.17
# via pre-commit
idna==3.4
# via
# -c requirements.txt
# requests
-importlib-metadata==5.1.0
+importlib-metadata==6.0.0
# via
# -c requirements.txt
# click
@@ -203,7 +215,11 @@ importlib-metadata==5.1.0
# pre-commit
# pytest
# virtualenv
-iniconfig==1.1.1
+importlib-resources==5.10.2
+ # via
+ # -c requirements.txt
+ # keyring
+iniconfig==2.0.0
# via pytest
ipython==7.34.0
# via -r dev-requirements.in
@@ -242,15 +258,19 @@ keras==2.8.0
# via tensorflow
keras-preprocessing==1.1.2
# via tensorflow
-keyring==23.11.0
+keyring==23.13.1
+ # via
+ # -c requirements.txt
+ # flytekit
+kubernetes==25.3.0
# via
# -c requirements.txt
# flytekit
-libclang==14.0.6
+libclang==15.0.6.1
# via tensorflow
markdown==3.4.1
# via tensorboard
-markupsafe==2.1.1
+markupsafe==2.1.2
# via
# -c requirements.txt
# jinja2
@@ -271,7 +291,7 @@ marshmallow-jsonschema==0.13.0
# flytekit
matplotlib-inline==0.1.6
# via ipython
-mock==4.0.3
+mock==5.0.1
# via -r dev-requirements.in
more-itertools==9.0.0
# via
@@ -304,7 +324,9 @@ numpy==1.21.6
# tensorboard
# tensorflow
oauthlib==3.2.2
- # via requests-oauthlib
+ # via
+ # -c requirements.txt
+ # requests-oauthlib
opt-einsum==3.3.0
# via tensorflow
packaging==21.3
@@ -318,7 +340,7 @@ pandas==1.3.5
# via
# -c requirements.txt
# flytekit
-paramiko==2.12.0
+paramiko==3.0.0
# via docker
parso==0.8.3
# via jedi
@@ -326,19 +348,19 @@ pexpect==4.8.0
# via ipython
pickleshare==0.7.5
# via ipython
-platformdirs==2.6.0
+platformdirs==2.6.2
# via virtualenv
pluggy==1.0.0
# via pytest
-pre-commit==2.20.0
+pre-commit==2.21.0
# via -r dev-requirements.in
prompt-toolkit==3.0.36
# via ipython
-proto-plus==1.22.1
+proto-plus==1.22.2
# via
# google-cloud-bigquery
# google-cloud-bigquery-storage
-protobuf==4.21.10
+protobuf==4.21.12
# via
# -c requirements.txt
# flyteidl
@@ -365,22 +387,24 @@ pyarrow==10.0.1
# via
# -c requirements.txt
# flytekit
- # google-cloud-bigquery
pyasn1==0.4.8
# via
+ # -c requirements.txt
# pyasn1-modules
# rsa
pyasn1-modules==0.2.8
- # via google-auth
+ # via
+ # -c requirements.txt
+ # google-auth
pycparser==2.21
# via
# -c requirements.txt
# cffi
-pygments==2.13.0
+pygments==2.14.0
# via ipython
pynacl==1.5.0
# via paramiko
-pyopenssl==22.1.0
+pyopenssl==23.0.0
# via
# -c requirements.txt
# flytekit
@@ -388,11 +412,11 @@ pyparsing==3.0.9
# via
# -c requirements.txt
# packaging
-pyrsistent==0.19.2
+pyrsistent==0.19.3
# via
# -c requirements.txt
# jsonschema
-pytest==7.2.0
+pytest==7.2.1
# via
# -r dev-requirements.in
# pytest-cov
@@ -411,14 +435,15 @@ python-dateutil==2.8.2
# croniter
# flytekit
# google-cloud-bigquery
+ # kubernetes
# pandas
-python-dotenv==0.21.0
+python-dotenv==0.21.1
# via docker-compose
python-json-logger==2.0.4
# via
# -c requirements.txt
# flytekit
-python-slugify==7.0.0
+python-slugify==8.0.0
# via
# -c requirements.txt
# cookiecutter
@@ -426,7 +451,7 @@ pytimeparse==1.1.8
# via
# -c requirements.txt
# flytekit
-pytz==2022.6
+pytz==2022.7.1
# via
# -c requirements.txt
# flytekit
@@ -437,12 +462,13 @@ pyyaml==5.4.1
# cookiecutter
# docker-compose
# flytekit
+ # kubernetes
# pre-commit
regex==2022.10.31
# via
# -c requirements.txt
# docker-image-py
-requests==2.28.1
+requests==2.28.2
# via
# -c requirements.txt
# cookiecutter
@@ -451,11 +477,15 @@ requests==2.28.1
# flytekit
# google-api-core
# google-cloud-bigquery
+ # kubernetes
# requests-oauthlib
# responses
# tensorboard
requests-oauthlib==1.3.1
- # via google-auth-oauthlib
+ # via
+ # -c requirements.txt
+ # google-auth-oauthlib
+ # kubernetes
responses==0.22.0
# via
# -c requirements.txt
@@ -465,7 +495,9 @@ retry==0.9.2
# -c requirements.txt
# flytekit
rsa==4.9
- # via google-auth
+ # via
+ # -c requirements.txt
+ # google-auth
scikit-learn==1.0.2
# via -r dev-requirements.in
scipy==1.7.3
@@ -487,10 +519,14 @@ six==1.16.0
# google-pasta
# jsonschema
# keras-preprocessing
- # paramiko
+ # kubernetes
# python-dateutil
# tensorflow
# websocket-client
+smmap==5.0.0
+ # via
+ # -c requirements.txt
+ # gitdb
sortedcontainers==2.4.0
# via
# -c requirements.txt
@@ -509,9 +545,9 @@ tensorflow==2.8.1
# via -r dev-requirements.in
tensorflow-estimator==2.8.0
# via tensorflow
-tensorflow-io-gcs-filesystem==0.28.0
+tensorflow-io-gcs-filesystem==0.30.0
# via tensorflow
-termcolor==2.1.1
+termcolor==2.2.0
# via tensorflow
text-unidecode==1.3
# via
@@ -524,16 +560,15 @@ threadpoolctl==3.1.0
toml==0.10.2
# via
# -c requirements.txt
- # pre-commit
# responses
tomli==2.0.1
# via
# coverage
# mypy
# pytest
-torch==1.13.1
+torch==1.12.1
# via -r dev-requirements.in
-traitlets==5.6.0
+traitlets==5.9.0
# via
# ipython
# matplotlib-inline
@@ -548,8 +583,10 @@ typing-extensions==4.4.0
# -c requirements.txt
# arrow
# flytekit
+ # gitpython
# importlib-metadata
# mypy
+ # platformdirs
# responses
# tensorflow
# torch
@@ -558,22 +595,24 @@ typing-inspect==0.8.0
# via
# -c requirements.txt
# dataclasses-json
-urllib3==1.26.13
+urllib3==1.26.14
# via
# -c requirements.txt
# docker
# flytekit
+ # kubernetes
# requests
# responses
virtualenv==20.17.1
# via pre-commit
-wcwidth==0.2.5
+wcwidth==0.2.6
# via prompt-toolkit
websocket-client==0.59.0
# via
# -c requirements.txt
# docker
# docker-compose
+ # kubernetes
werkzeug==2.2.2
# via tensorboard
wheel==0.38.4
@@ -588,10 +627,11 @@ wrapt==1.14.1
# deprecated
# flytekit
# tensorflow
-zipp==3.11.0
+zipp==3.12.0
# via
# -c requirements.txt
# importlib-metadata
+ # importlib-resources
# The following packages are considered to be unsafe in a requirements file:
# setuptools
diff --git a/doc-requirements.txt b/doc-requirements.txt
index 8eb39a5a1e..2eb0532253 100644
--- a/doc-requirements.txt
+++ b/doc-requirements.txt
@@ -6,17 +6,17 @@
#
-e file:.#egg=flytekit
# via -r doc-requirements.in
-absl-py==1.3.0
+absl-py==1.4.0
# via
# tensorboard
# tensorflow
aiosignal==1.3.1
# via ray
-alabaster==0.7.12
+alabaster==0.7.13
# via sphinx
-alembic==1.9.1
+alembic==1.9.2
# via mlflow
-altair==4.2.0
+altair==4.2.2
# via great-expectations
ansiwrap==0.8.4
# via papermill
@@ -27,10 +27,6 @@ anyio==3.6.2
# watchfiles
aplus==0.11.0
# via vaex-core
-appnope==0.1.3
- # via
- # ipykernel
- # ipython
argon2-cffi==21.3.0
# via
# jupyter-server
@@ -42,9 +38,9 @@ arrow==1.2.3
# via
# isoduration
# jinja2-time
-astroid==2.12.13
+astroid==2.14.1
# via sphinx-autoapi
-astropy==5.2
+astropy==5.2.1
# via vaex-astro
asttokens==2.2.1
# via stack-data
@@ -59,7 +55,7 @@ babel==2.11.0
# via sphinx
backcall==0.2.0
# via ipython
-beautifulsoup4==4.11.1
+beautifulsoup4==4.11.2
# via
# furo
# nbconvert
@@ -69,15 +65,17 @@ binaryornot==0.4.4
# via cookiecutter
blake3==0.3.3
# via vaex-core
-bleach==5.0.1
+bleach==6.0.0
# via nbconvert
-botocore==1.29.44
+botocore==1.29.61
# via -r doc-requirements.in
bqplot==0.12.36
- # via vaex-jupyter
+ # via
+ # ipyvolume
+ # vaex-jupyter
branca==0.6.0
# via ipyleaflet
-cachetools==5.2.0
+cachetools==5.3.0
# via
# google-auth
# vaex-server
@@ -93,13 +91,14 @@ cfgv==3.3.1
# via pre-commit
chardet==5.1.0
# via binaryornot
-charset-normalizer==2.1.1
+charset-normalizer==3.0.1
# via requests
click==8.1.3
# via
# cookiecutter
# dask
# databricks-cli
+ # distributed
# flask
# flytekit
# great-expectations
@@ -108,9 +107,10 @@ click==8.1.3
# ray
# sphinx-click
# uvicorn
-cloudpickle==2.2.0
+cloudpickle==2.2.1
# via
# dask
+ # distributed
# flytekit
# mlflow
# shap
@@ -119,9 +119,7 @@ colorama==0.4.6
# via great-expectations
comm==0.1.2
# via ipykernel
-commonmark==0.9.1
- # via rich
-contourpy==1.0.6
+contourpy==1.0.7
# via matplotlib
cookiecutter==2.1.1
# via flytekit
@@ -132,19 +130,23 @@ cryptography==39.0.0
# -r doc-requirements.in
# great-expectations
# pyopenssl
+ # secretstorage
css-html-js-minify==2.5.5
# via sphinx-material
cycler==0.11.0
# via matplotlib
-dask==2022.12.1
- # via vaex-core
+dask[distributed]==2023.1.1
+ # via
+ # -r doc-requirements.in
+ # distributed
+ # vaex-core
databricks-cli==0.17.4
# via mlflow
dataclasses-json==0.5.7
# via
# dolt-integrations
# flytekit
-debugpy==1.6.4
+debugpy==1.6.6
# via ipykernel
decorator==5.1.1
# via
@@ -158,6 +160,8 @@ diskcache==5.4.0
# via flytekit
distlib==0.3.6
# via virtualenv
+distributed==2023.1.1
+ # via dask
docker==6.0.1
# via
# flytekit
@@ -178,12 +182,11 @@ doltcli==0.1.17
entrypoints==0.4
# via
# altair
- # jupyter-client
# mlflow
# papermill
executing==1.2.0
# via stack-data
-fastapi==0.88.0
+fastapi==0.89.1
# via vaex-server
fastjsonschema==2.16.2
# via nbformat
@@ -194,9 +197,9 @@ filelock==3.9.0
# virtualenv
flask==2.2.2
# via mlflow
-flatbuffers==23.1.4
+flatbuffers==23.1.21
# via tensorflow
-flyteidl==1.3.2
+flyteidl==1.3.5
# via flytekit
fonttools==4.38.0
# via matplotlib
@@ -208,27 +211,29 @@ frozenlist==1.3.3
# via
# aiosignal
# ray
-fsspec==2022.11.0
+fsspec==2023.1.0
# via
# -r doc-requirements.in
# dask
# modin
furo @ git+https://github.com/flyteorg/furo@main
# via -r doc-requirements.in
-future==0.18.2
+future==0.18.3
# via vaex-core
gast==0.5.3
# via tensorflow
gitdb==4.0.10
# via gitpython
gitpython==3.1.30
- # via mlflow
+ # via
+ # flytekit
+ # mlflow
google-api-core[grpc]==2.11.0
# via
# -r doc-requirements.in
# google-cloud-bigquery
# google-cloud-core
-google-auth==2.15.0
+google-auth==2.16.0
# via
# google-api-core
# google-auth-oauthlib
@@ -239,7 +244,7 @@ google-auth-oauthlib==0.4.6
# via tensorboard
google-cloud==0.34.0
# via -r doc-requirements.in
-google-cloud-bigquery==3.4.1
+google-cloud-bigquery==3.5.0
# via -r doc-requirements.in
google-cloud-core==2.3.2
# via google-cloud-bigquery
@@ -247,17 +252,17 @@ google-crc32c==1.5.0
# via google-resumable-media
google-pasta==0.2.0
# via tensorflow
-google-resumable-media==2.4.0
+google-resumable-media==2.4.1
# via google-cloud-bigquery
-googleapis-common-protos==1.57.1
+googleapis-common-protos==1.58.0
# via
# flyteidl
# flytekit
# google-api-core
# grpcio-status
-great-expectations==0.15.42
+great-expectations==0.15.46
# via -r doc-requirements.in
-greenlet==2.0.1
+greenlet==2.0.2
# via sqlalchemy
grpcio==1.51.1
# via
@@ -276,15 +281,17 @@ gunicorn==20.1.0
# via mlflow
h11==0.14.0
# via uvicorn
-h5py==3.7.0
+h5py==3.8.0
# via
# tensorflow
# vaex-hdf5
+heapdict==1.0.1
+ # via zict
htmlmin==0.1.12
- # via pandas-profiling
+ # via ydata-profiling
httptools==0.5.0
# via uvicorn
-identify==2.5.12
+identify==2.5.17
# via pre-commit
idna==3.4
# via
@@ -300,6 +307,7 @@ importlib-metadata==5.2.0
# flask
# flytekit
# great-expectations
+ # jupyter-client
# keyring
# markdown
# mlflow
@@ -307,7 +315,7 @@ importlib-metadata==5.2.0
# sphinx
ipydatawidgets==4.3.2
# via pythreejs
-ipykernel==6.19.4
+ipykernel==6.20.2
# via
# ipywidgets
# jupyter
@@ -319,7 +327,7 @@ ipyleaflet==0.17.2
# via vaex-jupyter
ipympl==0.9.2
# via vaex-jupyter
-ipython==8.8.0
+ipython==8.9.0
# via
# great-expectations
# ipykernel
@@ -332,12 +340,16 @@ ipython-genutils==0.2.0
# nbclassic
# notebook
# qtconsole
-ipyvolume==0.5.2
+ipyvolume==0.6.0
# via vaex-jupyter
ipyvue==1.8.0
- # via ipyvuetify
+ # via
+ # ipyvolume
+ # ipyvuetify
ipyvuetify==1.8.4
- # via vaex-jupyter
+ # via
+ # ipyvolume
+ # vaex-jupyter
ipywebrtc==0.6.0
# via ipyvolume
ipywidgets==8.0.4
@@ -359,11 +371,16 @@ jaraco-classes==3.2.3
# via keyring
jedi==0.18.2
# via ipython
+jeepney==0.8.0
+ # via
+ # keyring
+ # secretstorage
jinja2==3.1.2
# via
# altair
# branca
# cookiecutter
+ # distributed
# flask
# great-expectations
# jinja2-time
@@ -372,10 +389,10 @@ jinja2==3.1.2
# nbclassic
# nbconvert
# notebook
- # pandas-profiling
# sphinx
# sphinx-autoapi
# vaex-ml
+ # ydata-profiling
jinja2-time==0.2.0
# via cookiecutter
jmespath==1.0.1
@@ -400,7 +417,7 @@ jsonschema[format-nongpl]==4.17.3
# ray
jupyter==1.0.0
# via -r doc-requirements.in
-jupyter-client==7.4.8
+jupyter-client==8.0.2
# via
# ipykernel
# jupyter-console
@@ -411,7 +428,7 @@ jupyter-client==7.4.8
# qtconsole
jupyter-console==6.4.4
# via jupyter
-jupyter-core==5.1.2
+jupyter-core==5.2.0
# via
# jupyter-client
# jupyter-server
@@ -421,13 +438,13 @@ jupyter-core==5.1.2
# nbformat
# notebook
# qtconsole
-jupyter-events==0.5.0
+jupyter-events==0.6.3
# via jupyter-server
-jupyter-server==2.0.6
+jupyter-server==2.2.0
# via
# nbclassic
# notebook-shim
-jupyter-server-terminals==0.4.3
+jupyter-server-terminals==0.4.4
# via jupyter-server
jupyterlab-pygments==0.2.2
# via nbconvert
@@ -442,15 +459,19 @@ keyring==23.13.1
kiwisolver==1.4.4
# via matplotlib
kubernetes==25.3.0
- # via -r doc-requirements.in
+ # via
+ # -r doc-requirements.in
+ # flytekit
lazy-object-proxy==1.9.0
# via astroid
-libclang==14.0.6
+libclang==15.0.6.1
# via tensorflow
llvmlite==0.39.1
# via numba
locket==1.0.0
- # via partd
+ # via
+ # distributed
+ # partd
lxml==4.9.2
# via sphinx-material
makefun==1.15.0
@@ -462,7 +483,9 @@ markdown==3.4.1
# -r doc-requirements.in
# mlflow
# tensorboard
-markupsafe==2.1.1
+markdown-it-py==2.1.0
+ # via rich
+markupsafe==2.1.2
# via
# jinja2
# mako
@@ -478,51 +501,56 @@ marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow-jsonschema==0.13.0
# via flytekit
-matplotlib==3.6.2
+matplotlib==3.6.3
# via
# ipympl
+ # ipyvolume
# mlflow
- # pandas-profiling
# phik
# seaborn
# vaex-viz
+ # ydata-profiling
matplotlib-inline==0.1.6
# via
# ipykernel
# ipython
+mdurl==0.1.2
+ # via markdown-it-py
mistune==2.0.4
# via
# great-expectations
# nbconvert
mlflow==2.1.1
# via -r doc-requirements.in
-modin==0.18.0
+modin==0.18.1
# via -r doc-requirements.in
more-itertools==9.0.0
# via jaraco-classes
msgpack==1.0.4
- # via ray
+ # via
+ # distributed
+ # ray
multimethod==1.9.1
# via
- # pandas-profiling
# visions
+ # ydata-profiling
mypy-extensions==0.4.3
# via typing-inspect
natsort==8.2.0
# via flytekit
-nbclassic==0.4.8
+nbclassic==0.5.1
# via notebook
nbclient==0.7.2
# via
# nbconvert
# papermill
-nbconvert==7.2.7
+nbconvert==7.2.9
# via
# jupyter
# jupyter-server
# nbclassic
# notebook
-nbformat==5.7.1
+nbformat==5.7.3
# via
# great-expectations
# jupyter-server
@@ -534,11 +562,10 @@ nbformat==5.7.1
nest-asyncio==1.5.6
# via
# ipykernel
- # jupyter-client
# nbclassic
# notebook
# vaex-core
-networkx==2.8.8
+networkx==3.0
# via visions
nodeenv==1.7.0
# via pre-commit
@@ -572,7 +599,6 @@ numpy==1.23.5
# numba
# opt-einsum
# pandas
- # pandas-profiling
# pandera
# patsy
# phik
@@ -591,16 +617,28 @@ numpy==1.23.5
# vaex-core
# visions
# xarray
+ # ydata-profiling
+nvidia-cublas-cu11==11.10.3.66
+ # via
+ # nvidia-cudnn-cu11
+ # torch
+nvidia-cuda-nvrtc-cu11==11.7.99
+ # via torch
+nvidia-cuda-runtime-cu11==11.7.99
+ # via torch
+nvidia-cudnn-cu11==8.5.0.96
+ # via torch
oauthlib==3.2.2
# via
# databricks-cli
# requests-oauthlib
opt-einsum==3.3.0
# via tensorflow
-packaging==21.3
+packaging==22.0
# via
# astropy
# dask
+ # distributed
# docker
# google-cloud-bigquery
# great-expectations
@@ -617,7 +655,7 @@ packaging==21.3
# sphinx
# statsmodels
# xarray
-pandas==1.5.2
+pandas==1.5.3
# via
# altair
# bqplot
@@ -626,7 +664,6 @@ pandas==1.5.2
# great-expectations
# mlflow
# modin
- # pandas-profiling
# pandera
# phik
# seaborn
@@ -635,7 +672,8 @@ pandas==1.5.2
# vaex-core
# visions
# xarray
-pandas-profiling==3.6.2
+ # ydata-profiling
+pandas-profiling==3.6.6
# via -r doc-requirements.in
pandera==0.13.4
# via -r doc-requirements.in
@@ -652,7 +690,7 @@ patsy==0.5.3
pexpect==4.8.0
# via ipython
phik==0.12.3
- # via pandas-profiling
+ # via ydata-profiling
pickleshare==0.7.5
# via ipython
pillow==9.4.0
@@ -667,13 +705,13 @@ platformdirs==2.6.2
# via
# jupyter-core
# virtualenv
-plotly==5.11.0
+plotly==5.13.0
# via -r doc-requirements.in
-pre-commit==2.21.0
+pre-commit==3.0.2
# via sphinx-tags
progressbar2==4.2.0
# via vaex-core
-prometheus-client==0.15.0
+prometheus-client==0.16.0
# via
# jupyter-server
# nbclassic
@@ -682,7 +720,7 @@ prompt-toolkit==3.0.36
# via
# ipython
# jupyter-console
-proto-plus==1.22.1
+proto-plus==1.22.2
# via google-cloud-bigquery
protobuf==4.21.12
# via
@@ -702,6 +740,7 @@ protoc-gen-swagger==0.1.0
# via flyteidl
psutil==5.9.4
# via
+ # distributed
# ipykernel
# modin
ptyprocess==0.7.0
@@ -731,9 +770,9 @@ pydantic==1.10.4
# via
# fastapi
# great-expectations
- # pandas-profiling
# pandera
# vaex-core
+ # ydata-profiling
pyerfa==2.0.0.1
# via astropy
pygments==2.14.0
@@ -754,7 +793,6 @@ pyparsing==3.0.9
# via
# great-expectations
# matplotlib
- # packaging
pyrsistent==0.19.3
# via jsonschema
pyspark==3.3.1
@@ -772,13 +810,13 @@ python-dateutil==2.8.2
# matplotlib
# pandas
# whylabs-client
-python-dotenv==0.21.0
+python-dotenv==0.21.1
# via uvicorn
python-json-logger==2.0.4
# via
# flytekit
# jupyter-events
-python-slugify[unidecode]==7.0.0
+python-slugify[unidecode]==8.0.0
# via
# cookiecutter
# sphinx-material
@@ -788,7 +826,7 @@ pythreejs==2.4.1
# via ipyvolume
pytimeparse==1.1.8
# via flytekit
-pytz==2022.7
+pytz==2022.7.1
# via
# babel
# flytekit
@@ -804,18 +842,19 @@ pyyaml==6.0
# astropy
# cookiecutter
# dask
+ # distributed
# flytekit
# jupyter-events
# kubernetes
# mlflow
- # pandas-profiling
# papermill
# pre-commit
# ray
# sphinx-autoapi
# uvicorn
# vaex-core
-pyzmq==24.0.1
+ # ydata-profiling
+pyzmq==25.0.0
# via
# ipykernel
# jupyter-client
@@ -833,7 +872,7 @@ ray==2.2.0
# via -r doc-requirements.in
regex==2022.10.31
# via docker-image-py
-requests==2.28.1
+requests==2.28.2
# via
# cookiecutter
# databricks-cli
@@ -845,7 +884,6 @@ requests==2.28.1
# ipyvolume
# kubernetes
# mlflow
- # pandas-profiling
# papermill
# ray
# requests-oauthlib
@@ -853,6 +891,7 @@ requests==2.28.1
# sphinx
# tensorboard
# vaex-core
+ # ydata-profiling
requests-oauthlib==1.3.1
# via
# google-auth-oauthlib
@@ -862,10 +901,14 @@ responses==0.22.0
retry==0.9.2
# via flytekit
rfc3339-validator==0.1.4
- # via jsonschema
+ # via
+ # jsonschema
+ # jupyter-events
rfc3986-validator==0.1.1
- # via jsonschema
-rich==13.0.0
+ # via
+ # jsonschema
+ # jupyter-events
+rich==13.3.1
# via vaex-core
rsa==4.9
# via google-auth
@@ -873,7 +916,7 @@ ruamel-yaml==0.17.17
# via great-expectations
ruamel-yaml-clib==0.2.7
# via ruamel-yaml
-scikit-learn==1.2.0
+scikit-learn==1.2.1
# via
# -r doc-requirements.in
# mlflow
@@ -883,13 +926,15 @@ scipy==1.9.3
# great-expectations
# imagehash
# mlflow
- # pandas-profiling
# phik
# scikit-learn
# shap
# statsmodels
+ # ydata-profiling
seaborn==0.12.2
- # via pandas-profiling
+ # via ydata-profiling
+secretstorage==3.3.3
+ # via keyring
send2trash==1.8.0
# via
# jupyter-server
@@ -923,7 +968,9 @@ sniffio==1.3.0
snowballstemmer==2.2.0
# via sphinx
sortedcontainers==2.4.0
- # via flytekit
+ # via
+ # distributed
+ # flytekit
soupsieve==2.3.2.post1
# via beautifulsoup4
sphinx==4.5.0
@@ -942,7 +989,7 @@ sphinx==4.5.0
# sphinx-prompt
# sphinx-tags
# sphinxcontrib-yt
-sphinx-autoapi==2.0.0
+sphinx-autoapi==2.0.1
# via -r doc-requirements.in
sphinx-basic-ng==1.0.0b1
# via furo
@@ -962,13 +1009,13 @@ sphinx-panels==0.6.0
# via -r doc-requirements.in
sphinx-prompt==1.5.0
# via -r doc-requirements.in
-sphinx-tags==0.1.6
+sphinx-tags==0.2.0
# via -r doc-requirements.in
-sphinxcontrib-applehelp==1.0.2
+sphinxcontrib-applehelp==1.0.4
# via sphinx
sphinxcontrib-devhelp==1.0.2
# via sphinx
-sphinxcontrib-htmlhelp==2.0.0
+sphinxcontrib-htmlhelp==2.0.1
# via sphinx
sphinxcontrib-jsmath==1.0.1
# via sphinx
@@ -992,13 +1039,15 @@ starlette==0.22.0
statsd==3.3.0
# via flytekit
statsmodels==0.13.5
- # via pandas-profiling
+ # via ydata-profiling
tabulate==0.9.0
# via
# databricks-cli
# vaex-core
tangled-up-in-unicode==0.2.0
# via visions
+tblib==1.7.0
+ # via distributed
tenacity==8.1.0
# via
# papermill
@@ -1013,7 +1062,7 @@ tensorflow==2.8.1
# via -r doc-requirements.in
tensorflow-estimator==2.8.0
# via tensorflow
-tensorflow-io-gcs-filesystem==0.29.0
+tensorflow-io-gcs-filesystem==0.30.0
# via tensorflow
termcolor==2.2.0
# via tensorflow
@@ -1037,11 +1086,13 @@ toolz==0.12.0
# via
# altair
# dask
+ # distributed
# partd
torch==1.13.1
# via -r doc-requirements.in
tornado==6.2
# via
+ # distributed
# ipykernel
# jupyter-client
# jupyter-server
@@ -1052,10 +1103,10 @@ tornado==6.2
tqdm==4.64.1
# via
# great-expectations
- # pandas-profiling
# papermill
# shap
-traitlets==5.8.0
+ # ydata-profiling
+traitlets==5.9.0
# via
# bqplot
# comm
@@ -1085,7 +1136,7 @@ traittypes==0.2.1
# ipyleaflet
# ipyvolume
typeguard==2.13.3
- # via pandas-profiling
+ # via ydata-profiling
types-toml==0.10.8.1
# via responses
typing-extensions==4.4.0
@@ -1113,9 +1164,10 @@ unidecode==1.3.6
# sphinx-autoapi
uri-template==1.2.0
# via jsonschema
-urllib3==1.26.13
+urllib3==1.26.14
# via
# botocore
+ # distributed
# docker
# flytekit
# great-expectations
@@ -1157,10 +1209,10 @@ virtualenv==20.17.1
# pre-commit
# ray
visions[type_image_path]==0.7.5
- # via pandas-profiling
+ # via ydata-profiling
watchfiles==0.18.1
# via uvicorn
-wcwidth==0.2.5
+wcwidth==0.2.6
# via prompt-toolkit
webcolors==1.12
# via jsonschema
@@ -1168,7 +1220,7 @@ webencodings==0.5.1
# via
# bleach
# tinycss2
-websocket-client==1.4.2
+websocket-client==1.5.0
# via
# docker
# jupyter-server
@@ -1183,10 +1235,12 @@ wheel==0.38.4
# via
# astunparse
# flytekit
+ # nvidia-cublas-cu11
+ # nvidia-cuda-runtime-cu11
# tensorboard
-whylabs-client==0.4.2
+whylabs-client==0.4.3
# via -r doc-requirements.in
-whylogs==1.1.20
+whylogs==1.1.24
# via -r doc-requirements.in
whylogs-sketching==3.4.1.dev3
# via whylogs
@@ -1199,11 +1253,15 @@ wrapt==1.14.1
# flytekit
# pandera
# tensorflow
-xarray==2022.12.0
+xarray==2023.1.0
# via vaex-jupyter
xyzservices==2022.9.0
# via ipyleaflet
-zipp==3.11.0
+ydata-profiling==4.0.0
+ # via pandas-profiling
+zict==2.2.0
+ # via distributed
+zipp==3.12.0
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
diff --git a/flytekit/__init__.py b/flytekit/__init__.py
index e028cbaab9..1748a7a17e 100644
--- a/flytekit/__init__.py
+++ b/flytekit/__init__.py
@@ -58,7 +58,7 @@
TaskMetadata - Wrapper object that allows users to specify Task
Resources - Things like CPUs/Memory, etc.
WorkflowFailurePolicy - Customizes what happens when a workflow fails.
-
+ PodTemplate - Custom PodTemplate for a task.
Dynamic and Nested Workflows
==============================
@@ -175,6 +175,7 @@
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.map_task import map_task
from flytekit.core.notification import Email, PagerDuty, Slack
+from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.reference import get_reference_entity
from flytekit.core.reference_entity import LaunchPlanReference, TaskReference, WorkflowReference
diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py
index 491bed4385..2b21f7c40b 100644
--- a/flytekit/core/base_task.py
+++ b/flytekit/core/base_task.py
@@ -85,6 +85,7 @@ class TaskMetadata(object):
timeout (Optional[Union[datetime.timedelta, int]]): the max amount of time for which one execution of this task
should be executed for. The execution will be terminated if the runtime exceeds the given timeout
(approximately)
+ pod_template_name (Optional[str]): the name of existing PodTemplate resource in the cluster which will be used in this task.
"""
cache: bool = False
@@ -94,6 +95,7 @@ class TaskMetadata(object):
deprecated: str = ""
retries: int = 0
timeout: Optional[Union[datetime.timedelta, int]] = None
+ pod_template_name: Optional[str] = None
def __post_init__(self):
if self.timeout:
@@ -127,6 +129,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
discovery_version=self.cache_version,
deprecated_error_message=self.deprecated,
cache_serializable=self.cache_serialize,
+ pod_template_name=self.pod_template_name,
)
diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py
index 848c1d2524..d470fb54fe 100644
--- a/flytekit/core/container_task.py
+++ b/flytekit/core/container_task.py
@@ -10,6 +10,7 @@
from flytekit.models.security import Secret, SecurityContext
+# TODO: do we need pod_template here? Seems that it is a raw container not running in pods
class ContainerTask(PythonTask):
"""
This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast
diff --git a/flytekit/core/pod_template.py b/flytekit/core/pod_template.py
new file mode 100644
index 0000000000..be8991db45
--- /dev/null
+++ b/flytekit/core/pod_template.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+from typing import Dict, Optional
+
+from kubernetes.client.models import V1PodSpec
+
+from flytekit.exceptions import user as _user_exceptions
+
+PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
+
+
+@dataclass
+class PodTemplate(object):
+ pod_spec: V1PodSpec = V1PodSpec(containers=[])
+ primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME
+ labels: Optional[Dict[str, str]] = None
+ annotations: Optional[Dict[str, str]] = None
+
+ def __post_init__(self):
+ if not self.primary_container_name:
+ raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined")
diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py
index 06133d9784..2d05df3c3d 100644
--- a/flytekit/core/python_auto_container.py
+++ b/flytekit/core/python_auto_container.py
@@ -4,11 +4,16 @@
import re
from abc import ABC
from types import ModuleType
-from typing import Callable, Dict, List, Optional, TypeVar, Union
+from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
+
+from flyteidl.core import tasks_pb2 as _core_task
+from kubernetes.client import ApiClient
+from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements
from flytekit.configuration import ImageConfig, SerializationSettings
-from flytekit.core.base_task import PythonTask, TaskResolverMixin
+from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager
+from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
@@ -18,6 +23,11 @@
from flytekit.models.security import Secret, SecurityContext
T = TypeVar("T")
+_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
+
+
+def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str:
+ return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")
class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC):
@@ -40,6 +50,8 @@ def __init__(
environment: Optional[Dict[str, str]] = None,
task_resolver: Optional[TaskResolverMixin] = None,
secret_requests: Optional[List[Secret]] = None,
+ pod_template: Optional[PodTemplate] = None,
+ pod_template_name: Optional[str] = None,
**kwargs,
):
"""
@@ -64,6 +76,8 @@ def __init__(
- `Confidant `__
- `Kube secrets `__
- `AWS Parameter store `__
+ :param pod_template: Custom PodTemplate for this task.
+ :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""
sec_ctx = None
if secret_requests:
@@ -71,6 +85,11 @@ def __init__(
if not isinstance(s, Secret):
raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}")
sec_ctx = SecurityContext(secrets=secret_requests)
+
+ # pod_template_name overwrites the metedata.pod_template_name
+ kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata()
+ kwargs["metadata"].pod_template_name = pod_template_name
+
super().__init__(
task_type=task_type,
name=name,
@@ -98,6 +117,8 @@ def __init__(
self._task_resolver = task_resolver or default_task_resolver
self._get_command_fn = self.get_default_command
+ self.pod_template = pod_template
+
@property
def task_resolver(self) -> Optional[TaskResolverMixin]:
return self._task_resolver
@@ -157,6 +178,13 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
return self._get_command_fn(settings)
def get_container(self, settings: SerializationSettings) -> _task_model.Container:
+ # if pod_template is not None, return None here but in get_k8s_pod, return pod_template merged with container
+ if self.pod_template is not None:
+ return None
+ else:
+ return self._get_container(settings)
+
+ def _get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = {}
for elem in (settings.env, self.environment):
if elem:
@@ -179,6 +207,64 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe
memory_limit=self.resources.limits.mem,
)
+ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]:
+ containers = self.pod_template.pod_spec.containers
+ primary_exists = False
+
+ for container in containers:
+ if container.name == self.pod_template.primary_container_name:
+ primary_exists = True
+ break
+
+ if not primary_exists:
+ # insert a placeholder primary container if it is not defined in the pod spec.
+ containers.append(V1Container(name=self.pod_template.primary_container_name))
+ final_containers = []
+ for container in containers:
+ # In the case of the primary container, we overwrite specific container attributes
+ # with the default values used in the regular Python task.
+ # The attributes include: image, command, args, resource, and env (env is unioned)
+ if container.name == self.pod_template.primary_container_name:
+ sdk_default_container = self._get_container(settings)
+ container.image = sdk_default_container.image
+ # clear existing commands
+ container.command = sdk_default_container.command
+ # also clear existing args
+ container.args = sdk_default_container.args
+ limits, requests = {}, {}
+ for resource in sdk_default_container.resources.limits:
+ limits[_sanitize_resource_name(resource)] = resource.value
+ for resource in sdk_default_container.resources.requests:
+ requests[_sanitize_resource_name(resource)] = resource.value
+ resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
+ if len(limits) > 0 or len(requests) > 0:
+ # Important! Only copy over resource requirements if they are non-empty.
+ container.resources = resource_requirements
+ container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + (
+ container.env or []
+ )
+ final_containers.append(container)
+ self.pod_template.pod_spec.containers = final_containers
+
+ return ApiClient().sanitize_for_serialization(self.pod_template.pod_spec)
+
+ def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod:
+ if self.pod_template is None:
+ return None
+ return _task_model.K8sPod(
+ pod_spec=self._serialize_pod_spec(settings),
+ metadata=_task_model.K8sObjectMetadata(
+ labels=self.pod_template.labels,
+ annotations=self.pod_template.annotations,
+ ),
+ )
+
+ # need to call super in all its children tasks
+ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]:
+ if self.pod_template is None:
+ return {}
+ return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name}
+
class DefaultTaskResolver(TrackedInstance, TaskResolverMixin):
"""
diff --git a/flytekit/core/task.py b/flytekit/core/task.py
index bb24181338..28c5b5def7 100644
--- a/flytekit/core/task.py
+++ b/flytekit/core/task.py
@@ -4,6 +4,7 @@
from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
+from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
from flytekit.core.resources import Resources
@@ -92,6 +93,8 @@ def task(
task_resolver: Optional[TaskResolverMixin] = None,
docs: Optional[Documentation] = None,
disable_deck: bool = True,
+ pod_template: Optional[PodTemplate] = None,
+ pod_template_name: Optional[str] = None,
) -> Union[Callable, PythonFunctionTask]:
"""
This is the core decorator to use for any task type in flytekit.
@@ -182,6 +185,8 @@ def foo2():
:param task_resolver: Provide a custom task resolver.
:param disable_deck: If true, this task will not output deck html file
:param docs: Documentation about this task
+ :param pod_template: Custom PodTemplate for this task.
+ :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""
def wrapper(fn) -> PythonFunctionTask:
@@ -208,6 +213,8 @@ def wrapper(fn) -> PythonFunctionTask:
task_resolver=task_resolver,
disable_deck=disable_deck,
docs=docs,
+ pod_template=pod_template,
+ pod_template_name=pod_template_name,
)
update_wrapper(task_instance, fn)
return task_instance
diff --git a/flytekit/models/task.py b/flytekit/models/task.py
index 2129cdd88f..fc79c87a2d 100644
--- a/flytekit/models/task.py
+++ b/flytekit/models/task.py
@@ -177,6 +177,7 @@ def __init__(
discovery_version,
deprecated_error_message,
cache_serializable,
+ pod_template_name,
):
"""
Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts,
@@ -196,6 +197,7 @@ def __init__(
receive deprecation warnings.
:param bool cache_serializable: Whether or not caching operations are executed in serial. This means only a
single instance over identical inputs is executed, other concurrent executions wait for the cached results.
+ :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""
self._discoverable = discoverable
self._runtime = runtime
@@ -205,6 +207,7 @@ def __init__(
self._discovery_version = discovery_version
self._deprecated_error_message = deprecated_error_message
self._cache_serializable = cache_serializable
+ self._pod_template_name = pod_template_name
@property
def discoverable(self):
@@ -274,6 +277,14 @@ def cache_serializable(self):
"""
return self._cache_serializable
+ @property
+ def pod_template_name(self):
+ """
+ The name of the existing PodTemplate resource which will be used in this task.
+ :rtype: Text
+ """
+ return self._pod_template_name
+
def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.task_pb2.TaskMetadata
@@ -286,6 +297,7 @@ def to_flyte_idl(self):
discovery_version=self.discovery_version,
deprecated_error_message=self.deprecated_error_message,
cache_serializable=self.cache_serializable,
+ pod_template_name=self.pod_template_name,
)
if self.timeout:
tm.timeout.FromTimedelta(self.timeout)
@@ -306,6 +318,7 @@ def from_flyte_idl(cls, pb2_object):
discovery_version=pb2_object.discovery_version,
deprecated_error_message=pb2_object.deprecated_error_message,
cache_serializable=pb2_object.cache_serializable,
+ pod_template_name=pb2_object.pod_template_name,
)
diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py
index 8787e70011..0e67b2e50b 100644
--- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py
+++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py
@@ -53,7 +53,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
# Parameters in taskTemplate config will be used to create aws job definition.
# More detail about job definition: https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html
- return {"platformCapabilities": self._task_config.platformCapabilities}
+ return {**super().get_config(settings), "platformCapabilities": self._task_config.platformCapabilities}
def get_command(self, settings: SerializationSettings) -> List[str]:
container_args = [
diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py
index 04f821ccf3..6c160c2690 100644
--- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py
+++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py
@@ -186,7 +186,7 @@ def fn(settings: SerializationSettings) -> typing.List[str]:
return self._config_task_instance.get_k8s_pod(settings)
def get_config(self, settings: SerializationSettings) -> typing.Dict[str, str]:
- return self._config_task_instance.get_config(settings)
+ return {**super().get_config(settings), **self._config_task_instance.get_config(settings)}
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
return self._config_task_instance.pre_execute(user_params)
diff --git a/requirements-spark2.txt b/requirements-spark2.txt
index c6d0ff7fc0..55bd796722 100644
--- a/requirements-spark2.txt
+++ b/requirements-spark2.txt
@@ -16,25 +16,29 @@ attrs==20.3.0
# jsonschema
binaryornot==0.4.4
# via cookiecutter
+cachetools==5.3.0
+ # via google-auth
certifi==2022.12.7
- # via requests
+ # via
+ # kubernetes
+ # requests
cffi==1.15.1
# via cryptography
chardet==5.1.0
# via binaryornot
-charset-normalizer==2.1.1
+charset-normalizer==3.0.1
# via requests
click==8.1.3
# via
# cookiecutter
# flytekit
-cloudpickle==2.2.0
+cloudpickle==2.2.1
# via flytekit
cookiecutter==2.1.1
# via flytekit
croniter==1.3.8
# via flytekit
-cryptography==38.0.4
+cryptography==39.0.0
# via
# pyopenssl
# secretstorage
@@ -52,11 +56,18 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
-flyteidl==1.3.1
+flyteidl==1.3.5
# via flytekit
-googleapis-common-protos==1.57.0
+gitdb==4.0.10
+ # via gitpython
+gitpython==3.1.30
+ # via flytekit
+google-auth==2.16.0
+ # via kubernetes
+googleapis-common-protos==1.58.0
# via
# flyteidl
+ # flytekit
# grpcio-status
grpcio==1.51.1
# via
@@ -66,12 +77,14 @@ grpcio-status==1.51.1
# via flytekit
idna==3.4
# via requests
-importlib-metadata==5.1.0
+importlib-metadata==6.0.0
# via
# click
# flytekit
# jsonschema
# keyring
+importlib-resources==5.10.2
+ # via keyring
jaraco-classes==3.2.3
# via keyring
jeepney==0.8.0
@@ -88,9 +101,11 @@ joblib==1.2.0
# via flytekit
jsonschema==3.2.0
# via -r requirements.in
-keyring==23.11.0
+keyring==23.13.1
# via flytekit
-markupsafe==2.1.1
+kubernetes==25.3.0
+ # via flytekit
+markupsafe==2.1.2
# via jinja2
marshmallow==3.19.0
# via
@@ -113,15 +128,18 @@ numpy==1.21.6
# flytekit
# pandas
# pyarrow
+oauthlib==3.2.2
+ # via requests-oauthlib
packaging==21.3
# via
+ # -r requirements.in
# docker
# marshmallow
pandas==1.3.5
# via
# -r requirements.in
# flytekit
-protobuf==4.21.10
+protobuf==4.21.12
# via
# flyteidl
# googleapis-common-protos
@@ -133,27 +151,34 @@ py==1.11.0
# via retry
pyarrow==10.0.1
# via flytekit
+pyasn1==0.4.8
+ # via
+ # pyasn1-modules
+ # rsa
+pyasn1-modules==0.2.8
+ # via google-auth
pycparser==2.21
# via cffi
-pyopenssl==22.1.0
+pyopenssl==23.0.0
# via flytekit
pyparsing==3.0.9
# via packaging
-pyrsistent==0.19.2
+pyrsistent==0.19.3
# via jsonschema
python-dateutil==2.8.2
# via
# arrow
# croniter
# flytekit
+ # kubernetes
# pandas
python-json-logger==2.0.4
# via flytekit
-python-slugify==7.0.0
+python-slugify==8.0.0
# via cookiecutter
pytimeparse==1.1.8
# via flytekit
-pytz==2022.6
+pytz==2022.7.1
# via
# flytekit
# pandas
@@ -162,27 +187,38 @@ pyyaml==5.4.1
# -r requirements.in
# cookiecutter
# flytekit
+ # kubernetes
regex==2022.10.31
# via docker-image-py
-requests==2.28.1
+requests==2.28.2
# via
# cookiecutter
# docker
# flytekit
+ # kubernetes
+ # requests-oauthlib
# responses
+requests-oauthlib==1.3.1
+ # via kubernetes
responses==0.22.0
# via flytekit
retry==0.9.2
# via flytekit
+rsa==4.9
+ # via google-auth
secretstorage==3.3.3
# via keyring
singledispatchmethod==1.0
# via flytekit
six==1.16.0
# via
+ # google-auth
# jsonschema
+ # kubernetes
# python-dateutil
# websocket-client
+smmap==5.0.0
+ # via gitdb
sortedcontainers==2.4.0
# via flytekit
statsd==3.3.0
@@ -197,29 +233,34 @@ typing-extensions==4.4.0
# via
# arrow
# flytekit
+ # gitpython
# importlib-metadata
# responses
# typing-inspect
typing-inspect==0.8.0
# via dataclasses-json
-urllib3==1.26.13
+urllib3==1.26.14
# via
# docker
# flytekit
+ # kubernetes
# requests
# responses
websocket-client==0.59.0
# via
# -r requirements.in
# docker
+ # kubernetes
wheel==0.38.4
# via flytekit
wrapt==1.14.1
# via
# deprecated
# flytekit
-zipp==3.11.0
- # via importlib-metadata
+zipp==3.12.0
+ # via
+ # importlib-metadata
+ # importlib-resources
# The following packages are considered to be unsafe in a requirements file:
# setuptools
diff --git a/requirements.txt b/requirements.txt
index 22d31976ed..53085a92c0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
#
-# This file is autogenerated by pip-compile with python 3.9
-# To update, run:
+# This file is autogenerated by pip-compile with Python 3.7
+# by the following command:
#
# make requirements.txt
#
@@ -14,25 +14,29 @@ attrs==20.3.0
# jsonschema
binaryornot==0.4.4
# via cookiecutter
+cachetools==5.3.0
+ # via google-auth
certifi==2022.12.7
- # via requests
+ # via
+ # kubernetes
+ # requests
cffi==1.15.1
# via cryptography
chardet==5.1.0
# via binaryornot
-charset-normalizer==2.1.1
+charset-normalizer==3.0.1
# via requests
click==8.1.3
# via
# cookiecutter
# flytekit
-cloudpickle==2.2.0
+cloudpickle==2.2.1
# via flytekit
cookiecutter==2.1.1
# via flytekit
croniter==1.3.8
# via flytekit
-cryptography==38.0.4
+cryptography==39.0.0
# via
# pyopenssl
# secretstorage
@@ -50,11 +54,18 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
-flyteidl==1.3.1
+flyteidl==1.3.5
+ # via flytekit
+gitdb==4.0.10
+ # via gitpython
+gitpython==3.1.30
# via flytekit
-googleapis-common-protos==1.57.0
+google-auth==2.16.0
+ # via kubernetes
+googleapis-common-protos==1.58.0
# via
# flyteidl
+ # flytekit
# grpcio-status
grpcio==1.51.1
# via
@@ -64,10 +75,14 @@ grpcio-status==1.51.1
# via flytekit
idna==3.4
# via requests
-importlib-metadata==5.1.0
+importlib-metadata==6.0.0
# via
+ # click
# flytekit
+ # jsonschema
# keyring
+importlib-resources==5.10.2
+ # via keyring
jaraco-classes==3.2.3
# via keyring
jeepney==0.8.0
@@ -84,9 +99,11 @@ joblib==1.2.0
# via flytekit
jsonschema==3.2.0
# via -r requirements.in
-keyring==23.11.0
+keyring==23.13.1
# via flytekit
-markupsafe==2.1.1
+kubernetes==25.3.0
+ # via flytekit
+markupsafe==2.1.2
# via jinja2
marshmallow==3.19.0
# via
@@ -106,8 +123,11 @@ natsort==8.2.0
numpy==1.21.6
# via
# -r requirements.in
+ # flytekit
# pandas
# pyarrow
+oauthlib==3.2.2
+ # via requests-oauthlib
packaging==21.3
# via
# -r requirements.in
@@ -117,7 +137,7 @@ pandas==1.3.5
# via
# -r requirements.in
# flytekit
-protobuf==4.21.10
+protobuf==4.21.12
# via
# flyteidl
# googleapis-common-protos
@@ -129,27 +149,34 @@ py==1.11.0
# via retry
pyarrow==10.0.1
# via flytekit
+pyasn1==0.4.8
+ # via
+ # pyasn1-modules
+ # rsa
+pyasn1-modules==0.2.8
+ # via google-auth
pycparser==2.21
# via cffi
-pyopenssl==22.1.0
+pyopenssl==23.0.0
# via flytekit
pyparsing==3.0.9
# via packaging
-pyrsistent==0.19.2
+pyrsistent==0.19.3
# via jsonschema
python-dateutil==2.8.2
# via
# arrow
# croniter
# flytekit
+ # kubernetes
# pandas
python-json-logger==2.0.4
# via flytekit
-python-slugify==7.0.0
+python-slugify==8.0.0
# via cookiecutter
pytimeparse==1.1.8
# via flytekit
-pytz==2022.6
+pytz==2022.7.1
# via
# flytekit
# pandas
@@ -158,27 +185,38 @@ pyyaml==5.4.1
# -r requirements.in
# cookiecutter
# flytekit
+ # kubernetes
regex==2022.10.31
# via docker-image-py
-requests==2.28.1
+requests==2.28.2
# via
# cookiecutter
# docker
# flytekit
+ # kubernetes
+ # requests-oauthlib
# responses
+requests-oauthlib==1.3.1
+ # via kubernetes
responses==0.22.0
# via flytekit
retry==0.9.2
# via flytekit
+rsa==4.9
+ # via google-auth
secretstorage==3.3.3
# via keyring
singledispatchmethod==1.0
# via flytekit
six==1.16.0
# via
+ # google-auth
# jsonschema
+ # kubernetes
# python-dateutil
# websocket-client
+smmap==5.0.0
+ # via gitdb
sortedcontainers==2.4.0
# via flytekit
statsd==3.3.0
@@ -191,28 +229,36 @@ types-toml==0.10.8.1
# via responses
typing-extensions==4.4.0
# via
+ # arrow
# flytekit
+ # gitpython
+ # importlib-metadata
+ # responses
# typing-inspect
typing-inspect==0.8.0
# via dataclasses-json
-urllib3==1.26.13
+urllib3==1.26.14
# via
# docker
# flytekit
+ # kubernetes
# requests
# responses
websocket-client==0.59.0
# via
# -r requirements.in
# docker
+ # kubernetes
wheel==0.38.4
# via flytekit
wrapt==1.14.1
# via
# deprecated
# flytekit
-zipp==3.11.0
- # via importlib-metadata
+zipp==3.12.0
+ # via
+ # importlib-metadata
+ # importlib-resources
# The following packages are considered to be unsafe in a requirements file:
# setuptools
diff --git a/setup.py b/setup.py
index b42f8a8490..f560e6c112 100644
--- a/setup.py
+++ b/setup.py
@@ -40,7 +40,7 @@
},
install_requires=[
"googleapis-common-protos>=1.57",
- "flyteidl>=1.3.0,<1.4.0",
+ "flyteidl>=1.3.5,<1.4.0",
"wheel>=0.30.0,<1.0.0",
"pandas>=1.0.0,<2.0.0",
"pyarrow>=4.0.0,<11.0.0",
@@ -83,6 +83,7 @@
# aliases. More details in https://github.com/flyteorg/flyte/issues/3166
"numpy<1.24.0",
"gitpython",
+ "kubernetes>=12.0.1",
],
extras_require=extras_require,
scripts=[
diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py
index 4b48fbcfb9..d5b07fe420 100644
--- a/tests/flytekit/common/parameterizers.py
+++ b/tests/flytekit/common/parameterizers.py
@@ -121,8 +121,9 @@
discovery_version,
deprecated,
cache_serializable,
+ pod_template_name,
)
- for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable in product(
+ for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable, pod_template_name in product(
[True, False],
LIST_OF_RUNTIME_METADATA,
[timedelta(days=i) for i in range(3)],
@@ -131,6 +132,7 @@
["1.0"],
["deprecated"],
[True, False],
+ ["A", "B"],
)
]
diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt
index b8a781224a..7bd27f438b 100644
--- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt
+++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt
@@ -14,19 +14,19 @@ cffi==1.15.1
# via cryptography
chardet==5.1.0
# via binaryornot
-charset-normalizer==2.1.1
+charset-normalizer==3.0.1
# via requests
click==8.1.3
# via
# cookiecutter
# flytekit
-cloudpickle==2.2.0
+cloudpickle==2.2.1
# via flytekit
cookiecutter==2.1.1
# via flytekit
croniter==1.3.8
# via flytekit
-cryptography==38.0.4
+cryptography==39.0.0
# via
# pyopenssl
# secretstorage
@@ -46,29 +46,36 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
-flyteidl==1.2.5
+flyteidl==1.3.5
# via flytekit
-flytekit==1.2.5
+flytekit==1.3.1
# via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in
fonttools==4.38.0
# via matplotlib
-googleapis-common-protos==1.57.0
+gitdb==4.0.10
+ # via gitpython
+gitpython==3.1.30
+ # via flytekit
+googleapis-common-protos==1.58.0
# via
# flyteidl
+ # flytekit
# grpcio-status
-grpcio==1.48.2
+grpcio==1.51.1
# via
# flytekit
# grpcio-status
-grpcio-status==1.48.2
+grpcio-status==1.51.1
# via flytekit
idna==3.4
# via requests
-importlib-metadata==5.1.0
+importlib-metadata==6.0.0
# via
# click
# flytekit
# keyring
+importlib-resources==5.10.2
+ # via keyring
jaraco-classes==3.2.3
# via keyring
jeepney==0.8.0
@@ -85,11 +92,11 @@ joblib==1.2.0
# via
# -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in
# flytekit
-keyring==23.11.0
+keyring==23.13.1
# via flytekit
kiwisolver==1.4.4
# via matplotlib
-markupsafe==2.1.1
+markupsafe==2.1.2
# via jinja2
marshmallow==3.19.0
# via
@@ -108,28 +115,27 @@ mypy-extensions==0.4.3
# via typing-inspect
natsort==8.2.0
# via flytekit
-numpy==1.22.0
+numpy==1.21.6
# via
# flytekit
# matplotlib
# opencv-python
# pandas
# pyarrow
-opencv-python==4.6.0.66
+opencv-python==4.7.0.68
# via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in
-packaging==21.3
+packaging==23.0
# via
# docker
# marshmallow
# matplotlib
pandas==1.3.5
# via flytekit
-pillow==9.3.0
+pillow==9.4.0
# via matplotlib
-protobuf==3.20.3
+protobuf==4.21.12
# via
# flyteidl
- # flytekit
# googleapis-common-protos
# grpcio-status
# protoc-gen-swagger
@@ -141,12 +147,10 @@ pyarrow==10.0.1
# via flytekit
pycparser==2.21
# via cffi
-pyopenssl==22.1.0
+pyopenssl==23.0.0
# via flytekit
pyparsing==3.0.9
- # via
- # matplotlib
- # packaging
+ # via matplotlib
python-dateutil==2.8.2
# via
# arrow
@@ -156,11 +160,11 @@ python-dateutil==2.8.2
# pandas
python-json-logger==2.0.4
# via flytekit
-python-slugify==7.0.0
+python-slugify==8.0.0
# via cookiecutter
pytimeparse==1.1.8
# via flytekit
-pytz==2022.6
+pytz==2022.7.1
# via
# flytekit
# pandas
@@ -170,7 +174,7 @@ pyyaml==6.0
# flytekit
regex==2022.10.31
# via docker-image-py
-requests==2.28.1
+requests==2.28.2
# via
# cookiecutter
# docker
@@ -185,9 +189,9 @@ secretstorage==3.3.3
singledispatchmethod==1.0
# via flytekit
six==1.16.0
- # via
- # grpcio
- # python-dateutil
+ # via python-dateutil
+smmap==5.0.0
+ # via gitdb
sortedcontainers==2.4.0
# via flytekit
statsd==3.3.0
@@ -202,19 +206,20 @@ typing-extensions==4.4.0
# via
# arrow
# flytekit
+ # gitpython
# importlib-metadata
# kiwisolver
# responses
# typing-inspect
typing-inspect==0.8.0
# via dataclasses-json
-urllib3==1.26.13
+urllib3==1.26.14
# via
# docker
# flytekit
# requests
# responses
-websocket-client==1.4.2
+websocket-client==1.5.0
# via docker
wheel==0.38.4
# via
@@ -224,5 +229,7 @@ wrapt==1.14.1
# via
# deprecated
# flytekit
-zipp==3.11.0
- # via importlib-metadata
+zipp==3.12.0
+ # via
+ # importlib-metadata
+ # importlib-resources
diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py
index 108b42ebf7..24856432b4 100644
--- a/tests/flytekit/unit/core/test_python_auto_container.py
+++ b/tests/flytekit/unit/core/test_python_auto_container.py
@@ -1,9 +1,14 @@
from typing import Any
import pytest
+from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements, V1Volume
from flytekit.configuration import Image, ImageConfig, SerializationSettings
+from flytekit.core.base_task import TaskMetadata
+from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image
+from flytekit.core.resources import Resources
+from flytekit.tools.translator import get_serializable_task
@pytest.fixture
@@ -36,7 +41,6 @@ def execute(self, **kwargs) -> Any:
task = DummyAutoContainerTask(name="x", task_config=None, task_type="t")
-task_with_env_vars = DummyAutoContainerTask(name="x", environment={"HAM": "spam"}, task_config=None, task_type="t")
def test_default_command(default_serialization_settings):
@@ -68,14 +72,227 @@ def test_get_container(default_serialization_settings):
assert c.image == "docker.io/xyz:some-git-hash"
assert c.env == {"FOO": "bar"}
+ ts = get_serializable_task(default_serialization_settings, task)
+ assert ts.template.container.image == "docker.io/xyz:some-git-hash"
+ assert ts.template.container.env == {"FOO": "bar"}
+
+
+task_with_env_vars = DummyAutoContainerTask(name="x", environment={"HAM": "spam"}, task_config=None, task_type="t")
+
def test_get_container_with_task_envvars(default_serialization_settings):
c = task_with_env_vars.get_container(default_serialization_settings)
assert c.image == "docker.io/xyz:some-git-hash"
assert c.env == {"FOO": "bar", "HAM": "spam"}
+ ts = get_serializable_task(default_serialization_settings, task_with_env_vars)
+ assert ts.template.container.image == "docker.io/xyz:some-git-hash"
+ assert ts.template.container.env == {"FOO": "bar", "HAM": "spam"}
+
def test_get_container_without_serialization_settings_envvars(minimal_serialization_settings):
c = task_with_env_vars.get_container(minimal_serialization_settings)
assert c.image == "docker.io/xyz:some-git-hash"
assert c.env == {"HAM": "spam"}
+
+ ts = get_serializable_task(minimal_serialization_settings, task_with_env_vars)
+ assert ts.template.container.image == "docker.io/xyz:some-git-hash"
+ assert ts.template.container.env == {"HAM": "spam"}
+
+
+task_with_pod_template = DummyAutoContainerTask(
+ name="x",
+ metadata=TaskMetadata(
+ pod_template_name="podTemplateB", # should be overwritten
+ retries=3, # ensure other fields still exists
+ ),
+ task_config=None,
+ task_type="t",
+ container_image="repo/image:0.0.0",
+ requests=Resources(cpu="3", gpu="1"),
+ limits=Resources(cpu="6", gpu="2"),
+ environment={"eKeyA": "eValA", "eKeyB": "vKeyB"},
+ pod_template=PodTemplate(
+ primary_container_name="primary",
+ labels={"lKeyA": "lValA", "lKeyB": "lValB"},
+ annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
+ pod_spec=V1PodSpec(
+ containers=[
+ V1Container(
+ name="notPrimary",
+ ),
+ V1Container(
+ name="primary",
+ image="repo/placeholderImage:0.0.0",
+ command="placeholderCommand",
+ args="placeholderArgs",
+ resources=V1ResourceRequirements(limits={"cpu": "999", "gpu": "999"}),
+ env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
+ ),
+ ],
+ volumes=[V1Volume(name="volume")],
+ ),
+ ),
+ pod_template_name="podTemplateA",
+)
+
+
+def test_pod_template(default_serialization_settings):
+ #################
+ # Test get_k8s_pod
+ #################
+
+ container = task_with_pod_template.get_container(default_serialization_settings)
+ assert container is None
+
+ k8s_pod = task_with_pod_template.get_k8s_pod(default_serialization_settings)
+
+ # labels/annotations should be passed
+ metadata = k8s_pod.metadata
+ assert metadata.labels == {"lKeyA": "lValA", "lKeyB": "lValB"}
+ assert metadata.annotations == {"aKeyA": "aValA", "aKeyB": "aValB"}
+
+ pod_spec = k8s_pod.pod_spec
+ primary_container = pod_spec["containers"][1]
+
+ # To test overwritten attributes
+
+ # image
+ assert primary_container["image"] == "repo/image:0.0.0"
+ # command
+ assert primary_container["command"] == []
+ # args
+ assert primary_container["args"] == [
+ "pyflyte-execute",
+ "--inputs",
+ "{{.input}}",
+ "--output-prefix",
+ "{{.outputPrefix}}",
+ "--raw-output-data-prefix",
+ "{{.rawOutputDataPrefix}}",
+ "--checkpoint-path",
+ "{{.checkpointOutputPrefix}}",
+ "--prev-checkpoint",
+ "{{.prevCheckpointPrefix}}",
+ "--resolver",
+ "flytekit.core.python_auto_container.default_task_resolver",
+ "--",
+ "task-module",
+ "tests.flytekit.unit.core.test_python_auto_container",
+ "task-name",
+ "task_with_pod_template",
+ ]
+ # resource
+ assert primary_container["resources"]["requests"] == {"cpu": "3", "gpu": "1"}
+ assert primary_container["resources"]["limits"] == {"cpu": "6", "gpu": "2"}
+
+ # To test union attributes
+ assert primary_container["env"] == [
+ {"name": "FOO", "value": "bar"},
+ {"name": "eKeyA", "value": "eValA"},
+ {"name": "eKeyB", "value": "vKeyB"},
+ {"name": "eKeyC", "value": "eValC"},
+ {"name": "eKeyD", "value": "eValD"},
+ ]
+
+ # To test not overwritten attributes
+ assert pod_spec["volumes"][0] == {"name": "volume"}
+
+ #################
+ # Test pod_template_name
+ #################
+ assert task_with_pod_template.metadata.pod_template_name == "podTemplateA"
+ assert task_with_pod_template.metadata.retries == 3
+
+ config = task_with_minimum_pod_template.get_config(default_serialization_settings)
+
+ #################
+ # Test config
+ #################
+ assert config == {"primary_container_name": "primary"}
+
+ #################
+ # Test Serialization
+ #################
+ ts = get_serializable_task(default_serialization_settings, task_with_pod_template)
+ assert ts.template.container is None
+ # k8s_pod content is already verified above, so only check the existence here
+ assert ts.template.k8s_pod is not None
+
+ assert ts.template.metadata.pod_template_name == "podTemplateA"
+ assert ts.template.metadata.retries.retries == 3
+ assert ts.template.config is not None
+
+
+task_with_minimum_pod_template = DummyAutoContainerTask(
+ name="x",
+ task_config=None,
+ task_type="t",
+ container_image="repo/image:0.0.0",
+ pod_template=PodTemplate(
+ primary_container_name="primary",
+ labels={"lKeyA": "lValA"},
+ annotations={"aKeyA": "aValA"},
+ ),
+ pod_template_name="A",
+)
+
+
+def test_minimum_pod_template(default_serialization_settings):
+
+ #################
+ # Test get_k8s_pod
+ #################
+
+ container = task_with_minimum_pod_template.get_container(default_serialization_settings)
+ assert container is None
+
+ k8s_pod = task_with_minimum_pod_template.get_k8s_pod(default_serialization_settings)
+
+ metadata = k8s_pod.metadata
+ assert metadata.labels == {"lKeyA": "lValA"}
+ assert metadata.annotations == {"aKeyA": "aValA"}
+
+ pod_spec = k8s_pod.pod_spec
+ primary_container = pod_spec["containers"][0]
+
+ assert primary_container["image"] == "repo/image:0.0.0"
+ assert primary_container["command"] == []
+ assert primary_container["args"] == [
+ "pyflyte-execute",
+ "--inputs",
+ "{{.input}}",
+ "--output-prefix",
+ "{{.outputPrefix}}",
+ "--raw-output-data-prefix",
+ "{{.rawOutputDataPrefix}}",
+ "--checkpoint-path",
+ "{{.checkpointOutputPrefix}}",
+ "--prev-checkpoint",
+ "{{.prevCheckpointPrefix}}",
+ "--resolver",
+ "flytekit.core.python_auto_container.default_task_resolver",
+ "--",
+ "task-module",
+ "tests.flytekit.unit.core.test_python_auto_container",
+ "task-name",
+ "task_with_minimum_pod_template",
+ ]
+
+ config = task_with_minimum_pod_template.get_config(default_serialization_settings)
+ assert config == {"primary_container_name": "primary"}
+
+ #################
+ # Test pod_teamplte_name
+ #################
+ assert task_with_minimum_pod_template.metadata.pod_template_name == "A"
+
+ #################
+ # Test Serialization
+ #################
+ ts = get_serializable_task(default_serialization_settings, task_with_minimum_pod_template)
+ assert ts.template.container is None
+ # k8s_pod content is already verified above, so only check the existence here
+ assert ts.template.k8s_pod is not None
+ assert ts.template.metadata.pod_template_name == "A"
+ assert ts.template.config is not None
diff --git a/tests/flytekit/unit/core/test_python_function_task.py b/tests/flytekit/unit/core/test_python_function_task.py
index 34aaefaeb3..02e04a302f 100644
--- a/tests/flytekit/unit/core/test_python_function_task.py
+++ b/tests/flytekit/unit/core/test_python_function_task.py
@@ -1,10 +1,13 @@
import pytest
+from kubernetes.client.models import V1Container, V1PodSpec
from flytekit import task
from flytekit.configuration import Image, ImageConfig, SerializationSettings
+from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.tracker import isnested, istestfunction
+from flytekit.tools.translator import get_serializable_task
from tests.flytekit.unit.core import tasks
@@ -122,3 +125,83 @@ def foo_missing_cache_version(i: str):
@task(cache_serialize=True)
def foo_missing_cache(i: str):
print(f"{i}")
+
+
+def test_pod_template():
+ @task(
+ container_image="repo/image:0.0.0",
+ pod_template=PodTemplate(
+ primary_container_name="primary",
+ labels={"lKeyA": "lValA"},
+ annotations={"aKeyA": "aValA"},
+ pod_spec=V1PodSpec(
+ containers=[
+ V1Container(
+ name="primary",
+ ),
+ ]
+ ),
+ ),
+ pod_template_name="A",
+ )
+ def func_with_pod_template(i: str):
+ print(i + 3)
+
+ default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash")
+ default_image_config = ImageConfig(default_image=default_image)
+ default_serialization_settings = SerializationSettings(
+ project="p", domain="d", version="v", image_config=default_image_config
+ )
+
+ #################
+ # Test get_k8s_pod
+ #################
+
+ container = func_with_pod_template.get_container(default_serialization_settings)
+ assert container is None
+
+ k8s_pod = func_with_pod_template.get_k8s_pod(default_serialization_settings)
+
+ metadata = k8s_pod.metadata
+ assert metadata.labels == {"lKeyA": "lValA"}
+ assert metadata.annotations == {"aKeyA": "aValA"}
+
+ pod_spec = k8s_pod.pod_spec
+ primary_container = pod_spec["containers"][0]
+
+ assert primary_container["image"] == "repo/image:0.0.0"
+ assert primary_container["command"] == []
+ assert primary_container["args"] == [
+ "pyflyte-execute",
+ "--inputs",
+ "{{.input}}",
+ "--output-prefix",
+ "{{.outputPrefix}}",
+ "--raw-output-data-prefix",
+ "{{.rawOutputDataPrefix}}",
+ "--checkpoint-path",
+ "{{.checkpointOutputPrefix}}",
+ "--prev-checkpoint",
+ "{{.prevCheckpointPrefix}}",
+ "--resolver",
+ "flytekit.core.python_auto_container.default_task_resolver",
+ "--",
+ "task-module",
+ "tests.flytekit.unit.core.test_python_function_task",
+ "task-name",
+ "func_with_pod_template",
+ ]
+
+ #################
+ # Test pod_teamplte_name
+ #################
+ assert func_with_pod_template.metadata.pod_template_name == "A"
+
+ #################
+ # Test Serialization
+ #################
+ ts = get_serializable_task(default_serialization_settings, func_with_pod_template)
+ assert ts.template.container is None
+ # k8s_pod content is already verified above, so only check the existence here
+ assert ts.template.k8s_pod is not None
+ assert ts.template.metadata.pod_template_name == "A"
diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py
index fed32b63aa..a979a39b66 100644
--- a/tests/flytekit/unit/models/test_tasks.py
+++ b/tests/flytekit/unit/models/test_tasks.py
@@ -71,6 +71,7 @@ def test_task_metadata():
"0.1.1b0",
"This is deprecated!",
True,
+ "A",
)
assert obj.discoverable is True
@@ -82,6 +83,7 @@ def test_task_metadata():
assert obj.runtime.version == "1.0.0"
assert obj.deprecated_error_message == "This is deprecated!"
assert obj.discovery_version == "0.1.1b0"
+ assert obj.pod_template_name == "A"
assert obj == task.TaskMetadata.from_flyte_idl(obj.to_flyte_idl())
@@ -134,6 +136,7 @@ def test_task_spec():
"0.1.1b0",
"This is deprecated!",
True,
+ "A",
)
int_type = types.LiteralType(types.SimpleType.INTEGER)
@@ -178,7 +181,7 @@ def test_task_spec():
assert obj.template == template
-def test_task_template__k8s_pod_target():
+def test_task_template_k8s_pod_target():
int_type = types.LiteralType(types.SimpleType.INTEGER)
obj = task.TaskTemplate(
identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"),
@@ -192,6 +195,7 @@ def test_task_template__k8s_pod_target():
"1.0",
"deprecated",
False,
+ "A",
),
interface_models.TypedInterface(
# inputs
diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py
index d229d0d5c9..2b5b06696b 100644
--- a/tests/flytekit/unit/models/test_workflow_closure.py
+++ b/tests/flytekit/unit/models/test_workflow_closure.py
@@ -38,6 +38,7 @@ def test_workflow_closure():
"0.1.1b0",
"This is deprecated!",
True,
+ "A",
)
cpu_resource = _task.Resources.ResourceEntry(_task.Resources.ResourceName.CPU, "1")