diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index ca77c9ff9b..c2411ee328 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -51,7 +51,7 @@ jobs: run: | coverage run -m pytest tests/flytekit/unit -m "not sandbox_test" - name: Codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v1.5.2 with: fail_ci_if_error: true # optional (default = false) @@ -82,6 +82,7 @@ jobs: - flytekit-onnx-tensorflow - flytekit-pandera - flytekit-papermill + - flytekit-polars - flytekit-snowflake - flytekit-spark - flytekit-sqlalchemy @@ -90,6 +91,10 @@ jobs: # Issue tracked in https://github.com/ray-project/ray/issues/19116. - python-version: 3.10 plugin-names: "flytekit-modin" + # Great-expectations does not support python 3.10 yet + # https://github.com/great-expectations/great_expectations/blob/develop/setup.py#L87-L89 + - python-version: 3.10 + plugin-names: "flytekit-greatexpectations" steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -107,7 +112,7 @@ jobs: run: | make setup cd plugins/${{ matrix.plugin-names }} - pip install -e . + pip install -r requirements.txt if [ -f dev-requirements.txt ]; then pip install -r dev-requirements.txt; fi pip install --no-deps -U https://github.com/flyteorg/flytekit/archive/${{ github.sha }}.zip#egg=flytekit pip freeze diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 316f1b0ee2..39470b7370 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 4.0.1 hooks: - id: flake8 - repo: https://github.com/psf/black @@ -8,17 +8,17 @@ repos: hooks: - id: black - repo: https://github.com/PyCQA/isort - rev: 5.9.3 + rev: 5.10.1 hooks: - id: isort args: ["--profile", "black"] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.2.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.7.2.1 + rev: v0.8.0.4 hooks: - id: shellcheck diff --git a/CODEOWNERS b/CODEOWNERS index 96d7d9d004..9389524869 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,3 +1,3 @@ # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence. -* @wild-endeavor @kumare3 @eapolinario +* @wild-endeavor @kumare3 @eapolinario @pingsutw diff --git a/README.md b/README.md index 7144575a37..67b6d12297 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,17 @@ - -

- Flyte Logo -

-

- Flytekit Python -

-

- Flytekit Python is the Python SDK built on top of Flyte -

-

- Plugins - · - Contribution Guide -

- +

+ Flyte Logo +

+

+ Flytekit Python +

+

+ Flytekit Python is the Python SDK built on top of Flyte +

+

+ Plugins + · + Contribution Guide +

[![PyPI version fury.io](https://badge.fury.io/py/flytekit.svg)](https://pypi.python.org/pypi/flytekit/) [![PyPI download day](https://img.shields.io/pypi/dd/flytekit.svg)](https://pypi.python.org/pypi/flytekit/) diff --git a/dev-requirements.in b/dev-requirements.in index 09f3b90c46..b2e6cf74a1 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -11,3 +11,4 @@ codespell google-cloud-bigquery google-cloud-bigquery-storage IPython +torch diff --git a/dev-requirements.txt b/dev-requirements.txt index 419503beff..3705a578c0 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -28,13 +28,13 @@ binaryornot==0.4.4 # cookiecutter cached-property==1.5.2 # via docker-compose -cachetools==5.1.0 +cachetools==5.2.0 # via google-auth -certifi==2022.5.18.1 +certifi==2022.6.15 # via # -c requirements.txt # requests -cffi==1.15.0 +cffi==1.15.1 # via # -c requirements.txt # bcrypt @@ -42,11 +42,11 @@ cffi==1.15.0 # pynacl cfgv==3.3.1 # via pre-commit -chardet==4.0.0 +chardet==5.0.0 # via # -c requirements.txt # binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via # -c requirements.txt # requests @@ -61,20 +61,21 @@ cloudpickle==2.1.0 # flytekit codespell==2.1.0 # via -r dev-requirements.in -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via # -c requirements.txt # flytekit -coverage[toml]==6.3.3 +coverage[toml]==6.4.1 # via -r dev-requirements.in croniter==1.3.5 # via # -c requirements.txt # flytekit -cryptography==37.0.2 +cryptography==37.0.4 # via # -c requirements.txt # paramiko + # pyopenssl # secretstorage dataclasses-json==0.5.7 # via @@ -118,47 +119,47 @@ docstring-parser==0.14.1 # via # -c requirements.txt # flytekit -filelock==3.7.0 +filelock==3.7.1 # via virtualenv -flyteidl==1.0.1 +flyteidl==1.1.8 # via # -c requirements.txt # flytekit -google-api-core[grpc]==2.8.0 +google-api-core[grpc]==2.8.2 # via # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.6.6 +google-auth==2.9.0 # via # google-api-core # google-cloud-core -google-cloud-bigquery==3.1.0 +google-cloud-bigquery==3.2.0 # via -r dev-requirements.in -google-cloud-bigquery-storage==2.13.1 +google-cloud-bigquery-storage==2.13.2 # via # -r dev-requirements.in # google-cloud-bigquery -google-cloud-core==2.3.0 +google-cloud-core==2.3.1 # via google-cloud-bigquery google-crc32c==1.3.0 # via google-resumable-media google-resumable-media==2.3.3 # via google-cloud-bigquery -googleapis-common-protos==1.56.1 +googleapis-common-protos==1.56.3 # via # -c requirements.txt # flyteidl # google-api-core # grpcio-status -grpcio==1.46.1 +grpcio==1.47.0 # via # -c requirements.txt # flytekit # google-api-core # google-cloud-bigquery # grpcio-status -grpcio-status==1.46.1 +grpcio-status==1.47.0 # via # -c requirements.txt # flytekit @@ -169,10 +170,11 @@ idna==3.3 # via # -c requirements.txt # requests -importlib-metadata==4.11.3 +importlib-metadata==4.12.0 # via # -c requirements.txt # click + # flytekit # jsonschema # keyring # pluggy @@ -181,7 +183,7 @@ importlib-metadata==4.11.3 # virtualenv iniconfig==1.1.1 # via pytest -ipython==7.33.0 +ipython==7.34.0 # via -r dev-requirements.in jedi==0.18.1 # via ipython @@ -206,7 +208,7 @@ jsonschema==3.2.0 # via # -c requirements.txt # docker-compose -keyring==23.5.0 +keyring==23.6.0 # via # -c requirements.txt # flytekit @@ -214,7 +216,7 @@ markupsafe==2.1.1 # via # -c requirements.txt # jinja2 -marshmallow==3.15.0 +marshmallow==3.17.0 # via # -c requirements.txt # dataclasses-json @@ -232,7 +234,7 @@ matplotlib-inline==0.1.3 # via ipython mock==4.0.3 # via -r dev-requirements.in -mypy==0.950 +mypy==0.961 # via -r dev-requirements.in mypy-extensions==0.4.3 # via @@ -243,7 +245,7 @@ natsort==8.1.0 # via # -c requirements.txt # flytekit -nodeenv==1.6.0 +nodeenv==1.7.0 # via pre-commit numpy==1.21.6 # via @@ -273,15 +275,11 @@ platformdirs==2.5.2 # via virtualenv pluggy==1.0.0 # via pytest -poyo==0.5.0 - # via - # -c requirements.txt - # cookiecutter pre-commit==2.19.0 # via -r dev-requirements.in -prompt-toolkit==3.0.29 +prompt-toolkit==3.0.30 # via ipython -proto-plus==1.20.4 +proto-plus==1.20.6 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -292,6 +290,7 @@ protobuf==3.20.1 # flytekit # google-api-core # google-cloud-bigquery + # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status # proto-plus @@ -326,6 +325,10 @@ pygments==2.12.0 # via ipython pynacl==1.5.0 # via paramiko +pyopenssl==22.0.0 + # via + # -c requirements.txt + # flytekit pyparsing==3.0.9 # via # -c requirements.txt @@ -373,14 +376,15 @@ pytz==2022.1 pyyaml==5.4.1 # via # -c requirements.txt + # cookiecutter # docker-compose # flytekit # pre-commit -regex==2022.4.24 +regex==2022.6.2 # via # -c requirements.txt # docker-image-py -requests==2.27.1 +requests==2.28.1 # via # -c requirements.txt # cookiecutter @@ -390,7 +394,7 @@ requests==2.27.1 # google-api-core # google-cloud-bigquery # responses -responses==0.20.0 +responses==0.21.0 # via # -c requirements.txt # flytekit @@ -411,7 +415,6 @@ singledispatchmethod==1.0 six==1.16.0 # via # -c requirements.txt - # cookiecutter # dockerpty # google-auth # grpcio @@ -441,19 +444,23 @@ tomli==2.0.1 # coverage # mypy # pytest -traitlets==5.2.1.post0 +torch==1.11.0 + # via -r dev-requirements.in +traitlets==5.3.0 # via # ipython # matplotlib-inline -typed-ast==1.5.3 +typed-ast==1.5.4 # via mypy -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via # -c requirements.txt # arrow # flytekit # importlib-metadata # mypy + # responses + # torch # typing-inspect typing-inspect==0.7.1 # via @@ -465,7 +472,7 @@ urllib3==1.26.9 # flytekit # requests # responses -virtualenv==20.14.1 +virtualenv==20.15.1 # via pre-commit wcwidth==0.2.5 # via prompt-toolkit diff --git a/doc-requirements.in b/doc-requirements.in index 4d60a6919b..760d3903dc 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -33,3 +33,4 @@ papermill # papermill jupyter # papermill pyspark # spark sqlalchemy # sqlalchemy +torch # pytorch diff --git a/doc-requirements.txt b/doc-requirements.txt index 6032cc31f4..e9f16c89a0 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -18,13 +18,13 @@ argon2-cffi-bindings==21.2.0 # via argon2-cffi arrow==1.2.2 # via jinja2-time -astroid==2.11.5 +astroid==2.11.6 # via sphinx-autoapi attrs==21.4.0 # via # jsonschema # visions -babel==2.10.1 +babel==2.10.3 # via sphinx backcall==0.2.0 # via ipython @@ -40,23 +40,23 @@ beautifulsoup4==4.11.1 # sphinx-material binaryornot==0.4.4 # via cookiecutter -bleach==5.0.0 +bleach==5.0.1 # via nbconvert -botocore==1.26.5 +botocore==1.27.22 # via -r doc-requirements.in -cachetools==5.1.0 +cachetools==5.2.0 # via google-auth -certifi==2022.5.18.1 +certifi==2022.6.15 # via # kubernetes # requests -cffi==1.15.0 +cffi==1.15.1 # via # argon2-cffi-bindings # cryptography -chardet==4.0.0 +chardet==5.0.0 # via binaryornot -charset-normalizer==2.0.12 +charset-normalizer==2.1.0 # via requests click==8.1.3 # via @@ -66,16 +66,17 @@ click==8.1.3 # papermill cloudpickle==2.1.0 # via flytekit -colorama==0.4.4 +colorama==0.4.5 # via great-expectations -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit croniter==1.3.5 # via flytekit -cryptography==36.0.2 +cryptography==37.0.4 # via # -r doc-requirements.in # great-expectations + # pyopenssl # secretstorage css-html-js-minify==2.5.5 # via sphinx-material @@ -119,7 +120,7 @@ entrypoints==0.4 # papermill fastjsonschema==2.15.3 # via nbformat -flyteidl==1.0.1 +flyteidl==1.1.8 # via flytekit fonttools==4.33.3 # via matplotlib @@ -129,45 +130,45 @@ fsspec==2022.5.0 # modin furo @ git+https://github.com/flyteorg/furo@main # via -r doc-requirements.in -google-api-core[grpc]==2.8.0 +google-api-core[grpc]==2.8.2 # via # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.6.6 +google-auth==2.9.0 # via # google-api-core # google-cloud-core # kubernetes google-cloud==0.34.0 # via -r doc-requirements.in -google-cloud-bigquery==3.1.0 +google-cloud-bigquery==3.2.0 # via -r doc-requirements.in -google-cloud-bigquery-storage==2.13.1 +google-cloud-bigquery-storage==2.13.2 # via google-cloud-bigquery -google-cloud-core==2.3.0 +google-cloud-core==2.3.1 # via google-cloud-bigquery google-crc32c==1.3.0 # via google-resumable-media google-resumable-media==2.3.3 # via google-cloud-bigquery -googleapis-common-protos==1.56.1 +googleapis-common-protos==1.56.3 # via # flyteidl # google-api-core # grpcio-status -great-expectations==0.15.6 +great-expectations==0.15.12 # via -r doc-requirements.in greenlet==1.1.2 # via sqlalchemy -grpcio==1.46.1 +grpcio==1.47.0 # via # -r doc-requirements.in # flytekit # google-api-core # google-cloud-bigquery # grpcio-status -grpcio-status==1.46.1 +grpcio-status==1.47.0 # via # flytekit # google-api-core @@ -177,27 +178,28 @@ idna==3.3 # via requests imagehash==4.2.1 # via visions -imagesize==1.3.0 +imagesize==1.4.1 # via sphinx -importlib-metadata==4.11.3 +importlib-metadata==4.12.0 # via # click + # flytekit # great-expectations # jsonschema # keyring # markdown # sphinx # sqlalchemy -importlib-resources==5.7.1 +importlib-resources==5.8.0 # via jsonschema -ipykernel==6.13.0 +ipykernel==6.15.0 # via # ipywidgets # jupyter # jupyter-console # notebook # qtconsole -ipython==7.33.0 +ipython==7.34.0 # via # great-expectations # ipykernel @@ -208,7 +210,7 @@ ipython-genutils==0.2.0 # ipywidgets # notebook # qtconsole -ipywidgets==7.7.0 +ipywidgets==7.7.1 # via jupyter jedi==0.18.1 # via ipython @@ -216,7 +218,7 @@ jeepney==0.8.0 # via # keyring # secretstorage -jinja2==3.0.3 +jinja2==3.1.2 # via # altair # cookiecutter @@ -229,7 +231,7 @@ jinja2==3.0.3 # sphinx-autoapi jinja2-time==0.2.0 # via cookiecutter -jmespath==1.0.0 +jmespath==1.0.1 # via botocore joblib==1.1.0 # via @@ -239,21 +241,21 @@ jsonpatch==1.32 # via great-expectations jsonpointer==2.3 # via jsonpatch -jsonschema==4.5.1 +jsonschema==4.6.1 # via # altair # great-expectations # nbformat jupyter==1.0.0 # via -r doc-requirements.in -jupyter-client==7.3.1 +jupyter-client==7.3.4 # via # ipykernel # jupyter-console # nbclient # notebook # qtconsole -jupyter-console==6.4.3 +jupyter-console==6.4.4 # via jupyter jupyter-core==4.10.0 # via @@ -264,17 +266,17 @@ jupyter-core==4.10.0 # qtconsole jupyterlab-pygments==0.2.2 # via nbconvert -jupyterlab-widgets==1.1.0 +jupyterlab-widgets==1.1.1 # via ipywidgets -keyring==23.5.0 +keyring==23.6.0 # via flytekit -kiwisolver==1.4.2 +kiwisolver==1.4.3 # via matplotlib -kubernetes==23.6.0 +kubernetes==24.2.0 # via -r doc-requirements.in lazy-object-proxy==1.7.1 # via astroid -lxml==4.8.0 +lxml==4.9.1 # via sphinx-material markdown==3.3.7 # via -r doc-requirements.in @@ -283,7 +285,7 @@ markupsafe==2.1.1 # jinja2 # nbconvert # pandas-profiling -marshmallow==3.15.0 +marshmallow==3.17.0 # via # dataclasses-json # marshmallow-enum @@ -318,7 +320,7 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 # via flytekit -nbclient==0.6.3 +nbclient==0.6.6 # via # nbconvert # papermill @@ -329,7 +331,6 @@ nbconvert==6.5.0 nbformat==5.4.0 # via # great-expectations - # ipywidgets # nbclient # nbconvert # notebook @@ -342,7 +343,7 @@ nest-asyncio==1.5.5 # notebook networkx==2.6.3 # via visions -notebook==6.4.11 +notebook==6.4.12 # via # great-expectations # jupyter @@ -407,22 +408,20 @@ phik==0.12.2 # via pandas-profiling pickleshare==0.7.5 # via ipython -pillow==9.1.1 +pillow==9.2.0 # via # imagehash # matplotlib # visions -plotly==5.8.0 +plotly==5.9.0 # via -r doc-requirements.in -poyo==0.5.0 - # via cookiecutter prometheus-client==0.14.1 # via notebook -prompt-toolkit==3.0.29 +prompt-toolkit==3.0.30 # via # ipython # jupyter-console -proto-plus==1.20.4 +proto-plus==1.20.6 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -432,13 +431,14 @@ protobuf==3.20.1 # flytekit # google-api-core # google-cloud-bigquery + # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status # proto-plus # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -psutil==5.9.0 +psutil==5.9.1 # via ipykernel ptyprocess==0.7.0 # via @@ -446,7 +446,7 @@ ptyprocess==0.7.0 # terminado py==1.11.0 # via retry -py4j==0.10.9.3 +py4j==0.10.9.5 # via pyspark pyarrow==6.0.1 # via @@ -467,12 +467,15 @@ pydantic==1.9.1 # pandera pygments==2.12.0 # via + # furo # ipython # jupyter-console # nbconvert # qtconsole # sphinx # sphinx-prompt +pyopenssl==22.0.0 + # via flytekit pyparsing==2.4.7 # via # great-expectations @@ -480,7 +483,7 @@ pyparsing==2.4.7 # packaging pyrsistent==0.18.1 # via jsonschema -pyspark==3.2.1 +pyspark==3.3.0 # via -r doc-requirements.in python-dateutil==2.8.2 # via @@ -514,23 +517,25 @@ pywavelets==1.3.0 # via imagehash pyyaml==6.0 # via + # cookiecutter # flytekit # kubernetes # pandas-profiling # papermill # sphinx-autoapi -pyzmq==23.0.0 +pyzmq==23.2.0 # via + # ipykernel # jupyter-client # notebook # qtconsole -qtconsole==5.3.0 +qtconsole==5.3.1 # via jupyter qtpy==2.1.0 # via qtconsole -regex==2022.4.24 +regex==2022.6.2 # via docker-image-py -requests==2.27.1 +requests==2.28.1 # via # cookiecutter # docker @@ -546,7 +551,7 @@ requests==2.27.1 # sphinx requests-oauthlib==1.3.1 # via kubernetes -responses==0.20.0 +responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit @@ -577,7 +582,6 @@ singledispatchmethod==1.0 six==1.16.0 # via # bleach - # cookiecutter # google-auth # grpcio # imagehash @@ -595,6 +599,7 @@ sphinx==4.5.0 # -r doc-requirements.in # furo # sphinx-autoapi + # sphinx-basic-ng # sphinx-code-include # sphinx-copybutton # sphinx-fontawesome @@ -605,6 +610,8 @@ sphinx==4.5.0 # sphinxcontrib-yt sphinx-autoapi==1.8.4 # via -r doc-requirements.in +sphinx-basic-ng==0.0.1a12 + # via furo sphinx-code-include==1.1.1 # via -r doc-requirements.in sphinx-copybutton==0.5.0 @@ -633,7 +640,7 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-yt==0.2.2 # via -r doc-requirements.in -sqlalchemy==1.4.36 +sqlalchemy==1.4.39 # via -r doc-requirements.in statsd==3.3.0 # via flytekit @@ -657,7 +664,9 @@ tinycss2==1.1.1 # via nbconvert toolz==0.11.2 # via altair -tornado==6.1 +torch==1.11.0 + # via -r doc-requirements.in +tornado==6.2 # via # ipykernel # jupyter-client @@ -668,7 +677,7 @@ tqdm==4.64.0 # great-expectations # pandas-profiling # papermill -traitlets==5.2.1.post0 +traitlets==5.3.0 # via # ipykernel # ipython @@ -681,9 +690,9 @@ traitlets==5.2.1.post0 # nbformat # notebook # qtconsole -typed-ast==1.5.3 +typed-ast==1.5.4 # via astroid -typing-extensions==4.2.0 +typing-extensions==4.3.0 # via # argon2-cffi # arrow @@ -695,6 +704,8 @@ typing-extensions==4.2.0 # kiwisolver # pandera # pydantic + # responses + # torch # typing-inspect typing-inspect==0.7.1 # via @@ -724,13 +735,13 @@ webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.3.2 +websocket-client==1.3.3 # via # docker # kubernetes wheel==0.37.1 # via flytekit -widgetsnbextension==3.6.0 +widgetsnbextension==3.6.1 # via ipywidgets wrapt==1.14.1 # via diff --git a/docs/source/data.extend.rst b/docs/source/data.extend.rst index f200500fbc..3f06961022 100644 --- a/docs/source/data.extend.rst +++ b/docs/source/data.extend.rst @@ -1,7 +1,7 @@ ############################## Extend Data Persistence layer ############################## -Flytekit provides a data persistence layer, which is used for recording metadata that is shared with backend Flyte. This persistence layer is available for various types to store raw user data and is designed to be cross-cloud compatible. +Flytekit provides a data persistence layer, which is used for recording metadata that is shared with the Flyte backend. This persistence layer is available for various types to store raw user data and is designed to be cross-cloud compatible. Moreover, it is designed to be extensible and users can bring their own data persistence plugins by following the persistence interface. .. note:: @@ -16,3 +16,22 @@ Moreover, it is designed to be extensible and users can bring their own data per :no-members: :no-inherited-members: :no-special-members: + +The ``fsspec`` Data Plugin +-------------------------- + +Flytekit ships with a default storage driver that uses aws-cli on AWS and gsutil on GCP. By default, Flyte uploads the task outputs to S3 or GCS using these storage drivers. + +Why ``fsspec``? +^^^^^^^^^^^^^^^ + +You can use the fsspec plugin implementation to utilize all its available plugins with flytekit. The `fsspec `_ plugin provides an implementation of the data persistence layer in Flytekit. For example: HDFS, FTP are supported in fsspec, so you can use them with flytekit too. +The data persistence layer helps store logs of metadata and raw user data. +As a consequence of the implementation, an S3 driver can be installed using ``pip install s3fs``. + +`Here `_ is a code snippet that shows protocols mapped to the class it implements. + +Once you install the plugin, it overrides all default implementations of the `DataPersistencePlugins `_ and provides the ones supported by fsspec. + +.. note:: + This plugin installs fsspec core only. To install all the fsspec plugins, see `here `_. diff --git a/docs/source/extras.pytorch.rst b/docs/source/extras.pytorch.rst new file mode 100644 index 0000000000..12fd3d62d9 --- /dev/null +++ b/docs/source/extras.pytorch.rst @@ -0,0 +1,7 @@ +############ +PyTorch Type +############ +.. automodule:: flytekit.extras.pytorch + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/types.extend.rst b/docs/source/types.extend.rst index f1b15455dd..f0cdff28dc 100644 --- a/docs/source/types.extend.rst +++ b/docs/source/types.extend.rst @@ -11,3 +11,4 @@ Feel free to follow the pattern of the built-in types. types.builtins.structured types.builtins.file types.builtins.directory + extras.pytorch diff --git a/flytekit/__init__.py b/flytekit/__init__.py index b6bd104a2d..c67a8a04b4 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -182,6 +182,7 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck +from flytekit.extras import pytorch from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels @@ -189,7 +190,7 @@ from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -from flytekit.types import directory, file, schema +from flytekit.types import directory, file, numpy, schema from flytekit.types.structured.structured_dataset import ( StructuredDataset, StructuredDatasetFormat, diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 68bd2578e5..1f8dd78ef0 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -243,6 +243,7 @@ def setup_execution( tmp_dir=user_workspace_dir, raw_output_prefix=raw_output_data_prefix, checkpoint=checkpointer, + task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version), ) try: diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 8db7c93a98..d542af5f7e 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -1004,3 +1004,17 @@ def get_upload_signed_url( expires_in=expires_in_pb, ) ) + + def get_download_signed_url( + self, native_url: str, expires_in: datetime.timedelta = None + ) -> _data_proxy_pb2.CreateUploadLocationResponse: + expires_in_pb = None + if expires_in: + expires_in_pb = Duration() + expires_in_pb.FromTimedelta(expires_in) + return super(SynchronousFlyteClient, self).create_download_location( + _data_proxy_pb2.CreateDownloadLocationRequest( + native_url=native_url, + expires_in=expires_in_pb, + ) + ) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 9aba78888a..9bdf85a178 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -1,12 +1,14 @@ from __future__ import annotations import base64 as _base64 +import ssl import subprocess import time import typing from typing import Optional import grpc +import OpenSSL import requests as _requests from flyteidl.admin.project_pb2 import ProjectListRequest from flyteidl.service import admin_pb2_grpc as _admin_service @@ -110,6 +112,26 @@ def __init__(self, cfg: PlatformConfig, **kwargs): self._cfg = cfg if cfg.insecure: self._channel = grpc.insecure_channel(cfg.endpoint, **kwargs) + elif cfg.insecure_skip_verify: + # Get port from endpoint or use 443 + endpoint_parts = cfg.endpoint.rsplit(":", 1) + if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit(): + server_address = tuple(endpoint_parts) + else: + server_address = (cfg.endpoint, "443") + + cert = ssl.get_server_certificate(server_address) + x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert) + cn = x509.get_subject().CN + credentials = grpc.ssl_channel_credentials(str.encode(cert)) + options = kwargs.get("options", []) + options.append(("grpc.ssl_target_name_override", cn)) + self._channel = grpc.secure_channel( + target=cfg.endpoint, + credentials=credentials, + options=options, + compression=kwargs.get("compression", None), + ) else: if "credentials" not in kwargs: credentials = grpc.ssl_channel_credentials( @@ -251,7 +273,10 @@ def _refresh_credentials_from_command(self): except subprocess.CalledProcessError as e: cli_logger.error("Failed to generate token from command {}".format(command)) raise _user_exceptions.FlyteAuthenticationException("Problems refreshing token with command: " + str(e)) - self.set_access_token(output.stdout.strip()) + authorization_header_key = self.public_client_config.authorization_metadata_key or None + if not authorization_header_key: + self.set_access_token(output.stdout.strip()) + self.set_access_token(output.stdout.strip(), authorization_header_key) def _refresh_credentials_noop(self): pass @@ -833,6 +858,17 @@ def create_upload_location( """ return self._dataproxy_stub.CreateUploadLocation(create_upload_location_request, metadata=self._metadata) + @_handle_rpc_error(retry=True) + def create_download_location( + self, create_download_location_request: _dataproxy_pb2.CreateDownloadLocationRequest + ) -> _dataproxy_pb2.CreateDownloadLocationResponse: + """ + Get a signed url to be used during fast registration + :param flyteidl.service.dataproxy_pb2.CreateDownloadLocationRequest create_download_location_request: + :rtype: flyteidl.service.dataproxy_pb2.CreateDownloadLocationResponse + """ + return self._dataproxy_stub.CreateDownloadLocation(create_download_location_request, metadata=self._metadata) + def get_token(token_endpoint, authorization_header, scope): """ diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 3628af7beb..21aec1c4ad 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -4,6 +4,7 @@ import os as _os import stat as _stat import sys as _sys +from dataclasses import replace from typing import Callable, Dict, List, Tuple, Union import click as _click @@ -276,7 +277,7 @@ def _get_client(host: str, insecure: bool) -> _friendly_client.SynchronousFlyteC if parent_ctx.obj["cacert"]: kwargs["root_certificates"] = parent_ctx.obj["cacert"] cfg = parent_ctx.obj["config"] - cfg = cfg.with_parameters(endpoint=host, insecure=insecure) + cfg = replace(cfg, endpoint=host, insecure=insecure) return _friendly_client.SynchronousFlyteClient(cfg, **kwargs) diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py index 73274e972e..d922f5e3c1 100644 --- a/flytekit/clis/helpers.py +++ b/flytekit/clis/helpers.py @@ -1,5 +1,7 @@ +import sys from typing import Tuple, Union +import click from flyteidl.admin.launch_plan_pb2 import LaunchPlan from flyteidl.admin.task_pb2 import TaskSpec from flyteidl.admin.workflow_pb2 import WorkflowSpec @@ -125,3 +127,9 @@ def hydrate_registration_parameters( del entity.sub_workflows[:] entity.sub_workflows.extend(refreshed_sub_workflows) return identifier, entity + + +def display_help_with_error(ctx: click.Context, message: str): + click.echo(f"{ctx.get_help()}\n") + click.secho(message, fg="red") + sys.exit(1) diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py new file mode 100644 index 0000000000..a9a9c4900d --- /dev/null +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -0,0 +1,32 @@ +import click + +from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE +from flytekit.configuration import Config +from flytekit.loggers import cli_logger +from flytekit.remote.remote import FlyteRemote + +FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote" + + +def get_and_save_remote_with_click_context( + ctx: click.Context, project: str, domain: str, save: bool = True +) -> FlyteRemote: + """ + NB: This function will by default mutate the click Context.obj dictionary, adding a remote key with value + of the created FlyteRemote object. + + :param ctx: the click context object + :param project: default project for the remote instance + :param domain: default domain + :param save: If false, will not mutate the context.obj dict + :return: FlyteRemote instance + """ + cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) + cfg_obj = Config.auto(cfg_file_location) + cli_logger.info( + f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "") + ) + r = FlyteRemote(cfg_obj, default_project=project, default_domain=domain) + if save: + ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r + return r diff --git a/flytekit/clis/sdk_in_container/package.py b/flytekit/clis/sdk_in_container/package.py index 71efeab576..2a884e29da 100644 --- a/flytekit/clis/sdk_in_container/package.py +++ b/flytekit/clis/sdk_in_container/package.py @@ -1,8 +1,8 @@ import os -import sys import click +from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants from flytekit.configuration import ( DEFAULT_RUNTIME_PYTHON_INTERPRETER, @@ -100,8 +100,7 @@ def package(ctx, image_config, source, output, force, fast, in_container_source_ pkgs = ctx.obj[constants.CTX_PACKAGES] if not pkgs: - click.secho("No packages to scan for flyte entities. Aborting!", fg="red") - sys.exit(-1) + display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!") try: serialize_and_package(pkgs, serialization_settings, source, output, fast) diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index c2b5f2b045..76777c5663 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -5,6 +5,7 @@ from flytekit.clis.sdk_in_container.init import init from flytekit.clis.sdk_in_container.local_cache import local_cache from flytekit.clis.sdk_in_container.package import package +from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.run import run from flytekit.clis.sdk_in_container.serialize import serialize from flytekit.configuration.internal import LocalSDK @@ -68,6 +69,7 @@ def main(ctx, pkgs=None, config=None): main.add_command(local_cache) main.add_command(init) main.add_command(run) +main.add_command(register) if __name__ == "__main__": main() diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py new file mode 100644 index 0000000000..03e00d7896 --- /dev/null +++ b/flytekit/clis/sdk_in_container/register.py @@ -0,0 +1,179 @@ +import os +import pathlib +import typing + +import click + +from flytekit.clis.helpers import display_help_with_error +from flytekit.clis.sdk_in_container import constants +from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings +from flytekit.configuration.default_images import DefaultImages +from flytekit.loggers import cli_logger +from flytekit.tools.fast_registration import fast_package +from flytekit.tools.repo import find_common_root, load_packages_and_modules +from flytekit.tools.repo import register as repo_register +from flytekit.tools.translator import Options + +_register_help = """ +This command is similar to package but instead of producing a zip file, all your Flyte entities are compiled, +and then sent to the backend specified by your config file. Think of this as combining the pyflyte package +and the flytectl register step in one command. This is why you see switches you'd normally use with flytectl +like service account here. + +Note: This command runs "fast" register by default. Future work to come to add a non-fast version. +This means that a zip is created from the detected root of the packages given, and uploaded. Just like with +pyflyte run, tasks registered from this command will download and unzip that code package before running. + +Note: This command only works on regular Python packages, not namespace packages. When determining + the root of your project, it finds the first folder that does not have an __init__.py file. +""" + + +@click.command("register", help=_register_help) +@click.option( + "-p", + "--project", + required=False, + type=str, + default="flytesnacks", + help="Project to register and run this workflow in", +) +@click.option( + "-d", + "--domain", + required=False, + type=str, + default="development", + help="Domain to register and run this workflow in", +) +@click.option( + "-i", + "--image", + "image_config", + required=False, + multiple=True, + type=click.UNPROCESSED, + callback=ImageConfig.validate_image, + default=[DefaultImages.default_image()], + help="A fully qualified tag for an docker image, e.g. somedocker.com/myimage:someversion123. This is a " + "multi-option and can be of the form --image xyz.io/docker:latest " + "--image my_image=xyz.io/docker2:latest. Note, the `name=image_uri`. The name is optional, if not " + "provided the image will be used as the default image. All the names have to be unique, and thus " + "there can only be one --image option with no name.", +) +@click.option( + "-o", + "--output", + required=False, + type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True), + default=None, + help="Directory to write the output zip file containing the protobuf definitions", +) +@click.option( + "-d", + "--destination-dir", + required=False, + type=str, + default="/root", + help="Directory inside the image where the tar file containing the code will be copied to", +) +@click.option( + "--service-account", + required=False, + type=str, + default="", + help="Service account used when creating launch plans", +) +@click.option( + "--raw-data-prefix", + required=False, + type=str, + default="", + help="Raw output data prefix when creating launch plans, where offloaded data will be stored", +) +@click.option( + "-v", + "--version", + required=False, + type=str, + help="Version the package or module is registered with", +) +@click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) +@click.pass_context +def register( + ctx: click.Context, + project: str, + domain: str, + image_config: ImageConfig, + output: str, + destination_dir: str, + service_account: str, + raw_data_prefix: str, + version: typing.Optional[str], + package_or_module: typing.Tuple[str], +): + """ + see help + """ + pkgs = ctx.obj[constants.CTX_PACKAGES] + if not pkgs: + cli_logger.debug("No pkgs") + if pkgs: + raise ValueError("Unimplemented, just specify pkgs like folder/files as args at the end of the command") + + if len(package_or_module) == 0: + display_help_with_error( + ctx, + "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed", + ) + + cli_logger.debug( + f"Running pyflyte register from {os.getcwd()} " + f"with images {image_config} " + f"and image destinationfolder {destination_dir} " + f"on {len(package_or_module)} package(s) {package_or_module}" + ) + + # Create and save FlyteRemote, + remote = get_and_save_remote_with_click_context(ctx, project, domain) + + # Todo: add switch for non-fast - skip the zipping and uploading and no fastserializationsettings + # Create a zip file containing all the entries. + detected_root = find_common_root(package_or_module) + cli_logger.debug(f"Using {detected_root} as root folder for project") + zip_file = fast_package(detected_root, output) + + # Upload zip file to Admin using FlyteRemote. + md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file)) + cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}") + + # Create serialization settings + # Todo: Rely on default Python interpreter for now, this will break custom Spark containers + serialization_settings = SerializationSettings( + project=project, + domain=domain, + image_config=image_config, + fast_serialization_settings=FastSerializationSettings( + enabled=True, + destination_dir=destination_dir, + distribution_location=native_url, + ), + ) + + options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) + + # Load all the entities + registerable_entities = load_packages_and_modules( + serialization_settings, detected_root, list(package_or_module), options + ) + if len(registerable_entities) == 0: + display_help_with_error(ctx, "No Flyte entities were detected. Aborting!") + cli_logger.info(f"Found and serialized {len(registerable_entities)} entities") + + if not version: + version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa + cli_logger.info(f"Computed version is {version}") + + # Register using repo code + repo_register(registerable_entities, project, domain, version, remote.client) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index b6f723fe0c..bd7b8ee20b 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -2,6 +2,7 @@ import functools import importlib import json +import logging import os import pathlib import typing @@ -11,10 +12,12 @@ import click from dataclasses_json import DataClassJsonMixin from pytimeparse import parse +from typing_extensions import get_args from flytekit import BlobType, Literal, Scalar -from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_DOMAIN, CTX_PROJECT -from flytekit.configuration import Config, ImageConfig, SerializationSettings +from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PROJECT +from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context +from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages from flytekit.core import context_manager, tracker from flytekit.core.base_task import PythonTask @@ -22,19 +25,17 @@ from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase -from flytekit.loggers import cli_logger from flytekit.models import literals from flytekit.models.interface import Variable from flytekit.models.literals import Blob, BlobMetadata, Primitive from flytekit.models.types import LiteralType, SimpleType from flytekit.remote.executions import FlyteWorkflowExecution -from flytekit.remote.remote import FlyteRemote from flytekit.tools import module_loader, script_mode +from flytekit.tools.script_mode import _find_project_root from flytekit.tools.translator import Options REMOTE_FLAG_KEY = "remote" RUN_LEVEL_PARAMS_KEY = "run_level_params" -FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote" DATA_PROXY_CALLBACK_KEY = "data_proxy" @@ -244,6 +245,31 @@ def convert_to_blob( return lit + def convert_to_union( + self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any + ) -> Literal: + lt = self._literal_type + for i in range(len(self._literal_type.union_type.variants)): + variant = self._literal_type.union_type.variants[i] + python_type = get_args(self._python_type)[i] + converter = FlyteLiteralConverter( + ctx, + self._flyte_ctx, + variant, + python_type, + self._create_upload_fn, + ) + try: + # Here we use click converter to convert the input in command line to native python type, + # and then use flyte converter to convert it to literal. + python_val = converter._click_type.convert(value, param, ctx) + literal = converter.convert_to_literal(ctx, param, python_val) + self._python_type = python_type + return literal + except (Exception or AttributeError) as e: + logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e) + raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}") + def convert_to_literal( self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any ) -> Literal: @@ -255,7 +281,7 @@ def convert_to_literal( if self._literal_type.collection_type or self._literal_type.map_value_type: # TODO Does not support nested flytefile, flyteschema types - v = json.loads(value) + v = json.loads(value) if isinstance(value, str) else value if self._literal_type.collection_type and not isinstance(v, list): raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(v)}") if self._literal_type.map_value_type and not isinstance(v, dict): @@ -263,11 +289,14 @@ def convert_to_literal( return TypeEngine.to_literal(self._flyte_ctx, v, self._python_type, self._literal_type) if self._literal_type.union_type: - raise NotImplementedError("Union type is not yet implemented for pyflyte run") + return self.convert_to_union(ctx, param, value) if self._literal_type.simple or self._literal_type.enum_type: if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT: - o = cast(DataClassJsonMixin, self._python_type).from_json(value) + if type(value) != self._python_type: + o = cast(DataClassJsonMixin, self._python_type).from_json(value) + else: + o = value return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type) return Literal(scalar=self._converter.convert(value, self._python_type)) @@ -396,16 +425,14 @@ def get_workflow_command_base_params() -> typing.List[click.Option]: ] -def load_naive_entity(module_name: str, entity_name: str) -> typing.Union[WorkflowBase, PythonTask]: +def load_naive_entity(module_name: str, entity_name: str, project_root: str) -> typing.Union[WorkflowBase, PythonTask]: """ Load the workflow of a the script file. N.B.: it assumes that the file is self-contained, in other words, there are no relative imports. """ - flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( - SerializationSettings(None) - ) - with context_manager.FlyteContextManager.with_context(flyte_ctx): - with module_loader.add_sys_path(os.getcwd()): + flyte_ctx_builder = context_manager.FlyteContextManager.current_context().new_builder() + with context_manager.FlyteContextManager.with_context(flyte_ctx_builder): + with module_loader.add_sys_path(project_root): importlib.import_module(module_name) return module_loader.load_object_from_module(f"{module_name}.{entity_name}") @@ -444,9 +471,7 @@ def get_entities_in_file(filename: str) -> Entities: """ Returns a list of flyte workflow names and list of Flyte tasks in a file. """ - flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( - SerializationSettings(None) - ) + flyte_ctx = context_manager.FlyteContextManager.current_context().new_builder() module_name = os.path.splitext(os.path.relpath(filename))[0].replace(os.path.sep, ".") with context_manager.FlyteContextManager.with_context(flyte_ctx): with module_loader.add_sys_path(os.getcwd()): @@ -473,6 +498,8 @@ def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, """ def _run(*args, **kwargs): + # By the time we get to this function, all the loading has already happened + run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY] project, domain = run_level_params.get("project"), run_level_params.get("domain") inputs = {} @@ -486,10 +513,6 @@ def _run(*args, **kwargs): remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] - # StructuredDatasetTransformerEngine.register( - # PandasToParquetDataProxyEncodingHandler(get_upload_url_fn), default_for_type=True - # ) - remote_entity = remote.register_script( entity, project=project, @@ -532,32 +555,41 @@ class WorkflowCommand(click.MultiCommand): def __init__(self, filename: str, *args, **kwargs): super().__init__(*args, **kwargs) - self._filename = filename + self._filename = pathlib.Path(filename).resolve() def list_commands(self, ctx): entities = get_entities_in_file(self._filename) return entities.all() def get_command(self, ctx, exe_entity): + """ + This command uses the filename with which this command was created, and the string name of the entity passed + after the Python filename on the command line, to load the Python object, and then return the Command that + click should run. + :param ctx: The click Context object. + :param exe_entity: string of the flyte entity provided by the user. Should be the name of a workflow, or task + function. + :return: + """ + rel_path = os.path.relpath(self._filename) if rel_path.startswith(".."): raise ValueError( f"You must call pyflyte from the same or parent dir, {self._filename} not under {os.getcwd()}" ) + project_root = _find_project_root(self._filename) + # Find the relative path for the filename relative to the root of the project. + # N.B.: by construction project_root will necessarily be an ancestor of the filename passed in as + # a parameter. + rel_path = self._filename.relative_to(project_root) module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".") - entity = load_naive_entity(module, exe_entity) + entity = load_naive_entity(module, exe_entity, project_root) # If this is a remote execution, which we should know at this point, then create the remote object p = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT) d = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_DOMAIN) - cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) - cfg_obj = Config.auto(cfg_file_location) - cli_logger.info( - f"Run is using config object {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "") - ) - r = FlyteRemote(cfg_obj, default_project=p, default_domain=d) - ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r + r = get_and_save_remote_with_click_context(ctx, p, d) get_upload_url_fn = functools.partial(r.client.get_upload_signed_url, project=p, domain=d) flyte_ctx = context_manager.FlyteContextManager.current_context() @@ -596,8 +628,16 @@ def get_command(self, ctx, filename): return WorkflowCommand(filename, name=filename, help="Run a [workflow|task] in a file using script mode") +_run_help = """ +This command can execute either a workflow or a task from the commandline, for fully self-contained scripts. +Tasks and workflows cannot be imported from other files currently. Please use `pyflyte package` or +`pyflyte register` to handle those and then launch from the Flyte UI or `flytectl` + +Note: This command only works on regular Python packages, not namespace packages. When determining + the root of your project, it finds the first folder that does not have an __init__.py file. +""" + run = RunCommand( name="run", - help="Run command: This command can execute either a workflow or a task from the commandline, for " - "fully self-contained scripts. Tasks and workflows cannot be imported from other files currently.", + help=_run_help, ) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 188f6ebeee..5354abde4f 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -297,6 +297,7 @@ class PlatformConfig(object): :param endpoint: DNS for Flyte backend :param insecure: Whether or not to use SSL + :param insecure_skip_verify: Wether to skip SSL certificate verification :param command: This command is executed to return a token using an external process. :param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. @@ -309,32 +310,13 @@ class PlatformConfig(object): endpoint: str = "localhost:30081" insecure: bool = False + insecure_skip_verify: bool = False command: typing.Optional[typing.List[str]] = None client_id: typing.Optional[str] = None client_credentials_secret: typing.Optional[str] = None scopes: List[str] = field(default_factory=list) auth_mode: AuthType = AuthType.STANDARD - def with_parameters( - self, - endpoint: str = "localhost:30081", - insecure: bool = False, - command: typing.Optional[typing.List[str]] = None, - client_id: typing.Optional[str] = None, - client_credentials_secret: typing.Optional[str] = None, - scopes: List[str] = None, - auth_mode: AuthType = AuthType.STANDARD, - ) -> PlatformConfig: - return PlatformConfig( - endpoint=endpoint, - insecure=insecure, - command=command, - client_id=client_id, - client_credentials_secret=client_credentials_secret, - scopes=scopes if scopes else [], - auth_mode=auth_mode, - ) - @classmethod def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None) -> PlatformConfig: """ @@ -345,6 +327,9 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None config_file = get_config_file(config_file) kwargs = {} kwargs = set_if_exists(kwargs, "insecure", _internal.Platform.INSECURE.read(config_file)) + kwargs = set_if_exists( + kwargs, "insecure_skip_verify", _internal.Platform.INSECURE_SKIP_VERIFY.read(config_file) + ) kwargs = set_if_exists(kwargs, "command", _internal.Credentials.COMMAND.read(config_file)) kwargs = set_if_exists(kwargs, "client_id", _internal.Credentials.CLIENT_ID.read(config_file)) kwargs = set_if_exists( @@ -562,7 +547,7 @@ def for_sandbox(cls) -> Config: :return: Config """ return Config( - platform=PlatformConfig(insecure=True), + platform=PlatformConfig(endpoint="localhost:30081", auth_mode="Pkce", insecure=True), data_config=DataConfig( s3=S3Config(endpoint="http://localhost:30084", access_key_id="minio", secret_access_key="miniostorage") ), diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py index b329a16112..467f660d42 100644 --- a/flytekit/configuration/file.py +++ b/flytekit/configuration/file.py @@ -32,13 +32,16 @@ class LegacyConfigEntry(object): option: str type_: typing.Type = str + def get_env_name(self): + return f"FLYTE_{self.section.upper()}_{self.option.upper()}" + def read_from_env(self, transform: typing.Optional[typing.Callable] = None) -> typing.Optional[typing.Any]: """ Reads the config entry from environment variable, the structure of the env var is current ``FLYTE_{SECTION}_{OPTION}`` all upper cased. We will change this in the future. :return: """ - env = f"FLYTE_{self.section.upper()}_{self.option.upper()}" + env = self.get_env_name() v = os.environ.get(env, None) if v is None: return None @@ -159,7 +162,7 @@ def __init__(self, location: str): Load the config from this location """ self._location = location - if location.endswith("yaml"): + if location.endswith("yaml") or location.endswith("yml"): self._legacy_config = None self._yaml_config = self._read_yaml_config(location) else: diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 9bb7567be2..fb09015b6f 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -105,6 +105,9 @@ class Platform(object): LegacyConfigEntry(SECTION, "url"), YamlConfigEntry("admin.endpoint"), lambda x: x.replace("dns:///", "") ) INSECURE = ConfigEntry(LegacyConfigEntry(SECTION, "insecure", bool), YamlConfigEntry("admin.insecure", bool)) + INSECURE_SKIP_VERIFY = ConfigEntry( + LegacyConfigEntry(SECTION, "insecure_skip_verify", bool), YamlConfigEntry("admin.insecureSkipVerify", bool) + ) class LocalSDK(object): diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 77ac6516a7..a17a9148b6 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -27,7 +27,6 @@ from enum import Enum from typing import Generator, List, Optional, Union -import flytekit from flytekit.clients import friendly as friendly_client # noqa from flytekit.configuration import Config, SecretsConfig, SerializationSettings from flytekit.core import mock_stats, utils @@ -39,6 +38,9 @@ from flytekit.loggers import logger, user_space_logger from flytekit.models.core import identifier as _identifier +if typing.TYPE_CHECKING: + from flytekit.deck.deck import Deck + # TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin # Enables static type checking https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING @@ -80,12 +82,13 @@ class Builder(object): stats: taggable.TaggableStats execution_date: datetime logging: _logging.Logger - execution_id: str + execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier] attrs: typing.Dict[str, typing.Any] working_dir: typing.Union[os.PathLike, utils.AutoDeletingTempDir] checkpoint: typing.Optional[Checkpoint] - decks: List[flytekit.Deck] + decks: List[Deck] raw_output_prefix: str + task_id: typing.Optional[_identifier.Identifier] def __init__(self, current: typing.Optional[ExecutionParameters] = None): self.stats = current.stats if current else None @@ -97,6 +100,7 @@ def __init__(self, current: typing.Optional[ExecutionParameters] = None): self.decks = current._decks if current else [] self.attrs = current._attrs if current else {} self.raw_output_prefix = current.raw_output_prefix if current else None + self.task_id = current.task_id if current else None def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder: self.attrs[key] = v @@ -114,6 +118,7 @@ def build(self) -> ExecutionParameters: checkpoint=self.checkpoint, decks=self.decks, raw_output_prefix=self.raw_output_prefix, + task_id=self.task_id, **self.attrs, ) @@ -143,11 +148,12 @@ def __init__( execution_date, tmp_dir, stats, - execution_id, + execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier], logging, raw_output_prefix, checkpoint=None, decks=None, + task_id: typing.Optional[_identifier.Identifier] = None, **kwargs, ): """ @@ -173,6 +179,7 @@ def __init__( self._secrets_manager = SecretsManager() self._checkpoint = checkpoint self._decks = decks + self._task_id = task_id @property def stats(self) -> taggable.TaggableStats: @@ -230,6 +237,14 @@ def execution_id(self) -> _identifier.WorkflowExecutionIdentifier: """ return self._execution_id + @property + def task_id(self) -> typing.Optional[_identifier.Identifier]: + """ + At production run-time, this will be generated by reading environment variables that are set + by the backend. + """ + return self._task_id + @property def secrets(self) -> SecretsManager: return self._secrets_manager @@ -248,8 +263,8 @@ def decks(self) -> typing.List: return self._decks @property - def default_deck(self) -> "Deck": - from flytekit import Deck + def default_deck(self) -> Deck: + from flytekit.deck.deck import Deck return Deck("default") @@ -602,10 +617,10 @@ def current_context() -> Optional[FlyteContext]: """ return FlyteContextManager.current_context() - def get_deck(self): + def get_deck(self) -> str: from flytekit.deck.deck import _get_deck - _get_deck(self.execution_state.user_space_params) + return _get_deck(self.execution_state.user_space_params) @dataclass class Builder(object): @@ -796,6 +811,7 @@ def initialize(): default_context = FlyteContext(file_access=default_local_file_access_provider) default_user_space_params = ExecutionParameters( execution_id=WorkflowExecutionIdentifier.promote_from_model(default_execution_id), + task_id=_identifier.Identifier(_identifier.ResourceType.TASK, "local", "local", "local", "local"), execution_date=_datetime.datetime.utcnow(), stats=mock_stats.MockStats(), logging=user_space_logger, diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index fdd47d1741..0f651410bf 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -7,12 +7,15 @@ from collections import OrderedDict from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union +from typing_extensions import get_args, get_origin, get_type_hints + from flytekit.core import context_manager from flytekit.core.docstring import Docstring from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.user import FlyteValidationException from flytekit.loggers import logger from flytekit.models import interface as _interface_models +from flytekit.models.literals import Void from flytekit.types.pickle import FlytePickle T = typing.TypeVar("T") @@ -182,11 +185,17 @@ def transform_inputs_to_parameters( inputs_with_def = interface.inputs_with_defaults for k, v in inputs_vars.items(): val, _default = inputs_with_def[k] - required = _default is None - default_lv = None - if _default is not None: - default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type) - params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required) + if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val): + from flytekit import Literal, Scalar + + literal = Literal(scalar=Scalar(none_type=Void())) + params[k] = _interface_models.Parameter(var=v, default=literal, required=False) + else: + required = _default is None + default_lv = None + if _default is not None: + default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type) + params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required) return _interface_models.ParameterMap(params) @@ -274,11 +283,8 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc For now the fancy object, maybe in the future a dumb object. """ - try: - # include_extras can only be used in python >= 3.9 - type_hints = typing.get_type_hints(fn, include_extras=True) - except TypeError: - type_hints = typing.get_type_hints(fn) + + type_hints = get_type_hints(fn, include_extras=True) signature = inspect.signature(fn) return_annotation = type_hints.get("return", None) @@ -386,7 +392,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... bases = return_annotation.__bases__ # type: ignore if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"): logger.debug(f"Task returns named tuple {return_annotation}") - return dict(typing.get_type_hints(return_annotation)) + return dict(get_type_hints(return_annotation, include_extras=True)) if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore # Handle option 3 diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index e1f53dc5d2..e8f5bd4aa3 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -292,7 +292,6 @@ def get_or_create( LaunchPlan.CACHE[name or workflow.name] = lp return lp - # TODO: Add QoS after it's done def __init__( self, name: str, @@ -359,6 +358,10 @@ def clone_with( def python_interface(self) -> Interface: return self.workflow.python_interface + @property + def interface(self) -> _interface_models.TypedInterface: + return self.workflow.interface + @property def name(self) -> str: return self._name diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index 6eb4f51d81..1694af6d70 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -1,7 +1,7 @@ from __future__ import annotations import collections -from typing import Type, Union +from typing import TYPE_CHECKING, Type, Union from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext @@ -12,11 +12,15 @@ from flytekit.exceptions import user as _user_exceptions from flytekit.loggers import logger +if TYPE_CHECKING: + from flytekit.remote.remote_callable import RemoteEntity + + # This file exists instead of moving to node.py because it needs Task/Workflow/LaunchPlan and those depend on Node def create_node( - entity: Union[PythonTask, LaunchPlan, WorkflowBase], *args, **kwargs + entity: Union[PythonTask, LaunchPlan, WorkflowBase, RemoteEntity], *args, **kwargs ) -> Union[Node, VoidPromise, Type[collections.namedtuple]]: """ This is the function you want to call if you need to specify dependencies between tasks that don't consume and/or @@ -65,6 +69,8 @@ def sub_wf(): t2(t1_node.o0) """ + from flytekit.remote.remote_callable import RemoteEntity + if len(args) > 0: raise _user_exceptions.FlyteAssertion( f"Only keyword args are supported to pass inputs to workflows and tasks." @@ -75,8 +81,9 @@ def sub_wf(): not isinstance(entity, PythonTask) and not isinstance(entity, WorkflowBase) and not isinstance(entity, LaunchPlan) + and not isinstance(entity, RemoteEntity) ): - raise AssertionError("Should be but it's not") + raise AssertionError(f"Should be a callable Flyte entity (either local or fetched) but is {type(entity)}") # This function is only called from inside workflows and dynamic tasks. # That means there are two scenarios we need to take care of, compilation and local workflow execution. @@ -84,7 +91,6 @@ def sub_wf(): # When compiling, calling the entity will create a node. ctx = FlyteContext.current_context() if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: - outputs = entity(**kwargs) # This is always the output of create_and_link_node which returns create_task_output, which can be # VoidPromise, Promise, or our custom namedtuple of Promises. @@ -105,9 +111,11 @@ def sub_wf(): return node # If a Promise or custom namedtuple of Promises, we need to attach each output as an attribute to the node. - if entity.python_interface.outputs: + # todo: fix the noqas below somehow... can't add abstract property to RemoteEntity because it has to come + # before the model Template classes in FlyteTask/Workflow/LaunchPlan + if entity.interface.outputs: # noqa if isinstance(outputs, tuple): - for output_name in entity.python_interface.output_names: + for output_name in entity.interface.outputs.keys(): # noqa attr = getattr(outputs, output_name) if attr is None: raise _user_exceptions.FlyteAssertion( @@ -120,7 +128,7 @@ def sub_wf(): setattr(node, output_name, attr) node.outputs[output_name] = attr else: - output_names = entity.python_interface.output_names + output_names = [k for k in entity.interface.outputs.keys()] # noqa if len(output_names) != 1: raise _user_exceptions.FlyteAssertion(f"Output of length 1 expected but {len(output_names)} found") @@ -136,6 +144,9 @@ def sub_wf(): # Handling local execution elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + if isinstance(entity, RemoteEntity): + raise AssertionError(f"Remote entities are not yet runnable locally {entity.name}") + if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED: logger.warning(f"Manual node creation cannot be used in branch logic {entity.name}") raise Exception("Being more restrictive for now and disallowing manual node creation in branch logic") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index e5877e21fe..4c9150881d 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union, cast -from typing_extensions import Protocol +from typing_extensions import Protocol, get_args from flytekit.core import constants as _common_constants from flytekit.core import context_manager as _flyte_context @@ -23,6 +23,7 @@ from flytekit.models import types as type_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.literals import Primitive +from flytekit.models.types import SimpleType def translate_inputs_to_literals( @@ -68,29 +69,43 @@ def extract_value( ) -> _literal_models.Literal: if isinstance(input_val, list): - if flyte_literal_type.collection_type is None: + lt = flyte_literal_type + python_type = val_type + if flyte_literal_type.union_type: + for i in range(len(flyte_literal_type.union_type.variants)): + variant = flyte_literal_type.union_type.variants[i] + if variant.collection_type: + lt = variant + python_type = get_args(val_type)[i] + if lt.collection_type is None: raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}") try: - sub_type = ListTransformer.get_sub_type(val_type) + sub_type = ListTransformer.get_sub_type(python_type) except ValueError: if len(input_val) == 0: raise sub_type = type(input_val[0]) - literal_list = [extract_value(ctx, v, sub_type, flyte_literal_type.collection_type) for v in input_val] + literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val] return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literal_list)) elif isinstance(input_val, dict): - if ( - flyte_literal_type.map_value_type is None - and flyte_literal_type.simple != _type_models.SimpleType.STRUCT - ): - raise TypeError(f"Not a map type {flyte_literal_type} but got a map {input_val}") - k_type, sub_type = DictTransformer.get_dict_types(val_type) # type: ignore - if flyte_literal_type.simple == _type_models.SimpleType.STRUCT: - return TypeEngine.to_literal(ctx, input_val, type(input_val), flyte_literal_type) + lt = flyte_literal_type + python_type = val_type + if flyte_literal_type.union_type: + for i in range(len(flyte_literal_type.union_type.variants)): + variant = flyte_literal_type.union_type.variants[i] + if variant.map_value_type: + lt = variant + python_type = get_args(val_type)[i] + if variant.simple == _type_models.SimpleType.STRUCT: + lt = variant + python_type = get_args(val_type)[i] + if lt.map_value_type is None and lt.simple != _type_models.SimpleType.STRUCT: + raise TypeError(f"Not a map type {lt} but got a map {input_val}") + if lt.simple == _type_models.SimpleType.STRUCT: + return TypeEngine.to_literal(ctx, input_val, type(input_val), lt) else: - literal_map = { - k: extract_value(ctx, v, sub_type, flyte_literal_type.map_value_type) for k, v in input_val.items() - } + k_type, sub_type = DictTransformer.get_dict_types(python_type) # type: ignore + literal_map = {k: extract_value(ctx, v, sub_type, lt.map_value_type) for k, v in input_val.items()} return _literal_models.Literal(map=_literal_models.LiteralMap(literals=literal_map)) elif isinstance(input_val, Promise): # In the example above, this handles the "in2=a" type of argument @@ -863,7 +878,7 @@ def create_and_link_node( ctx: FlyteContext, entity: SupportsNodeCreation, **kwargs, -): +) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: """ This method is used to generate a node with bindings. This is not used in the execution path. """ @@ -881,7 +896,20 @@ def create_and_link_node( for k in sorted(interface.inputs): var = typed_interface.inputs[k] if k not in kwargs: - raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) + is_optional = False + if var.type.union_type: + for variant in var.type.union_type.variants: + if variant.simple == SimpleType.NONE: + val, _default = interface.inputs_with_defaults[k] + if _default is not None: + raise ValueError( + f"The default value for the optional type must be None, but got {_default}" + ) + is_optional = True + if not is_optional: + raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) + else: + continue v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 0fff171f36..4b6743afd3 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -156,7 +156,10 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return self._get_command_fn(settings) def get_container(self, settings: SerializationSettings) -> _task_model.Container: - env = {**settings.env, **self.environment} if self.environment else settings.env + env = {} + for elem in (settings.env, self.environment): + if elem: + env.update(elem) return _get_container_definition( image=get_registerable_container_image(self.container_image, settings.image_config), command=[], @@ -253,4 +256,4 @@ def get_registerable_container_image(img: Optional[str], cfg: ImageConfig) -> st # fqn will access the fully qualified name of the image (e.g. registry/imagename:version -> registry/imagename) # version will access the version part of the image (e.g. registry/imagename:version -> version) # With empty attribute, it'll access the full image path (e.g. registry/imagename:version -> registry/imagename:version) -_IMAGE_REPLACE_REGEX = re.compile(r"({{\s*\.image[s]?(?:\.([a-zA-Z]+))(?:\.([a-zA-Z]+))?\s*}})", re.IGNORECASE) +_IMAGE_REPLACE_REGEX = re.compile(r"({{\s*\.image[s]?(?:\.([a-zA-Z0-9_]+))(?:\.([a-zA-Z0-9_]+))?\s*}})", re.IGNORECASE) diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index 0c5fe786ae..7addc89197 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -15,13 +15,14 @@ # Duplicates flytekit.common.schedules.Schedule to avoid using the ExtendedSdkType metaclass. class CronSchedule(_schedule_models.Schedule): """ - Use this when you have a launch plan that you want to run on a cron expression. The syntax currently used for this - follows the `AWS convention `__ + Use this when you have a launch plan that you want to run on a cron expression. + This uses standard `cron format `__ + in case where you are using default native scheduler using the schedule attribute. .. code-block:: CronSchedule( - cron_expression="0 10 * * ? *", + schedule="*/1 * * * *", # Following schedule runs every min ) See the :std:ref:`User Guide ` for further examples. @@ -54,9 +55,10 @@ def __init__( self, cron_expression: str = None, schedule: str = None, offset: str = None, kickoff_time_input_arg: str = None ): """ - :param str cron_expression: This should be a cron expression in AWS style. + :param str cron_expression: This should be a cron expression in AWS style.Shouldn't be used in case of native scheduler. :param str schedule: This takes a cron alias (see ``_VALID_CRON_ALIASES``) or a croniter parseable schedule. - Only one of this or ``cron_expression`` can be set, not both. + Only one of this or ``cron_expression`` can be set, not both. This uses standard `cron format `_ + and is supported by native scheduler :param str offset: :param str kickoff_time_input_arg: This is a convenient argument to use when your code needs to know what time a run was kicked off. Supply the name of the input argument of your workflow to this argument here. Note @@ -67,7 +69,7 @@ def __init__( def my_wf(kickoff_time: datetime): ... schedule = CronSchedule( - cron_expression="0 10 * * ? *", + schedule="*/1 * * * *" kickoff_time_input_arg="kickoff_time") """ diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 60c42ac339..b3851e77ce 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -319,7 +319,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct())) ) - def _serialize_flyte_type(self, python_val: T, python_type: Type[T]): + def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.Any: """ If any field inside the dataclass is flyte type, we should use flyte type transformer for that field. """ @@ -328,36 +328,42 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]): from flytekit.types.schema.types import FlyteSchema from flytekit.types.structured.structured_dataset import StructuredDataset - for f in dataclasses.fields(python_type): - v = python_val.__getattribute__(f.name) - field_type = f.type - if inspect.isclass(field_type) and ( - issubclass(field_type, FlyteSchema) - or issubclass(field_type, FlyteFile) - or issubclass(field_type, FlyteDirectory) - or issubclass(field_type, StructuredDataset) - ): - lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None) - # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a - # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the - # path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here, - # so that dataclass_json can always get a remote path. - # In other words, the file transformer has special code that handles the fact that if remote_source is - # set, then the real uri in the literal should be the remote source, not the path (which may be an - # auto-generated random local path). To be sure we're writing the right path to the json, use the uri - # as determined by the transformer. - if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory): - python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri)) - elif issubclass(field_type, StructuredDataset): - python_val.__setattr__( - f.name, - field_type( - uri=lv.scalar.structured_dataset.uri, - ), - ) + if hasattr(python_type, "__origin__") and python_type.__origin__ is list: + return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val] + + if hasattr(python_type, "__origin__") and python_type.__origin__ is dict: + return {k: self._serialize_flyte_type(v, python_type.__args__[1]) for k, v in python_val.items()} - elif dataclasses.is_dataclass(field_type): - self._serialize_flyte_type(v, field_type) + if not dataclasses.is_dataclass(python_type): + return python_val + + if inspect.isclass(python_type) and ( + issubclass(python_type, FlyteSchema) + or issubclass(python_type, FlyteFile) + or issubclass(python_type, FlyteDirectory) + or issubclass(python_type, StructuredDataset) + ): + lv = TypeEngine.to_literal(FlyteContext.current_context(), python_val, python_type, None) + # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a + # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the + # path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here, + # so that dataclass_json can always get a remote path. + # In other words, the file transformer has special code that handles the fact that if remote_source is + # set, then the real uri in the literal should be the remote source, not the path (which may be an + # auto-generated random local path). To be sure we're writing the right path to the json, use the uri + # as determined by the transformer. + if issubclass(python_type, FlyteFile) or issubclass(python_type, FlyteDirectory): + return python_type(path=lv.scalar.blob.uri) + elif issubclass(python_type, StructuredDataset): + return python_type(uri=lv.scalar.structured_dataset.uri) + else: + return python_val + else: + for v in dataclasses.fields(python_type): + val = python_val.__getattribute__(v.name) + field_type = v.type + python_val.__setattr__(v.name, self._serialize_flyte_type(val, field_type)) + return python_val def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T: from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer @@ -365,6 +371,12 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine + if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is list: + return [self._deserialize_flyte_type(v, expected_python_type.__args__[0]) for v in python_val] + + if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is dict: + return {k: self._deserialize_flyte_type(v, expected_python_type.__args__[1]) for k, v in python_val.items()} + if not dataclasses.is_dataclass(expected_python_type): return python_val @@ -737,11 +749,11 @@ def literal_map_to_kwargs( """ Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task """ - if len(lm.literals) != len(python_types): + if len(lm.literals) > len(python_types): raise ValueError( f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" ) - return {k: TypeEngine.to_python_value(ctx, lm.literals[k], v) for k, v in python_types.items()} + return {k: TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) for k, v in lm.literals.items()} @classmethod def dict_to_literal_map( @@ -993,7 +1005,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp # Should really never happen, sanity check raise TypeError("Ambiguous choice of variant for union type") found_res = True - except TypeTransformerFailedError as e: + except (TypeTransformerFailedError, AttributeError) as e: logger.debug(f"Failed to convert from {python_val} to {t}", e) continue @@ -1047,7 +1059,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: ) res_tag = trans.name found_res = True - except TypeTransformerFailedError as e: + except (TypeTransformerFailedError, AttributeError) as e: logger.debug(f"Failed to convert from {lv} to {v}", e) if found_res: diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py index b5f8aa1eab..599b886ab0 100644 --- a/flytekit/deck/deck.py +++ b/flytekit/deck/deck.py @@ -3,7 +3,7 @@ from jinja2 import Environment, FileSystemLoader -from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.loggers import logger OUTPUT_DIR_JUPYTER_PREFIX = "jupyter" @@ -89,7 +89,7 @@ def _ipython_check() -> bool: return is_ipython -def _get_deck(new_user_params: ExecutionParameters): +def _get_deck(new_user_params: ExecutionParameters) -> str: """ Get flyte deck html string """ @@ -98,7 +98,11 @@ def _get_deck(new_user_params: ExecutionParameters): def _output_deck(task_name: str, new_user_params: ExecutionParameters): - output_dir = FlyteContext.current_context().file_access.get_random_local_directory() + ctx = FlyteContext.current_context() + if ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: + output_dir = ctx.execution_state.engine_dir + else: + output_dir = ctx.file_access.get_random_local_directory() deck_path = os.path.join(output_dir, DECK_FILE_NAME) with open(deck_path, "w") as f: f.write(_get_deck(new_user_params)) diff --git a/flytekit/deck/html/template.html b/flytekit/deck/html/template.html index 5886f1d936..3992ab9c0f 100644 --- a/flytekit/deck/html/template.html +++ b/flytekit/deck/html/template.html @@ -4,6 +4,7 @@ User Content + -