[![PyPI version fury.io](https://badge.fury.io/py/flytekit.svg)](https://pypi.python.org/pypi/flytekit/)
[![PyPI download day](https://img.shields.io/pypi/dd/flytekit.svg)](https://pypi.python.org/pypi/flytekit/)
diff --git a/dev-requirements.in b/dev-requirements.in
index 09f3b90c46..b2e6cf74a1 100644
--- a/dev-requirements.in
+++ b/dev-requirements.in
@@ -11,3 +11,4 @@ codespell
google-cloud-bigquery
google-cloud-bigquery-storage
IPython
+torch
diff --git a/dev-requirements.txt b/dev-requirements.txt
index 419503beff..3705a578c0 100644
--- a/dev-requirements.txt
+++ b/dev-requirements.txt
@@ -28,13 +28,13 @@ binaryornot==0.4.4
# cookiecutter
cached-property==1.5.2
# via docker-compose
-cachetools==5.1.0
+cachetools==5.2.0
# via google-auth
-certifi==2022.5.18.1
+certifi==2022.6.15
# via
# -c requirements.txt
# requests
-cffi==1.15.0
+cffi==1.15.1
# via
# -c requirements.txt
# bcrypt
@@ -42,11 +42,11 @@ cffi==1.15.0
# pynacl
cfgv==3.3.1
# via pre-commit
-chardet==4.0.0
+chardet==5.0.0
# via
# -c requirements.txt
# binaryornot
-charset-normalizer==2.0.12
+charset-normalizer==2.1.0
# via
# -c requirements.txt
# requests
@@ -61,20 +61,21 @@ cloudpickle==2.1.0
# flytekit
codespell==2.1.0
# via -r dev-requirements.in
-cookiecutter==1.7.3
+cookiecutter==2.1.1
# via
# -c requirements.txt
# flytekit
-coverage[toml]==6.3.3
+coverage[toml]==6.4.1
# via -r dev-requirements.in
croniter==1.3.5
# via
# -c requirements.txt
# flytekit
-cryptography==37.0.2
+cryptography==37.0.4
# via
# -c requirements.txt
# paramiko
+ # pyopenssl
# secretstorage
dataclasses-json==0.5.7
# via
@@ -118,47 +119,47 @@ docstring-parser==0.14.1
# via
# -c requirements.txt
# flytekit
-filelock==3.7.0
+filelock==3.7.1
# via virtualenv
-flyteidl==1.0.1
+flyteidl==1.1.8
# via
# -c requirements.txt
# flytekit
-google-api-core[grpc]==2.8.0
+google-api-core[grpc]==2.8.2
# via
# google-cloud-bigquery
# google-cloud-bigquery-storage
# google-cloud-core
-google-auth==2.6.6
+google-auth==2.9.0
# via
# google-api-core
# google-cloud-core
-google-cloud-bigquery==3.1.0
+google-cloud-bigquery==3.2.0
# via -r dev-requirements.in
-google-cloud-bigquery-storage==2.13.1
+google-cloud-bigquery-storage==2.13.2
# via
# -r dev-requirements.in
# google-cloud-bigquery
-google-cloud-core==2.3.0
+google-cloud-core==2.3.1
# via google-cloud-bigquery
google-crc32c==1.3.0
# via google-resumable-media
google-resumable-media==2.3.3
# via google-cloud-bigquery
-googleapis-common-protos==1.56.1
+googleapis-common-protos==1.56.3
# via
# -c requirements.txt
# flyteidl
# google-api-core
# grpcio-status
-grpcio==1.46.1
+grpcio==1.47.0
# via
# -c requirements.txt
# flytekit
# google-api-core
# google-cloud-bigquery
# grpcio-status
-grpcio-status==1.46.1
+grpcio-status==1.47.0
# via
# -c requirements.txt
# flytekit
@@ -169,10 +170,11 @@ idna==3.3
# via
# -c requirements.txt
# requests
-importlib-metadata==4.11.3
+importlib-metadata==4.12.0
# via
# -c requirements.txt
# click
+ # flytekit
# jsonschema
# keyring
# pluggy
@@ -181,7 +183,7 @@ importlib-metadata==4.11.3
# virtualenv
iniconfig==1.1.1
# via pytest
-ipython==7.33.0
+ipython==7.34.0
# via -r dev-requirements.in
jedi==0.18.1
# via ipython
@@ -206,7 +208,7 @@ jsonschema==3.2.0
# via
# -c requirements.txt
# docker-compose
-keyring==23.5.0
+keyring==23.6.0
# via
# -c requirements.txt
# flytekit
@@ -214,7 +216,7 @@ markupsafe==2.1.1
# via
# -c requirements.txt
# jinja2
-marshmallow==3.15.0
+marshmallow==3.17.0
# via
# -c requirements.txt
# dataclasses-json
@@ -232,7 +234,7 @@ matplotlib-inline==0.1.3
# via ipython
mock==4.0.3
# via -r dev-requirements.in
-mypy==0.950
+mypy==0.961
# via -r dev-requirements.in
mypy-extensions==0.4.3
# via
@@ -243,7 +245,7 @@ natsort==8.1.0
# via
# -c requirements.txt
# flytekit
-nodeenv==1.6.0
+nodeenv==1.7.0
# via pre-commit
numpy==1.21.6
# via
@@ -273,15 +275,11 @@ platformdirs==2.5.2
# via virtualenv
pluggy==1.0.0
# via pytest
-poyo==0.5.0
- # via
- # -c requirements.txt
- # cookiecutter
pre-commit==2.19.0
# via -r dev-requirements.in
-prompt-toolkit==3.0.29
+prompt-toolkit==3.0.30
# via ipython
-proto-plus==1.20.4
+proto-plus==1.20.6
# via
# google-cloud-bigquery
# google-cloud-bigquery-storage
@@ -292,6 +290,7 @@ protobuf==3.20.1
# flytekit
# google-api-core
# google-cloud-bigquery
+ # google-cloud-bigquery-storage
# googleapis-common-protos
# grpcio-status
# proto-plus
@@ -326,6 +325,10 @@ pygments==2.12.0
# via ipython
pynacl==1.5.0
# via paramiko
+pyopenssl==22.0.0
+ # via
+ # -c requirements.txt
+ # flytekit
pyparsing==3.0.9
# via
# -c requirements.txt
@@ -373,14 +376,15 @@ pytz==2022.1
pyyaml==5.4.1
# via
# -c requirements.txt
+ # cookiecutter
# docker-compose
# flytekit
# pre-commit
-regex==2022.4.24
+regex==2022.6.2
# via
# -c requirements.txt
# docker-image-py
-requests==2.27.1
+requests==2.28.1
# via
# -c requirements.txt
# cookiecutter
@@ -390,7 +394,7 @@ requests==2.27.1
# google-api-core
# google-cloud-bigquery
# responses
-responses==0.20.0
+responses==0.21.0
# via
# -c requirements.txt
# flytekit
@@ -411,7 +415,6 @@ singledispatchmethod==1.0
six==1.16.0
# via
# -c requirements.txt
- # cookiecutter
# dockerpty
# google-auth
# grpcio
@@ -441,19 +444,23 @@ tomli==2.0.1
# coverage
# mypy
# pytest
-traitlets==5.2.1.post0
+torch==1.11.0
+ # via -r dev-requirements.in
+traitlets==5.3.0
# via
# ipython
# matplotlib-inline
-typed-ast==1.5.3
+typed-ast==1.5.4
# via mypy
-typing-extensions==4.2.0
+typing-extensions==4.3.0
# via
# -c requirements.txt
# arrow
# flytekit
# importlib-metadata
# mypy
+ # responses
+ # torch
# typing-inspect
typing-inspect==0.7.1
# via
@@ -465,7 +472,7 @@ urllib3==1.26.9
# flytekit
# requests
# responses
-virtualenv==20.14.1
+virtualenv==20.15.1
# via pre-commit
wcwidth==0.2.5
# via prompt-toolkit
diff --git a/doc-requirements.in b/doc-requirements.in
index 4d60a6919b..760d3903dc 100644
--- a/doc-requirements.in
+++ b/doc-requirements.in
@@ -33,3 +33,4 @@ papermill # papermill
jupyter # papermill
pyspark # spark
sqlalchemy # sqlalchemy
+torch # pytorch
diff --git a/doc-requirements.txt b/doc-requirements.txt
index 6032cc31f4..e9f16c89a0 100644
--- a/doc-requirements.txt
+++ b/doc-requirements.txt
@@ -18,13 +18,13 @@ argon2-cffi-bindings==21.2.0
# via argon2-cffi
arrow==1.2.2
# via jinja2-time
-astroid==2.11.5
+astroid==2.11.6
# via sphinx-autoapi
attrs==21.4.0
# via
# jsonschema
# visions
-babel==2.10.1
+babel==2.10.3
# via sphinx
backcall==0.2.0
# via ipython
@@ -40,23 +40,23 @@ beautifulsoup4==4.11.1
# sphinx-material
binaryornot==0.4.4
# via cookiecutter
-bleach==5.0.0
+bleach==5.0.1
# via nbconvert
-botocore==1.26.5
+botocore==1.27.22
# via -r doc-requirements.in
-cachetools==5.1.0
+cachetools==5.2.0
# via google-auth
-certifi==2022.5.18.1
+certifi==2022.6.15
# via
# kubernetes
# requests
-cffi==1.15.0
+cffi==1.15.1
# via
# argon2-cffi-bindings
# cryptography
-chardet==4.0.0
+chardet==5.0.0
# via binaryornot
-charset-normalizer==2.0.12
+charset-normalizer==2.1.0
# via requests
click==8.1.3
# via
@@ -66,16 +66,17 @@ click==8.1.3
# papermill
cloudpickle==2.1.0
# via flytekit
-colorama==0.4.4
+colorama==0.4.5
# via great-expectations
-cookiecutter==1.7.3
+cookiecutter==2.1.1
# via flytekit
croniter==1.3.5
# via flytekit
-cryptography==36.0.2
+cryptography==37.0.4
# via
# -r doc-requirements.in
# great-expectations
+ # pyopenssl
# secretstorage
css-html-js-minify==2.5.5
# via sphinx-material
@@ -119,7 +120,7 @@ entrypoints==0.4
# papermill
fastjsonschema==2.15.3
# via nbformat
-flyteidl==1.0.1
+flyteidl==1.1.8
# via flytekit
fonttools==4.33.3
# via matplotlib
@@ -129,45 +130,45 @@ fsspec==2022.5.0
# modin
furo @ git+https://github.com/flyteorg/furo@main
# via -r doc-requirements.in
-google-api-core[grpc]==2.8.0
+google-api-core[grpc]==2.8.2
# via
# google-cloud-bigquery
# google-cloud-bigquery-storage
# google-cloud-core
-google-auth==2.6.6
+google-auth==2.9.0
# via
# google-api-core
# google-cloud-core
# kubernetes
google-cloud==0.34.0
# via -r doc-requirements.in
-google-cloud-bigquery==3.1.0
+google-cloud-bigquery==3.2.0
# via -r doc-requirements.in
-google-cloud-bigquery-storage==2.13.1
+google-cloud-bigquery-storage==2.13.2
# via google-cloud-bigquery
-google-cloud-core==2.3.0
+google-cloud-core==2.3.1
# via google-cloud-bigquery
google-crc32c==1.3.0
# via google-resumable-media
google-resumable-media==2.3.3
# via google-cloud-bigquery
-googleapis-common-protos==1.56.1
+googleapis-common-protos==1.56.3
# via
# flyteidl
# google-api-core
# grpcio-status
-great-expectations==0.15.6
+great-expectations==0.15.12
# via -r doc-requirements.in
greenlet==1.1.2
# via sqlalchemy
-grpcio==1.46.1
+grpcio==1.47.0
# via
# -r doc-requirements.in
# flytekit
# google-api-core
# google-cloud-bigquery
# grpcio-status
-grpcio-status==1.46.1
+grpcio-status==1.47.0
# via
# flytekit
# google-api-core
@@ -177,27 +178,28 @@ idna==3.3
# via requests
imagehash==4.2.1
# via visions
-imagesize==1.3.0
+imagesize==1.4.1
# via sphinx
-importlib-metadata==4.11.3
+importlib-metadata==4.12.0
# via
# click
+ # flytekit
# great-expectations
# jsonschema
# keyring
# markdown
# sphinx
# sqlalchemy
-importlib-resources==5.7.1
+importlib-resources==5.8.0
# via jsonschema
-ipykernel==6.13.0
+ipykernel==6.15.0
# via
# ipywidgets
# jupyter
# jupyter-console
# notebook
# qtconsole
-ipython==7.33.0
+ipython==7.34.0
# via
# great-expectations
# ipykernel
@@ -208,7 +210,7 @@ ipython-genutils==0.2.0
# ipywidgets
# notebook
# qtconsole
-ipywidgets==7.7.0
+ipywidgets==7.7.1
# via jupyter
jedi==0.18.1
# via ipython
@@ -216,7 +218,7 @@ jeepney==0.8.0
# via
# keyring
# secretstorage
-jinja2==3.0.3
+jinja2==3.1.2
# via
# altair
# cookiecutter
@@ -229,7 +231,7 @@ jinja2==3.0.3
# sphinx-autoapi
jinja2-time==0.2.0
# via cookiecutter
-jmespath==1.0.0
+jmespath==1.0.1
# via botocore
joblib==1.1.0
# via
@@ -239,21 +241,21 @@ jsonpatch==1.32
# via great-expectations
jsonpointer==2.3
# via jsonpatch
-jsonschema==4.5.1
+jsonschema==4.6.1
# via
# altair
# great-expectations
# nbformat
jupyter==1.0.0
# via -r doc-requirements.in
-jupyter-client==7.3.1
+jupyter-client==7.3.4
# via
# ipykernel
# jupyter-console
# nbclient
# notebook
# qtconsole
-jupyter-console==6.4.3
+jupyter-console==6.4.4
# via jupyter
jupyter-core==4.10.0
# via
@@ -264,17 +266,17 @@ jupyter-core==4.10.0
# qtconsole
jupyterlab-pygments==0.2.2
# via nbconvert
-jupyterlab-widgets==1.1.0
+jupyterlab-widgets==1.1.1
# via ipywidgets
-keyring==23.5.0
+keyring==23.6.0
# via flytekit
-kiwisolver==1.4.2
+kiwisolver==1.4.3
# via matplotlib
-kubernetes==23.6.0
+kubernetes==24.2.0
# via -r doc-requirements.in
lazy-object-proxy==1.7.1
# via astroid
-lxml==4.8.0
+lxml==4.9.1
# via sphinx-material
markdown==3.3.7
# via -r doc-requirements.in
@@ -283,7 +285,7 @@ markupsafe==2.1.1
# jinja2
# nbconvert
# pandas-profiling
-marshmallow==3.15.0
+marshmallow==3.17.0
# via
# dataclasses-json
# marshmallow-enum
@@ -318,7 +320,7 @@ mypy-extensions==0.4.3
# via typing-inspect
natsort==8.1.0
# via flytekit
-nbclient==0.6.3
+nbclient==0.6.6
# via
# nbconvert
# papermill
@@ -329,7 +331,6 @@ nbconvert==6.5.0
nbformat==5.4.0
# via
# great-expectations
- # ipywidgets
# nbclient
# nbconvert
# notebook
@@ -342,7 +343,7 @@ nest-asyncio==1.5.5
# notebook
networkx==2.6.3
# via visions
-notebook==6.4.11
+notebook==6.4.12
# via
# great-expectations
# jupyter
@@ -407,22 +408,20 @@ phik==0.12.2
# via pandas-profiling
pickleshare==0.7.5
# via ipython
-pillow==9.1.1
+pillow==9.2.0
# via
# imagehash
# matplotlib
# visions
-plotly==5.8.0
+plotly==5.9.0
# via -r doc-requirements.in
-poyo==0.5.0
- # via cookiecutter
prometheus-client==0.14.1
# via notebook
-prompt-toolkit==3.0.29
+prompt-toolkit==3.0.30
# via
# ipython
# jupyter-console
-proto-plus==1.20.4
+proto-plus==1.20.6
# via
# google-cloud-bigquery
# google-cloud-bigquery-storage
@@ -432,13 +431,14 @@ protobuf==3.20.1
# flytekit
# google-api-core
# google-cloud-bigquery
+ # google-cloud-bigquery-storage
# googleapis-common-protos
# grpcio-status
# proto-plus
# protoc-gen-swagger
protoc-gen-swagger==0.1.0
# via flyteidl
-psutil==5.9.0
+psutil==5.9.1
# via ipykernel
ptyprocess==0.7.0
# via
@@ -446,7 +446,7 @@ ptyprocess==0.7.0
# terminado
py==1.11.0
# via retry
-py4j==0.10.9.3
+py4j==0.10.9.5
# via pyspark
pyarrow==6.0.1
# via
@@ -467,12 +467,15 @@ pydantic==1.9.1
# pandera
pygments==2.12.0
# via
+ # furo
# ipython
# jupyter-console
# nbconvert
# qtconsole
# sphinx
# sphinx-prompt
+pyopenssl==22.0.0
+ # via flytekit
pyparsing==2.4.7
# via
# great-expectations
@@ -480,7 +483,7 @@ pyparsing==2.4.7
# packaging
pyrsistent==0.18.1
# via jsonschema
-pyspark==3.2.1
+pyspark==3.3.0
# via -r doc-requirements.in
python-dateutil==2.8.2
# via
@@ -514,23 +517,25 @@ pywavelets==1.3.0
# via imagehash
pyyaml==6.0
# via
+ # cookiecutter
# flytekit
# kubernetes
# pandas-profiling
# papermill
# sphinx-autoapi
-pyzmq==23.0.0
+pyzmq==23.2.0
# via
+ # ipykernel
# jupyter-client
# notebook
# qtconsole
-qtconsole==5.3.0
+qtconsole==5.3.1
# via jupyter
qtpy==2.1.0
# via qtconsole
-regex==2022.4.24
+regex==2022.6.2
# via docker-image-py
-requests==2.27.1
+requests==2.28.1
# via
# cookiecutter
# docker
@@ -546,7 +551,7 @@ requests==2.27.1
# sphinx
requests-oauthlib==1.3.1
# via kubernetes
-responses==0.20.0
+responses==0.21.0
# via flytekit
retry==0.9.2
# via flytekit
@@ -577,7 +582,6 @@ singledispatchmethod==1.0
six==1.16.0
# via
# bleach
- # cookiecutter
# google-auth
# grpcio
# imagehash
@@ -595,6 +599,7 @@ sphinx==4.5.0
# -r doc-requirements.in
# furo
# sphinx-autoapi
+ # sphinx-basic-ng
# sphinx-code-include
# sphinx-copybutton
# sphinx-fontawesome
@@ -605,6 +610,8 @@ sphinx==4.5.0
# sphinxcontrib-yt
sphinx-autoapi==1.8.4
# via -r doc-requirements.in
+sphinx-basic-ng==0.0.1a12
+ # via furo
sphinx-code-include==1.1.1
# via -r doc-requirements.in
sphinx-copybutton==0.5.0
@@ -633,7 +640,7 @@ sphinxcontrib-serializinghtml==1.1.5
# via sphinx
sphinxcontrib-yt==0.2.2
# via -r doc-requirements.in
-sqlalchemy==1.4.36
+sqlalchemy==1.4.39
# via -r doc-requirements.in
statsd==3.3.0
# via flytekit
@@ -657,7 +664,9 @@ tinycss2==1.1.1
# via nbconvert
toolz==0.11.2
# via altair
-tornado==6.1
+torch==1.11.0
+ # via -r doc-requirements.in
+tornado==6.2
# via
# ipykernel
# jupyter-client
@@ -668,7 +677,7 @@ tqdm==4.64.0
# great-expectations
# pandas-profiling
# papermill
-traitlets==5.2.1.post0
+traitlets==5.3.0
# via
# ipykernel
# ipython
@@ -681,9 +690,9 @@ traitlets==5.2.1.post0
# nbformat
# notebook
# qtconsole
-typed-ast==1.5.3
+typed-ast==1.5.4
# via astroid
-typing-extensions==4.2.0
+typing-extensions==4.3.0
# via
# argon2-cffi
# arrow
@@ -695,6 +704,8 @@ typing-extensions==4.2.0
# kiwisolver
# pandera
# pydantic
+ # responses
+ # torch
# typing-inspect
typing-inspect==0.7.1
# via
@@ -724,13 +735,13 @@ webencodings==0.5.1
# via
# bleach
# tinycss2
-websocket-client==1.3.2
+websocket-client==1.3.3
# via
# docker
# kubernetes
wheel==0.37.1
# via flytekit
-widgetsnbextension==3.6.0
+widgetsnbextension==3.6.1
# via ipywidgets
wrapt==1.14.1
# via
diff --git a/docs/source/data.extend.rst b/docs/source/data.extend.rst
index f200500fbc..3f06961022 100644
--- a/docs/source/data.extend.rst
+++ b/docs/source/data.extend.rst
@@ -1,7 +1,7 @@
##############################
Extend Data Persistence layer
##############################
-Flytekit provides a data persistence layer, which is used for recording metadata that is shared with backend Flyte. This persistence layer is available for various types to store raw user data and is designed to be cross-cloud compatible.
+Flytekit provides a data persistence layer, which is used for recording metadata that is shared with the Flyte backend. This persistence layer is available for various types to store raw user data and is designed to be cross-cloud compatible.
Moreover, it is designed to be extensible and users can bring their own data persistence plugins by following the persistence interface.
.. note::
@@ -16,3 +16,22 @@ Moreover, it is designed to be extensible and users can bring their own data per
:no-members:
:no-inherited-members:
:no-special-members:
+
+The ``fsspec`` Data Plugin
+--------------------------
+
+Flytekit ships with a default storage driver that uses aws-cli on AWS and gsutil on GCP. By default, Flyte uploads the task outputs to S3 or GCS using these storage drivers.
+
+Why ``fsspec``?
+^^^^^^^^^^^^^^^
+
+You can use the fsspec plugin implementation to utilize all its available plugins with flytekit. The `fsspec `_ plugin provides an implementation of the data persistence layer in Flytekit. For example: HDFS, FTP are supported in fsspec, so you can use them with flytekit too.
+The data persistence layer helps store logs of metadata and raw user data.
+As a consequence of the implementation, an S3 driver can be installed using ``pip install s3fs``.
+
+`Here `_ is a code snippet that shows protocols mapped to the class it implements.
+
+Once you install the plugin, it overrides all default implementations of the `DataPersistencePlugins `_ and provides the ones supported by fsspec.
+
+.. note::
+ This plugin installs fsspec core only. To install all the fsspec plugins, see `here `_.
diff --git a/docs/source/extras.pytorch.rst b/docs/source/extras.pytorch.rst
new file mode 100644
index 0000000000..12fd3d62d9
--- /dev/null
+++ b/docs/source/extras.pytorch.rst
@@ -0,0 +1,7 @@
+############
+PyTorch Type
+############
+.. automodule:: flytekit.extras.pytorch
+ :no-members:
+ :no-inherited-members:
+ :no-special-members:
diff --git a/docs/source/types.extend.rst b/docs/source/types.extend.rst
index f1b15455dd..f0cdff28dc 100644
--- a/docs/source/types.extend.rst
+++ b/docs/source/types.extend.rst
@@ -11,3 +11,4 @@ Feel free to follow the pattern of the built-in types.
types.builtins.structured
types.builtins.file
types.builtins.directory
+ extras.pytorch
diff --git a/flytekit/__init__.py b/flytekit/__init__.py
index b6bd104a2d..c67a8a04b4 100644
--- a/flytekit/__init__.py
+++ b/flytekit/__init__.py
@@ -182,6 +182,7 @@
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
+from flytekit.extras import pytorch
from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence
from flytekit.loggers import logger
from flytekit.models.common import Annotations, AuthRole, Labels
@@ -189,7 +190,7 @@
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
-from flytekit.types import directory, file, schema
+from flytekit.types import directory, file, numpy, schema
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
StructuredDatasetFormat,
diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py
index 68bd2578e5..1f8dd78ef0 100644
--- a/flytekit/bin/entrypoint.py
+++ b/flytekit/bin/entrypoint.py
@@ -243,6 +243,7 @@ def setup_execution(
tmp_dir=user_workspace_dir,
raw_output_prefix=raw_output_data_prefix,
checkpoint=checkpointer,
+ task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version),
)
try:
diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py
index 8db7c93a98..d542af5f7e 100644
--- a/flytekit/clients/friendly.py
+++ b/flytekit/clients/friendly.py
@@ -1004,3 +1004,17 @@ def get_upload_signed_url(
expires_in=expires_in_pb,
)
)
+
+ def get_download_signed_url(
+ self, native_url: str, expires_in: datetime.timedelta = None
+ ) -> _data_proxy_pb2.CreateUploadLocationResponse:
+ expires_in_pb = None
+ if expires_in:
+ expires_in_pb = Duration()
+ expires_in_pb.FromTimedelta(expires_in)
+ return super(SynchronousFlyteClient, self).create_download_location(
+ _data_proxy_pb2.CreateDownloadLocationRequest(
+ native_url=native_url,
+ expires_in=expires_in_pb,
+ )
+ )
diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py
index 9aba78888a..9bdf85a178 100644
--- a/flytekit/clients/raw.py
+++ b/flytekit/clients/raw.py
@@ -1,12 +1,14 @@
from __future__ import annotations
import base64 as _base64
+import ssl
import subprocess
import time
import typing
from typing import Optional
import grpc
+import OpenSSL
import requests as _requests
from flyteidl.admin.project_pb2 import ProjectListRequest
from flyteidl.service import admin_pb2_grpc as _admin_service
@@ -110,6 +112,26 @@ def __init__(self, cfg: PlatformConfig, **kwargs):
self._cfg = cfg
if cfg.insecure:
self._channel = grpc.insecure_channel(cfg.endpoint, **kwargs)
+ elif cfg.insecure_skip_verify:
+ # Get port from endpoint or use 443
+ endpoint_parts = cfg.endpoint.rsplit(":", 1)
+ if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
+ server_address = tuple(endpoint_parts)
+ else:
+ server_address = (cfg.endpoint, "443")
+
+ cert = ssl.get_server_certificate(server_address)
+ x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
+ cn = x509.get_subject().CN
+ credentials = grpc.ssl_channel_credentials(str.encode(cert))
+ options = kwargs.get("options", [])
+ options.append(("grpc.ssl_target_name_override", cn))
+ self._channel = grpc.secure_channel(
+ target=cfg.endpoint,
+ credentials=credentials,
+ options=options,
+ compression=kwargs.get("compression", None),
+ )
else:
if "credentials" not in kwargs:
credentials = grpc.ssl_channel_credentials(
@@ -251,7 +273,10 @@ def _refresh_credentials_from_command(self):
except subprocess.CalledProcessError as e:
cli_logger.error("Failed to generate token from command {}".format(command))
raise _user_exceptions.FlyteAuthenticationException("Problems refreshing token with command: " + str(e))
- self.set_access_token(output.stdout.strip())
+ authorization_header_key = self.public_client_config.authorization_metadata_key or None
+ if not authorization_header_key:
+ self.set_access_token(output.stdout.strip())
+ self.set_access_token(output.stdout.strip(), authorization_header_key)
def _refresh_credentials_noop(self):
pass
@@ -833,6 +858,17 @@ def create_upload_location(
"""
return self._dataproxy_stub.CreateUploadLocation(create_upload_location_request, metadata=self._metadata)
+ @_handle_rpc_error(retry=True)
+ def create_download_location(
+ self, create_download_location_request: _dataproxy_pb2.CreateDownloadLocationRequest
+ ) -> _dataproxy_pb2.CreateDownloadLocationResponse:
+ """
+ Get a signed url to be used during fast registration
+ :param flyteidl.service.dataproxy_pb2.CreateDownloadLocationRequest create_download_location_request:
+ :rtype: flyteidl.service.dataproxy_pb2.CreateDownloadLocationResponse
+ """
+ return self._dataproxy_stub.CreateDownloadLocation(create_download_location_request, metadata=self._metadata)
+
def get_token(token_endpoint, authorization_header, scope):
"""
diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py
index 3628af7beb..21aec1c4ad 100644
--- a/flytekit/clis/flyte_cli/main.py
+++ b/flytekit/clis/flyte_cli/main.py
@@ -4,6 +4,7 @@
import os as _os
import stat as _stat
import sys as _sys
+from dataclasses import replace
from typing import Callable, Dict, List, Tuple, Union
import click as _click
@@ -276,7 +277,7 @@ def _get_client(host: str, insecure: bool) -> _friendly_client.SynchronousFlyteC
if parent_ctx.obj["cacert"]:
kwargs["root_certificates"] = parent_ctx.obj["cacert"]
cfg = parent_ctx.obj["config"]
- cfg = cfg.with_parameters(endpoint=host, insecure=insecure)
+ cfg = replace(cfg, endpoint=host, insecure=insecure)
return _friendly_client.SynchronousFlyteClient(cfg, **kwargs)
diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py
index 73274e972e..d922f5e3c1 100644
--- a/flytekit/clis/helpers.py
+++ b/flytekit/clis/helpers.py
@@ -1,5 +1,7 @@
+import sys
from typing import Tuple, Union
+import click
from flyteidl.admin.launch_plan_pb2 import LaunchPlan
from flyteidl.admin.task_pb2 import TaskSpec
from flyteidl.admin.workflow_pb2 import WorkflowSpec
@@ -125,3 +127,9 @@ def hydrate_registration_parameters(
del entity.sub_workflows[:]
entity.sub_workflows.extend(refreshed_sub_workflows)
return identifier, entity
+
+
+def display_help_with_error(ctx: click.Context, message: str):
+ click.echo(f"{ctx.get_help()}\n")
+ click.secho(message, fg="red")
+ sys.exit(1)
diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py
new file mode 100644
index 0000000000..a9a9c4900d
--- /dev/null
+++ b/flytekit/clis/sdk_in_container/helpers.py
@@ -0,0 +1,32 @@
+import click
+
+from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE
+from flytekit.configuration import Config
+from flytekit.loggers import cli_logger
+from flytekit.remote.remote import FlyteRemote
+
+FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote"
+
+
+def get_and_save_remote_with_click_context(
+ ctx: click.Context, project: str, domain: str, save: bool = True
+) -> FlyteRemote:
+ """
+ NB: This function will by default mutate the click Context.obj dictionary, adding a remote key with value
+ of the created FlyteRemote object.
+
+ :param ctx: the click context object
+ :param project: default project for the remote instance
+ :param domain: default domain
+ :param save: If false, will not mutate the context.obj dict
+ :return: FlyteRemote instance
+ """
+ cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE)
+ cfg_obj = Config.auto(cfg_file_location)
+ cli_logger.info(
+ f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "")
+ )
+ r = FlyteRemote(cfg_obj, default_project=project, default_domain=domain)
+ if save:
+ ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r
+ return r
diff --git a/flytekit/clis/sdk_in_container/package.py b/flytekit/clis/sdk_in_container/package.py
index 71efeab576..2a884e29da 100644
--- a/flytekit/clis/sdk_in_container/package.py
+++ b/flytekit/clis/sdk_in_container/package.py
@@ -1,8 +1,8 @@
import os
-import sys
import click
+from flytekit.clis.helpers import display_help_with_error
from flytekit.clis.sdk_in_container import constants
from flytekit.configuration import (
DEFAULT_RUNTIME_PYTHON_INTERPRETER,
@@ -100,8 +100,7 @@ def package(ctx, image_config, source, output, force, fast, in_container_source_
pkgs = ctx.obj[constants.CTX_PACKAGES]
if not pkgs:
- click.secho("No packages to scan for flyte entities. Aborting!", fg="red")
- sys.exit(-1)
+ display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!")
try:
serialize_and_package(pkgs, serialization_settings, source, output, fast)
diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py
index c2b5f2b045..76777c5663 100644
--- a/flytekit/clis/sdk_in_container/pyflyte.py
+++ b/flytekit/clis/sdk_in_container/pyflyte.py
@@ -5,6 +5,7 @@
from flytekit.clis.sdk_in_container.init import init
from flytekit.clis.sdk_in_container.local_cache import local_cache
from flytekit.clis.sdk_in_container.package import package
+from flytekit.clis.sdk_in_container.register import register
from flytekit.clis.sdk_in_container.run import run
from flytekit.clis.sdk_in_container.serialize import serialize
from flytekit.configuration.internal import LocalSDK
@@ -68,6 +69,7 @@ def main(ctx, pkgs=None, config=None):
main.add_command(local_cache)
main.add_command(init)
main.add_command(run)
+main.add_command(register)
if __name__ == "__main__":
main()
diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py
new file mode 100644
index 0000000000..03e00d7896
--- /dev/null
+++ b/flytekit/clis/sdk_in_container/register.py
@@ -0,0 +1,179 @@
+import os
+import pathlib
+import typing
+
+import click
+
+from flytekit.clis.helpers import display_help_with_error
+from flytekit.clis.sdk_in_container import constants
+from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context
+from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings
+from flytekit.configuration.default_images import DefaultImages
+from flytekit.loggers import cli_logger
+from flytekit.tools.fast_registration import fast_package
+from flytekit.tools.repo import find_common_root, load_packages_and_modules
+from flytekit.tools.repo import register as repo_register
+from flytekit.tools.translator import Options
+
+_register_help = """
+This command is similar to package but instead of producing a zip file, all your Flyte entities are compiled,
+and then sent to the backend specified by your config file. Think of this as combining the pyflyte package
+and the flytectl register step in one command. This is why you see switches you'd normally use with flytectl
+like service account here.
+
+Note: This command runs "fast" register by default. Future work to come to add a non-fast version.
+This means that a zip is created from the detected root of the packages given, and uploaded. Just like with
+pyflyte run, tasks registered from this command will download and unzip that code package before running.
+
+Note: This command only works on regular Python packages, not namespace packages. When determining
+ the root of your project, it finds the first folder that does not have an __init__.py file.
+"""
+
+
+@click.command("register", help=_register_help)
+@click.option(
+ "-p",
+ "--project",
+ required=False,
+ type=str,
+ default="flytesnacks",
+ help="Project to register and run this workflow in",
+)
+@click.option(
+ "-d",
+ "--domain",
+ required=False,
+ type=str,
+ default="development",
+ help="Domain to register and run this workflow in",
+)
+@click.option(
+ "-i",
+ "--image",
+ "image_config",
+ required=False,
+ multiple=True,
+ type=click.UNPROCESSED,
+ callback=ImageConfig.validate_image,
+ default=[DefaultImages.default_image()],
+ help="A fully qualified tag for an docker image, e.g. somedocker.com/myimage:someversion123. This is a "
+ "multi-option and can be of the form --image xyz.io/docker:latest "
+ "--image my_image=xyz.io/docker2:latest. Note, the `name=image_uri`. The name is optional, if not "
+ "provided the image will be used as the default image. All the names have to be unique, and thus "
+ "there can only be one --image option with no name.",
+)
+@click.option(
+ "-o",
+ "--output",
+ required=False,
+ type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True),
+ default=None,
+ help="Directory to write the output zip file containing the protobuf definitions",
+)
+@click.option(
+ "-d",
+ "--destination-dir",
+ required=False,
+ type=str,
+ default="/root",
+ help="Directory inside the image where the tar file containing the code will be copied to",
+)
+@click.option(
+ "--service-account",
+ required=False,
+ type=str,
+ default="",
+ help="Service account used when creating launch plans",
+)
+@click.option(
+ "--raw-data-prefix",
+ required=False,
+ type=str,
+ default="",
+ help="Raw output data prefix when creating launch plans, where offloaded data will be stored",
+)
+@click.option(
+ "-v",
+ "--version",
+ required=False,
+ type=str,
+ help="Version the package or module is registered with",
+)
+@click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1)
+@click.pass_context
+def register(
+ ctx: click.Context,
+ project: str,
+ domain: str,
+ image_config: ImageConfig,
+ output: str,
+ destination_dir: str,
+ service_account: str,
+ raw_data_prefix: str,
+ version: typing.Optional[str],
+ package_or_module: typing.Tuple[str],
+):
+ """
+ see help
+ """
+ pkgs = ctx.obj[constants.CTX_PACKAGES]
+ if not pkgs:
+ cli_logger.debug("No pkgs")
+ if pkgs:
+ raise ValueError("Unimplemented, just specify pkgs like folder/files as args at the end of the command")
+
+ if len(package_or_module) == 0:
+ display_help_with_error(
+ ctx,
+ "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed",
+ )
+
+ cli_logger.debug(
+ f"Running pyflyte register from {os.getcwd()} "
+ f"with images {image_config} "
+ f"and image destinationfolder {destination_dir} "
+ f"on {len(package_or_module)} package(s) {package_or_module}"
+ )
+
+ # Create and save FlyteRemote,
+ remote = get_and_save_remote_with_click_context(ctx, project, domain)
+
+ # Todo: add switch for non-fast - skip the zipping and uploading and no fastserializationsettings
+ # Create a zip file containing all the entries.
+ detected_root = find_common_root(package_or_module)
+ cli_logger.debug(f"Using {detected_root} as root folder for project")
+ zip_file = fast_package(detected_root, output)
+
+ # Upload zip file to Admin using FlyteRemote.
+ md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file))
+ cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}")
+
+ # Create serialization settings
+ # Todo: Rely on default Python interpreter for now, this will break custom Spark containers
+ serialization_settings = SerializationSettings(
+ project=project,
+ domain=domain,
+ image_config=image_config,
+ fast_serialization_settings=FastSerializationSettings(
+ enabled=True,
+ destination_dir=destination_dir,
+ distribution_location=native_url,
+ ),
+ )
+
+ options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix)
+
+ # Load all the entities
+ registerable_entities = load_packages_and_modules(
+ serialization_settings, detected_root, list(package_or_module), options
+ )
+ if len(registerable_entities) == 0:
+ display_help_with_error(ctx, "No Flyte entities were detected. Aborting!")
+ cli_logger.info(f"Found and serialized {len(registerable_entities)} entities")
+
+ if not version:
+ version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa
+ cli_logger.info(f"Computed version is {version}")
+
+ # Register using repo code
+ repo_register(registerable_entities, project, domain, version, remote.client)
diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py
index b6f723fe0c..bd7b8ee20b 100644
--- a/flytekit/clis/sdk_in_container/run.py
+++ b/flytekit/clis/sdk_in_container/run.py
@@ -2,6 +2,7 @@
import functools
import importlib
import json
+import logging
import os
import pathlib
import typing
@@ -11,10 +12,12 @@
import click
from dataclasses_json import DataClassJsonMixin
from pytimeparse import parse
+from typing_extensions import get_args
from flytekit import BlobType, Literal, Scalar
-from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_DOMAIN, CTX_PROJECT
-from flytekit.configuration import Config, ImageConfig, SerializationSettings
+from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PROJECT
+from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context
+from flytekit.configuration import ImageConfig
from flytekit.configuration.default_images import DefaultImages
from flytekit.core import context_manager, tracker
from flytekit.core.base_task import PythonTask
@@ -22,19 +25,17 @@
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
-from flytekit.loggers import cli_logger
from flytekit.models import literals
from flytekit.models.interface import Variable
from flytekit.models.literals import Blob, BlobMetadata, Primitive
from flytekit.models.types import LiteralType, SimpleType
from flytekit.remote.executions import FlyteWorkflowExecution
-from flytekit.remote.remote import FlyteRemote
from flytekit.tools import module_loader, script_mode
+from flytekit.tools.script_mode import _find_project_root
from flytekit.tools.translator import Options
REMOTE_FLAG_KEY = "remote"
RUN_LEVEL_PARAMS_KEY = "run_level_params"
-FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote"
DATA_PROXY_CALLBACK_KEY = "data_proxy"
@@ -244,6 +245,31 @@ def convert_to_blob(
return lit
+ def convert_to_union(
+ self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any
+ ) -> Literal:
+ lt = self._literal_type
+ for i in range(len(self._literal_type.union_type.variants)):
+ variant = self._literal_type.union_type.variants[i]
+ python_type = get_args(self._python_type)[i]
+ converter = FlyteLiteralConverter(
+ ctx,
+ self._flyte_ctx,
+ variant,
+ python_type,
+ self._create_upload_fn,
+ )
+ try:
+ # Here we use click converter to convert the input in command line to native python type,
+ # and then use flyte converter to convert it to literal.
+ python_val = converter._click_type.convert(value, param, ctx)
+ literal = converter.convert_to_literal(ctx, param, python_val)
+ self._python_type = python_type
+ return literal
+ except (Exception or AttributeError) as e:
+ logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e)
+ raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}")
+
def convert_to_literal(
self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any
) -> Literal:
@@ -255,7 +281,7 @@ def convert_to_literal(
if self._literal_type.collection_type or self._literal_type.map_value_type:
# TODO Does not support nested flytefile, flyteschema types
- v = json.loads(value)
+ v = json.loads(value) if isinstance(value, str) else value
if self._literal_type.collection_type and not isinstance(v, list):
raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(v)}")
if self._literal_type.map_value_type and not isinstance(v, dict):
@@ -263,11 +289,14 @@ def convert_to_literal(
return TypeEngine.to_literal(self._flyte_ctx, v, self._python_type, self._literal_type)
if self._literal_type.union_type:
- raise NotImplementedError("Union type is not yet implemented for pyflyte run")
+ return self.convert_to_union(ctx, param, value)
if self._literal_type.simple or self._literal_type.enum_type:
if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT:
- o = cast(DataClassJsonMixin, self._python_type).from_json(value)
+ if type(value) != self._python_type:
+ o = cast(DataClassJsonMixin, self._python_type).from_json(value)
+ else:
+ o = value
return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type)
return Literal(scalar=self._converter.convert(value, self._python_type))
@@ -396,16 +425,14 @@ def get_workflow_command_base_params() -> typing.List[click.Option]:
]
-def load_naive_entity(module_name: str, entity_name: str) -> typing.Union[WorkflowBase, PythonTask]:
+def load_naive_entity(module_name: str, entity_name: str, project_root: str) -> typing.Union[WorkflowBase, PythonTask]:
"""
Load the workflow of a the script file.
N.B.: it assumes that the file is self-contained, in other words, there are no relative imports.
"""
- flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings(
- SerializationSettings(None)
- )
- with context_manager.FlyteContextManager.with_context(flyte_ctx):
- with module_loader.add_sys_path(os.getcwd()):
+ flyte_ctx_builder = context_manager.FlyteContextManager.current_context().new_builder()
+ with context_manager.FlyteContextManager.with_context(flyte_ctx_builder):
+ with module_loader.add_sys_path(project_root):
importlib.import_module(module_name)
return module_loader.load_object_from_module(f"{module_name}.{entity_name}")
@@ -444,9 +471,7 @@ def get_entities_in_file(filename: str) -> Entities:
"""
Returns a list of flyte workflow names and list of Flyte tasks in a file.
"""
- flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings(
- SerializationSettings(None)
- )
+ flyte_ctx = context_manager.FlyteContextManager.current_context().new_builder()
module_name = os.path.splitext(os.path.relpath(filename))[0].replace(os.path.sep, ".")
with context_manager.FlyteContextManager.with_context(flyte_ctx):
with module_loader.add_sys_path(os.getcwd()):
@@ -473,6 +498,8 @@ def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow,
"""
def _run(*args, **kwargs):
+ # By the time we get to this function, all the loading has already happened
+
run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY]
project, domain = run_level_params.get("project"), run_level_params.get("domain")
inputs = {}
@@ -486,10 +513,6 @@ def _run(*args, **kwargs):
remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY]
- # StructuredDatasetTransformerEngine.register(
- # PandasToParquetDataProxyEncodingHandler(get_upload_url_fn), default_for_type=True
- # )
-
remote_entity = remote.register_script(
entity,
project=project,
@@ -532,32 +555,41 @@ class WorkflowCommand(click.MultiCommand):
def __init__(self, filename: str, *args, **kwargs):
super().__init__(*args, **kwargs)
- self._filename = filename
+ self._filename = pathlib.Path(filename).resolve()
def list_commands(self, ctx):
entities = get_entities_in_file(self._filename)
return entities.all()
def get_command(self, ctx, exe_entity):
+ """
+ This command uses the filename with which this command was created, and the string name of the entity passed
+ after the Python filename on the command line, to load the Python object, and then return the Command that
+ click should run.
+ :param ctx: The click Context object.
+ :param exe_entity: string of the flyte entity provided by the user. Should be the name of a workflow, or task
+ function.
+ :return:
+ """
+
rel_path = os.path.relpath(self._filename)
if rel_path.startswith(".."):
raise ValueError(
f"You must call pyflyte from the same or parent dir, {self._filename} not under {os.getcwd()}"
)
+ project_root = _find_project_root(self._filename)
+ # Find the relative path for the filename relative to the root of the project.
+ # N.B.: by construction project_root will necessarily be an ancestor of the filename passed in as
+ # a parameter.
+ rel_path = self._filename.relative_to(project_root)
module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".")
- entity = load_naive_entity(module, exe_entity)
+ entity = load_naive_entity(module, exe_entity, project_root)
# If this is a remote execution, which we should know at this point, then create the remote object
p = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT)
d = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_DOMAIN)
- cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE)
- cfg_obj = Config.auto(cfg_file_location)
- cli_logger.info(
- f"Run is using config object {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "")
- )
- r = FlyteRemote(cfg_obj, default_project=p, default_domain=d)
- ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r
+ r = get_and_save_remote_with_click_context(ctx, p, d)
get_upload_url_fn = functools.partial(r.client.get_upload_signed_url, project=p, domain=d)
flyte_ctx = context_manager.FlyteContextManager.current_context()
@@ -596,8 +628,16 @@ def get_command(self, ctx, filename):
return WorkflowCommand(filename, name=filename, help="Run a [workflow|task] in a file using script mode")
+_run_help = """
+This command can execute either a workflow or a task from the commandline, for fully self-contained scripts.
+Tasks and workflows cannot be imported from other files currently. Please use `pyflyte package` or
+`pyflyte register` to handle those and then launch from the Flyte UI or `flytectl`
+
+Note: This command only works on regular Python packages, not namespace packages. When determining
+ the root of your project, it finds the first folder that does not have an __init__.py file.
+"""
+
run = RunCommand(
name="run",
- help="Run command: This command can execute either a workflow or a task from the commandline, for "
- "fully self-contained scripts. Tasks and workflows cannot be imported from other files currently.",
+ help=_run_help,
)
diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py
index 188f6ebeee..5354abde4f 100644
--- a/flytekit/configuration/__init__.py
+++ b/flytekit/configuration/__init__.py
@@ -297,6 +297,7 @@ class PlatformConfig(object):
:param endpoint: DNS for Flyte backend
:param insecure: Whether or not to use SSL
+ :param insecure_skip_verify: Wether to skip SSL certificate verification
:param command: This command is executed to return a token using an external process.
:param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment.
More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/.
@@ -309,32 +310,13 @@ class PlatformConfig(object):
endpoint: str = "localhost:30081"
insecure: bool = False
+ insecure_skip_verify: bool = False
command: typing.Optional[typing.List[str]] = None
client_id: typing.Optional[str] = None
client_credentials_secret: typing.Optional[str] = None
scopes: List[str] = field(default_factory=list)
auth_mode: AuthType = AuthType.STANDARD
- def with_parameters(
- self,
- endpoint: str = "localhost:30081",
- insecure: bool = False,
- command: typing.Optional[typing.List[str]] = None,
- client_id: typing.Optional[str] = None,
- client_credentials_secret: typing.Optional[str] = None,
- scopes: List[str] = None,
- auth_mode: AuthType = AuthType.STANDARD,
- ) -> PlatformConfig:
- return PlatformConfig(
- endpoint=endpoint,
- insecure=insecure,
- command=command,
- client_id=client_id,
- client_credentials_secret=client_credentials_secret,
- scopes=scopes if scopes else [],
- auth_mode=auth_mode,
- )
-
@classmethod
def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None) -> PlatformConfig:
"""
@@ -345,6 +327,9 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None
config_file = get_config_file(config_file)
kwargs = {}
kwargs = set_if_exists(kwargs, "insecure", _internal.Platform.INSECURE.read(config_file))
+ kwargs = set_if_exists(
+ kwargs, "insecure_skip_verify", _internal.Platform.INSECURE_SKIP_VERIFY.read(config_file)
+ )
kwargs = set_if_exists(kwargs, "command", _internal.Credentials.COMMAND.read(config_file))
kwargs = set_if_exists(kwargs, "client_id", _internal.Credentials.CLIENT_ID.read(config_file))
kwargs = set_if_exists(
@@ -562,7 +547,7 @@ def for_sandbox(cls) -> Config:
:return: Config
"""
return Config(
- platform=PlatformConfig(insecure=True),
+ platform=PlatformConfig(endpoint="localhost:30081", auth_mode="Pkce", insecure=True),
data_config=DataConfig(
s3=S3Config(endpoint="http://localhost:30084", access_key_id="minio", secret_access_key="miniostorage")
),
diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py
index b329a16112..467f660d42 100644
--- a/flytekit/configuration/file.py
+++ b/flytekit/configuration/file.py
@@ -32,13 +32,16 @@ class LegacyConfigEntry(object):
option: str
type_: typing.Type = str
+ def get_env_name(self):
+ return f"FLYTE_{self.section.upper()}_{self.option.upper()}"
+
def read_from_env(self, transform: typing.Optional[typing.Callable] = None) -> typing.Optional[typing.Any]:
"""
Reads the config entry from environment variable, the structure of the env var is current
``FLYTE_{SECTION}_{OPTION}`` all upper cased. We will change this in the future.
:return:
"""
- env = f"FLYTE_{self.section.upper()}_{self.option.upper()}"
+ env = self.get_env_name()
v = os.environ.get(env, None)
if v is None:
return None
@@ -159,7 +162,7 @@ def __init__(self, location: str):
Load the config from this location
"""
self._location = location
- if location.endswith("yaml"):
+ if location.endswith("yaml") or location.endswith("yml"):
self._legacy_config = None
self._yaml_config = self._read_yaml_config(location)
else:
diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py
index 9bb7567be2..fb09015b6f 100644
--- a/flytekit/configuration/internal.py
+++ b/flytekit/configuration/internal.py
@@ -105,6 +105,9 @@ class Platform(object):
LegacyConfigEntry(SECTION, "url"), YamlConfigEntry("admin.endpoint"), lambda x: x.replace("dns:///", "")
)
INSECURE = ConfigEntry(LegacyConfigEntry(SECTION, "insecure", bool), YamlConfigEntry("admin.insecure", bool))
+ INSECURE_SKIP_VERIFY = ConfigEntry(
+ LegacyConfigEntry(SECTION, "insecure_skip_verify", bool), YamlConfigEntry("admin.insecureSkipVerify", bool)
+ )
class LocalSDK(object):
diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py
index 77ac6516a7..a17a9148b6 100644
--- a/flytekit/core/context_manager.py
+++ b/flytekit/core/context_manager.py
@@ -27,7 +27,6 @@
from enum import Enum
from typing import Generator, List, Optional, Union
-import flytekit
from flytekit.clients import friendly as friendly_client # noqa
from flytekit.configuration import Config, SecretsConfig, SerializationSettings
from flytekit.core import mock_stats, utils
@@ -39,6 +38,9 @@
from flytekit.loggers import logger, user_space_logger
from flytekit.models.core import identifier as _identifier
+if typing.TYPE_CHECKING:
+ from flytekit.deck.deck import Deck
+
# TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin
# Enables static type checking https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING
@@ -80,12 +82,13 @@ class Builder(object):
stats: taggable.TaggableStats
execution_date: datetime
logging: _logging.Logger
- execution_id: str
+ execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier]
attrs: typing.Dict[str, typing.Any]
working_dir: typing.Union[os.PathLike, utils.AutoDeletingTempDir]
checkpoint: typing.Optional[Checkpoint]
- decks: List[flytekit.Deck]
+ decks: List[Deck]
raw_output_prefix: str
+ task_id: typing.Optional[_identifier.Identifier]
def __init__(self, current: typing.Optional[ExecutionParameters] = None):
self.stats = current.stats if current else None
@@ -97,6 +100,7 @@ def __init__(self, current: typing.Optional[ExecutionParameters] = None):
self.decks = current._decks if current else []
self.attrs = current._attrs if current else {}
self.raw_output_prefix = current.raw_output_prefix if current else None
+ self.task_id = current.task_id if current else None
def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder:
self.attrs[key] = v
@@ -114,6 +118,7 @@ def build(self) -> ExecutionParameters:
checkpoint=self.checkpoint,
decks=self.decks,
raw_output_prefix=self.raw_output_prefix,
+ task_id=self.task_id,
**self.attrs,
)
@@ -143,11 +148,12 @@ def __init__(
execution_date,
tmp_dir,
stats,
- execution_id,
+ execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier],
logging,
raw_output_prefix,
checkpoint=None,
decks=None,
+ task_id: typing.Optional[_identifier.Identifier] = None,
**kwargs,
):
"""
@@ -173,6 +179,7 @@ def __init__(
self._secrets_manager = SecretsManager()
self._checkpoint = checkpoint
self._decks = decks
+ self._task_id = task_id
@property
def stats(self) -> taggable.TaggableStats:
@@ -230,6 +237,14 @@ def execution_id(self) -> _identifier.WorkflowExecutionIdentifier:
"""
return self._execution_id
+ @property
+ def task_id(self) -> typing.Optional[_identifier.Identifier]:
+ """
+ At production run-time, this will be generated by reading environment variables that are set
+ by the backend.
+ """
+ return self._task_id
+
@property
def secrets(self) -> SecretsManager:
return self._secrets_manager
@@ -248,8 +263,8 @@ def decks(self) -> typing.List:
return self._decks
@property
- def default_deck(self) -> "Deck":
- from flytekit import Deck
+ def default_deck(self) -> Deck:
+ from flytekit.deck.deck import Deck
return Deck("default")
@@ -602,10 +617,10 @@ def current_context() -> Optional[FlyteContext]:
"""
return FlyteContextManager.current_context()
- def get_deck(self):
+ def get_deck(self) -> str:
from flytekit.deck.deck import _get_deck
- _get_deck(self.execution_state.user_space_params)
+ return _get_deck(self.execution_state.user_space_params)
@dataclass
class Builder(object):
@@ -796,6 +811,7 @@ def initialize():
default_context = FlyteContext(file_access=default_local_file_access_provider)
default_user_space_params = ExecutionParameters(
execution_id=WorkflowExecutionIdentifier.promote_from_model(default_execution_id),
+ task_id=_identifier.Identifier(_identifier.ResourceType.TASK, "local", "local", "local", "local"),
execution_date=_datetime.datetime.utcnow(),
stats=mock_stats.MockStats(),
logging=user_space_logger,
diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py
index fdd47d1741..0f651410bf 100644
--- a/flytekit/core/interface.py
+++ b/flytekit/core/interface.py
@@ -7,12 +7,15 @@
from collections import OrderedDict
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union
+from typing_extensions import get_args, get_origin, get_type_hints
+
from flytekit.core import context_manager
from flytekit.core.docstring import Docstring
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.user import FlyteValidationException
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
+from flytekit.models.literals import Void
from flytekit.types.pickle import FlytePickle
T = typing.TypeVar("T")
@@ -182,11 +185,17 @@ def transform_inputs_to_parameters(
inputs_with_def = interface.inputs_with_defaults
for k, v in inputs_vars.items():
val, _default = inputs_with_def[k]
- required = _default is None
- default_lv = None
- if _default is not None:
- default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type)
- params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required)
+ if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val):
+ from flytekit import Literal, Scalar
+
+ literal = Literal(scalar=Scalar(none_type=Void()))
+ params[k] = _interface_models.Parameter(var=v, default=literal, required=False)
+ else:
+ required = _default is None
+ default_lv = None
+ if _default is not None:
+ default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type)
+ params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required)
return _interface_models.ParameterMap(params)
@@ -274,11 +283,8 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc
For now the fancy object, maybe in the future a dumb object.
"""
- try:
- # include_extras can only be used in python >= 3.9
- type_hints = typing.get_type_hints(fn, include_extras=True)
- except TypeError:
- type_hints = typing.get_type_hints(fn)
+
+ type_hints = get_type_hints(fn, include_extras=True)
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)
@@ -386,7 +392,7 @@ def t(a: int, b: str) -> Dict[str, int]: ...
bases = return_annotation.__bases__ # type: ignore
if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"):
logger.debug(f"Task returns named tuple {return_annotation}")
- return dict(typing.get_type_hints(return_annotation))
+ return dict(get_type_hints(return_annotation, include_extras=True))
if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore
# Handle option 3
diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py
index e1f53dc5d2..e8f5bd4aa3 100644
--- a/flytekit/core/launch_plan.py
+++ b/flytekit/core/launch_plan.py
@@ -292,7 +292,6 @@ def get_or_create(
LaunchPlan.CACHE[name or workflow.name] = lp
return lp
- # TODO: Add QoS after it's done
def __init__(
self,
name: str,
@@ -359,6 +358,10 @@ def clone_with(
def python_interface(self) -> Interface:
return self.workflow.python_interface
+ @property
+ def interface(self) -> _interface_models.TypedInterface:
+ return self.workflow.interface
+
@property
def name(self) -> str:
return self._name
diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py
index 6eb4f51d81..1694af6d70 100644
--- a/flytekit/core/node_creation.py
+++ b/flytekit/core/node_creation.py
@@ -1,7 +1,7 @@
from __future__ import annotations
import collections
-from typing import Type, Union
+from typing import TYPE_CHECKING, Type, Union
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext
@@ -12,11 +12,15 @@
from flytekit.exceptions import user as _user_exceptions
from flytekit.loggers import logger
+if TYPE_CHECKING:
+ from flytekit.remote.remote_callable import RemoteEntity
+
+
# This file exists instead of moving to node.py because it needs Task/Workflow/LaunchPlan and those depend on Node
def create_node(
- entity: Union[PythonTask, LaunchPlan, WorkflowBase], *args, **kwargs
+ entity: Union[PythonTask, LaunchPlan, WorkflowBase, RemoteEntity], *args, **kwargs
) -> Union[Node, VoidPromise, Type[collections.namedtuple]]:
"""
This is the function you want to call if you need to specify dependencies between tasks that don't consume and/or
@@ -65,6 +69,8 @@ def sub_wf():
t2(t1_node.o0)
"""
+ from flytekit.remote.remote_callable import RemoteEntity
+
if len(args) > 0:
raise _user_exceptions.FlyteAssertion(
f"Only keyword args are supported to pass inputs to workflows and tasks."
@@ -75,8 +81,9 @@ def sub_wf():
not isinstance(entity, PythonTask)
and not isinstance(entity, WorkflowBase)
and not isinstance(entity, LaunchPlan)
+ and not isinstance(entity, RemoteEntity)
):
- raise AssertionError("Should be but it's not")
+ raise AssertionError(f"Should be a callable Flyte entity (either local or fetched) but is {type(entity)}")
# This function is only called from inside workflows and dynamic tasks.
# That means there are two scenarios we need to take care of, compilation and local workflow execution.
@@ -84,7 +91,6 @@ def sub_wf():
# When compiling, calling the entity will create a node.
ctx = FlyteContext.current_context()
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
-
outputs = entity(**kwargs)
# This is always the output of create_and_link_node which returns create_task_output, which can be
# VoidPromise, Promise, or our custom namedtuple of Promises.
@@ -105,9 +111,11 @@ def sub_wf():
return node
# If a Promise or custom namedtuple of Promises, we need to attach each output as an attribute to the node.
- if entity.python_interface.outputs:
+ # todo: fix the noqas below somehow... can't add abstract property to RemoteEntity because it has to come
+ # before the model Template classes in FlyteTask/Workflow/LaunchPlan
+ if entity.interface.outputs: # noqa
if isinstance(outputs, tuple):
- for output_name in entity.python_interface.output_names:
+ for output_name in entity.interface.outputs.keys(): # noqa
attr = getattr(outputs, output_name)
if attr is None:
raise _user_exceptions.FlyteAssertion(
@@ -120,7 +128,7 @@ def sub_wf():
setattr(node, output_name, attr)
node.outputs[output_name] = attr
else:
- output_names = entity.python_interface.output_names
+ output_names = [k for k in entity.interface.outputs.keys()] # noqa
if len(output_names) != 1:
raise _user_exceptions.FlyteAssertion(f"Output of length 1 expected but {len(output_names)} found")
@@ -136,6 +144,9 @@ def sub_wf():
# Handling local execution
elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
+ if isinstance(entity, RemoteEntity):
+ raise AssertionError(f"Remote entities are not yet runnable locally {entity.name}")
+
if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
logger.warning(f"Manual node creation cannot be used in branch logic {entity.name}")
raise Exception("Being more restrictive for now and disallowing manual node creation in branch logic")
diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py
index e5877e21fe..4c9150881d 100644
--- a/flytekit/core/promise.py
+++ b/flytekit/core/promise.py
@@ -5,7 +5,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union, cast
-from typing_extensions import Protocol
+from typing_extensions import Protocol, get_args
from flytekit.core import constants as _common_constants
from flytekit.core import context_manager as _flyte_context
@@ -23,6 +23,7 @@
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Primitive
+from flytekit.models.types import SimpleType
def translate_inputs_to_literals(
@@ -68,29 +69,43 @@ def extract_value(
) -> _literal_models.Literal:
if isinstance(input_val, list):
- if flyte_literal_type.collection_type is None:
+ lt = flyte_literal_type
+ python_type = val_type
+ if flyte_literal_type.union_type:
+ for i in range(len(flyte_literal_type.union_type.variants)):
+ variant = flyte_literal_type.union_type.variants[i]
+ if variant.collection_type:
+ lt = variant
+ python_type = get_args(val_type)[i]
+ if lt.collection_type is None:
raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}")
try:
- sub_type = ListTransformer.get_sub_type(val_type)
+ sub_type = ListTransformer.get_sub_type(python_type)
except ValueError:
if len(input_val) == 0:
raise
sub_type = type(input_val[0])
- literal_list = [extract_value(ctx, v, sub_type, flyte_literal_type.collection_type) for v in input_val]
+ literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val]
return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literal_list))
elif isinstance(input_val, dict):
- if (
- flyte_literal_type.map_value_type is None
- and flyte_literal_type.simple != _type_models.SimpleType.STRUCT
- ):
- raise TypeError(f"Not a map type {flyte_literal_type} but got a map {input_val}")
- k_type, sub_type = DictTransformer.get_dict_types(val_type) # type: ignore
- if flyte_literal_type.simple == _type_models.SimpleType.STRUCT:
- return TypeEngine.to_literal(ctx, input_val, type(input_val), flyte_literal_type)
+ lt = flyte_literal_type
+ python_type = val_type
+ if flyte_literal_type.union_type:
+ for i in range(len(flyte_literal_type.union_type.variants)):
+ variant = flyte_literal_type.union_type.variants[i]
+ if variant.map_value_type:
+ lt = variant
+ python_type = get_args(val_type)[i]
+ if variant.simple == _type_models.SimpleType.STRUCT:
+ lt = variant
+ python_type = get_args(val_type)[i]
+ if lt.map_value_type is None and lt.simple != _type_models.SimpleType.STRUCT:
+ raise TypeError(f"Not a map type {lt} but got a map {input_val}")
+ if lt.simple == _type_models.SimpleType.STRUCT:
+ return TypeEngine.to_literal(ctx, input_val, type(input_val), lt)
else:
- literal_map = {
- k: extract_value(ctx, v, sub_type, flyte_literal_type.map_value_type) for k, v in input_val.items()
- }
+ k_type, sub_type = DictTransformer.get_dict_types(python_type) # type: ignore
+ literal_map = {k: extract_value(ctx, v, sub_type, lt.map_value_type) for k, v in input_val.items()}
return _literal_models.Literal(map=_literal_models.LiteralMap(literals=literal_map))
elif isinstance(input_val, Promise):
# In the example above, this handles the "in2=a" type of argument
@@ -863,7 +878,7 @@ def create_and_link_node(
ctx: FlyteContext,
entity: SupportsNodeCreation,
**kwargs,
-):
+) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]:
"""
This method is used to generate a node with bindings. This is not used in the execution path.
"""
@@ -881,7 +896,20 @@ def create_and_link_node(
for k in sorted(interface.inputs):
var = typed_interface.inputs[k]
if k not in kwargs:
- raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type))
+ is_optional = False
+ if var.type.union_type:
+ for variant in var.type.union_type.variants:
+ if variant.simple == SimpleType.NONE:
+ val, _default = interface.inputs_with_defaults[k]
+ if _default is not None:
+ raise ValueError(
+ f"The default value for the optional type must be None, but got {_default}"
+ )
+ is_optional = True
+ if not is_optional:
+ raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type))
+ else:
+ continue
v = kwargs[k]
# This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
# Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py
index 0fff171f36..4b6743afd3 100644
--- a/flytekit/core/python_auto_container.py
+++ b/flytekit/core/python_auto_container.py
@@ -156,7 +156,10 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
return self._get_command_fn(settings)
def get_container(self, settings: SerializationSettings) -> _task_model.Container:
- env = {**settings.env, **self.environment} if self.environment else settings.env
+ env = {}
+ for elem in (settings.env, self.environment):
+ if elem:
+ env.update(elem)
return _get_container_definition(
image=get_registerable_container_image(self.container_image, settings.image_config),
command=[],
@@ -253,4 +256,4 @@ def get_registerable_container_image(img: Optional[str], cfg: ImageConfig) -> st
# fqn will access the fully qualified name of the image (e.g. registry/imagename:version -> registry/imagename)
# version will access the version part of the image (e.g. registry/imagename:version -> version)
# With empty attribute, it'll access the full image path (e.g. registry/imagename:version -> registry/imagename:version)
-_IMAGE_REPLACE_REGEX = re.compile(r"({{\s*\.image[s]?(?:\.([a-zA-Z]+))(?:\.([a-zA-Z]+))?\s*}})", re.IGNORECASE)
+_IMAGE_REPLACE_REGEX = re.compile(r"({{\s*\.image[s]?(?:\.([a-zA-Z0-9_]+))(?:\.([a-zA-Z0-9_]+))?\s*}})", re.IGNORECASE)
diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py
index 0c5fe786ae..7addc89197 100644
--- a/flytekit/core/schedule.py
+++ b/flytekit/core/schedule.py
@@ -15,13 +15,14 @@
# Duplicates flytekit.common.schedules.Schedule to avoid using the ExtendedSdkType metaclass.
class CronSchedule(_schedule_models.Schedule):
"""
- Use this when you have a launch plan that you want to run on a cron expression. The syntax currently used for this
- follows the `AWS convention `__
+ Use this when you have a launch plan that you want to run on a cron expression.
+ This uses standard `cron format `__
+ in case where you are using default native scheduler using the schedule attribute.
.. code-block::
CronSchedule(
- cron_expression="0 10 * * ? *",
+ schedule="*/1 * * * *", # Following schedule runs every min
)
See the :std:ref:`User Guide ` for further examples.
@@ -54,9 +55,10 @@ def __init__(
self, cron_expression: str = None, schedule: str = None, offset: str = None, kickoff_time_input_arg: str = None
):
"""
- :param str cron_expression: This should be a cron expression in AWS style.
+ :param str cron_expression: This should be a cron expression in AWS style.Shouldn't be used in case of native scheduler.
:param str schedule: This takes a cron alias (see ``_VALID_CRON_ALIASES``) or a croniter parseable schedule.
- Only one of this or ``cron_expression`` can be set, not both.
+ Only one of this or ``cron_expression`` can be set, not both. This uses standard `cron format `_
+ and is supported by native scheduler
:param str offset:
:param str kickoff_time_input_arg: This is a convenient argument to use when your code needs to know what time
a run was kicked off. Supply the name of the input argument of your workflow to this argument here. Note
@@ -67,7 +69,7 @@ def __init__(
def my_wf(kickoff_time: datetime): ...
schedule = CronSchedule(
- cron_expression="0 10 * * ? *",
+ schedule="*/1 * * * *"
kickoff_time_input_arg="kickoff_time")
"""
diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py
index 60c42ac339..b3851e77ce 100644
--- a/flytekit/core/type_engine.py
+++ b/flytekit/core/type_engine.py
@@ -319,7 +319,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct()))
)
- def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
+ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.Any:
"""
If any field inside the dataclass is flyte type, we should use flyte type transformer for that field.
"""
@@ -328,36 +328,42 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
from flytekit.types.schema.types import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDataset
- for f in dataclasses.fields(python_type):
- v = python_val.__getattribute__(f.name)
- field_type = f.type
- if inspect.isclass(field_type) and (
- issubclass(field_type, FlyteSchema)
- or issubclass(field_type, FlyteFile)
- or issubclass(field_type, FlyteDirectory)
- or issubclass(field_type, StructuredDataset)
- ):
- lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None)
- # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a
- # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the
- # path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here,
- # so that dataclass_json can always get a remote path.
- # In other words, the file transformer has special code that handles the fact that if remote_source is
- # set, then the real uri in the literal should be the remote source, not the path (which may be an
- # auto-generated random local path). To be sure we're writing the right path to the json, use the uri
- # as determined by the transformer.
- if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory):
- python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri))
- elif issubclass(field_type, StructuredDataset):
- python_val.__setattr__(
- f.name,
- field_type(
- uri=lv.scalar.structured_dataset.uri,
- ),
- )
+ if hasattr(python_type, "__origin__") and python_type.__origin__ is list:
+ return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val]
+
+ if hasattr(python_type, "__origin__") and python_type.__origin__ is dict:
+ return {k: self._serialize_flyte_type(v, python_type.__args__[1]) for k, v in python_val.items()}
- elif dataclasses.is_dataclass(field_type):
- self._serialize_flyte_type(v, field_type)
+ if not dataclasses.is_dataclass(python_type):
+ return python_val
+
+ if inspect.isclass(python_type) and (
+ issubclass(python_type, FlyteSchema)
+ or issubclass(python_type, FlyteFile)
+ or issubclass(python_type, FlyteDirectory)
+ or issubclass(python_type, StructuredDataset)
+ ):
+ lv = TypeEngine.to_literal(FlyteContext.current_context(), python_val, python_type, None)
+ # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a
+ # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the
+ # path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here,
+ # so that dataclass_json can always get a remote path.
+ # In other words, the file transformer has special code that handles the fact that if remote_source is
+ # set, then the real uri in the literal should be the remote source, not the path (which may be an
+ # auto-generated random local path). To be sure we're writing the right path to the json, use the uri
+ # as determined by the transformer.
+ if issubclass(python_type, FlyteFile) or issubclass(python_type, FlyteDirectory):
+ return python_type(path=lv.scalar.blob.uri)
+ elif issubclass(python_type, StructuredDataset):
+ return python_type(uri=lv.scalar.structured_dataset.uri)
+ else:
+ return python_val
+ else:
+ for v in dataclasses.fields(python_type):
+ val = python_val.__getattribute__(v.name)
+ field_type = v.type
+ python_val.__setattr__(v.name, self._serialize_flyte_type(val, field_type))
+ return python_val
def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T:
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
@@ -365,6 +371,12 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) ->
from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine
+ if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is list:
+ return [self._deserialize_flyte_type(v, expected_python_type.__args__[0]) for v in python_val]
+
+ if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is dict:
+ return {k: self._deserialize_flyte_type(v, expected_python_type.__args__[1]) for k, v in python_val.items()}
+
if not dataclasses.is_dataclass(expected_python_type):
return python_val
@@ -737,11 +749,11 @@ def literal_map_to_kwargs(
"""
Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task
"""
- if len(lm.literals) != len(python_types):
+ if len(lm.literals) > len(python_types):
raise ValueError(
f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}"
)
- return {k: TypeEngine.to_python_value(ctx, lm.literals[k], v) for k, v in python_types.items()}
+ return {k: TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) for k, v in lm.literals.items()}
@classmethod
def dict_to_literal_map(
@@ -993,7 +1005,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
# Should really never happen, sanity check
raise TypeError("Ambiguous choice of variant for union type")
found_res = True
- except TypeTransformerFailedError as e:
+ except (TypeTransformerFailedError, AttributeError) as e:
logger.debug(f"Failed to convert from {python_val} to {t}", e)
continue
@@ -1047,7 +1059,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
)
res_tag = trans.name
found_res = True
- except TypeTransformerFailedError as e:
+ except (TypeTransformerFailedError, AttributeError) as e:
logger.debug(f"Failed to convert from {lv} to {v}", e)
if found_res:
diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py
index b5f8aa1eab..599b886ab0 100644
--- a/flytekit/deck/deck.py
+++ b/flytekit/deck/deck.py
@@ -3,7 +3,7 @@
from jinja2 import Environment, FileSystemLoader
-from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager
+from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager
from flytekit.loggers import logger
OUTPUT_DIR_JUPYTER_PREFIX = "jupyter"
@@ -89,7 +89,7 @@ def _ipython_check() -> bool:
return is_ipython
-def _get_deck(new_user_params: ExecutionParameters):
+def _get_deck(new_user_params: ExecutionParameters) -> str:
"""
Get flyte deck html string
"""
@@ -98,7 +98,11 @@ def _get_deck(new_user_params: ExecutionParameters):
def _output_deck(task_name: str, new_user_params: ExecutionParameters):
- output_dir = FlyteContext.current_context().file_access.get_random_local_directory()
+ ctx = FlyteContext.current_context()
+ if ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
+ output_dir = ctx.execution_state.engine_dir
+ else:
+ output_dir = ctx.file_access.get_random_local_directory()
deck_path = os.path.join(output_dir, DECK_FILE_NAME)
with open(deck_path, "w") as f:
f.write(_get_deck(new_user_params))
diff --git a/flytekit/deck/html/template.html b/flytekit/deck/html/template.html
index 5886f1d936..3992ab9c0f 100644
--- a/flytekit/deck/html/template.html
+++ b/flytekit/deck/html/template.html
@@ -4,6 +4,7 @@
User Content
+
-