diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 2e2ef772b2..0f47137169 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -7,7 +7,7 @@ on: pull_request: env: - FLYTE_SDK_LOGGING_LEVEL: 10 # debug + FLYTE_SDK_LOGGING_LEVEL: 10 # debug jobs: build: @@ -49,11 +49,11 @@ jobs: pip freeze - name: Test with coverage run: | - pytest tests/flytekit/unit -m "not sandbox_test" --cov=./ --cov-report=xml + make unit_test_codecov - name: Codecov uses: codecov/codecov-action@v3.1.0 with: - fail_ci_if_error: true # optional (default = false) + fail_ci_if_error: false build-plugins: runs-on: ubuntu-latest @@ -68,6 +68,7 @@ jobs: - flytekit-aws-sagemaker - flytekit-bigquery - flytekit-data-fsspec + - flytekit-dbt - flytekit-deck-standard - flytekit-dolt - flytekit-greatexpectations @@ -87,6 +88,7 @@ jobs: - flytekit-snowflake - flytekit-spark - flytekit-sqlalchemy + - flytekit-vaex - flytekit-whylogs exclude: # flytekit-modin depends on ray which does not have a 3.10 wheel yet. @@ -109,8 +111,6 @@ jobs: # Issue tracked: https://github.com/whylabs/whylogs/issues/697 - python-version: 3.10 plugin-names: "flytekit-whylogs" - - steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -136,7 +136,6 @@ jobs: run: | cd plugins/${{ matrix.plugin-names }} coverage run -m pytest tests - lint: runs-on: ubuntu-latest steps: @@ -178,4 +177,6 @@ jobs: python -m pip install --upgrade pip==21.2.4 setuptools wheel pip install -r doc-requirements.txt - name: Build the documentation - run: make -C docs html + run: | + # TODO: Remove after buf migration is done and packages updated + PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python make -C docs html diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index 6553b84507..097d82323e 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -65,31 +65,61 @@ jobs: - uses: actions/checkout@v2 with: fetch-depth: "0" - - name: Build & Push Flytekit Python${{ matrix.python-version }} Docker Image to Github Registry - uses: whoan/docker-build-with-cache-action@v5 + - name: Set up QEMU + uses: docker/setup-qemu-action@v1 + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + - name: Login to GitHub Container Registry + if: ${{ github.event_name == 'release' }} + uses: docker/login-action@v1 with: - # https://docs.github.com/en/packages/learn-github-packages/publishing-a-package + registry: ghcr.io username: "${{ secrets.FLYTE_BOT_USERNAME }}" password: "${{ secrets.FLYTE_BOT_PAT }}" - image_name: ${{ github.repository_owner }}/flytekit - image_tag: py${{ matrix.python-version }}-latest,py${{ matrix.python-version }}-${{ github.sha }},py${{ matrix.python-version }}-${{ needs.deploy.outputs.version }} - push_git_tag: true - push_image_and_stages: true - registry: ghcr.io - build_extra_args: "--compress=true --build-arg=VERSION=${{ needs.deploy.outputs.version }} --build-arg=DOCKER_IMAGE=ghcr.io/flyteorg/flytekit:py${{ matrix.python-version }}-${{ needs.deploy.outputs.version }}" + - name: Prepare Flytekit Image Names + id: flytekit-names + uses: docker/metadata-action@v3 + with: + images: | + ghcr.io/${{ github.repository_owner }}/flytekit + tags: | + py${{ matrix.python-version }}-latest + py${{ matrix.python-version }}-${{ github.sha }} + py${{ matrix.python-version }}-${{ needs.deploy.outputs.version }} + - name: Build & Push Flytekit Python${{ matrix.python-version }} Docker Image to Github Registry + uses: docker/build-push-action@v2 + with: context: . - dockerfile: Dockerfile.py${{ matrix.python-version }} + platforms: linux/arm64, linux/amd64 + push: ${{ github.event_name == 'release' }} + tags: ${{ steps.flytekit-names.outputs.tags }} + build-args: | + VERSION=${{ needs.deploy.outputs.version }} + DOCKER_IMAGE=ghcr.io/${{ github.repository_owner }}/flytekit:py${{ matrix.python-version }}-${{ needs.deploy.outputs.version }} + PYTHON_VERSION=${{ matrix.python-version }} + file: Dockerfile + cache-from: type=gha + cache-to: type=gha,mode=max + - name: Prepare SQLAlchemy Image Names + id: sqlalchemy-names + uses: docker/metadata-action@v3 + with: + images: | + ghcr.io/${{ github.repository_owner }}/flytekit + tags: | + py${{ matrix.python-version }}-sqlalchemy-latest + py${{ matrix.python-version }}-sqlalchemy-${{ github.sha }} + py${{ matrix.python-version }}-sqlalchemy-${{ needs.deploy.outputs.version }} - name: Push SQLAlchemy Image to GitHub Registry - uses: whoan/docker-build-with-cache-action@v5 + uses: docker/build-push-action@v2 with: - # https://docs.github.com/en/packages/learn-github-packages/publishing-a-package - username: "${{ secrets.FLYTE_BOT_USERNAME }}" - password: "${{ secrets.FLYTE_BOT_PAT }}" - image_name: ${{ github.repository_owner }}/flytekit - image_tag: py${{ matrix.python-version }}-sqlalchemy-latest,py${{ matrix.python-version }}-sqlalchemy-${{ github.sha }},py${{ matrix.python-version }}-sqlalchemy-${{ needs.deploy.outputs.version }} - push_git_tag: true - push_image_and_stages: true - registry: ghcr.io - build_extra_args: "--compress=true --build-arg=VERSION=${{ needs.deploy.outputs.version }}" context: "./plugins/flytekit-sqlalchemy/" - dockerfile: Dockerfile.py${{ matrix.python-version }} + platforms: linux/arm64, linux/amd64 + push: ${{ github.event_name == 'release' }} + tags: ${{ steps.sqlalchemy-names.outputs.tags }} + build-args: | + VERSION=${{ needs.deploy.outputs.version }} + file: ./plugins/flytekit-sqlalchemy/Dockerfile.py${{ matrix.python-version }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.gitignore b/.gitignore index 2bd9a0161f..fc76e7d07c 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,4 @@ docs/source/plugins/generated/ htmlcov *.ipynb *dat +source/_tags/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000..82f4fe5366 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,23 @@ +ARG PYTHON_VERSION +FROM python:${PYTHON_VERSION}-slim-buster + +MAINTAINER Flyte Team +LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit + +WORKDIR /root +ENV PYTHONPATH /root + +ARG VERSION +ARG DOCKER_IMAGE + +RUN apt-get update && apt-get install build-essential -y + +# Pod tasks should be exposed in the default image +RUN pip install -U flytekit==$VERSION \ + flytekitplugins-pod==$VERSION \ + flytekitplugins-deck-standard==$VERSION \ + flytekitplugins-data-fsspec[aws]==$VERSION \ + flytekitplugins-data-fsspec[gcp]==$VERSION \ + scikit-learn + +ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE" diff --git a/Dockerfile.py3.10 b/Dockerfile.py3.10 deleted file mode 100644 index e5cbf2d390..0000000000 --- a/Dockerfile.py3.10 +++ /dev/null @@ -1,18 +0,0 @@ -FROM python:3.10-slim-buster - -MAINTAINER Flyte Team -LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit - -WORKDIR /root -ENV PYTHONPATH /root - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION -ARG DOCKER_IMAGE - -# Pod tasks should be exposed in the default image -RUN pip install -U flytekit==$VERSION flytekitplugins-pod==$VERSION - -ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE" diff --git a/Dockerfile.py3.7 b/Dockerfile.py3.7 deleted file mode 100644 index f4e5517955..0000000000 --- a/Dockerfile.py3.7 +++ /dev/null @@ -1,18 +0,0 @@ -FROM python:3.7-slim-buster - -MAINTAINER Flyte Team -LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit - -WORKDIR /root -ENV PYTHONPATH /root - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION -ARG DOCKER_IMAGE - -# Pod tasks should be exposed in the default image -RUN pip install -U flytekit==$VERSION flytekitplugins-pod==$VERSION - -ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE" diff --git a/Dockerfile.py3.8 b/Dockerfile.py3.8 deleted file mode 100644 index f102bdaafb..0000000000 --- a/Dockerfile.py3.8 +++ /dev/null @@ -1,18 +0,0 @@ -FROM python:3.8-slim-buster - -MAINTAINER Flyte Team -LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit - -WORKDIR /root -ENV PYTHONPATH /root - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION -ARG DOCKER_IMAGE - -# Pod tasks should be exposed in the default image -RUN pip install -U flytekit==$VERSION flytekitplugins-pod==$VERSION - -ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE" diff --git a/Dockerfile.py3.9 b/Dockerfile.py3.9 deleted file mode 100644 index 6efc02cf43..0000000000 --- a/Dockerfile.py3.9 +++ /dev/null @@ -1,18 +0,0 @@ -FROM python:3.9-slim-buster - -MAINTAINER Flyte Team -LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit - -WORKDIR /root -ENV PYTHONPATH /root - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION -ARG DOCKER_IMAGE - -# Pod tasks should be exposed in the default image -RUN pip install -U flytekit==$VERSION flytekitplugins-pod==$VERSION - -ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE" diff --git a/MANIFEST.in b/MANIFEST.in index 9b92cba26b..18dc8ed77c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,7 +3,6 @@ # include folders recursive-include flytekit * recursive-include flytekit_scripts * -recursive-include plugins * # include specific files include README.md @@ -25,6 +24,7 @@ recursive-exclude tests * recursive-exclude docs * recursive-exclude boilerplate * recursive-exclude .github * +recursive-exclude plugins * # exclude dist folder: # - contains the generated *.tar.gz and .whl files. diff --git a/Makefile b/Makefile index 4b3278bec0..8254904469 100644 --- a/Makefile +++ b/Makefile @@ -49,9 +49,18 @@ spellcheck: ## Runs a spellchecker over all code and documentation .PHONY: test test: lint unit_test +.PHONY: unit_test_codecov +unit_test_codecov: + # Ensure coverage file + rm coverage.xml || true + $(MAKE) CODECOV_OPTS="--cov=./ --cov-report=xml --cov-append" unit_test + .PHONY: unit_test unit_test: - pytest -m "not sandbox_test" tests/flytekit/unit + # Skip tensorflow tests and run them with the necessary env var set so that a working (albeit slower) + # library is used to serialize/deserialize protobufs is used. + pytest -m "not sandbox_test" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/tensorflow ${CODECOV_OPTS} && \ + PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python pytest tests/flytekit/unit/extras/tensorflow ${CODECOV_OPTS} requirements-spark2.txt: export CUSTOM_COMPILE_COMMAND := make requirements-spark2.txt requirements-spark2.txt: requirements-spark2.in install-piptools diff --git a/dev-requirements.in b/dev-requirements.in index 313ce1d82b..a02c8fa144 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -12,4 +12,8 @@ codespell google-cloud-bigquery google-cloud-bigquery-storage IPython -torch +tensorflow==2.8.1 +# Newer versions of torch bring in nvidia dependencies that are not present in windows, so +# we put this constraint while we do not have per-environment requirements files +torch<=1.12.1 +scikit-learn diff --git a/dev-requirements.txt b/dev-requirements.txt index 13f2006571..a9de992c14 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # make dev-requirements.txt # @@ -8,12 +8,16 @@ # via # -c requirements.txt # pytest-flyte -appnope==0.1.3 - # via ipython -arrow==1.2.2 +absl-py==1.3.0 + # via + # tensorboard + # tensorflow +arrow==1.2.3 # via # -c requirements.txt # jinja2-time +astunparse==1.6.3 + # via tensorflow attrs==20.3.0 # via # -c requirements.txt @@ -22,7 +26,7 @@ attrs==20.3.0 # pytest-docker backcall==0.2.0 # via ipython -bcrypt==4.0.0 +bcrypt==4.0.1 # via paramiko binaryornot==0.4.4 # via @@ -32,7 +36,7 @@ cached-property==1.5.2 # via docker-compose cachetools==5.2.0 # via google-auth -certifi==2022.6.15 +certifi==2022.12.7 # via # -c requirements.txt # requests @@ -43,7 +47,7 @@ cffi==1.15.1 # pynacl cfgv==3.3.1 # via pre-commit -chardet==5.0.0 +chardet==5.1.0 # via # -c requirements.txt # binaryornot @@ -56,29 +60,30 @@ click==8.1.3 # -c requirements.txt # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.0 # via # -c requirements.txt # flytekit -codespell==2.2.1 +codespell==2.2.2 # via -r dev-requirements.in cookiecutter==2.1.1 # via # -c requirements.txt # flytekit -coverage[toml]==6.4.4 +coverage[toml]==6.5.0 # via # -r dev-requirements.in # pytest-cov -croniter==1.3.5 +croniter==1.3.8 # via # -c requirements.txt # flytekit -cryptography==37.0.4 +cryptography==38.0.4 # via # -c requirements.txt # paramiko # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via # -c requirements.txt @@ -98,9 +103,9 @@ diskcache==5.4.0 # flytekit distlib==0.3.6 # via virtualenv -distro==1.7.0 +distro==1.8.0 # via docker-compose -docker[ssh]==6.0.0 +docker[ssh]==6.0.1 # via # -c requirements.txt # docker-compose @@ -115,68 +120,85 @@ dockerpty==0.4.1 # via docker-compose docopt==0.6.2 # via docker-compose -docstring-parser==0.14.1 +docstring-parser==0.15 # via # -c requirements.txt # flytekit -filelock==3.8.0 +exceptiongroup==1.0.4 + # via pytest +filelock==3.8.2 # via virtualenv -flyteidl==1.1.12 +flatbuffers==22.12.6 + # via tensorflow +flyteidl==1.3.0 # via # -c requirements.txt # flytekit -google-api-core[grpc]==2.8.2 +gast==0.5.3 + # via tensorflow +google-api-core[grpc]==2.11.0 # via # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.11.0 +google-auth==2.15.0 # via # google-api-core + # google-auth-oauthlib # google-cloud-core -google-cloud-bigquery==3.3.2 + # tensorboard +google-auth-oauthlib==0.4.6 + # via tensorboard +google-cloud-bigquery==3.4.0 # via -r dev-requirements.in -google-cloud-bigquery-storage==2.14.2 +google-cloud-bigquery-storage==2.16.2 # via # -r dev-requirements.in # google-cloud-bigquery google-cloud-core==2.3.2 # via google-cloud-bigquery -google-crc32c==1.3.0 +google-crc32c==1.5.0 # via google-resumable-media -google-resumable-media==2.3.3 +google-pasta==0.2.0 + # via tensorflow +google-resumable-media==2.4.0 # via google-cloud-bigquery -googleapis-common-protos==1.56.4 +googleapis-common-protos==1.57.0 # via # -c requirements.txt # flyteidl # google-api-core # grpcio-status -grpcio==1.47.0 +grpcio==1.51.1 # via # -c requirements.txt # flytekit # google-api-core # google-cloud-bigquery # grpcio-status -grpcio-status==1.47.0 + # tensorboard + # tensorflow +grpcio-status==1.51.1 # via # -c requirements.txt # flytekit # google-api-core -identify==2.5.3 +h5py==3.7.0 + # via tensorflow +identify==2.5.9 # via pre-commit -idna==3.3 +idna==3.4 # via # -c requirements.txt # requests -importlib-metadata==4.12.0 +importlib-metadata==5.1.0 # via # -c requirements.txt # click # flytekit # jsonschema # keyring + # markdown # pluggy # pre-commit # pytest @@ -185,8 +207,17 @@ iniconfig==1.1.1 # via pytest ipython==7.34.0 # via -r dev-requirements.in -jedi==0.18.1 +jaraco-classes==3.2.3 + # via + # -c requirements.txt + # keyring +jedi==0.18.2 # via ipython +jeepney==0.8.0 + # via + # -c requirements.txt + # keyring + # secretstorage jinja2==3.1.2 # via # -c requirements.txt @@ -197,24 +228,34 @@ jinja2-time==0.2.0 # via # -c requirements.txt # cookiecutter -joblib==1.1.0 +joblib==1.2.0 # via # -c requirements.txt # -r dev-requirements.in # flytekit + # scikit-learn jsonschema==3.2.0 # via # -c requirements.txt # docker-compose -keyring==23.8.2 +keras==2.8.0 + # via tensorflow +keras-preprocessing==1.1.2 + # via tensorflow +keyring==23.11.0 # via # -c requirements.txt # flytekit +libclang==14.0.6 + # via tensorflow +markdown==3.4.1 + # via tensorboard markupsafe==2.1.1 # via # -c requirements.txt # jinja2 -marshmallow==3.17.1 + # werkzeug +marshmallow==3.19.0 # via # -c requirements.txt # dataclasses-json @@ -232,14 +273,18 @@ matplotlib-inline==0.1.6 # via ipython mock==4.0.3 # via -r dev-requirements.in -mypy==0.971 +more-itertools==9.0.0 + # via + # -c requirements.txt + # jaraco-classes +mypy==0.991 # via -r dev-requirements.in mypy-extensions==0.4.3 # via # -c requirements.txt # mypy # typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via # -c requirements.txt # flytekit @@ -249,8 +294,19 @@ numpy==1.21.6 # via # -c requirements.txt # flytekit + # h5py + # keras-preprocessing + # opt-einsum # pandas # pyarrow + # scikit-learn + # scipy + # tensorboard + # tensorflow +oauthlib==3.2.2 + # via requests-oauthlib +opt-einsum==3.3.0 + # via tensorflow packaging==21.3 # via # -c requirements.txt @@ -262,7 +318,7 @@ pandas==1.3.5 # via # -c requirements.txt # flytekit -paramiko==2.11.0 +paramiko==2.12.0 # via docker parso==0.8.3 # via jedi @@ -270,23 +326,22 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -platformdirs==2.5.2 +platformdirs==2.6.0 # via virtualenv pluggy==1.0.0 # via pytest pre-commit==2.20.0 # via -r dev-requirements.in -prompt-toolkit==3.0.30 +prompt-toolkit==3.0.36 # via ipython proto-plus==1.22.1 # via # google-cloud-bigquery # google-cloud-bigquery-storage -protobuf==3.20.1 +protobuf==4.21.10 # via # -c requirements.txt # flyteidl - # flytekit # google-api-core # google-cloud-bigquery # google-cloud-bigquery-storage @@ -294,6 +349,8 @@ protobuf==3.20.1 # grpcio-status # proto-plus # protoc-gen-swagger + # tensorboard + # tensorflow protoc-gen-swagger==0.1.0 # via # -c requirements.txt @@ -303,9 +360,8 @@ ptyprocess==0.7.0 py==1.11.0 # via # -c requirements.txt - # pytest # retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via # -c requirements.txt # flytekit @@ -324,7 +380,7 @@ pygments==2.13.0 # via ipython pynacl==1.5.0 # via paramiko -pyopenssl==22.0.0 +pyopenssl==22.1.0 # via # -c requirements.txt # flytekit @@ -332,19 +388,19 @@ pyparsing==3.0.9 # via # -c requirements.txt # packaging -pyrsistent==0.18.1 +pyrsistent==0.19.2 # via # -c requirements.txt # jsonschema -pytest==7.1.2 +pytest==7.2.0 # via # -r dev-requirements.in # pytest-cov # pytest-docker # pytest-flyte -pytest-cov==3.0.0 +pytest-cov==4.0.0 # via -r dev-requirements.in -pytest-docker==1.0.0 +pytest-docker==1.0.1 # via pytest-flyte pytest-flyte @ git+https://github.com/flyteorg/pytest-flyte@main # via -r dev-requirements.in @@ -356,13 +412,13 @@ python-dateutil==2.8.2 # flytekit # google-cloud-bigquery # pandas -python-dotenv==0.20.0 +python-dotenv==0.21.0 # via docker-compose python-json-logger==2.0.4 # via # -c requirements.txt # flytekit -python-slugify==6.1.2 +python-slugify==7.0.0 # via # -c requirements.txt # cookiecutter @@ -370,7 +426,7 @@ pytimeparse==1.1.8 # via # -c requirements.txt # flytekit -pytz==2022.2.1 +pytz==2022.6 # via # -c requirements.txt # flytekit @@ -382,7 +438,7 @@ pyyaml==5.4.1 # docker-compose # flytekit # pre-commit -regex==2022.8.17 +regex==2022.10.31 # via # -c requirements.txt # docker-image-py @@ -395,8 +451,12 @@ requests==2.28.1 # flytekit # google-api-core # google-cloud-bigquery + # requests-oauthlib # responses -responses==0.21.0 + # tensorboard +requests-oauthlib==1.3.1 + # via google-auth-oauthlib +responses==0.22.0 # via # -c requirements.txt # flytekit @@ -406,6 +466,14 @@ retry==0.9.2 # flytekit rsa==4.9 # via google-auth +scikit-learn==1.0.2 + # via -r dev-requirements.in +scipy==1.7.3 + # via scikit-learn +secretstorage==3.3.3 + # via + # -c requirements.txt + # keyring singledispatchmethod==1.0 # via # -c requirements.txt @@ -413,12 +481,15 @@ singledispatchmethod==1.0 six==1.16.0 # via # -c requirements.txt + # astunparse # dockerpty # google-auth - # grpcio + # google-pasta # jsonschema + # keras-preprocessing # paramiko # python-dateutil + # tensorflow # websocket-client sortedcontainers==2.4.0 # via @@ -428,14 +499,33 @@ statsd==3.3.0 # via # -c requirements.txt # flytekit +tensorboard==2.8.0 + # via tensorflow +tensorboard-data-server==0.6.1 + # via tensorboard +tensorboard-plugin-wit==1.8.1 + # via tensorboard +tensorflow==2.8.1 + # via -r dev-requirements.in +tensorflow-estimator==2.8.0 + # via tensorflow +tensorflow-io-gcs-filesystem==0.28.0 + # via tensorflow +termcolor==2.1.1 + # via tensorflow text-unidecode==1.3 # via # -c requirements.txt # python-slugify -texttable==1.6.4 +texttable==1.6.7 # via docker-compose +threadpoolctl==3.1.0 + # via scikit-learn toml==0.10.2 - # via pre-commit + # via + # -c requirements.txt + # pre-commit + # responses tomli==2.0.1 # via # coverage @@ -443,13 +533,17 @@ tomli==2.0.1 # pytest torch==1.12.1 # via -r dev-requirements.in -traitlets==5.3.0 +traitlets==5.6.0 # via # ipython # matplotlib-inline typed-ast==1.5.4 # via mypy -typing-extensions==4.3.0 +types-toml==0.10.8.1 + # via + # -c requirements.txt + # responses +typing-extensions==4.4.0 # via # -c requirements.txt # arrow @@ -457,20 +551,21 @@ typing-extensions==4.3.0 # importlib-metadata # mypy # responses + # tensorflow # torch # typing-inspect typing-inspect==0.8.0 # via # -c requirements.txt # dataclasses-json -urllib3==1.26.12 +urllib3==1.26.13 # via # -c requirements.txt # docker # flytekit # requests # responses -virtualenv==20.16.4 +virtualenv==20.17.1 # via pre-commit wcwidth==0.2.5 # via prompt-toolkit @@ -479,16 +574,21 @@ websocket-client==0.59.0 # -c requirements.txt # docker # docker-compose -wheel==0.37.1 +werkzeug==2.2.2 + # via tensorboard +wheel==0.38.4 # via # -c requirements.txt + # astunparse # flytekit + # tensorboard wrapt==1.14.1 # via # -c requirements.txt # deprecated # flytekit -zipp==3.8.1 + # tensorflow +zipp==3.11.0 # via # -c requirements.txt # importlib-metadata diff --git a/doc-requirements.in b/doc-requirements.in index e929e404ae..7a1c95353c 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -12,8 +12,10 @@ sphinx-copybutton sphinx_fontawesome sphinx-panels sphinxcontrib-yt -grpcio cryptography +google-api-core[grpc] +scikit-learn +sphinx-tags # Packages for Plugin docs # Package name Plugin needing it @@ -34,6 +36,11 @@ jupyter # papermill pyspark # spark sqlalchemy # sqlalchemy torch # pytorch -skl2onnx # onnxscikitlearn -tf2onnx # onnxtensorflow -tensorflow # onnxtensorflow +# TODO: Remove after buf migration is done and packages updated +# skl2onnx # onnxscikitlearn +# tf2onnx # onnxtensorflow +tensorflow==2.8.1 # onnxtensorflow +whylogs # whylogs +whylabs-client # whylogs +ray # ray +scikit-learn # scikit-learn diff --git a/doc-requirements.txt b/doc-requirements.txt index db87a80b6e..fefdd3c2da 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -1,47 +1,51 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: # # make doc-requirements.txt # -e file:.#egg=flytekit # via -r doc-requirements.in -absl-py==1.2.0 +absl-py==1.3.0 # via # tensorboard # tensorflow +aiosignal==1.3.1 + # via ray alabaster==0.7.12 # via sphinx altair==4.2.0 # via great-expectations ansiwrap==0.8.4 # via papermill -appnope==0.1.3 - # via - # ipykernel - # ipython +anyio==3.6.2 + # via jupyter-server argon2-cffi==21.3.0 - # via notebook + # via + # jupyter-server + # nbclassic + # notebook argon2-cffi-bindings==21.2.0 # via argon2-cffi -arrow==1.2.2 - # via jinja2-time -astroid==2.12.5 +arrow==1.2.3 + # via + # isoduration + # jinja2-time +astroid==2.12.13 # via sphinx-autoapi +asttokens==2.2.1 + # via stack-data astunparse==1.6.3 # via tensorflow attrs==22.1.0 # via # jsonschema + # ray # visions -babel==2.10.3 +babel==2.11.0 # via sphinx backcall==0.2.0 # via ipython -backports-zoneinfo==0.2.1 - # via - # pytz-deprecation-shim - # tzlocal beautifulsoup4==4.11.1 # via # furo @@ -52,11 +56,11 @@ binaryornot==0.4.4 # via cookiecutter bleach==5.0.1 # via nbconvert -botocore==1.27.63 +botocore==1.29.25 # via -r doc-requirements.in cachetools==5.2.0 # via google-auth -certifi==2022.6.15 +certifi==2022.12.7 # via # kubernetes # requests @@ -64,29 +68,37 @@ cffi==1.15.1 # via # argon2-cffi-bindings # cryptography -chardet==5.0.0 +cfgv==3.3.1 + # via pre-commit +chardet==5.1.0 # via binaryornot charset-normalizer==2.1.1 # via requests -click==8.1.3 +click==8.0.4 # via # cookiecutter # flytekit # great-expectations # papermill -cloudpickle==2.1.0 + # ray +cloudpickle==2.2.0 # via flytekit -colorama==0.4.5 +colorama==0.4.6 # via great-expectations +comm==0.1.1 + # via ipykernel +contourpy==1.0.6 + # via matplotlib cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 +cryptography==38.0.4 # via # -r doc-requirements.in # great-expectations # pyopenssl + # secretstorage css-html-js-minify==2.5.5 # via sphinx-material cycler==0.11.0 @@ -95,7 +107,7 @@ dataclasses-json==0.5.7 # via # dolt-integrations # flytekit -debugpy==1.6.3 +debugpy==1.6.4 # via ipykernel decorator==5.1.1 # via @@ -107,11 +119,13 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==6.0.0 +distlib==0.3.6 + # via virtualenv +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit docutils==0.17.1 # via @@ -126,30 +140,41 @@ entrypoints==0.4 # altair # jupyter-client # papermill -fastjsonschema==2.16.1 +executing==1.2.0 + # via stack-data +fastjsonschema==2.16.2 # via nbformat -flatbuffers==1.12 +filelock==3.8.2 # via - # tensorflow - # tf2onnx -flyteidl==1.1.12 + # ray + # virtualenv +flatbuffers==22.12.6 + # via tensorflow +flyteidl==1.3.0 # via flytekit -fonttools==4.37.1 +fonttools==4.38.0 # via matplotlib -fsspec==2022.8.0 +fqdn==1.5.1 + # via jsonschema +frozenlist==1.3.3 + # via + # aiosignal + # ray +fsspec==2022.11.0 # via # -r doc-requirements.in # modin furo @ git+https://github.com/flyteorg/furo@main # via -r doc-requirements.in -gast==0.4.0 +gast==0.5.3 # via tensorflow -google-api-core[grpc]==2.8.2 +google-api-core[grpc]==2.11.0 # via + # -r doc-requirements.in # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.11.0 +google-auth==2.15.0 # via # google-api-core # google-auth-oauthlib @@ -160,37 +185,37 @@ google-auth-oauthlib==0.4.6 # via tensorboard google-cloud==0.34.0 # via -r doc-requirements.in -google-cloud-bigquery==3.3.2 +google-cloud-bigquery==3.4.0 # via -r doc-requirements.in -google-cloud-bigquery-storage==2.14.2 +google-cloud-bigquery-storage==2.16.2 # via google-cloud-bigquery google-cloud-core==2.3.2 # via google-cloud-bigquery -google-crc32c==1.3.0 +google-crc32c==1.5.0 # via google-resumable-media google-pasta==0.2.0 # via tensorflow -google-resumable-media==2.3.3 +google-resumable-media==2.4.0 # via google-cloud-bigquery -googleapis-common-protos==1.56.4 +googleapis-common-protos==1.57.0 # via # flyteidl # google-api-core # grpcio-status -great-expectations==0.15.20 +great-expectations==0.15.36 # via -r doc-requirements.in -greenlet==1.1.3 +greenlet==2.0.1 # via sqlalchemy -grpcio==1.47.0 +grpcio==1.51.1 # via - # -r doc-requirements.in # flytekit # google-api-core # google-cloud-bigquery # grpcio-status + # ray # tensorboard # tensorflow -grpcio-status==1.47.0 +grpcio-status==1.51.1 # via # flytekit # google-api-core @@ -198,33 +223,34 @@ h5py==3.7.0 # via tensorflow htmlmin==0.1.12 # via pandas-profiling -idna==3.3 - # via requests -imagehash==4.2.1 +identify==2.5.9 + # via pre-commit +idna==3.4 + # via + # anyio + # jsonschema + # requests +imagehash==4.3.1 # via visions imagesize==1.4.1 # via sphinx -importlib-metadata==4.12.0 +importlib-metadata==5.1.0 # via - # click # flytekit # great-expectations - # jsonschema # keyring # markdown # nbconvert # sphinx - # sqlalchemy -importlib-resources==5.9.0 - # via jsonschema -ipykernel==6.15.2 +ipykernel==6.19.0 # via # ipywidgets # jupyter # jupyter-console + # nbclassic # notebook # qtconsole -ipython==7.34.0 +ipython==8.7.0 # via # great-expectations # ipykernel @@ -232,20 +258,31 @@ ipython==7.34.0 # jupyter-console ipython-genutils==0.2.0 # via + # nbclassic # notebook # qtconsole -ipywidgets==8.0.1 +ipywidgets==8.0.3 # via # great-expectations # jupyter -jedi==0.18.1 +isoduration==20.11.0 + # via jsonschema +jaraco-classes==3.2.3 + # via keyring +jedi==0.18.2 # via ipython +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # altair # cookiecutter # great-expectations # jinja2-time + # jupyter-server + # nbclassic # nbconvert # notebook # pandas-profiling @@ -255,62 +292,76 @@ jinja2-time==0.2.0 # via cookiecutter jmespath==1.0.1 # via botocore -joblib==1.1.0 +joblib==1.2.0 # via # flytekit - # pandas-profiling # phik # scikit-learn jsonpatch==1.32 # via great-expectations jsonpointer==2.3 - # via jsonpatch -jsonschema==4.14.0 + # via + # jsonpatch + # jsonschema +jsonschema[format-nongpl]==4.7.2 # via # altair # great-expectations + # jupyter-events # nbformat + # ray jupyter==1.0.0 # via -r doc-requirements.in -jupyter-client==7.3.5 +jupyter-client==7.4.8 # via # ipykernel # jupyter-console + # jupyter-server + # nbclassic # nbclient # notebook # qtconsole jupyter-console==6.4.4 # via jupyter -jupyter-core==4.11.1 +jupyter-core==5.1.0 # via # jupyter-client + # jupyter-server + # nbclassic + # nbclient # nbconvert # nbformat # notebook # qtconsole +jupyter-events==0.5.0 + # via jupyter-server +jupyter-server==2.0.0 + # via + # nbclassic + # notebook-shim +jupyter-server-terminals==0.4.2 + # via jupyter-server jupyterlab-pygments==0.2.2 # via nbconvert -jupyterlab-widgets==3.0.2 +jupyterlab-widgets==3.0.4 # via ipywidgets -keras==2.9.0 +keras==2.8.0 # via tensorflow keras-preprocessing==1.1.2 # via tensorflow -keyring==23.8.2 +keyring==23.11.0 # via flytekit kiwisolver==1.4.4 # via matplotlib -kubernetes==24.2.0 +kubernetes==25.3.0 # via -r doc-requirements.in -lazy-object-proxy==1.7.1 +lazy-object-proxy==1.8.0 # via astroid libclang==14.0.6 # via tensorflow lxml==4.9.1 - # via - # nbconvert - # sphinx-material -makefun==1.14.0 + # via sphinx-material +makefun==1.15.0 # via great-expectations markdown==3.4.1 # via @@ -320,20 +371,19 @@ markupsafe==2.1.1 # via # jinja2 # nbconvert - # pandas-profiling # werkzeug -marshmallow==3.17.1 +marshmallow==3.19.0 # via # dataclasses-json + # great-expectations # marshmallow-enum # marshmallow-jsonschema marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.5.3 +matplotlib==3.6.2 # via - # missingno # pandas-profiling # phik # seaborn @@ -341,86 +391,99 @@ matplotlib-inline==0.1.6 # via # ipykernel # ipython -missingno==0.5.1 - # via pandas-profiling mistune==2.0.4 # via # great-expectations # nbconvert -modin==0.12.1 +modin==0.17.1 # via -r doc-requirements.in -multimethod==1.8 +more-itertools==9.0.0 + # via jaraco-classes +msgpack==1.0.4 + # via ray +multimethod==1.9 # via # pandas-profiling # visions mypy-extensions==0.4.3 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit -nbclient==0.6.7 +nbclassic==0.4.8 + # via notebook +nbclient==0.7.2 # via # nbconvert # papermill -nbconvert==7.0.0 +nbconvert==7.2.6 # via # jupyter + # jupyter-server + # nbclassic # notebook -nbformat==5.4.0 +nbformat==5.7.0 # via # great-expectations + # jupyter-server + # nbclassic # nbclient # nbconvert # notebook # papermill -nest-asyncio==1.5.5 +nest-asyncio==1.5.6 # via # ipykernel # jupyter-client - # nbclient + # nbclassic # notebook -networkx==2.6.3 +networkx==2.8.8 # via visions -notebook==6.4.12 +nodeenv==1.7.0 + # via pre-commit +notebook==6.5.2 # via # great-expectations # jupyter -numpy==1.21.6 +notebook-shim==0.2.2 + # via nbclassic +numpy==1.23.5 # via # altair - # flytekit + # contourpy # great-expectations # h5py # imagehash # keras-preprocessing # matplotlib - # missingno # modin - # onnx - # onnxconverter-common # opt-einsum # pandas # pandas-profiling # pandera + # patsy # phik # pyarrow # pywavelets + # ray # scikit-learn # scipy # seaborn - # skl2onnx + # statsmodels # tensorboard # tensorflow - # tf2onnx # visions -oauthlib==3.2.1 - # via requests-oauthlib -onnx==1.12.0 +nvidia-cublas-cu11==11.10.3.66 # via - # onnxconverter-common - # skl2onnx - # tf2onnx -onnxconverter-common==1.12.2 - # via skl2onnx + # nvidia-cudnn-cu11 + # torch +nvidia-cuda-nvrtc-cu11==11.7.99 + # via torch +nvidia-cuda-runtime-cu11==11.7.99 + # via torch +nvidia-cudnn-cu11==8.5.0.96 + # via torch +oauthlib==3.2.2 + # via requests-oauthlib opt-einsum==3.3.0 # via tensorflow packaging==21.3 @@ -429,6 +492,7 @@ packaging==21.3 # google-cloud-bigquery # great-expectations # ipykernel + # jupyter-server # marshmallow # matplotlib # modin @@ -436,8 +500,8 @@ packaging==21.3 # pandera # qtpy # sphinx - # tensorflow -pandas==1.3.5 + # statsmodels +pandas==1.5.2 # via # altair # dolt-integrations @@ -448,10 +512,11 @@ pandas==1.3.5 # pandera # phik # seaborn + # statsmodels # visions -pandas-profiling==3.2.0 +pandas-profiling==3.5.0 # via -r doc-requirements.in -pandera==0.9.0 +pandera==0.13.4 # via -r doc-requirements.in pandocfilters==1.5.0 # via nbconvert @@ -459,24 +524,33 @@ papermill==2.4.0 # via -r doc-requirements.in parso==0.8.3 # via jedi +patsy==0.5.3 + # via statsmodels pexpect==4.8.0 # via ipython -phik==0.12.2 +phik==0.12.3 # via pandas-profiling pickleshare==0.7.5 # via ipython -pillow==9.2.0 +pillow==9.3.0 # via # imagehash # matplotlib # visions -pkgutil-resolve-name==1.3.10 - # via jsonschema -plotly==5.10.0 +platformdirs==2.6.0 + # via + # jupyter-core + # virtualenv +plotly==5.11.0 # via -r doc-requirements.in -prometheus-client==0.14.1 - # via notebook -prompt-toolkit==3.0.30 +pre-commit==2.20.0 + # via sphinx-tags +prometheus-client==0.15.0 + # via + # jupyter-server + # nbclassic + # notebook +prompt-toolkit==3.0.36 # via # ipython # jupyter-console @@ -484,39 +558,40 @@ proto-plus==1.22.1 # via # google-cloud-bigquery # google-cloud-bigquery-storage -protobuf==3.19.4 +protobuf==4.21.10 # via # flyteidl - # flytekit # google-api-core # google-cloud-bigquery # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status - # onnx - # onnxconverter-common # proto-plus # protoc-gen-swagger - # skl2onnx + # ray # tensorboard # tensorflow + # whylogs protoc-gen-swagger==0.1.0 # via flyteidl -psutil==5.9.1 - # via ipykernel +psutil==5.9.4 + # via + # ipykernel + # modin ptyprocess==0.7.0 # via # pexpect # terminado +pure-eval==0.2.2 + # via stack-data py==1.11.0 # via retry py4j==0.10.9.5 # via pyspark -pyarrow==6.0.1 +pyarrow==10.0.1 # via # flytekit # google-cloud-bigquery - # pandera pyasn1==0.4.8 # via # pyasn1-modules @@ -525,8 +600,9 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pydantic==1.10.0 +pydantic==1.10.2 # via + # great-expectations # pandas-profiling # pandera pygments==2.13.0 @@ -538,16 +614,16 @@ pygments==2.13.0 # qtconsole # sphinx # sphinx-prompt -pyopenssl==22.0.0 +pyopenssl==22.1.0 # via flytekit -pyparsing==2.4.7 +pyparsing==3.0.9 # via # great-expectations # matplotlib # packaging -pyrsistent==0.18.1 +pyrsistent==0.19.2 # via jsonschema -pyspark==3.3.0 +pyspark==3.3.1 # via -r doc-requirements.in python-dateutil==2.8.2 # via @@ -561,15 +637,18 @@ python-dateutil==2.8.2 # kubernetes # matplotlib # pandas + # whylabs-client python-json-logger==2.0.4 - # via flytekit -python-slugify[unidecode]==6.1.2 + # via + # flytekit + # jupyter-events +python-slugify[unidecode]==7.0.0 # via # cookiecutter # sphinx-material pytimeparse==1.1.8 # via flytekit -pytz==2022.2.1 +pytz==2022.6 # via # babel # flytekit @@ -577,27 +656,34 @@ pytz==2022.2.1 # pandas pytz-deprecation-shim==0.1.0.post0 # via tzlocal -pywavelets==1.3.0 +pywavelets==1.4.1 # via imagehash pyyaml==6.0 # via # cookiecutter # flytekit + # jupyter-events # kubernetes # pandas-profiling # papermill + # pre-commit + # ray # sphinx-autoapi -pyzmq==23.2.1 +pyzmq==24.0.1 # via # ipykernel # jupyter-client + # jupyter-server + # nbclassic # notebook # qtconsole -qtconsole==5.3.2 +qtconsole==5.4.0 # via jupyter -qtpy==2.2.0 +qtpy==2.3.0 # via qtconsole -regex==2022.8.17 +ray==2.1.0 + # via -r doc-requirements.in +regex==2022.10.31 # via docker-image-py requests==2.28.1 # via @@ -610,61 +696,64 @@ requests==2.28.1 # kubernetes # pandas-profiling # papermill + # ray # requests-oauthlib # responses # sphinx # tensorboard - # tf2onnx requests-oauthlib==1.3.1 # via # google-auth-oauthlib # kubernetes -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit +rfc3339-validator==0.1.4 + # via jsonschema +rfc3986-validator==0.1.1 + # via jsonschema rsa==4.9 # via google-auth ruamel-yaml==0.17.17 # via great-expectations -ruamel-yaml-clib==0.2.6 +ruamel-yaml-clib==0.2.7 # via ruamel-yaml -scikit-learn==1.0.2 - # via skl2onnx -scipy==1.7.3 +scikit-learn==1.1.3 + # via -r doc-requirements.in +scipy==1.9.3 # via # great-expectations # imagehash - # missingno # pandas-profiling # phik # scikit-learn - # seaborn - # skl2onnx -seaborn==0.11.2 - # via - # missingno - # pandas-profiling + # statsmodels +seaborn==0.12.1 + # via pandas-profiling +secretstorage==3.3.3 + # via keyring send2trash==1.8.0 - # via notebook -singledispatchmethod==1.0 - # via flytekit + # via + # jupyter-server + # nbclassic + # notebook six==1.16.0 # via + # asttokens # astunparse # bleach # google-auth # google-pasta - # grpcio - # imagehash # keras-preprocessing # kubernetes + # patsy # python-dateutil + # rfc3339-validator # sphinx-code-include # tensorflow - # tf2onnx -skl2onnx==1.12 - # via -r doc-requirements.in +sniffio==1.3.0 + # via anyio snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 @@ -684,14 +773,15 @@ sphinx==4.5.0 # sphinx-material # sphinx-panels # sphinx-prompt + # sphinx-tags # sphinxcontrib-yt -sphinx-autoapi==1.9.0 +sphinx-autoapi==2.0.0 # via -r doc-requirements.in -sphinx-basic-ng==0.0.1a12 +sphinx-basic-ng==1.0.0b1 # via furo sphinx-code-include==1.1.1 # via -r doc-requirements.in -sphinx-copybutton==0.5.0 +sphinx-copybutton==0.5.1 # via -r doc-requirements.in sphinx-fontawesome==0.0.6 # via -r doc-requirements.in @@ -703,6 +793,8 @@ sphinx-panels==0.6.0 # via -r doc-requirements.in sphinx-prompt==1.5.0 # via -r doc-requirements.in +sphinx-tags==0.1.6 + # via -r doc-requirements.in sphinxcontrib-applehelp==1.0.2 # via sphinx sphinxcontrib-devhelp==1.0.2 @@ -717,106 +809,115 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-yt==0.2.2 # via -r doc-requirements.in -sqlalchemy==1.4.40 +sqlalchemy==1.4.44 # via -r doc-requirements.in +stack-data==0.6.2 + # via ipython statsd==3.3.0 # via flytekit +statsmodels==0.13.5 + # via pandas-profiling tangled-up-in-unicode==0.2.0 - # via - # pandas-profiling - # visions -tenacity==8.0.1 + # via visions +tenacity==8.1.0 # via # papermill # plotly -tensorboard==2.9.1 +tensorboard==2.8.0 # via tensorflow tensorboard-data-server==0.6.1 # via tensorboard tensorboard-plugin-wit==1.8.1 # via tensorboard -tensorflow==2.9.1 +tensorflow==2.8.1 # via -r doc-requirements.in -tensorflow-estimator==2.9.0 +tensorflow-estimator==2.8.0 + # via tensorflow +tensorflow-io-gcs-filesystem==0.28.0 # via tensorflow -tensorflow-io-gcs-filesystem==0.26.0 +termcolor==2.1.1 # via tensorflow -termcolor==1.1.0 +terminado==0.17.1 # via - # great-expectations - # tensorflow -terminado==0.15.0 - # via notebook + # jupyter-server + # jupyter-server-terminals + # nbclassic + # notebook text-unidecode==1.3 # via python-slugify textwrap3==0.9.2 # via ansiwrap -tf2onnx==1.12.0 - # via -r doc-requirements.in threadpoolctl==3.1.0 # via scikit-learn -tinycss2==1.1.1 +tinycss2==1.2.1 # via nbconvert +toml==0.10.2 + # via + # pre-commit + # responses toolz==0.12.0 # via altair -torch==1.12.1 +torch==1.13.0 # via -r doc-requirements.in tornado==6.2 # via # ipykernel # jupyter-client + # jupyter-server + # nbclassic # notebook # terminado -tqdm==4.64.0 +tqdm==4.64.1 # via # great-expectations # pandas-profiling # papermill -traitlets==5.3.0 +traitlets==5.6.0 # via + # comm # ipykernel # ipython # ipywidgets # jupyter-client # jupyter-core + # jupyter-events + # jupyter-server # matplotlib-inline + # nbclassic # nbclient # nbconvert # nbformat # notebook # qtconsole -typed-ast==1.5.4 - # via astroid -typing-extensions==4.3.0 +typeguard==2.13.3 + # via pandas-profiling +types-toml==0.10.8.1 + # via responses +typing-extensions==4.4.0 # via - # argon2-cffi - # arrow # astroid # flytekit # great-expectations - # importlib-metadata - # jsonschema - # kiwisolver - # onnx - # pandera # pydantic - # responses # tensorflow # torch # typing-inspect + # whylogs typing-inspect==0.8.0 # via # dataclasses-json # pandera -tzdata==2022.2 +tzdata==2022.7 # via pytz-deprecation-shim tzlocal==4.2 # via great-expectations -unidecode==1.3.4 +unidecode==1.3.6 # via # python-slugify # sphinx-autoapi -urllib3==1.26.12 +uri-template==1.2.0 + # via jsonschema +urllib3==1.26.13 # via # botocore # docker @@ -825,26 +926,42 @@ urllib3==1.26.12 # kubernetes # requests # responses -visions[type_image_path]==0.7.4 + # whylabs-client +virtualenv==20.17.1 + # via + # pre-commit + # ray +visions[type_image_path]==0.7.5 # via pandas-profiling wcwidth==0.2.5 # via prompt-toolkit +webcolors==1.12 + # via jsonschema webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.4.0 +websocket-client==1.4.2 # via # docker + # jupyter-server # kubernetes werkzeug==2.2.2 # via tensorboard -wheel==0.37.1 +wheel==0.38.4 # via # astunparse # flytekit + # nvidia-cublas-cu11 + # nvidia-cuda-runtime-cu11 # tensorboard -widgetsnbextension==4.0.2 +whylabs-client==0.4.2 + # via -r doc-requirements.in +whylogs==1.1.16 + # via -r doc-requirements.in +whylogs-sketching==3.4.1.dev3 + # via whylogs +widgetsnbextension==4.0.4 # via ipywidgets wrapt==1.14.1 # via @@ -853,10 +970,8 @@ wrapt==1.14.1 # flytekit # pandera # tensorflow -zipp==3.8.1 - # via - # importlib-metadata - # importlib-resources +zipp==3.11.0 + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/docs/source/conf.py b/docs/source/conf.py index 4eb205e786..6aba967ae9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -61,6 +61,7 @@ "sphinx_fontawesome", "sphinx_panels", "sphinxcontrib.yt", + "sphinx_tags", ] # build the templated autosummary files @@ -68,6 +69,8 @@ autodoc_typehints = "description" +suppress_warnings = ["autosectionlabel.*"] + # autosectionlabel throws warnings if section names are duplicated. # The following tells autosectionlabel to not throw a warning for # duplicated section names that are in different documents. @@ -224,3 +227,8 @@ } autoclass_content = "both" + +# Tags config +tags_create_tags = True +tags_page_title = "Tag" +tags_overview_title = "All Tags" diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 48e0b7e940..3d887f6f3c 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -4,6 +4,8 @@ Flytekit Contribution Guide ########################### +.. tags:: Contribute, Basic + First off, thank you for thinking about contributing! Below you'll find instructions that will hopefully guide you through how to fix, improve, and extend Flytekit. Please also take some time to read through the :std:ref:`design guides `, which describe the various parts of Flytekit and should make contributing easier. diff --git a/docs/source/data.extend.rst b/docs/source/data.extend.rst index 3f06961022..517df6dd3f 100644 --- a/docs/source/data.extend.rst +++ b/docs/source/data.extend.rst @@ -1,6 +1,9 @@ -############################## -Extend Data Persistence layer -############################## +###################### +Data Persistence Layer +###################### + +.. tags:: Data, AWS, GCP, Intermediate + Flytekit provides a data persistence layer, which is used for recording metadata that is shared with the Flyte backend. This persistence layer is available for various types to store raw user data and is designed to be cross-cloud compatible. Moreover, it is designed to be extensible and users can bring their own data persistence plugins by following the persistence interface. @@ -29,7 +32,7 @@ You can use the fsspec plugin implementation to utilize all its available plugin The data persistence layer helps store logs of metadata and raw user data. As a consequence of the implementation, an S3 driver can be installed using ``pip install s3fs``. -`Here `_ is a code snippet that shows protocols mapped to the class it implements. +`Here `__ is a code snippet that shows protocols mapped to the class it implements. Once you install the plugin, it overrides all default implementations of the `DataPersistencePlugins `_ and provides the ones supported by fsspec. diff --git a/docs/source/deck.rst b/docs/source/deck.rst new file mode 100644 index 0000000000..43159c51f4 --- /dev/null +++ b/docs/source/deck.rst @@ -0,0 +1,5 @@ + +.. automodule:: flytekit.deck + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/design/authoring.rst b/docs/source/design/authoring.rst index be2a32b1df..46a094399d 100644 --- a/docs/source/design/authoring.rst +++ b/docs/source/design/authoring.rst @@ -4,6 +4,8 @@ Authoring Structure ####################### +.. tags:: Design, Basic + One of the core features of Flytekit is to enable users to write tasks and workflows. In this section, we will understand how it works internally. .. note:: diff --git a/docs/source/design/clis.rst b/docs/source/design/clis.rst index cd6632b8b9..bde51e774d 100644 --- a/docs/source/design/clis.rst +++ b/docs/source/design/clis.rst @@ -4,6 +4,8 @@ Command Line Interfaces and Clients ################################### +.. tags:: CLI, Basic + Flytekit currently ships with two CLIs, both of which rely on the same client implementation code. ******* @@ -34,6 +36,8 @@ Pyflyte Unlike ``flytectl``, think of this CLI as code-aware, which is responsible for the serialization (compilation) step in the registration flow. It will parse through the user code, looking for tasks, workflows, and launch plans, and compile them to `protobuf files `__. +.. _pyflyte-run: + What is ``pyflyte run``? ======================== @@ -45,6 +49,8 @@ Suppose you execute a script that defines 10 tasks and a workflow that calls onl It is considered fast registration because when a script is executed using ``pyflyte run``, the script is bundled up and uploaded to FlyteAdmin. When the task is executed in the backend, this zipped file is extracted and used. +.. _pyflyte-register: + What is ``pyflyte register``? ============================= @@ -63,7 +69,7 @@ The ``pyflyte register`` command bridges the gap between ``pyflyte package`` + ` .. note :: - You can’t use ``pyflyte register`` if you are unaware of the run-time options yet (IAM role, service account, and so on). + You can't use ``pyflyte register`` if you are unaware of the run-time options yet (IAM role, service account, and so on). Usage ===== @@ -75,7 +81,7 @@ Usage In a broad way, ``pyflyte register`` is equivalent to ``pyflyte run`` minus launching workflows, with the exception that ``pyflyte run`` can only register a single workflow, whereas ``pyflyte register`` can register all workflows in a repository. What is the difference between ``pyflyte package + flytectl register`` and ``pyflyte register``? -============================================================================================== +================================================================================================ ``pyflyte package + flytectl register`` works well with multiple FlyteAdmins since it produces a portable package. You can also use it to run scripts in CI. diff --git a/docs/source/design/control_plane.rst b/docs/source/design/control_plane.rst index d4c57d733d..5aa9d30e55 100644 --- a/docs/source/design/control_plane.rst +++ b/docs/source/design/control_plane.rst @@ -4,6 +4,8 @@ FlyteRemote: A Programmatic Control Plane Interface ################################################### +.. tags:: Remote, Basic + For those who require programmatic access to the control plane, the :mod:`~flytekit.remote` module enables you to perform certain operations in a Python runtime environment. @@ -299,6 +301,10 @@ You can use :meth:`~flytekit.remote.remote.FlyteRemote.sync` to sync the entity synced_execution = remote.sync(execution, sync_nodes=True) node_keys = synced_execution.node_executions.keys() +.. note:: + + During the sync, you may come across ``Received message larger than max (xxx vs. 4194304)`` error if the message size is too large. In that case, edit the ``flyte-admin-base-config`` config map using the command ``kubectl edit cm flyte-admin-base-config -n flyte`` to increase the ``maxMessageSizeBytes`` value. Refer to the :ref:`troubleshooting guide ` in case you've queries about the command's usage. + ``node_executions`` will fetch all the underlying node executions recursively. To fetch output of a specific node execution: diff --git a/docs/source/design/execution.rst b/docs/source/design/execution.rst index bf72f6ad3a..22683fb8b2 100644 --- a/docs/source/design/execution.rst +++ b/docs/source/design/execution.rst @@ -3,6 +3,9 @@ ####################### Execution Time Support ####################### + +.. tags:: Design, Basic + Most of the tasks that are written in Flytekit will be Python functions decorated with ``@task`` which turns the body of the function into a Flyte task, capable of being run independently, or included in any number of workflows. The interaction between Flytekit and these tasks do not end once they have been serialized and registered onto the Flyte control plane however. When compiled, the command that will be executed when the task is run is hardcoded into the task definition itself. In the basic ``@task`` decorated function scenario, the command to be run will be something containing ``pyflyte-execute``, which is one of the CLIs discussed in that section. diff --git a/docs/source/design/index.rst b/docs/source/design/index.rst index 355c35acd0..1539baa3a1 100644 --- a/docs/source/design/index.rst +++ b/docs/source/design/index.rst @@ -1,8 +1,8 @@ .. _design: -############################ +######## Overview -############################ +######## Flytekit is comprised of a handful of different logical components, each discusssed in greater detail below: diff --git a/docs/source/design/models.rst b/docs/source/design/models.rst index 63fed55ea2..7e92d77dae 100644 --- a/docs/source/design/models.rst +++ b/docs/source/design/models.rst @@ -1,8 +1,10 @@ .. _design-models: -###################### +########### Model Files -###################### +########### + +.. tags:: Design, Basic *********** Description diff --git a/docs/source/extras.pytorch.rst b/docs/source/extras.pytorch.rst index 12fd3d62d9..0f51bf219d 100644 --- a/docs/source/extras.pytorch.rst +++ b/docs/source/extras.pytorch.rst @@ -1,6 +1,9 @@ ############ PyTorch Type ############ + +.. tags:: MachineLearning, Basic + .. automodule:: flytekit.extras.pytorch :no-members: :no-inherited-members: diff --git a/docs/source/extras.sklearn.rst b/docs/source/extras.sklearn.rst new file mode 100644 index 0000000000..a2efcfa84b --- /dev/null +++ b/docs/source/extras.sklearn.rst @@ -0,0 +1,7 @@ +############ +Sklearn Type +############ +.. automodule:: flytekit.extras.sklearn + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/extras.sqlite3.rst b/docs/source/extras.sqlite3.rst new file mode 100644 index 0000000000..f1a1174480 --- /dev/null +++ b/docs/source/extras.sqlite3.rst @@ -0,0 +1,10 @@ +############ +SQLite3 Task +############ + +.. tags:: SQL, Basic + +.. automodule:: flytekit.extras.sqlite3 + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/extras.tasks.rst b/docs/source/extras.tasks.rst new file mode 100644 index 0000000000..f25e607b23 --- /dev/null +++ b/docs/source/extras.tasks.rst @@ -0,0 +1,8 @@ +########## +Shell Task +########## + +.. automodule:: flytekit.extras.tasks + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/extras.tensorflow.rst b/docs/source/extras.tensorflow.rst new file mode 100644 index 0000000000..699dd44da0 --- /dev/null +++ b/docs/source/extras.tensorflow.rst @@ -0,0 +1,7 @@ +############ +TensorFlow Type +############ +.. automodule:: flytekit.extras.tensorflow + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/index.rst b/docs/source/index.rst index a65a50ea2e..13630e9e58 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -7,6 +7,8 @@ Flytekit Python Reference ************************* +`Flytekit Tags <_tags/tagsindex.html>`__ + This section of the documentation provides detailed descriptions of the high-level design of ``Flytekit`` and an API reference for specific usage details of Python functions, classes, and decorators that you import to specify tasks, build workflows, and extend ``Flytekit``. @@ -76,6 +78,7 @@ Expected output: remote testing extend + deck plugins/index tasks.extend types.extend diff --git a/docs/source/plugins/athena.rst b/docs/source/plugins/athena.rst index b5e506e64a..30e7292258 100644 --- a/docs/source/plugins/athena.rst +++ b/docs/source/plugins/athena.rst @@ -4,6 +4,8 @@ AWS Athena Plugin API reference ################################################### +.. tags:: Integration, AWS, Data + .. automodule:: flytekitplugins.athena :no-members: :no-inherited-members: diff --git a/docs/source/plugins/awsbatch.rst b/docs/source/plugins/awsbatch.rst index 5bdf739dbc..0882ef143c 100644 --- a/docs/source/plugins/awsbatch.rst +++ b/docs/source/plugins/awsbatch.rst @@ -1,9 +1,11 @@ .. _awsbatch: ################################################### -AWSBatch Plugin API reference +AWS Batch Plugin API reference ################################################### +.. tags:: Data, Integration, AWS + .. automodule:: flytekitplugins.awsbatch :no-inherited-members: :no-special-members: diff --git a/docs/source/plugins/awssagemaker.rst b/docs/source/plugins/awssagemaker.rst index 2ba9455f0e..5921c0eb40 100644 --- a/docs/source/plugins/awssagemaker.rst +++ b/docs/source/plugins/awssagemaker.rst @@ -4,6 +4,8 @@ AWS Sagemaker API reference ################################################### +.. tags:: Integration, MachineLearning, AWS + .. automodule:: flytekitplugins.awssagemaker :no-members: :no-inherited-members: diff --git a/docs/source/plugins/bigquery.rst b/docs/source/plugins/bigquery.rst index 525229d991..fd7c8d5fcd 100644 --- a/docs/source/plugins/bigquery.rst +++ b/docs/source/plugins/bigquery.rst @@ -4,6 +4,8 @@ Google Bigquery Plugin API reference ################################################### +.. tags:: GCP, Data, Integration + .. automodule:: flytekitplugins.bigquery :no-members: :no-inherited-members: diff --git a/docs/source/plugins/dbt.rst b/docs/source/plugins/dbt.rst new file mode 100644 index 0000000000..0588eda641 --- /dev/null +++ b/docs/source/plugins/dbt.rst @@ -0,0 +1,12 @@ +.. _dbt: + +################################################### +DBT Plugin API reference +################################################### + +.. tags:: Data + +.. automodule:: flytekitplugins.dbt + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/plugins/deck.rst b/docs/source/plugins/deck.rst index 496fec70c3..58922c7e23 100644 --- a/docs/source/plugins/deck.rst +++ b/docs/source/plugins/deck.rst @@ -4,6 +4,8 @@ Deck standard Plugin API reference ################################################### +.. tags:: UI + .. automodule:: flytekitplugins.deck :no-members: :no-inherited-members: diff --git a/docs/source/plugins/dolt.rst b/docs/source/plugins/dolt.rst index bb2ff50577..2f937b9dbb 100644 --- a/docs/source/plugins/dolt.rst +++ b/docs/source/plugins/dolt.rst @@ -4,6 +4,8 @@ Dolt standard Plugin API reference ################################################### +.. tags:: Integration, Data, SQL + .. automodule:: flytekitplugins.dolt :no-members: :no-inherited-members: diff --git a/docs/source/plugins/fsspec.rst b/docs/source/plugins/fsspec.rst index 4a40165c17..cbf3816e54 100644 --- a/docs/source/plugins/fsspec.rst +++ b/docs/source/plugins/fsspec.rst @@ -4,6 +4,8 @@ FS Spec API reference ################################################### +.. tags:: Data, AWS, GCP + .. automodule:: flytekitplugins.fsspec :no-members: :no-inherited-members: diff --git a/docs/source/plugins/greatexpectations.rst b/docs/source/plugins/greatexpectations.rst index d6bbaca0a8..d993c9e129 100644 --- a/docs/source/plugins/greatexpectations.rst +++ b/docs/source/plugins/greatexpectations.rst @@ -4,6 +4,8 @@ Great expectations API reference ################################################### +.. tags:: Integration, Data, SQL + .. automodule:: flytekitplugins.great_expectations :no-members: :no-inherited-members: diff --git a/docs/source/plugins/hive.rst b/docs/source/plugins/hive.rst index 4f0cce6177..63e0abdd7c 100644 --- a/docs/source/plugins/hive.rst +++ b/docs/source/plugins/hive.rst @@ -4,6 +4,8 @@ Hive API reference ################################################### +.. tags:: Integration, Data + .. automodule:: flytekitplugins.hive :no-members: :no-inherited-members: diff --git a/docs/source/plugins/index.rst b/docs/source/plugins/index.rst index 7cf2edc539..dd1f59238d 100644 --- a/docs/source/plugins/index.rst +++ b/docs/source/plugins/index.rst @@ -26,6 +26,8 @@ Plugin API reference * :ref:`ONNX PyTorch ` - ONNX PyTorch API reference * :ref:`ONNX TensorFlow ` - ONNX TensorFlow API reference * :ref:`ONNX ScikitLearn ` - ONNX ScikitLearn API reference +* :ref:`Ray ` - Ray +* :ref:`DBT ` - DBT API reference .. toctree:: :maxdepth: 2 @@ -53,3 +55,5 @@ Plugin API reference ONNX PyTorch ONNX TensorFlow ONNX ScikitLearn + Ray + DBT diff --git a/docs/source/plugins/kfmpi.rst b/docs/source/plugins/kfmpi.rst index 520eb3794f..1ed280b3c7 100644 --- a/docs/source/plugins/kfmpi.rst +++ b/docs/source/plugins/kfmpi.rst @@ -4,6 +4,8 @@ KfMPI API reference ################################################### +.. tags:: Integration, DistributedComputing, MachineLearning, KubernetesOperator + .. automodule:: flytekitplugins.kfmpi :no-members: :no-inherited-members: diff --git a/docs/source/plugins/kfpytorch.rst b/docs/source/plugins/kfpytorch.rst index b06680ebab..343d726a8f 100644 --- a/docs/source/plugins/kfpytorch.rst +++ b/docs/source/plugins/kfpytorch.rst @@ -4,6 +4,8 @@ KF Pytorch API reference ################################################### +.. tags:: Integration, DistributedComputing, MachineLearning, KubernetesOperator + .. automodule:: flytekitplugins.kfpytorch :no-members: :no-inherited-members: diff --git a/docs/source/plugins/kftensorflow.rst b/docs/source/plugins/kftensorflow.rst index 094b20fa7e..1a49ed7a05 100644 --- a/docs/source/plugins/kftensorflow.rst +++ b/docs/source/plugins/kftensorflow.rst @@ -4,6 +4,8 @@ KFTensorflow Plugin API reference ################################################### +.. tags:: Integration, DistributedComputing, MachineLearning, KubernetesOperator + .. automodule:: flytekitplugins.kftensorflow :no-members: :no-inherited-members: diff --git a/docs/source/plugins/modin.rst b/docs/source/plugins/modin.rst index 4405f534aa..735e23460a 100644 --- a/docs/source/plugins/modin.rst +++ b/docs/source/plugins/modin.rst @@ -4,6 +4,8 @@ Modin API reference ################################################### +.. tags:: Integration, DataFrame + .. automodule:: flytekitplugins.modin :no-members: :no-inherited-members: diff --git a/docs/source/plugins/onnxpytorch.rst b/docs/source/plugins/onnxpytorch.rst index 827a7c4bf2..f66be6c08e 100644 --- a/docs/source/plugins/onnxpytorch.rst +++ b/docs/source/plugins/onnxpytorch.rst @@ -4,6 +4,8 @@ PyTorch ONNX API reference ########################## +.. tags:: Integration, MachineLearning + .. automodule:: flytekitplugins.onnxpytorch :no-members: :no-inherited-members: diff --git a/docs/source/plugins/onnxscikitlearn.rst b/docs/source/plugins/onnxscikitlearn.rst index 3c26b69500..cd6e9f2492 100644 --- a/docs/source/plugins/onnxscikitlearn.rst +++ b/docs/source/plugins/onnxscikitlearn.rst @@ -4,6 +4,8 @@ ScikitLearn ONNX API reference ############################## +.. tags:: Integration, MachineLearning + .. automodule:: flytekitplugins.onnxscikitlearn :no-members: :no-inherited-members: diff --git a/docs/source/plugins/onnxtensorflow.rst b/docs/source/plugins/onnxtensorflow.rst index 9b5036cea8..42496526d9 100644 --- a/docs/source/plugins/onnxtensorflow.rst +++ b/docs/source/plugins/onnxtensorflow.rst @@ -4,6 +4,8 @@ TensorFlow ONNX API reference ############################# +.. tags:: Integration, MachineLearning + .. automodule:: flytekitplugins.onnxtensorflow :no-members: :no-inherited-members: diff --git a/docs/source/plugins/pandera.rst b/docs/source/plugins/pandera.rst index 878f1509d1..fc4b6b2f25 100644 --- a/docs/source/plugins/pandera.rst +++ b/docs/source/plugins/pandera.rst @@ -4,6 +4,8 @@ Pandera API reference ################################################### +.. tags:: Integration, DataFrame + .. automodule:: flytekitplugins.pandera :no-members: :no-inherited-members: diff --git a/docs/source/plugins/papermill.rst b/docs/source/plugins/papermill.rst index f1562a3111..b61c5e6c90 100644 --- a/docs/source/plugins/papermill.rst +++ b/docs/source/plugins/papermill.rst @@ -4,6 +4,8 @@ Papermill API reference ################################################### +.. tags:: Integration, Jupyter + .. automodule:: flytekitplugins.papermill :no-members: :no-inherited-members: diff --git a/docs/source/plugins/pod.rst b/docs/source/plugins/pod.rst index 36ace3a56d..29ba0b9c0c 100644 --- a/docs/source/plugins/pod.rst +++ b/docs/source/plugins/pod.rst @@ -4,6 +4,8 @@ Pod API reference ################################################### +.. tags:: Integration, Kubernetes, KubernetesOperator + .. automodule:: flytekitplugins.pod :no-members: :no-inherited-members: diff --git a/docs/source/plugins/ray.rst b/docs/source/plugins/ray.rst new file mode 100644 index 0000000000..cb96ab7adc --- /dev/null +++ b/docs/source/plugins/ray.rst @@ -0,0 +1,12 @@ +.. _ray: + +################################################### +Ray API reference +################################################### + +.. tags:: Integration, DistributedComputing, KubernetesOperator + +.. automodule:: flytekitplugins.ray + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/plugins/snowflake.rst b/docs/source/plugins/snowflake.rst index 3dde70a44a..0b2ba27ab3 100644 --- a/docs/source/plugins/snowflake.rst +++ b/docs/source/plugins/snowflake.rst @@ -4,6 +4,8 @@ Snowflake API reference ################################################### +.. tags:: Integration, Data + .. automodule:: flytekitplugins.snowflake :no-members: :no-inherited-members: diff --git a/docs/source/plugins/spark.rst b/docs/source/plugins/spark.rst index 79dbeec42d..9dc828c177 100644 --- a/docs/source/plugins/spark.rst +++ b/docs/source/plugins/spark.rst @@ -4,6 +4,8 @@ Spark API reference ################################################### +.. tags:: Spark, Integration, DistributedComputing + .. automodule:: flytekitplugins.spark :no-members: :no-inherited-members: diff --git a/docs/source/plugins/sqlalchemy.rst b/docs/source/plugins/sqlalchemy.rst index 469e9f6915..554b38e5d2 100644 --- a/docs/source/plugins/sqlalchemy.rst +++ b/docs/source/plugins/sqlalchemy.rst @@ -4,6 +4,8 @@ Sqlalchemy Plugin API reference ################################################### +.. tags:: Integration, Data, SQL + .. automodule:: flytekitplugins.sqlalchemy :no-members: :no-inherited-members: diff --git a/docs/source/plugins/vaex.rst b/docs/source/plugins/vaex.rst new file mode 100644 index 0000000000..efb72076bc --- /dev/null +++ b/docs/source/plugins/vaex.rst @@ -0,0 +1,10 @@ +.. _vaex: + +################################################### +Vaex API reference +################################################### + +.. automodule:: flytekitplugins.vaex + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/tasks.extend.rst b/docs/source/tasks.extend.rst index 31cfb124da..b012b2f1b1 100644 --- a/docs/source/tasks.extend.rst +++ b/docs/source/tasks.extend.rst @@ -1,8 +1,12 @@ -############################# -Build Custom Task Types -############################# +############ +Custom Tasks +############ -These modules are useful to extend the base task types. +.. tags:: Extensibility, Intermediate + +Flytekit ships with an extensible task system, which makes it easy for anyone to extend and add new task types. + +Refer to the :ref:`cookbook:prebuilt_container` and :ref:`cookbook:user_container` guides if you'd like to contribute a new task type. .. automodule:: flytekit.core.base_task :no-members: @@ -13,3 +17,9 @@ These modules are useful to extend the base task types. :no-members: :no-inherited-members: :no-special-members: + +.. toctree:: + :maxdepth: 1 + + extras.tasks + extras.sqlite3 diff --git a/docs/source/types.extend.rst b/docs/source/types.extend.rst index f0cdff28dc..e0b1d6aaf6 100644 --- a/docs/source/types.extend.rst +++ b/docs/source/types.extend.rst @@ -1,9 +1,12 @@ -################### -Extend Type System -################### +############ +Custom Types +############ + +.. tags:: Extensibility, Intermediate + Flytekit ships with an extensible type system, which makes it easy for anyone to extend and add new types. -Feel free to follow the pattern of the built-in types. +Refer to the :ref:`extensibility contribution guide ` if you'd like to contribute a Flyte type. .. toctree:: :maxdepth: 1 @@ -12,3 +15,4 @@ Feel free to follow the pattern of the built-in types. types.builtins.file types.builtins.directory extras.pytorch + extras.tensorflow diff --git a/flytekit/__init__.py b/flytekit/__init__.py index c67a8a04b4..9a2ed7f180 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -5,7 +5,7 @@ .. currentmodule:: flytekit -This package contains all of the most common abstractions you'll need to write Flyte workflows, and extend Flytekit. +This package contains all of the most common abstractions you'll need to write Flyte workflows and extend Flytekit. Basic Authoring =============== @@ -25,6 +25,8 @@ FlyteContext map_task ~core.workflow.ImperativeWorkflow + ~core.node_creation.create_node + FlyteContextManager Running Locally ------------------ @@ -152,7 +154,6 @@ LiteralType BlobType """ - import sys from typing import Generator @@ -182,7 +183,7 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck -from flytekit.extras import pytorch +from flytekit.extras import pytorch, tensorflow from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 3ac0fb90c3..bf4c1e2d0c 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -77,8 +77,10 @@ def handler(self, create_request): if e.code() == grpc.StatusCode.INVALID_ARGUMENT: cli_logger.error("Error creating Flyte entity because of invalid arguments. Create request: ") cli_logger.error(_MessageToJson(create_request)) - - # In any case, re-raise since we're not truly handling the error here + cli_logger.error("Details returned from the flyte admin: ") + cli_logger.error(e.details) + e.details += "create_request: " + _MessageToJson(create_request) + # Re-raise since we're not handling the error here and add the create_request details raise e return handler diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py index d922f5e3c1..1d959f56d7 100644 --- a/flytekit/clis/helpers.py +++ b/flytekit/clis/helpers.py @@ -32,6 +32,7 @@ def str2bool(str): return not str.lower() in ["false", "0", "off", "no"] +# TODO Deprecated delete after deleting flyte_cli register def _hydrate_identifier( project: str, domain: str, version: str, identifier: _identifier_pb2.Identifier ) -> _identifier_pb2.Identifier: @@ -46,6 +47,7 @@ def _hydrate_identifier( return identifier +# TODO Deprecated delete after deleting flyte_cli register def _hydrate_node(project: str, domain: str, version: str, node: _workflow_pb2.Node) -> _workflow_pb2.Node: if node.HasField("task_node"): task_node = node.task_node @@ -79,6 +81,7 @@ def _hydrate_node(project: str, domain: str, version: str, node: _workflow_pb2.N return node +# TODO Deprecated delete after deleting flyte_cli register def _hydrate_workflow_template_nodes( project: str, domain: str, version: str, template: _workflow_pb2.WorkflowTemplate ) -> _workflow_pb2.WorkflowTemplate: @@ -92,6 +95,7 @@ def _hydrate_workflow_template_nodes( return template +# TODO Deprecated delete after deleting flyte_cli register def hydrate_registration_parameters( resource_type: int, project: str, diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index a9a9c4900d..72246bcba4 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -1,7 +1,10 @@ +from dataclasses import replace +from typing import Optional + import click from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE -from flytekit.configuration import Config +from flytekit.configuration import Config, ImageConfig from flytekit.loggers import cli_logger from flytekit.remote.remote import FlyteRemote @@ -30,3 +33,28 @@ def get_and_save_remote_with_click_context( if save: ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r return r + + +def patch_image_config(config_file: Optional[str], image_config: ImageConfig) -> ImageConfig: + """ + Merge ImageConfig object with images defined in config file + """ + # Images come from three places: + # * The default flytekit images, which are already supplied by the base run_level_params. + # * The images provided by the user on the command line. + # * The images provided by the user via the config file, if there is one. (Images on the command line should + # override all). + # + # However, the run_level_params already contains both the default flytekit images (lowest priority), as well + # as the images from the command line (highest priority). So when we read from the config file, we only + # want to add in the images that are missing, including the default, if that's also missing. + additional_image_names = set([v.name for v in image_config.images]) + new_additional_images = [v for v in image_config.images] + new_default = image_config.default_image + if config_file: + cfg_ic = ImageConfig.auto(config_file=config_file) + new_default = new_default or cfg_ic.default_image + for addl in cfg_ic.images: + if addl.name not in additional_image_names: + new_additional_images.append(addl) + return replace(image_config, default_image=new_default, images=new_additional_images) diff --git a/flytekit/clis/sdk_in_container/init.py b/flytekit/clis/sdk_in_container/init.py index 079aab54c3..1ec2f57c32 100644 --- a/flytekit/clis/sdk_in_container/init.py +++ b/flytekit/clis/sdk_in_container/init.py @@ -20,8 +20,7 @@ def init(template, project_name): } cookiecutter( "https://github.com/flyteorg/flytekit-python-template.git", - # TODO: remove this once we make the transition to cookie-cutter official. - checkout="cookie-cutter", + checkout="main", no_input=True, # We do not want to clobber existing files/directories. overwrite_if_exists=False, diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 024b70edde..1556e343bf 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -1,19 +1,15 @@ import os -import pathlib import typing import click from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants -from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context -from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings +from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context, patch_image_config +from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages from flytekit.loggers import cli_logger -from flytekit.tools.fast_registration import fast_package -from flytekit.tools.repo import find_common_root, load_packages_and_modules -from flytekit.tools.repo import register as repo_register -from flytekit.tools.translator import Options +from flytekit.tools import repo _register_help = """ This command is similar to package but instead of producing a zip file, all your Flyte entities are compiled, @@ -105,6 +101,12 @@ is_flag=True, help="Enables symlink dereferencing when packaging files in fast registration", ) +@click.option( + "--non-fast", + default=False, + is_flag=True, + help="Enables to skip zipping and uploading the package", +) @click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) @click.pass_context def register( @@ -118,6 +120,7 @@ def register( raw_data_prefix: str, version: typing.Optional[str], deref_symlinks: bool, + non_fast: bool, package_or_module: typing.Tuple[str], ): """ @@ -129,58 +132,44 @@ def register( if pkgs: raise ValueError("Unimplemented, just specify pkgs like folder/files as args at the end of the command") + if non_fast and not version: + raise ValueError("Version is a required parameter in case --non-fast is specified.") + if len(package_or_module) == 0: display_help_with_error( ctx, "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed", ) - cli_logger.debug( + # Use extra images in the config file if that file exists + config_file = ctx.obj.get(constants.CTX_CONFIG_FILE) + if config_file: + image_config = patch_image_config(config_file, image_config) + + click.secho( f"Running pyflyte register from {os.getcwd()} " f"with images {image_config} " - f"and image destinationfolder {destination_dir} " - f"on {len(package_or_module)} package(s) {package_or_module}" + f"and image destination folder {destination_dir} " + f"on {len(package_or_module)} package(s) {package_or_module}", + dim=True, ) # Create and save FlyteRemote, remote = get_and_save_remote_with_click_context(ctx, project, domain) - - # Todo: add switch for non-fast - skip the zipping and uploading and no fastserializationsettings - # Create a zip file containing all the entries. - detected_root = find_common_root(package_or_module) - cli_logger.debug(f"Using {detected_root} as root folder for project") - zip_file = fast_package(detected_root, output, deref_symlinks) - - # Upload zip file to Admin using FlyteRemote. - md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file)) - cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}") - - # Create serialization settings - # Todo: Rely on default Python interpreter for now, this will break custom Spark containers - serialization_settings = SerializationSettings( - project=project, - domain=domain, - image_config=image_config, - fast_serialization_settings=FastSerializationSettings( - enabled=True, - destination_dir=destination_dir, - distribution_location=native_url, - ), - ) - - options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) - - # Load all the entities - registerable_entities = load_packages_and_modules( - serialization_settings, detected_root, list(package_or_module), options - ) - if len(registerable_entities) == 0: - display_help_with_error(ctx, "No Flyte entities were detected. Aborting!") - cli_logger.info(f"Found and serialized {len(registerable_entities)} entities") - - if not version: - version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa - cli_logger.info(f"Computed version is {version}") - - # Register using repo code - repo_register(registerable_entities, project, domain, version, remote.client) + try: + repo.register( + project, + domain, + image_config, + output, + destination_dir, + service_account, + raw_data_prefix, + version, + deref_symlinks, + fast=not non_fast, + package_or_module=package_or_module, + remote=remote, + ) + except Exception as e: + raise e diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index d0b890ba7b..0121094805 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -6,7 +6,7 @@ import os import pathlib import typing -from dataclasses import dataclass, replace +from dataclasses import dataclass from typing import cast import click @@ -22,7 +22,11 @@ CTX_PROJECT, CTX_PROJECT_ROOT, ) -from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context +from flytekit.clis.sdk_in_container.helpers import ( + FLYTE_REMOTE_INSTANCE_KEY, + get_and_save_remote_with_click_context, + patch_image_config, +) from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages from flytekit.core import context_manager @@ -298,7 +302,9 @@ def convert_to_literal( if self._literal_type.simple or self._literal_type.enum_type: if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT: - if type(value) != self._python_type: + if self._python_type == dict: + o = json.loads(value) + elif type(value) != self._python_type: o = cast(DataClassJsonMixin, self._python_type).from_json(value) else: o = value @@ -517,27 +523,8 @@ def _run(*args, **kwargs): remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] config_file = ctx.obj.get(CTX_CONFIG_FILE) - # Images come from three places: - # * The default flytekit images, which are already supplied by the base run_level_params. - # * The images provided by the user on the command line. - # * The images provided by the user via the config file, if there is one. (Images on the command line should - # override all). - # - # However, the run_level_params already contains both the default flytekit images (lowest priority), as well - # as the images from the command line (highest priority). So when we read from the config file, we only - # want to add in the images that are missing, including the default, if that's also missing. - image_config_from_parent_cmd = run_level_params.get("image_config", None) - additional_image_names = set([v.name for v in image_config_from_parent_cmd.images]) - new_additional_images = [v for v in image_config_from_parent_cmd.images] - new_default = image_config_from_parent_cmd.default_image - if config_file: - cfg_ic = ImageConfig.auto(config_file=config_file) - new_default = new_default or cfg_ic.default_image - for addl in cfg_ic.images: - if addl.name not in additional_image_names: - new_additional_images.append(addl) - - image_config = replace(image_config_from_parent_cmd, default_image=new_default, images=new_additional_images) + image_config = run_level_params.get("image_config") + image_config = patch_image_config(config_file, image_config) remote_entity = remote.register_script( entity, diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index c0edf07ab2..c9e15c2eba 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -94,6 +94,7 @@ from flytekit.configuration import internal as _internal from flytekit.configuration.default_images import DefaultImages from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, read_file_if_exists, set_if_exists +from flytekit.loggers import logger PROJECT_PLACEHOLDER = "{{ registration.project }}" DOMAIN_PLACEHOLDER = "{{ registration.domain }}" @@ -298,6 +299,7 @@ class PlatformConfig(object): :param endpoint: DNS for Flyte backend :param insecure: Whether or not to use SSL :param insecure_skip_verify: Wether to skip SSL certificate verification + :param console_endpoint: endpoint for console if differenet than Flyte backend :param command: This command is executed to return a token using an external process. :param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. @@ -311,6 +313,7 @@ class PlatformConfig(object): endpoint: str = "localhost:30081" insecure: bool = False insecure_skip_verify: bool = False + console_endpoint: typing.Optional[str] = None command: typing.Optional[typing.List[str]] = None client_id: typing.Optional[str] = None client_credentials_secret: typing.Optional[str] = None @@ -336,14 +339,21 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None kwargs, "client_credentials_secret", _internal.Credentials.CLIENT_CREDENTIALS_SECRET.read(config_file) ) + client_credentials_secret = read_file_if_exists( + _internal.Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(config_file) + ) + if client_credentials_secret and client_credentials_secret.endswith("\n"): + logger.info("Newline stripped from client secret") + client_credentials_secret = client_credentials_secret.strip() kwargs = set_if_exists( kwargs, "client_credentials_secret", - read_file_if_exists(_internal.Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(config_file)), + client_credentials_secret, ) kwargs = set_if_exists(kwargs, "scopes", _internal.Credentials.SCOPES.read(config_file)) kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file)) + kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file)) return PlatformConfig(**kwargs) @classmethod diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 4221a11966..5c29045db5 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -35,11 +35,6 @@ def get_specified_images(cfg: typing.Optional[ConfigFile]) -> typing.Dict[str, s return cfg.yaml_config.get("images", images) -class Deck(object): - SECTION = "deck" - DISABLE_DECK = ConfigEntry(LegacyConfigEntry(SECTION, "disable_deck", bool)) - - class AWS(object): SECTION = "aws" S3_ENDPOINT = ConfigEntry(LegacyConfigEntry(SECTION, "endpoint"), YamlConfigEntry("storage.connection.endpoint")) @@ -112,6 +107,7 @@ class Platform(object): INSECURE_SKIP_VERIFY = ConfigEntry( LegacyConfigEntry(SECTION, "insecure_skip_verify", bool), YamlConfigEntry("admin.insecureSkipVerify", bool) ) + CONSOLE_ENDPOINT = ConfigEntry(LegacyConfigEntry(SECTION, "console_endpoint"), YamlConfigEntry("console.endpoint")) class LocalSDK(object): diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index 78f4341839..d2e4838ed8 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -41,7 +41,7 @@ def __init__( task_config=task_config, **kwargs, ) - self._query_template = query_template + self._query_template = query_template.replace("\n", "\\n").replace("\t", "\\t") @property def query_template(self) -> str: diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 6aef36305e..dccbaec803 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -24,7 +24,6 @@ from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union from flytekit.configuration import SerializationSettings -from flytekit.configuration import internal as _internal from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, FlyteEntities from flytekit.core.interface import Interface, transform_interface_to_typed_interface from flytekit.core.local_cache import LocalTaskCache @@ -218,7 +217,7 @@ def get_input_types(self) -> Optional[Dict[str, type]]: """ return None - def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: """ This function is used only in the local execution path and is responsible for calling dispatch execute. Use this function when calling a task with native values (or Promises containing Flyte literals derived from @@ -365,7 +364,7 @@ def __init__( task_config: T, interface: Optional[Interface] = None, environment: Optional[Dict[str, str]] = None, - disable_deck: bool = False, + disable_deck: bool = True, **kwargs, ): """ @@ -488,6 +487,7 @@ def dispatch_execute( # Short circuit the translation to literal map because what's returned may be a dj spec (or an # already-constructed LiteralMap if the dynamic task was a no-op), not python native values + # dynamic_execute returns a literal map in local execute so this also gets triggered. if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance( native_outputs, _dynamic_job.DynamicJobSpec ): @@ -527,18 +527,18 @@ def dispatch_execute( f"Failed to convert return value for var {k} for function {self.name} with error {type(e)}: {e}" ) from e - INPUT = "input" - OUTPUT = "output" + if self._disable_deck is False: + INPUT = "input" + OUTPUT = "output" - input_deck = Deck(INPUT) - for k, v in native_inputs.items(): - input_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_input_var(k, v))) + input_deck = Deck(INPUT) + for k, v in native_inputs.items(): + input_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_input_var(k, v))) - output_deck = Deck(OUTPUT) - for k, v in native_outputs_as_map.items(): - output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v))) + output_deck = Deck(OUTPUT) + for k, v in native_outputs_as_map.items(): + output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v))) - if _internal.Deck.DISABLE_DECK.read() is not True and self.disable_deck is False: _output_deck(self.name.split(".")[-1], new_user_params) outputs_literal_map = _literal_models.LiteralMap(literals=literals) diff --git a/flytekit/core/checkpointer.py b/flytekit/core/checkpointer.py index bd17b7748b..c1eb933ec6 100644 --- a/flytekit/core/checkpointer.py +++ b/flytekit/core/checkpointer.py @@ -81,7 +81,7 @@ def __init__(self, checkpoint_dest: str, checkpoint_src: typing.Optional[str] = self._checkpoint_dest = checkpoint_dest self._checkpoint_src = checkpoint_src if checkpoint_src and checkpoint_src != "" else None self._td = tempfile.TemporaryDirectory() - self._prev_download_path = None + self._prev_download_path: typing.Optional[Path] = None def __del__(self): self._td.cleanup() @@ -146,12 +146,14 @@ def read(self) -> typing.Optional[bytes]: if p is None: return None files = list(p.iterdir()) - if len(files) == 0 or len(files) > 1: + if len(files) == 0: + return None + if len(files) > 1: raise ValueError(f"Expected exactly one checkpoint - found {len(files)}") f = files[0] return f.read_bytes() def write(self, b: bytes): - f = io.BytesIO(b) - f = typing.cast(io.BufferedReader, f) + p = io.BytesIO(b) + f = typing.cast(io.BufferedReader, p) self.save(f) diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index a31ad8150b..848c1d2524 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -7,6 +7,7 @@ from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.utils import _get_container_definition from flytekit.models import task as _task_model +from flytekit.models.security import Secret, SecurityContext class ContainerTask(PythonTask): @@ -44,14 +45,22 @@ def __init__( output_data_dir: str = None, metadata_format: MetadataFormat = MetadataFormat.JSON, io_strategy: IOStrategy = None, + secret_requests: Optional[List[Secret]] = None, **kwargs, ): + sec_ctx = None + if secret_requests: + for s in secret_requests: + if not isinstance(s, Secret): + raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") + sec_ctx = SecurityContext(secrets=secret_requests) super().__init__( task_type="raw-container", name=name, interface=Interface(inputs, outputs), metadata=metadata, task_config=None, + security_ctx=sec_ctx, **kwargs, ) self._image = image @@ -81,7 +90,8 @@ def execute(self, **kwargs) -> Any: return None def get_container(self, settings: SerializationSettings) -> _task_model.Container: - env = {**settings.env, **self.environment} if self.environment else settings.env + env = settings.env or {} + env = {**env, **self.environment} if self.environment else env return _get_container_definition( image=self._image, command=self._cmd, @@ -94,10 +104,14 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe io_strategy=self._io_strategy.value if self._io_strategy else None, ), environment=env, + storage_request=self.resources.requests.storage, + ephemeral_storage_request=self.resources.requests.ephemeral_storage, cpu_request=self.resources.requests.cpu, - cpu_limit=self.resources.limits.cpu, + gpu_request=self.resources.requests.gpu, memory_request=self.resources.requests.mem, - memory_limit=self.resources.limits.mem, - ephemeral_storage_request=self.resources.requests.ephemeral_storage, + storage_limit=self.resources.limits.storage, ephemeral_storage_limit=self.resources.limits.ephemeral_storage, + cpu_limit=self.resources.limits.cpu, + gpu_limit=self.resources.limits.gpu, + memory_limit=self.resources.limits.mem, ) diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index a17a9148b6..7e4600b3bb 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -80,15 +80,15 @@ class ExecutionParameters(object): @dataclass(init=False) class Builder(object): stats: taggable.TaggableStats - execution_date: datetime - logging: _logging.Logger - execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier] attrs: typing.Dict[str, typing.Any] - working_dir: typing.Union[os.PathLike, utils.AutoDeletingTempDir] - checkpoint: typing.Optional[Checkpoint] decks: List[Deck] - raw_output_prefix: str - task_id: typing.Optional[_identifier.Identifier] + raw_output_prefix: Optional[str] = None + execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier] = None + working_dir: typing.Optional[utils.AutoDeletingTempDir] = None + checkpoint: typing.Optional[Checkpoint] = None + execution_date: typing.Optional[datetime] = None + logging: Optional[_logging.Logger] = None + task_id: typing.Optional[_identifier.Identifier] = None def __init__(self, current: typing.Optional[ExecutionParameters] = None): self.stats = current.stats if current else None @@ -450,20 +450,20 @@ class Mode(Enum): Defines the possible execution modes, which in turn affects execution behavior. """ - #: This is the mode that is used when a task execution mimics the actual runtime environment. - #: NOTE: This is important to understand the difference between TASK_EXECUTION and LOCAL_TASK_EXECUTION - #: LOCAL_TASK_EXECUTION, is the mode that is run purely locally and in some cases the difference between local - #: and runtime environment may be different. For example for Dynamic tasks local_task_execution will just run it - #: as a regular function, while task_execution will extract a runtime spec + # This is the mode that is used when a task execution mimics the actual runtime environment. + # NOTE: This is important to understand the difference between TASK_EXECUTION and LOCAL_TASK_EXECUTION + # LOCAL_TASK_EXECUTION, is the mode that is run purely locally and in some cases the difference between local + # and runtime environment may be different. For example for Dynamic tasks local_task_execution will just run it + # as a regular function, while task_execution will extract a runtime spec TASK_EXECUTION = 1 - #: This represents when flytekit is locally running a workflow. The behavior of tasks differs in this case - #: because instead of running a task's user defined function directly, it'll need to wrap the return values in - #: NodeOutput + # This represents when flytekit is locally running a workflow. The behavior of tasks differs in this case + # because instead of running a task's user defined function directly, it'll need to wrap the return values in + # NodeOutput LOCAL_WORKFLOW_EXECUTION = 2 - #: This is the mode that is used to to indicate a purely local task execution - i.e. running without a container - #: or propeller. + # This is the mode that is used to indicate a purely local task execution - i.e. running without a container + # or propeller. LOCAL_TASK_EXECUTION = 3 mode: Optional[ExecutionState.Mode] @@ -617,7 +617,26 @@ def current_context() -> Optional[FlyteContext]: """ return FlyteContextManager.current_context() - def get_deck(self) -> str: + def get_deck(self) -> typing.Union[str, "IPython.core.display.HTML"]: # type:ignore + """ + Returns the deck that was created as part of the last execution. + + The return value depends on the execution environment. In a notebook, the return value is compatible with + IPython.display and should be rendered in the notebook. + + .. code-block:: python + + with flytekit.new_context() as ctx: + my_task(...) + ctx.get_deck() + + OR if you wish to explicity display + + .. code-block:: python + + from IPython import display + display(ctx.get_deck()) + """ from flytekit.deck.deck import _get_deck return _get_deck(self.execution_state.user_space_params) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index a2ad5311f1..d407b3528b 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -25,20 +25,24 @@ import os import pathlib import re +import shutil +import sys import tempfile import typing from abc import abstractmethod -from distutils import dir_util from shutil import copyfile from typing import Dict, Union from uuid import UUID from flytekit.configuration import DataConfig from flytekit.core.utils import PerformanceTimer -from flytekit.exceptions.user import FlyteAssertion +from flytekit.exceptions.user import FlyteAssertion, FlyteValueException from flytekit.interfaces.random import random from flytekit.loggers import logger +CURRENT_PYTHON = sys.version_info[:2] +THREE_SEVEN = (3, 7) + class UnsupportedPersistenceOp(Exception): """ @@ -221,17 +225,34 @@ def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, N def exists(self, path: str): return os.path.exists(self.strip_file_header(path)) + def copy_tree(self, from_path: str, to_path: str): + # TODO: Remove this code after support for 3.7 is dropped and inline this function back + # 3.7 doesn't have dirs_exist_ok + if CURRENT_PYTHON == THREE_SEVEN: + tp = pathlib.Path(self.strip_file_header(to_path)) + if tp.exists(): + if not tp.is_dir(): + raise FlyteValueException(tp, f"Target {tp} exists but is not a dir") + files = os.listdir(tp) + if len(files) != 0: + logger.debug(f"Deleting existing target dir {tp} with files {files}") + shutil.rmtree(tp) + shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path)) + else: + # copytree will overwrite existing files in the to_path + shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True) + def get(self, from_path: str, to_path: str, recursive: bool = False): if from_path != to_path: if recursive: - dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path)) + self.copy_tree(from_path, to_path) else: copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) def put(self, from_path: str, to_path: str, recursive: bool = False): if from_path != to_path: if recursive: - dir_util.copy_tree(self.strip_file_header(from_path), self.strip_file_header(to_path)) + self.copy_tree(from_path, to_path) else: # Emulate s3's flat storage by automatically creating directory path self._make_local_path(os.path.dirname(self.strip_file_header(to_path))) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index c721e2e160..63d7c8106f 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -256,7 +256,7 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface: return Interface(inputs=map_inputs, outputs=map_outputs) -def _change_unrecognized_type_to_pickle(t: Type[T]) -> Type[T]: +def _change_unrecognized_type_to_pickle(t: Type[T]) -> typing.Union[Tuple[Type[T]], Type[T], Annotated]: try: if hasattr(t, "__origin__") and hasattr(t, "__args__"): if get_origin(t) is list: @@ -294,9 +294,9 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc outputs = extract_return_annotation(return_annotation) for k, v in outputs.items(): - outputs[k] = _change_unrecognized_type_to_pickle(v) + outputs[k] = _change_unrecognized_type_to_pickle(v) # type: ignore inputs = OrderedDict() - for k, v in signature.parameters.items(): + for k, v in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) default = v.default if v.default is not inspect.Parameter.empty else None # Inputs with default values are currently ignored, we may want to look into that in the future diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index e8f5bd4aa3..0d143e5fe8 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -303,7 +303,7 @@ def __init__( labels: _common_models.Labels = None, annotations: _common_models.Annotations = None, raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: int = None, + max_parallelism: typing.Optional[int] = None, security_context: typing.Optional[security.SecurityContext] = None, ): self._name = name @@ -375,7 +375,7 @@ def fixed_inputs(self) -> _literal_models.LiteralMap: return self._fixed_inputs @property - def workflow(self) -> _annotated_workflow.PythonFunctionWorkflow: + def workflow(self) -> _annotated_workflow.WorkflowBase: return self._workflow @property @@ -407,7 +407,7 @@ def raw_output_data_config(self) -> Optional[_common_models.RawOutputDataConfig] return self._raw_output_data_config @property - def max_parallelism(self) -> int: + def max_parallelism(self) -> typing.Optional[int]: return self._max_parallelism @property diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 48b6f9c7da..11cb3b926c 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -1,8 +1,7 @@ from typing import Optional +import joblib from diskcache import Cache -from google.protobuf.struct_pb2 import Struct -from joblib.hashing import NumpyHasher from flytekit.models.literals import Literal, LiteralCollection, LiteralMap @@ -28,26 +27,17 @@ def _recursive_hash_placement(literal: Literal) -> Literal: return literal -class ProtoJoblibHasher(NumpyHasher): - def save(self, obj): - if isinstance(obj, Struct): - obj = dict( - rewrite_rule="google.protobuf.struct_pb2.Struct", - cls=obj.__class__, - obj=dict(sorted(obj.fields.items())), - ) - NumpyHasher.save(self, obj) - - def _calculate_cache_key(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> str: # Traverse the literals and replace the literal with a new literal that only contains the hash literal_map_overridden = {} for key, literal in input_literal_map.literals.items(): literal_map_overridden[key] = _recursive_hash_placement(literal) - # Generate a hash key of inputs with joblib - hashed_inputs = ProtoJoblibHasher().hash(literal_map_overridden) - return f"{task_name}-{cache_version}-{hashed_inputs}" + # Generate a stable representation of the underlying protobuf by passing `deterministic=True` to the + # protobuf library. + hashed_inputs = LiteralMap(literal_map_overridden).to_flyte_idl().SerializeToString(deterministic=True) + # Use joblib to hash the string representation of the literal into a fixed length string + return f"{task_name}-{cache_version}-{joblib.hash(hashed_inputs)}" class LocalTaskCache(object): diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index fc3b0e9117..3b5c0a09ca 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -99,8 +99,8 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return self._cmd_prefix + container_args return container_args - def set_command_prefix(self, cmd: typing.List[str]): - self._cmd_prefix = cmd + def set_command_prefix(self, cmd: typing.Optional[typing.List[str]]): + self._cmd_prefix = cmd # type: ignore @contextmanager def prepare_target(self): @@ -128,7 +128,7 @@ def get_sql(self, settings: SerializationSettings) -> Sql: def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return ArrayJob(parallelism=self._max_concurrency, min_success_ratio=self._min_success_ratio).to_dict() - def get_config(self, settings: SerializationSettings) -> Dict[str, str]: + def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: return self._run_task.get_config(settings) @property diff --git a/flytekit/core/node.py b/flytekit/core/node.py index cf2625b87a..d849ef5397 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -49,6 +49,7 @@ def runs_before(self, other: Node): def __rshift__(self, other: Node): self.runs_before(other) + return other @property def outputs(self): @@ -107,6 +108,8 @@ def with_overrides(self, *args, **kwargs): ) if "interruptible" in kwargs: self._metadata._interruptible = kwargs["interruptible"] + if "name" in kwargs: + self._metadata._name = kwargs["name"] return self @@ -133,7 +136,10 @@ def _convert_resource_overrides( ) if resources.ephemeral_storage is not None: resource_entries.append( - _resources_model.ResourceEntry(_resources_model.ResourceName.EPHEMERAL_STORAGE, resources.ephemeral_storage) + _resources_model.ResourceEntry( + _resources_model.ResourceName.EPHEMERAL_STORAGE, + resources.ephemeral_storage, + ) ) return resource_entries diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index 0fdecd319a..de33393c13 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -164,9 +164,9 @@ def sub_wf(): # The reason we return it if it's a tuple is to handle the case where the task returns a typing.NamedTuple. # In that case, it's already a tuple and we don't need to further tupletize. if isinstance(results, VoidPromise) or isinstance(results, tuple): - return results + return results # type: ignore - output_names = entity.python_interface.output_names + output_names = entity.python_interface.output_names # type: ignore if not output_names: raise Exception(f"Non-VoidPromise received {results} but interface for {entity.name} doesn't have outputs") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 4fe8e669ab..0fb404fb9e 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -65,7 +65,10 @@ def my_wf(in1: int, in2: int) -> int: """ def extract_value( - ctx: FlyteContext, input_val: Any, val_type: type, flyte_literal_type: _type_models.LiteralType + ctx: FlyteContext, + input_val: Any, + val_type: type, + flyte_literal_type: _type_models.LiteralType, ) -> _literal_models.Literal: if isinstance(input_val, list): @@ -350,6 +353,7 @@ def __hash__(self): def __rshift__(self, other: typing.Union[Promise, VoidPromise]): if not self.is_ready: self.ref.node.runs_before(other.ref.node) + return other def with_var(self, new_var: str) -> Promise: if self.is_ready: @@ -453,7 +457,9 @@ def __str__(self): def create_native_named_tuple( - ctx: FlyteContext, promises: Optional[Union[Promise, typing.List[Promise]]], entity_interface: Interface + ctx: FlyteContext, + promises: Optional[Union[Promise, typing.List[Promise]]], + entity_interface: Interface, ) -> Optional[Tuple]: """ Creates and returns a Named tuple with all variables that match the expected named outputs. this makes @@ -500,7 +506,8 @@ def create_native_named_tuple( # To create a class that is a named tuple, we might have to create namedtuplemeta and manipulate the tuple def create_task_output( - promises: Optional[Union[List[Promise], Promise]], entity_interface: Optional[Interface] = None + promises: Optional[Union[List[Promise], Promise]], + entity_interface: Optional[Interface] = None, ) -> Optional[Union[Tuple[Promise], Promise]]: # TODO: Add VoidPromise here to simplify things at call site. Consider returning for [] below as well instead of # raising an exception. @@ -562,7 +569,7 @@ def runs_before(self, other: Any): def __rshift__(self, other: Any): # See comment for runs_before - return self + return other return Output(*promises) # type: ignore @@ -618,7 +625,10 @@ def binding_data_from_python_std( lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type) return _literals_models.BindingData(scalar=lit.scalar) else: - _, v_type = DictTransformer.get_dict_types(t_value_type) if t_value_type else None, None + _, v_type = ( + DictTransformer.get_dict_types(t_value_type) if t_value_type else None, + None, + ) m = _literals_models.BindingDataMap( bindings={ k: binding_data_from_python_std(ctx, expected_literal_type.map_value_type, v, v_type) @@ -672,12 +682,13 @@ def runs_before(self, *args, **kwargs): """ @property - def ref(self) -> NodeOutput: + def ref(self) -> typing.Optional[NodeOutput]: return self._ref def __rshift__(self, other: typing.Union[Promise, VoidPromise]): if self.ref: self.ref.node.runs_before(other.ref.node) + return other def with_overrides(self, *args, **kwargs): if self.ref: @@ -869,7 +880,7 @@ def create_and_link_node_from_remote( ctx.compilation_state.add_node(flytekit_node) if len(typed_interface.outputs) == 0: - return VoidPromise(entity.name) + return VoidPromise(entity.name, NodeOutput(node=flytekit_node, var="placeholder")) # Create a node output object for each output, they should all point to this node of course. node_outputs = [] @@ -927,7 +938,11 @@ def create_and_link_node( try: bindings.append( binding_from_python_std( - ctx, var_name=k, expected_literal_type=var.type, t_value=v, t_value_type=interface.inputs[k] + ctx, + var_name=k, + expected_literal_type=var.type, + t_value=v, + t_value_type=interface.inputs[k], ) ) used_inputs.add(k) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 4b6743afd3..06133d9784 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -3,7 +3,8 @@ import importlib import re from abc import ABC -from typing import Callable, Dict, List, Optional, TypeVar +from types import ModuleType +from typing import Callable, Dict, List, Optional, TypeVar, Union from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask, TaskResolverMixin @@ -82,7 +83,7 @@ def __init__( self._resources = ResourceSpec( requests=requests if requests else Resources(), limits=limits if limits else Resources() ) - self._environment = environment + self._environment = environment or {} compilation_state = FlyteContextManager.current_context().compilation_state if compilation_state and compilation_state.task_resolver: @@ -92,13 +93,13 @@ def __init__( ) self._task_resolver = compilation_state.task_resolver if self._task_resolver.task_name(self) is not None: - self._name = self._task_resolver.task_name(self) + self._name = self._task_resolver.task_name(self) or "" else: self._task_resolver = task_resolver or default_task_resolver self._get_command_fn = self.get_default_command @property - def task_resolver(self) -> TaskResolverMixin: + def task_resolver(self) -> Optional[TaskResolverMixin]: return self._task_resolver @property @@ -139,7 +140,7 @@ def set_command_fn(self, get_command_fn: Optional[Callable[[SerializationSetting However, it can be useful to update the command with which the task is serialized for specific cases like running map tasks ("pyflyte-map-execute") or for fast-executed tasks. """ - self._get_command_fn = get_command_fn + self._get_command_fn = get_command_fn # type: ignore def reset_command_fn(self): """ @@ -187,7 +188,7 @@ class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): def name(self) -> str: return "DefaultTaskResolver" - def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: + def load_task(self, loader_args: List[Union[T, ModuleType]]) -> PythonAutoContainerTask: _, task_module, _, task_name, *_ = loader_args task_module = importlib.import_module(task_module) diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py index f06b96933c..eee0dce9b8 100644 --- a/flytekit/core/python_customized_container_task.py +++ b/flytekit/core/python_customized_container_task.py @@ -105,7 +105,7 @@ def __init__( self._resources = ResourceSpec( requests=requests if requests else Resources(), limits=limits if limits else Resources() ) - self._environment = environment + self._environment = environment or {} self._container_image = container_image self._task_resolver = task_resolver or default_task_template_resolver @@ -166,10 +166,12 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe data_loading_config=None, environment=env, storage_request=self.resources.requests.storage, + ephemeral_storage_request=self.resources.requests.ephemeral_storage, cpu_request=self.resources.requests.cpu, gpu_request=self.resources.requests.gpu, memory_request=self.resources.requests.mem, storage_limit=self.resources.limits.storage, + ephemeral_storage_limit=self.resources.limits.ephemeral_storage, cpu_limit=self.resources.limits.cpu, gpu_limit=self.resources.limits.gpu, memory_limit=self.resources.limits.mem, diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 669b634f89..bcb80f34ca 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -19,12 +19,11 @@ from enum import Enum from typing import Any, Callable, List, Optional, TypeVar, Union -from flytekit.configuration import SerializationSettings -from flytekit.configuration.default_images import DefaultImages from flytekit.core.base_task import Task, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.docstring import Docstring from flytekit.core.interface import transform_function_to_interface +from flytekit.core.promise import VoidPromise, translate_inputs_to_literals from flytekit.core.python_auto_container import PythonAutoContainerTask, default_task_resolver from flytekit.core.tracker import extract_task_module, is_functools_wrapped_module_level, isnested, istestfunction from flytekit.core.workflow import ( @@ -34,6 +33,7 @@ WorkflowMetadataDefaults, ) from flytekit.exceptions import scopes as exception_scopes +from flytekit.exceptions.user import FlyteValueException from flytekit.loggers import logger from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import literals as _literal_models @@ -100,7 +100,7 @@ def __init__( task_function: Callable, task_type="python-task", ignore_input_vars: Optional[List[str]] = None, - execution_mode: Optional[ExecutionBehavior] = ExecutionBehavior.DEFAULT, + execution_mode: ExecutionBehavior = ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, **kwargs, ): @@ -145,6 +145,7 @@ def __init__( ) self._task_function = task_function self._execution_mode = execution_mode + self._wf = None # For dynamic tasks @property def execution_mode(self) -> ExecutionBehavior: @@ -164,6 +165,14 @@ def execute(self, **kwargs) -> Any: elif self.execution_mode == self.ExecutionBehavior.DYNAMIC: return self.dynamic_execute(self._task_function, **kwargs) + def _create_and_cache_dynamic_workflow(self): + if self._wf is None: + workflow_meta = WorkflowMetadata(on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) + defaults = WorkflowMetadataDefaults( + interruptible=self.metadata.interruptible if self.metadata.interruptible is not None else False + ) + self._wf = PythonFunctionWorkflow(self._task_function, metadata=workflow_meta, default_metadata=defaults) + def compile_into_workflow( self, ctx: FlyteContext, task_function: Callable, **kwargs ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]: @@ -183,12 +192,7 @@ def compile_into_workflow( # TODO: Resolve circular import from flytekit.tools.translator import get_serializable - workflow_metadata = WorkflowMetadata(on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) - defaults = WorkflowMetadataDefaults( - interruptible=self.metadata.interruptible if self.metadata.interruptible is not None else False - ) - - self._wf = PythonFunctionWorkflow(task_function, metadata=workflow_metadata, default_metadata=defaults) + self._create_and_cache_dynamic_workflow() self._wf.compile(**kwargs) wf = self._wf @@ -259,19 +263,44 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: representing that newly generated workflow, instead of executing it. """ ctx = FlyteContextManager.current_context() - # This is a placeholder SerializationSettings placeholder and is only used to test compilation for dynamic tasks - # when run locally. The output of the compilation should never actually be used anywhere. - _LOCAL_ONLY_SS = SerializationSettings.for_image(DefaultImages.default_image(), "v", "p", "d") - if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: - updated_exec_state = ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION) - with FlyteContextManager.with_context( - ctx.with_execution_state(updated_exec_state).with_serialization_settings(_LOCAL_ONLY_SS) - ) as ctx: - logger.debug(f"Running compilation for {self} as part of local run as check") - self.compile_into_workflow(ctx, task_function, **kwargs) - logger.info("Executing Dynamic workflow, using raw inputs") - return exception_scopes.user_entry_point(task_function)(**kwargs) + # The rest of this function mimics the local_execute of the workflow. We can't use the workflow + # local_execute directly though since that converts inputs into Promises. + logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}") + self._create_and_cache_dynamic_workflow() + function_outputs = self._wf.execute(**kwargs) + + if isinstance(function_outputs, VoidPromise) or function_outputs is None: + return VoidPromise(self.name) + + if len(self._wf.python_interface.outputs) == 0: + raise FlyteValueException(function_outputs, "Interface output should've been VoidPromise or None.") + + # TODO: This will need to be cleaned up when we revisit top-level tuple support. + expected_output_names = list(self.python_interface.outputs.keys()) + if len(expected_output_names) == 1: + # Here we have to handle the fact that the wf could've been declared with a typing.NamedTuple of + # length one. That convention is used for naming outputs - and single-length-NamedTuples are + # particularly troublesome but elegant handling of them is not a high priority + # Again, we're using the output_tuple_name as a proxy. + if self.python_interface.output_tuple_name and isinstance(function_outputs, tuple): + wf_outputs_as_map = {expected_output_names[0]: function_outputs[0]} + else: + wf_outputs_as_map = {expected_output_names[0]: function_outputs} + else: + wf_outputs_as_map = { + expected_output_names[i]: function_outputs[i] for i, _ in enumerate(function_outputs) + } + + # In a normal workflow, we'd repackage the promises coming from tasks into new Promises matching the + # workflow's interface. For a dynamic workflow, just return the literal map. + wf_outputs_as_literal_dict = translate_inputs_to_literals( + ctx, + wf_outputs_as_map, + flyte_interface_types=self.interface.outputs, + native_types=self.python_interface.outputs, + ) + return _literal_models.LiteralMap(literals=wf_outputs_as_literal_dict) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: return self.compile_into_workflow(ctx, task_function, **kwargs) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index d210aaa2ed..6e5b0a6b6a 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -89,7 +89,7 @@ def task( secret_requests: Optional[List[Secret]] = None, execution_mode: Optional[PythonFunctionTask.ExecutionBehavior] = PythonFunctionTask.ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, - disable_deck: bool = False, + disable_deck: bool = True, ) -> Union[Callable, PythonFunctionTask]: """ This is the core decorator to use for any task type in flytekit. @@ -227,7 +227,7 @@ def __init__( super().__init__(TaskReference(project, domain, name, version), inputs, outputs) # Reference tasks shouldn't call the parent constructor, but the parent constructor is what sets the resolver - self._task_resolver = None + self._task_resolver = None # type: ignore def reference_task( diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 9851e2e98b..2a203d4861 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -229,16 +229,18 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, mod_name = mod.__name__ name = f.lhs # We cannot get the sourcefile for an instance, so we replace it with the module - f = mod + g = mod + inspect_file = inspect.getfile(g) else: - mod = inspect.getmodule(f) + mod = inspect.getmodule(f) # type: ignore if mod is None: raise AssertionError(f"Unable to determine module of {f}") mod_name = mod.__name__ name = f.__name__.split(".")[-1] + inspect_file = inspect.getfile(f) if mod_name == "__main__": - return name, "", name, os.path.abspath(inspect.getfile(f)) + return name, "", name, os.path.abspath(inspect_file) mod_name = get_full_module_path(mod, mod_name) return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod)) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 43c11065ba..40b39eae90 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,10 +15,10 @@ from dataclasses_json import DataClassJsonMixin, dataclass_json from google.protobuf import json_format as _json_format -from google.protobuf import reflection as _proto_reflection from google.protobuf import struct_pb2 as _struct from google.protobuf.json_format import MessageToDict as _MessageToDict from google.protobuf.json_format import ParseDict as _ParseDict +from google.protobuf.message import Message from google.protobuf.struct_pb2 import Struct from marshmallow_enum import EnumField, LoadDumpOptions from marshmallow_jsonschema import JSONSchema @@ -62,13 +62,10 @@ class TypeTransformer(typing.Generic[T]): Base transformer type that should be implemented for every python native type that can be handled by flytekit """ - def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True, hash_overridable: bool = False): + def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True): self._t = t self._name = name self._type_assertions_enabled = enable_type_assertions - # `hash_overridable` indicates that the literals produced by this type transformer can set their hashes if needed. - # See (link to documentation where this feature is explained). - self._hash_overridable = hash_overridable @property def name(self): @@ -88,10 +85,6 @@ def type_assertions_enabled(self) -> bool: """ return self._type_assertions_enabled - @property - def hash_overridable(self) -> bool: - return self._hash_overridable - def assert_type(self, t: Type[T], v: T): if not hasattr(t, "__origin__") and not isinstance(v, t): raise TypeTransformerFailedError(f"Type of Val '{v}' is not an instance of {t}") @@ -395,7 +388,9 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A if issubclass(python_type, FlyteFile) or issubclass(python_type, FlyteDirectory): return python_type(path=lv.scalar.blob.uri) elif issubclass(python_type, StructuredDataset): - return python_type(uri=lv.scalar.structured_dataset.uri) + sd = python_type(uri=lv.scalar.structured_dataset.uri) + sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format + return sd else: return python_val else: @@ -412,10 +407,10 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is list: - return [self._deserialize_flyte_type(v, expected_python_type.__args__[0]) for v in python_val] + return [self._deserialize_flyte_type(v, expected_python_type.__args__[0]) for v in python_val] # type: ignore if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is dict: - return {k: self._deserialize_flyte_type(v, expected_python_type.__args__[1]) for k, v in python_val.items()} + return {k: self._deserialize_flyte_type(v, expected_python_type.__args__[1]) for k, v in python_val.items()} # type: ignore if not dataclasses.is_dataclass(expected_python_type): return python_val @@ -491,6 +486,14 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if val is None: return val + + if get_origin(t) is typing.Union and type(None) in get_args(t): + # Handle optional type. e.g. Optional[int], Optional[dataclass] + # Marshmallow doesn't support union type, so the type here is always an optional type. + # https://github.com/marshmallow-code/marshmallow/issues/1191#issuecomment-480831796 + # Note: Union[None, int] is also an optional type, but Marshmallow does not support it. + t = get_args(t)[0] + if t == int: return int(val) @@ -503,13 +506,6 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: # Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}}) return {self._fix_val_int(ktype, k): self._fix_val_int(vtype, v) for k, v in val.items()} - if get_origin(t) is typing.Union and type(None) in get_args(t): - # Handle optional type. e.g. Optional[int], Optional[dataclass] - # Marshmallow doesn't support union type, so the type here is always an optional type. - # https://github.com/marshmallow-code/marshmallow/issues/1191#issuecomment-480831796 - # Note: Union[None, int] is also an optional type, but Marshmallow does not support it. - return self._fix_val_int(get_args(t)[0], val) - if dataclasses.is_dataclass(t): return self._fix_dataclass_int(t, val) # type: ignore @@ -540,7 +536,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"serialized correctly" ) - dc = cast(DataClassJsonMixin, expected_python_type).from_json(_json_format.MessageToJson(lv.scalar.generic)) + json_str = _json_format.MessageToJson(lv.scalar.generic) + dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str) return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type)) # This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run`` @@ -557,11 +554,11 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]: raise ValueError(f"Dataclass transformer cannot reverse {literal_type}") -class ProtobufTransformer(TypeTransformer[_proto_reflection.GeneratedProtocolMessageType]): +class ProtobufTransformer(TypeTransformer[Message]): PB_FIELD_KEY = "pb_type" def __init__(self): - super().__init__("Protobuf-Transformer", _proto_reflection.GeneratedProtocolMessageType) + super().__init__("Protobuf-Transformer", Message) @staticmethod def tag(expected_python_type: Type[T]) -> str: @@ -657,14 +654,13 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc Step 3: - if v is of type data class, use the dataclass transformer - - Step 4: Walk the inheritance hierarchy of v and find a transformer that matches the first base class. This is potentially non-deterministic - will depend on the registration pattern. TODO lets make this deterministic by using an ordered dict + Step 4: + if v is of type data class, use the dataclass transformer """ # Step 1 @@ -687,9 +683,6 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: raise ValueError(f"Generic Type {python_type.__origin__} not supported currently in Flytekit.") # Step 3 - if dataclasses.is_dataclass(python_type): - return cls._DATACLASS_TRANSFORMER - # To facilitate cases where users may specify one transformer for multiple types that all inherit from one # parent. for base_type in cls._REGISTRY.keys(): @@ -704,6 +697,11 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: # As of python 3.9, calls to isinstance raise a TypeError if the base type is not a valid type, which # is the case for one of the restricted types, namely NamedTuple. logger.debug(f"Invalid base type {base_type} in call to isinstance", exc_info=True) + + # Step 4 + if dataclasses.is_dataclass(python_type): + return cls._DATACLASS_TRANSFORMER + raise ValueError(f"Type {python_type} not supported currently in Flytekit. Please register a new transformer") @classmethod @@ -742,7 +740,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type # In case the value is an annotated type we inspect the annotations and look for hash-related annotations. hash = None - if transformer.hash_overridable and get_origin(python_type) is Annotated: + if get_origin(python_type) is Annotated: # We are now dealing with one of two cases: # 1. The annotated type is a `HashMethod`, which indicates that we should we should produce the hash using # the method indicated in the annotation. @@ -885,7 +883,6 @@ def get_sub_type(t: Type[T]) -> Type[T]: """ Return the generic Type T of the List """ - if hasattr(t, "__origin__"): # Handle annotation on list generic, eg: # Annotated[typing.List[int], 'foo'] @@ -1065,7 +1062,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp # Should really never happen, sanity check raise TypeError("Ambiguous choice of variant for union type") found_res = True - except (TypeTransformerFailedError, AttributeError) as e: + except (TypeTransformerFailedError, AttributeError, ValueError, AssertionError) as e: logger.debug(f"Failed to convert from {python_val} to {t}", e) continue @@ -1129,7 +1126,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: def guess_python_type(self, literal_type: LiteralType) -> type: if literal_type.union_type is not None: - return typing.Union[tuple(TypeEngine.guess_python_type(v.type) for v in literal_type.union_type.variants)] + return typing.Union[tuple(TypeEngine.guess_python_type(v.type) for v in literal_type.union_type.variants)] # type: ignore raise ValueError(f"Union transformer cannot reverse {literal_type}") @@ -1158,7 +1155,7 @@ def get_dict_types(t: Optional[Type[dict]]) -> typing.Tuple[Optional[type], Opti parsed." ) if _origin is dict and _args is not None: - return _args + return _args # type: ignore return None, None @staticmethod @@ -1225,14 +1222,14 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") - def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + def guess_python_type(self, literal_type: LiteralType) -> Union[Type[dict], typing.Dict[Type, Type]]: if literal_type.map_value_type: mt = TypeEngine.guess_python_type(literal_type.map_value_type) return typing.Dict[str, mt] # type: ignore if literal_type.simple == SimpleType.STRUCT: if literal_type.metadata is None: - return dict + return dict # type: ignore raise ValueError(f"Dictionary transformer cannot reverse {literal_type}") @@ -1366,13 +1363,13 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) -def _get_element_type(element_property: typing.Dict[str, str]) -> Type[T]: +def _get_element_type(element_property: typing.Dict[str, str]) -> Type: element_type = element_property["type"] element_format = element_property["format"] if "format" in element_property else None if type(element_type) == list: # Element type of Optional[int] is [integer, None] - return typing.Optional[_get_element_type({"type": element_type[0]})] + return typing.Optional[_get_element_type({"type": element_type[0]})] # type: ignore if element_type == "string": return str diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 0f36f86ec3..468a5aa7ea 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -134,7 +134,7 @@ def get_promise(binding_data: _literal_models.BindingData, outputs_cache: Dict[N val=_literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literals)), ) elif binding_data.map is not None: - literals = {} + literals = {} # type: ignore for k, bd in binding_data.map.bindings.items(): p = get_promise(bd, outputs_cache) literals[k] = p.val @@ -178,7 +178,7 @@ def __init__( self._inputs = {} self._unbound_inputs = set() self._nodes = [] - self._output_bindings: Optional[List[_literal_models.Binding]] = [] + self._output_bindings: List[_literal_models.Binding] = [] FlyteEntities.entities.append(self) super().__init__(**kwargs) @@ -240,7 +240,7 @@ def __call__(self, *args, **kwargs): def execute(self, **kwargs): raise Exception("Should not be called") - def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. for k, v in kwargs.items(): @@ -487,7 +487,7 @@ def get_input_values(input_value): for input_value in filter(lambda x: isinstance(x, Promise), all_input_values): if input_value in self._unbound_inputs: self._unbound_inputs.remove(input_value) - return n + return n # type: ignore def add_workflow_input(self, input_name: str, python_type: Type) -> Interface: """ diff --git a/flytekit/deck/__init__.py b/flytekit/deck/__init__.py index 93e3b6ef70..f83049ac48 100644 --- a/flytekit/deck/__init__.py +++ b/flytekit/deck/__init__.py @@ -1,2 +1,18 @@ +""" +========== +Flyte Deck +========== + +.. currentmodule:: flytekit.deck + +Contains deck renderers provided by flytekit. + +.. autosummary:: + :toctree: generated/ + + Deck + TopFrameRenderer +""" + from .deck import Deck from .renderer import TopFrameRenderer diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py index 24f8dcf514..cec59e7318 100644 --- a/flytekit/deck/deck.py +++ b/flytekit/deck/deck.py @@ -1,4 +1,5 @@ import os +import typing from typing import Optional from jinja2 import Environment, FileSystemLoader, select_autoescape @@ -30,26 +31,26 @@ class Deck: This feature is in beta. - .. code-block:: python + .. code-block:: python - iris_df = px.data.iris() + iris_df = px.data.iris() - @task() - def t1() -> str: - md_text = "#Hello Flyte\n##Hello Flyte\n###Hello Flyte" - m = MarkdownRenderer() - s = BoxRenderer("sepal_length") - deck = flytekit.Deck("demo", s.to_html(iris_df)) - deck.append(m.to_html(md_text)) - default_deck = flytekit.current_context().default_deck - default_deck.append(m.to_html(md_text)) - return md_text + @task() + def t1() -> str: + md_text = '#Hello Flyte##Hello Flyte###Hello Flyte' + m = MarkdownRenderer() + s = BoxRenderer("sepal_length") + deck = flytekit.Deck("demo", s.to_html(iris_df)) + deck.append(m.to_html(md_text)) + default_deck = flytekit.current_context().default_deck + default_deck.append(m.to_html(md_text)) + return md_text - # Use Annotated to override default renderer - @task() - def t2() -> Annotated[pd.DataFrame, TopFrameRenderer(10)]: - return iris_df + # Use Annotated to override default renderer + @task() + def t2() -> Annotated[pd.DataFrame, TopFrameRenderer(10)]: + return iris_df """ @@ -89,12 +90,20 @@ def _ipython_check() -> bool: return is_ipython -def _get_deck(new_user_params: ExecutionParameters) -> str: +def _get_deck( + new_user_params: ExecutionParameters, ignore_jupyter: bool = False +) -> typing.Union[str, "IPython.core.display.HTML"]: # type:ignore """ Get flyte deck html string + If ignore_jupyter is set to True, then it will return a str even in a jupyter environment. """ deck_map = {deck.name: deck.html for deck in new_user_params.decks} - return template.render(metadata=deck_map) + raw_html = template.render(metadata=deck_map) + if not ignore_jupyter and _ipython_check(): + from IPython.core.display import HTML + + return HTML(raw_html) + return raw_html def _output_deck(task_name: str, new_user_params: ExecutionParameters): @@ -105,7 +114,7 @@ def _output_deck(task_name: str, new_user_params: ExecutionParameters): output_dir = ctx.file_access.get_random_local_directory() deck_path = os.path.join(output_dir, DECK_FILE_NAME) with open(deck_path, "w") as f: - f.write(_get_deck(new_user_params)) + f.write(_get_deck(new_user_params, ignore_jupyter=True)) logger.info(f"{task_name} task creates flyte deck html to file://{deck_path}") diff --git a/flytekit/deck/renderer.py b/flytekit/deck/renderer.py index 0cf781d3da..dddb88e420 100644 --- a/flytekit/deck/renderer.py +++ b/flytekit/deck/renderer.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import pandas import pyarrow @@ -14,17 +14,22 @@ def to_html(self, python_value: Any) -> str: raise NotImplementedError +DEFAULT_MAX_ROWS = 10 +DEFAULT_MAX_COLS = 100 + + class TopFrameRenderer: """ Render a DataFrame as an HTML table. """ - def __init__(self, max_rows: Optional[int] = None): + def __init__(self, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS): self._max_rows = max_rows + self._max_cols = max_cols def to_html(self, df: pandas.DataFrame) -> str: assert isinstance(df, pandas.DataFrame) - return df.to_html(max_rows=self._max_rows) + return df.to_html(max_rows=self._max_rows, max_cols=self._max_cols) class ArrowRenderer: diff --git a/flytekit/extras/sklearn/__init__.py b/flytekit/extras/sklearn/__init__.py new file mode 100644 index 0000000000..0a1bf2dda5 --- /dev/null +++ b/flytekit/extras/sklearn/__init__.py @@ -0,0 +1,26 @@ +""" +Flytekit Sklearn +========================================= +.. currentmodule:: flytekit.extras.sklearn + +.. autosummary:: + :template: custom.rst + :toctree: generated/ +""" +from flytekit.loggers import logger + +# TODO: abstract this out so that there's an established pattern for registering plugins +# that have soft dependencies +try: + # isolate the exception to the sklearn import + import sklearn + + _sklearn_installed = True +except (ImportError, OSError): + _sklearn_installed = False + + +if _sklearn_installed: + from .native import SklearnEstimatorTransformer +else: + logger.info("We won't register SklearnEstimatorTransformer because scikit-learn is not installed.") diff --git a/flytekit/extras/sklearn/native.py b/flytekit/extras/sklearn/native.py new file mode 100644 index 0000000000..59ecca70c5 --- /dev/null +++ b/flytekit/extras/sklearn/native.py @@ -0,0 +1,79 @@ +import pathlib +from typing import Generic, Type, TypeVar + +import joblib +import sklearn + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + +T = TypeVar("T") + + +class SklearnTypeTransformer(TypeTransformer, Generic[T]): + def get_literal_type(self, t: Type[T]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.SKLEARN_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: T, + python_type: Type[T], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.SKLEARN_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + local_path = ctx.file_access.get_random_local_path() + ".joblib" + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + # save sklearn estimator to a file + joblib.dump(python_val, local_path) + + remote_path = ctx.file_access.get_random_remote_path(local_path) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=False) + + # load sklearn estimator from a file + return joblib.load(local_path) + + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.SKLEARN_FORMAT + ): + return T + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +class SklearnEstimatorTransformer(SklearnTypeTransformer[sklearn.base.BaseEstimator]): + SKLEARN_FORMAT = "SklearnEstimator" + + def __init__(self): + super().__init__(name="Sklearn Estimator", t=sklearn.base.BaseEstimator) + + +TypeEngine.register(SklearnEstimatorTransformer()) diff --git a/flytekit/extras/sqlite3/__init__.py b/flytekit/extras/sqlite3/__init__.py index e69de29bb2..cc016ac4af 100644 --- a/flytekit/extras/sqlite3/__init__.py +++ b/flytekit/extras/sqlite3/__init__.py @@ -0,0 +1,12 @@ +""" +Flytekit SQLite3Task +========================================= +.. currentmodule:: flytekit.extras.sqlite3 + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + ~task.SQLite3Task + ~task.SQLite3Config +""" diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 1018b5254b..0284440da3 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -9,7 +9,7 @@ import pandas as pd from flytekit import FlyteContext, kwtypes -from flytekit.configuration import SerializationSettings +from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.base_sql_task import SQLTask from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor @@ -79,6 +79,7 @@ def __init__( inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, task_config: typing.Optional[SQLite3Config] = None, output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None, + container_image: typing.Optional[str] = None, **kwargs, ): if task_config is None or task_config.uri is None: @@ -87,8 +88,8 @@ def __init__( super().__init__( name=name, task_config=task_config, - # If you make changes to this task itself, you'll have to bump this image to what the release _will_ be. - container_image="ghcr.io/flyteorg/flytekit:v0.19.0", + # if you use your own image, keep in mind to specify the container image here + container_image=container_image or DefaultImages.default_image(), executor_type=SQLite3TaskExecutor, task_type=self._SQLITE_TASK_TYPE, query_template=query_template, @@ -96,6 +97,9 @@ def __init__( outputs=outputs, **kwargs, ) + # Sanitize query by removing the newlines at the end of the query. Keep in mind + # that the query can be a multiline string. + self._query_template = query_template.replace("\n", " ") @property def output_columns(self) -> typing.Optional[typing.List[str]]: diff --git a/flytekit/extras/tasks/__init__.py b/flytekit/extras/tasks/__init__.py index e69de29bb2..94ca9244fd 100644 --- a/flytekit/extras/tasks/__init__.py +++ b/flytekit/extras/tasks/__init__.py @@ -0,0 +1,12 @@ +""" +Flytekit ShellTask +========================================= +.. currentmodule:: flytekit.extras.tasks + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + ~shell.ShellTask + ~shell.OutputLocation +""" diff --git a/flytekit/extras/tensorflow/__init__.py b/flytekit/extras/tensorflow/__init__.py new file mode 100644 index 0000000000..f51da24dae --- /dev/null +++ b/flytekit/extras/tensorflow/__init__.py @@ -0,0 +1,34 @@ +""" +Flytekit TensorFlow +========================================= +.. currentmodule:: flytekit.extras.tensorflow + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + TensorFlowRecord +""" +from flytekit.loggers import logger + +# TODO: abstract this out so that there's an established pattern for registering plugins +# that have soft dependencies +try: + # isolate the exception to the tensorflow import + import tensorflow + + _tensorflow_installed = True +except TypeError as e: + logger.warn(f"Unsupported version of tensorflow installed. Error message from protobuf library: {e}") + _tensorflow_installed = False +except (ImportError, OSError): + _tensorflow_installed = False + + +if _tensorflow_installed: + from .record import TensorFlowRecordFileTransformer, TensorFlowRecordsDirTransformer +else: + logger.info( + "We won't register TensorFlowRecordFileTransformer and TensorFlowRecordsDirTransformer " + "because tensorflow is not installed." + ) diff --git a/flytekit/extras/tensorflow/record.py b/flytekit/extras/tensorflow/record.py new file mode 100644 index 0000000000..d5d750b521 --- /dev/null +++ b/flytekit/extras/tensorflow/record.py @@ -0,0 +1,188 @@ +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Type, Union + +import tensorflow as tf +from dataclasses_json import dataclass_json +from tensorflow.python.data.ops.readers import TFRecordDatasetV2 +from typing_extensions import Annotated, get_args, get_origin + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType +from flytekit.types.directory import TFRecordsDirectory +from flytekit.types.file import TFRecordFile + + +@dataclass_json +@dataclass +class TFRecordDatasetConfig: + """ + TFRecordDatasetConfig can be used while creating tf.data.TFRecordDataset comprising + record of one or more TFRecord files. + + Args: + compression_type: A scalar evaluating to one of "" (no compression), "ZLIB", or "GZIP". + buffer_size: The number of bytes in the read buffer. If None, a sensible default for both local and remote file systems is used. + num_parallel_reads: The number of files to read in parallel. If greater than one, the record of files read in parallel are outputted in an interleaved order. + name: A name for the operation. + """ + + compression_type: Optional[str] = None + buffer_size: Optional[int] = None + num_parallel_reads: Optional[int] = None + name: Optional[str] = None + + +def extract_metadata_and_uri( + lv: Literal, t: Type[Union[TFRecordFile, TFRecordsDirectory]] +) -> Tuple[Union[TFRecordFile, TFRecordsDirectory], TFRecordDatasetConfig]: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {t}") + metadata = TFRecordDatasetConfig() + if get_origin(t) is Annotated: + _, metadata = get_args(t) + if isinstance(metadata, TFRecordDatasetConfig): + return uri, metadata + else: + raise TypeTransformerFailedError(f"{t}'s metadata needs to be of type TFRecordDatasetConfig") + return uri, metadata + + +class TensorFlowRecordFileTransformer(TypeTransformer[TFRecordFile]): + """ + TypeTransformer that supports serialising and deserialising to and from TFRecord file. + https://www.tensorflow.org/tutorials/load_data/tfrecord + """ + + TENSORFLOW_FORMAT = "TensorFlowRecord" + + def __init__(self): + super().__init__(name="TensorFlow Record File", t=TFRecordFile) + + def get_literal_type(self, t: Type[TFRecordFile]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + def to_literal( + self, ctx: FlyteContext, python_val: TFRecordFile, python_type: Type[TFRecordFile], expected: LiteralType + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + local_dir = ctx.file_access.get_random_local_directory() + remote_path = ctx.file_access.get_random_remote_path() + local_path = os.path.join(local_dir, "0000.tfrecord") + with tf.io.TFRecordWriter(local_path) as writer: + writer.write(python_val.SerializeToString()) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[TFRecordFile] + ) -> TFRecordDatasetV2: + uri, metadata = extract_metadata_and_uri(lv, expected_python_type) + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=False) + filenames = [local_path] + return tf.data.TFRecordDataset( + filenames=filenames, + compression_type=metadata.compression_type, + buffer_size=metadata.buffer_size, + num_parallel_reads=metadata.num_parallel_reads, + name=metadata.name, + ) + + def guess_python_type(self, literal_type: LiteralType) -> Type[TFRecordFile]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.TENSORFLOW_FORMAT + ): + return TFRecordFile + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +class TensorFlowRecordsDirTransformer(TypeTransformer[TFRecordsDirectory]): + """ + TypeTransformer that supports serialising and deserialising to and from TFRecord directory. + https://www.tensorflow.org/tutorials/load_data/tfrecord + """ + + TENSORFLOW_FORMAT = "TensorFlowRecord" + + def __init__(self): + super().__init__(name="TensorFlow Record Directory", t=TFRecordsDirectory) + + def get_literal_type(self, t: Type[TFRecordsDirectory]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: TFRecordsDirectory, + python_type: Type[TFRecordsDirectory], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) + ) + local_dir = ctx.file_access.get_random_local_directory() + remote_path = ctx.file_access.get_random_remote_directory() + for i, val in enumerate(python_val): + local_path = f"{local_dir}/part_{i}.tfrecord" + with tf.io.TFRecordWriter(local_path) as writer: + writer.write(val.SerializeToString()) + ctx.file_access.upload_directory(local_dir, remote_path) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[TFRecordsDirectory] + ) -> TFRecordDatasetV2: + + uri, metadata = extract_metadata_and_uri(lv, expected_python_type) + local_dir = ctx.file_access.get_random_local_directory() + ctx.file_access.get_data(uri, local_dir, is_multipart=True) + files = os.scandir(local_dir) + filenames = [os.path.join(local_dir, f.name) for f in files] + return tf.data.TFRecordDataset( + filenames=filenames, + compression_type=metadata.compression_type, + buffer_size=metadata.buffer_size, + num_parallel_reads=metadata.num_parallel_reads, + name=metadata.name, + ) + + def guess_python_type(self, literal_type: LiteralType) -> Type[TFRecordsDirectory]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART + and literal_type.blob.format == self.TENSORFLOW_FORMAT + ): + return TFRecordsDirectory + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(TensorFlowRecordsDirTransformer()) +TypeEngine.register(TensorFlowRecordFileTransformer()) diff --git a/flytekit/models/common.py b/flytekit/models/common.py index 63253bf399..7236dd15ce 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -65,6 +65,9 @@ def verbose_string(self): """ return self.short_string() + def serialize_to_string(self) -> str: + return self.to_flyte_idl().SerializeToString() + @property def is_empty(self): return len(self.to_flyte_idl().SerializeToString()) == 0 diff --git a/flytekit/models/core/identifier.py b/flytekit/models/core/identifier.py index b65f7d269d..bf46ace349 100644 --- a/flytekit/models/core/identifier.py +++ b/flytekit/models/core/identifier.py @@ -90,6 +90,12 @@ def from_flyte_idl(cls, p): version=p.version, ) + def __repr__(self): + return self.__str__() + + def __str__(self): + return f"{self.resource_type_name()}:{self.project}:{self.domain}:{self.name}:{self.version}" + class WorkflowExecutionIdentifier(_common_models.FlyteIdlEntity): def __init__(self, project, domain, name): diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 3dbdfb8564..75e040891b 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -1,10 +1,15 @@ +from __future__ import annotations + +import datetime import typing +import flyteidl import flyteidl.admin.execution_pb2 as _execution_pb2 import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 import pytz as _pytz +import flytekit from flytekit.models import common as _common_models from flytekit.models import literals as _literals_models from flytekit.models import security @@ -13,52 +18,126 @@ from flytekit.models.node_execution import DynamicWorkflowNodeMetadata +class SystemMetadata(_common_models.FlyteIdlEntity): + def __init__(self, execution_cluster: str): + self._execution_cluster = execution_cluster + + @property + def execution_cluster(self) -> str: + return self._execution_cluster + + def to_flyte_idl(self) -> flyteidl.admin.execution_pb2.SystemMetadata: + return _execution_pb2.SystemMetadata(execution_cluster=self.execution_cluster) + + @classmethod + def from_flyte_idl(cls, pb2_object: flyteidl.admin.execution_pb2.SystemMetadata) -> SystemMetadata: + return cls( + execution_cluster=pb2_object.execution_cluster, + ) + + class ExecutionMetadata(_common_models.FlyteIdlEntity): class ExecutionMode(object): MANUAL = 0 SCHEDULED = 1 SYSTEM = 2 - def __init__(self, mode, principal, nesting): + def __init__( + self, + mode: int, + principal: str, + nesting: int, + scheduled_at: typing.Optional[datetime.datetime] = None, + parent_node_execution: typing.Optional[_identifier.NodeExecutionIdentifier] = None, + reference_execution: typing.Optional[_identifier.WorkflowExecutionIdentifier] = None, + system_metadata: typing.Optional[SystemMetadata] = None, + ): """ - :param int mode: An enum value from ExecutionMetadata.ExecutionMode which specifies how the job started. - :param Text principal: The entity that triggered the execution - :param int nesting: An integer representing how deeply nested the workflow is (i.e. was it triggered by a parent + :param mode: An enum value from ExecutionMetadata.ExecutionMode which specifies how the job started. + :param principal: The entity that triggered the execution + :param nesting: An integer representing how deeply nested the workflow is (i.e. was it triggered by a parent workflow) + :param scheduled_at: For scheduled executions, the requested time for execution for this specific schedule invocation. + :param parent_node_execution: Which subworkflow node (if any) launched this execution + :param reference_execution: Optional, reference workflow execution related to this execution + :param system_metadata: Optional, platform-specific metadata about the execution. """ self._mode = mode self._principal = principal self._nesting = nesting + self._scheduled_at = scheduled_at + self._parent_node_execution = parent_node_execution + self._reference_execution = reference_execution + self._system_metadata = system_metadata @property - def mode(self): + def mode(self) -> int: """ An enum value from ExecutionMetadata.ExecutionMode which specifies how the job started. - :rtype: int """ return self._mode @property - def principal(self): + def principal(self) -> str: """ The entity that triggered the execution - :rtype: Text """ return self._principal @property - def nesting(self): + def nesting(self) -> int: """ An integer representing how deeply nested the workflow is (i.e. was it triggered by a parent workflow) - :rtype: int """ return self._nesting + @property + def scheduled_at(self) -> datetime.datetime: + """ + For scheduled executions, the requested time for execution for this specific schedule invocation. + """ + return self._scheduled_at + + @property + def parent_node_execution(self) -> _identifier.NodeExecutionIdentifier: + """ + Which subworkflow node (if any) launched this execution + """ + return self._parent_node_execution + + @property + def reference_execution(self) -> _identifier.WorkflowExecutionIdentifier: + """ + Optional, reference workflow execution related to this execution + """ + return self._reference_execution + + @property + def system_metadata(self) -> SystemMetadata: + """ + Optional, platform-specific metadata about the execution. + """ + return self._system_metadata + def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.ExecutionMetadata """ - return _execution_pb2.ExecutionMetadata(mode=self.mode, principal=self.principal, nesting=self.nesting) + p = _execution_pb2.ExecutionMetadata( + mode=self.mode, + principal=self.principal, + nesting=self.nesting, + parent_node_execution=self.parent_node_execution.to_flyte_idl() + if self.parent_node_execution is not None + else None, + reference_execution=self.reference_execution.to_flyte_idl() + if self.reference_execution is not None + else None, + system_metadata=self.system_metadata.to_flyte_idl() if self.system_metadata is not None else None, + ) + if self.scheduled_at is not None: + p.scheduled_at.FromDatetime(self.scheduled_at) + return p @classmethod def from_flyte_idl(cls, pb2_object): @@ -70,6 +149,16 @@ def from_flyte_idl(cls, pb2_object): mode=pb2_object.mode, principal=pb2_object.principal, nesting=pb2_object.nesting, + scheduled_at=pb2_object.scheduled_at.ToDatetime() if pb2_object.HasField("scheduled_at") else None, + parent_node_execution=_identifier.NodeExecutionIdentifier.from_flyte_idl(pb2_object.parent_node_execution) + if pb2_object.HasField("parent_node_execution") + else None, + reference_execution=_identifier.WorkflowExecutionIdentifier.from_flyte_idl(pb2_object.reference_execution) + if pb2_object.HasField("reference_execution") + else None, + system_metadata=SystemMetadata.from_flyte_idl(pb2_object.system_metadata) + if pb2_object.HasField("system_metadata") + else None, ) @@ -319,57 +408,82 @@ def from_flyte_idl(cls, pb): ) +class AbortMetadata(_common_models.FlyteIdlEntity): + def __init__(self, cause: str, principal: str): + self._cause = cause + self._principal = principal + + @property + def cause(self) -> str: + return self._cause + + @property + def principal(self) -> str: + return self._principal + + def to_flyte_idl(self) -> flyteidl.admin.execution_pb2.AbortMetadata: + return _execution_pb2.AbortMetadata(cause=self.cause, principal=self.principal) + + @classmethod + def from_flyte_idl(cls, pb2_object: flyteidl.admin.execution_pb2.AbortMetadata) -> AbortMetadata: + return cls( + cause=pb2_object.cause, + principal=pb2_object.principal, + ) + + class ExecutionClosure(_common_models.FlyteIdlEntity): - def __init__(self, phase, started_at, duration, error=None, outputs=None): + def __init__( + self, + phase: int, + started_at: datetime.datetime, + duration: datetime.timedelta, + error: typing.Optional[flytekit.models.core.execution.ExecutionError] = None, + outputs: typing.Optional[LiteralMapBlob] = None, + abort_metadata: typing.Optional[AbortMetadata] = None, + ): """ - :param int phase: From the flytekit.models.core.execution.WorkflowExecutionPhase enum - :param datetime.datetime started_at: - :param datetime.timedelta duration: Duration for which the execution has been running. - :param flytekit.models.core.execution.ExecutionError error: - :param LiteralMapBlob outputs: + :param phase: From the flytekit.models.core.execution.WorkflowExecutionPhase enum + :param started_at: + :param duration: Duration for which the execution has been running. + :param error: + :param outputs: + :param abort_metadata: Specifies metadata around an aborted workflow execution. """ self._phase = phase self._started_at = started_at self._duration = duration self._error = error self._outputs = outputs + self._abort_metadata = abort_metadata @property - def error(self): - """ - :rtype: flytekit.models.core.execution.ExecutionError - """ + def error(self) -> flytekit.models.core.execution.ExecutionError: return self._error @property - def phase(self): + def phase(self) -> int: """ From the flytekit.models.core.execution.WorkflowExecutionPhase enum - :rtype: int """ return self._phase @property - def started_at(self): - """ - :rtype: datetime.datetime - """ + def started_at(self) -> datetime.datetime: return self._started_at @property - def duration(self): - """ - :rtype: datetime.timedelta - """ + def duration(self) -> datetime.timedelta: return self._duration @property - def outputs(self): - """ - :rtype: LiteralMapBlob - """ + def outputs(self) -> LiteralMapBlob: return self._outputs + @property + def abort_metadata(self) -> AbortMetadata: + return self._abort_metadata + def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.ExecutionClosure @@ -378,6 +492,7 @@ def to_flyte_idl(self): phase=self.phase, error=self.error.to_flyte_idl() if self.error is not None else None, outputs=self.outputs.to_flyte_idl() if self.outputs is not None else None, + abort_metadata=self.abort_metadata.to_flyte_idl() if self.abort_metadata is not None else None, ) obj.started_at.FromDatetime(self.started_at.astimezone(_pytz.UTC).replace(tzinfo=None)) obj.duration.FromTimedelta(self.duration) @@ -395,12 +510,16 @@ def from_flyte_idl(cls, pb2_object): outputs = None if pb2_object.HasField("outputs"): outputs = LiteralMapBlob.from_flyte_idl(pb2_object.outputs) + abort_metadata = None + if pb2_object.HasField("abort_metadata"): + abort_metadata = AbortMetadata.from_flyte_idl(pb2_object.abort_metadata) return cls( error=error, outputs=outputs, phase=pb2_object.phase, started_at=pb2_object.started_at.ToDatetime().replace(tzinfo=_pytz.UTC), duration=pb2_object.duration.ToTimedelta(), + abort_metadata=abort_metadata, ) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index 1bc69ae41b..4f06c3d3c6 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -464,7 +464,9 @@ def to_literal_model(self): ) ) elif self.map: - return Literal(map=LiteralMap(literals={k: binding.to_literal_model() for k, binding in self.map.bindings})) + return Literal( + map=LiteralMap(literals={k: binding.to_literal_model() for k, binding in self.map.bindings.items()}) + ) class Binding(_common.FlyteIdlEntity): diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index f02226decc..00227f88f0 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -60,8 +60,15 @@ from flytekit.remote.remote_callable import RemoteEntity from flytekit.remote.task import FlyteTask from flytekit.remote.workflow import FlyteWorkflow +from flytekit.tools.fast_registration import fast_package from flytekit.tools.script_mode import fast_register_single_script, hash_file -from flytekit.tools.translator import FlyteLocalEntity, Options, get_serializable, get_serializable_launch_plan +from flytekit.tools.translator import ( + FlyteControlPlaneEntity, + FlyteLocalEntity, + Options, + get_serializable, + get_serializable_launch_plan, +) ExecutionDataResponse = typing.Union[WorkflowExecutionGetDataResponse, NodeExecutionGetDataResponse] @@ -364,6 +371,91 @@ def _resolve_identifier(self, t: int, name: str, version: str, ss: Serialization ) return ident + def raw_register( + self, + cp_entity: FlyteControlPlaneEntity, + settings: typing.Optional[SerializationSettings], + version: str, + create_default_launchplan: bool = True, + options: Options = None, + og_entity: FlyteLocalEntity = None, + ) -> typing.Optional[Identifier]: + """ + Raw register method, can be used to register control plane entities. Usually if you have a Flyte Entity like a + WorkflowBase, Task, LaunchPlan then use other methods. This should be used only if you have already serialized entities + + :param cp_entity: The controlplane "serializable" version of a flyte entity. This is in the form that FlyteAdmin + understands. + :param settings: SerializationSettings to be used for registration - especially to identify the id + :param version: Version to be registered + :param create_default_launchplan: boolean that indicates if a default launch plan should be created + :param options: Options to be used if registering a default launch plan + :param og_entity: Pass in the original workflow (flytekit type) if create_default_launchplan is true + :return: Identifier of the created entity + """ + if isinstance( + cp_entity, + ( + workflow_model.Node, + workflow_model.WorkflowNode, + workflow_model.BranchNode, + workflow_model.TaskNode, + ), + ): + remote_logger.debug("Ignoring nodes for registration.") + return None + + elif isinstance(cp_entity, ReferenceSpec): + remote_logger.debug(f"Skipping registration of Reference entity, name: {cp_entity.template.id.name}") + return None + + if isinstance(cp_entity, task_models.TaskSpec): + ident = self._resolve_identifier(ResourceType.TASK, cp_entity.template.id.name, version, settings) + try: + self.client.create_task(task_identifer=ident, task_spec=cp_entity) + except FlyteEntityAlreadyExistsException: + remote_logger.info(f" {ident} Already Exists!") + return ident + + if isinstance(cp_entity, admin_workflow_models.WorkflowSpec): + ident = self._resolve_identifier(ResourceType.WORKFLOW, cp_entity.template.id.name, version, settings) + try: + self.client.create_workflow(workflow_identifier=ident, workflow_spec=cp_entity) + except FlyteEntityAlreadyExistsException: + remote_logger.info(f" {ident} Already Exists!") + + if create_default_launchplan: + if not og_entity: + raise user_exceptions.FlyteValueException( + "To create default launch plan, please pass in the original flytekit workflow `og_entity`" + ) + + # Let us also create a default launch-plan, ideally the default launchplan should be added + # to the orderedDict, but we do not. + default_lp = LaunchPlan.get_default_launch_plan(self.context, og_entity) + lp_entity = get_serializable_launch_plan( + OrderedDict(), + settings, + default_lp, + recurse_downstream=False, + options=options, + ) + try: + self.client.create_launch_plan(lp_entity.id, lp_entity.spec) + except FlyteEntityAlreadyExistsException: + remote_logger.info(f" {lp_entity.id} Already Exists!") + return ident + + if isinstance(cp_entity, launch_plan_models.LaunchPlan): + ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, cp_entity.id.name, version, settings) + try: + self.client.create_launch_plan(launch_plan_identifer=ident, launch_plan_spec=cp_entity.spec) + except FlyteEntityAlreadyExistsException: + remote_logger.info(f" {ident} Already Exists!") + return ident + + raise AssertionError(f"Unknown entity of type {type(cp_entity)}") + def _serialize_and_register( self, entity: FlyteLocalEntity, @@ -378,16 +470,16 @@ def _serialize_and_register( m = OrderedDict() # Create dummy serialization settings for now. # TODO: Clean this up by using lazy usage of serialization settings in translator.py - serialization_settings = ( - settings - if settings - else SerializationSettings( + serialization_settings = settings + is_dummy_serialization_setting = False + if not settings: + serialization_settings = SerializationSettings( ImageConfig.auto_default_image(), project=self.default_project, domain=self.default_domain, version=version, ) - ) + is_dummy_serialization_setting = True _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) ident = None @@ -395,59 +487,23 @@ def _serialize_and_register( if isinstance(entity, RemoteEntity): remote_logger.debug(f"Skipping registration of remote entity: {entity.name}") continue - if isinstance( - cp_entity, - ( - workflow_model.Node, - workflow_model.WorkflowNode, - workflow_model.BranchNode, - workflow_model.TaskNode, - ), - ): - remote_logger.debug("Ignoring nodes for registration.") - continue - elif isinstance(cp_entity, ReferenceSpec): - remote_logger.debug(f"Skipping registration of Reference entity, name: {entity.name}") - continue - if not isinstance(cp_entity, admin_workflow_models.WorkflowSpec) and not settings: + if not isinstance(cp_entity, admin_workflow_models.WorkflowSpec) and is_dummy_serialization_setting: # Only in the case of workflows can we use the dummy serialization settings. raise user_exceptions.FlyteValueException( settings, - f"No serialization settings set, but workflow contains entities that need to be " - f"registered. Type: {type(entity)} {entity.name}", + f"No serialization settings set, but workflow contains entities that need to be registered. {cp_entity.id.name}", ) - try: - if isinstance(cp_entity, task_models.TaskSpec): - ident = self._resolve_identifier(ResourceType.TASK, entity.name, version, settings) - self.client.create_task(task_identifer=ident, task_spec=cp_entity) - elif isinstance(cp_entity, admin_workflow_models.WorkflowSpec): - ident = self._resolve_identifier(ResourceType.WORKFLOW, entity.name, version, settings) - try: - self.client.create_workflow(workflow_identifier=ident, workflow_spec=cp_entity) - except FlyteEntityAlreadyExistsException: - remote_logger.info(f"{entity.name} already exists") - # Let us also create a default launch-plan, ideally the default launchplan should be added - # to the orderedDict, but we do not. - default_lp = LaunchPlan.get_default_launch_plan(self.context, entity) - lp_entity = get_serializable_launch_plan( - OrderedDict(), - settings or serialization_settings, - default_lp, - recurse_downstream=False, - options=options, - ) - self.client.create_launch_plan(lp_entity.id, lp_entity.spec) - elif isinstance(cp_entity, launch_plan_models.LaunchPlan): - ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, settings) - self.client.create_launch_plan(launch_plan_identifer=ident, launch_plan_spec=cp_entity.spec) - else: - raise AssertionError(f"Unknown entity of type {type(cp_entity)}") - except FlyteEntityAlreadyExistsException: - remote_logger.info(f"{entity.name} already exists") - except Exception as e: - remote_logger.info(f"Failed to register entity {entity.name} with error {e}") - raise + + ident = self.raw_register( + cp_entity, + settings=settings, + version=version, + create_default_launchplan=True, + options=options, + og_entity=entity, + ) + return ident def register_task( @@ -508,6 +564,21 @@ def register_workflow( fwf._python_interface = entity.python_interface return fwf + def fast_package(self, root: os.PathLike, deref_symlinks: bool = True, output: str = None) -> (bytes, str): + """ + Packages the given paths into an installable zip and returns the md5_bytes and the URL of the uploaded location + :param root: path to the root of the package system that should be uploaded + :param output: output path. Optional, will default to a tempdir + :param deref_symlinks: if symlinks should be dereferenced. Defaults to True + :return: md5_bytes, url + """ + # Create a zip file containing all the entries. + zip_file = fast_package(root, output, deref_symlinks) + md5_bytes, _ = hash_file(pathlib.Path(zip_file)) + + # Upload zip file to Admin using FlyteRemote. + return self._upload_file(pathlib.Path(zip_file)) + def _upload_file( self, to_upload: pathlib.Path, project: typing.Optional[str] = None, domain: typing.Optional[str] = None ) -> typing.Tuple[bytes, str]: @@ -1473,13 +1544,15 @@ def _get_output_literal_map(self, execution_data: ExecutionDataResponse) -> lite ) return literal_models.LiteralMap({}) - def generate_http_domain(self) -> str: + def generate_console_http_domain(self) -> str: """ - This should generate the domain where the HTTP endpoints for the Flyte backend are hosted. This should be - the domain that console is hosted on. + This should generate the domain where console is hosted. :return: """ + # If the console endpoint is explicitly set, return it, else derive it from the admin config + if self.config.platform.console_endpoint: + return self.config.platform.console_endpoint protocol = "http" if self.config.platform.insecure else "https" endpoint = self.config.platform.endpoint # N.B.: this assumes that in case we have an identical configuration as the sandbox default config we are running single binary. The intent here is @@ -1491,4 +1564,4 @@ def generate_http_domain(self) -> str: def generate_console_url( self, execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution] ): - return f"{self.generate_http_domain()}/console/projects/{execution.id.project}/domains/{execution.id.domain}/executions/{execution.id.name}" + return f"{self.generate_console_http_domain()}/console/projects/{execution.id.project}/domains/{execution.id.domain}/executions/{execution.id.name}" diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index 34faadc58c..b2f7efcc65 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -18,8 +18,6 @@ FAST_PREFIX = "fast" FAST_FILEENDING = ".tar.gz" -file_access = FlyteContextManager.current_context().file_access - def fast_package(source: os.PathLike, output_dir: os.PathLike, deref_symlinks: bool = False) -> os.PathLike: """ @@ -36,7 +34,7 @@ def fast_package(source: os.PathLike, output_dir: os.PathLike, deref_symlinks: b if output_dir is None: output_dir = tempfile.mkdtemp() - click.secho(f"Output given as {None}, using a temporary directory at {output_dir} instead", fg="yellow") + click.secho(f"No output path provided, using a temporary directory at {output_dir} instead", fg="yellow") archive_fname = os.path.join(output_dir, archive_fname) @@ -105,10 +103,15 @@ def download_distribution(additional_distribution: str, destination: str): :param Text additional_distribution: :param os.PathLike destination: """ - file_access.get_data(additional_distribution, destination) + if not os.path.isdir(destination): + raise ValueError("Destination path is required to download distribution and it should be a directory") + # NOTE the os.path.join(destination, ''). This is to ensure that the given path is infact a directory and all + # downloaded data should be copied into this directory. We do this to account for a difference in behavior in + # fsspec, which requires a trailing slash in case of pre-existing directory. + FlyteContextManager.current_context().file_access.get_data(additional_distribution, os.path.join(destination, "")) tarfile_name = os.path.basename(additional_distribution) if not tarfile_name.endswith(".tar.gz"): - raise ValueError("Unrecognized additional distribution format for {}".format(additional_distribution)) + raise RuntimeError("Unrecognized additional distribution format for {}".format(additional_distribution)) # This will overwrite the existing user flyte workflow code in the current working code dir. result = _subprocess.run( diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index ceaee36435..870299e5ad 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -5,24 +5,16 @@ from pathlib import Path import click -from flyteidl.admin.launch_plan_pb2 import LaunchPlan as _idl_admin_LaunchPlan -from flyteidl.admin.launch_plan_pb2 import LaunchPlanCreateRequest -from flyteidl.admin.task_pb2 import TaskCreateRequest -from flyteidl.admin.task_pb2 import TaskSpec as _idl_admin_TaskSpec -from flyteidl.admin.workflow_pb2 import WorkflowCreateRequest -from flyteidl.admin.workflow_pb2 import WorkflowSpec as _idl_admin_WorkflowSpec -from flyteidl.core import identifier_pb2 - -from flytekit.clients.friendly import SynchronousFlyteClient -from flytekit.clis.helpers import hydrate_registration_parameters -from flytekit.configuration import SerializationSettings + +from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings from flytekit.core.context_manager import FlyteContextManager -from flytekit.exceptions.user import FlyteEntityAlreadyExistsException from flytekit.loggers import logger +from flytekit.models import launch_plan +from flytekit.remote import FlyteRemote from flytekit.tools import fast_registration, module_loader from flytekit.tools.script_mode import _find_project_root -from flytekit.tools.serialize_helpers import RegistrableEntity, get_registrable_entities, persist_registrable_entities -from flytekit.tools.translator import Options +from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities +from flytekit.tools.translator import FlyteControlPlaneEntity, Options class NoSerializableEntitiesError(Exception): @@ -34,7 +26,7 @@ def serialize( settings: SerializationSettings, local_source_root: typing.Optional[str] = None, options: typing.Optional[Options] = None, -) -> typing.List[RegistrableEntity]: +) -> typing.List[FlyteControlPlaneEntity]: """ See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the entity type. @@ -71,7 +63,7 @@ def serialize_to_folder( def package( - registrable_entities: typing.List[RegistrableEntity], + serializable_entities: typing.List[FlyteControlPlaneEntity], source: str = ".", output: str = "./flyte-package.tgz", fast: bool = False, @@ -79,17 +71,17 @@ def package( ): """ Package the given entities and the source code (if fast is enabled) into a package with the given name in output - :param registrable_entities: Entities that can be serialized + :param serializable_entities: Entities that can be serialized :param source: source folder :param output: output package name with suffix :param fast: fast enabled implies source code is bundled :param deref_symlinks: if enabled then symlinks are dereferenced during packaging """ - if not registrable_entities: + if not serializable_entities: raise NoSerializableEntitiesError("Nothing to package") with tempfile.TemporaryDirectory() as output_tmpdir: - persist_registrable_entities(registrable_entities, output_tmpdir) + persist_registrable_entities(serializable_entities, output_tmpdir) # If Fast serialization is enabled, then an archive is also created and packaged if fast: @@ -103,7 +95,7 @@ def package( with tarfile.open(output, "w:gz") as tar: tar.add(output_tmpdir, arcname="") - click.secho(f"Successfully packaged {len(registrable_entities)} flyte objects into {output}", fg="green") + click.secho(f"Successfully packaged {len(serializable_entities)} flyte objects into {output}", fg="green") def serialize_and_package( @@ -118,48 +110,8 @@ def serialize_and_package( """ Fist serialize and then package all entities """ - registrable_entities = serialize(pkgs, settings, source, options=options) - package(registrable_entities, source, output, fast, deref_symlinks) - - -def register( - registrable_entities: typing.List[RegistrableEntity], - project: str, - domain: str, - version: str, - client: SynchronousFlyteClient, -): - # The incoming registrable entities are already in base protobuf form, not model form, so we use the - # raw client's methods instead of the friendly client's methods by calling super - for admin_entity in registrable_entities: - try: - if isinstance(admin_entity, _idl_admin_TaskSpec): - ident, task_spec = hydrate_registration_parameters( - identifier_pb2.TASK, project, domain, version, admin_entity - ) - logger.debug(f"Creating task {ident}") - super(SynchronousFlyteClient, client).create_task(TaskCreateRequest(id=ident, spec=task_spec)) - elif isinstance(admin_entity, _idl_admin_WorkflowSpec): - ident, wf_spec = hydrate_registration_parameters( - identifier_pb2.WORKFLOW, project, domain, version, admin_entity - ) - logger.debug(f"Creating workflow {ident}") - super(SynchronousFlyteClient, client).create_workflow(WorkflowCreateRequest(id=ident, spec=wf_spec)) - elif isinstance(admin_entity, _idl_admin_LaunchPlan): - ident, admin_lp = hydrate_registration_parameters( - identifier_pb2.LAUNCH_PLAN, project, domain, version, admin_entity - ) - logger.debug(f"Creating launch plan {ident}") - super(SynchronousFlyteClient, client).create_launch_plan( - LaunchPlanCreateRequest(id=ident, spec=admin_lp.spec) - ) - else: - raise AssertionError(f"Unknown entity of type {type(admin_entity)}") - except FlyteEntityAlreadyExistsException: - logger.info(f"{admin_entity} already exists") - except Exception as e: - logger.info(f"Failed to register entity {admin_entity} with error {e}") - raise e + serializable_entities = serialize(pkgs, settings, source, options=options) + package(serializable_entities, source, output, fast, deref_symlinks) def find_common_root( @@ -192,7 +144,7 @@ def load_packages_and_modules( project_root: Path, pkgs_or_mods: typing.List[str], options: typing.Optional[Options] = None, -) -> typing.List[RegistrableEntity]: +) -> typing.List[FlyteControlPlaneEntity]: """ The project root is added as the first entry to sys.path, and then all the specified packages and modules given are loaded with all submodules. The reason for prepending the entry is to ensure that the name that @@ -225,3 +177,68 @@ def load_packages_and_modules( registrable_entities = serialize(pkgs_and_modules, ss, str(project_root), options) return registrable_entities + + +def register( + project: str, + domain: str, + image_config: ImageConfig, + output: str, + destination_dir: str, + service_account: str, + raw_data_prefix: str, + version: typing.Optional[str], + deref_symlinks: bool, + fast: bool, + package_or_module: typing.Tuple[str], + remote: FlyteRemote, +): + detected_root = find_common_root(package_or_module) + click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow") + fast_serialization_settings = None + if fast: + md5_bytes, native_url = remote.fast_package(detected_root, deref_symlinks, output) + fast_serialization_settings = FastSerializationSettings( + enabled=True, + destination_dir=destination_dir, + distribution_location=native_url, + ) + + # Create serialization settings + # Todo: Rely on default Python interpreter for now, this will break custom Spark containers + serialization_settings = SerializationSettings( + project=project, + domain=domain, + version=version, + image_config=image_config, + fast_serialization_settings=fast_serialization_settings, + ) + + if not version and fast: + version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa + click.secho(f"Computed version is {version}", fg="yellow") + elif not version: + click.secho("Version is required.", fg="red") + return + + b = serialization_settings.new_builder() + b.version = version + serialization_settings = b.build() + + options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) + + # Load all the entities + serializable_entities = load_packages_and_modules( + serialization_settings, detected_root, list(package_or_module), options + ) + if len(serializable_entities) == 0: + click.secho("No Flyte entities were detected. Aborting!", fg="red") + return + click.secho(f"Found and serialized {len(serializable_entities)} entities") + + for cp_entity in serializable_entities: + name = cp_entity.id.name if isinstance(cp_entity, launch_plan.LaunchPlan) else cp_entity.template.id.name + click.secho(f" Registering {name}....", dim=True, nl=False) + i = remote.raw_register(cp_entity, serialization_settings, version=version, create_default_launchplan=False) + click.secho(f"done, {i.resource_type_name()} with version {i.version}.", dim=True) + click.secho(f"Successfully registered {len(serializable_entities)} entities", fg="green") diff --git a/flytekit/tools/serialize_helpers.py b/flytekit/tools/serialize_helpers.py index 9fd7f05e89..f7937443be 100644 --- a/flytekit/tools/serialize_helpers.py +++ b/flytekit/tools/serialize_helpers.py @@ -5,9 +5,6 @@ from collections import OrderedDict import click -from flyteidl.admin.launch_plan_pb2 import LaunchPlan as _idl_admin_LaunchPlan -from flyteidl.admin.task_pb2 import TaskSpec as _idl_admin_TaskSpec -from flyteidl.admin.workflow_pb2 import WorkflowSpec as _idl_admin_WorkflowSpec from flytekit import LaunchPlan from flytekit.core import context_manager as flyte_context @@ -17,10 +14,10 @@ from flytekit.models import launch_plan as _launch_plan_models from flytekit.models import task as task_models from flytekit.models.admin import workflow as admin_workflow_models +from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core import identifier as _identifier -from flytekit.tools.translator import Options, get_serializable - -RegistrableEntity = typing.Union[_idl_admin_TaskSpec, _idl_admin_LaunchPlan, _idl_admin_WorkflowSpec] +from flytekit.models.task import TaskSpec +from flytekit.tools.translator import FlyteControlPlaneEntity, Options, get_serializable def _determine_text_chars(length): @@ -62,7 +59,7 @@ def _find_duplicate_tasks(tasks: typing.List[task_models.TaskSpec]) -> typing.Se def get_registrable_entities( ctx: flyte_context.FlyteContext, options: typing.Optional[Options] = None -) -> typing.List[RegistrableEntity]: +) -> typing.List[FlyteControlPlaneEntity]: """ Returns all entities that can be serialized and should be sent over to Flyte backend. This will filter any entities that are not known to Admin @@ -94,10 +91,10 @@ def get_registrable_entities( f"Multiple definitions of the following tasks were found: {duplicate_task_names}" ) - return [v.to_flyte_idl() for v in entities_to_be_serialized] + return entities_to_be_serialized -def persist_registrable_entities(entities: typing.List[RegistrableEntity], folder: str): +def persist_registrable_entities(entities: typing.List[FlyteControlPlaneEntity], folder: str): """ For protobuf serializable list of entities, writes a file with the name if the entity and enumeration order to the specified folder @@ -113,13 +110,13 @@ def persist_registrable_entities(entities: typing.List[RegistrableEntity], folde zero_padded_length = _determine_text_chars(len(entities)) for i, entity in enumerate(entities): fname_index = str(i).zfill(zero_padded_length) - if isinstance(entity, _idl_admin_TaskSpec): + if isinstance(entity, TaskSpec): name = entity.template.id.name fname = "{}_{}_1.pb".format(fname_index, entity.template.id.name) - elif isinstance(entity, _idl_admin_WorkflowSpec): + elif isinstance(entity, WorkflowSpec): name = entity.template.id.name fname = "{}_{}_2.pb".format(fname_index, entity.template.id.name) - elif isinstance(entity, _idl_admin_LaunchPlan): + elif isinstance(entity, _launch_plan_models.LaunchPlan): name = entity.id.name fname = "{}_{}_3.pb".format(fname_index, entity.id.name) else: @@ -128,4 +125,4 @@ def persist_registrable_entities(entities: typing.List[RegistrableEntity], folde click.secho(f" Packaging {name} -> {fname}", dim=True) fname = _os.path.join(folder, fname) with open(fname, "wb") as writer: - writer.write(entity.SerializeToString()) + writer.write(entity.serialize_to_string()) diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index 6aa193d7ce..c2ab8fd438 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -11,6 +11,7 @@ FlyteDirectory TensorboardLogs + TFRecordsDirectory """ import typing @@ -26,3 +27,11 @@ This is usually the SummaryWriter output in PyTorch or Keras callbacks which record the history readable by TensorBoard. """ + +tfrecords_dir = typing.TypeVar("tfrecord") +TFRecordsDirectory = FlyteDirectory[tfrecords_dir] +""" + This type can be used to denote that the output is a folder that contains tensorflow record files. + This is usually the TFRecordWriter output in Tensorflow which writes serialised tf.train.Example + message (or protobuf) to tfrecord files +""" diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index b08323cd1e..afb59d58d0 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -15,8 +15,10 @@ from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType +from flytekit.types.file import FileExt T = typing.TypeVar("T") +PathType = typing.Union[str, os.PathLike] def noop(): @@ -26,7 +28,7 @@ def noop(): @dataclass_json @dataclass class FlyteDirectory(os.PathLike, typing.Generic[T]): - path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) + path: PathType = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore """ .. warning:: @@ -144,7 +146,9 @@ def extension(cls) -> str: def __class_getitem__(cls, item: typing.Union[typing.Type, str]) -> typing.Type[FlyteDirectory]: if item is None: return cls - item_string = str(item) + + item_string = FileExt.check_and_convert_to_str(item) + item_string = item_string.strip().lstrip("~").lstrip(".") if item_string == "": return cls diff --git a/flytekit/types/file/__init__.py b/flytekit/types/file/__init__.py index 9e8fca1971..871c48d4c6 100644 --- a/flytekit/types/file/__init__.py +++ b/flytekit/types/file/__init__.py @@ -20,63 +20,97 @@ PythonNotebook SVGImageFile """ - import typing +from typing_extensions import Annotated, get_args, get_origin + from .file import FlyteFile + +class FileExt: + """ + Used for annotating file extension types of FlyteFile. + This is useful for extensions that have periods in them, e.g., "tar.gz". + + Example: + TAR_GZ = Annotated[str, FileExt("tar.gz")] + """ + + def __init__(self, ext: str): + self._ext = ext + + def __str__(self): + return self._ext + + def __repr__(self): + return self._ext + + @staticmethod + def check_and_convert_to_str(item: typing.Union[typing.Type, str]) -> str: + if not get_origin(item) is Annotated: + return str(item) + if get_args(item)[0] == str: + return str(get_args(item)[1]) + raise ValueError("Underlying type of File Extension must be of type ") + + # The following section provides some predefined aliases for commonly used FlyteFile formats. # This makes their usage extremely simple for the users. Please keep the list sorted. -hdf5 = typing.TypeVar("hdf5") +hdf5 = Annotated[str, FileExt("hdf5")] #: This can be used to denote that the returned file is of type hdf5 and can be received by other tasks that #: accept an hdf5 format. This is usually useful for serializing Tensorflow models HDF5EncodedFile = FlyteFile[hdf5] -html = typing.TypeVar("html") +html = Annotated[str, FileExt("html")] #: Can be used to receive or return an HTMLPage. The underlying type is a FlyteFile type. This is just a #: decoration and useful for attaching content type information with the file and automatically documenting code. HTMLPage = FlyteFile[html] -joblib = typing.TypeVar("joblib") +joblib = Annotated[str, FileExt("joblib")] #: This File represents a file that was serialized using `joblib.dump` method can be loaded back using `joblib.load`. JoblibSerializedFile = FlyteFile[joblib] -jpeg = typing.TypeVar("jpeg") +jpeg = Annotated[str, FileExt("jpeg")] #: Can be used to receive or return an JPEGImage. The underlying type is a FlyteFile type. This is just a #: decoration and useful for attaching content type information with the file and automatically documenting code. JPEGImageFile = FlyteFile[jpeg] -pdf = typing.TypeVar("pdf") +pdf = Annotated[str, FileExt("pdf")] #: Can be used to receive or return an PDFFile. The underlying type is a FlyteFile type. This is just a #: decoration and useful for attaching content type information with the file and automatically documenting code. PDFFile = FlyteFile[pdf] -png = typing.TypeVar("png") +png = Annotated[str, FileExt("png")] #: Can be used to receive or return an PNGImage. The underlying type is a FlyteFile type. This is just a #: decoration and useful for attaching content type information with the file and automatically documenting code. PNGImageFile = FlyteFile[png] -python_pickle = typing.TypeVar("python_pickle") +python_pickle = Annotated[str, FileExt("python_pickle")] #: This type can be used when a serialized Python pickled object is returned and shared between tasks. This only #: adds metadata to the file in Flyte, but does not really carry any object information. PythonPickledFile = FlyteFile[python_pickle] -ipynb = typing.TypeVar("ipynb") +ipynb = Annotated[str, FileExt("ipynb")] #: This type is used to identify a Python notebook file. PythonNotebook = FlyteFile[ipynb] -svg = typing.TypeVar("svg") +svg = Annotated[str, FileExt("svg")] #: Can be used to receive or return an SVGImage. The underlying type is a FlyteFile type. This is just a #: decoration and useful for attaching content type information with the file and automatically documenting code. SVGImageFile = FlyteFile[svg] -csv = typing.TypeVar("csv") +csv = Annotated[str, FileExt("csv")] #: Can be used to receive or return a CSVFile. The underlying type is a FlyteFile type. This is just a #: decoration and useful for attaching content type information with the file and automatically documenting code. CSVFile = FlyteFile[csv] -onnx = typing.TypeVar("onnx") +onnx = Annotated[str, FileExt("onnx")] #: Can be used to receive or return an ONNXFile. The underlying type is a FlyteFile type. This is just a #: decoration and useful for attaching content type information with the file and automatically documenting code. ONNXFile = FlyteFile[onnx] + +tfrecords_file = Annotated[str, FileExt("tfrecord")] +#: Can be used to receive or return an TFRecordFile. The underlying type is a FlyteFile type. This is just a +#: decoration and useful for attaching content type information with the file and automatically documenting code. +TFRecordFile = FlyteFile[tfrecords_file] diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 744f56de6b..9fc55f76ce 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -27,7 +27,7 @@ def noop(): @dataclass_json @dataclass class FlyteFile(os.PathLike, typing.Generic[T]): - path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) + path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore """ Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int exists for Flyte's Integer type) we need to create one so that users can express that their tasks take @@ -149,9 +149,13 @@ def extension(cls) -> str: return "" def __class_getitem__(cls, item: typing.Union[str, typing.Type]) -> typing.Type[FlyteFile]: + from . import FileExt + if item is None: return cls - item_string = str(item) + + item_string = FileExt.check_and_convert_to_str(item) + item_string = item_string.strip().lstrip("~").lstrip(".") if item == "": return cls @@ -167,7 +171,10 @@ def extension(cls) -> str: return _SpecificFormatClass def __init__( - self, path: typing.Union[str, os.PathLike], downloader: typing.Callable = noop, remote_path: os.PathLike = None + self, + path: typing.Union[str, os.PathLike], + downloader: typing.Callable = noop, + remote_path: typing.Optional[os.PathLike] = None, ): """ :param path: The source path that users are expected to call open() on @@ -205,7 +212,7 @@ def downloaded(self) -> bool: return self._downloaded @property - def remote_path(self) -> os.PathLike: + def remote_path(self) -> typing.Optional[os.PathLike]: return self._remote_path @property diff --git a/flytekit/types/numpy/ndarray.py b/flytekit/types/numpy/ndarray.py index b4f67b94f1..38fedfacca 100644 --- a/flytekit/types/numpy/ndarray.py +++ b/flytekit/types/numpy/ndarray.py @@ -1,8 +1,10 @@ import pathlib import typing -from typing import Type +from collections import OrderedDict +from typing import Dict, Tuple, Type import numpy as np +from typing_extensions import Annotated, get_args, get_origin from flytekit.core.context_manager import FlyteContext from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError @@ -11,6 +13,17 @@ from flytekit.models.types import LiteralType +def extract_metadata(t: Type[np.ndarray]) -> Tuple[Type[np.ndarray], Dict[str, bool]]: + metadata = {} + if get_origin(t) is Annotated: + base_type, metadata = get_args(t) + if isinstance(metadata, OrderedDict): + return base_type, metadata + else: + raise TypeTransformerFailedError(f"{t}'s metadata needs to be of type kwtypes.") + return t, metadata + + class NumpyArrayTransformer(TypeTransformer[np.ndarray]): """ TypeTransformer that supports np.ndarray as a native type. @@ -31,6 +44,8 @@ def get_literal_type(self, t: Type[np.ndarray]) -> LiteralType: def to_literal( self, ctx: FlyteContext, python_val: np.ndarray, python_type: Type[np.ndarray], expected: LiteralType ) -> Literal: + python_type, metadata = extract_metadata(python_type) + meta = BlobMetadata( type=_core_types.BlobType( format=self.NUMPY_ARRAY_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE @@ -40,9 +55,8 @@ def to_literal( local_path = ctx.file_access.get_random_local_path() + ".npy" pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) - # save numpy array to a file - # allow_pickle=False prevents numpy from trying to save object arrays (dtype=object) using pickle - np.save(file=local_path, arr=python_val, allow_pickle=False) + # save numpy array to file + np.save(file=local_path, arr=python_val, allow_pickle=metadata.get("allow_pickle", False)) remote_path = ctx.file_access.get_random_remote_path(local_path) ctx.file_access.put_data(local_path, remote_path, is_multipart=False) @@ -54,11 +68,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: except AttributeError: raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + expected_python_type, metadata = extract_metadata(expected_python_type) + local_path = ctx.file_access.get_random_local_path() ctx.file_access.get_data(uri, local_path, is_multipart=False) # load numpy array from a file - return np.load(file=local_path) + return np.load( + file=local_path, + allow_pickle=metadata.get("allow_pickle", False), + mmap_mode=metadata.get("mmap_mode"), + ) def guess_python_type(self, literal_type: LiteralType) -> typing.Type[np.ndarray]: if ( diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 6f01cea085..f486d1012e 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -16,6 +16,7 @@ from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.loggers import logger from flytekit.models.literals import Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType @@ -361,7 +362,11 @@ def to_literal( remote_path = python_val.remote_path if remote_path is None or remote_path == "": remote_path = ctx.file_access.get_random_remote_path() - ctx.file_access.put_data(python_val.local_path, remote_path, is_multipart=True) + if python_val.supported_mode == SchemaOpenMode.READ and not python_val._downloaded: + # This means the local path is empty. Don't try to overwrite the remote data + logger.debug(f"Skipping upload for {python_val} because it was never downloaded.") + else: + ctx.file_access.put_data(python_val.local_path, remote_path, is_multipart=True) return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type(python_type)))) schema = python_type( diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index 0d3bc8cd78..e4c6078e94 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -81,7 +81,6 @@ class PandasDataFrameTransformer(TypeTransformer[pandas.DataFrame]): def __init__(self): super().__init__("PandasDataFrame<->GenericSchema", pandas.DataFrame) self._parquet_engine = ParquetIO() - self._hash_overridable = True @staticmethod def _get_schema_type() -> SchemaType: diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 71dff61c5e..39f8d11e24 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -101,10 +101,10 @@ def decode( return pq.read_table(local_dir) -StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler()) -StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler()) -StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler()) -StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler()) +StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True) +StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True) +StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True) +StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True) StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index b4020a2022..f0fd917340 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -34,6 +34,7 @@ # Storage formats PARQUET: StructuredDatasetFormat = "parquet" +GENERIC_FORMAT: StructuredDatasetFormat = "" @dataclass_json @@ -45,9 +46,7 @@ class (that is just a model, a Python class representation of the protobuf). """ uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) - file_format: typing.Optional[str] = field(default=PARQUET, metadata=config(mm_field=fields.String())) - - DEFAULT_FILE_FORMAT = PARQUET + file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String())) @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: @@ -60,7 +59,7 @@ def column_names(cls) -> typing.List[str]: def __init__( self, dataframe: typing.Optional[typing.Any] = None, - uri: Optional[str] = None, + uri: Optional[str, os.PathLike] = None, metadata: typing.Optional[literals.StructuredDatasetMetadata] = None, **kwargs, ): @@ -68,15 +67,17 @@ def __init__( # Make these fields public, so that the dataclass transformer can set a value for it # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298 self.uri = uri + # When dataclass_json runs from_json, we need to set it here, otherwise the format will be empty string + self.file_format = kwargs["file_format"] if "file_format" in kwargs else GENERIC_FORMAT # This is a special attribute that indicates if the data was either downloaded or uploaded self._metadata = metadata # This is not for users to set, the transformer will set this. self._literal_sd: Optional[literals.StructuredDataset] = None # Not meant for users to set, will be set by an open() call - self._dataframe_type = None + self._dataframe_type: Optional[Type[DF]] = None @property - def dataframe(self) -> Type[typing.Any]: + def dataframe(self) -> Optional[Type[DF]]: return self._dataframe @property @@ -336,20 +337,42 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): @classmethod def _finder(cls, handler_map, df_type: Type, protocol: str, format: str): - try: - return handler_map[df_type][protocol][format] - except KeyError: + # If the incoming format requested is a specific format (e.g. "avro"), then look for that specific handler + # if missing, see if there's a generic format handler. Error if missing. + # If the incoming format requested is the generic format (""), then see if it's present, + # if not, look to see if there is a default format for the df_type and a handler for that format. + # if still missing, look to see if there's only _one_ handler for that type, if so then use that. + if format != GENERIC_FORMAT: try: - default_format = cls.DEFAULT_FORMATS.get(df_type, PARQUET) - hh = handler_map[df_type][protocol][default_format] - logger.info( - f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}" - f" format {format}, using default {default_format} instead." - ) - return hh + return handler_map[df_type][protocol][format] + except KeyError: + try: + return handler_map[df_type][protocol][GENERIC_FORMAT] + except KeyError: + ... + else: + try: + return handler_map[df_type][protocol][GENERIC_FORMAT] except KeyError: - ... - raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt {format}") + if df_type in cls.DEFAULT_FORMATS and cls.DEFAULT_FORMATS[df_type] in handler_map[df_type][protocol]: + hh = handler_map[df_type][protocol][cls.DEFAULT_FORMATS[df_type]] + logger.debug( + f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}" + f" using the generic handler {hh} instead." + ) + return hh + if len(handler_map[df_type][protocol]) == 1: + hh = list(handler_map[df_type][protocol].values())[0] + logger.debug( + f"Using {hh} with format {hh.supported_format} as it's the only one available for {df_type}" + ) + return hh + else: + logger.warning( + f"Did not automatically pick a handler for {df_type}," + f" more than one detected {handler_map[df_type][protocol].keys()}" + ) + raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt |{format}|") @classmethod def get_encoder(cls, df_type: Type, protocol: str, format: str): @@ -364,28 +387,32 @@ def _handler_finder(cls, h: Handlers, protocol: str) -> Dict[str, Handlers]: if isinstance(h, StructuredDatasetEncoder): top_level = cls.ENCODERS elif isinstance(h, StructuredDatasetDecoder): - top_level = cls.DECODERS + top_level = cls.DECODERS # type: ignore else: raise TypeError(f"We don't support this type of handler {h}") if h.python_type not in top_level: top_level[h.python_type] = {} if protocol not in top_level[h.python_type]: top_level[h.python_type][protocol] = {} - return top_level[h.python_type][protocol] + return top_level[h.python_type][protocol] # type: ignore def __init__(self): super().__init__("StructuredDataset Transformer", StructuredDataset) self._type_assertions_enabled = False - # Instances of StructuredDataset opt-in to the ability of being cached. - self._hash_overridable = True - @classmethod def register_renderer(cls, python_type: Type, renderer: Renderable): cls.Renderers[python_type] = renderer @classmethod - def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False): + def register( + cls, + h: Handlers, + default_for_type: bool = False, + override: bool = False, + default_format_for_type: bool = False, + default_storage_for_type: bool = False, + ): """ Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not specify a protocol (e.g. s3, gs, etc.) field, then @@ -399,6 +426,10 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = False, overrid In these cases, the protocol is determined by the raw output data prefix set in the active context. :param override: Override any previous registrations. If default_for_type is also set, this will also override the default. + :param default_format_for_type: Unlike the default_for_type arg that will set this handler's format and storage + as the default, this will only set the format. Error if already set, unless override is specified. + :param default_storage_for_type: Same as above but only for the storage format. Error if already set, + unless override is specified. """ if not (isinstance(h, StructuredDatasetEncoder) or isinstance(h, StructuredDatasetDecoder)): raise TypeError(f"We don't support this type of handler {h}") @@ -413,17 +444,29 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = False, overrid stripped = DataPersistencePlugins.get_protocol(persistence_protocol) logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}") try: - cls.register_for_protocol(h, stripped, False, override) + cls.register_for_protocol( + h, stripped, False, override, default_format_for_type, default_storage_for_type + ) except DuplicateHandlerError: logger.debug(f"Skipping {persistence_protocol}/{stripped} for {h} because duplicate") elif h.protocol == "": raise ValueError(f"Use None instead of empty string for registering handler {h}") else: - cls.register_for_protocol(h, h.protocol, default_for_type, override) + cls.register_for_protocol( + h, h.protocol, default_for_type, override, default_format_for_type, default_storage_for_type + ) @classmethod - def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: bool, override: bool): + def register_for_protocol( + cls, + h: Handlers, + protocol: str, + default_for_type: bool, + override: bool, + default_format_for_type: bool, + default_storage_for_type: bool, + ): """ See the main register function instead. """ @@ -438,12 +481,24 @@ def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: boo lowest_level[h.supported_format] = h logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}") - if default_for_type: - logger.debug( - f"Using storage {protocol} and format {h.supported_format} for dataframes of type {h.python_type} from handler {h}" - ) - cls.DEFAULT_FORMATS[h.python_type] = h.supported_format - cls.DEFAULT_PROTOCOLS[h.python_type] = protocol + if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT: + if h.python_type in cls.DEFAULT_FORMATS and not override: + logger.warning( + f"Not using handler {h} with format {h.supported_format} as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified." + ) + else: + logger.debug( + f"Setting format {h.supported_format} for dataframes of type {h.python_type} from handler {h}" + ) + cls.DEFAULT_FORMATS[h.python_type] = h.supported_format + if default_storage_for_type or default_for_type: + if h.protocol in cls.DEFAULT_PROTOCOLS and not override: + logger.warning( + f"Not using handler {h} with storage protocol {h.protocol} as default for {h.python_type}, {cls.DEFAULT_PROTOCOLS[h.python_type]} already specified." + ) + else: + logger.debug(f"Using storage {protocol} for dataframes of type {h.python_type} from handler {h}") + cls.DEFAULT_PROTOCOLS[h.python_type] = protocol # Register with the type engine as well # The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as @@ -465,7 +520,7 @@ def to_literal( # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema python_type, *attrs = extract_cols_and_format(python_type) # In case it's a FlyteSchema - sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, None)) + sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT)) if expected and expected.structured_dataset_type: sdt = StructuredDatasetType( @@ -569,7 +624,9 @@ def encode( sd_model.metadata._structured_dataset_type.format = handler.supported_format return Literal(scalar=Scalar(structured_dataset=sd_model)) - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset + ) -> T | StructuredDataset: """ The only tricky thing with converting a Literal (say the output of an earlier task), to a Python value at the start of a task execution, is the column subsetting behavior. For example, if you have, @@ -758,7 +815,6 @@ def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any] # Get the column information converted_cols = self._convert_ordered_dict_of_columns_to_list(column_map) - return StructuredDatasetType( columns=converted_cols, format=storage_format, @@ -776,7 +832,7 @@ def get_literal_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) """ return LiteralType(structured_dataset_type=self._get_dataset_type(t)) - def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + def guess_python_type(self, literal_type: LiteralType) -> Type[StructuredDataset]: # todo: technically we should return the dataframe type specified in the constructor, but to do that, # we'd have to store that, which we don't do today. See possibly #1363 if literal_type.structured_dataset_type is not None: diff --git a/flytekit_scripts/flytekit_build_image.sh b/flytekit_scripts/flytekit_build_image.sh index 12d6ced04b..8ddfa54358 100755 --- a/flytekit_scripts/flytekit_build_image.sh +++ b/flytekit_scripts/flytekit_build_image.sh @@ -61,10 +61,20 @@ if [ -n "$REGISTRY" ]; then else FLYTE_INTERNAL_IMAGE=${IMAGE_NAME}:${PREFIX}${TAG} fi -echo "Building: $FLYTE_INTERNAL_IMAGE" + +DOCKER_PLATFORM_OPT=() +# Check if the user set the target build architecture, if not use the default instead. +if [ -n "$TARGET_PLATFORM_BUILD" ]; then + DOCKER_PLATFORM_OPT=(--platform "$TARGET_PLATFORM_BUILD") +else + TARGET_PLATFORM_BUILD="default" +fi + +echo "Building: $FLYTE_INTERNAL_IMAGE for $TARGET_PLATFORM_BUILD architecture" # This build command is the raison d'etre of this script, it ensures that the version is injected into the image itself -docker build . --build-arg tag="$FLYTE_INTERNAL_IMAGE" -t "$FLYTE_INTERNAL_IMAGE" -f "${DOCKERFILE_PATH}" +docker build . "${DOCKER_PLATFORM_OPT[@]}" --build-arg tag="${FLYTE_INTERNAL_IMAGE}" -t "${FLYTE_INTERNAL_IMAGE}" -f "${DOCKERFILE_PATH}" + echo "$IMAGE_NAME built locally." # Create the appropriate tags diff --git a/plugins/README.md b/plugins/README.md index 4986e3445f..447b91a37c 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -4,23 +4,24 @@ All the Flytekit plugins maintained by the core team are added here. It is not n ## Currently Available Plugins 🔌 -| Plugin | Installation | Description | Version | Type | -|------------------------------|------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------|---------------| -| AWS Sagemaker Training | ```bash pip install flytekitplugins-awssagemaker ``` | Installs SDK to author Sagemaker built-in and custom training jobs in python | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-awssagemaker/) | Backend | -| Hive Queries | ```bash pip install flytekitplugins-hive ``` | Installs SDK to author Hive Queries that can be executed on a configured hive backend using Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-hive/) | Backend | -| K8s distributed PyTorch Jobs | ```bash pip install flytekitplugins-kfpytorch ``` | Installs SDK to author Distributed pyTorch Jobs in python using Kubeflow PyTorch Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-kfpytorch/) | Backend | -| K8s native tensorflow Jobs | ```bash pip install flytekitplugins-kftensorflow ``` | Installs SDK to author Distributed tensorflow Jobs in python using Kubeflow Tensorflow Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-kftensorflow/) | Backend | -| K8s native MPI Jobs | ```bash pip install flytekitplugins-kfmpi ``` | Installs SDK to author Distributed MPI Jobs in python using Kubeflow MPI Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-kftensorflow/) | Backend | -| Papermill based Tasks | ```bash pip install flytekitplugins-papermill ``` | Execute entire notebooks as Flyte Tasks and pass inputs and outputs between them and python tasks | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-papermill/) | Flytekit-only | -| Pod Tasks | ```bash pip install flytekitplugins-pod ``` | Installs SDK to author Pods in python. These pods can have multiple containers, use volumes and have non exiting side-cars | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-pod/) | Flytekit-only | -| spark | ```bash pip install flytekitplugins-spark ``` | Installs SDK to author Spark jobs that can be executed natively on Kubernetes with a supported backend Flyte plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-spark/) | Backend | -| AWS Athena Queries | ```bash pip install flytekitplugins-athena ``` | Installs SDK to author queries executed on AWS Athena | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-athena/) | Backend | -| DOLT | ```bash pip install flytekitplugins-dolt ``` | Read & write dolt data sets and use dolt tables as native types | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-dolt/) | Flytekit-only | -| Pandera | ```bash pip install flytekitplugins-pandera ``` | Use Pandera schemas as native Flyte types, which enable data quality checks. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-pandera/) | Flytekit-only | -| SQLAlchemy | ```bash pip install flytekitplugins-sqlalchemy ``` | Write queries for any database that supports SQLAlchemy | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-sqlalchemy/) | Flytekit-only | -| Great Expectations | ```bash pip install flytekitplugins-great-expectations``` | Enforce data quality for various data types within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-great-expectations.svg)](https://pypi.python.org/pypi/flytekitplugins-great-expectations/) | Flytekit-only | -| Snowflake | ```bash pip install flytekitplugins-snowflake``` | Use Snowflake as a 'data warehouse-as-a-service' within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-great-expectations.svg)](https://pypi.python.org/pypi/flytekitplugins-great-expectations/) | Backend | - +| Plugin | Installation | Description | Version | Type | +|------------------------------|-----------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| +| AWS Sagemaker Training | ```bash pip install flytekitplugins-awssagemaker ``` | Installs SDK to author Sagemaker built-in and custom training jobs in python | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-awssagemaker/) | Backend | +| Hive Queries | ```bash pip install flytekitplugins-hive ``` | Installs SDK to author Hive Queries that can be executed on a configured hive backend using Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-hive.svg)](https://pypi.python.org/pypi/flytekitplugins-hive/) | Backend | +| K8s distributed PyTorch Jobs | ```bash pip install flytekitplugins-kfpytorch ``` | Installs SDK to author Distributed pyTorch Jobs in python using Kubeflow PyTorch Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfpytorch.svg)](https://pypi.python.org/pypi/flytekitplugins-kfpytorch/) | Backend | +| K8s native tensorflow Jobs | ```bash pip install flytekitplugins-kftensorflow ``` | Installs SDK to author Distributed tensorflow Jobs in python using Kubeflow Tensorflow Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kftensorflow.svg)](https://pypi.python.org/pypi/flytekitplugins-kftensorflow/) | Backend | +| K8s native MPI Jobs | ```bash pip install flytekitplugins-kfmpi ``` | Installs SDK to author Distributed MPI Jobs in python using Kubeflow MPI Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfmpi.svg)](https://pypi.python.org/pypi/flytekitplugins-kfmpi/) | Backend | +| Papermill based Tasks | ```bash pip install flytekitplugins-papermill ``` | Execute entire notebooks as Flyte Tasks and pass inputs and outputs between them and python tasks | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-papermill.svg)](https://pypi.python.org/pypi/flytekitplugins-papermill/) | Flytekit-only | +| Pod Tasks | ```bash pip install flytekitplugins-pod ``` | Installs SDK to author Pods in python. These pods can have multiple containers, use volumes and have non exiting side-cars | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-pod.svg)](https://pypi.python.org/pypi/flytekitplugins-pod/) | Flytekit-only | +| spark | ```bash pip install flytekitplugins-spark ``` | Installs SDK to author Spark jobs that can be executed natively on Kubernetes with a supported backend Flyte plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-spark/) | Backend | +| AWS Athena Queries | ```bash pip install flytekitplugins-athena ``` | Installs SDK to author queries executed on AWS Athena | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-athena.svg)](https://pypi.python.org/pypi/flytekitplugins-athena/) | Backend | +| DOLT | ```bash pip install flytekitplugins-dolt ``` | Read & write dolt data sets and use dolt tables as native types | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dolt.svg)](https://pypi.python.org/pypi/flytekitplugins-dolt/) | Flytekit-only | +| Pandera | ```bash pip install flytekitplugins-pandera ``` | Use Pandera schemas as native Flyte types, which enable data quality checks. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-pandera.svg)](https://pypi.python.org/pypi/flytekitplugins-pandera/) | Flytekit-only | +| SQLAlchemy | ```bash pip install flytekitplugins-sqlalchemy ``` | Write queries for any database that supports SQLAlchemy | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-sqlalchemy.svg)](https://pypi.python.org/pypi/flytekitplugins-sqlalchemy/) | Flytekit-only | +| Great Expectations | ```bash pip install flytekitplugins-great-expectations``` | Enforce data quality for various data types within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-great-expectations.svg)](https://pypi.python.org/pypi/flytekitplugins-great-expectations/) | Flytekit-only | +| Snowflake | ```bash pip install flytekitplugins-snowflake``` | Use Snowflake as a 'data warehouse-as-a-service' within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-snowflake.svg)](https://pypi.python.org/pypi/flytekitplugins-snowflake/) | Backend | +| dbt | ```bash pip install flytekitplugins-dbt``` | Run dbt within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dbt.svg)](https://pypi.python.org/pypi/flytekitplugins-dbt/) | Flytekit-only | +| Huggingface | ```bash pip install flytekitplugins-huggingface``` | Read & write Hugginface Datasets as Flyte StructuredDatasets | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-huggingface.svg)](https://pypi.python.org/pypi/flytekitplugins-huggingface/) | Flytekit-only | ## Have a Plugin Idea? 💡 Please [file an issue](https://github.com/flyteorg/flyte/issues/new?assignees=&labels=untriaged%2Cplugins&template=backend-plugin-request.md&title=%5BPlugin%5D). @@ -71,7 +72,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" # microlib_name = f"flytekitplugins-data-{PLUGIN_NAME}" # TODO add additional requirements -plugin_requires = ["flytekit>=0.21.3,<1.0.0", ""] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0, ""] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-aws-athena/requirements.txt b/plugins/flytekit-aws-athena/requirements.txt index 41c8b68f54..0659434eb4 100644 --- a/plugins/flytekit-aws-athena/requirements.txt +++ b/plugins/flytekit-aws-athena/requirements.txt @@ -103,7 +103,7 @@ packaging==21.3 # via marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -184,7 +184,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-aws-batch/requirements.txt b/plugins/flytekit-aws-batch/requirements.txt index 14947ba486..b73cef90da 100644 --- a/plugins/flytekit-aws-batch/requirements.txt +++ b/plugins/flytekit-aws-batch/requirements.txt @@ -103,7 +103,7 @@ packaging==21.3 # via marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -184,7 +184,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py index 7e6c0726ce..9951758a42 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py @@ -1,9 +1,10 @@ import typing from dataclasses import dataclass -from typing import Any, Callable, Dict, TypeVar +from typing import Any, Callable, Dict from flytekitplugins.awssagemaker.distributed_training import DistributedTrainingContext from google.protobuf.json_format import MessageToDict +from typing_extensions import Annotated import flytekit from flytekit import ExecutionParameters, FlyteContextManager, PythonFunctionTask, kwtypes @@ -11,7 +12,7 @@ from flytekit.extend import ExecutionState, IgnoreOutputs, Interface, PythonTask, TaskPlugins from flytekit.loggers import logger from flytekit.types.directory.types import FlyteDirectory -from flytekit.types.file import FlyteFile +from flytekit.types.file import FileExt, FlyteFile from .models import training_job as _training_job_models @@ -48,7 +49,7 @@ class SagemakerBuiltinAlgorithmsTask(PythonTask[SagemakerTrainingJobConfig]): _SAGEMAKER_TRAINING_JOB_TASK = "sagemaker_training_job_task" - OUTPUT_TYPE = TypeVar("tar.gz") + OUTPUT_TYPE = Annotated[str, FileExt("tar.gz")] def __init__( self, @@ -70,7 +71,9 @@ def __init__( ): raise ValueError("TaskConfig, algorithm_specification, training_job_resource_config are required") - input_type = TypeVar(self._content_type_to_blob_format(task_config.algorithm_specification.input_content_type)) + input_type = Annotated[ + str, FileExt(self._content_type_to_blob_format(task_config.algorithm_specification.input_content_type)) + ] interface = Interface( # TODO change train and validation to be FlyteDirectory when available diff --git a/plugins/flytekit-aws-sagemaker/requirements.txt b/plugins/flytekit-aws-sagemaker/requirements.txt index a4806c45c3..82be0fa00b 100644 --- a/plugins/flytekit-aws-sagemaker/requirements.txt +++ b/plugins/flytekit-aws-sagemaker/requirements.txt @@ -129,7 +129,7 @@ pandas==1.3.5 # via flytekit paramiko==2.11.0 # via sagemaker-training -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -230,7 +230,7 @@ websocket-client==1.3.3 # via docker werkzeug==2.1.2 # via sagemaker-training -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-bigquery/requirements.txt b/plugins/flytekit-bigquery/requirements.txt index 43f6e6c8ba..6ae47dbecf 100644 --- a/plugins/flytekit-bigquery/requirements.txt +++ b/plugins/flytekit-bigquery/requirements.txt @@ -135,7 +135,7 @@ proto-plus==1.20.6 # via # google-cloud-bigquery # google-cloud-bigquery-storage -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -234,7 +234,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py index 6fed3cd488..4fe1b22baa 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py @@ -115,7 +115,12 @@ def put(self, from_path: str, to_path: str, recursive: bool = False): from fsspec.utils import other_paths lfs = LocalFileSystem() - lpaths = lfs.expand_path(from_path, recursive=recursive) + try: + lpaths = lfs.expand_path(from_path, recursive=recursive) + except FileNotFoundError: + # In some cases, there is no file in the original directory, so we just skip copying the file to the remote path + logger.debug(f"there is no file in the {from_path}") + return rpaths = other_paths(lpaths, to_path) for l, r in zip(lpaths, rpaths): fs.put_file(l, r) diff --git a/plugins/flytekit-data-fsspec/requirements.txt b/plugins/flytekit-data-fsspec/requirements.txt index b29beebd95..86d5df921f 100644 --- a/plugins/flytekit-data-fsspec/requirements.txt +++ b/plugins/flytekit-data-fsspec/requirements.txt @@ -109,7 +109,7 @@ packaging==21.3 # via marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -192,7 +192,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-dbt/README.md b/plugins/flytekit-dbt/README.md new file mode 100644 index 0000000000..bfea79849e --- /dev/null +++ b/plugins/flytekit-dbt/README.md @@ -0,0 +1,15 @@ +# Flytekit dbt plugin + +Flytekit plugin for performing DBT tasks. Currently it supports `dbt run` , `dbt test`, `dbt source freshness` tasks. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-dbt +``` + +_Example coming soon!_ + +## Contributors + +- [Gojek](https://www.gojek.io/) diff --git a/plugins/flytekit-dbt/flytekitplugins/dbt/__init__.py b/plugins/flytekit-dbt/flytekitplugins/dbt/__init__.py new file mode 100644 index 0000000000..437607657f --- /dev/null +++ b/plugins/flytekit-dbt/flytekitplugins/dbt/__init__.py @@ -0,0 +1,21 @@ +""" +.. currentmodule:: flytekitplugins.dbt + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + DBTRun + DBTTest + BaseDBTInput + BaseDBTOutput + DBTRunInput + DBTRunOutput + DBTTestInput + DBTTestOutput +""" + +from .schema import BaseDBTInput, BaseDBTOutput, DBTRunInput, DBTRunOutput, DBTTestInput, DBTTestOutput +from .task import DBTRun, DBTTest diff --git a/plugins/flytekit-dbt/flytekitplugins/dbt/error.py b/plugins/flytekit-dbt/flytekitplugins/dbt/error.py new file mode 100644 index 0000000000..a617b6e505 --- /dev/null +++ b/plugins/flytekit-dbt/flytekitplugins/dbt/error.py @@ -0,0 +1,49 @@ +from typing import List + + +class DBTHandledError(Exception): + """ + DBTHandledError wraps error logs and message from command execution that returns ``exit code 1``. + + Parameters + ---------- + message : str + Error message. + logs : list of str + Logs produced by the command execution. + + Attributes + ---------- + message : str + Error message. + logs : list of str + Logs produced by the command execution. + """ + + def __init__(self, message: str, logs: List[str]): + self.logs = logs + self.message = message + + +class DBTUnhandledError(Exception): + """ + DBTUnhandledError wraps error logs and message from command execution that returns ``exit code 2``. + + Parameters + ---------- + message : str + Error message. + logs : list of str + Logs produced by the command execution. + + Attributes + ---------- + message : str + Error message. + logs : list of str + Logs produced by the command execution. + """ + + def __init__(self, message: str, logs: List[str]): + self.logs = logs + self.message = message diff --git a/plugins/flytekit-dbt/flytekitplugins/dbt/schema.py b/plugins/flytekit-dbt/flytekitplugins/dbt/schema.py new file mode 100644 index 0000000000..3634118b38 --- /dev/null +++ b/plugins/flytekit-dbt/flytekitplugins/dbt/schema.py @@ -0,0 +1,258 @@ +import json +from dataclasses import dataclass +from typing import List, Optional + +from dataclasses_json import dataclass_json + + +@dataclass_json +@dataclass +class BaseDBTInput: + """ + Base class for DBT Task Input. + + Attributes + ---------- + project_dir : str + Path to directory containing the DBT ``dbt_project.yml``. + profiles_dir : str + Path to directory containing the DBT ``profiles.yml``. + profile : str + Profile name to be used for the DBT task. It will override value in ``dbt_project.yml``. + target : str + Target to load for the given profile (default=None). + output_path : str + Path to directory where compiled files (e.g. models) will be written when running the task (default=target). + ignore_handled_error : bool + Ignore handled error (exit code = 1) returned by DBT, see https://docs.getdbt.com/reference/exit-codes (default=False). + flags : dict + Dictionary containing CLI flags to be added to the ``dbt run`` command (default=False). + """ + + project_dir: str + profiles_dir: str + profile: str + target: str = None + output_path: str = "target" + ignore_handled_error: bool = False + flags: dict = None + + def to_args(self) -> List[str]: + """ + Convert the instance of BaseDBTInput into list of arguments. + + Returns + ------- + List[str] + List of arguments. + """ + + args = [] + args += ["--project-dir", self.project_dir] + args += ["--profiles-dir", self.profiles_dir] + args += ["--profile", self.profile] + if self.target is not None: + args += ["--target", self.target] + + if self.flags is not None: + for flag, value in self.flags.items(): + if not value: + continue + + args.append(f"--{flag}") + if isinstance(value, bool): + continue + + if isinstance(value, list): + args += value + continue + + if isinstance(value, dict): + args.append(json.dumps(value)) + continue + + args.append(str(value)) + + return args + + +@dataclass_json +@dataclass +class BaseDBTOutput: + """ + Base class for output of DBT task. + + Attributes + ---------- + command : str + Complete CLI command and flags that was executed by DBT Task. + exit_code : int + Exit code returned by DBT CLI. + """ + + command: str + exit_code: int + + +@dataclass_json +@dataclass +class DBTRunInput(BaseDBTInput): + """ + Input to DBT Run task. + + Attributes + ---------- + select : List[str] + List of model to be executed (default=None). + exclude : List[str] + List of model to be excluded (default=None). + """ + + select: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + def to_args(self) -> List[str]: + """ + Convert the instance of BaseDBTInput into list of arguments. + + Returns + ------- + List[str] + List of arguments. + """ + + args = BaseDBTInput.to_args(self) + if self.select is not None: + args += ["--select"] + self.select + + if self.exclude is not None: + args += ["--exclude"] + self.exclude + + return args + + +@dataclass_json +@dataclass +class DBTRunOutput(BaseDBTOutput): + """ + Output of DBT run task. + + Attributes + ---------- + raw_run_result : str + Raw value of DBT's ``run_result.json``. + raw_manifest : str + Raw value of DBT's ``manifest.json``. + """ + + raw_run_result: str + raw_manifest: str + + +@dataclass_json +@dataclass +class DBTTestInput(BaseDBTInput): + """ + Input to DBT Test task. + + Attributes + ---------- + select : List[str] + List of model to be executed (default : None). + exclude : List[str] + List of model to be excluded (default : None). + """ + + select: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + def to_args(self) -> List[str]: + """ + Convert the instance of DBTTestInput into list of arguments. + + Returns + ------- + List[str] + List of arguments. + """ + + args = BaseDBTInput.to_args(self) + + if self.select is not None: + args += ["--select"] + self.select + + if self.exclude is not None: + args += ["--exclude"] + self.exclude + + return args + + +@dataclass_json +@dataclass +class DBTTestOutput(BaseDBTOutput): + """ + Output of DBT test task. + + Attributes + ---------- + raw_run_result : str + Raw value of DBT's ``run_result.json``. + raw_manifest : str + Raw value of DBT's ``manifest.json``. + """ + + raw_run_result: str + raw_manifest: str + + +@dataclass_json +@dataclass +class DBTFreshnessInput(BaseDBTInput): + """ + Input to DBT Freshness task. + + Attributes + ---------- + select : List[str] + List of model to be executed (default : None). + exclude : List[str] + List of model to be excluded (default : None). + """ + + select: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + def to_args(self) -> List[str]: + """ + Convert the instance of DBTFreshnessInput into list of arguments. + + Returns + ------- + List[str] + List of arguments. + """ + + args = BaseDBTInput.to_args(self) + + if self.select is not None: + args += ["--select"] + self.select + + if self.exclude is not None: + args += ["--exclude"] + self.exclude + + return args + + +@dataclass_json +@dataclass +class DBTFreshnessOutput(BaseDBTOutput): + """ + Output of DBT Freshness task. + + Attributes + ---------- + raw_sources : str + Raw value of DBT's ``sources.json``. + """ + + raw_sources: str diff --git a/plugins/flytekit-dbt/flytekitplugins/dbt/task.py b/plugins/flytekit-dbt/flytekitplugins/dbt/task.py new file mode 100644 index 0000000000..4e7c721837 --- /dev/null +++ b/plugins/flytekit-dbt/flytekitplugins/dbt/task.py @@ -0,0 +1,361 @@ +import os + +from flytekitplugins.dbt.error import DBTHandledError, DBTUnhandledError +from flytekitplugins.dbt.schema import ( + DBTFreshnessInput, + DBTFreshnessOutput, + DBTRunInput, + DBTRunOutput, + DBTTestInput, + DBTTestOutput, +) +from flytekitplugins.dbt.util import run_cli + +from flytekit import kwtypes +from flytekit.core.interface import Interface +from flytekit.core.python_function_task import PythonInstanceTask +from flytekit.loggers import logger + +SUCCESS = 0 +HANDLED_ERROR_CODE = 1 +UNHANDLED_ERROR_CODE = 2 + + +class DBTRun(PythonInstanceTask): + """ + Execute DBT Run CLI command. + + The task will execute ``dbt run`` CLI command in a subprocess. + Input from :class:`flytekitplugins.dbt.schema.DBTRunInput` will be converted into the corresponding CLI flags + and stored in :class:`flytekitplugins.dbt.schema.DBTRunOutput`'s command. + + Parameters + ---------- + name : str + Task name. + """ + + def __init__( + self, + name: str, + **kwargs, + ): + super(DBTRun, self).__init__( + task_type="dbt-run", + name=name, + task_config=None, + interface=Interface(inputs=kwtypes(input=DBTRunInput), outputs=kwtypes(output=DBTRunOutput)), + **kwargs, + ) + + def execute(self, **kwargs) -> DBTRunOutput: + """ + This method will be invoked to execute the task. + + Example + ------- + :: + + dbt_run_task = DBTRun(name="test-task") + + @workflow + def my_workflow() -> DBTRunOutput: + return dbt_run_task( + input=DBTRunInput( + project_dir="tests/jaffle_shop", + profiles_dir="tests/jaffle_shop/profiles", + profile="jaffle_shop", + ) + ) + + + Parameters + ---------- + input : DBTRunInput + DBT run input. + + Returns + ------- + DBTRunOutput + DBT run output. + + Raises + ------ + DBTHandledError + If the ``dbt run`` command returns ``exit code 1``. + DBTUnhandledError + If the ``dbt run`` command returns ``exit code 2``. + """ + + task_input: DBTRunInput = kwargs["input"] + + args = task_input.to_args() + cmd = ["dbt", "--log-format", "json", "run"] + args + full_command = " ".join(cmd) + + logger.info(f"Executing command: {full_command}") + exit_code, logs = run_cli(cmd) + logger.info(f"dbt exited with return code {exit_code}") + + if exit_code == HANDLED_ERROR_CODE and not task_input.ignore_handled_error: + raise DBTHandledError(f"handled error while executing {full_command}", logs) + + if exit_code == UNHANDLED_ERROR_CODE: + raise DBTUnhandledError(f"unhandled error while executing {full_command}", logs) + + output_dir = os.path.join(task_input.project_dir, task_input.output_path) + run_result_path = os.path.join(output_dir, "run_results.json") + with open(run_result_path) as file: + run_result = file.read() + + # read manifest.json + manifest_path = os.path.join(output_dir, "manifest.json") + with open(manifest_path) as file: + manifest = file.read() + + return DBTRunOutput( + command=full_command, + exit_code=exit_code, + raw_run_result=run_result, + raw_manifest=manifest, + ) + + +class DBTTest(PythonInstanceTask): + """Execute DBT Test CLI command + + The task will execute ``dbt test`` CLI command in a subprocess. + Input from :class:`flytekitplugins.dbt.schema.DBTTestInput` will be converted into the corresponding CLI flags + and stored in :class:`flytekitplugins.dbt.schema.DBTTestOutput`'s command. + + Parameters + ---------- + name : str + Task name. + """ + + def __init__( + self, + name: str, + **kwargs, + ): + super(DBTTest, self).__init__( + task_type="dbt-test", + name=name, + task_config=None, + interface=Interface( + inputs={ + "input": DBTTestInput, + }, + outputs={"output": DBTTestOutput}, + ), + **kwargs, + ) + + def execute(self, **kwargs) -> DBTTestOutput: + """ + This method will be invoked to execute the task. + + Example + ------- + :: + + dbt_test_task = DBTTest(name="test-task") + + @workflow + def my_workflow() -> DBTTestOutput: + # run all models + dbt_test_task( + input=DBTTestInput( + project_dir="tests/jaffle_shop", + profiles_dir="tests/jaffle_shop/profiles", + profile="jaffle_shop", + ) + ) + + # run singular test only + dbt_test_task( + input=DBTTestInput( + project_dir="tests/jaffle_shop", + profiles_dir="tests/jaffle_shop/profiles", + profile="jaffle_shop", + select=["test_type:singular"], + ) + ) + + # run both singular and generic test + return dbt_test_task( + input=DBTTestInput( + project_dir="tests/jaffle_shop", + profiles_dir="tests/jaffle_shop/profiles", + profile="jaffle_shop", + select=["test_type:singular", "test_type:generic"], + ) + ) + + + Parameters + ---------- + input : DBTTestInput + DBT test input + + Returns + ------- + DBTTestOutput + DBT test output + + Raises + ------ + DBTHandledError + If the ``dbt test`` command returns ``exit code 1``. + DBTUnhandledError + If the ``dbt test`` command returns ``exit code 2``. + """ + + task_input: DBTTestInput = kwargs["input"] + + args = task_input.to_args() + cmd = ["dbt", "--log-format", "json", "test"] + args + full_command = " ".join(cmd) + + logger.info(f"Executing command: {full_command}") + exit_code, logs = run_cli(cmd) + logger.info(f"dbt exited with return code {exit_code}") + + if exit_code == HANDLED_ERROR_CODE and not task_input.ignore_handled_error: + raise DBTHandledError(f"handled error while executing {full_command}", logs) + + if exit_code == UNHANDLED_ERROR_CODE: + raise DBTUnhandledError(f"unhandled error while executing {full_command}", logs) + + output_dir = os.path.join(task_input.project_dir, task_input.output_path) + run_result_path = os.path.join(output_dir, "run_results.json") + with open(run_result_path) as file: + run_result = file.read() + + # read manifest.json + manifest_path = os.path.join(output_dir, "manifest.json") + with open(manifest_path) as file: + manifest = file.read() + + return DBTTestOutput( + command=full_command, + exit_code=exit_code, + raw_run_result=run_result, + raw_manifest=manifest, + ) + + +class DBTFreshness(PythonInstanceTask): + """Execute DBT Freshness CLI command + + The task will execute ``dbt freshness`` CLI command in a subprocess. + Input from :class:`flytekitplugins.dbt.schema.DBTFreshnessInput` will be converted into the corresponding CLI flags + and stored in :class:`flytekitplugins.dbt.schema.DBTFreshnessOutput`'s command. + + Parameters + ---------- + name : str + Task name. + """ + + def __init__( + self, + name: str, + **kwargs, + ): + super(DBTFreshness, self).__init__( + task_type="dbt-freshness", + name=name, + task_config=None, + interface=Interface( + inputs={ + "input": DBTFreshnessInput, + }, + outputs={"output": DBTFreshnessOutput}, + ), + **kwargs, + ) + + def execute(self, **kwargs) -> DBTFreshnessOutput: + """ + This method will be invoked to execute the task. + + Example + ------- + :: + + dbt_freshness_task = DBTFreshness(name="freshness-task") + + @workflow + def my_workflow() -> DBTFreshnessOutput: + # run all models + dbt_freshness_task( + input=DBTFreshnessInput( + project_dir="tests/jaffle_shop", + profiles_dir="tests/jaffle_shop/profiles", + profile="jaffle_shop", + ) + ) + + # run singular freshness only + dbt_freshness_task( + input=DBTFreshnessInput( + project_dir="tests/jaffle_shop", + profiles_dir="tests/jaffle_shop/profiles", + profile="jaffle_shop", + select=["test_type:singular"], + ) + ) + + # run both singular and generic freshness + return dbt_freshness_task( + input=DBTFreshnessInput( + project_dir="tests/jaffle_shop", + profiles_dir="tests/jaffle_shop/profiles", + profile="jaffle_shop", + select=["test_type:singular", "test_type:generic"], + ) + ) + + + Parameters + ---------- + input : DBTFreshnessInput + DBT freshness input + + Returns + ------- + DBTFreshnessOutput + DBT freshness output + + Raises + ------ + DBTHandledError + If the ``dbt source freshness`` command returns ``exit code 1``. + DBTUnhandledError + If the ``dbt source freshness`` command returns ``exit code 2``. + """ + + task_input: DBTFreshnessInput = kwargs["input"] + + args = task_input.to_args() + cmd = ["dbt", "--log-format", "json", "source", "freshness"] + args + full_command = " ".join(cmd) + + logger.info(f"Executing command: {full_command}") + exit_code, logs = run_cli(cmd) + logger.info(f"dbt exited with return code {exit_code}") + + if exit_code == HANDLED_ERROR_CODE and not task_input.ignore_handled_error: + raise DBTHandledError(f"handled error while executing {full_command}", logs) + + if exit_code == UNHANDLED_ERROR_CODE: + raise DBTUnhandledError(f"unhandled error while executing {full_command}", logs) + + output_dir = os.path.join(task_input.project_dir, task_input.output_path) + sources_path = os.path.join(output_dir, "sources.json") + with open(sources_path) as file: + sources = file.read() + + return DBTFreshnessOutput(command=full_command, exit_code=exit_code, raw_sources=sources) diff --git a/plugins/flytekit-dbt/flytekitplugins/dbt/util.py b/plugins/flytekit-dbt/flytekitplugins/dbt/util.py new file mode 100644 index 0000000000..c127c9279c --- /dev/null +++ b/plugins/flytekit-dbt/flytekitplugins/dbt/util.py @@ -0,0 +1,40 @@ +import json +import subprocess +from typing import List + +from flytekit.loggers import logger + + +def run_cli(cmd: List[str]) -> (int, List[str]): + """ + Execute a CLI command in a subprocess + + Parameters + ---------- + cmd : list of str + Command to be executed. + + Returns + ------- + int + Command's exit code. + list of str + Logs produced by the command execution. + """ + + logs = [] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE) + for raw_line in process.stdout or []: + line = raw_line.decode("utf-8") + try: + json_line = json.loads(line) + except json.JSONDecodeError: + logger.info(line.rstrip()) + else: + logs.append(json_line) + # TODO: pluck `levelname` from json_line and choose appropriate level to use + # in flytekit logger instead of defaulting to `info` + logger.info(line.rstrip()) + + process.wait() + return process.returncode, logs diff --git a/plugins/flytekit-dbt/requirements.in b/plugins/flytekit-dbt/requirements.in new file mode 100644 index 0000000000..57d29c207a --- /dev/null +++ b/plugins/flytekit-dbt/requirements.in @@ -0,0 +1,5 @@ +. +-e file:.#egg=flytekitplugins-dbt +# dbt-sqlite and dbt-core should be compatible +dbt-core==1.2.1 +dbt-sqlite==1.2.0a2 diff --git a/plugins/flytekit-dbt/requirements.txt b/plugins/flytekit-dbt/requirements.txt new file mode 100644 index 0000000000..ea09d84eb1 --- /dev/null +++ b/plugins/flytekit-dbt/requirements.txt @@ -0,0 +1,265 @@ +# +# This file is autogenerated by pip-compile with python 3.8 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-dbt + # via -r requirements.in +agate==1.6.3 + # via dbt-core +arrow==1.2.2 + # via jinja2-time +attrs==21.4.0 + # via jsonschema +babel==2.9.1 + # via agate +binaryornot==0.4.4 + # via cookiecutter +certifi==2021.10.8 + # via requests +cffi==1.15.0 + # via + # cryptography + # dbt-core +chardet==4.0.0 + # via binaryornot +charset-normalizer==2.0.12 + # via requests +click==8.0.4 + # via + # cookiecutter + # dbt-core + # flytekit +cloudpickle==2.0.0 + # via flytekit +colorama==0.4.4 + # via dbt-core +cookiecutter==1.7.3 + # via flytekit +croniter==1.3.4 + # via flytekit +cryptography==3.4.8 + # via + # pyopenssl + # secretstorage +dataclasses-json==0.5.6 + # via flytekit +dbt-core==1.2.1 + # via + # -r requirements.in + # dbt-sqlite + # flytekitplugins-dbt +dbt-extractor==0.4.1 + # via dbt-core +dbt-sqlite==1.2.0a2 + # via -r requirements.in +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.13 + # via flytekit +flyteidl==1.1.12 + # via flytekit +flytekit==1.1.1 + # via flytekitplugins-dbt +future==0.18.2 + # via parsedatetime +googleapis-common-protos==1.55.0 + # via + # flyteidl + # grpcio-status +grpcio==1.44.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.44.0 + # via flytekit +hologram==0.0.14 + # via dbt-core +idna==3.3 + # via + # dbt-core + # requests +importlib-metadata==4.11.2 + # via + # flytekit + # jsonschema + # keyring +isodate==0.6.1 + # via + # agate + # dbt-core +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==2.11.3 + # via + # cookiecutter + # dbt-core + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +jsonschema==3.1.1 + # via hologram +keyring==23.5.0 + # via flytekit +leather==0.3.4 + # via agate +logbook==1.5.3 + # via dbt-core +markupsafe==2.0.1 + # via + # dbt-core + # jinja2 +marshmallow==3.14.1 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mashumaro==2.9 + # via dbt-core +minimal-snowplow-tracker==0.0.2 + # via dbt-core +msgpack==1.0.3 + # via mashumaro +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +networkx==2.7.1 + # via dbt-core +numpy==1.22.3 + # via + # pandas + # pyarrow +packaging==21.3 + # via dbt-core +pandas==1.4.1 + # via flytekit +parsedatetime==2.4 + # via agate +poyo==0.5.0 + # via cookiecutter +protobuf==3.19.5 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyopenssl==21.0.0 + # via flytekit +pyparsing==3.0.7 + # via packaging +pyrsistent==0.18.1 + # via jsonschema +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # hologram + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==6.1.1 + # via + # agate + # cookiecutter +pytimeparse==1.1.8 + # via + # agate + # flytekit +pytz==2021.3 + # via + # babel + # flytekit + # pandas +pyyaml==6.0 + # via + # flytekit + # mashumaro +regex==2022.3.2 + # via docker-image-py +requests==2.27.1 + # via + # cookiecutter + # dbt-core + # docker + # flytekit + # minimal-snowplow-tracker + # responses +responses==0.19.0 + # via flytekit +retry==0.9.2 + # via flytekit +secretstorage==3.3.3 + # via keyring +six==1.16.0 + # via + # agate + # cookiecutter + # grpcio + # isodate + # jsonschema + # leather + # minimal-snowplow-tracker + # pyopenssl + # python-dateutil +sortedcontainers==2.4.0 + # via flytekit +sqlparse==0.4.2 + # via dbt-core +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +typing-extensions==3.10.0.2 + # via + # dbt-core + # flytekit + # mashumaro + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.8 + # via + # flytekit + # requests + # responses +websocket-client==1.3.3 + # via docker +werkzeug==2.0.3 + # via dbt-core +wheel==0.38.0 + # via flytekit +wrapt==1.13.3 + # via + # deprecated + # flytekit +zipp==3.7.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-dbt/setup.py b/plugins/flytekit-dbt/setup.py new file mode 100644 index 0000000000..ab325d82a0 --- /dev/null +++ b/plugins/flytekit-dbt/setup.py @@ -0,0 +1,42 @@ +from setuptools import setup + +PLUGIN_NAME = "dbt" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = [ + "flytekit>=1.1.0,<2.0.0", + "dbt-core>=1.0.0", +] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="DBT Plugin for Flytekit", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-dbt", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-kf-mpi/__init__.py b/plugins/flytekit-dbt/tests/__init__.py similarity index 100% rename from plugins/flytekit-kf-mpi/__init__.py rename to plugins/flytekit-dbt/tests/__init__.py diff --git a/plugins/flytekit-dbt/tests/test_schema.py b/plugins/flytekit-dbt/tests/test_schema.py new file mode 100644 index 0000000000..6555b209d5 --- /dev/null +++ b/plugins/flytekit-dbt/tests/test_schema.py @@ -0,0 +1,334 @@ +import shlex + +import pytest +from flytekitplugins.dbt.schema import BaseDBTInput, DBTFreshnessInput, DBTRunInput, DBTTestInput + +project_dir = "." +profiles_dir = "profiles" +profile_name = "development" + + +class TestBaseDBTInput: + @pytest.mark.parametrize( + "task_input,expected", + [ + ( + BaseDBTInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name}", + ), + ( + BaseDBTInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + target="production", + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --target production", + ), + ( + BaseDBTInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + flags={"vars": {"var1": "val1", "var2": 2}}, + ), + f"""--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --vars '{{"var1": "val1", "var2": 2}}'""", + ), + ( + BaseDBTInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + flags={"bool-flag": True}, + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --bool-flag", + ), + ( + BaseDBTInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + flags={"list-flag": ["a", "b", "c"]}, + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --list-flag a b c", + ), + ], + ) + def test_to_args(self, task_input, expected): + assert task_input.to_args() == shlex.split(expected) + + +class TestDBRunTInput: + @pytest.mark.parametrize( + "task_input,expected", + [ + ( + DBTRunInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name}", + ), + ( + DBTRunInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["model_a", "model_b"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select model_a model_b", + ), + ( + DBTRunInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["tag:nightly", "my_model", "finance.base.*"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select tag:nightly my_model finance.base.*", + ), + ( + DBTRunInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["path:marts/finance,tag:nightly,config.materialized:table"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select path:marts/finance,tag:nightly,config.materialized:table", + ), + ( + DBTRunInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + exclude=["model_a", "model_b"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --exclude model_a model_b", + ), + ( + DBTRunInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + exclude=["tag:nightly", "my_model", "finance.base.*"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --exclude tag:nightly my_model finance.base.*", + ), + ( + DBTRunInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + exclude=["path:marts/finance,tag:nightly,config.materialized:table"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --exclude path:marts/finance,tag:nightly,config.materialized:table", + ), + ], + ) + def test_to_args(self, task_input, expected): + assert task_input.to_args() == shlex.split(expected) + + +class TestDBTestInput: + @pytest.mark.parametrize( + "task_input,expected", + [ + ( + DBTTestInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name}", + ), + ( + DBTTestInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["test_type:singular"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select test_type:singular", + ), + ( + DBTTestInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["model_a", "model_b"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select model_a model_b", + ), + ( + DBTTestInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["tag:nightly", "my_model", "finance.base.*"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select tag:nightly my_model finance.base.*", + ), + ( + DBTTestInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["tag:nightly", "my_model", "finance.base.*", "test_type:singular"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select tag:nightly my_model finance.base.* test_type:singular", + ), + ( + DBTTestInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["path:marts/finance,tag:nightly,config.materialized:table,test_type:singular"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select path:marts/finance,tag:nightly,config.materialized:table,test_type:singular", + ), + ( + DBTTestInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + exclude=["model_a", "model_b"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --exclude model_a model_b", + ), + ( + DBTTestInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + exclude=["tag:nightly", "my_model", "finance.base.*"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --exclude tag:nightly my_model finance.base.*", + ), + ( + DBTTestInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + exclude=["path:marts/finance,tag:nightly,config.materialized:table"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --exclude path:marts/finance,tag:nightly,config.materialized:table", + ), + ( + DBTTestInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["test_type:singular"], + exclude=["path:marts/finance,tag:nightly,config.materialized:table"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select test_type:singular --exclude path:marts/finance,tag:nightly,config.materialized:table", + ), + ], + ) + def test_to_args(self, task_input, expected): + assert task_input.to_args() == shlex.split(expected) + + +class TestDBFreshnessInput: + @pytest.mark.parametrize( + "task_input,expected", + [ + ( + DBTFreshnessInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name}", + ), + ( + DBTFreshnessInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["test_type:singular"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select test_type:singular", + ), + ( + DBTFreshnessInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["model_a", "model_b"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select model_a model_b", + ), + ( + DBTFreshnessInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["tag:nightly", "my_model", "finance.base.*"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select tag:nightly my_model finance.base.*", + ), + ( + DBTFreshnessInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["tag:nightly", "my_model", "finance.base.*", "test_type:singular"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select tag:nightly my_model finance.base.* test_type:singular", + ), + ( + DBTFreshnessInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["path:marts/finance,tag:nightly,config.materialized:table,test_type:singular"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select path:marts/finance,tag:nightly,config.materialized:table,test_type:singular", + ), + ( + DBTFreshnessInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + exclude=["model_a", "model_b"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --exclude model_a model_b", + ), + ( + DBTFreshnessInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + exclude=["tag:nightly", "my_model", "finance.base.*"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --exclude tag:nightly my_model finance.base.*", + ), + ( + DBTFreshnessInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + exclude=["path:marts/finance,tag:nightly,config.materialized:table"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --exclude path:marts/finance,tag:nightly,config.materialized:table", + ), + ( + DBTFreshnessInput( + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile_name, + select=["test_type:singular"], + exclude=["path:marts/finance,tag:nightly,config.materialized:table"], + ), + f"--project-dir {project_dir} --profiles-dir {profiles_dir} --profile {profile_name} --select test_type:singular --exclude path:marts/finance,tag:nightly,config.materialized:table", + ), + ], + ) + def test_to_args(self, task_input, expected): + assert task_input.to_args() == shlex.split(expected) diff --git a/plugins/flytekit-dbt/tests/test_task.py b/plugins/flytekit-dbt/tests/test_task.py new file mode 100644 index 0000000000..cdad8d73f6 --- /dev/null +++ b/plugins/flytekit-dbt/tests/test_task.py @@ -0,0 +1,225 @@ +import os +import pathlib + +import pytest +from flytekitplugins.dbt.error import DBTUnhandledError +from flytekitplugins.dbt.schema import ( + DBTFreshnessInput, + DBTFreshnessOutput, + DBTRunInput, + DBTRunOutput, + DBTTestInput, + DBTTestOutput, +) +from flytekitplugins.dbt.task import DBTFreshness, DBTRun, DBTTest + +from flytekit import workflow +from flytekit.tools.subprocess import check_call + +DBT_PROJECT_DIR = str(pathlib.Path(os.path.dirname(os.path.realpath(__file__)), "testdata", "jaffle_shop")) +DBT_PROFILES_DIR = str(pathlib.Path(os.path.dirname(os.path.realpath(__file__)), "testdata", "profiles")) +DBT_PROFILE = "jaffle_shop" + + +@pytest.fixture(scope="module", autouse=True) +def prepare_db(): + # Ensure path to sqlite database file exists + dbs_path = pathlib.Path(DBT_PROJECT_DIR, "dbs") + dbs_path.mkdir(exist_ok=True, parents=True) + database_file = pathlib.Path(dbs_path, "database_name.db") + database_file.touch() + + # Seed the database + check_call( + [ + "dbt", + "--log-format", + "json", + "seed", + "--project-dir", + DBT_PROJECT_DIR, + "--profiles-dir", + DBT_PROFILES_DIR, + "--profile", + DBT_PROFILE, + ] + ) + + yield + + # Delete the database file + database_file.unlink() + + +class TestDBTRun: + def test_simple_task(self): + dbt_run_task = DBTRun( + name="test-task", + ) + + @workflow + def my_workflow() -> DBTRunOutput: + # run all models + return dbt_run_task( + input=DBTRunInput( + project_dir=DBT_PROJECT_DIR, + profiles_dir=DBT_PROFILES_DIR, + profile=DBT_PROFILE, + select=["tag:something"], + exclude=["tag:something-else"], + ) + ) + + result = my_workflow() + assert isinstance(result, DBTRunOutput) + + def test_incorrect_project_dir(self): + dbt_run_task = DBTRun( + name="test-task", + ) + + with pytest.raises(DBTUnhandledError): + dbt_run_task( + input=DBTRunInput( + project_dir=".", + profiles_dir=DBT_PROFILES_DIR, + profile=DBT_PROFILE, + ) + ) + + def test_task_output(self): + dbt_run_task = DBTRun( + name="test-task", + ) + + output = dbt_run_task.execute( + input=DBTRunInput(project_dir=DBT_PROJECT_DIR, profiles_dir=DBT_PROFILES_DIR, profile=DBT_PROFILE) + ) + + assert output.exit_code == 0 + assert ( + output.command + == f"dbt --log-format json run --project-dir {DBT_PROJECT_DIR} --profiles-dir {DBT_PROFILES_DIR} --profile {DBT_PROFILE}" + ) + + with open(f"{DBT_PROJECT_DIR}/target/run_results.json", "r") as fp: + exp_run_result = fp.read() + assert output.raw_run_result == exp_run_result + + with open(f"{DBT_PROJECT_DIR}/target/manifest.json", "r") as fp: + exp_manifest = fp.read() + assert output.raw_manifest == exp_manifest + + +class TestDBTTest: + def test_simple_task(self): + dbt_test_task = DBTTest( + name="test-task", + ) + + @workflow + def test_workflow() -> DBTTestOutput: + # run all tests + return dbt_test_task( + input=DBTTestInput( + project_dir=DBT_PROJECT_DIR, + profiles_dir=DBT_PROFILES_DIR, + profile=DBT_PROFILE, + ) + ) + + assert isinstance(test_workflow(), DBTTestOutput) + + def test_incorrect_project_dir(self): + dbt_test_task = DBTTest( + name="test-task", + ) + + with pytest.raises(DBTUnhandledError): + dbt_test_task( + input=DBTTestInput( + project_dir=".", + profiles_dir=DBT_PROFILES_DIR, + profile=DBT_PROFILE, + ) + ) + + def test_task_output(self): + dbt_test_task = DBTTest( + name="test-task", + ) + + output = dbt_test_task.execute( + input=DBTTestInput(project_dir=DBT_PROJECT_DIR, profiles_dir=DBT_PROFILES_DIR, profile=DBT_PROFILE) + ) + + assert output.exit_code == 0 + assert ( + output.command + == f"dbt --log-format json test --project-dir {DBT_PROJECT_DIR} --profiles-dir {DBT_PROFILES_DIR} --profile {DBT_PROFILE}" + ) + + with open(f"{DBT_PROJECT_DIR}/target/run_results.json", "r") as fp: + exp_run_result = fp.read() + assert output.raw_run_result == exp_run_result + + with open(f"{DBT_PROJECT_DIR}/target/manifest.json", "r") as fp: + exp_manifest = fp.read() + assert output.raw_manifest == exp_manifest + + +class TestDBTFreshness: + def test_simple_task(self): + dbt_freshness_task = DBTFreshness( + name="test-task", + ) + + @workflow + def my_workflow() -> DBTFreshnessOutput: + # run all models + return dbt_freshness_task( + input=DBTFreshnessInput( + project_dir=DBT_PROJECT_DIR, + profiles_dir=DBT_PROFILES_DIR, + profile=DBT_PROFILE, + select=["tag:something"], + exclude=["tag:something-else"], + ) + ) + + result = my_workflow() + assert isinstance(result, DBTFreshnessOutput) + + def test_incorrect_project_dir(self): + dbt_freshness_task = DBTFreshness( + name="test-task", + ) + + with pytest.raises(DBTUnhandledError): + dbt_freshness_task( + input=DBTFreshnessInput( + project_dir=".", + profiles_dir=DBT_PROFILES_DIR, + profile=DBT_PROFILE, + ) + ) + + def test_task_output(self): + dbt_freshness_task = DBTFreshness( + name="test-task", + ) + + output = dbt_freshness_task.execute( + input=DBTFreshnessInput(project_dir=DBT_PROJECT_DIR, profiles_dir=DBT_PROFILES_DIR, profile=DBT_PROFILE) + ) + + assert output.exit_code == 0 + assert ( + output.command + == f"dbt --log-format json source freshness --project-dir {DBT_PROJECT_DIR} --profiles-dir {DBT_PROFILES_DIR} --profile {DBT_PROFILE}" + ) + + with open(f"{DBT_PROJECT_DIR}/target/sources.json", "r") as fp: + exp_sources = fp.read() + + assert output.raw_sources == exp_sources diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/.gitignore b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/.gitignore new file mode 100644 index 0000000000..7164422079 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/.gitignore @@ -0,0 +1,5 @@ + +target/ +dbt_modules/ +logs/ +**/.DS_Store diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/LICENSE b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/LICENSE new file mode 100644 index 0000000000..8dada3edaf --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/README.md b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/README.md new file mode 100644 index 0000000000..cd94389ceb --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/README.md @@ -0,0 +1,94 @@ +## Testing dbt project: `jaffle_shop` + +`jaffle_shop` is a fictional ecommerce store. This dbt project transforms raw data from an app database into a customers and orders model ready for analytics. + +### What is this repo? + +What this repo _is_: + +- A self-contained playground dbt project, useful for testing out scripts, and communicating some of the core dbt concepts. + +What this repo _is not_: + +- A tutorial — check out the [Getting Started Tutorial](https://docs.getdbt.com/tutorial/setting-up) for that. Notably, this repo contains some anti-patterns to make it self-contained, namely the use of seeds instead of sources. +- A demonstration of best practices — check out the [dbt Learn Demo](https://github.com/dbt-labs/dbt-learn-demo) repo instead. We want to keep this project as simple as possible. As such, we chose not to implement: + - our standard file naming patterns (which make more sense on larger projects, rather than this five-model project) + - a pull request flow + - CI/CD integrations +- A demonstration of using dbt for a high-complex project, or a demo of advanced features (e.g. macros, packages, hooks, operations) — we're just trying to keep things simple here! + +### What's in this repo? + +This repo contains [seeds](https://docs.getdbt.com/docs/building-a-dbt-project/seeds) that includes some (fake) raw data from a fictional app. + +The raw data consists of customers, orders, and payments, with the following entity-relationship diagram: + +![Jaffle Shop ERD](./etc/jaffle_shop_erd.png) + +### Running this project + +To get up and running with this project: + +1. Install dbt using [these instructions](https://docs.getdbt.com/docs/installation). + +2. Clone this repository. + +3. Change into the `jaffle_shop` directory from the command line: + +```bash +$ cd jaffle_shop +``` + +4. Set up a profile called `jaffle_shop` to connect to a data warehouse by following [these instructions](https://docs.getdbt.com/docs/configure-your-profile). If you have access to a data warehouse, you can use those credentials – we recommend setting your [target schema](https://docs.getdbt.com/docs/configure-your-profile#section-populating-your-profile) to be a new schema (dbt will create the schema for you, as long as you have the right privileges). If you don't have access to an existing data warehouse, you can also setup a local postgres database and connect to it in your profile. + +5. Ensure your profile is setup correctly from the command line: + +```bash +$ dbt debug +``` + +6. Load the CSVs with the demo data set. This materializes the CSVs as tables in your target schema. Note that a typical dbt project **does not require this step** since dbt assumes your raw data is already in your warehouse. + +```bash +$ dbt seed +``` + +7. Run the models: + +```bash +$ dbt run +``` + +> **NOTE:** If this steps fails, it might mean that you need to make small changes to the SQL in the models folder to adjust for the flavor of SQL of your target database. Definitely consider this if you are using a community-contributed adapter. + +8. Test the output of the models: + +```bash +$ dbt test +``` + +9. Generate documentation for the project: + +```bash +$ dbt docs generate +``` + +10. View the documentation for the project: + +```bash +$ dbt docs serve +``` + +### What is a jaffle? + +A jaffle is a toasted sandwich with crimped, sealed edges. Invented in Bondi in 1949, the humble jaffle is an Australian classic. The sealed edges allow jaffle-eaters to enjoy liquid fillings inside the sandwich, which reach temperatures close to the core of the earth during cooking. Often consumed at home after a night out, the most classic filling is tinned spaghetti, while my personal favourite is leftover beef stew with melted cheese. + +--- + +For more information on dbt: + +- Read the [introduction to dbt](https://docs.getdbt.com/docs/introduction). +- Read the [dbt viewpoint](https://docs.getdbt.com/docs/about/viewpoint). +- Join the [dbt community](http://community.getdbt.com/). + +--- diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/dbt_project.yml b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/dbt_project.yml new file mode 100644 index 0000000000..acdce4c57c --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/dbt_project.yml @@ -0,0 +1,26 @@ +name: 'jaffle_shop' + +config-version: 2 +version: '0.1' + +profile: 'jaffle_shop' + +model-paths: ["models"] +seed-paths: ["seeds"] +test-paths: ["tests"] +analysis-paths: ["analysis"] +macro-paths: ["macros"] + +target-path: "target" +clean-targets: + - "target" + - "dbt_modules" + - "logs" + +require-dbt-version: [">=1.0.0", "<2.0.0"] + +models: + jaffle_shop: + materialized: table + staging: + materialized: view diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/etc/dbdiagram_definition.txt b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/etc/dbdiagram_definition.txt new file mode 100644 index 0000000000..3a6e12c079 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/etc/dbdiagram_definition.txt @@ -0,0 +1,23 @@ +Table orders { + id int PK + user_id int + order_date date + status varchar +} + +Table payments { + id int + order_id int + payment_method int + amount int +} + +Table customers { + id int PK + first_name varchar + last_name varchar +} + +Ref: orders.user_id > customers.id + +Ref: payments.order_id > orders.id diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/etc/jaffle_shop_erd.png b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/etc/jaffle_shop_erd.png new file mode 100644 index 0000000000..dd14739095 Binary files /dev/null and b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/etc/jaffle_shop_erd.png differ diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/customers.sql b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/customers.sql new file mode 100644 index 0000000000..016a004fe5 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/customers.sql @@ -0,0 +1,69 @@ +with customers as ( + + select * from {{ ref('stg_customers') }} + +), + +orders as ( + + select * from {{ ref('stg_orders') }} + +), + +payments as ( + + select * from {{ ref('stg_payments') }} + +), + +customer_orders as ( + + select + customer_id, + + min(order_date) as first_order, + max(order_date) as most_recent_order, + count(order_id) as number_of_orders + from orders + + group by customer_id + +), + +customer_payments as ( + + select + orders.customer_id, + sum(amount) as total_amount + + from payments + + left join orders on + payments.order_id = orders.order_id + + group by orders.customer_id + +), + +final as ( + + select + customers.customer_id, + customers.first_name, + customers.last_name, + customer_orders.first_order, + customer_orders.most_recent_order, + customer_orders.number_of_orders, + customer_payments.total_amount as customer_lifetime_value + + from customers + + left join customer_orders + on customers.customer_id = customer_orders.customer_id + + left join customer_payments + on customers.customer_id = customer_payments.customer_id + +) + +select * from final diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/docs.md b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/docs.md new file mode 100644 index 0000000000..c6ae93be07 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/docs.md @@ -0,0 +1,14 @@ +{% docs orders_status %} + +Orders can be one of the following statuses: + +| status | description | +|----------------|------------------------------------------------------------------------------------------------------------------------| +| placed | The order has been placed but has not yet left the warehouse | +| shipped | The order has ben shipped to the customer and is currently in transit | +| completed | The order has been received by the customer | +| return_pending | The customer has indicated that they would like to return the order, but it has not yet been received at the warehouse | +| returned | The order has been returned by the customer and received at the warehouse | + + +{% enddocs %} diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/orders.sql b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/orders.sql new file mode 100644 index 0000000000..cbb2934911 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/orders.sql @@ -0,0 +1,56 @@ +{% set payment_methods = ['credit_card', 'coupon', 'bank_transfer', 'gift_card'] %} + +with orders as ( + + select * from {{ ref('stg_orders') }} + +), + +payments as ( + + select * from {{ ref('stg_payments') }} + +), + +order_payments as ( + + select + order_id, + + {% for payment_method in payment_methods -%} + sum(case when payment_method = '{{ payment_method }}' then amount else 0 end) as {{ payment_method }}_amount, + {% endfor -%} + + sum(amount) as total_amount + + from payments + + group by order_id + +), + +final as ( + + select + orders.order_id, + orders.customer_id, + orders.order_date, + orders.status, + + {% for payment_method in payment_methods -%} + + order_payments.{{ payment_method }}_amount, + + {% endfor -%} + + order_payments.total_amount as amount + + from orders + + + left join order_payments + on orders.order_id = order_payments.order_id + +) + +select * from final diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/overview.md b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/overview.md new file mode 100644 index 0000000000..0544c42b17 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/overview.md @@ -0,0 +1,11 @@ +{% docs __overview__ %} + +## Data Documentation for Jaffle Shop + +`jaffle_shop` is a fictional ecommerce store. + +This [dbt](https://www.getdbt.com/) project is for testing out code. + +The source code can be found [here](https://github.com/clrcrl/jaffle_shop). + +{% enddocs %} diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/schema.yml b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/schema.yml new file mode 100644 index 0000000000..381349cfda --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/schema.yml @@ -0,0 +1,82 @@ +version: 2 + +models: + - name: customers + description: This table has basic information about a customer, as well as some derived facts based on a customer's orders + + columns: + - name: customer_id + description: This is a unique identifier for a customer + tests: + - unique + - not_null + + - name: first_name + description: Customer's first name. PII. + + - name: last_name + description: Customer's last name. PII. + + - name: first_order + description: Date (UTC) of a customer's first order + + - name: most_recent_order + description: Date (UTC) of a customer's most recent order + + - name: number_of_orders + description: Count of the number of orders a customer has placed + + - name: total_order_amount + description: Total value (AUD) of a customer's orders + + - name: orders + description: This table has basic information about orders, as well as some derived facts based on payments + + columns: + - name: order_id + tests: + - unique + - not_null + description: This is a unique identifier for an order + + - name: customer_id + description: Foreign key to the customers table + tests: + - not_null + - relationships: + to: ref('customers') + field: customer_id + + - name: order_date + description: Date (UTC) that the order was placed + + - name: status + description: '{{ doc("orders_status") }}' + tests: + - accepted_values: + values: ['placed', 'shipped', 'completed', 'return_pending', 'returned'] + + - name: amount + description: Total amount (AUD) of the order + tests: + - not_null + + - name: credit_card_amount + description: Amount of the order (AUD) paid for by credit card + tests: + - not_null + + - name: coupon_amount + description: Amount of the order (AUD) paid for by coupon + tests: + - not_null + + - name: bank_transfer_amount + description: Amount of the order (AUD) paid for by bank transfer + tests: + - not_null + + - name: gift_card_amount + description: Amount of the order (AUD) paid for by gift card + tests: + - not_null diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/schema.yml b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/schema.yml new file mode 100644 index 0000000000..c207e4cf52 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/schema.yml @@ -0,0 +1,31 @@ +version: 2 + +models: + - name: stg_customers + columns: + - name: customer_id + tests: + - unique + - not_null + + - name: stg_orders + columns: + - name: order_id + tests: + - unique + - not_null + - name: status + tests: + - accepted_values: + values: ['placed', 'shipped', 'completed', 'return_pending', 'returned'] + + - name: stg_payments + columns: + - name: payment_id + tests: + - unique + - not_null + - name: payment_method + tests: + - accepted_values: + values: ['credit_card', 'coupon', 'bank_transfer', 'gift_card'] diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/stg_customers.sql b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/stg_customers.sql new file mode 100644 index 0000000000..cad0472695 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/stg_customers.sql @@ -0,0 +1,22 @@ +with source as ( + + {#- + Normally we would select from the table here, but we are using seeds to load + our data in this project + #} + select * from {{ ref('raw_customers') }} + +), + +renamed as ( + + select + id as customer_id, + first_name, + last_name + + from source + +) + +select * from renamed diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/stg_orders.sql b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/stg_orders.sql new file mode 100644 index 0000000000..a654dcb947 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/stg_orders.sql @@ -0,0 +1,23 @@ +with source as ( + + {#- + Normally we would select from the table here, but we are using seeds to load + our data in this project + #} + select * from {{ ref('raw_orders') }} + +), + +renamed as ( + + select + id as order_id, + user_id as customer_id, + order_date, + status + + from source + +) + +select * from renamed diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/stg_payments.sql b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/stg_payments.sql new file mode 100644 index 0000000000..f718596ad0 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/models/staging/stg_payments.sql @@ -0,0 +1,25 @@ +with source as ( + + {#- + Normally we would select from the table here, but we are using seeds to load + our data in this project + #} + select * from {{ ref('raw_payments') }} + +), + +renamed as ( + + select + id as payment_id, + order_id, + payment_method, + + -- `amount` is currently stored in cents, so we convert it to dollars + amount / 100 as amount + + from source + +) + +select * from renamed diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/seeds/.gitkeep b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/seeds/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/seeds/raw_customers.csv b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/seeds/raw_customers.csv new file mode 100644 index 0000000000..b3e6747d69 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/seeds/raw_customers.csv @@ -0,0 +1,101 @@ +id,first_name,last_name +1,Michael,P. +2,Shawn,M. +3,Kathleen,P. +4,Jimmy,C. +5,Katherine,R. +6,Sarah,R. +7,Martin,M. +8,Frank,R. +9,Jennifer,F. +10,Henry,W. +11,Fred,S. +12,Amy,D. +13,Kathleen,M. +14,Steve,F. +15,Teresa,H. +16,Amanda,H. +17,Kimberly,R. +18,Johnny,K. +19,Virginia,F. +20,Anna,A. +21,Willie,H. +22,Sean,H. +23,Mildred,A. +24,David,G. +25,Victor,H. +26,Aaron,R. +27,Benjamin,B. +28,Lisa,W. +29,Benjamin,K. +30,Christina,W. +31,Jane,G. +32,Thomas,O. +33,Katherine,M. +34,Jennifer,S. +35,Sara,T. +36,Harold,O. +37,Shirley,J. +38,Dennis,J. +39,Louise,W. +40,Maria,A. +41,Gloria,C. +42,Diana,S. +43,Kelly,N. +44,Jane,R. +45,Scott,B. +46,Norma,C. +47,Marie,P. +48,Lillian,C. +49,Judy,N. +50,Billy,L. +51,Howard,R. +52,Laura,F. +53,Anne,B. +54,Rose,M. +55,Nicholas,R. +56,Joshua,K. +57,Paul,W. +58,Kathryn,K. +59,Adam,A. +60,Norma,W. +61,Timothy,R. +62,Elizabeth,P. +63,Edward,G. +64,David,C. +65,Brenda,W. +66,Adam,W. +67,Michael,H. +68,Jesse,E. +69,Janet,P. +70,Helen,F. +71,Gerald,C. +72,Kathryn,O. +73,Alan,B. +74,Harry,A. +75,Andrea,H. +76,Barbara,W. +77,Anne,W. +78,Harry,H. +79,Jack,R. +80,Phillip,H. +81,Shirley,H. +82,Arthur,D. +83,Virginia,R. +84,Christina,R. +85,Theresa,M. +86,Jason,C. +87,Phillip,B. +88,Adam,T. +89,Margaret,J. +90,Paul,P. +91,Todd,W. +92,Willie,O. +93,Frances,R. +94,Gregory,H. +95,Lisa,P. +96,Jacqueline,A. +97,Shirley,D. +98,Nicole,M. +99,Mary,G. +100,Jean,M. diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/seeds/raw_orders.csv b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/seeds/raw_orders.csv new file mode 100644 index 0000000000..7c2be07888 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/seeds/raw_orders.csv @@ -0,0 +1,100 @@ +id,user_id,order_date,status +1,1,2018-01-01,returned +2,3,2018-01-02,completed +3,94,2018-01-04,completed +4,50,2018-01-05,completed +5,64,2018-01-05,completed +6,54,2018-01-07,completed +7,88,2018-01-09,completed +8,2,2018-01-11,returned +9,53,2018-01-12,completed +10,7,2018-01-14,completed +11,99,2018-01-14,completed +12,59,2018-01-15,completed +13,84,2018-01-17,completed +14,40,2018-01-17,returned +15,25,2018-01-17,completed +16,39,2018-01-18,completed +17,71,2018-01-18,completed +18,64,2018-01-20,returned +19,54,2018-01-22,completed +20,20,2018-01-23,completed +21,71,2018-01-23,completed +22,86,2018-01-24,completed +23,22,2018-01-26,return_pending +24,3,2018-01-27,completed +25,51,2018-01-28,completed +26,32,2018-01-28,completed +27,94,2018-01-29,completed +28,8,2018-01-29,completed +29,57,2018-01-31,completed +30,69,2018-02-02,completed +31,16,2018-02-02,completed +32,28,2018-02-04,completed +33,42,2018-02-04,completed +34,38,2018-02-06,completed +35,80,2018-02-08,completed +36,85,2018-02-10,completed +37,1,2018-02-10,completed +38,51,2018-02-10,completed +39,26,2018-02-11,completed +40,33,2018-02-13,completed +41,99,2018-02-14,completed +42,92,2018-02-16,completed +43,31,2018-02-17,completed +44,66,2018-02-17,completed +45,22,2018-02-17,completed +46,6,2018-02-19,completed +47,50,2018-02-20,completed +48,27,2018-02-21,completed +49,35,2018-02-21,completed +50,51,2018-02-23,completed +51,71,2018-02-24,completed +52,54,2018-02-25,return_pending +53,34,2018-02-26,completed +54,54,2018-02-26,completed +55,18,2018-02-27,completed +56,79,2018-02-28,completed +57,93,2018-03-01,completed +58,22,2018-03-01,completed +59,30,2018-03-02,completed +60,12,2018-03-03,completed +61,63,2018-03-03,completed +62,57,2018-03-05,completed +63,70,2018-03-06,completed +64,13,2018-03-07,completed +65,26,2018-03-08,completed +66,36,2018-03-10,completed +67,79,2018-03-11,completed +68,53,2018-03-11,completed +69,3,2018-03-11,completed +70,8,2018-03-12,completed +71,42,2018-03-12,shipped +72,30,2018-03-14,shipped +73,19,2018-03-16,completed +74,9,2018-03-17,shipped +75,69,2018-03-18,completed +76,25,2018-03-20,completed +77,35,2018-03-21,shipped +78,90,2018-03-23,shipped +79,52,2018-03-23,shipped +80,11,2018-03-23,shipped +81,76,2018-03-23,shipped +82,46,2018-03-24,shipped +83,54,2018-03-24,shipped +84,70,2018-03-26,placed +85,47,2018-03-26,shipped +86,68,2018-03-26,placed +87,46,2018-03-27,placed +88,91,2018-03-27,shipped +89,21,2018-03-28,placed +90,66,2018-03-30,shipped +91,47,2018-03-31,placed +92,84,2018-04-02,placed +93,66,2018-04-03,placed +94,63,2018-04-03,placed +95,27,2018-04-04,placed +96,90,2018-04-06,placed +97,89,2018-04-07,placed +98,41,2018-04-07,placed +99,85,2018-04-09,placed diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/seeds/raw_payments.csv b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/seeds/raw_payments.csv new file mode 100644 index 0000000000..a587baab59 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/seeds/raw_payments.csv @@ -0,0 +1,114 @@ +id,order_id,payment_method,amount +1,1,credit_card,1000 +2,2,credit_card,2000 +3,3,coupon,100 +4,4,coupon,2500 +5,5,bank_transfer,1700 +6,6,credit_card,600 +7,7,credit_card,1600 +8,8,credit_card,2300 +9,9,gift_card,2300 +10,9,bank_transfer,0 +11,10,bank_transfer,2600 +12,11,credit_card,2700 +13,12,credit_card,100 +14,13,credit_card,500 +15,13,bank_transfer,1400 +16,14,bank_transfer,300 +17,15,coupon,2200 +18,16,credit_card,1000 +19,17,bank_transfer,200 +20,18,credit_card,500 +21,18,credit_card,800 +22,19,gift_card,600 +23,20,bank_transfer,1500 +24,21,credit_card,1200 +25,22,bank_transfer,800 +26,23,gift_card,2300 +27,24,coupon,2600 +28,25,bank_transfer,2000 +29,25,credit_card,2200 +30,25,coupon,1600 +31,26,credit_card,3000 +32,27,credit_card,2300 +33,28,bank_transfer,1900 +34,29,bank_transfer,1200 +35,30,credit_card,1300 +36,31,credit_card,1200 +37,32,credit_card,300 +38,33,credit_card,2200 +39,34,bank_transfer,1500 +40,35,credit_card,2900 +41,36,bank_transfer,900 +42,37,credit_card,2300 +43,38,credit_card,1500 +44,39,bank_transfer,800 +45,40,credit_card,1400 +46,41,credit_card,1700 +47,42,coupon,1700 +48,43,gift_card,1800 +49,44,gift_card,1100 +50,45,bank_transfer,500 +51,46,bank_transfer,800 +52,47,credit_card,2200 +53,48,bank_transfer,300 +54,49,credit_card,600 +55,49,credit_card,900 +56,50,credit_card,2600 +57,51,credit_card,2900 +58,51,credit_card,100 +59,52,bank_transfer,1500 +60,53,credit_card,300 +61,54,credit_card,1800 +62,54,bank_transfer,1100 +63,55,credit_card,2900 +64,56,credit_card,400 +65,57,bank_transfer,200 +66,58,coupon,1800 +67,58,gift_card,600 +68,59,gift_card,2800 +69,60,credit_card,400 +70,61,bank_transfer,1600 +71,62,gift_card,1400 +72,63,credit_card,2900 +73,64,bank_transfer,2600 +74,65,credit_card,0 +75,66,credit_card,2800 +76,67,bank_transfer,400 +77,67,credit_card,1900 +78,68,credit_card,1600 +79,69,credit_card,1900 +80,70,credit_card,2600 +81,71,credit_card,500 +82,72,credit_card,2900 +83,73,bank_transfer,300 +84,74,credit_card,3000 +85,75,credit_card,1900 +86,76,coupon,200 +87,77,credit_card,0 +88,77,bank_transfer,1900 +89,78,bank_transfer,2600 +90,79,credit_card,1800 +91,79,credit_card,900 +92,80,gift_card,300 +93,81,coupon,200 +94,82,credit_card,800 +95,83,credit_card,100 +96,84,bank_transfer,2500 +97,85,bank_transfer,1700 +98,86,coupon,2300 +99,87,gift_card,3000 +100,87,credit_card,2600 +101,88,credit_card,2900 +102,89,bank_transfer,2200 +103,90,bank_transfer,200 +104,91,credit_card,1900 +105,92,bank_transfer,1500 +106,92,coupon,200 +107,93,gift_card,2600 +108,94,coupon,700 +109,95,coupon,2400 +110,96,gift_card,1700 +111,97,bank_transfer,1400 +112,98,bank_transfer,1000 +113,99,credit_card,2400 diff --git a/plugins/flytekit-dbt/tests/testdata/jaffle_shop/tests/assert_total_payment_amount_is_positive.sql b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/tests/assert_total_payment_amount_is_positive.sql new file mode 100644 index 0000000000..bfd8ee6b5d --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/jaffle_shop/tests/assert_total_payment_amount_is_positive.sql @@ -0,0 +1,6 @@ +select + order_id, + sum(amount) as total_amount +from {{ ref('orders' )}} +group by 1 +having not(sum(amount) >= 0) diff --git a/plugins/flytekit-dbt/tests/testdata/profiles/profiles.yml b/plugins/flytekit-dbt/tests/testdata/profiles/profiles.yml new file mode 100644 index 0000000000..8d90d76768 --- /dev/null +++ b/plugins/flytekit-dbt/tests/testdata/profiles/profiles.yml @@ -0,0 +1,11 @@ +jaffle_shop: + target: dev + outputs: + dev: + type: sqlite + threads: 1 + database: 'database' + schema: 'main' + schemas_and_paths: + main: 'dbs/database_name.db' + schema_directory: 'file_path' diff --git a/plugins/flytekit-deck-standard/requirements.txt b/plugins/flytekit-deck-standard/requirements.txt index 368fa073ea..c52a2739e0 100644 --- a/plugins/flytekit-deck-standard/requirements.txt +++ b/plugins/flytekit-deck-standard/requirements.txt @@ -166,7 +166,7 @@ pillow==9.2.0 # visions plotly==5.9.0 # via flytekitplugins-deck-standard -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -280,7 +280,7 @@ visions[type_image_path]==0.7.4 # via pandas-profiling websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-dolt/requirements.txt b/plugins/flytekit-dolt/requirements.txt index 09d78c7c78..b619ca9520 100644 --- a/plugins/flytekit-dolt/requirements.txt +++ b/plugins/flytekit-dolt/requirements.txt @@ -111,7 +111,7 @@ pandas==1.3.5 # via # dolt-integrations # flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -192,7 +192,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-greatexpectations/requirements.txt b/plugins/flytekit-greatexpectations/requirements.txt index e2f32a907f..57fd7b20d3 100644 --- a/plugins/flytekit-greatexpectations/requirements.txt +++ b/plugins/flytekit-greatexpectations/requirements.txt @@ -151,7 +151,7 @@ jupyter-client==7.3.4 # ipykernel # nbclient # notebook -jupyter-core==4.10.0 +jupyter-core==4.11.2 # via # jupyter-client # nbconvert @@ -235,7 +235,7 @@ prometheus-client==0.14.1 # via notebook prompt-toolkit==3.0.30 # via ipython -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -397,7 +397,7 @@ webencodings==0.5.1 # tinycss2 websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-hive/requirements.txt b/plugins/flytekit-hive/requirements.txt index 73e7a6cdb6..b6a0b8b490 100644 --- a/plugins/flytekit-hive/requirements.txt +++ b/plugins/flytekit-hive/requirements.txt @@ -103,7 +103,7 @@ packaging==21.3 # via marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -184,7 +184,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-huggingface/README.md b/plugins/flytekit-huggingface/README.md new file mode 100644 index 0000000000..394c489ab9 --- /dev/null +++ b/plugins/flytekit-huggingface/README.md @@ -0,0 +1,10 @@ +# Flytekit Hugging Face Plugin +[Hugging Face](https://github.com/huggingface) is a community and data science platform that provides: Tools that enable users to build, train and deploy ML models based on open source (OS) code and technologies + +This plugin supports `datasets.Dataset` as a data type with [StructuredDataset](https://docs.flyte.org/projects/cookbook/en/latest/auto/core/type_system/structured_dataset.html). + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-huggingface +``` diff --git a/plugins/flytekit-huggingface/flytekitplugins/huggingface/__init__.py b/plugins/flytekit-huggingface/flytekitplugins/huggingface/__init__.py new file mode 100644 index 0000000000..30a877bb40 --- /dev/null +++ b/plugins/flytekit-huggingface/flytekitplugins/huggingface/__init__.py @@ -0,0 +1,14 @@ +""" +.. currentmodule:: flytekitplugins.huggingface + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + HuggingFaceDatasetToParquetEncodingHandler + ParquetToHuggingFaceDatasetDecodingHandler +""" + +from .sd_transformers import HuggingFaceDatasetToParquetEncodingHandler, ParquetToHuggingFaceDatasetDecodingHandler diff --git a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py new file mode 100644 index 0000000000..0690179bb1 --- /dev/null +++ b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py @@ -0,0 +1,72 @@ +import typing + +import datasets + +from flytekit import FlyteContext +from flytekit.models import literals +from flytekit.models.literals import StructuredDatasetMetadata +from flytekit.models.types import StructuredDatasetType +from flytekit.types.structured.structured_dataset import ( + PARQUET, + StructuredDataset, + StructuredDatasetDecoder, + StructuredDatasetEncoder, + StructuredDatasetTransformerEngine, +) + + +class HuggingFaceDatasetRenderer: + """ + The datasets.Dataset printable representation is saved to HTML. + """ + + def to_html(self, df: datasets.Dataset) -> str: + assert isinstance(df, datasets.Dataset) + return str(df).replace("\n", "
") + + +class HuggingFaceDatasetToParquetEncodingHandler(StructuredDatasetEncoder): + def __init__(self): + super().__init__(datasets.Dataset, None, PARQUET) + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + df = typing.cast(datasets.Dataset, structured_dataset.dataframe) + + local_dir = ctx.file_access.get_random_local_directory() + local_path = f"{local_dir}/00000" + + df.to_parquet(local_path) + + remote_dir = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + ctx.file_access.upload_directory(local_dir, remote_dir) + return literals.StructuredDataset(uri=remote_dir, metadata=StructuredDatasetMetadata(structured_dataset_type)) + + +class ParquetToHuggingFaceDatasetDecodingHandler(StructuredDatasetDecoder): + def __init__(self): + super().__init__(datasets.Dataset, None, PARQUET) + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> datasets.Dataset: + local_dir = ctx.file_access.get_random_local_directory() + ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) + path = f"{local_dir}/00000" + + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + return datasets.Dataset.from_parquet(path, columns=columns) + return datasets.Dataset.from_parquet(path) + + +StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToHuggingFaceDatasetDecodingHandler()) +StructuredDatasetTransformerEngine.register_renderer(datasets.Dataset, HuggingFaceDatasetRenderer()) diff --git a/plugins/flytekit-huggingface/requirements.in b/plugins/flytekit-huggingface/requirements.in new file mode 100644 index 0000000000..9419fdddce --- /dev/null +++ b/plugins/flytekit-huggingface/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-huggingface diff --git a/plugins/flytekit-huggingface/requirements.txt b/plugins/flytekit-huggingface/requirements.txt new file mode 100644 index 0000000000..8cabda2e96 --- /dev/null +++ b/plugins/flytekit-huggingface/requirements.txt @@ -0,0 +1,238 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-huggingface + # via -r requirements.in +aiohttp==3.8.1 + # via + # datasets + # fsspec +aiosignal==1.2.0 + # via aiohttp +arrow==1.2.2 + # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==21.4.0 + # via aiohttp +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.6.15 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.0 + # via + # aiohttp + # requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.1.0 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.4 + # via pyopenssl +dataclasses-json==0.5.7 + # via flytekit +datasets==2.4.0 + # via flytekitplugins-huggingface +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +dill==0.3.5.1 + # via + # datasets + # multiprocess +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +filelock==3.7.1 + # via huggingface-hub +flyteidl==1.1.9 + # via flytekit +flytekit==1.1.0 + # via flytekitplugins-huggingface +frozenlist==1.3.0 + # via + # aiohttp + # aiosignal +fsspec[http]==2022.7.0 + # via datasets +googleapis-common-protos==1.56.4 + # via + # flyteidl + # grpcio-status +grpcio==1.47.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.47.0 + # via flytekit +huggingface-hub==0.8.1 + # via datasets +idna==3.3 + # via + # requests + # yarl +importlib-metadata==4.12.0 + # via + # flytekit + # keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +keyring==23.7.0 + # via flytekit +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.17.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +multidict==6.0.2 + # via + # aiohttp + # yarl +multiprocess==0.70.13 + # via datasets +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +numpy==1.23.1 + # via + # datasets + # pandas + # pyarrow +packaging==21.3 + # via + # datasets + # huggingface-hub + # marshmallow +pandas==1.4.3 + # via + # datasets + # flytekit +protobuf==3.20.2 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via + # datasets + # flytekit +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.4 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit + # huggingface-hub +regex==2022.7.25 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # datasets + # docker + # flytekit + # fsspec + # huggingface-hub + # responses +responses==0.18.0 + # via + # datasets + # flytekit +retry==0.9.2 + # via flytekit +six==1.16.0 + # via + # grpcio + # python-dateutil +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +tqdm==4.64.0 + # via + # datasets + # huggingface-hub +typing-extensions==4.3.0 + # via + # flytekit + # huggingface-hub + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.11 + # via + # flytekit + # requests + # responses +websocket-client==1.3.3 + # via docker +wheel==0.38.0 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +xxhash==3.0.0 + # via datasets +yarl==1.7.2 + # via aiohttp +zipp==3.8.1 + # via importlib-metadata diff --git a/plugins/flytekit-huggingface/setup.py b/plugins/flytekit-huggingface/setup.py new file mode 100644 index 0000000000..bd4da382da --- /dev/null +++ b/plugins/flytekit-huggingface/setup.py @@ -0,0 +1,41 @@ +from setuptools import setup + +PLUGIN_NAME = "huggingface" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = [ + "flytekit>=1.1.0b0,<2.0.0", + "datasets>=2.4.0", +] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="Evan Sadler", + description="Hugging Face plugin for flytekit", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-huggingface", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-huggingface/tests/__init__.py b/plugins/flytekit-huggingface/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py new file mode 100644 index 0000000000..170fdc3789 --- /dev/null +++ b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py @@ -0,0 +1,70 @@ +from typing import Annotated + +import datasets +import pandas as pd +from flytekitplugins.huggingface.sd_transformers import HuggingFaceDatasetRenderer + +from flytekit import kwtypes, task, workflow +from flytekit.types.structured.structured_dataset import PARQUET, StructuredDataset + +subset_schema = Annotated[StructuredDataset, kwtypes(col2=str), PARQUET] +full_schema = Annotated[StructuredDataset, PARQUET] + + +def test_huggingface_dataset_workflow_subset(): + @task + def generate() -> subset_schema: + df = pd.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + dataset = datasets.Dataset.from_pandas(df) + return StructuredDataset(dataframe=dataset) + + @task + def consume(df: subset_schema) -> subset_schema: + dataset = df.open(datasets.Dataset).all() + + assert dataset[0]["col2"] == "a" + assert dataset[1]["col2"] == "b" + assert dataset[2]["col2"] == "c" + + return StructuredDataset(dataframe=dataset) + + @workflow + def wf() -> subset_schema: + return consume(df=generate()) + + result = wf() + assert result is not None + + +def test_huggingface_dataset__workflow_full(): + @task + def generate() -> full_schema: + df = pd.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + dataset = datasets.Dataset.from_pandas(df) + return StructuredDataset(dataframe=dataset) + + @task + def consume(df: full_schema) -> full_schema: + dataset = df.open(datasets.Dataset).all() + + assert dataset[0]["col1"] == 1 + assert dataset[1]["col1"] == 3 + assert dataset[2]["col1"] == 2 + assert dataset[0]["col2"] == "a" + assert dataset[1]["col2"] == "b" + assert dataset[2]["col2"] == "c" + + return StructuredDataset(dataframe=dataset) + + @workflow + def wf() -> full_schema: + return consume(df=generate()) + + result = wf() + assert result is not None + + +def test_datasets_renderer(): + df = pd.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + dataset = datasets.Dataset.from_pandas(df) + assert HuggingFaceDatasetRenderer().to_html(dataset) == str(dataset).replace("\n", "
") diff --git a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py index 7d16ba0b14..e81728ddb4 100644 --- a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py +++ b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py @@ -19,6 +19,16 @@ def _sanitize_resource_name(resource: _task_models.Resources.ResourceEntry) -> s class Pod(object): + """ + Pod is a platform-wide configuration that uses pod templates. By default, every task is launched as a container in a pod. + This plugin helps expose a fully modifiable Kubernetes pod spec to customize the task execution runtime. + To use pod tasks: (1) Define a pod spec, and (2) Specify the primary container name. + :param V1PodSpec pod_spec: Kubernetes pod spec. https://kubernetes.io/docs/concepts/workloads/pods + :param str primary_container_name: the primary container name + :param Optional[Dict[str, str]] labels: Labels are key/value pairs that are attached to pod spec + :param Optional[Dict[str, str]] annotations: Annotations are key/value pairs that are attached to arbitrary non-identifying metadata to pod spec. + """ + def __init__( self, pod_spec: V1PodSpec, @@ -98,7 +108,9 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any] # Important! Only copy over resource requirements if they are non-empty. container.resources = resource_requirements - container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + ( + container.env or [] + ) final_containers.append(container) diff --git a/plugins/flytekit-k8s-pod/requirements.txt b/plugins/flytekit-k8s-pod/requirements.txt index cc9a88b96a..75036fceb9 100644 --- a/plugins/flytekit-k8s-pod/requirements.txt +++ b/plugins/flytekit-k8s-pod/requirements.txt @@ -113,7 +113,7 @@ packaging==21.3 # via marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -213,7 +213,7 @@ websocket-client==1.3.3 # via # docker # kubernetes -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-k8s-pod/tests/test_pod.py b/plugins/flytekit-k8s-pod/tests/test_pod.py index e9814c98d3..716190b4df 100644 --- a/plugins/flytekit-k8s-pod/tests/test_pod.py +++ b/plugins/flytekit-k8s-pod/tests/test_pod.py @@ -3,6 +3,7 @@ from typing import List from unittest.mock import MagicMock +import pytest from flytekitplugins.pod.task import Pod, PodFunctionTask from kubernetes.client import ApiClient from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements, V1VolumeMount @@ -15,9 +16,10 @@ from flytekit.tools.translator import get_serializable -def get_pod_spec(): +def get_pod_spec(environment=[]): a_container = V1Container( name="a container", + env=environment, ) a_container.command = ["fee", "fi", "fo", "fum"] a_container.volume_mounts = [ @@ -31,10 +33,27 @@ def get_pod_spec(): return pod_spec -def test_pod_task_deserialization(): - pod = Pod(pod_spec=get_pod_spec(), primary_container_name="a container") +@pytest.mark.parametrize( + "task_environment, podspec_env_vars, expected_environment", + [ + ({"FOO": "bar"}, [], [V1EnvVar(name="FOO", value="bar")]), + ( + {"FOO": "bar"}, + [V1EnvVar(name="AN_ENV_VAR", value="42")], + [V1EnvVar(name="FOO", value="bar"), V1EnvVar(name="AN_ENV_VAR", value="42")], + ), + # We do not provide any validation for the duplication of env vars, neither does k8s. + ( + {"FOO": "bar"}, + [V1EnvVar(name="FOO", value="another-bar")], + [V1EnvVar(name="FOO", value="bar"), V1EnvVar(name="FOO", value="another-bar")], + ), + ], +) +def test_pod_task_deserialization(task_environment, podspec_env_vars, expected_environment): + pod = Pod(pod_spec=get_pod_spec(podspec_env_vars), primary_container_name="a container") - @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment={"FOO": "bar"}) + @task(task_config=pod, requests=Resources(cpu="10"), limits=Resources(gpu="2"), environment=task_environment) def simple_pod_task(i: int): pass @@ -85,7 +104,7 @@ def simple_pod_task(i: int): assert primary_container.volume_mounts[0].mount_path == "some/where" assert primary_container.volume_mounts[0].name == "volume mount" assert primary_container.resources == V1ResourceRequirements(limits={"gpu": "2"}, requests={"cpu": "10"}) - assert primary_container.env == [V1EnvVar(name="FOO", value="bar")] + assert primary_container.env == expected_environment assert deserialized_pod_spec.containers[1].name == "another container" config = simple_pod_task.get_config( diff --git a/plugins/flytekit-kf-mpi/requirements.txt b/plugins/flytekit-kf-mpi/requirements.txt index 36bd1f2895..5c3a8a8efa 100644 --- a/plugins/flytekit-kf-mpi/requirements.txt +++ b/plugins/flytekit-kf-mpi/requirements.txt @@ -105,7 +105,7 @@ packaging==21.3 # via marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -186,7 +186,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-kf-pytorch/requirements.txt b/plugins/flytekit-kf-pytorch/requirements.txt index 7a879873c0..fc354d83bd 100644 --- a/plugins/flytekit-kf-pytorch/requirements.txt +++ b/plugins/flytekit-kf-pytorch/requirements.txt @@ -103,7 +103,7 @@ packaging==21.3 # via marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -184,7 +184,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-kf-tensorflow/requirements.txt b/plugins/flytekit-kf-tensorflow/requirements.txt index 89a7e29bd1..d1dc578fdf 100644 --- a/plugins/flytekit-kf-tensorflow/requirements.txt +++ b/plugins/flytekit-kf-tensorflow/requirements.txt @@ -103,7 +103,7 @@ packaging==21.3 # via marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -184,7 +184,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-modin/requirements.txt b/plugins/flytekit-modin/requirements.txt index 089d6f5678..ffec3bce45 100644 --- a/plugins/flytekit-modin/requirements.txt +++ b/plugins/flytekit-modin/requirements.txt @@ -31,7 +31,7 @@ click==8.1.2 # ray cloudpickle==2.0.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit croniter==1.3.4 # via flytekit @@ -140,7 +140,7 @@ platformdirs==2.5.2 # via virtualenv poyo==0.5.0 # via cookiecutter -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -224,7 +224,7 @@ virtualenv==20.14.1 # via ray websocket-client==1.3.2 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.0 # via diff --git a/plugins/flytekit-onnx-pytorch/requirements.txt b/plugins/flytekit-onnx-pytorch/requirements.txt index fa8ac445af..7a4bdea79b 100644 --- a/plugins/flytekit-onnx-pytorch/requirements.txt +++ b/plugins/flytekit-onnx-pytorch/requirements.txt @@ -105,7 +105,7 @@ pillow==9.2.0 # via # -r requirements.in # torchvision -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -189,7 +189,7 @@ urllib3==1.26.10 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-onnx-scikitlearn/requirements.txt b/plugins/flytekit-onnx-scikitlearn/requirements.txt index 7ff826a50b..86c474c8b4 100644 --- a/plugins/flytekit-onnx-scikitlearn/requirements.txt +++ b/plugins/flytekit-onnx-scikitlearn/requirements.txt @@ -113,7 +113,7 @@ packaging==21.3 # via marshmallow pandas==1.4.3 # via flytekit -protobuf==3.20.1 +protobuf==3.19.5 # via # flyteidl # flytekit @@ -202,7 +202,7 @@ urllib3==1.26.10 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-onnx-tensorflow/requirements.txt b/plugins/flytekit-onnx-tensorflow/requirements.txt index 51f649a4ed..b25b24fe1e 100644 --- a/plugins/flytekit-onnx-tensorflow/requirements.txt +++ b/plugins/flytekit-onnx-tensorflow/requirements.txt @@ -151,7 +151,7 @@ pandas==1.4.3 # via flytekit pillow==9.2.0 # via -r requirements.in -protobuf==3.19.4 +protobuf==3.19.5 # via # flyteidl # flytekit @@ -268,7 +268,7 @@ websocket-client==1.3.3 # via docker werkzeug==2.1.2 # via tensorboard -wheel==0.37.1 +wheel==0.38.0 # via # astunparse # flytekit diff --git a/plugins/flytekit-pandera/requirements.txt b/plugins/flytekit-pandera/requirements.txt index 82660ea15a..fe8ed5c840 100644 --- a/plugins/flytekit-pandera/requirements.txt +++ b/plugins/flytekit-pandera/requirements.txt @@ -110,7 +110,7 @@ pandas==1.3.5 # pandera pandera==0.9.0 # via flytekitplugins-pandera -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -199,7 +199,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-papermill/dev-requirements.txt b/plugins/flytekit-papermill/dev-requirements.txt index c8294ca254..d714377ad9 100644 --- a/plugins/flytekit-papermill/dev-requirements.txt +++ b/plugins/flytekit-papermill/dev-requirements.txt @@ -113,7 +113,7 @@ pandas==1.3.5 # via flytekit poyo==0.5.0 # via cookiecutter -protobuf==3.19.3 +protobuf==3.19.5 # via # flyteidl # flytekit @@ -209,7 +209,7 @@ websocket-client==1.3.2 # via # docker # kubernetes -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.13.3 # via diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 0721c39a37..04f821ccf3 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -168,6 +168,12 @@ def rendered_output_path(self) -> str: return self._notebook_path.split(".ipynb")[0] + "-out.html" def get_container(self, settings: SerializationSettings) -> task_models.Container: + # The task name in the original command is incorrect because we use _dummy_task_func to construct the _config_task_instance. + # Therefore, Here we replace the original command with NotebookTask's command. + def fn(settings: SerializationSettings) -> typing.List[str]: + return self.get_command(settings) + + self._config_task_instance.set_command_fn(fn) return self._config_task_instance.get_container(settings) def get_k8s_pod(self, settings: SerializationSettings) -> task_models.K8sPod: diff --git a/plugins/flytekit-papermill/requirements.txt b/plugins/flytekit-papermill/requirements.txt index 10c137989e..125caca902 100644 --- a/plugins/flytekit-papermill/requirements.txt +++ b/plugins/flytekit-papermill/requirements.txt @@ -117,7 +117,7 @@ jupyter-client==7.3.4 # via # ipykernel # nbclient -jupyter-core==4.10.0 +jupyter-core==4.11.2 # via # jupyter-client # nbconvert @@ -189,7 +189,7 @@ pickleshare==0.7.5 # via ipython prompt-toolkit==3.0.30 # via ipython -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -319,7 +319,7 @@ webencodings==0.5.1 # tinycss2 websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 4b456f7fdc..1947d09445 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -22,20 +22,34 @@ def _get_nb_path(name: str, suffix: str = "", abs: bool = True, ext: str = ".ipy return os.path.abspath(path) if abs else path +nb_name = "nb-simple" +nb_simple = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + inputs=kwtypes(pi=float), + outputs=kwtypes(square=float), +) + + def test_notebook_task_simple(): - nb_name = "nb-simple" - nb = NotebookTask( - name="test", - notebook_path=_get_nb_path(nb_name, abs=False), - inputs=kwtypes(pi=float), - outputs=kwtypes(square=float), + serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), ) - sqr, out, render = nb.execute(pi=4) + + sqr, out, render = nb_simple.execute(pi=4) assert sqr == 16.0 - assert nb.python_interface.inputs == {"pi": float} - assert nb.python_interface.outputs.keys() == {"square", "out_nb", "out_rendered_nb"} - assert nb.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out") - assert nb.rendered_output_path == render == _get_nb_path(nb_name, suffix="-out", ext=".html") + assert nb_simple.python_interface.inputs == {"pi": float} + assert nb_simple.python_interface.outputs.keys() == {"square", "out_nb", "out_rendered_nb"} + assert nb_simple.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out") + assert nb_simple.rendered_output_path == render == _get_nb_path(nb_name, suffix="-out", ext=".html") + assert ( + nb_simple.get_command(settings=serialization_settings) + == nb_simple.get_container(settings=serialization_settings).args + ) def test_notebook_task_multi_values(): diff --git a/plugins/flytekit-polars/requirements.txt b/plugins/flytekit-polars/requirements.txt index a4fa78640e..7595ab781c 100644 --- a/plugins/flytekit-polars/requirements.txt +++ b/plugins/flytekit-polars/requirements.txt @@ -106,7 +106,7 @@ pandas==1.3.5 # via flytekit polars==0.13.51 # via flytekitplugins-polars -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -188,7 +188,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-ray/flytekitplugins/ray/__init__.py b/plugins/flytekit-ray/flytekitplugins/ray/__init__.py index 44543df900..ff6fcfd2e6 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/__init__.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/__init__.py @@ -7,7 +7,9 @@ :template: custom.rst :toctree: generated/ - RayConfig + HeadNodeConfig + RayJobConfig + WorkerNodeConfig """ from .task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig diff --git a/plugins/flytekit-ray/requirements.txt b/plugins/flytekit-ray/requirements.txt index 080bff46b7..d270a1e0c0 100644 --- a/plugins/flytekit-ray/requirements.txt +++ b/plugins/flytekit-ray/requirements.txt @@ -6,16 +6,31 @@ # -e file:.#egg=flytekitplugins-ray # via -r requirements.in -aiosignal==1.2.0 +aiohttp==3.8.3 + # via + # aiohttp-cors + # ray +aiohttp-cors==0.7.0 # via ray +aiosignal==1.2.0 + # via + # aiohttp + # ray arrow==1.2.2 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp attrs==21.4.0 # via + # aiohttp # jsonschema # ray binaryornot==0.4.4 # via cookiecutter +blessed==1.19.1 + # via gpustat +cachetools==5.2.0 + # via google-auth certifi==2022.5.18.1 # via requests cffi==1.15.0 @@ -23,7 +38,9 @@ cffi==1.15.0 chardet==4.0.0 # via binaryornot charset-normalizer==2.0.12 - # via requests + # via + # aiohttp + # requests click==8.0.4 # via # cookiecutter @@ -31,6 +48,8 @@ click==8.0.4 # ray cloudpickle==2.1.0 # via flytekit +colorful==0.5.5 + # via ray cookiecutter==1.7.3 # via flytekit croniter==1.3.5 @@ -65,12 +84,20 @@ flytekit==1.1.0 # via flytekitplugins-ray frozenlist==1.3.0 # via + # aiohttp # aiosignal # ray -googleapis-common-protos==1.56.1 +google-api-core==2.11.0 + # via opencensus +google-auth==2.15.0 + # via google-api-core +googleapis-common-protos==1.57.0 # via # flyteidl + # google-api-core # grpcio-status +gpustat==1.0.0 + # via ray grpcio==1.43.0 # via # flytekit @@ -79,7 +106,9 @@ grpcio==1.43.0 grpcio-status==1.43.0 # via flytekit idna==3.3 - # via requests + # via + # requests + # yarl importlib-metadata==4.11.3 # via # flytekit @@ -107,6 +136,10 @@ marshmallow-jsonschema==0.13.0 # via flytekit msgpack==1.0.4 # via ray +multidict==6.0.3 + # via + # aiohttp + # yarl mypy-extensions==0.4.3 # via typing-inspect natsort==8.1.0 @@ -116,6 +149,12 @@ numpy==1.21.6 # pandas # pyarrow # ray +nvidia-ml-py==11.495.46 + # via gpustat +opencensus==0.11.0 + # via ray +opencensus-context==0.1.3 + # via opencensus packaging==21.3 # via marshmallow pandas==1.3.5 @@ -124,20 +163,33 @@ platformdirs==2.5.2 # via virtualenv poyo==0.5.0 # via cookiecutter -protobuf==3.20.1 +prometheus-client==0.13.1 + # via ray +protobuf==3.20.2 # via # flyteidl # flytekit + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger # ray protoc-gen-swagger==0.1.0 # via flyteidl +psutil==5.9.4 + # via gpustat py==1.11.0 # via retry +py-spy==0.3.14 + # via ray pyarrow==6.0.1 # via flytekit +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth pycparser==2.21 # via cffi pyopenssl==22.0.0 @@ -166,7 +218,7 @@ pyyaml==6.0 # via # flytekit # ray -ray==1.13.0 +ray[default]==1.13.0 # via flytekitplugins-ray regex==2022.4.24 # via docker-image-py @@ -175,18 +227,26 @@ requests==2.27.1 # cookiecutter # docker # flytekit + # google-api-core # ray # responses responses==0.20.0 # via flytekit retry==0.9.2 # via flytekit +rsa==4.9 + # via google-auth six==1.16.0 # via + # blessed # cookiecutter + # google-auth + # gpustat # grpcio # python-dateutil # virtualenv +smart-open==6.2.0 + # via ray sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 @@ -206,13 +266,20 @@ urllib3==1.26.9 # responses virtualenv==20.15.1 # via ray +wcwidth==0.2.5 + # via blessed websocket-client==1.3.2 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit +yarl==1.8.2 + # via aiohttp zipp==3.8.0 # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-ray/setup.py b/plugins/flytekit-ray/setup.py index 8c8a6f479a..2af9fc2334 100644 --- a/plugins/flytekit-ray/setup.py +++ b/plugins/flytekit-ray/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["ray", "flytekit>=1.1.0b0,<2.0.0", "flyteidl>=1.1.10"] +plugin_requires = ["ray[default]", "flytekit>=1.1.0b0,<2.0.0", "flyteidl>=1.1.10"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-snowflake/requirements.txt b/plugins/flytekit-snowflake/requirements.txt index 969a62a435..da4bb43c2a 100644 --- a/plugins/flytekit-snowflake/requirements.txt +++ b/plugins/flytekit-snowflake/requirements.txt @@ -103,7 +103,7 @@ packaging==21.3 # via marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -184,7 +184,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-snowflake/tests/test_snowflake.py b/plugins/flytekit-snowflake/tests/test_snowflake.py index ab558ca534..a012e38d99 100644 --- a/plugins/flytekit-snowflake/tests/test_snowflake.py +++ b/plugins/flytekit-snowflake/tests/test_snowflake.py @@ -64,14 +64,26 @@ def test_local_exec(): snowflake_task = SnowflakeTask( name="flytekit.demo.snowflake_task.query2", inputs=kwtypes(ds=str), - query_template=query_template, + query_template="select 1\n", # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) assert len(snowflake_task.interface.inputs) == 1 + assert snowflake_task.query_template == "select 1\\n" assert len(snowflake_task.interface.outputs) == 1 # will not run locally with pytest.raises(Exception): snowflake_task() + + +def test_sql_template(): + snowflake_task = SnowflakeTask( + name="flytekit.demo.snowflake_task.query2", + inputs=kwtypes(ds=str), + query_template="""select 1 from\t + custom where column = 1""", + output_schema_type=FlyteSchema, + ) + assert snowflake_task.query_template == "select 1 from\\t\\n custom where column = 1" diff --git a/plugins/flytekit-spark/requirements.txt b/plugins/flytekit-spark/requirements.txt index 979feb79ef..ebb767cf0c 100644 --- a/plugins/flytekit-spark/requirements.txt +++ b/plugins/flytekit-spark/requirements.txt @@ -103,7 +103,7 @@ packaging==21.3 # via marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -188,7 +188,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-spark/tests/test_wf.py b/plugins/flytekit-spark/tests/test_wf.py index 8c42a6162f..0c2dd1ebc5 100644 --- a/plugins/flytekit-spark/tests/test_wf.py +++ b/plugins/flytekit-spark/tests/test_wf.py @@ -1,16 +1,12 @@ import pandas as pd import pyspark from flytekitplugins.spark.task import Spark +from typing_extensions import Annotated import flytekit from flytekit import kwtypes, task, workflow from flytekit.types.schema import FlyteSchema -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - def test_wf1_with_spark(): @task(task_config=Spark()) diff --git a/plugins/flytekit-sqlalchemy/requirements.txt b/plugins/flytekit-sqlalchemy/requirements.txt index 6d4bbbb40d..4879071049 100644 --- a/plugins/flytekit-sqlalchemy/requirements.txt +++ b/plugins/flytekit-sqlalchemy/requirements.txt @@ -106,7 +106,7 @@ packaging==21.3 # via marshmallow pandas==1.3.5 # via flytekit -protobuf==3.20.1 +protobuf==3.20.2 # via # flyteidl # flytekit @@ -189,7 +189,7 @@ urllib3==1.26.9 # responses websocket-client==1.3.3 # via docker -wheel==0.37.1 +wheel==0.38.0 # via flytekit wrapt==1.14.1 # via diff --git a/plugins/flytekit-vaex/README.md b/plugins/flytekit-vaex/README.md new file mode 100644 index 0000000000..aeef9ad00f --- /dev/null +++ b/plugins/flytekit-vaex/README.md @@ -0,0 +1,11 @@ +# Flytekit Vaex Plugin +[Vaex](https://github.com/vaexio/vaex) is a high-performance Python library for lazy out-of-core DataFrames +(similar to Pandas) to visualize and explore big tabular datasets. + +This plugin supports `vaex.DataFrame` as a data type with [StructuredDataset](https://docs.flyte.org/projects/cookbook/en/latest/auto/core/type_system/structured_dataset.html). + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-vaex +``` diff --git a/plugins/flytekit-vaex/flytekitplugins/vaex/__init__.py b/plugins/flytekit-vaex/flytekitplugins/vaex/__init__.py new file mode 100644 index 0000000000..36176444eb --- /dev/null +++ b/plugins/flytekit-vaex/flytekitplugins/vaex/__init__.py @@ -0,0 +1,14 @@ +""" +.. currentmodule:: flytekitplugins.vaex + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + VaexDataFrameToParquetEncodingHandler + ParquetToVaexDataFrameDecodingHandler +""" + +from .sd_transformers import ParquetToVaexDataFrameDecodingHandler, VaexDataFrameToParquetEncodingHandler diff --git a/plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py b/plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py new file mode 100644 index 0000000000..78c72c9569 --- /dev/null +++ b/plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py @@ -0,0 +1,73 @@ +import os +import typing + +import pandas as pd +import vaex + +from flytekit import FlyteContext, StructuredDatasetType +from flytekit.models import literals +from flytekit.models.literals import StructuredDatasetMetadata +from flytekit.types.structured.structured_dataset import ( + PARQUET, + StructuredDataset, + StructuredDatasetDecoder, + StructuredDatasetEncoder, + StructuredDatasetTransformerEngine, +) + + +class VaexDataFrameToParquetEncodingHandler(StructuredDatasetEncoder): + def __init__(self): + super().__init__(vaex.dataframe.DataFrameLocal, None, PARQUET) + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + df = typing.cast(vaex.dataframe.DataFrameLocal, structured_dataset.dataframe) + path = ctx.file_access.get_random_remote_directory() + local_dir = ctx.file_access.get_random_local_directory() + local_path = os.path.join(local_dir, f"{0:05}") + df.export_parquet(local_path) + ctx.file_access.upload_directory(local_dir, path) + return literals.StructuredDataset( + uri=path, + metadata=StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type), + ) + + +class ParquetToVaexDataFrameDecodingHandler(StructuredDatasetDecoder): + def __init__(self): + super().__init__(vaex.dataframe.DataFrameLocal, None, PARQUET) + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> vaex.dataframe.DataFrameLocal: + local_dir = ctx.file_access.get_random_local_directory() + ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) + path = f"{local_dir}/00000" + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + return vaex.open(path)[columns] + return vaex.open(path) + + +class VaexDataFrameRenderer: + """ + Render a Vaex dataframe schema as an HTML table. + """ + + def to_html(self, df: vaex.dataframe.DataFrameLocal) -> str: + assert isinstance(df, vaex.dataframe.DataFrameLocal) + describe_df = df.describe() + return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False) + + +StructuredDatasetTransformerEngine.register(VaexDataFrameToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToVaexDataFrameDecodingHandler()) +StructuredDatasetTransformerEngine.register_renderer(vaex.dataframe.DataFrameLocal, VaexDataFrameRenderer()) diff --git a/plugins/flytekit-vaex/requirements.in b/plugins/flytekit-vaex/requirements.in new file mode 100644 index 0000000000..271c723dad --- /dev/null +++ b/plugins/flytekit-vaex/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-vaex diff --git a/plugins/flytekit-vaex/requirements.txt b/plugins/flytekit-vaex/requirements.txt new file mode 100644 index 0000000000..4f8266cc16 --- /dev/null +++ b/plugins/flytekit-vaex/requirements.txt @@ -0,0 +1,239 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile --output-file=requirements.txt requirements.in +# +-e file:.#egg=flytekitplugins-vaex + # via -r requirements.in +aplus==0.11.0 + # via vaex-core +arrow==1.2.2 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +blake3==0.3.1 + # via vaex-core +certifi==2022.6.15 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.1 + # via requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.1.0 + # via + # dask + # flytekit + # vaex-core +commonmark==0.9.1 + # via rich +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.4 + # via pyopenssl +dask==2022.10.0 + # via vaex-core +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==6.0.0 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +filelock==3.8.0 + # via vaex-core +flyteidl==1.1.12 + # via flytekit +flytekit==1.2.1 + # via flytekitplugins-vaex +frozendict==2.3.4 + # via vaex-core +fsspec==2022.10.0 + # via dask +future==0.18.2 + # via vaex-core +googleapis-common-protos==1.56.4 + # via + # flyteidl + # grpcio-status +grpcio==1.47.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.47.0 + # via flytekit +idna==3.3 + # via requests +importlib-metadata==4.12.0 + # via + # flytekit + # keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.1.0 + # via flytekit +keyring==23.8.2 + # via flytekit +locket==1.0.0 + # via partd +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.17.1 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +nest-asyncio==1.5.6 + # via vaex-core +numpy==1.21.6 + # via + # pandas + # pyarrow + # vaex-core +packaging==21.3 + # via + # dask + # docker + # marshmallow +pandas==1.3.5 + # via + # flytekit + # vaex-core +partd==1.3.0 + # via dask +progressbar2==4.1.1 + # via vaex-core +protobuf==3.20.1 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via + # flytekit + # vaex-core +pycparser==2.21 + # via cffi +pydantic==1.10.2 + # via vaex-core +pygments==2.13.0 + # via rich +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.4 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +python-utils==3.3.3 + # via progressbar2 +pytimeparse==1.1.8 + # via flytekit +pytz==2022.2.1 + # via + # flytekit + # pandas +pyyaml==5.4.1 + # via + # cookiecutter + # dask + # flytekit + # vaex-core +regex==2022.8.17 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # docker + # flytekit + # responses + # vaex-core +responses==0.21.0 + # via flytekit +retry==0.9.2 + # via flytekit +rich==12.6.0 + # via vaex-core +six==1.16.0 + # via + # grpcio + # python-dateutil + # vaex-core + # websocket-client +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +tabulate==0.9.0 + # via vaex-core +text-unidecode==1.3 + # via python-slugify +toolz==0.12.0 + # via + # dask + # partd +typing-extensions==4.3.0 + # via + # flytekit + # pydantic + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +urllib3==1.26.12 + # via + # docker + # flytekit + # requests + # responses +vaex-core==4.13.0 + # via flytekitplugins-vaex +websocket-client==0.59.0 + # via docker +wheel==0.38.0 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zipp==3.8.1 + # via importlib-metadata diff --git a/plugins/flytekit-vaex/setup.py b/plugins/flytekit-vaex/setup.py new file mode 100644 index 0000000000..c012748057 --- /dev/null +++ b/plugins/flytekit-vaex/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "vaex" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "vaex-core>=4.13.0,<4.14"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="admin@flyte.org", + description="Vaex plugin for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-vaex/tests/__init__.py b/plugins/flytekit-vaex/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-vaex/tests/test_vaex_plugin_sd.py b/plugins/flytekit-vaex/tests/test_vaex_plugin_sd.py new file mode 100644 index 0000000000..8e10411b01 --- /dev/null +++ b/plugins/flytekit-vaex/tests/test_vaex_plugin_sd.py @@ -0,0 +1,89 @@ +import pandas as pd +import vaex +from flytekitplugins.vaex.sd_transformers import VaexDataFrameRenderer +from typing_extensions import Annotated + +from flytekit import kwtypes, task, workflow +from flytekit.types.structured.structured_dataset import PARQUET, StructuredDataset + +full_schema = Annotated[StructuredDataset, kwtypes(x=int, y=str), PARQUET] +subset_schema = Annotated[StructuredDataset, kwtypes(y=str), PARQUET] +vaex_df = vaex.from_dict(dict(x=[1, 3, 2], y=["a", "b", "c"])) + + +def test_vaex_workflow_subset(): + @task + def generate() -> subset_schema: + return StructuredDataset(dataframe=vaex_df) + + @task + def consume(df: subset_schema) -> subset_schema: + subset_df = df.open(vaex.dataframe.DataFrameLocal).all() + assert subset_df.column_names == ["y"] + coly = subset_df.y.values.tolist() + assert coly[0] == "a" + assert coly[1] == "b" + assert coly[2] == "c" + return StructuredDataset(dataframe=subset_df) + + @workflow + def wf() -> subset_schema: + return consume(df=generate()) + + result = wf() + assert result is not None + + +def test_vaex_workflow_full(): + @task + def generate() -> full_schema: + return StructuredDataset(dataframe=vaex_df) + + @task + def consume(df: full_schema) -> full_schema: + full_df = df.open(vaex.dataframe.DataFrameLocal).all() + assert full_df.column_names == ["x", "y"] + colx = full_df.x.values.tolist() + coly = full_df.y.values.tolist() + + assert colx[0] == 1 + assert colx[1] == 3 + assert colx[2] == 2 + assert coly[0] == "a" + assert coly[1] == "b" + assert coly[2] == "c" + + return StructuredDataset(dataframe=full_df.sort("x")) + + @workflow + def wf() -> full_schema: + return consume(df=generate()) + + result = wf() + assert result is not None + + +def test_vaex_renderer(): + vaex_df = vaex.from_dict(dict(x=[1, 3, 2], y=["a", "b", "c"])) + assert VaexDataFrameRenderer().to_html(vaex_df) == pd.DataFrame( + vaex_df.describe().transpose(), columns=vaex_df.columns + ).to_html(index=False) + + +def test_vaex_type(): + @task + def create_vaex_df() -> vaex.dataframe.DataFrameLocal: + return vaex.from_pandas(pd.DataFrame(data={"column_1": [-1, 2, -3], "column_2": [1.5, 2.21, 3.9]})) + + @task + def consume_vaex_df(vaex_df: vaex.dataframe.DataFrameLocal) -> vaex.dataframe.DataFrameLocal: + df_negative = vaex_df[vaex_df.column_1 < 0] + return df_negative + + @workflow + def wf() -> vaex.dataframe.DataFrameLocal: + return consume_vaex_df(vaex_df=create_vaex_df()) + + result = wf() + assert isinstance(result, vaex.dataframe.DataFrameLocal) + assert len(result) == 2 diff --git a/plugins/flytekit-whylogs/README.md b/plugins/flytekit-whylogs/README.md index aeaff969e5..827d4b9cbc 100644 --- a/plugins/flytekit-whylogs/README.md +++ b/plugins/flytekit-whylogs/README.md @@ -15,15 +15,17 @@ pip install flytekitplugins-whylogs To generate profiles, you can add a task like the following: ```python +import whylogs as why from whylogs.core import DatasetProfileView -import whylogs as ylog import pandas as pd +from flytekit import task + @task def profile(df: pd.DataFrame) -> DatasetProfileView: - result = ylog.log(df) # Various overloads for different common data types exist - profile = result.view() + result = why.log(df) # Various overloads for different common data types exist + profile_view = result.view() return profile ``` @@ -37,21 +39,19 @@ if the data in the workflow doesn't conform to some configured constraints, like min/max values on features, data types on features, etc. ```python +from whylogs.core.constraints.factories import greater_than_number, mean_between_range + @task -def validate_data(profile: DatasetProfileView): - column = profile.get_column("my_column") - print(column.to_summary_dict()) # To see available things you can validate against - builder = ConstraintsBuilder(profile) - numConstraint = MetricConstraint( - name='numbers between 0 and 4 only', - condition=lambda x: x.min > 0 and x.max < 4, - metric_selector=MetricsSelector(metric_name='distribution', column_name='my_column')) - builder.add_constraint(numConstraint) +def validate_data(profile_view: DatasetProfileView): + builder = ConstraintsBuilder(dataset_profile_view=profile_view) + builder.add_constraint(greater_than_number(column_name="my_column", number=0.14)) + builder.add_constraint(mean_between_range(column_name="my_other_column", lower=2, upper=3)) constraint = builder.build() valid = constraint.validate() - if(not valid): + if valid is False: + print(constraint.report()) raise Exception("Invalid data found") ``` -Check out our [constraints notebook](https://github.com/whylabs/whylogs/blob/1.0.x/python/examples/basic/MetricConstraints.ipynb) for more examples. +If you want to learn more about whylogs, check out our [example notebooks](https://github.com/whylabs/whylogs/tree/mainline/python/examples). diff --git a/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py b/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py index 71247255f7..46646a012a 100644 --- a/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py +++ b/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py @@ -1,6 +1,7 @@ from typing import Type from whylogs.core import DatasetProfileView +from whylogs.viz.extensions.reports.profile_summary import ProfileSummaryReport from flytekit import BlobType, FlyteContext from flytekit.extend import T, TypeEngine, TypeTransformer @@ -42,9 +43,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: def to_html( self, ctx: FlyteContext, python_val: DatasetProfileView, expected_python_type: Type[DatasetProfileView] ) -> str: - pandas_profile = str(python_val.to_pandas().to_html()) - header = str("

Profile View

\n") - return header + pandas_profile + report = ProfileSummaryReport(target_view=python_val) + return report.report().data TypeEngine.register(WhylogsDatasetProfileTransformer()) diff --git a/plugins/flytekit-whylogs/requirements.txt b/plugins/flytekit-whylogs/requirements.txt index 9001bc05e0..7596afc017 100644 --- a/plugins/flytekit-whylogs/requirements.txt +++ b/plugins/flytekit-whylogs/requirements.txt @@ -6,21 +6,61 @@ # -e file:.#egg=flytekitplugins-whylogs # via -r requirements.in -flake8==4.0.1 +appnope==0.1.3 + # via ipython +asttokens==2.0.8 + # via stack-data +backcall==0.2.0 + # via ipython +decorator==5.1.1 + # via ipython +executing==1.0.0 + # via stack-data +ipython==8.5.0 # via whylogs -mccabe==0.6.1 - # via flake8 -protobuf==3.20.1 +jedi==0.18.1 + # via ipython +matplotlib-inline==0.1.6 + # via ipython +numpy==1.23.3 + # via scipy +parso==0.8.3 + # via jedi +pexpect==4.8.0 + # via ipython +pickleshare==0.7.5 + # via ipython +prompt-toolkit==3.0.31 + # via ipython +protobuf==3.20.2 # via # flytekitplugins-whylogs # whylogs -pycodestyle==2.8.0 - # via flake8 -pyflakes==2.4.0 - # via flake8 +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.2 + # via stack-data +pybars3==0.9.7 + # via whylogs +pygments==2.13.0 + # via ipython +pymeta3==0.5.1 + # via pybars3 +scipy==1.9.1 + # via whylogs +six==1.16.0 + # via asttokens +stack-data==0.5.0 + # via ipython +traitlets==5.4.0 + # via + # ipython + # matplotlib-inline typing-extensions==4.3.0 # via whylogs -whylogs==1.0.6 +wcwidth==0.2.5 + # via prompt-toolkit +whylogs[viz]==1.1.0 # via flytekitplugins-whylogs -whylogs-sketching==3.4.1.dev2 +whylogs-sketching==3.4.1.dev3 # via whylogs diff --git a/plugins/flytekit-whylogs/setup.py b/plugins/flytekit-whylogs/setup.py index 54af3c474e..ce10e877f6 100644 --- a/plugins/flytekit-whylogs/setup.py +++ b/plugins/flytekit-whylogs/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["protobuf>=3.15,<4.0.0", "whylogs", "whylogs[viz]"] +plugin_requires = ["protobuf>=3.15,<4.0.0", "whylogs[viz]>=1.0.8"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-whylogs/tests/test_schema.py b/plugins/flytekit-whylogs/tests/test_schema.py index 8fffae1c75..032d57c16f 100644 --- a/plugins/flytekit-whylogs/tests/test_schema.py +++ b/plugins/flytekit-whylogs/tests/test_schema.py @@ -1,21 +1,19 @@ from datetime import datetime +from typing import Type import pandas as pd -import pytest import whylogs as why +from flytekitplugins.whylogs.schema import WhylogsDatasetProfileTransformer from whylogs.core import DatasetProfileView from flytekit import task, workflow - - -@pytest.fixture -def input_data(): - return pd.DataFrame({"a": [1, 2, 3, 4]}) +from flytekit.core.context_manager import FlyteContextManager @task -def whylogs_profiling(data: pd.DataFrame) -> DatasetProfileView: - result = why.log(pandas=data) +def whylogs_profiling() -> DatasetProfileView: + df = pd.DataFrame({"a": [1, 2, 3, 4]}) + result = why.log(pandas=df) return result.view() @@ -25,18 +23,27 @@ def fetch_whylogs_datetime(profile_view: DatasetProfileView) -> datetime: @workflow -def whylogs_wf(data: pd.DataFrame) -> datetime: - profile_view = whylogs_profiling(data=data) +def whylogs_wf() -> datetime: + profile_view = whylogs_profiling() return fetch_whylogs_datetime(profile_view=profile_view) -def test_task_returns_whylogs_profile_view(input_data): - actual_profile = whylogs_profiling(data=input_data) +def test_task_returns_whylogs_profile_view() -> None: + actual_profile = whylogs_profiling() assert actual_profile is not None assert isinstance(actual_profile, DatasetProfileView) -def test_profile_view_gets_passed_on_tasks(input_data): - result = whylogs_wf(data=input_data) +def test_profile_view_gets_passed_on_tasks() -> None: + result = whylogs_wf() assert result is not None assert isinstance(result, datetime) + + +def test_to_html_method() -> None: + tf = WhylogsDatasetProfileTransformer() + profile_view = whylogs_profiling() + report = tf.to_html(FlyteContextManager.current_context(), profile_view, Type[DatasetProfileView]) + + assert isinstance(report, str) + assert "Profile Visualizer" in report diff --git a/plugins/setup.py b/plugins/setup.py index 96072304d5..fe5c8c200d 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -13,6 +13,7 @@ "flytekitplugins-awssagemaker": "flytekit-aws-sagemaker", "flytekitplugins-bigquery": "flytekit-bigquery", "flytekitplugins-fsspec": "flytekit-data-fsspec", + "flytekitplugins-dbt": "flytekit-dbt", "flytekitplugins-dolt": "flytekit-dolt", "flytekitplugins-great_expectations": "flytekit-greatexpectations", "flytekitplugins-hive": "flytekit-hive", @@ -30,6 +31,7 @@ "flytekitplugins-snowflake": "flytekit-snowflake", "flytekitplugins-spark": "flytekit-spark", "flytekitplugins-sqlalchemy": "flytekit-sqlalchemy", + "flytekitplugins-whylogs": "flytekit-whylogs", } diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 6543d204fd..3c07cb3cee 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # make requirements-spark2.txt # @@ -8,7 +8,7 @@ # via # -r requirements-spark2.in # -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time attrs==20.3.0 # via @@ -16,11 +16,11 @@ attrs==20.3.0 # jsonschema binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot charset-normalizer==2.1.1 # via requests @@ -28,14 +28,16 @@ click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.0 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 - # via pyopenssl +cryptography==38.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -44,47 +46,53 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==6.0.0 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.12 +flyteidl==1.3.0 # via flytekit -googleapis-common-protos==1.56.4 +googleapis-common-protos==1.57.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.51.1 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.51.1 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==5.1.0 # via # click # flytekit # jsonschema # keyring +jaraco-classes==3.2.3 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -joblib==1.1.0 +joblib==1.2.0 # via flytekit jsonschema==3.2.0 # via -r requirements.in -keyring==23.8.2 +keyring==23.11.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.17.1 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -93,9 +101,11 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit +more-itertools==9.0.0 + # via jaraco-classes mypy-extensions==0.4.3 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via @@ -111,10 +121,9 @@ pandas==1.3.5 # via # -r requirements.in # flytekit -protobuf==3.20.1 +protobuf==4.21.10 # via # flyteidl - # flytekit # googleapis-common-protos # grpcio-status # protoc-gen-swagger @@ -122,15 +131,15 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==22.1.0 # via flytekit pyparsing==3.0.9 # via packaging -pyrsistent==0.18.1 +pyrsistent==0.19.2 # via jsonschema python-dateutil==2.8.2 # via @@ -140,11 +149,11 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==6.1.2 +python-slugify==7.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.2.1 +pytz==2022.6 # via # flytekit # pandas @@ -153,7 +162,7 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit -regex==2022.8.17 +regex==2022.10.31 # via docker-image-py requests==2.28.1 # via @@ -161,15 +170,16 @@ requests==2.28.1 # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit +secretstorage==3.3.3 + # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 # via - # grpcio # jsonschema # python-dateutil # websocket-client @@ -179,7 +189,11 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.1 + # via responses +typing-extensions==4.4.0 # via # arrow # flytekit @@ -188,7 +202,7 @@ typing-extensions==4.3.0 # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.12 +urllib3==1.26.13 # via # docker # flytekit @@ -198,13 +212,13 @@ websocket-client==0.59.0 # via # -r requirements.in # docker -wheel==0.37.1 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.1 +zipp==3.11.0 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements.in b/requirements.in index 0c3ba0378d..09828957fe 100644 --- a/requirements.in +++ b/requirements.in @@ -13,3 +13,5 @@ pandas<1.4.0 numpy<1.22.0 # This is required by docker-compose and otherwise clashes with docker-py websocket-client<1.0.0 +# TODO: Remove after buf migration is done and packages updated +packaging<22.0 diff --git a/requirements.txt b/requirements.txt index 32b4ae49a6..caff0db497 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # make requirements.txt # -e file:.#egg=flytekit # via -r requirements.in -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time attrs==20.3.0 # via @@ -14,11 +14,11 @@ attrs==20.3.0 # jsonschema binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot charset-normalizer==2.1.1 # via requests @@ -26,14 +26,16 @@ click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.0 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 - # via pyopenssl +cryptography==38.0.4 + # via + # pyopenssl + # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -42,47 +44,53 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==6.0.0 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.12 +flyteidl==1.3.0 # via flytekit -googleapis-common-protos==1.56.4 +googleapis-common-protos==1.57.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.51.1 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.51.1 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==5.1.0 # via # click # flytekit # jsonschema # keyring +jaraco-classes==3.2.3 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -joblib==1.1.0 +joblib==1.2.0 # via flytekit jsonschema==3.2.0 # via -r requirements.in -keyring==23.8.2 +keyring==23.11.0 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.17.1 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -91,9 +99,11 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit +more-itertools==9.0.0 + # via jaraco-classes mypy-extensions==0.4.3 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit numpy==1.21.6 # via @@ -103,16 +113,16 @@ numpy==1.21.6 # pyarrow packaging==21.3 # via + # -r requirements.in # docker # marshmallow pandas==1.3.5 # via # -r requirements.in # flytekit -protobuf==3.20.1 +protobuf==4.21.10 # via # flyteidl - # flytekit # googleapis-common-protos # grpcio-status # protoc-gen-swagger @@ -120,15 +130,15 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==22.1.0 # via flytekit pyparsing==3.0.9 # via packaging -pyrsistent==0.18.1 +pyrsistent==0.19.2 # via jsonschema python-dateutil==2.8.2 # via @@ -138,11 +148,11 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==6.1.2 +python-slugify==7.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.2.1 +pytz==2022.6 # via # flytekit # pandas @@ -151,7 +161,7 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit -regex==2022.8.17 +regex==2022.10.31 # via docker-image-py requests==2.28.1 # via @@ -159,15 +169,16 @@ requests==2.28.1 # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit +secretstorage==3.3.3 + # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 # via - # grpcio # jsonschema # python-dateutil # websocket-client @@ -177,7 +188,11 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.1 + # via responses +typing-extensions==4.4.0 # via # arrow # flytekit @@ -186,7 +201,7 @@ typing-extensions==4.3.0 # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.12 +urllib3==1.26.13 # via # docker # flytekit @@ -196,13 +211,13 @@ websocket-client==0.59.0 # via # -r requirements.in # docker -wheel==0.37.1 +wheel==0.38.4 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.1 +zipp==3.11.0 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/setup.py b/setup.py index e5bc3cfa33..94499f67eb 100644 --- a/setup.py +++ b/setup.py @@ -21,8 +21,8 @@ maintainer="Flyte Contributors", maintainer_email="admin@flyte.org", packages=find_packages( - include=["flytekit", "flytekit_scripts", "plugins"], - exclude=["boilerplate", "docs", "tests*"], + include=["flytekit", "flytekit_scripts"], + exclude=["boilerplate", "docs", "plugins", "tests*"], ), include_package_data=True, url="https://github.com/flyteorg/flytekit", @@ -39,21 +39,22 @@ ] }, install_requires=[ - "flyteidl>=1.1.3,<1.2.0", + "flyteidl>=1.3.0,<1.4.0", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", - "pyarrow>=4.0.0,<7.0.0", + "pyarrow>=4.0.0,<11.0.0", "click>=6.6,<9.0", "croniter>=0.3.20,<4.0.0", "deprecated>=1.0,<2.0", "docker>=5.0.3,<7.0.0", "python-dateutil>=2.1", - "grpcio>=1.43.0,!=1.45.0,<2.0", - "grpcio-status>=1.43,!=1.45.0", + # Restrict grpcio and grpcio-status. Version 1.50.0 pulls in a version of protobuf that is not compatible + # with the old protobuf library (as described in https://developers.google.com/protocol-buffers/docs/news/2022-05-06) + "grpcio>=1.50.0,<2.0", + "grpcio-status>=1.50.0,<2.0", "importlib-metadata", "pyopenssl", "joblib", - "protobuf>=3.6.1,<4", "python-json-logger>=2.0.0", "pytimeparse>=1.1.8,<2.0.0", "pytz", diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index 57e35e64b3..b8a781224a 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -1,18 +1,18 @@ # -# This file is autogenerated by pip-compile with python 3.7 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.7 +# by the following command: # # make tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt # -arrow==1.2.2 +arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter -certifi==2022.6.15 +certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -chardet==5.0.0 +chardet==5.1.0 # via binaryornot charset-normalizer==2.1.1 # via requests @@ -20,14 +20,16 @@ click==8.1.3 # via # cookiecutter # flytekit -cloudpickle==2.1.0 +cloudpickle==2.2.0 # via flytekit cookiecutter==2.1.1 # via flytekit -croniter==1.3.5 +croniter==1.3.8 # via flytekit -cryptography==37.0.4 - # via pyopenssl +cryptography==38.0.4 + # via + # pyopenssl + # secretstorage cycler==0.11.0 # via matplotlib dataclasses-json==0.5.7 @@ -38,50 +40,58 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.1 # via flytekit docker-image-py==0.1.12 # via flytekit -docstring-parser==0.14.1 +docstring-parser==0.15 # via flytekit -flyteidl==1.1.12 +flyteidl==1.2.5 # via flytekit -flytekit==1.1.1 +flytekit==1.2.5 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -fonttools==4.37.1 +fonttools==4.38.0 # via matplotlib -googleapis-common-protos==1.56.4 +googleapis-common-protos==1.57.0 # via # flyteidl # grpcio-status -grpcio==1.47.0 +grpcio==1.48.2 # via # flytekit # grpcio-status -grpcio-status==1.47.0 +grpcio-status==1.48.2 # via flytekit -idna==3.3 +idna==3.4 # via requests -importlib-metadata==4.12.0 +importlib-metadata==5.1.0 # via # click # flytekit # keyring +jaraco-classes==3.2.3 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter -joblib==1.1.0 - # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -keyring==23.8.2 +joblib==1.2.0 + # via + # -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in + # flytekit +keyring==23.11.0 # via flytekit kiwisolver==1.4.4 # via matplotlib markupsafe==2.1.1 # via jinja2 -marshmallow==3.17.1 +marshmallow==3.19.0 # via # dataclasses-json # marshmallow-enum @@ -92,11 +102,13 @@ marshmallow-jsonschema==0.13.0 # via flytekit matplotlib==3.5.3 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in +more-itertools==9.0.0 + # via jaraco-classes mypy-extensions==0.4.3 # via typing-inspect -natsort==8.1.0 +natsort==8.2.0 # via flytekit -numpy==1.21.6 +numpy==1.22.0 # via # flytekit # matplotlib @@ -107,13 +119,14 @@ opencv-python==4.6.0.66 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in packaging==21.3 # via + # docker # marshmallow # matplotlib pandas==1.3.5 # via flytekit -pillow==9.2.0 +pillow==9.3.0 # via matplotlib -protobuf==3.20.1 +protobuf==3.20.3 # via # flyteidl # flytekit @@ -124,11 +137,11 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -pyarrow==6.0.1 +pyarrow==10.0.1 # via flytekit pycparser==2.21 # via cffi -pyopenssl==22.0.0 +pyopenssl==22.1.0 # via flytekit pyparsing==3.0.9 # via @@ -143,11 +156,11 @@ python-dateutil==2.8.2 # pandas python-json-logger==2.0.4 # via flytekit -python-slugify==6.1.2 +python-slugify==7.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.2.1 +pytz==2022.6 # via # flytekit # pandas @@ -155,7 +168,7 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.8.17 +regex==2022.10.31 # via docker-image-py requests==2.28.1 # via @@ -163,10 +176,12 @@ requests==2.28.1 # docker # flytekit # responses -responses==0.21.0 +responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit +secretstorage==3.3.3 + # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 @@ -179,7 +194,11 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -typing-extensions==4.3.0 +toml==0.10.2 + # via responses +types-toml==0.10.8.1 + # via responses +typing-extensions==4.4.0 # via # arrow # flytekit @@ -189,14 +208,15 @@ typing-extensions==4.3.0 # typing-inspect typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.12 +urllib3==1.26.13 # via + # docker # flytekit # requests # responses -websocket-client==1.4.0 +websocket-client==1.4.2 # via docker -wheel==0.37.1 +wheel==0.38.4 # via # -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in # flytekit @@ -204,5 +224,5 @@ wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.1 +zipp==3.11.0 # via importlib-metadata diff --git a/tests/flytekit/unit/cli/pyflyte/test_init.py b/tests/flytekit/unit/cli/pyflyte/test_init.py new file mode 100644 index 0000000000..0a66433625 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_init.py @@ -0,0 +1,26 @@ +import tempfile + +import pytest +from click.testing import CliRunner + +from flytekit.clis.sdk_in_container import pyflyte + + +@pytest.mark.parametrize( + "command", + [ + ["example"], + ["example", "--template", "simple-example"], + ["example", "--template", "bayesian-optimization"], + ], +) +def test_pyflyte_init(command, monkeypatch: pytest.MonkeyPatch): + tmp_dir = tempfile.mkdtemp() + monkeypatch.chdir(tmp_dir) + runner = CliRunner() + result = runner.invoke( + pyflyte.init, + command, + catch_exceptions=True, + ) + assert result.exit_code == 0 diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index 364b6b14d9..40d63021f2 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -3,9 +3,6 @@ import pytest from click.testing import CliRunner -from flyteidl.admin.launch_plan_pb2 import LaunchPlan -from flyteidl.admin.task_pb2 import TaskSpec -from flyteidl.admin.workflow_pb2 import WorkflowSpec import flytekit import flytekit.configuration @@ -13,6 +10,9 @@ from flytekit.clis.sdk_in_container import pyflyte from flytekit.core import context_manager from flytekit.exceptions.user import FlyteValidationException +from flytekit.models.admin.workflow import WorkflowSpec +from flytekit.models.launch_plan import LaunchPlan +from flytekit.models.task import TaskSpec sample_file_contents = """ from flytekit import task, workflow diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index d078851e1b..4951d4be46 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -8,6 +8,7 @@ from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clis.sdk_in_container import pyflyte from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.core import context_manager from flytekit.remote.remote import FlyteRemote sample_file_contents = """ @@ -49,17 +50,57 @@ def test_register_with_no_package_or_module_argument(): @mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_register_with_no_output_dir_passed(mock_client, mock_remote): + mock_remote._client = mock_client + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core1", exist_ok=True) + with open(os.path.join("core1", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + result = runner.invoke(pyflyte.main, ["register", "core1"]) + assert "Successfully registered 4 entities" in result.output + shutil.rmtree("core1") + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_non_fast_register(mock_client, mock_remote): + mock_remote._client = mock_client + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core2", exist_ok=True) + with open(os.path.join("core2", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + result = runner.invoke(pyflyte.main, ["register", "--non-fast", "--version", "a-version", "core2"]) + assert "Successfully registered 4 entities" in result.output + shutil.rmtree("core2") + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_non_fast_register_require_version(mock_client, mock_remote): mock_remote._client = mock_client mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" mock_remote.return_value._upload_file.return_value = "dummy_md5_bytes", "dummy_native_url" runner = CliRunner() + context_manager.FlyteEntities.entities.clear() with runner.isolated_filesystem(): out = subprocess.run(["git", "init"], capture_output=True) assert out.returncode == 0 - os.makedirs("core", exist_ok=True) - with open(os.path.join("core", "sample.py"), "w") as f: + os.makedirs("core3", exist_ok=True) + with open(os.path.join("core3", "sample.py"), "w") as f: f.write(sample_file_contents) f.close() - result = runner.invoke(pyflyte.main, ["register", "core"]) - assert "Output given as None, using a temporary directory at" in result.output - shutil.rmtree("core") + result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core3"]) + assert result.exit_code == 1 + assert str(result.exception) == "Version is a required parameter in case --non-fast is specified." + shutil.rmtree("core3") diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index b7ac80cce4..d5db7296b9 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -78,6 +78,8 @@ def test_pyflyte_run_cli(): "20H", "--k", "RED", + "--l", + '{"hello": "world"}', "--remote", os.path.join(DIR_NAME, "testdata"), "--image", diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index 18d29f648d..71d4e29a7c 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -54,8 +54,9 @@ def print_all( i: datetime.datetime, j: datetime.timedelta, k: Color, + l: dict, ): - print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}") + print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}") @task @@ -81,11 +82,12 @@ def my_wf( i: datetime.datetime, j: datetime.timedelta, k: Color, + l: dict, remote: pd.DataFrame, image: StructuredDataset, ) -> Annotated[StructuredDataset, subset_cols]: x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks show_sd(in_sd=x) show_sd(in_sd=image) - print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k) + print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l) return x diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index 6636c1e0c8..b3f1807b96 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -1,13 +1,19 @@ import json from subprocess import CompletedProcess +import grpc import mock import pytest from flyteidl.admin import project_pb2 as _project_pb2 from flyteidl.service import auth_pb2 from mock import MagicMock, patch -from flytekit.clients.raw import RawSynchronousFlyteClient, get_basic_authorization_header, get_token +from flytekit.clients.raw import ( + RawSynchronousFlyteClient, + _handle_invalid_create_request, + get_basic_authorization_header, + get_token, +) from flytekit.configuration import AuthType, PlatformConfig from flytekit.configuration.internal import Credentials @@ -211,3 +217,24 @@ def test_refresh_from_environment_variable(mocked_method, monkeypatch: pytest.Mo cc = RawSynchronousFlyteClient(PlatformConfig(auth_mode=None).auto(None)) cc.refresh_credentials() assert mocked_method.called + + +def test__handle_invalid_create_request_decorator_happy(): + client = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.CLIENT_CREDENTIALS)) + mocked_method = client._stub.CreateWorkflow = mock.Mock() + _handle_invalid_create_request(client.create_workflow("/flyteidl.service.AdminService/CreateWorkflow")) + mocked_method.assert_called_once() + + +@patch("flytekit.clients.raw.cli_logger") +@patch("flytekit.clients.raw._MessageToJson") +def test__handle_invalid_create_request_decorator_raises(mock_to_JSON, mock_logger): + mock_to_JSON(return_value="test") + err = grpc.RpcError() + err.details = "There is already a workflow with different structure." + err.code = lambda: grpc.StatusCode.INVALID_ARGUMENT + client = RawSynchronousFlyteClient(PlatformConfig(auth_mode=AuthType.CLIENT_CREDENTIALS)) + client._stub.CreateWorkflow = mock.Mock(side_effect=err) + with pytest.raises(grpc.RpcError): + _handle_invalid_create_request(client.create_workflow("/flyteidl.service.AdminService/CreateWorkflow")) + mock_logger.error.assert_called_with("There is already a workflow with different structure.") diff --git a/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml b/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml index 9c1ad83a3e..7da41b7c38 100644 --- a/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml +++ b/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml @@ -1,7 +1,7 @@ admin: # For GRPC endpoints you might want to use dns:///flyte.myexample.com endpoint: dns:///flyte.mycorp.io - clientSecretLocation: ../tests/flytekit/unit/configuration/configs/fake_secret + clientSecretLocation: configs/fake_secret authType: Pkce insecure: true clientId: propeller diff --git a/tests/flytekit/unit/configuration/test_internal.py b/tests/flytekit/unit/configuration/test_internal.py index 6ba81f309c..7f6be53a55 100644 --- a/tests/flytekit/unit/configuration/test_internal.py +++ b/tests/flytekit/unit/configuration/test_internal.py @@ -2,7 +2,7 @@ import mock -from flytekit.configuration import get_config_file, read_file_if_exists +from flytekit.configuration import PlatformConfig, get_config_file, read_file_if_exists from flytekit.configuration.internal import AWS, Credentials, Images @@ -31,7 +31,20 @@ def test_client_secret_location(): os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/creds_secret_location.yaml") ) secret_location = Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(cfg) - assert secret_location == "../tests/flytekit/unit/configuration/configs/fake_secret" + assert secret_location == "configs/fake_secret" + + # Modify the path to the secret inline + cfg._yaml_config["admin"]["clientSecretLocation"] = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs/fake_secret" + ) + + # Assert secret contains a newline + with open(cfg._yaml_config["admin"]["clientSecretLocation"], "rb") as f: + assert f.read().decode().endswith("\n") is True + + # Assert that secret in platform config does not contain a newline + platform_cfg = PlatformConfig.auto(cfg) + assert platform_cfg.client_credentials_secret == "hello" def test_read_file_if_exists(): diff --git a/tests/flytekit/unit/core/test_checkpoint.py b/tests/flytekit/unit/core/test_checkpoint.py index 0199737335..f1dbbbd5ef 100644 --- a/tests/flytekit/unit/core/test_checkpoint.py +++ b/tests/flytekit/unit/core/test_checkpoint.py @@ -83,6 +83,16 @@ def test_sync_checkpoint_restore_default_path(tmpdir): assert cp.restore() == cp._prev_download_path +def test_sync_checkpoint_read_empty_dir(tmpdir): + td_path = Path(tmpdir) + dest = td_path.joinpath("dest") + dest.mkdir() + src = td_path.joinpath("src") + src.mkdir() + cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src)) + assert cp.read() is None + + def test_sync_checkpoint_read_multiple_files(tmpdir): """ Read can only work with one file. diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index 668ca97dfd..cccf406c71 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -8,9 +8,11 @@ from flytekit.core import context_manager from flytekit.core.context_manager import ExecutionState from flytekit.core.node_creation import create_node +from flytekit.core.resources import Resources from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +from flytekit.models.literals import LiteralMap settings = flytekit.configuration.SerializationSettings( project="test_proj", @@ -153,3 +155,105 @@ def wf() -> str: return dynamic_wf() assert wf() == "hello" + + +def test_dynamic_local_rshift(): + @task + def task1(s: str) -> str: + return s + + @task + def task2(s: str) -> str: + return s + + @dynamic + def dynamic_wf() -> str: + to1 = task1(s="hello").with_overrides(requests=Resources(cpu="3", mem="5Gi")) + to2 = task2(s="world") + to1 >> to2 # noqa + + return to1 + + @workflow + def wf() -> str: + return dynamic_wf() + + assert wf() == "hello" + + with context_manager.FlyteContextManager.with_context( + context_manager.FlyteContextManager.current_context().with_serialization_settings(settings) + ) as ctx: + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=ExecutionState.Mode.TASK_EXECUTION, + ) + ) + ) as ctx: + dynamic_job_spec = dynamic_wf.dispatch_execute(ctx, LiteralMap(literals={})) + assert dynamic_job_spec.nodes[1].upstream_node_ids == ["dn0"] + assert dynamic_job_spec.nodes[0].task_node.overrides.resources.requests[0].value == "3" + assert dynamic_job_spec.nodes[0].task_node.overrides.resources.requests[1].value == "5Gi" + + +def test_dynamic_return_dict(): + @dynamic + def t1(v: str) -> typing.Dict[str, str]: + return {"a": v} + + @dynamic + def t2(v: str) -> typing.Dict[str, typing.Dict[str, str]]: + return {"a": {"b": v}} + + @dynamic + def t3(v: str) -> (str, typing.Dict[str, typing.Dict[str, str]]): + return v, {"a": {"b": v}} + + @workflow + def wf(): + t1(v="a") + t2(v="b") + t3(v="c") + + wf() + + +def test_nested_dynamic_locals(): + @task + def t1(a: int) -> str: + a = a + 2 + return "fast-" + str(a) + + @task + def t2(b: str) -> str: + return f"In t2 string is {b}" + + @task + def t3(b: str) -> str: + return f"In t3 string is {b}" + + @workflow() + def normalwf(a: int) -> str: + x = t1(a=a) + return x + + @dynamic + def dt(ss: str) -> typing.List[str]: + if ss == "hello": + bb = t2(b=ss) + bbb = t3(b=bb) + else: + bb = t2(b=ss + "hi again") + bbb = "static" + return [bb, bbb] + + @workflow + def wf(wf_in: str) -> typing.List[str]: + x = dt(ss=wf_in) + return x + + res = wf(wf_in="hello") + assert res == ["In t2 string is hello", "In t3 string is In t2 string is hello"] + + res = dt(ss="hello") + assert res == ["In t2 string is hello", "In t3 string is In t2 string is hello"] diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 64bd195db8..e2123222e0 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -164,7 +164,7 @@ def my_wf() -> FlyteFile: # This second layer should have two dirs, a random one generated by the new_execution_context call # and an empty folder, created by FlyteFile transformer's to_python_value function. This folder will have # something in it after we open() it. - assert len(working_dir) == 2 + assert len(working_dir) == 1 assert not os.path.exists(workflow_output.path) # # The act of opening it should trigger the download, since we do lazy downloading. @@ -214,7 +214,7 @@ def my_wf() -> FlyteFile: # The act of running the workflow should create the engine dir, and the directory that will contain the # file but the file itself isn't downloaded yet. working_dir = os.listdir(os.path.join(random_dir, "local_flytekit")) - assert len(working_dir) == 2 # local flytekit and the downloaded file + assert len(working_dir) == 1 # local flytekit and the downloaded file assert not os.path.exists(workflow_output.path) # # The act of opening it should trigger the download, since we do lazy downloading. @@ -224,7 +224,7 @@ def my_wf() -> FlyteFile: # and an empty folder, created by FlyteFile transformer's to_python_value function. This folder will have # something in it after we open() it. working_dir = os.listdir(os.path.join(random_dir, "local_flytekit")) - assert len(working_dir) == 3 # local flytekit and the downloaded file + assert len(working_dir) == 2 # local flytekit and the downloaded file assert os.path.exists(workflow_output.path) diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 62c45c346e..442851a8a2 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -2,6 +2,8 @@ import typing from typing import Dict, List +from typing_extensions import Annotated # type: ignore + from flytekit.core import context_manager from flytekit.core.docstring import Docstring from flytekit.core.interface import ( @@ -16,11 +18,6 @@ from flytekit.types.file import FlyteFile from flytekit.types.pickle import FlytePickle -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - def test_extract_only(): def x() -> typing.NamedTuple("NT1", x_str=str, y_int=int): diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index 674f6176e1..c3baebeace 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -4,6 +4,8 @@ from typing import Dict, List import pandas +import pandas as pd +import pytest from dataclasses_json import dataclass_json from pytest import fixture from typing_extensions import Annotated @@ -261,9 +263,9 @@ def my_wf(a: int, b: str) -> (int, str): assert n_cached_task_calls == 2 -def test_set_integer_literal_hash_is_not_cached(): +def test_set_integer_literal_hash_is_cached(): """ - Test to confirm that the local cache is not set in the case of integers, even if we + Test to confirm that the local cache is set in the case of integers, even if we return an annotated integer. In order to make this very explicit, we define a constant hash function, i.e. the same value is returned by it regardless of the input. """ @@ -289,13 +291,13 @@ def wf(a: int) -> int: assert n_cached_task_calls == 0 assert wf(a=3) == 3 assert n_cached_task_calls == 1 - # Confirm that the value is not cached, even though we set a hash function that - # returns a constant value and that the task has only one input. - assert wf(a=2) == 2 - assert n_cached_task_calls == 2 + # Confirm that the value is cached due to the fact the hash value is constant, regardless + # of the value passed to the cacheable task. + assert wf(a=2) == 3 + assert n_cached_task_calls == 1 # Confirm that the cache is hit if we execute the workflow with the same value as previous run. - assert wf(a=2) == 2 - assert n_cached_task_calls == 2 + assert wf(a=2) == 3 + assert n_cached_task_calls == 1 def test_pass_annotated_to_downstream_tasks(): @@ -426,6 +428,14 @@ def test_stable_cache_key(): "b": "abcd", "c": 0.12349, "d": [1, 2, 3], + "e": { + "e_a": 11, + "e_b": list(range(1000)), + "e_c": { + "e_c_a": 12.34, + "e_c_b": "a string", + }, + }, } lit = TypeEngine.to_literal(ctx, kwargs, Dict, lt) lm = LiteralMap( @@ -437,4 +447,39 @@ def test_stable_cache_key(): } ) key = _calculate_cache_key("task_name_1", "31415", lm) - assert key == "task_name_1-31415-a291dc6fe0be387c1cfd67b4c6b78259" + assert key == "task_name_1-31415-404b45f8556276183621d4bf37f50049" + + +def calculate_cache_key_multiple_times(x, n=1000): + series = pd.Series( + [ + _calculate_cache_key( + task_name="task_name", + cache_version="cache_version", + input_literal_map=LiteralMap( + literals={ + "d": TypeEngine.to_literal( + ctx=FlyteContextManager.current_context(), + expected=TypeEngine.to_literal_type(Dict), + python_type=Dict, + python_val=x, + ), + } + ), + ) + for _ in range(n) + ] + ).value_counts() + return series + + +@pytest.mark.parametrize( + "d", + [ + dict(a=1, b=2, c=3), + dict(x=dict(a=1, b=2, c=3)), + dict(xs=[dict(a=1, b=2, c=3), dict(y=dict(a=10, b=20, c=30))]), + ], +) +def test_cache_key_consistency(d): + assert len(calculate_cache_key_multiple_times(d)) == 1 diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index f6dc9c9ba5..47c8af9830 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -87,6 +87,7 @@ def empty_wf2(): wf_spec = get_serializable(OrderedDict(), serialization_settings, empty_wf2) assert wf_spec.template.nodes[0].upstream_node_ids[0] == "n1" assert wf_spec.template.nodes[0].id == "n0" + assert wf_spec.template.nodes[0].metadata.name == "t2" with pytest.raises(FlyteAssertion): @@ -194,7 +195,12 @@ def wf(x: int) -> str: c >> a return b + @workflow + def wf1(x: int): + task_a(x=x) >> task_b(x=x) >> task_c(x=x) + wf(x=3) + wf1(x=3) def test_resource_request_override(): @@ -329,7 +335,8 @@ def my_wf(a: str) -> str: @pytest.mark.parametrize( - "retries,expected", [(None, _literal_models.RetryStrategy(0)), (3, _literal_models.RetryStrategy(3))] + "retries,expected", + [(None, _literal_models.RetryStrategy(0)), (3, _literal_models.RetryStrategy(3))], ) def test_retries_override(retries, expected): @task @@ -396,3 +403,24 @@ def my_wf(a: str): _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"), _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "100"), ] + + +def test_name_override(): + @task + def t1(a: str) -> str: + return f"*~*~*~{a}*~*~*~" + + @workflow + def my_wf(a: str) -> str: + return t1(a=a).with_overrides(name="foo") + + serialization_settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].metadata.name == "foo" diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 6354ca3247..23b3de4573 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -7,7 +7,12 @@ from flytekit import task from flytekit.core import context_manager from flytekit.core.context_manager import CompilationState -from flytekit.core.promise import VoidPromise, create_and_link_node, translate_inputs_to_literals +from flytekit.core.promise import ( + VoidPromise, + create_and_link_node, + create_and_link_node_from_remote, + translate_inputs_to_literals, +) from flytekit.exceptions.user import FlyteAssertion @@ -27,7 +32,7 @@ def t1(a: typing.Union[int, typing.List[int]]) -> typing.Union[int, typing.List[ assert len(p.ref.node.bindings) == 1 @task - def t2(a: typing.Optional[int] = None) -> typing.Union[int]: + def t2(a: typing.Optional[int] = None) -> typing.Optional[int]: return a p = create_and_link_node(ctx, t2) @@ -35,6 +40,30 @@ def t2(a: typing.Optional[int] = None) -> typing.Union[int]: assert len(p.ref.node.bindings) == 0 +def test_create_and_link_node_from_remote(): + @task + def t1() -> None: + ... + + with pytest.raises(FlyteAssertion, match="Cannot create node when not compiling..."): + ctx = context_manager.FlyteContext.current_context() + create_and_link_node_from_remote(ctx, t1, a=3) + + ctx = context_manager.FlyteContext.current_context().with_compilation_state(CompilationState(prefix="")) + p = create_and_link_node_from_remote(ctx, t1) + assert p.ref.node_id == "n0" + assert p.ref.var == "placeholder" + assert len(p.ref.node.bindings) == 0 + + @task + def t2(a: int) -> int: + return a + + p = create_and_link_node_from_remote(ctx, t2, a=3) + assert p.ref.var == "o0" + assert len(p.ref.node.bindings) == 1 + + @pytest.mark.parametrize( "input", [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3]], @@ -73,11 +102,17 @@ def t1(a: typing.Union[float, typing.Dict[str, int]]): translate_inputs_to_literals(ctx, {"a": [1, 2, 3]}, t1.interface.inputs, t1.python_interface.inputs) with pytest.raises( - AssertionError, match="Outputs of a non-output producing task n0 cannot be passed to another task" + AssertionError, + match="Outputs of a non-output producing task n0 cannot be passed to another task", ): @task def t1(a: typing.Union[float, typing.Dict[str, int]]): print(a) - translate_inputs_to_literals(ctx, {"a": VoidPromise("n0")}, t1.interface.inputs, t1.python_interface.inputs) + translate_inputs_to_literals( + ctx, + {"a": VoidPromise("n0")}, + t1.interface.inputs, + t1.python_interface.inputs, + ) diff --git a/tests/flytekit/unit/core/test_realworld_examples.py b/tests/flytekit/unit/core/test_realworld_examples.py index 6f2da3efd0..83e859c1da 100644 --- a/tests/flytekit/unit/core/test_realworld_examples.py +++ b/tests/flytekit/unit/core/test_realworld_examples.py @@ -2,11 +2,12 @@ from collections import OrderedDict import pandas as pd +from typing_extensions import Annotated from flytekit import Resources from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.types.file.file import FlyteFile +from flytekit.types.file import FileExt, FlyteFile from flytekit.types.schema import FlyteSchema @@ -44,7 +45,7 @@ def test_diabetes(): # the last column is the class CLASSES_COLUMNS = OrderedDict({"class": int}) - MODELSER_JOBLIB = typing.TypeVar("joblib.dat") + MODELSER_JOBLIB = Annotated[str, FileExt("joblib.dat")] class XGBoostModelHyperparams(object): """ @@ -80,12 +81,12 @@ def from_dict(cls, d): @task(cache_version="1.0", cache=True, limits=Resources(mem="200Mi")) def split_traintest_dataset( dataset: FlyteFile[typing.TypeVar("csv")], seed: int, test_split_ratio: float - ) -> ( + ) -> typing.Tuple[ FlyteSchema[FEATURE_COLUMNS], FlyteSchema[FEATURE_COLUMNS], FlyteSchema[CLASSES_COLUMNS], FlyteSchema[CLASSES_COLUMNS], - ): + ]: """ Retrieves the training dataset from the given blob location and then splits it using the split ratio and returns the result This splitter is only for the dataset that has the format as specified in the example csv. The last column is assumed to be diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index a69ec6665b..a96a94843b 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -32,6 +32,7 @@ def test_serialization(): inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", + environment={"a": "b"}, command=["sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"], ) diff --git a/tests/flytekit/unit/core/test_shim_task.py b/tests/flytekit/unit/core/test_shim_task.py index 2cd8d3e668..76188bf557 100644 --- a/tests/flytekit/unit/core/test_shim_task.py +++ b/tests/flytekit/unit/core/test_shim_task.py @@ -4,7 +4,7 @@ import mock import flytekit.configuration -from flytekit import ContainerTask, kwtypes +from flytekit import ContainerTask, Resources, kwtypes from flytekit.configuration import Image, ImageConfig from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask, TaskTemplateResolver from flytekit.core.utils import write_proto_to_file @@ -54,10 +54,17 @@ def test_serialize_to_model(mock_custom, mock_config): mock_custom.return_value = {"a": "custom"} mock_config.return_value = {"a": "config"} ct = PythonCustomizedContainerTask( - name="mytest", task_config=None, container_image="someimage", executor_type=Placeholder + name="mytest", + task_config=None, + container_image="someimage", + executor_type=Placeholder, + requests=Resources(ephemeral_storage="200Mi"), + limits=Resources(ephemeral_storage="300Mi"), ) tt = ct.serialize_to_model(serialization_settings) assert tt.container.image == "someimage" assert len(tt.config) == 1 assert tt.id.name == "mytest" assert len(tt.custom) == 1 + assert tt.container.resources.requests[0].value == "200Mi" + assert tt.container.resources.limits[0].value == "300Mi" diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index 0c450bb913..bfb41d0fef 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -7,11 +7,13 @@ from typing_extensions import Annotated import flytekit.configuration -from flytekit import kwtypes, task from flytekit.configuration import Image, ImageConfig -from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.base_task import kwtypes +from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine +from flytekit.core.workflow import workflow from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import SchemaType, SimpleType, StructuredDatasetType @@ -38,6 +40,7 @@ image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) +df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) def test_protocol(): @@ -49,6 +52,56 @@ def generate_pandas() -> pd.DataFrame: return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]}) +def test_formats_make_sense(): + @task + def t1(a: pd.DataFrame) -> pd.DataFrame: + print(a) + return generate_pandas() + + # this should be an empty string format + assert t1.interface.outputs["o0"].type.structured_dataset_type.format == "" + assert t1.interface.inputs["a"].type.structured_dataset_type.format == "" + + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) + ) + ): + result = t1(a=generate_pandas()) + val = result.val.scalar.value + assert val.metadata.structured_dataset_type.format == "parquet" + + +def test_setting_of_unset_formats(): + + custom = Annotated[StructuredDataset, "parquet"] + example = custom(dataframe=df, uri="/path") + # It's okay that the annotation is not used here yet. + assert example.file_format == "" + + @task + def t2(path: str) -> StructuredDataset: + sd = StructuredDataset(dataframe=df, uri=path) + return sd + + @workflow + def wf(path: str) -> StructuredDataset: + return t2(path=path) + + res = wf(path="/tmp/somewhere") + # Now that it's passed through an encoder however, it should be set. + assert res.file_format == "parquet" + + +def test_json(): + sd = StructuredDataset(dataframe=df, uri="/some/path") + sd.file_format = "myformat" + json_str = sd.to_json() + new_sd = StructuredDataset.from_json(json_str) + assert new_sd.file_format == "myformat" + + def test_types_pandas(): pt = pd.DataFrame lt = TypeEngine.to_literal_type(pt) @@ -119,6 +172,7 @@ def test_types_sd(): def test_retrieving(): assert StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET) is not None + # Asking for a generic means you're okay with any one registered for that type assuming there's just one. assert StructuredDatasetTransformerEngine.get_encoder( pd.DataFrame, "file", "" ) is StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET) @@ -209,7 +263,7 @@ def encode( StructuredDatasetTransformerEngine.register(empty_format_temp_encoder, default_for_type=False) res = StructuredDatasetTransformerEngine.get_encoder(MyDF, "tmpfs", "rando") - assert res is default_encoder + assert res is empty_format_temp_encoder def test_slash_register(): @@ -233,7 +287,7 @@ def encode( def test_sd(): sd = StructuredDataset(dataframe="hi") sd.uri = "my uri" - assert sd.file_format == PARQUET + assert sd.file_format == "" with pytest.raises(ValueError, match="No dataframe type set"): sd.all() @@ -388,9 +442,8 @@ def encode( assert df_literal_type.structured_dataset_type.format == "avro" sd = annotated_sd_type(df) - lt = TypeEngine.to_literal(ctx, sd, python_type=annotated_sd_type, expected=df_literal_type) - # We haven't registered avro encoder, so here use default parquet encoder instead - assert lt.scalar.structured_dataset.metadata.structured_dataset_type.format == PARQUET + with pytest.raises(ValueError, match="Failed to find a handler"): + TypeEngine.to_literal(ctx, sd, python_type=annotated_sd_type, expected=df_literal_type) StructuredDatasetTransformerEngine.register(TempEncoder(), default_for_type=False) sd2 = annotated_sd_type(df) diff --git a/tests/flytekit/unit/core/test_type_delayed.py b/tests/flytekit/unit/core/test_type_delayed.py index 268bf64285..3e6824788d 100644 --- a/tests/flytekit/unit/core/test_type_delayed.py +++ b/tests/flytekit/unit/core/test_type_delayed.py @@ -4,16 +4,12 @@ from dataclasses import dataclass from dataclasses_json import dataclass_json +from typing_extensions import Annotated # type: ignore from flytekit.core import context_manager from flytekit.core.interface import transform_function_to_interface, transform_inputs_to_parameters from flytekit.core.type_engine import TypeEngine -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - @dataclass_json @dataclass diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 0fe2513908..3e813c0fb7 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -47,7 +47,7 @@ from flytekit.models.types import LiteralType, SimpleType, TypeStructure from flytekit.types.directory import TensorboardLogs from flytekit.types.directory.types import FlyteDirectory -from flytekit.types.file import JPEGImageFile +from flytekit.types.file import FileExt, JPEGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop from flytekit.types.pickle import FlytePickle from flytekit.types.pickle.pickle import FlytePickleTransformer @@ -144,6 +144,31 @@ def test_list_of_dict_getting_python_value(): assert isinstance(pv, list) +def test_list_of_single_dataclass(): + @dataclass_json + @dataclass() + class Bar(object): + v: typing.Optional[typing.List[int]] + w: typing.Optional[typing.List[float]] + + @dataclass_json + @dataclass() + class Foo(object): + a: typing.Optional[typing.List[str]] + b: Bar + + foo = Foo(a=["abc", "def"], b=Bar(v=[1, 2, 99], w=[3.1415, 2.7182])) + generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct()) + lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) + + transformer = TypeEngine.get_transformer(typing.List) + ctx = FlyteContext.current_context() + + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) + assert pv[0].a == ["abc", "def"] + assert pv[0].b == Bar(v=[1, 2, 99], w=[3.1415, 2.7182]) + + def test_list_of_dataclass_getting_python_value(): @dataclass_json @dataclass() @@ -437,8 +462,8 @@ class TestStruct(object): class TestStructB(object): s: InnerStruct m: typing.Dict[int, str] - n: typing.List[typing.List[int]] = None - o: typing.Dict[int, typing.Dict[int, int]] = None + n: typing.Optional[typing.List[typing.List[int]]] = None + o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None @dataclass_json @@ -1004,9 +1029,9 @@ def to_literal( return Literal(scalar=Scalar(primitive=Primitive(integer=python_val))) - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[T]) -> T: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[T]) -> Literal: val = lv.scalar.primitive.integer - return UnsignedInt(0 if val < 0 else val) + return UnsignedInt(0 if val < 0 else val) # type: ignore TypeEngine.register(UnsignedIntTransformer()) @@ -1254,17 +1279,16 @@ def t1(a: int) -> int: assert t1(a=3) == 9 -def test_literal_hash_int_not_set(): +def test_literal_hash_int_can_be_set(): """ - Test to confirm that annotating an integer with `HashMethod` does not force the literal to have its - hash set. + Test to confirm that annotating an integer with `HashMethod` is allowed. """ ctx = FlyteContext.current_context() lv = TypeEngine.to_literal( ctx, 42, Annotated[int, HashMethod(str)], LiteralType(simple=model_types.SimpleType.INTEGER) ) assert lv.scalar.primitive.integer == 42 - assert lv.hash is None + assert lv.hash == "42" def test_literal_hash_to_python_value(): @@ -1399,3 +1423,56 @@ def hello(self): lr = LiteralsResolver(lit_dict) assert lr.get("a", Foo) == foo assert hasattr(lr.get("a", Foo), "hello") is True + + +def test_flyte_dir_in_union(): + pt = typing.Union[str, FlyteDirectory, FlyteFile] + lt = TypeEngine.to_literal_type(pt) + ctx = FlyteContext.current_context() + tf = UnionTransformer() + + pv = tempfile.mkdtemp(prefix="flyte-") + lv = tf.to_literal(ctx, FlyteDirectory(pv), pt, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=pt) + assert ot is not None + + pv = "s3://bucket/key" + lv = tf.to_literal(ctx, FlyteFile(pv), pt, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=pt) + assert ot is not None + + pv = "hello" + lv = tf.to_literal(ctx, pv, pt, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=pt) + assert ot == "hello" + + +def test_file_ext_with_flyte_file_existing_file(): + assert JPEGImageFile.extension() == "jpeg" + + +def test_file_ext_convert_static_method(): + TAR_GZ = Annotated[str, FileExt("tar.gz")] + item = FileExt.check_and_convert_to_str(TAR_GZ) + assert item == "tar.gz" + + str_item = FileExt.check_and_convert_to_str("csv") + assert str_item == "csv" + + +def test_file_ext_with_flyte_file_new_file(): + TAR_GZ = Annotated[str, FileExt("tar.gz")] + flyte_file = FlyteFile[TAR_GZ] + assert flyte_file.extension() == "tar.gz" + + +class WrongType: + def __init__(self, num: int): + self.num = num + + +def test_file_ext_with_flyte_file_wrong_type(): + WRONG_TYPE = Annotated[int, WrongType(2)] + with pytest.raises(ValueError) as e: + FlyteFile[WRONG_TYPE] + assert str(e.value) == "Underlying type of File Extension must be of type " diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index abdf69f5b0..b6d2d77ae5 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -1613,14 +1613,14 @@ def foo(a: int, b: str) -> typing.Tuple[int, str]: @task def foo2(a: int, b: str) -> typing.Tuple[int, str]: - return "hello", 10 + return "hello", 10 # type: ignore @task def foo3(a: typing.Dict) -> typing.Dict: return a with pytest.raises(TypeError, match="Type of Val 'hello' is not an instance of "): - foo(a="hello", b=10) + foo(a="hello", b=10) # type: ignore with pytest.raises( TypeError, @@ -1629,7 +1629,7 @@ def foo3(a: typing.Dict) -> typing.Dict: foo2(a=10, b="hello") with pytest.raises(TypeError, match="Not a collection type simple: STRUCT\n but got a list \\[{'hello': 2}\\]"): - foo3(a=[{"hello": 2}]) + foo3(a=[{"hello": 2}]) # type: ignore def test_union_type(): @@ -1648,7 +1648,9 @@ def wf(a: ut) -> ut: assert wf(a=2.0) == 2.0 file = tempfile.NamedTemporaryFile(delete=False) assert isinstance(wf(a=FlyteFile(file.name)), FlyteFile) - assert isinstance(wf(a=FlyteSchema()), FlyteSchema) + flyteSchema = FlyteSchema() + flyteSchema.open().write(pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]})) + assert isinstance(wf(a=flyteSchema), FlyteSchema) assert wf(a=[1, 2, 3]) == [1, 2, 3] assert wf(a={"a": 1}) == {"a": 1} @@ -1811,11 +1813,11 @@ def wf(a: int) -> str: del TypeEngine._REGISTRY[MyInt] -def test_task_annotate_primitive_type_has_no_effect(): +def test_task_annotate_primitive_type_is_allowed(): @task def plus_two( a: int, - ) -> Annotated[int, HashMethod(str)]: # Note the use of `str` as the hash function for ints. This has no effect. + ) -> Annotated[int, HashMethod(lambda x: str(x + 1))]: return a + 2 assert plus_two(a=1) == 3 @@ -1832,7 +1834,7 @@ def plus_two( ), ) assert output_lm.literals["o0"].scalar.primitive.integer == 5 - assert output_lm.literals["o0"].hash is None + assert output_lm.literals["o0"].hash == "6" def test_task_hash_return_pandas_dataframe(): diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index c21284211b..46389daed2 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -4,6 +4,7 @@ import pandas as pd import pytest from pandas.testing import assert_frame_equal +from typing_extensions import Annotated # type: ignore import flytekit.configuration from flytekit import StructuredDataset, kwtypes @@ -15,11 +16,6 @@ from flytekit.tools.translator import get_serializable from flytekit.types.schema import FlyteSchema -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( project="project", diff --git a/tests/flytekit/unit/deck/test_deck.py b/tests/flytekit/unit/deck/test_deck.py index 3db311653c..a6b00e79e2 100644 --- a/tests/flytekit/unit/deck/test_deck.py +++ b/tests/flytekit/unit/deck/test_deck.py @@ -1,4 +1,5 @@ import pandas as pd +import pytest from mock import mock import flytekit @@ -21,12 +22,53 @@ def test_deck(): _output_deck("test_task", ctx.user_space_params) - @task() + +@pytest.mark.parametrize( + "disable_deck,expected_decks", + [ + (None, 0), + (False, 2), # input and output decks + (True, 0), + ], +) +def test_deck_for_task(disable_deck, expected_decks): + ctx = FlyteContextManager.current_context() + + kwargs = {} + if disable_deck is not None: + kwargs["disable_deck"] = disable_deck + + @task(**kwargs) def t1(a: int) -> str: return str(a) t1(a=3) - assert len(ctx.user_space_params.decks) == 2 # input, output decks + assert len(ctx.user_space_params.decks) == expected_decks + + +@pytest.mark.parametrize( + "disable_deck, expected_decks", + [ + (None, 1), + (False, 1 + 2), # input and output decks + (True, 1), + ], +) +def test_deck_pandas_dataframe(disable_deck, expected_decks): + ctx = FlyteContextManager.current_context() + + kwargs = {} + if disable_deck is not None: + kwargs["disable_deck"] = disable_deck + + @task(**kwargs) + def t_df(a: str) -> int: + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + flytekit.current_context().default_deck.append(TopFrameRenderer().to_html(df)) + return int(a) + + t_df(a="42") + assert len(ctx.user_space_params.decks) == expected_decks @mock.patch("flytekit.deck.deck._ipython_check") @@ -35,6 +77,10 @@ def test_deck_in_jupyter(mock_ipython_check): ctx = FlyteContextManager.current_context() ctx.user_space_params._decks = [ctx.user_space_params.default_deck] + v = ctx.get_deck() + from IPython.core.display import HTML + + assert isinstance(v, HTML) _output_deck("test_task", ctx.user_space_params) @task() diff --git a/tests/flytekit/unit/deck/test_renderer.py b/tests/flytekit/unit/deck/test_renderer.py index 3f597af416..257e320cfd 100644 --- a/tests/flytekit/unit/deck/test_renderer.py +++ b/tests/flytekit/unit/deck/test_renderer.py @@ -1,12 +1,33 @@ import pandas as pd import pyarrow as pa +import pytest -from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer +from flytekit.deck.renderer import DEFAULT_MAX_COLS, DEFAULT_MAX_ROWS, ArrowRenderer, TopFrameRenderer -def test_renderer(): - df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]}) +@pytest.mark.parametrize( + "rows, cols, max_rows, expected_max_rows, max_cols, expected_max_cols", + [ + (1, 1, None, DEFAULT_MAX_ROWS, None, DEFAULT_MAX_COLS), + (10, 1, None, DEFAULT_MAX_ROWS, None, DEFAULT_MAX_COLS), + (1, 10, None, DEFAULT_MAX_ROWS, None, DEFAULT_MAX_COLS), + (DEFAULT_MAX_ROWS + 1, 10, None, DEFAULT_MAX_ROWS, None, DEFAULT_MAX_COLS), + (1, DEFAULT_MAX_COLS + 1, None, DEFAULT_MAX_ROWS, None, DEFAULT_MAX_COLS), + (10, DEFAULT_MAX_COLS + 1, None, DEFAULT_MAX_ROWS, None, DEFAULT_MAX_COLS), + (DEFAULT_MAX_ROWS + 1, DEFAULT_MAX_COLS + 1, None, DEFAULT_MAX_ROWS, None, DEFAULT_MAX_COLS), + (100_000, 10, 123, 123, 5, 5), + (10_000, 1000, DEFAULT_MAX_ROWS, DEFAULT_MAX_ROWS, DEFAULT_MAX_COLS, DEFAULT_MAX_COLS), + ], +) +def test_renderer(rows, cols, max_rows, expected_max_rows, max_cols, expected_max_cols): + df = pd.DataFrame({f"abc-{k}": list(range(rows)) for k in range(cols)}) pa_df = pa.Table.from_pandas(df) - assert TopFrameRenderer().to_html(df) == df.to_html() + kwargs = {} + if max_rows is not None: + kwargs["max_rows"] = max_rows + if max_cols is not None: + kwargs["max_cols"] = max_cols + + assert TopFrameRenderer(**kwargs).to_html(df) == df.to_html(max_rows=expected_max_rows, max_cols=expected_max_cols) assert ArrowRenderer().to_html(pa_df) == pa_df.to_string() diff --git a/tests/flytekit/unit/extras/pytorch/test_transformations.py b/tests/flytekit/unit/extras/pytorch/test_transformations.py index 1a3a83ab93..9724a01182 100644 --- a/tests/flytekit/unit/extras/pytorch/test_transformations.py +++ b/tests/flytekit/unit/extras/pytorch/test_transformations.py @@ -72,7 +72,8 @@ def test_get_literal_type(transformer, python_type, format): def test_to_python_value_and_literal(transformer, python_type, format, python_val): ctx = context_manager.FlyteContext.current_context() tf = transformer - python_val = python_val + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + python_val = python_val.to(device) if hasattr(python_val, "to") else python_val lt = tf.get_literal_type(python_type) lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore diff --git a/tests/flytekit/unit/extras/sklearn/__init__.py b/tests/flytekit/unit/extras/sklearn/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/sklearn/test_native.py b/tests/flytekit/unit/extras/sklearn/test_native.py new file mode 100644 index 0000000000..7d704597c1 --- /dev/null +++ b/tests/flytekit/unit/extras/sklearn/test_native.py @@ -0,0 +1,48 @@ +import numpy as np +from sklearn.linear_model import LinearRegression +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler + +from flytekit import task, workflow + + +@task +def get_preprocessor() -> StandardScaler: + return StandardScaler() + + +@task +def get_model() -> LinearRegression: + return LinearRegression() + + +@task +def make_pipeline(preprocessor: StandardScaler, model: LinearRegression) -> Pipeline: + return Pipeline([("scaler", preprocessor), ("model", model)]) + + +@task +def fit_pipeline(pipeline: Pipeline) -> Pipeline: + x = np.random.normal(size=(10, 2)) + y = np.random.randint(2, size=(10,)) + pipeline.fit(x, y) + return pipeline + + +@task +def num_features(pipeline: Pipeline) -> int: + return pipeline.n_features_in_ + + +@workflow +def wf(): + preprocessor = get_preprocessor() + model = get_model() + pipeline = make_pipeline(preprocessor=preprocessor, model=model) + pipeline = fit_pipeline(pipeline=pipeline) + num_features(pipeline=pipeline) + + +@workflow +def test_wf(): + wf() diff --git a/tests/flytekit/unit/extras/sklearn/test_transformations.py b/tests/flytekit/unit/extras/sklearn/test_transformations.py new file mode 100644 index 0000000000..39343f9180 --- /dev/null +++ b/tests/flytekit/unit/extras/sklearn/test_transformations.py @@ -0,0 +1,96 @@ +from collections import OrderedDict +from functools import partial + +import numpy as np +import pytest +from sklearn.base import BaseEstimator +from sklearn.linear_model import LinearRegression +from sklearn.svm import SVC + +import flytekit +from flytekit import task +from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.extras.sklearn import SklearnEstimatorTransformer +from flytekit.models.core.types import BlobType +from flytekit.models.literals import BlobMetadata +from flytekit.models.types import LiteralType +from flytekit.tools.translator import get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +def get_model(model_type: str) -> BaseEstimator: + models_map = { + "lr": LinearRegression, + "svc": partial(SVC, kernel="linear"), + } + x = np.random.normal(size=(10, 2)) + y = np.random.randint(2, size=(10,)) + model = models_map[model_type]() + model.fit(x, y) + return model + + +@pytest.mark.parametrize( + "transformer,python_type,format", + [ + (SklearnEstimatorTransformer(), BaseEstimator, SklearnEstimatorTransformer.SKLEARN_FORMAT), + ], +) +def test_get_literal_type(transformer, python_type, format): + tf = transformer + lt = tf.get_literal_type(python_type) + assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)) + + +@pytest.mark.parametrize( + "transformer,python_type,format,python_val", + [ + ( + SklearnEstimatorTransformer(), + BaseEstimator, + SklearnEstimatorTransformer.SKLEARN_FORMAT, + get_model("lr"), + ), + ( + SklearnEstimatorTransformer(), + BaseEstimator, + SklearnEstimatorTransformer.SKLEARN_FORMAT, + get_model("svc"), + ), + ], +) +def test_to_python_value_and_literal(transformer, python_type, format, python_val): + ctx = context_manager.FlyteContext.current_context() + tf = transformer + lt = tf.get_literal_type(python_type) + + lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=format, + dimensionality=BlobType.BlobDimensionality.SINGLE, + ) + ) + assert lv.scalar.blob.uri is not None + + output = tf.to_python_value(ctx, lv, python_type) + + np.testing.assert_array_equal(output.coef_, python_val.coef_) + + +def test_example_estimator(): + @task + def t1() -> BaseEstimator: + return get_model("lr") + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.interface.outputs["o0"].type.blob.format is SklearnEstimatorTransformer.SKLEARN_FORMAT diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index f586d94a16..ef7ea491e6 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -1,12 +1,17 @@ import pandas +import pytest from flytekit import kwtypes, task, workflow +from flytekit.configuration import DefaultImages +from flytekit.core import context_manager from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task # https://www.sqlitetutorial.net/sqlite-sample-database/ from flytekit.types.schema import FlyteSchema -EXAMPLE_DB = "https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip" +ctx = context_manager.FlyteContextManager.current_context() +EXAMPLE_DB = ctx.file_access.get_random_local_path("chinook.zip") +ctx.file_access.get_data("https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip", EXAMPLE_DB) # This task belongs to test_task_static but is intentionally here to help test tracking tk = SQLite3Task( @@ -28,7 +33,6 @@ def test_task_static(): def test_task_schema(): # sqlite3_start - DB_LOCATION = "https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip" sql_task = SQLite3Task( "test", @@ -36,7 +40,7 @@ def test_task_schema(): inputs=kwtypes(limit=int), output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], task_config=SQLite3Config( - uri=DB_LOCATION, + uri=EXAMPLE_DB, compressed=True, ), ) @@ -99,4 +103,42 @@ def test_task_serialization(): ] assert tt.custom["query_template"] == "select TrackId, Name from tracks limit {{.inputs.limit}}" - assert tt.container.image != "" + assert tt.container.image == DefaultImages.default_image() + + image = "xyz.io/docker2:latest" + sql_task._container_image = image + tt = sql_task.serialize_to_model(sql_task.SERIALIZE_SETTINGS) + assert tt.container.image == image + + +@pytest.mark.parametrize( + "query_template, expected_query", + [ + ( + """ +select * +from tracks +limit {{.inputs.limit}}""", + " select * from tracks limit {{.inputs.limit}}", + ), + ( + """ \ +select * \ +from tracks \ +limit {{.inputs.limit}}""", + " select * from tracks limit {{.inputs.limit}}", + ), + ("select * from abc", "select * from abc"), + ], +) +def test_query_sanitization(query_template, expected_query): + sql_task = SQLite3Task( + "test", + query_template=query_template, + inputs=kwtypes(limit=int), + task_config=SQLite3Config( + uri=EXAMPLE_DB, + compressed=True, + ), + ) + assert sql_task.query_template == expected_query diff --git a/tests/flytekit/unit/extras/tensorflow/__init__.py b/tests/flytekit/unit/extras/tensorflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/tensorflow/record/__init__.py b/tests/flytekit/unit/extras/tensorflow/record/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/tensorflow/record/test_record.py b/tests/flytekit/unit/extras/tensorflow/record/test_record.py new file mode 100644 index 0000000000..c523475562 --- /dev/null +++ b/tests/flytekit/unit/extras/tensorflow/record/test_record.py @@ -0,0 +1,117 @@ +from typing import Dict, Tuple + +import numpy as np +import tensorflow as tf +from tensorflow.python.data.ops.readers import TFRecordDatasetV2 +from typing_extensions import Annotated + +from flytekit import task, workflow +from flytekit.extras.tensorflow.record import TFRecordDatasetConfig +from flytekit.types.directory import TFRecordsDirectory +from flytekit.types.file import TFRecordFile + +a = tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"foo", b"bar"])) +b = tf.train.Feature(float_list=tf.train.FloatList(value=[1.0, 2.0])) +c = tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])) +d = tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"ham", b"spam"])) +e = tf.train.Feature(float_list=tf.train.FloatList(value=[8.0, 9.0])) +f = tf.train.Feature(int64_list=tf.train.Int64List(value=[22, 23])) +features1 = tf.train.Features(feature=dict(a=a, b=b, c=c)) +features2 = tf.train.Features(feature=dict(a=d, b=e, c=f)) + + +def decode_fn(dataset: TFRecordDatasetV2) -> Dict[str, np.ndarray]: + examples_list = [] + # parse serialised tensors https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file_2 + for batch in list(dataset.as_numpy_iterator()): + example = tf.train.Example() + example.ParseFromString(batch) + examples_list.append(example) + result = {} + for a in examples_list: + # convert example to dict of numpy arrays + # https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file_2 + for key, feature in a.features.feature.items(): + kind = feature.WhichOneof("kind") + val = np.array(getattr(feature, kind).value) + if key not in result.keys(): + result[key] = val + else: + result.update({key: np.concatenate((result[key], val))}) + return result + + +@task +def generate_tf_record_file() -> TFRecordFile: + return tf.train.Example(features=features1) + + +@task +def generate_tf_record_dir() -> TFRecordsDirectory: + return [tf.train.Example(features=features1), tf.train.Example(features=features2)] + + +@task +def t1( + dataset: Annotated[ + TFRecordFile, + TFRecordDatasetConfig(buffer_size=1024, num_parallel_reads=3, compression_type="GZIP"), + ] +): + assert isinstance(dataset, TFRecordDatasetV2) + assert dataset._compression_type == "GZIP" + assert dataset._buffer_size == 1024 + assert dataset._num_parallel_reads == 3 + + +@task +def t2(dataset: TFRecordFile): + + # if not annotated with TFRecordDatasetConfig, all attributes should default to None + assert isinstance(dataset, TFRecordDatasetV2) + assert dataset._compression_type is None + assert dataset._buffer_size is None + assert dataset._num_parallel_reads is None + + +@task +def t3(dataset: TFRecordsDirectory): + + # if not annotated with TFRecordDatasetConfig, all attributes should default to None + assert isinstance(dataset, TFRecordDatasetV2) + assert dataset._compression_type is None + assert dataset._buffer_size is None + assert dataset._num_parallel_reads is None + + +@task +def t4(dataset: Annotated[TFRecordFile, TFRecordDatasetConfig(buffer_size=1024)]) -> Dict[str, np.ndarray]: + return decode_fn(dataset) + + +@task +def t5(dataset: Annotated[TFRecordsDirectory, TFRecordDatasetConfig(buffer_size=1024)]) -> Dict[str, np.ndarray]: + return decode_fn(dataset) + + +@workflow +def wf() -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + file = generate_tf_record_file() + files = generate_tf_record_dir() + t1(dataset=file) + t2(dataset=file) + t3(dataset=files) + files_res = t4(dataset=file) + dir_res = t5(dataset=files) + return files_res, dir_res + + +def test_wf(): + file_res, dir_res = wf() + assert np.array_equal(file_res["a"], np.array([b"foo", b"bar"])) + assert np.array_equal(file_res["b"], np.array([1.0, 2.0])) + assert np.array_equal(file_res["c"], np.array([3, 4])) + + assert np.array_equal(np.sort(dir_res["a"]), np.array([b"bar", b"foo", b"ham", b"spam"])) + assert np.array_equal(np.sort(dir_res["b"]), np.array([1.0, 2.0, 8.0, 9.0])) + assert np.array_equal(np.sort(dir_res["c"]), np.array([3, 4, 22, 23])) diff --git a/tests/flytekit/unit/extras/tensorflow/record/test_transformations.py b/tests/flytekit/unit/extras/tensorflow/record/test_transformations.py new file mode 100644 index 0000000000..f236696a0c --- /dev/null +++ b/tests/flytekit/unit/extras/tensorflow/record/test_transformations.py @@ -0,0 +1,89 @@ +import pytest +import tensorflow +import tensorflow as tf +from tensorflow.core.example.example_pb2 import Example +from tensorflow.python.data.ops.readers import TFRecordDatasetV2 +from typing_extensions import Annotated + +import flytekit +from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.extras.tensorflow.record import ( + TensorFlowRecordFileTransformer, + TensorFlowRecordsDirTransformer, + TFRecordDatasetConfig, +) +from flytekit.models.core.types import BlobType +from flytekit.models.literals import BlobMetadata +from flytekit.models.types import LiteralType +from flytekit.types.directory import TFRecordsDirectory +from flytekit.types.file import TFRecordFile + +from .test_record import features1, features2 + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +@pytest.mark.parametrize( + "transformer,python_type,format,dimensionality", + [ + (TensorFlowRecordFileTransformer(), TFRecordFile, TensorFlowRecordFileTransformer.TENSORFLOW_FORMAT, 0), + (TensorFlowRecordsDirTransformer(), TFRecordsDirectory, TensorFlowRecordsDirTransformer.TENSORFLOW_FORMAT, 1), + ], +) +def test_get_literal_type(transformer, python_type, format, dimensionality): + tf = transformer + lt = tf.get_literal_type(python_type) + assert lt == LiteralType(blob=BlobType(format=format, dimensionality=dimensionality)) + + +@pytest.mark.parametrize( + "transformer,python_type,format,python_val,dimension", + [ + ( + TensorFlowRecordFileTransformer(), + TFRecordFile, + TensorFlowRecordFileTransformer.TENSORFLOW_FORMAT, + tf.train.Example(features=features1), + BlobType.BlobDimensionality.SINGLE, + ), + ( + TensorFlowRecordsDirTransformer(), + TFRecordsDirectory, + TensorFlowRecordsDirTransformer.TENSORFLOW_FORMAT, + [tf.train.Example(features=features1), tf.train.Example(features=features2)], + BlobType.BlobDimensionality.MULTIPART, + ), + ], +) +def test_to_python_value_and_literal(transformer, python_type, format, python_val, dimension): + ctx = context_manager.FlyteContext.current_context() + tf = transformer + lt = tf.get_literal_type(python_type) + lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=format, + dimensionality=dimension, + ) + ) + assert lv.scalar.blob.uri is not None + output = tf.to_python_value(ctx, lv, Annotated[python_type, TFRecordDatasetConfig(buffer_size=1024)]) + assert isinstance(output, TFRecordDatasetV2) + results = [] + example = tensorflow.train.Example() + for raw_record in output: + example.ParseFromString(raw_record.numpy()) + results.append(example) + if isinstance(python_val, list): + assert len(results) == 2 + assert all(list(map(lambda x: isinstance(x, Example), python_val))) + else: + assert results == [python_val] diff --git a/tests/flytekit/unit/models/test_execution.py b/tests/flytekit/unit/models/test_execution.py index 3db0a8eb23..b327d5e9d6 100644 --- a/tests/flytekit/unit/models/test_execution.py +++ b/tests/flytekit/unit/models/test_execution.py @@ -66,16 +66,79 @@ def test_execution_closure_with_error(): assert obj2.error == test_error +def test_execution_closure_with_abort_metadata(): + test_datetime = datetime.datetime(year=2022, month=1, day=1, tzinfo=pytz.UTC) + test_timedelta = datetime.timedelta(seconds=10) + abort_metadata = _execution.AbortMetadata(cause="cause", principal="skinner") + + obj = _execution.ExecutionClosure( + phase=_core_exec.WorkflowExecutionPhase.SUCCEEDED, + started_at=test_datetime, + duration=test_timedelta, + abort_metadata=abort_metadata, + ) + assert obj.phase == _core_exec.WorkflowExecutionPhase.SUCCEEDED + assert obj.started_at == test_datetime + assert obj.duration == test_timedelta + assert obj.abort_metadata == abort_metadata + obj2 = _execution.ExecutionClosure.from_flyte_idl(obj.to_flyte_idl()) + assert obj2 == obj + assert obj2.phase == _core_exec.WorkflowExecutionPhase.SUCCEEDED + assert obj2.started_at == test_datetime + assert obj2.duration == test_timedelta + assert obj2.abort_metadata == abort_metadata + + +def test_system_metadata(): + obj = _execution.SystemMetadata(execution_cluster="my_cluster") + assert obj.execution_cluster == "my_cluster" + obj2 = _execution.SystemMetadata.from_flyte_idl(obj.to_flyte_idl()) + assert obj == obj2 + assert obj2.execution_cluster == "my_cluster" + + def test_execution_metadata(): - obj = _execution.ExecutionMetadata(_execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1) + scheduled_at = datetime.datetime.now() + system_metadata = _execution.SystemMetadata(execution_cluster="my_cluster") + parent_node_execution = _identifier.NodeExecutionIdentifier( + node_id="node_id", + execution_id=_identifier.WorkflowExecutionIdentifier( + project="project", + domain="domain", + name="parent", + ), + ) + reference_execution = _identifier.WorkflowExecutionIdentifier( + project="project", + domain="domain", + name="reference", + ) + + obj = _execution.ExecutionMetadata( + _execution.ExecutionMetadata.ExecutionMode.MANUAL, + "tester", + 1, + scheduled_at=scheduled_at, + parent_node_execution=parent_node_execution, + reference_execution=reference_execution, + system_metadata=system_metadata, + ) assert obj.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj.principal == "tester" assert obj.nesting == 1 + assert obj.scheduled_at == scheduled_at + assert obj.parent_node_execution == parent_node_execution + assert obj.reference_execution == reference_execution + assert obj.system_metadata == system_metadata obj2 = _execution.ExecutionMetadata.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj2.principal == "tester" assert obj2.nesting == 1 + assert obj2.scheduled_at == scheduled_at + assert obj2.parent_node_execution == parent_node_execution + assert obj2.reference_execution == reference_execution + assert obj2.system_metadata == system_metadata @pytest.mark.parametrize("literal_value_pair", _parameterizers.LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE) @@ -198,3 +261,13 @@ def test_task_execution_data_response(): assert obj2.outputs == output_blob assert obj2.full_inputs == _INPUT_MAP assert obj2.full_outputs == _OUTPUT_MAP + + +def test_abort_metadata(): + obj = _execution.AbortMetadata(cause="cause", principal="skinner") + assert obj.cause == "cause" + assert obj.principal == "skinner" + obj2 = _execution.AbortMetadata.from_flyte_idl(obj.to_flyte_idl()) + assert obj == obj2 + assert obj2.cause == "cause" + assert obj2.principal == "skinner" diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 00258cb85e..dd37b97f87 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -196,18 +196,54 @@ def test_more_stuff(mock_client): @patch("flytekit.remote.remote.SynchronousFlyteClient") -def test_generate_http_domain_sandbox_rewrite(mock_client): +def test_generate_console_http_domain_sandbox_rewrite(mock_client): _, temp_filename = tempfile.mkstemp(suffix=".yaml") - with open(temp_filename, "w") as f: - # This string is similar to the relevant configuration emitted by flytectl in the cases of both demo and sandbox. - flytectl_config_file = """admin: + try: + with open(temp_filename, "w") as f: + # This string is similar to the relevant configuration emitted by flytectl in the cases of both demo and sandbox. + flytectl_config_file = """admin: + endpoint: example.com + authType: Pkce + insecure: false +""" + f.write(flytectl_config_file) + + remote = FlyteRemote( + config=Config.auto(config_file=temp_filename), default_project="project", default_domain="domain" + ) + assert remote.generate_console_http_domain() == "https://example.com" + + with open(temp_filename, "w") as f: + # This string is similar to the relevant configuration emitted by flytectl in the cases of both demo and sandbox. + flytectl_config_file = """admin: endpoint: localhost:30081 authType: Pkce insecure: true - """ - f.write(flytectl_config_file) +""" + f.write(flytectl_config_file) - remote = FlyteRemote( - config=Config.auto(config_file=temp_filename), default_project="project", default_domain="domain" - ) - assert remote.generate_http_domain() == "http://localhost:30080" + remote = FlyteRemote( + config=Config.auto(config_file=temp_filename), default_project="project", default_domain="domain" + ) + assert remote.generate_console_http_domain() == "http://localhost:30080" + + with open(temp_filename, "w") as f: + # This string is similar to the relevant configuration emitted by flytectl in the cases of both demo and sandbox. + flytectl_config_file = """admin: + endpoint: localhost:30081 + authType: Pkce + insecure: true +console: + endpoint: http://localhost:30090 +""" + f.write(flytectl_config_file) + + remote = FlyteRemote( + config=Config.auto(config_file=temp_filename), default_project="project", default_domain="domain" + ) + assert remote.generate_console_http_domain() == "http://localhost:30090" + finally: + try: + os.remove(temp_filename) + except OSError: + pass diff --git a/tests/flytekit/unit/test_translator.py b/tests/flytekit/unit/test_translator.py index 6f881f0e01..792ca0b131 100644 --- a/tests/flytekit/unit/test_translator.py +++ b/tests/flytekit/unit/test_translator.py @@ -10,6 +10,7 @@ from flytekit.core.task import ReferenceTask, task from flytekit.core.workflow import ReferenceWorkflow, workflow from flytekit.models.core import identifier as identifier_models +from flytekit.models.task import Resources as resource_model from flytekit.tools.translator import get_serializable default_img = Image(name="default", fqn="test", tag="tag") @@ -114,7 +115,7 @@ def t1(a: int) -> (int, str): output_data_dir="/tmp", command=["cat"], arguments=["/tmp/a"], - requests=Resources(mem="400Mi", cpu="1"), + requests=Resources(mem="400Mi", cpu="1", gpu="2"), ) ssettings = ( @@ -124,6 +125,12 @@ def t1(a: int) -> (int, str): ) task_spec = get_serializable(OrderedDict(), ssettings, t2) assert "pyflyte" not in task_spec.template.container.args + assert t2.get_container(ssettings).resources.requests[0].name == resource_model.ResourceName.CPU + assert t2.get_container(ssettings).resources.requests[0].value == "1" + assert t2.get_container(ssettings).resources.requests[1].name == resource_model.ResourceName.GPU + assert t2.get_container(ssettings).resources.requests[1].value == "2" + assert t2.get_container(ssettings).resources.requests[2].name == resource_model.ResourceName.MEMORY + assert t2.get_container(ssettings).resources.requests[2].value == "400Mi" def test_launch_plan_with_fixed_input(): diff --git a/tests/flytekit/unit/types/numpy/test_ndarray.py b/tests/flytekit/unit/types/numpy/test_ndarray.py index 761cb74812..d53979f2a9 100644 --- a/tests/flytekit/unit/types/numpy/test_ndarray.py +++ b/tests/flytekit/unit/types/numpy/test_ndarray.py @@ -1,6 +1,7 @@ import numpy as np +from typing_extensions import Annotated -from flytekit import task, workflow +from flytekit import kwtypes, task, workflow @task @@ -14,8 +15,7 @@ def generate_numpy_2d() -> np.ndarray: @task -def generate_numpy_dtype_object() -> np.ndarray: - # dtype=object cannot be serialized +def generate_numpy_dtype_object() -> Annotated[np.ndarray, kwtypes(allow_pickle=True, mmap_mode="r")]: return np.array( [ [ @@ -33,6 +33,11 @@ def generate_numpy_dtype_object() -> np.ndarray: ) +@task +def generate_numpy_fails() -> Annotated[np.ndarray, {"allow_pickle": True}]: + return np.array([1, 2, 3]) + + @task def t1(array: np.ndarray) -> np.ndarray: assert array.dtype == int @@ -53,17 +58,24 @@ def t3(array: np.ndarray) -> np.ndarray: return array.reshape(2, 3) +@task +def t4(array: Annotated[np.ndarray, kwtypes(allow_pickle=True)]) -> int: + return array.size + + @workflow def wf(): array_1d = generate_numpy_1d() array_2d = generate_numpy_2d() - try: - generate_numpy_dtype_object() - except Exception as e: - assert isinstance(e, TypeError) + array_dtype_object = generate_numpy_dtype_object() t1(array=array_1d) t2(array=array_2d) t3(array=array_1d) + t4(array=array_dtype_object) + try: + generate_numpy_fails() + except Exception as e: + assert isinstance(e, TypeError) @workflow diff --git a/tests/flytekit/unit/types/structured_dataset/test_bigquery.py b/tests/flytekit/unit/types/structured_dataset/test_bigquery.py index 4801568d5e..2877ed3743 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_bigquery.py +++ b/tests/flytekit/unit/types/structured_dataset/test_bigquery.py @@ -1,14 +1,9 @@ import mock import pandas as pd +from typing_extensions import Annotated from flytekit import StructuredDataset, kwtypes, task, workflow -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - - pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) my_cols = kwtypes(Name=str, Age=int)