diff --git a/dev-requirements.txt b/dev-requirements.txt index 645ffbeb18..13f2006571 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -8,6 +8,8 @@ # via # -c requirements.txt # pytest-flyte +appnope==0.1.3 + # via ipython arrow==1.2.2 # via # -c requirements.txt @@ -20,7 +22,7 @@ attrs==20.3.0 # pytest-docker backcall==0.2.0 # via ipython -bcrypt==3.2.2 +bcrypt==4.0.0 # via paramiko binaryornot==0.4.4 # via @@ -37,7 +39,6 @@ certifi==2022.6.15 cffi==1.15.1 # via # -c requirements.txt - # bcrypt # cryptography # pynacl cfgv==3.3.1 @@ -46,7 +47,7 @@ chardet==5.0.0 # via # -c requirements.txt # binaryornot -charset-normalizer==2.1.0 +charset-normalizer==2.1.1 # via # -c requirements.txt # requests @@ -59,13 +60,13 @@ cloudpickle==2.1.0 # via # -c requirements.txt # flytekit -codespell==2.1.0 +codespell==2.2.1 # via -r dev-requirements.in cookiecutter==2.1.1 # via # -c requirements.txt # flytekit -coverage[toml]==6.4.1 +coverage[toml]==6.4.4 # via # -r dev-requirements.in # pytest-cov @@ -78,7 +79,6 @@ cryptography==37.0.4 # -c requirements.txt # paramiko # pyopenssl - # secretstorage dataclasses-json==0.5.7 # via # -c requirements.txt @@ -96,19 +96,17 @@ diskcache==5.4.0 # via # -c requirements.txt # flytekit -distlib==0.3.4 +distlib==0.3.6 # via virtualenv distro==1.7.0 # via docker-compose -docker[ssh]==5.0.3 +docker[ssh]==6.0.0 # via # -c requirements.txt # docker-compose # flytekit docker-compose==1.29.2 - # via - # pytest-docker - # pytest-flyte + # via pytest-flyte docker-image-py==0.1.12 # via # -c requirements.txt @@ -121,9 +119,9 @@ docstring-parser==0.14.1 # via # -c requirements.txt # flytekit -filelock==3.7.1 +filelock==3.8.0 # via virtualenv -flyteidl==1.1.8 +flyteidl==1.1.12 # via # -c requirements.txt # flytekit @@ -132,23 +130,23 @@ google-api-core[grpc]==2.8.2 # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.9.0 +google-auth==2.11.0 # via # google-api-core # google-cloud-core -google-cloud-bigquery==3.2.0 +google-cloud-bigquery==3.3.2 # via -r dev-requirements.in -google-cloud-bigquery-storage==2.14.0 +google-cloud-bigquery-storage==2.14.2 # via # -r dev-requirements.in # google-cloud-bigquery -google-cloud-core==2.3.1 +google-cloud-core==2.3.2 # via google-cloud-bigquery google-crc32c==1.3.0 # via google-resumable-media google-resumable-media==2.3.3 # via google-cloud-bigquery -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.56.4 # via # -c requirements.txt # flyteidl @@ -166,7 +164,7 @@ grpcio-status==1.47.0 # -c requirements.txt # flytekit # google-api-core -identify==2.5.1 +identify==2.5.3 # via pre-commit idna==3.3 # via @@ -189,11 +187,6 @@ ipython==7.34.0 # via -r dev-requirements.in jedi==0.18.1 # via ipython -jeepney==0.8.0 - # via - # -c requirements.txt - # keyring - # secretstorage jinja2==3.1.2 # via # -c requirements.txt @@ -205,12 +198,15 @@ jinja2-time==0.2.0 # -c requirements.txt # cookiecutter joblib==1.1.0 - # via -r dev-requirements.in + # via + # -c requirements.txt + # -r dev-requirements.in + # flytekit jsonschema==3.2.0 # via # -c requirements.txt # docker-compose -keyring==23.6.0 +keyring==23.8.2 # via # -c requirements.txt # flytekit @@ -218,7 +214,7 @@ markupsafe==2.1.1 # via # -c requirements.txt # jinja2 -marshmallow==3.17.0 +marshmallow==3.17.1 # via # -c requirements.txt # dataclasses-json @@ -232,11 +228,11 @@ marshmallow-jsonschema==0.13.0 # via # -c requirements.txt # flytekit -matplotlib-inline==0.1.3 +matplotlib-inline==0.1.6 # via ipython mock==4.0.3 # via -r dev-requirements.in -mypy==0.961 +mypy==0.971 # via -r dev-requirements.in mypy-extensions==0.4.3 # via @@ -258,6 +254,7 @@ numpy==1.21.6 packaging==21.3 # via # -c requirements.txt + # docker # google-cloud-bigquery # marshmallow # pytest @@ -281,7 +278,7 @@ pre-commit==2.20.0 # via -r dev-requirements.in prompt-toolkit==3.0.30 # via ipython -proto-plus==1.20.6 +proto-plus==1.22.1 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -323,7 +320,7 @@ pycparser==2.21 # via # -c requirements.txt # cffi -pygments==2.12.0 +pygments==2.13.0 # via ipython pynacl==1.5.0 # via paramiko @@ -347,7 +344,7 @@ pytest==7.1.2 # pytest-flyte pytest-cov==3.0.0 # via -r dev-requirements.in -pytest-docker==0.12.0 +pytest-docker==1.0.0 # via pytest-flyte pytest-flyte @ git+https://github.com/flyteorg/pytest-flyte@main # via -r dev-requirements.in @@ -373,7 +370,7 @@ pytimeparse==1.1.8 # via # -c requirements.txt # flytekit -pytz==2022.1 +pytz==2022.2.1 # via # -c requirements.txt # flytekit @@ -385,7 +382,7 @@ pyyaml==5.4.1 # docker-compose # flytekit # pre-commit -regex==2022.7.9 +regex==2022.8.17 # via # -c requirements.txt # docker-image-py @@ -407,12 +404,8 @@ retry==0.9.2 # via # -c requirements.txt # flytekit -rsa==4.8 +rsa==4.9 # via google-auth -secretstorage==3.3.2 - # via - # -c requirements.txt - # keyring singledispatchmethod==1.0 # via # -c requirements.txt @@ -426,7 +419,6 @@ six==1.16.0 # jsonschema # paramiko # python-dateutil - # virtualenv # websocket-client sortedcontainers==2.4.0 # via @@ -449,7 +441,7 @@ tomli==2.0.1 # coverage # mypy # pytest -torch==1.12.0 +torch==1.12.1 # via -r dev-requirements.in traitlets==5.3.0 # via @@ -467,17 +459,18 @@ typing-extensions==4.3.0 # responses # torch # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via # -c requirements.txt # dataclasses-json -urllib3==1.26.10 +urllib3==1.26.12 # via # -c requirements.txt + # docker # flytekit # requests # responses -virtualenv==20.15.1 +virtualenv==20.16.4 # via pre-commit wcwidth==0.2.5 # via prompt-toolkit @@ -495,7 +488,7 @@ wrapt==1.14.1 # -c requirements.txt # deprecated # flytekit -zipp==3.8.0 +zipp==3.8.1 # via # -c requirements.txt # importlib-metadata diff --git a/doc-requirements.txt b/doc-requirements.txt index e9f16c89a0..6febe9d508 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -18,9 +18,11 @@ argon2-cffi-bindings==21.2.0 # via argon2-cffi arrow==1.2.2 # via jinja2-time -astroid==2.11.6 +astroid==2.12.2 # via sphinx-autoapi -attrs==21.4.0 +astunparse==1.6.3 + # via tensorflow +attrs==22.1.0 # via # jsonschema # visions @@ -42,7 +44,7 @@ binaryornot==0.4.4 # via cookiecutter bleach==5.0.1 # via nbconvert -botocore==1.27.22 +botocore==1.27.63 # via -r doc-requirements.in cachetools==5.2.0 # via google-auth @@ -86,7 +88,7 @@ dataclasses-json==0.5.7 # via # dolt-integrations # flytekit -debugpy==1.6.0 +debugpy==1.6.3 # via ipykernel decorator==5.1.1 # via @@ -98,7 +100,7 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.0 # via flytekit docker-image-py==0.1.12 # via flytekit @@ -118,13 +120,13 @@ entrypoints==0.4 # jupyter-client # nbconvert # papermill -fastjsonschema==2.15.3 +fastjsonschema==2.16.1 # via nbformat -flyteidl==1.1.8 +flyteidl==1.1.12 # via flytekit -fonttools==4.33.3 +fonttools==4.35.0 # via matplotlib -fsspec==2022.5.0 +fsspec==2022.8.0 # via # -r doc-requirements.in # modin @@ -135,29 +137,29 @@ google-api-core[grpc]==2.8.2 # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.9.0 +google-auth==2.10.0 # via # google-api-core # google-cloud-core # kubernetes google-cloud==0.34.0 # via -r doc-requirements.in -google-cloud-bigquery==3.2.0 +google-cloud-bigquery==3.3.1 # via -r doc-requirements.in -google-cloud-bigquery-storage==2.13.2 +google-cloud-bigquery-storage==2.14.2 # via google-cloud-bigquery -google-cloud-core==2.3.1 +google-cloud-core==2.3.2 # via google-cloud-bigquery google-crc32c==1.3.0 # via google-resumable-media google-resumable-media==2.3.3 # via google-cloud-bigquery -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.56.4 # via # flyteidl # google-api-core # grpcio-status -great-expectations==0.15.12 +great-expectations==0.15.18 # via -r doc-requirements.in greenlet==1.1.2 # via sqlalchemy @@ -190,9 +192,9 @@ importlib-metadata==4.12.0 # markdown # sphinx # sqlalchemy -importlib-resources==5.8.0 +importlib-resources==5.9.0 # via jsonschema -ipykernel==6.15.0 +ipykernel==6.15.2 # via # ipywidgets # jupyter @@ -211,7 +213,9 @@ ipython-genutils==0.2.0 # notebook # qtconsole ipywidgets==7.7.1 - # via jupyter + # via + # great-expectations + # jupyter jedi==0.18.1 # via ipython jeepney==0.8.0 @@ -235,13 +239,14 @@ jmespath==1.0.1 # via botocore joblib==1.1.0 # via + # flytekit # pandas-profiling # phik jsonpatch==1.32 # via great-expectations jsonpointer==2.3 # via jsonpatch -jsonschema==4.6.1 +jsonschema==4.10.0 # via # altair # great-expectations @@ -257,7 +262,7 @@ jupyter-client==7.3.4 # qtconsole jupyter-console==6.4.4 # via jupyter -jupyter-core==4.10.0 +jupyter-core==4.11.1 # via # jupyter-client # nbconvert @@ -268,17 +273,21 @@ jupyterlab-pygments==0.2.2 # via nbconvert jupyterlab-widgets==1.1.1 # via ipywidgets -keyring==23.6.0 +keyring==23.8.2 # via flytekit -kiwisolver==1.4.3 +kiwisolver==1.4.4 # via matplotlib kubernetes==24.2.0 # via -r doc-requirements.in lazy-object-proxy==1.7.1 # via astroid lxml==4.9.1 - # via sphinx-material -markdown==3.3.7 + # via + # nbconvert + # sphinx-material +makefun==1.14.0 + # via great-expectations +markdown==3.4.1 # via -r doc-requirements.in markupsafe==2.1.1 # via @@ -294,13 +303,13 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.5.2 +matplotlib==3.5.3 # via # missingno # pandas-profiling # phik # seaborn -matplotlib-inline==0.1.3 +matplotlib-inline==0.1.5 # via # ipykernel # ipython @@ -324,7 +333,7 @@ nbclient==0.6.6 # via # nbconvert # papermill -nbconvert==6.5.0 +nbconvert==6.5.3 # via # jupyter # notebook @@ -347,7 +356,6 @@ notebook==6.4.12 # via # great-expectations # jupyter - # widgetsnbextension numpy==1.21.6 # via # altair @@ -370,6 +378,7 @@ oauthlib==3.2.0 # via requests-oauthlib packaging==21.3 # via + # docker # google-cloud-bigquery # great-expectations # ipykernel @@ -380,6 +389,7 @@ packaging==21.3 # pandera # qtpy # sphinx + # tensorflow pandas==1.3.5 # via # altair @@ -398,7 +408,7 @@ pandera==0.9.0 # via -r doc-requirements.in pandocfilters==1.5.0 # via nbconvert -papermill==2.3.4 +papermill==2.4.0 # via -r doc-requirements.in parso==0.8.3 # via jedi @@ -413,7 +423,9 @@ pillow==9.2.0 # imagehash # matplotlib # visions -plotly==5.9.0 +pkgutil-resolve-name==1.3.10 + # via jsonschema +plotly==5.10.0 # via -r doc-requirements.in prometheus-client==0.14.1 # via notebook @@ -421,7 +433,7 @@ prompt-toolkit==3.0.30 # via # ipython # jupyter-console -proto-plus==1.20.6 +proto-plus==1.22.1 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -461,11 +473,11 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pydantic==1.9.1 +pydantic==1.10.0 # via # pandas-profiling # pandera -pygments==2.12.0 +pygments==2.13.0 # via # furo # ipython @@ -497,7 +509,7 @@ python-dateutil==2.8.2 # kubernetes # matplotlib # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.4 # via flytekit python-slugify[unidecode]==6.1.2 # via @@ -505,7 +517,7 @@ python-slugify[unidecode]==6.1.2 # sphinx-material pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.2.1 # via # babel # flytekit @@ -523,7 +535,7 @@ pyyaml==6.0 # pandas-profiling # papermill # sphinx-autoapi -pyzmq==23.2.0 +pyzmq==23.2.1 # via # ipykernel # jupyter-client @@ -531,9 +543,9 @@ pyzmq==23.2.0 # qtconsole qtconsole==5.3.1 # via jupyter -qtpy==2.1.0 +qtpy==2.2.0 # via qtconsole -regex==2022.6.2 +regex==2022.7.25 # via docker-image-py requests==2.28.1 # via @@ -555,12 +567,14 @@ responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit -rsa==4.8 +rsa==4.9 # via google-auth ruamel-yaml==0.17.17 # via great-expectations ruamel-yaml-clib==0.2.6 # via ruamel-yaml +scikit-learn==1.0.2 + # via skl2onnx scipy==1.7.3 # via # great-expectations @@ -573,7 +587,7 @@ seaborn==0.11.2 # via # missingno # pandas-profiling -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring send2trash==1.8.0 # via notebook @@ -581,6 +595,7 @@ singledispatchmethod==1.0 # via flytekit six==1.16.0 # via + # astunparse # bleach # google-auth # grpcio @@ -608,7 +623,7 @@ sphinx==4.5.0 # sphinx-panels # sphinx-prompt # sphinxcontrib-yt -sphinx-autoapi==1.8.4 +sphinx-autoapi==1.9.0 # via -r doc-requirements.in sphinx-basic-ng==0.0.1a12 # via furo @@ -618,7 +633,7 @@ sphinx-copybutton==0.5.0 # via -r doc-requirements.in sphinx-fontawesome==0.0.6 # via -r doc-requirements.in -sphinx-gallery==0.10.1 +sphinx-gallery==0.11.0 # via -r doc-requirements.in sphinx-material==0.0.35 # via -r doc-requirements.in @@ -640,7 +655,7 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-yt==0.2.2 # via -r doc-requirements.in -sqlalchemy==1.4.39 +sqlalchemy==1.4.40 # via -r doc-requirements.in statsd==3.3.0 # via flytekit @@ -662,9 +677,9 @@ textwrap3==0.9.2 # via ansiwrap tinycss2==1.1.1 # via nbconvert -toolz==0.11.2 +toolz==0.12.0 # via altair -torch==1.11.0 +torch==1.12.1 # via -r doc-requirements.in tornado==6.2 # via @@ -702,16 +717,18 @@ typing-extensions==4.3.0 # importlib-metadata # jsonschema # kiwisolver + # onnx # pandera # pydantic # responses + # tensorflow # torch # typing-inspect typing-inspect==0.7.1 # via # dataclasses-json # pandera -tzdata==2022.1 +tzdata==2022.2 # via pytz-deprecation-shim tzlocal==4.2 # via great-expectations @@ -719,9 +736,10 @@ unidecode==1.3.4 # via # python-slugify # sphinx-autoapi -urllib3==1.26.9 +urllib3==1.26.11 # via # botocore + # docker # flytekit # great-expectations # kubernetes @@ -749,7 +767,7 @@ wrapt==1.14.1 # deprecated # flytekit # pandera -zipp==3.8.0 +zipp==3.8.1 # via # importlib-metadata # importlib-resources diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index d0d7f7a229..46513553b9 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -7,6 +7,8 @@ CTX_PACKAGES = "pkgs" CTX_NOTIFICATIONS = "notifications" CTX_CONFIG_FILE = "config_file" +CTX_PROJECT_ROOT = "project_root" +CTX_MODULE = "module" project_option = _click.option( diff --git a/flytekit/clis/sdk_in_container/package.py b/flytekit/clis/sdk_in_container/package.py index 2a884e29da..1a849d0681 100644 --- a/flytekit/clis/sdk_in_container/package.py +++ b/flytekit/clis/sdk_in_container/package.py @@ -77,8 +77,16 @@ default="/root", help="Filesystem path to where the code is copied into within the Dockerfile. look for `COPY . /root` like command.", ) +@click.option( + "--deref-symlinks", + default=False, + is_flag=True, + help="Enables symlink dereferencing when packaging files in fast registration", +) @click.pass_context -def package(ctx, image_config, source, output, force, fast, in_container_source_path, python_interpreter): +def package( + ctx, image_config, source, output, force, fast, in_container_source_path, python_interpreter, deref_symlinks +): """ This command produces a Flyte backend registrable package of all entities in Flyte. For tasks, one pb file is produced for each task, representing one TaskTemplate object. @@ -103,6 +111,6 @@ def package(ctx, image_config, source, output, force, fast, in_container_source_ display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!") try: - serialize_and_package(pkgs, serialization_settings, source, output, fast) + serialize_and_package(pkgs, serialization_settings, source, output, fast, deref_symlinks) except NoSerializableEntitiesError: click.secho(f"No flyte objects found in packages {pkgs}", fg="yellow") diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 03e00d7896..024b70edde 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -99,6 +99,12 @@ type=str, help="Version the package or module is registered with", ) +@click.option( + "--deref-symlinks", + default=False, + is_flag=True, + help="Enables symlink dereferencing when packaging files in fast registration", +) @click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) @click.pass_context def register( @@ -111,6 +117,7 @@ def register( service_account: str, raw_data_prefix: str, version: typing.Optional[str], + deref_symlinks: bool, package_or_module: typing.Tuple[str], ): """ @@ -142,7 +149,7 @@ def register( # Create a zip file containing all the entries. detected_root = find_common_root(package_or_module) cli_logger.debug(f"Using {detected_root} as root folder for project") - zip_file = fast_package(detected_root, output) + zip_file = fast_package(detected_root, output, deref_symlinks) # Upload zip file to Admin using FlyteRemote. md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file)) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 6fece1e7d2..d0b890ba7b 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -15,11 +15,17 @@ from typing_extensions import get_args from flytekit import BlobType, Literal, Scalar -from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_DOMAIN, CTX_PROJECT +from flytekit.clis.sdk_in_container.constants import ( + CTX_CONFIG_FILE, + CTX_DOMAIN, + CTX_MODULE, + CTX_PROJECT, + CTX_PROJECT_ROOT, +) from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages -from flytekit.core import context_manager, tracker +from flytekit.core import context_manager from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContext from flytekit.core.data_persistence import FileAccessProvider @@ -27,7 +33,7 @@ from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.models import literals from flytekit.models.interface import Variable -from flytekit.models.literals import Blob, BlobMetadata, Primitive +from flytekit.models.literals import Blob, BlobMetadata, Primitive, Union from flytekit.models.types import LiteralType, SimpleType from flytekit.remote.executions import FlyteWorkflowExecution from flytekit.tools import module_loader, script_mode @@ -88,7 +94,7 @@ def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: if FileAccessProvider.is_remote(value): - return FileParam(filepath=value) + return FileParam(filepath=value, local=False) p = pathlib.Path(value) if p.exists() and p.is_file(): return FileParam(filepath=str(p.resolve())) @@ -264,8 +270,7 @@ def convert_to_union( # and then use flyte converter to convert it to literal. python_val = converter._click_type.convert(value, param, ctx) literal = converter.convert_to_literal(ctx, param, python_val) - self._python_type = python_type - return literal + return Literal(scalar=Scalar(union=Union(literal, variant))) except (Exception or AttributeError) as e: logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e) raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}") @@ -480,14 +485,12 @@ def get_entities_in_file(filename: str) -> Entities: workflows = [] tasks = [] module = importlib.import_module(module_name) - for k in dir(module): - o = module.__dict__[k] - if isinstance(o, PythonFunctionWorkflow): - _, _, fn, _ = tracker.extract_task_module(o) - workflows.append(fn) + for name in dir(module): + o = module.__dict__[name] + if isinstance(o, WorkflowBase): + workflows.append(name) elif isinstance(o, PythonTask): - _, _, fn, _ = tracker.extract_task_module(o) - tasks.append(fn) + tasks.append(name) return Entities(workflows, tasks) @@ -542,6 +545,8 @@ def _run(*args, **kwargs): domain=domain, image_config=image_config, destination_dir=run_level_params.get("destination_dir"), + source_path=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT), + module_name=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_MODULE), ) options = None @@ -602,11 +607,16 @@ def get_command(self, ctx, exe_entity): ) project_root = _find_project_root(self._filename) + # Find the relative path for the filename relative to the root of the project. # N.B.: by construction project_root will necessarily be an ancestor of the filename passed in as # a parameter. rel_path = self._filename.relative_to(project_root) module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".") + + ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_PROJECT_ROOT] = project_root + ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_MODULE] = module + entity = load_naive_entity(module, exe_entity, project_root) # If this is a remote execution, which we should know at this point, then create the remote object diff --git a/flytekit/clis/sdk_in_container/serialize.py b/flytekit/clis/sdk_in_container/serialize.py index 0b12d6b406..33c0b47940 100644 --- a/flytekit/clis/sdk_in_container/serialize.py +++ b/flytekit/clis/sdk_in_container/serialize.py @@ -155,16 +155,22 @@ def fast(ctx): @click.command("workflows") +@click.option( + "--deref-symlinks", + default=False, + is_flag=True, + help="Enables symlink dereferencing when packaging files in fast registration", +) @click.option("-f", "--folder", type=click.Path(exists=True)) @click.pass_context -def fast_workflows(ctx, folder=None): +def fast_workflows(ctx, folder=None, deref_symlinks=False): if folder: click.echo(f"Writing output to {folder}") source_dir = ctx.obj[CTX_LOCAL_SRC_ROOT] # Write using gzip - archive_fname = fast_package(source_dir, folder) + archive_fname = fast_package(source_dir, folder, deref_symlinks) click.echo(f"Wrote compressed archive to {archive_fname}") pkgs = ctx.obj[CTX_PACKAGES] diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index c0edf07ab2..047ce4b3d3 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -94,6 +94,7 @@ from flytekit.configuration import internal as _internal from flytekit.configuration.default_images import DefaultImages from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, read_file_if_exists, set_if_exists +from flytekit.loggers import logger PROJECT_PLACEHOLDER = "{{ registration.project }}" DOMAIN_PLACEHOLDER = "{{ registration.domain }}" @@ -336,10 +337,16 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None kwargs, "client_credentials_secret", _internal.Credentials.CLIENT_CREDENTIALS_SECRET.read(config_file) ) + client_credentials_secret = read_file_if_exists( + _internal.Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(config_file) + ) + if client_credentials_secret and client_credentials_secret.endswith("\n"): + logger.info("Newline stripped from client secret") + client_credentials_secret = client_credentials_secret.strip() kwargs = set_if_exists( kwargs, "client_credentials_secret", - read_file_if_exists(_internal.Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(config_file)), + client_credentials_secret, ) kwargs = set_if_exists(kwargs, "scopes", _internal.Credentials.SCOPES.read(config_file)) kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index e69e3f6476..a2ad5311f1 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -171,6 +171,10 @@ def is_supported_protocol(cls, protocol: str) -> bool: """ return protocol in cls._PLUGINS + @classmethod + def supported_protocols(cls) -> typing.List[str]: + return [k for k in cls._PLUGINS.keys()] + class DiskPersistence(DataPersistence): """ diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 0f651410bf..c721e2e160 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -7,7 +7,7 @@ from collections import OrderedDict from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union -from typing_extensions import get_args, get_origin, get_type_hints +from typing_extensions import Annotated, get_args, get_origin, get_type_hints from flytekit.core import context_manager from flytekit.core.docstring import Docstring @@ -259,12 +259,16 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface: def _change_unrecognized_type_to_pickle(t: Type[T]) -> Type[T]: try: if hasattr(t, "__origin__") and hasattr(t, "__args__"): - if t.__origin__ == list: + if get_origin(t) is list: return typing.List[_change_unrecognized_type_to_pickle(t.__args__[0])] - elif t.__origin__ == dict and t.__args__[0] == str: + elif get_origin(t) is dict and t.__args__[0] == str: return typing.Dict[str, _change_unrecognized_type_to_pickle(t.__args__[1])] - else: - TypeEngine.get_transformer(t) + elif get_origin(t) is typing.Union: + return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))] + elif get_origin(t) is Annotated: + base_type, *config = get_args(t) + return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)] + TypeEngine.get_transformer(t) except ValueError: logger.warning( f"Unsupported Type {t} found, Flyte will default to use PickleFile as the transport. " @@ -329,7 +333,11 @@ def transform_variable_map( elif v.__origin__ is dict: sub_type = v.__args__[1] if hasattr(sub_type, "__origin__") and sub_type.__origin__ is FlytePickle: - res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__} + if hasattr(sub_type.python_type(), "__name__"): + res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__} + elif hasattr(sub_type.python_type(), "_name"): + # If the class doesn't have the __name__ attribute, like typing.Sequence, use _name instead. + res[k].type.metadata = {"python_class_name": sub_type.python_type()._name} return res diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 4afab73955..48b6f9c7da 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -1,8 +1,8 @@ -import base64 from typing import Optional -import cloudpickle from diskcache import Cache +from google.protobuf.struct_pb2 import Struct +from joblib.hashing import NumpyHasher from flytekit.models.literals import Literal, LiteralCollection, LiteralMap @@ -28,15 +28,26 @@ def _recursive_hash_placement(literal: Literal) -> Literal: return literal +class ProtoJoblibHasher(NumpyHasher): + def save(self, obj): + if isinstance(obj, Struct): + obj = dict( + rewrite_rule="google.protobuf.struct_pb2.Struct", + cls=obj.__class__, + obj=dict(sorted(obj.fields.items())), + ) + NumpyHasher.save(self, obj) + + def _calculate_cache_key(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> str: # Traverse the literals and replace the literal with a new literal that only contains the hash literal_map_overridden = {} for key, literal in input_literal_map.literals.items(): literal_map_overridden[key] = _recursive_hash_placement(literal) - # Pickle the literal map and use base64 encoding to generate a representation of it - b64_encoded = base64.b64encode(cloudpickle.dumps(LiteralMap(literal_map_overridden))) - return f"{task_name}-{cache_version}-{b64_encoded}" + # Generate a hash key of inputs with joblib + hashed_inputs = ProtoJoblibHasher().hash(literal_map_overridden) + return f"{task_name}-{cache_version}-{hashed_inputs}" class LocalTaskCache(object): diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 4c9150881d..4fe8e669ab 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -679,6 +679,11 @@ def __rshift__(self, other: typing.Union[Promise, VoidPromise]): if self.ref: self.ref.node.runs_before(other.ref.node) + def with_overrides(self, *args, **kwargs): + if self.ref: + self.ref.node.with_overrides(*args, **kwargs) + return self + @property def task_name(self): return self._task_name diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 84a8eaedef..e3a10afdf3 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -19,6 +19,8 @@ from enum import Enum from typing import Any, Callable, List, Optional, TypeVar, Union +from flytekit.configuration import SerializationSettings +from flytekit.configuration.default_images import DefaultImages from flytekit.core.base_task import Task, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.docstring import Docstring @@ -257,10 +259,17 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: representing that newly generated workflow, instead of executing it. """ ctx = FlyteContextManager.current_context() + # This is a placeholder SerializationSettings placeholder and is only used to test compilation for dynamic tasks + # when run locally. The output of the compilation should never actually be used anywhere. + _LOCAL_ONLY_SS = SerializationSettings.for_image(DefaultImages.default_image(), "v", "p", "d") if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: updated_exec_state = ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION) - with FlyteContextManager.with_context(ctx.with_execution_state(updated_exec_state)): + with FlyteContextManager.with_context( + ctx.with_execution_state(updated_exec_state).with_serialization_settings(_LOCAL_ONLY_SS) + ) as ctx: + logger.debug(f"Running compilation for {self} as part of local run as check") + self.compile_into_workflow(ctx, task_function, **kwargs) logger.info("Executing Dynamic workflow, using raw inputs") return exception_scopes.user_entry_point(task_function)(**kwargs) diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 0fad8335c2..9851e2e98b 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -4,6 +4,7 @@ import inspect as _inspect import os import typing +from types import ModuleType from typing import Callable, Tuple, Union from flytekit.configuration.feature_flags import FeatureFlags @@ -239,6 +240,11 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, if mod_name == "__main__": return name, "", name, os.path.abspath(inspect.getfile(f)) + mod_name = get_full_module_path(mod, mod_name) + return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod)) + + +def get_full_module_path(mod: ModuleType, mod_name: str) -> str: if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != ".": package_root = ( FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != "auto" else None @@ -247,4 +253,4 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, # We only replace the mod_name if it is more specific, else we already have a fully resolved path if len(new_mod_name) > len(mod_name): mod_name = new_mod_name - return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod)) + return mod_name diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index b3851e77ce..686b03bb8c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -270,6 +270,46 @@ class Test(): def __init__(self): super().__init__("Object-Dataclass-Transformer", object) + def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): + # Skip iterating all attributes in the dataclass if the type of v already matches the expected_type + if type(v) == expected_type: + return + + # @dataclass_json + # @dataclass + # class Foo(object): + # a: int = 0 + # + # @task + # def t1(a: Foo): + # ... + # + # In above example, the type of v may not equal to the expected_type in some cases + # For example, + # 1. The input of t1 is another dataclass (bar), then we should raise an error + # 2. when using flyte remote to execute the above task, the expected_type is guess_python_type (FooSchema) by default. + # However, FooSchema is created by flytekit and it's not equal to the user-defined dataclass (Foo). + # Therefore, we should iterate all attributes in the dataclass and check the type of value in dataclass matches the expected_type. + + expected_fields_dict = {} + for f in dataclasses.fields(expected_type): + expected_fields_dict[f.name] = f.type + + for f in dataclasses.fields(type(v)): + original_type = f.type + expected_type = expected_fields_dict[f.name] + + if UnionTransformer.is_optional_type(original_type): + original_type = UnionTransformer.get_sub_type_in_optional(original_type) + if UnionTransformer.is_optional_type(expected_type): + expected_type = UnionTransformer.get_sub_type_in_optional(expected_type) + + val = v.__getattribute__(f.name) + if dataclasses.is_dataclass(val): + self.assert_type(expected_type, val) + elif original_type != expected_type: + raise TypeTransformerFailedError(f"Type of Val '{original_type}' is not an instance of {expected_type}") + def get_literal_type(self, t: Type[T]) -> LiteralType: """ Extracts the Literal type definition for a Dataclass and returns a type Struct. @@ -449,7 +489,9 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> return python_val def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: - if t == int: + if val is None: + return val + if t == int or t == typing.Optional[int]: return int(val) if isinstance(val, list): @@ -461,6 +503,13 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: # Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}}) return {self._fix_val_int(ktype, k): self._fix_val_int(vtype, v) for k, v in val.items()} + if get_origin(t) is typing.Union and type(None) in get_args(t): + # Handle optional type. e.g. Optional[int], Optional[dataclass] + # Marshmallow doesn't support union type, so the type here is always an optional type. + # https://github.com/marshmallow-code/marshmallow/issues/1191#issuecomment-480831796 + # Note: Union[None, int] is also an optional type, but Marshmallow does not support it. + return self._fix_val_int(get_args(t)[0], val) + if dataclasses.is_dataclass(t): return self._fix_dataclass_int(t, val) # type: ignore @@ -975,6 +1024,17 @@ class UnionTransformer(TypeTransformer[T]): def __init__(self): super().__init__("Typed Union", typing.Union) + @staticmethod + def is_optional_type(t: Type[T]) -> bool: + return get_origin(t) is typing.Union and type(None) in get_args(t) + + @staticmethod + def get_sub_type_in_optional(t: Type[T]) -> Type[T]: + """ + Return the generic Type T of the Optional type + """ + return get_args(t)[0] + def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: if get_origin(t) is Annotated: t = get_args(t)[0] @@ -1309,6 +1369,11 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac def _get_element_type(element_property: typing.Dict[str, str]) -> Type[T]: element_type = element_property["type"] element_format = element_property["format"] if "format" in element_property else None + + if type(element_type) == list: + # Element type of Optional[int] is [integer, None] + return typing.Optional[_get_element_type({"type": element_type[0]})] + if element_type == "string": return str elif element_type == "integer": diff --git a/flytekit/deck/renderer.py b/flytekit/deck/renderer.py index 8617ae4d12..0cf781d3da 100644 --- a/flytekit/deck/renderer.py +++ b/flytekit/deck/renderer.py @@ -1,6 +1,7 @@ from typing import Any, Optional import pandas +import pyarrow from typing_extensions import Protocol, runtime_checkable @@ -24,3 +25,13 @@ def __init__(self, max_rows: Optional[int] = None): def to_html(self, df: pandas.DataFrame) -> str: assert isinstance(df, pandas.DataFrame) return df.to_html(max_rows=self._max_rows) + + +class ArrowRenderer: + """ + Render a Arrow dataframe as an HTML table. + """ + + def to_html(self, df: pyarrow.Table) -> str: + assert isinstance(df, pyarrow.Table) + return df.to_string() diff --git a/flytekit/extras/pytorch/__init__.py b/flytekit/extras/pytorch/__init__.py index ae077d9755..770fe11b73 100644 --- a/flytekit/extras/pytorch/__init__.py +++ b/flytekit/extras/pytorch/__init__.py @@ -11,10 +11,21 @@ """ from flytekit.loggers import logger +# TODO: abstract this out so that there's an established pattern for registering plugins +# that have soft dependencies try: + # isolate the exception to the torch import + import torch + + _torch_installed = True +except (ImportError, OSError): + _torch_installed = False + + +if _torch_installed: from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer from .native import PyTorchModuleTransformer, PyTorchTensorTransformer -except ImportError: +else: logger.info( "We won't register PyTorchCheckpointTransformer, PyTorchTensorTransformer, and PyTorchModuleTransformer because torch is not installed." ) diff --git a/flytekit/extras/tasks/shell.py b/flytekit/extras/tasks/shell.py index 812c0a3749..be7cda0a17 100644 --- a/flytekit/extras/tasks/shell.py +++ b/flytekit/extras/tasks/shell.py @@ -132,7 +132,8 @@ def __init__( script_file = os.path.abspath(script_file) if task_config is not None: - if str(type(task_config)) != "flytekitplugins.pod.task.Pod": + fully_qualified_class_name = task_config.__module__ + "." + task_config.__class__.__name__ + if not fully_qualified_class_name == "flytekitplugins.pod.task.Pod": raise ValueError("TaskConfig can either be empty - indicating simple container task or a PodConfig.") # Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 99f54b7933..f02226decc 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -577,6 +577,8 @@ def register_script( destination_dir: str = ".", default_launch_plan: typing.Optional[bool] = True, options: typing.Optional[Options] = None, + source_path: typing.Optional[str] = None, + module_name: typing.Optional[str] = None, ) -> typing.Union[FlyteWorkflow, FlyteTask]: """ Use this method to register a workflow via script mode. @@ -588,13 +590,16 @@ def register_script( :param entity: The workflow to be registered or the task to be registered :param default_launch_plan: This should be true if a default launch plan should be created for the workflow :param options: Additional execution options that can be configured for the default launchplan + :param source_path: The root of the project path + :param module_name: the name of the module :return: """ if image_config is None: image_config = ImageConfig.auto_default_image() upload_location, md5_bytes = fast_register_single_script( - entity, + source_path, + module_name, functools.partial( self.client.get_upload_signed_url, project=project or self.default_project, diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index c4ac31a01a..34faadc58c 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -21,12 +21,13 @@ file_access = FlyteContextManager.current_context().file_access -def fast_package(source: os.PathLike, output_dir: os.PathLike) -> os.PathLike: +def fast_package(source: os.PathLike, output_dir: os.PathLike, deref_symlinks: bool = False) -> os.PathLike: """ Takes a source directory and packages everything not covered by common ignores into a tarball named after a hexdigest of the included files. :param os.PathLike source: :param os.PathLike output_dir: + :param bool deref_symlinks: Enables dereferencing symlinks when packaging directory :return os.PathLike: """ ignore = IgnoreGroup(source, [GitIgnore, DockerIgnore, StandardIgnore]) @@ -41,7 +42,7 @@ def fast_package(source: os.PathLike, output_dir: os.PathLike) -> os.PathLike: with tempfile.TemporaryDirectory() as tmp_dir: tar_path = os.path.join(tmp_dir, "tmp.tar") - with tarfile.open(tar_path, "w") as tar: + with tarfile.open(tar_path, "w", dereference=deref_symlinks) as tar: tar.add(source, arcname="", filter=lambda x: ignore.tar_filter(tar_strip_file_attributes(x))) with gzip.GzipFile(filename=archive_fname, mode="wb", mtime=0) as gzipped: with open(tar_path, "rb") as tar_file: diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 167c772184..ceaee36435 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -75,6 +75,7 @@ def package( source: str = ".", output: str = "./flyte-package.tgz", fast: bool = False, + deref_symlinks: bool = False, ): """ Package the given entities and the source code (if fast is enabled) into a package with the given name in output @@ -82,6 +83,7 @@ def package( :param source: source folder :param output: output package name with suffix :param fast: fast enabled implies source code is bundled + :param deref_symlinks: if enabled then symlinks are dereferenced during packaging """ if not registrable_entities: raise NoSerializableEntitiesError("Nothing to package") @@ -95,7 +97,7 @@ def package( if os.path.abspath(output).startswith(os.path.abspath(source)) and os.path.exists(output): click.secho(f"{output} already exists within {source}, deleting and re-creating it", fg="yellow") os.remove(output) - archive_fname = fast_registration.fast_package(source, output_tmpdir) + archive_fname = fast_registration.fast_package(source, output_tmpdir, deref_symlinks) click.secho(f"Fast mode enabled: compressed archive {archive_fname}", dim=True) with tarfile.open(output, "w:gz") as tar: @@ -110,13 +112,14 @@ def serialize_and_package( source: str = ".", output: str = "./flyte-package.tgz", fast: bool = False, + deref_symlinks: bool = False, options: typing.Optional[Options] = None, ): """ Fist serialize and then package all entities """ registrable_entities = serialize(pkgs, settings, source, options=options) - package(registrable_entities, source, output, fast) + package(registrable_entities, source, output, fast, deref_symlinks) def register( diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index f837447637..29b617824c 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -1,5 +1,6 @@ import gzip import hashlib +import importlib import os import shutil import tarfile @@ -10,8 +11,7 @@ from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2 from flytekit.core import context_manager -from flytekit.core.tracker import extract_task_module -from flytekit.core.workflow import WorkflowBase +from flytekit.core.tracker import get_full_module_path def compress_single_script(source_path: str, destination: str, full_module_name: str): @@ -97,16 +97,14 @@ def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo: def fast_register_single_script( - wf_entity: WorkflowBase, create_upload_location_fn: typing.Callable + source_path: str, module_name: str, create_upload_location_fn: typing.Callable ) -> (_data_proxy_pb2.CreateUploadLocationResponse, bytes): - _, mod_name, _, script_full_path = extract_task_module(wf_entity) - # Find project root by moving up the folder hierarchy until you cannot find a __init__.py file. - source_path = _find_project_root(script_full_path) # Open a temp directory and dump the contents of the digest. with tempfile.TemporaryDirectory() as tmp_dir: archive_fname = os.path.join(tmp_dir, "script_mode.tar.gz") - compress_single_script(source_path, archive_fname, mod_name) + mod = importlib.import_module(module_name) + compress_single_script(source_path, archive_fname, get_full_module_path(mod, mod.__name__)) flyte_ctx = context_manager.FlyteContextManager.current_context() md5, _ = hash_file(archive_fname) diff --git a/flytekit/types/numpy/ndarray.py b/flytekit/types/numpy/ndarray.py index cb1cf2a900..b4f67b94f1 100644 --- a/flytekit/types/numpy/ndarray.py +++ b/flytekit/types/numpy/ndarray.py @@ -52,7 +52,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: try: uri = lv.scalar.blob.uri except AttributeError: - TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") local_path = ctx.file_access.get_random_local_path() ctx.file_access.get_data(uri, local_path, is_multipart=False) diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 82680d2787..71dff61c5e 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -7,16 +7,13 @@ import pyarrow.parquet as pq from flytekit import FlyteContext -from flytekit.core.data_persistence import DataPersistencePlugins +from flytekit.deck import TopFrameRenderer +from flytekit.deck.renderer import ArrowRenderer from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( - ABFS, - GCS, - LOCAL, PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -27,10 +24,8 @@ class PandasToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pd.DataFrame, protocol, PARQUET) - # todo: Use this somehow instead of relaying ont he ctx file_access - self._persistence = DataPersistencePlugins.find_plugin(protocol)() + def __init__(self): + super().__init__(pd.DataFrame, None, PARQUET) def encode( self, @@ -50,8 +45,8 @@ def encode( class ParquetToPandasDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pd.DataFrame, protocol, PARQUET) + def __init__(self): + super().__init__(pd.DataFrame, None, PARQUET) def decode( self, @@ -69,8 +64,8 @@ def decode( class ArrowToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pa.Table, protocol, PARQUET) + def __init__(self): + super().__init__(pa.Table, None, PARQUET) def encode( self, @@ -88,8 +83,8 @@ def encode( class ParquetToArrowDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pa.Table, protocol, PARQUET) + def __init__(self): + super().__init__(pa.Table, None, PARQUET) def decode( self, @@ -106,9 +101,10 @@ def decode( return pq.read_table(local_dir) -# Don't override default protocol -for protocol in [LOCAL, S3, GCS, ABFS]: - StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), default_for_type=False) - StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), default_for_type=False) - StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), default_for_type=False) - StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), default_for_type=False) +StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler()) +StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler()) + +StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) +StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) diff --git a/flytekit/types/structured/bigquery.py b/flytekit/types/structured/bigquery.py index 92d203e25d..85cede1544 100644 --- a/flytekit/types/structured/bigquery.py +++ b/flytekit/types/structured/bigquery.py @@ -10,7 +10,6 @@ from flytekit.models import literals from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( - BIGQUERY, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -18,6 +17,8 @@ StructuredDatasetTransformerEngine, ) +BIGQUERY = "bq" + def _write_to_bq(structured_dataset: StructuredDataset): table_id = typing.cast(str, structured_dataset.uri).split("://", 1)[1].replace(":", ".") @@ -111,7 +112,7 @@ def decode( return pa.Table.from_pandas(_read_from_bq(flyte_value, current_task_metadata)) -StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers(), default_for_type=False) -StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler(), default_for_type=False) -StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers(), default_for_type=False) -StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler(), default_for_type=False) +StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers()) +StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler()) +StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers()) +StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index cdb26a87c2..bdad752b16 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -1,9 +1,7 @@ from __future__ import annotations import collections -import importlib import os -import re import types import typing from abc import ABC, abstractmethod @@ -12,21 +10,16 @@ import _datetime import numpy as _np -import pandas import pandas as pd -import pyarrow import pyarrow as pa - -if importlib.util.find_spec("pyspark") is not None: - import pyspark -if importlib.util.find_spec("polars") is not None: - import polars as pl from dataclasses_json import config, dataclass_json from marshmallow import fields from typing_extensions import Annotated, TypeAlias, get_args, get_origin from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence from flytekit.core.type_engine import TypeEngine, TypeTransformer +from flytekit.deck.renderer import Renderable from flytekit.loggers import logger from flytekit.models import literals from flytekit.models import types as type_models @@ -36,13 +29,6 @@ T = typing.TypeVar("T") # StructuredDataset type or a dataframe type DF = typing.TypeVar("DF") # Dataframe type -# Protocols -BIGQUERY = "bq" -S3 = "s3" -ABFS = "abfs" -GCS = "gs" -LOCAL = "/" - # For specifying the storage formats of StructuredDatasets. It's just a string, nothing fancy. StructuredDatasetFormat: TypeAlias = str @@ -156,7 +142,7 @@ def extract_cols_and_format( if ordered_dict_cols is not None: raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {aa}") ordered_dict_cols = aa - elif isinstance(aa, pyarrow.Schema): + elif isinstance(aa, pa.Schema): if pa_schema is not None: raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {aa}") pa_schema = aa @@ -168,7 +154,7 @@ def extract_cols_and_format( class StructuredDatasetEncoder(ABC): - def __init__(self, python_type: Type[T], protocol: str, supported_format: Optional[str] = None): + def __init__(self, python_type: Type[T], protocol: Optional[str] = None, supported_format: Optional[str] = None): """ Extend this abstract class, implement the encode function, and register your concrete class with the StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle @@ -179,12 +165,14 @@ def __init__(self, python_type: Type[T], protocol: str, supported_format: Option :param python_type: The dataframe class in question that you want to register this encoder with :param protocol: A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either "s3" or "s3://". They are the same since the "://" will just be stripped by the constructor. + If None, this encoder will be registered with all protocols that flytekit's data persistence layer + is capable of handling. :param supported_format: Arbitrary string representing the format. If not supplied then an empty string will be used. An empty string implies that the encoder works with any format. If the format being asked for does not exist, the transformer enginer will look for the "" endcoder instead and write a warning. """ self._python_type = python_type - self._protocol = protocol.replace("://", "") + self._protocol = protocol.replace("://", "") if protocol else None self._supported_format = supported_format or "" @property @@ -192,7 +180,7 @@ def python_type(self) -> Type[T]: return self._python_type @property - def protocol(self) -> str: + def protocol(self) -> Optional[str]: return self._protocol @property @@ -228,7 +216,7 @@ def encode( class StructuredDatasetDecoder(ABC): - def __init__(self, python_type: Type[DF], protocol: str, supported_format: Optional[str] = None): + def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, supported_format: Optional[str] = None): """ Extend this abstract class, implement the decode function, and register your concrete class with the StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle @@ -238,12 +226,14 @@ def __init__(self, python_type: Type[DF], protocol: str, supported_format: Optio :param python_type: The dataframe class in question that you want to register this decoder with :param protocol: A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either "s3" or "s3://". They are the same since the "://" will just be stripped by the constructor. + If None, this decoder will be registered with all protocols that flytekit's data persistence layer + is capable of handling. :param supported_format: Arbitrary string representing the format. If not supplied then an empty string will be used. An empty string implies that the decoder works with any format. If the format being asked for does not exist, the transformer enginer will look for the "" decoder instead and write a warning. """ self._python_type = python_type - self._protocol = protocol.replace("://", "") + self._protocol = protocol.replace("://", "") if protocol else None self._supported_format = supported_format or "" @property @@ -251,7 +241,7 @@ def python_type(self) -> Type[DF]: return self._python_type @property - def protocol(self) -> str: + def protocol(self) -> Optional[str]: return self._protocol @property @@ -281,10 +271,8 @@ def decode( def protocol_prefix(uri: str) -> str: - g = re.search(r"([\w]+)://.*", uri) - if g and g.groups(): - return g.groups()[0] - return LOCAL + p = DataPersistencePlugins.get_protocol(uri) + return p def convert_schema_type_to_structured_dataset_type( @@ -306,6 +294,10 @@ def convert_schema_type_to_structured_dataset_type( raise AssertionError(f"Unrecognized SchemaColumnType: {column_type}") +class DuplicateHandlerError(ValueError): + ... + + class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): """ Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. @@ -340,6 +332,7 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): DEFAULT_FORMATS: Dict[Type, str] = {} Handlers = Union[StructuredDatasetEncoder, StructuredDatasetDecoder] + Renderers: Dict[Type, Renderable] = {} @staticmethod def _finder(handler_map, df_type: Type, protocol: str, format: str): @@ -366,8 +359,7 @@ def get_decoder(cls, df_type: Type, protocol: str, format: str): return cls._finder(StructuredDatasetTransformerEngine.DECODERS, df_type, protocol, format) @classmethod - def _handler_finder(cls, h: Handlers) -> Dict[str, Handlers]: - # Maybe think about default dict in the future, but is typing as nice? + def _handler_finder(cls, h: Handlers, protocol: str) -> Dict[str, Handlers]: if isinstance(h, StructuredDatasetEncoder): top_level = cls.ENCODERS elif isinstance(h, StructuredDatasetDecoder): @@ -376,9 +368,9 @@ def _handler_finder(cls, h: Handlers) -> Dict[str, Handlers]: raise TypeError(f"We don't support this type of handler {h}") if h.python_type not in top_level: top_level[h.python_type] = {} - if h.protocol not in top_level[h.python_type]: - top_level[h.python_type][h.protocol] = {} - return top_level[h.python_type][h.protocol] + if protocol not in top_level[h.python_type]: + top_level[h.python_type][protocol] = {} + return top_level[h.python_type][protocol] def __init__(self): super().__init__("StructuredDataset Transformer", StructuredDataset) @@ -388,22 +380,69 @@ def __init__(self): self._hash_overridable = True @classmethod - def register(cls, h: Handlers, default_for_type: Optional[bool] = True, override: Optional[bool] = False): + def register_renderer(cls, python_type: Type, renderer: Renderable): + cls.Renderers[python_type] = renderer + + @classmethod + def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False): """ - Call this with any handler to register it with this dataframe meta-transformer + Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not + specify a protocol (e.g. s3, gs, etc.) field, then + + :param h: The StructuredDatasetEncoder or StructuredDatasetDecoder you wish to register with this transformer. + :param default_for_type: If set, when a user returns from a task an instance of the dataframe the handler + handles, e.g. ``return pd.DataFrame(...)``, not wrapped around the ``StructuredDataset`` object, we will + use this handler's protocol and format as the default, effectively saying that this handler will be called. + Note that this shouldn't be set if your handler's protocol is None, because that implies that your handler + is capable of handling all the different storage protocols that flytekit's data persistence layer is aware of. + In these cases, the protocol is determined by the raw output data prefix set in the active context. + :param override: Override any previous registrations. If default_for_type is also set, this will also override + the default. + """ + if not (isinstance(h, StructuredDatasetEncoder) or isinstance(h, StructuredDatasetDecoder)): + raise TypeError(f"We don't support this type of handler {h}") - The string "://" should not be present in any handler's protocol so we don't check for it. + if h.protocol is None: + if default_for_type: + raise ValueError(f"Registering SD handler {h} with all protocols should never have default specified.") + for persistence_protocol in DataPersistencePlugins.supported_protocols(): + # TODO: Clean this up when we get to replacing the persistence layer. + # The behavior of the protocols given in the supported_protocols and is_supported_protocol + # is not actually the same as the one returned in get_protocol. + stripped = DataPersistencePlugins.get_protocol(persistence_protocol) + logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}") + try: + cls.register_for_protocol(h, stripped, False, override) + except DuplicateHandlerError: + logger.debug(f"Skipping {persistence_protocol}/{stripped} for {h} because duplicate") + + elif h.protocol == "": + raise ValueError(f"Use None instead of empty string for registering handler {h}") + else: + cls.register_for_protocol(h, h.protocol, default_for_type, override) + + @classmethod + def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: bool, override: bool): + """ + See the main register function instead. """ - lowest_level = cls._handler_finder(h) + if protocol == "/": + # TODO: Special fix again, because get_protocol returns file, instead of file:// + protocol = DataPersistencePlugins.get_protocol(DiskPersistence.PROTOCOL) + lowest_level = cls._handler_finder(h, protocol) if h.supported_format in lowest_level and override is False: - raise ValueError(f"Already registered a handler for {(h.python_type, h.protocol, h.supported_format)}") + raise DuplicateHandlerError( + f"Already registered a handler for {(h.python_type, protocol, h.supported_format)}" + ) lowest_level[h.supported_format] = h - logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {h.protocol}, fmt {h.supported_format}") + logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}") if default_for_type: - # TODO: Add logging, think about better ux, maybe default False and warn if doesn't exist. + logger.debug( + f"Using storage {protocol} and format {h.supported_format} for dataframes of type {h.python_type} from handler {h}" + ) cls.DEFAULT_FORMATS[h.python_type] = h.supported_format - cls.DEFAULT_PROTOCOLS[h.python_type] = h.protocol + cls.DEFAULT_PROTOCOLS[h.python_type] = protocol # Register with the type engine as well # The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as @@ -657,19 +696,10 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ else: df = python_val - if isinstance(df, pandas.DataFrame): - return df.describe().to_html() - elif isinstance(df, pa.Table): - return df.to_string() - elif isinstance(df, _np.ndarray): - return pd.DataFrame(df).describe().to_html() - elif importlib.util.find_spec("pyspark") is not None and isinstance(df, pyspark.sql.DataFrame): - return pd.DataFrame(df.schema, columns=["StructField"]).to_html() - elif importlib.util.find_spec("polars") is not None and isinstance(df, pl.DataFrame): - describe_df = df.describe() - return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False) + if type(df) in self.Renderers: + return self.Renderers[type(df)].to_html(df) else: - raise NotImplementedError("Conversion to html string should be implemented") + raise NotImplementedError(f"Could not find a renderer for {type(df)} in {self.Renderers}") def open_as( self, diff --git a/plugins/flytekit-aws-athena/setup.py b/plugins/flytekit-aws-athena/setup.py index 1164b99d00..0cea449a97 100644 --- a/plugins/flytekit-aws-athena/setup.py +++ b/plugins/flytekit-aws-athena/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-aws-batch/setup.py b/plugins/flytekit-aws-batch/setup.py index e176e35aae..68ad62750c 100644 --- a/plugins/flytekit-aws-batch/setup.py +++ b/plugins/flytekit-aws-batch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index 2781192f41..76a816fe06 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "sagemaker-training>=3.6.2,<4.0.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "sagemaker-training>=3.6.2,<4.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-bigquery/setup.py b/plugins/flytekit-bigquery/setup.py index e84cd6ce2b..0e7eed5d9d 100644 --- a/plugins/flytekit-bigquery/setup.py +++ b/plugins/flytekit-bigquery/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "google-cloud-bigquery"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "google-cloud-bigquery"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py index e8d88f26c3..68ee456ed6 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py @@ -25,13 +25,15 @@ import importlib from flytekit import StructuredDatasetTransformerEngine, logger -from flytekit.configuration import internal -from flytekit.types.structured.structured_dataset import ABFS, GCS, S3 from .arrow import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler from .pandas import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler from .persist import FSSpecPersistence +S3 = "s3" +ABFS = "abfs" +GCS = "gs" + def _register(protocol: str): logger.info(f"Registering fsspec {protocol} implementations and overriding default structured encoder/decoder.") diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py index 65a440b785..e4986ed9f6 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py @@ -13,7 +13,6 @@ from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -22,7 +21,7 @@ def get_storage_options(cfg: DataConfig, uri: str) -> typing.Optional[typing.Dict]: protocol = FSSpecPersistence.get_protocol(uri) - if protocol == S3: + if protocol == "s3": kwargs = s3_setup_args(cfg.s3) if kwargs: return kwargs diff --git a/plugins/flytekit-data-fsspec/setup.py b/plugins/flytekit-data-fsspec/setup.py index 3678a0b518..f7d03690a8 100644 --- a/plugins/flytekit-data-fsspec/setup.py +++ b/plugins/flytekit-data-fsspec/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-data-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "fsspec>=2021.7.0", "botocore>=1.7.48"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "fsspec>=2021.7.0", "botocore>=1.7.48", "pandas>=1.2.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-deck-standard/setup.py b/plugins/flytekit-deck-standard/setup.py index fe04ab5434..a47bf0f0d0 100644 --- a/plugins/flytekit-deck-standard/setup.py +++ b/plugins/flytekit-deck-standard/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}-standard" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "markdown", "plotly", "pandas_profiling"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "markdown", "plotly", "pandas_profiling"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-dolt/setup.py b/plugins/flytekit-dolt/setup.py index bb8b572bc7..ce6abbc64b 100644 --- a/plugins/flytekit-dolt/setup.py +++ b/plugins/flytekit-dolt/setup.py @@ -6,7 +6,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "dolt_integrations>=0.1.5"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "dolt_integrations>=0.1.5"] dev_requires = ["pytest-mock>=3.6.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-greatexpectations/setup.py b/plugins/flytekit-greatexpectations/setup.py index e50f707624..93c73d5416 100644 --- a/plugins/flytekit-greatexpectations/setup.py +++ b/plugins/flytekit-greatexpectations/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "great-expectations>=0.13.30", "sqlalchemy>=1.4.23"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "great-expectations>=0.13.30", "sqlalchemy>=1.4.23"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-hive/setup.py b/plugins/flytekit-hive/setup.py index f9602500f8..a2f67d982f 100644 --- a/plugins/flytekit-hive/setup.py +++ b/plugins/flytekit-hive/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-k8s-pod/setup.py b/plugins/flytekit-k8s-pod/setup.py index 01704e1a6a..29c56922b5 100644 --- a/plugins/flytekit-k8s-pod/setup.py +++ b/plugins/flytekit-k8s-pod/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.1.0b0,<1.2.0", + "flytekit>=1.1.0b0,<2.0.0", "kubernetes>=12.0.1", ] diff --git a/plugins/flytekit-kf-mpi/setup.py b/plugins/flytekit-kf-mpi/setup.py index 18f168af18..c8a845fb13 100644 --- a/plugins/flytekit-kf-mpi/setup.py +++ b/plugins/flytekit-kf-mpi/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "flyteidl>=0.21.4"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "flyteidl>=0.21.4"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index 2e0b57a7f8..dc10722bd9 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-tensorflow/setup.py b/plugins/flytekit-kf-tensorflow/setup.py index 45d8fe6b2e..5ec98ea74b 100644 --- a/plugins/flytekit-kf-tensorflow/setup.py +++ b/plugins/flytekit-kf-tensorflow/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" # TODO: Requirements are missing, add them back in later. -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-modin/setup.py b/plugins/flytekit-modin/setup.py index 777a19db47..46c5dbc02e 100644 --- a/plugins/flytekit-modin/setup.py +++ b/plugins/flytekit-modin/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.1.0b0,<1.2.0", + "flytekit>=1.1.0b0,<2.0.0", "modin>=0.13.0", "fsspec", "ray", diff --git a/plugins/flytekit-onnx-pytorch/setup.py b/plugins/flytekit-onnx-pytorch/setup.py index 74e3b940ec..0642054565 100644 --- a/plugins/flytekit-onnx-pytorch/setup.py +++ b/plugins/flytekit-onnx-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "torch>=1.11.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "torch>=1.11.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-onnx-scikitlearn/setup.py b/plugins/flytekit-onnx-scikitlearn/setup.py index 9815bedaf2..46a2bceaf7 100644 --- a/plugins/flytekit-onnx-scikitlearn/setup.py +++ b/plugins/flytekit-onnx-scikitlearn/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "skl2onnx>=1.10.3"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "skl2onnx>=1.10.3"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-onnx-tensorflow/setup.py b/plugins/flytekit-onnx-tensorflow/setup.py index d2865b083d..53d35e7fbd 100644 --- a/plugins/flytekit-onnx-tensorflow/setup.py +++ b/plugins/flytekit-onnx-tensorflow/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "tf2onnx>=1.9.3", "tensorflow>=2.7.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "tf2onnx>=1.9.3", "tensorflow>=2.7.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-pandera/setup.py b/plugins/flytekit-pandera/setup.py index cbe9bf4061..0625c138d2 100644 --- a/plugins/flytekit-pandera/setup.py +++ b/plugins/flytekit-pandera/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "pandera>=0.7.1"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "pandera>=0.7.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-papermill/dev-requirements.in b/plugins/flytekit-papermill/dev-requirements.in index 98b7896e22..15889bb4ce 100644 --- a/plugins/flytekit-papermill/dev-requirements.in +++ b/plugins/flytekit-papermill/dev-requirements.in @@ -1,3 +1,3 @@ flyteidl>=1.0.0 +flytekitplugins-pod==v1.2.0b0 git+https://github.com/flyteorg/flytekit@master#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark -# vcs+protocol://repo_url/#egg=pkg&subdirectory=flyte diff --git a/plugins/flytekit-papermill/dev-requirements.txt b/plugins/flytekit-papermill/dev-requirements.txt index 4b5cde2509..c8294ca254 100644 --- a/plugins/flytekit-papermill/dev-requirements.txt +++ b/plugins/flytekit-papermill/dev-requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.8 +# This file is autogenerated by pip-compile with python 3.9 # To update, run: # # pip-compile dev-requirements.in @@ -8,8 +8,14 @@ arrow==1.2.1 # via jinja2-time binaryornot==0.4.4 # via cookiecutter +cachetools==5.2.0 + # via google-auth certifi==2021.10.8 - # via requests + # via + # kubernetes + # requests +cffi==1.15.1 + # via cryptography chardet==4.0.0 # via binaryornot charset-normalizer==2.0.10 @@ -20,10 +26,12 @@ click==7.1.2 # flytekit cloudpickle==2.0.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit croniter==1.2.0 # via flytekit +cryptography==37.0.4 + # via secretstorage dataclasses-json==0.5.6 # via flytekit decorator==5.1.1 @@ -43,9 +51,15 @@ flyteidl==1.0.0.post1 # -r dev-requirements.in # flytekit flytekit==1.1.0b0 - # via flytekitplugins-spark + # via + # flytekitplugins-pod + # flytekitplugins-spark +flytekitplugins-pod==v1.2.0b0 + # via -r dev-requirements.in flytekitplugins-spark @ git+https://github.com/flyteorg/flytekit@master#subdirectory=plugins/flytekit-spark # via -r dev-requirements.in +google-auth==2.11.0 + # via kubernetes googleapis-common-protos==1.55.0 # via # flyteidl @@ -60,6 +74,10 @@ idna==3.3 # via requests importlib-metadata==4.10.1 # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.0.3 # via # cookiecutter @@ -68,6 +86,8 @@ jinja2-time==0.2.0 # via cookiecutter keyring==23.5.0 # via flytekit +kubernetes==24.2.0 + # via flytekitplugins-pod markupsafe==2.0.1 # via jinja2 marshmallow==3.14.1 @@ -87,6 +107,8 @@ numpy==1.22.1 # via # pandas # pyarrow +oauthlib==3.2.1 + # via requests-oauthlib pandas==1.3.5 # via flytekit poyo==0.5.0 @@ -102,17 +124,26 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -py4j==0.10.9.3 +py4j==0.10.9.5 # via pyspark pyarrow==6.0.1 # via flytekit -pyspark==3.2.1 +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth +pycparser==2.21 + # via cffi +pyspark==3.3.0 # via flytekitplugins-spark python-dateutil==2.8.1 # via # arrow # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.2 # via flytekit @@ -125,7 +156,9 @@ pytz==2021.3 # flytekit # pandas pyyaml==6.0 - # via flytekit + # via + # flytekit + # kubernetes regex==2021.11.10 # via docker-image-py requests==2.27.1 @@ -133,15 +166,25 @@ requests==2.27.1 # cookiecutter # docker # flytekit + # kubernetes + # requests-oauthlib # responses +requests-oauthlib==1.3.1 + # via kubernetes responses==0.17.0 # via flytekit retry==0.9.2 # via flytekit +rsa==4.9 + # via google-auth +secretstorage==3.3.3 + # via keyring six==1.16.0 # via # cookiecutter + # google-auth # grpcio + # kubernetes # python-dateutil # responses sortedcontainers==2.4.0 @@ -159,10 +202,13 @@ typing-inspect==0.7.1 urllib3==1.26.8 # via # flytekit + # kubernetes # requests # responses websocket-client==1.3.2 - # via docker + # via + # docker + # kubernetes wheel==0.37.1 # via flytekit wrapt==1.13.3 @@ -171,3 +217,6 @@ wrapt==1.13.3 # flytekit zipp==3.7.0 # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index a58b01d482..0721c39a37 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -10,9 +10,12 @@ from nbconvert import HTMLExporter from flytekit import FlyteContext, PythonInstanceTask +from flytekit.configuration import SerializationSettings from flytekit.core.context_manager import ExecutionParameters +from flytekit.deck.deck import Deck from flytekit.extend import Interface, TaskPlugins, TypeEngine from flytekit.loggers import logger +from flytekit.models import task as task_models from flytekit.models.literals import LiteralMap from flytekit.types.file import HTMLPage, PythonNotebook @@ -63,6 +66,7 @@ class NotebookTask(PythonInstanceTask[T]): name="modulename.my_notebook_task", # the name should be unique within all your tasks, usually it is a good # idea to use the modulename notebook_path="../path/to/my_notebook", + render_deck=True, inputs=kwtypes(v=int), outputs=kwtypes(x=int, y=str), metadata=TaskMetadata(retries=3, cache=True, cache_version="1.0"), @@ -76,7 +80,7 @@ class NotebookTask(PythonInstanceTask[T]): #. It captures the executed notebook in its entirety and is available from Flyte with the name ``out_nb``. #. It also converts the captured notebook into an ``html`` page, which the FlyteConsole will render called - - ``out_rendered_nb`` + ``out_rendered_nb``. If ``render_deck=True`` is passed, this html content will be inserted into a deck. .. note: @@ -109,6 +113,7 @@ def __init__( self, name: str, notebook_path: str, + render_deck: bool = False, task_config: T = None, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, outputs: typing.Optional[typing.Dict[str, typing.Type]] = None, @@ -120,14 +125,17 @@ def __init__( # errors. # This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work. plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config)) - self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func) + self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func, **kwargs) # Rename the internal task so that there are no conflicts at serialization time. Technically these internal # tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities # at serialization time. self._config_task_instance._name = f"{PAPERMILL_TASK_PREFIX}.{name}" - task_type = f"nb-{self._config_task_instance.task_type}" + task_type = f"{self._config_task_instance.task_type}" + task_type_version = self._config_task_instance.task_type_version self._notebook_path = os.path.abspath(notebook_path) + self._render_deck = render_deck + if not os.path.exists(self._notebook_path): raise ValueError(f"Illegal notebook path passed in {self._notebook_path}") @@ -139,7 +147,12 @@ def __init__( } ) super().__init__( - name, task_config, task_type=task_type, interface=Interface(inputs=inputs, outputs=outputs), **kwargs + name, + task_config, + task_type=task_type, + task_type_version=task_type_version, + interface=Interface(inputs=inputs, outputs=outputs), + **kwargs, ) @property @@ -154,6 +167,21 @@ def output_notebook_path(self) -> str: def rendered_output_path(self) -> str: return self._notebook_path.split(".ipynb")[0] + "-out.html" + def get_container(self, settings: SerializationSettings) -> task_models.Container: + return self._config_task_instance.get_container(settings) + + def get_k8s_pod(self, settings: SerializationSettings) -> task_models.K8sPod: + # The task name in original command is incorrect because we use _dummy_task_func to construct the _config_task_instance. + # Therefore, Here we replace primary container's command with NotebookTask's command. + def fn(settings: SerializationSettings) -> typing.List[str]: + return self.get_command(settings) + + self._config_task_instance.set_command_fn(fn) + return self._config_task_instance.get_k8s_pod(settings) + + def get_config(self, settings: SerializationSettings) -> typing.Dict[str, str]: + return self._config_task_instance.get_config(settings) + def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: return self._config_task_instance.pre_execute(user_params) @@ -225,6 +253,15 @@ def execute(self, **kwargs) -> Any: return tuple(output_list) def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + if self._render_deck: + nb_deck = Deck(self._IMPLICIT_RENDERED_NOTEBOOK) + with open(self.rendered_output_path, "r") as f: + notebook_html = f.read() + nb_deck.append(notebook_html) + # Since user_params is passed by reference, this modifies the object in the outside scope + # which then causes the deck to be rendered later during the dispatch_execute function. + user_params.decks.append(nb_deck) + return self._config_task_instance.post_execute(user_params, rval) diff --git a/plugins/flytekit-papermill/setup.py b/plugins/flytekit-papermill/setup.py index 26a3f1b705..46d6296f55 100644 --- a/plugins/flytekit-papermill/setup.py +++ b/plugins/flytekit-papermill/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.1.0b0,<1.2.0", + "flytekit>=1.1.0b0,<2.0.0", "papermill>=1.2.0", "nbconvert>=6.0.7", "ipykernel>=5.0.0", diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index ca25eea028..4b456f7fdc 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -2,8 +2,12 @@ import os from flytekitplugins.papermill import NotebookTask +from flytekitplugins.pod import Pod +from kubernetes.client import V1Container, V1PodSpec +import flytekit from flytekit import kwtypes +from flytekit.configuration import Image, ImageConfig from flytekit.types.file import PythonNotebook from .testdata.datatype import X @@ -69,3 +73,50 @@ def test_notebook_task_complex(): assert nb.python_interface.outputs.keys() == {"h", "w", "x", "out_nb", "out_rendered_nb"} assert nb.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out") assert nb.rendered_output_path == render == _get_nb_path(nb_name, suffix="-out", ext=".html") + + +def test_notebook_deck_local_execution_doesnt_fail(): + nb_name = "nb-simple" + nb = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + render_deck=True, + inputs=kwtypes(pi=float), + outputs=kwtypes(square=float), + ) + sqr, out, render = nb.execute(pi=4) + # This is largely a no assert test to ensure render_deck never inhibits local execution. + assert nb._render_deck, "Passing render deck to init should result in private attribute being set" + + +def generate_por_spec_for_task(): + primary_container = V1Container(name="primary") + pod_spec = V1PodSpec(containers=[primary_container]) + + return pod_spec + + +nb = NotebookTask( + name="test", + task_config=Pod(pod_spec=generate_por_spec_for_task(), primary_container_name="primary"), + notebook_path=_get_nb_path("nb-simple", abs=False), + inputs=kwtypes(h=str, n=int, w=str), + outputs=kwtypes(h=str, w=PythonNotebook, x=X), +) + + +def test_notebook_pod_task(): + serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + ) + + assert nb.get_container(serialization_settings) is None + assert nb.get_config(serialization_settings)["primary_container_name"] == "primary" + assert ( + nb.get_command(serialization_settings) + == nb.get_k8s_pod(serialization_settings).pod_spec["containers"][0]["args"] + ) diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 06b1127504..0dfd0c6516 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -1,5 +1,6 @@ import typing +import pandas as pd import polars as pl from flytekit import FlyteContext @@ -7,11 +8,7 @@ from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( - ABFS, - GCS, - LOCAL, PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -19,9 +16,20 @@ ) +class PolarsDataFrameRenderer: + """ + The Polars DataFrame summary statistics are rendered as an HTML table. + """ + + def to_html(self, df: pl.DataFrame) -> str: + assert isinstance(df, pl.DataFrame) + describe_df = df.describe() + return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False) + + class PolarsDataFrameToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pl.DataFrame, protocol, PARQUET) + def __init__(self): + super().__init__(pl.DataFrame, None, PARQUET) def encode( self, @@ -45,8 +53,8 @@ def encode( class ParquetToPolarsDataFrameDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pl.DataFrame, protocol, PARQUET) + def __init__(self): + super().__init__(pl.DataFrame, None, PARQUET) def decode( self, @@ -63,10 +71,6 @@ def decode( return pl.read_parquet(path) -for protocol in [LOCAL, S3, GCS, ABFS]: - StructuredDatasetTransformerEngine.register( - PolarsDataFrameToParquetEncodingHandler(protocol), default_for_type=False - ) - StructuredDatasetTransformerEngine.register( - ParquetToPolarsDataFrameDecodingHandler(protocol), default_for_type=False - ) +StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToPolarsDataFrameDecodingHandler()) +StructuredDatasetTransformerEngine.register_renderer(pl.DataFrame, PolarsDataFrameRenderer()) diff --git a/plugins/flytekit-polars/setup.py b/plugins/flytekit-polars/setup.py index ea3feb8582..f4086babb7 100644 --- a/plugins/flytekit-polars/setup.py +++ b/plugins/flytekit-polars/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.1.0b0,<1.2.0", + "flytekit>=1.1.0b0,<2.0.0", "polars>=0.8.27", ] diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index 3c9c2613ae..b991cd5d13 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -1,10 +1,7 @@ -import flytekitplugins.polars # noqa F401 +import pandas as pd import polars as pl - -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated +from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer +from typing_extensions import Annotated from flytekit import kwtypes, task, workflow from flytekit.types.structured.structured_dataset import PARQUET, StructuredDataset @@ -62,3 +59,10 @@ def wf() -> full_schema: result = wf() assert result is not None + + +def test_polars_renderer(): + df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame( + df.describe().transpose(), columns=df.describe().columns + ).to_html(index=False) diff --git a/plugins/flytekit-snowflake/setup.py b/plugins/flytekit-snowflake/setup.py index ebea48f304..e82bb7268f 100644 --- a/plugins/flytekit-snowflake/setup.py +++ b/plugins/flytekit-snowflake/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 9fef590bcc..46079f40dd 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -1,5 +1,6 @@ import typing +import pandas as pd from pyspark.sql.dataframe import DataFrame from flytekit import FlyteContext @@ -7,11 +8,7 @@ from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( - ABFS, - GCS, - LOCAL, PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -19,9 +16,19 @@ ) +class SparkDataFrameRenderer: + """ + Render a Spark dataframe schema as an HTML table. + """ + + def to_html(self, df: DataFrame) -> str: + assert isinstance(df, DataFrame) + return pd.DataFrame(df.schema, columns=["StructField"]).to_html() + + class SparkToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(DataFrame, protocol, PARQUET) + def __init__(self): + super().__init__(DataFrame, None, PARQUET) def encode( self, @@ -36,8 +43,8 @@ def encode( class ParquetToSparkDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(DataFrame, protocol, PARQUET) + def __init__(self): + super().__init__(DataFrame, None, PARQUET) def decode( self, @@ -52,6 +59,6 @@ def decode( return user_ctx.spark_session.read.parquet(flyte_value.uri) -for protocol in [LOCAL, S3, GCS, ABFS]: - StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler(protocol), default_for_type=False) - StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler(protocol), default_for_type=False) +StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler()) +StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer()) diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index 9fdffe6c22..108fbb1169 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "pyspark>=3.0.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "pyspark>=3.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-sqlalchemy/setup.py b/plugins/flytekit-sqlalchemy/setup.py index 6bf0a8e1ab..aa13aa8fbc 100644 --- a/plugins/flytekit-sqlalchemy/setup.py +++ b/plugins/flytekit-sqlalchemy/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "sqlalchemy>=1.4.7"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "sqlalchemy>=1.4.7"] __version__ = "0.0.0+develop" diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 8398d30d1e..6543d204fd 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -22,7 +22,7 @@ cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==2.1.1 # via requests click==8.1.3 # via @@ -35,9 +35,7 @@ cookiecutter==2.1.1 croniter==1.3.5 # via flytekit cryptography==37.0.4 - # via - # pyopenssl - # secretstorage + # via pyopenssl dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -46,15 +44,15 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.0 # via flytekit docker-image-py==0.1.12 # via flytekit docstring-parser==0.14.1 # via flytekit -flyteidl==1.1.8 +flyteidl==1.1.12 # via flytekit -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.56.4 # via # flyteidl # grpcio-status @@ -72,23 +70,21 @@ importlib-metadata==4.12.0 # flytekit # jsonschema # keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter +joblib==1.1.0 + # via flytekit jsonschema==3.2.0 # via -r requirements.in -keyring==23.6.0 +keyring==23.8.2 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.17.1 # via # dataclasses-json # marshmallow-enum @@ -108,7 +104,9 @@ numpy==1.21.6 # pandas # pyarrow packaging==21.3 - # via marshmallow + # via + # docker + # marshmallow pandas==1.3.5 # via # -r requirements.in @@ -146,7 +144,7 @@ python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.2.1 # via # flytekit # pandas @@ -155,7 +153,7 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit -regex==2022.7.9 +regex==2022.8.17 # via docker-image-py requests==2.28.1 # via @@ -167,8 +165,6 @@ responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 - # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 @@ -190,10 +186,11 @@ typing-extensions==4.3.0 # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.10 +urllib3==1.26.12 # via + # docker # flytekit # requests # responses @@ -207,7 +204,7 @@ wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 +zipp==3.8.1 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements.txt b/requirements.txt index 153da8b4d2..32b4ae49a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==2.1.1 # via requests click==8.1.3 # via @@ -33,9 +33,7 @@ cookiecutter==2.1.1 croniter==1.3.5 # via flytekit cryptography==37.0.4 - # via - # pyopenssl - # secretstorage + # via pyopenssl dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -44,15 +42,15 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.0 # via flytekit docker-image-py==0.1.12 # via flytekit docstring-parser==0.14.1 # via flytekit -flyteidl==1.1.8 +flyteidl==1.1.12 # via flytekit -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.56.4 # via # flyteidl # grpcio-status @@ -70,23 +68,21 @@ importlib-metadata==4.12.0 # flytekit # jsonschema # keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.2 # via # cookiecutter # jinja2-time jinja2-time==0.2.0 # via cookiecutter +joblib==1.1.0 + # via flytekit jsonschema==3.2.0 # via -r requirements.in -keyring==23.6.0 +keyring==23.8.2 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.17.1 # via # dataclasses-json # marshmallow-enum @@ -106,7 +102,9 @@ numpy==1.21.6 # pandas # pyarrow packaging==21.3 - # via marshmallow + # via + # docker + # marshmallow pandas==1.3.5 # via # -r requirements.in @@ -144,7 +142,7 @@ python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.2.1 # via # flytekit # pandas @@ -153,7 +151,7 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit -regex==2022.7.9 +regex==2022.8.17 # via docker-image-py requests==2.28.1 # via @@ -165,8 +163,6 @@ responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 - # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 @@ -188,10 +184,11 @@ typing-extensions==4.3.0 # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.10 +urllib3==1.26.12 # via + # docker # flytekit # requests # responses @@ -205,7 +202,7 @@ wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 +zipp==3.8.1 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/setup.py b/setup.py index 85af8691b6..e5bc3cfa33 100644 --- a/setup.py +++ b/setup.py @@ -46,12 +46,13 @@ "click>=6.6,<9.0", "croniter>=0.3.20,<4.0.0", "deprecated>=1.0,<2.0", - "docker>=5.0.3,<6.0.0", + "docker>=5.0.3,<7.0.0", "python-dateutil>=2.1", "grpcio>=1.43.0,!=1.45.0,<2.0", "grpcio-status>=1.43,!=1.45.0", "importlib-metadata", "pyopenssl", + "joblib", "protobuf>=3.6.1,<4", "python-json-logger>=2.0.0", "pytimeparse>=1.1.8,<2.0.0", diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index c93c56435c..57e35e64b3 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -14,7 +14,7 @@ cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==2.1.1 # via requests click==8.1.3 # via @@ -27,9 +27,7 @@ cookiecutter==2.1.1 croniter==1.3.5 # via flytekit cryptography==37.0.4 - # via - # pyopenssl - # secretstorage + # via pyopenssl cycler==0.11.0 # via matplotlib dataclasses-json==0.5.7 @@ -46,13 +44,13 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.14.1 # via flytekit -flyteidl==1.1.8 +flyteidl==1.1.12 # via flytekit -flytekit==1.1.0 +flytekit==1.1.1 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -fonttools==4.33.3 +fonttools==4.37.1 # via matplotlib -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.56.4 # via # flyteidl # grpcio-status @@ -69,10 +67,6 @@ importlib-metadata==4.12.0 # click # flytekit # keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.2 # via # cookiecutter @@ -81,13 +75,13 @@ jinja2-time==0.2.0 # via cookiecutter joblib==1.1.0 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -keyring==23.6.0 +keyring==23.8.2 # via flytekit -kiwisolver==1.4.3 +kiwisolver==1.4.4 # via matplotlib markupsafe==2.1.1 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.17.1 # via # dataclasses-json # marshmallow-enum @@ -96,7 +90,7 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.5.2 +matplotlib==3.5.3 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in mypy-extensions==0.4.3 # via typing-inspect @@ -142,17 +136,18 @@ pyparsing==3.0.9 # packaging python-dateutil==2.8.2 # via + # arrow # croniter # flytekit # matplotlib # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.4 # via flytekit python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.2.1 # via # flytekit # pandas @@ -160,7 +155,7 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.8.17 # via docker-image-py requests==2.28.1 # via @@ -172,8 +167,6 @@ responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 - # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 @@ -194,14 +187,14 @@ typing-extensions==4.3.0 # kiwisolver # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.12 # via # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.4.0 # via docker wheel==0.37.1 # via @@ -211,5 +204,5 @@ wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 +zipp==3.8.1 # via importlib-metadata diff --git a/tests/flytekit/unit/cli/pyflyte/imperative_wf.py b/tests/flytekit/unit/cli/pyflyte/imperative_wf.py new file mode 100644 index 0000000000..12d7f2e3a3 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/imperative_wf.py @@ -0,0 +1,39 @@ +import typing + +from flytekit import Workflow, task + + +@task +def t1(a: str) -> str: + return a + " world" + + +@task +def t2(): + print("side effect") + + +@task +def t3(a: typing.List[str]) -> str: + return ",".join(a) + + +wf = Workflow(name="my.imperative.workflow.example") +wf.add_workflow_input("in1", str) +node_t1 = wf.add_entity(t1, a=wf.inputs["in1"]) +wf.add_workflow_output("output_from_t1", node_t1.outputs["o0"]) +wf.add_entity(t2) + +wf_in2 = wf.add_workflow_input("in2", str) +node_t3 = wf.add_entity(t3, a=[wf.inputs["in1"], wf_in2]) + +wf.add_workflow_output( + "output_list", + [node_t1.outputs["o0"], node_t3.outputs["o0"]], + python_type=typing.List[str], +) + + +if __name__ == "__main__": + print(wf) + print(wf(in1="hello", in2="foo")) diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index ec35f5362d..b6bb3d44c4 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -1,13 +1,34 @@ +import functools import os import pathlib +import typing +from enum import Enum +import click +import mock import pytest from click.testing import CliRunner +from flytekit import FlyteContextManager from flytekit.clis.sdk_in_container import pyflyte -from flytekit.clis.sdk_in_container.run import get_entities_in_file +from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE +from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY +from flytekit.clis.sdk_in_container.run import ( + REMOTE_FLAG_KEY, + RUN_LEVEL_PARAMS_KEY, + FileParamType, + FlyteLiteralConverter, + get_entities_in_file, + run_command, +) +from flytekit.configuration import Config, Image, ImageConfig +from flytekit.core.task import task +from flytekit.core.type_engine import TypeEngine +from flytekit.models.types import SimpleType +from flytekit.remote import FlyteRemote WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") +IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py") DIR_NAME = os.path.dirname(os.path.realpath(__file__)) @@ -19,6 +40,16 @@ def test_pyflyte_run_wf(): assert result.exit_code == 0 +def test_imperative_wf(): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + ["run", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + def test_pyflyte_run_cli(): runner = CliRunner() result = runner.invoke( @@ -172,3 +203,129 @@ def test_list_default_arguments(wf_path): ) print(result.stdout) assert result.exit_code == 0 + + +# default case, what comes from click if no image is specified, the click param is configured to use the default. +ic_result_1 = ImageConfig( + default_image=Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest"), + images=[Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest")], +) +# test that command line args are merged with the file +ic_result_2 = ImageConfig( + default_image=None, + images=[ + Image(name="asdf", fqn="ghcr.io/asdf/asdf", tag="latest"), + Image(name="xyz", fqn="docker.io/xyz", tag="latest"), + Image(name="abc", fqn="docker.io/abc", tag=None), + ], +) +# test that command line args override the file +ic_result_3 = ImageConfig( + default_image=None, + images=[Image(name="xyz", fqn="ghcr.io/asdf/asdf", tag="latest"), Image(name="abc", fqn="docker.io/abc", tag=None)], +) + + +@pytest.mark.parametrize( + "image_string, leaf_configuration_file_name, final_image_config", + [ + ("ghcr.io/flyteorg/mydefault:py3.9-latest", "no_images.yaml", ic_result_1), + ("asdf=ghcr.io/asdf/asdf:latest", "sample.yaml", ic_result_2), + ("xyz=ghcr.io/asdf/asdf:latest", "sample.yaml", ic_result_3), + ], +) +def test_pyflyte_run_run(image_string, leaf_configuration_file_name, final_image_config): + @task + def a(): + ... + + mock_click_ctx = mock.MagicMock() + mock_remote = mock.MagicMock() + image_tuple = (image_string,) + image_config = ImageConfig.validate_image(None, "", image_tuple) + + run_level_params = { + "project": "p", + "domain": "d", + "image_config": image_config, + } + + pp = pathlib.Path.joinpath( + pathlib.Path(__file__).parent.parent.parent, "configuration/configs/", leaf_configuration_file_name + ) + + obj = { + RUN_LEVEL_PARAMS_KEY: run_level_params, + REMOTE_FLAG_KEY: True, + FLYTE_REMOTE_INSTANCE_KEY: mock_remote, + CTX_CONFIG_FILE: str(pp), + } + mock_click_ctx.obj = obj + + def check_image(*args, **kwargs): + assert kwargs["image_config"] == final_image_config + + mock_remote.register_script.side_effect = check_image + + run_command(mock_click_ctx, a)() + + +def test_file_param(): + m = mock.MagicMock() + l = FileParamType().convert(__file__, m, m) + assert l.local + r = FileParamType().convert("https://tmp/file", m, m) + assert r.local is False + + +class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + +@pytest.mark.parametrize( + "python_type, python_value", + [ + (typing.Union[typing.List[int], str, Color], "flyte"), + (typing.Union[typing.List[int], str, Color], "red"), + (typing.Union[typing.List[int], str, Color], [1, 2, 3]), + (typing.List[int], [1, 2, 3]), + (typing.Dict[str, int], {"flyte": 2}), + ], +) +def test_literal_converter(python_type, python_value): + get_upload_url_fn = functools.partial( + FlyteRemote(Config.auto()).client.get_upload_signed_url, project="p", domain="d" + ) + click_ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(python_type) + + lc = FlyteLiteralConverter( + click_ctx, ctx, literal_type=lt, python_type=python_type, get_upload_url_fn=get_upload_url_fn + ) + + assert lc.convert(click_ctx, ctx, python_value) == TypeEngine.to_literal(ctx, python_value, python_type, lt) + + +def test_enum_converter(): + pt = typing.Union[str, Color] + + get_upload_url_fn = functools.partial(FlyteRemote(Config.auto()).client.get_upload_signed_url) + click_ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(pt) + lc = FlyteLiteralConverter(click_ctx, ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn) + union_lt = lc.convert(click_ctx, ctx, "red").scalar.union + + assert union_lt.stored_type.simple == SimpleType.STRING + assert union_lt.stored_type.enum_type is None + + pt = typing.Union[Color, str] + lt = TypeEngine.to_literal_type(typing.Union[Color, str]) + lc = FlyteLiteralConverter(click_ctx, ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn) + union_lt = lc.convert(click_ctx, ctx, "red").scalar.union + + assert union_lt.stored_type.simple is None + assert union_lt.stored_type.enum_type.values == ["red", "green", "blue"] diff --git a/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml b/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml index 9c1ad83a3e..7da41b7c38 100644 --- a/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml +++ b/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml @@ -1,7 +1,7 @@ admin: # For GRPC endpoints you might want to use dns:///flyte.myexample.com endpoint: dns:///flyte.mycorp.io - clientSecretLocation: ../tests/flytekit/unit/configuration/configs/fake_secret + clientSecretLocation: configs/fake_secret authType: Pkce insecure: true clientId: propeller diff --git a/tests/flytekit/unit/configuration/test_internal.py b/tests/flytekit/unit/configuration/test_internal.py index 6ba81f309c..7f6be53a55 100644 --- a/tests/flytekit/unit/configuration/test_internal.py +++ b/tests/flytekit/unit/configuration/test_internal.py @@ -2,7 +2,7 @@ import mock -from flytekit.configuration import get_config_file, read_file_if_exists +from flytekit.configuration import PlatformConfig, get_config_file, read_file_if_exists from flytekit.configuration.internal import AWS, Credentials, Images @@ -31,7 +31,20 @@ def test_client_secret_location(): os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/creds_secret_location.yaml") ) secret_location = Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(cfg) - assert secret_location == "../tests/flytekit/unit/configuration/configs/fake_secret" + assert secret_location == "configs/fake_secret" + + # Modify the path to the secret inline + cfg._yaml_config["admin"]["clientSecretLocation"] = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs/fake_secret" + ) + + # Assert secret contains a newline + with open(cfg._yaml_config["admin"]["clientSecretLocation"], "rb") as f: + assert f.read().decode().endswith("\n") is True + + # Assert that secret in platform config does not contain a newline + platform_cfg = PlatformConfig.auto(cfg) + assert platform_cfg.client_credentials_secret == "hello" def test_read_file_if_exists(): diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index e61350a7ed..af39e9e852 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,4 +1,4 @@ -from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.data_persistence import DataPersistencePlugins, FileAccessProvider def test_get_random_remote_path(): @@ -14,3 +14,10 @@ def test_is_remote(): assert fp.is_remote("/tmp/foo/bar") is False assert fp.is_remote("file://foo/bar") is False assert fp.is_remote("s3://my-bucket/foo/bar") is True + + +def test_lister(): + x = DataPersistencePlugins.supported_protocols() + main_protocols = {"file", "/", "gs", "http", "https", "s3"} + all_protocols = set([y.replace("://", "") for y in x]) + assert main_protocols.issubset(all_protocols) diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index 365ce4c25f..668ca97dfd 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -1,5 +1,7 @@ import typing +import pytest + import flytekit.configuration from flytekit import dynamic from flytekit.configuration import FastSerializationSettings, Image, ImageConfig @@ -10,6 +12,19 @@ from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + fast_serialization_settings=FastSerializationSettings( + enabled=True, + destination_dir="/User/flyte/workflows", + distribution_location="s3://my-s3-bucket/fast/123", + ), +) + def test_wf1_with_fast_dynamic(): @task @@ -30,20 +45,7 @@ def my_wf(a: int) -> typing.List[str]: return v with context_manager.FlyteContextManager.with_context( - context_manager.FlyteContextManager.current_context().with_serialization_settings( - flytekit.configuration.SerializationSettings( - project="test_proj", - domain="test_domain", - version="abc", - image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), - env={}, - fast_serialization_settings=FastSerializationSettings( - enabled=True, - destination_dir="/User/flyte/workflows", - distribution_location="s3://my-s3-bucket/fast/123", - ), - ) - ) + context_manager.FlyteContextManager.current_context().with_serialization_settings(settings) ) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( @@ -111,6 +113,24 @@ def wf(a: int, b: int) -> typing.List[str]: assert res == ["fast-2", "fast-3", "fast-4", "fast-5", "fast-6"] +def test_dynamic_local_use(): + @task + def t1(a: int) -> str: + a = a + 2 + return "fast-" + str(a) + + @dynamic + def use_result(a: int) -> int: + x = t1(a=a) + if len(x) > 6: + return 5 + else: + return 0 + + with pytest.raises(TypeError): + use_result(a=6) + + def test_create_node_dynamic_local(): @task def task1(s: str) -> str: diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py index a3ec6d17ce..318a6b76f3 100644 --- a/tests/flytekit/unit/core/test_flyte_pickle.py +++ b/tests/flytekit/unit/core/test_flyte_pickle.py @@ -1,5 +1,10 @@ from collections import OrderedDict -from typing import Dict, List +from collections.abc import Sequence +from typing import Dict, List, Union + +import numpy as np +import pandas as pd +from typing_extensions import Annotated import flytekit.configuration from flytekit.configuration import Image, ImageConfig @@ -80,3 +85,15 @@ def t1(a: int) -> List[Dict[str, Foo]]: task_spec.template.interface.outputs["o0"].type.collection_type.map_value_type.blob.format is FlytePickleTransformer.PYTHON_PICKLE_FORMAT ) + + +def test_union(): + @task + def t1(data: Annotated[Union[np.ndarray, pd.DataFrame, Sequence], "some annotation"]): + print(data) + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + variants = task_spec.template.interface.inputs["data"].type.union_type.variants + assert variants[0].blob.format == "NumpyArray" + assert variants[1].structured_dataset_type.format == "parquet" + assert variants[2].blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index 3f3e56de88..674f6176e1 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -1,19 +1,25 @@ import datetime import typing from dataclasses import dataclass -from typing import List +from typing import Dict, List import pandas from dataclasses_json import dataclass_json from pytest import fixture from typing_extensions import Annotated -from flytekit import SQLTask, dynamic, kwtypes +from flytekit.core.base_sql_task import SQLTask +from flytekit.core.base_task import kwtypes +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.hash import HashMethod -from flytekit.core.local_cache import LocalTaskCache +from flytekit.core.local_cache import LocalTaskCache, _calculate_cache_key from flytekit.core.task import TaskMetadata, task from flytekit.core.testing import task_mock +from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +from flytekit.models.literals import LiteralMap +from flytekit.models.types import LiteralType, SimpleType from flytekit.types.schema import FlyteSchema # Global counter used to validate number of calls to cache @@ -309,13 +315,13 @@ def t1(a: int) -> int: # We should have a cache miss in the first call to downstream_t and have a cache hit # on the second call. - v_1 = downstream_t(a=v) + downstream_t(a=v) v_2 = downstream_t(a=v) - return v_1 + v_2 + return v_2 assert n_cached_task_calls == 0 - assert t1(a=3) == (6 + 6) + assert t1(a=3) == 6 assert n_cached_task_calls == 1 @@ -383,3 +389,52 @@ def my_workflow(): # Confirm that we see a cache hit in the case of annotated dataframes. my_workflow() assert n_cached_task_calls == 1 + + +def test_cache_key_repetition(): + pt = Dict + lt = TypeEngine.to_literal_type(pt) + ctx = FlyteContextManager.current_context() + kwargs = { + "a": 0.41083513079747874, + "b": 0.7773927872515183, + "c": 17, + } + keys = set() + for i in range(0, 100): + lit = TypeEngine.to_literal(ctx, kwargs, Dict, lt) + lm = LiteralMap( + literals={ + "d": lit, + } + ) + key = _calculate_cache_key("t1", "007", lm) + keys.add(key) + + assert len(keys) == 1 + + +def test_stable_cache_key(): + """ + The intent of this test is to ensure cache keys are stable across releases and python versions. + """ + pt = Dict + lt = TypeEngine.to_literal_type(pt) + ctx = FlyteContextManager.current_context() + kwargs = { + "a": 42, + "b": "abcd", + "c": 0.12349, + "d": [1, 2, 3], + } + lit = TypeEngine.to_literal(ctx, kwargs, Dict, lt) + lm = LiteralMap( + literals={ + "lit_1": lit, + "lit_2": TypeEngine.to_literal(ctx, 99, int, LiteralType(simple=SimpleType.INTEGER)), + "lit_3": TypeEngine.to_literal(ctx, 3.14, float, LiteralType(simple=SimpleType.FLOAT)), + "lit_4": TypeEngine.to_literal(ctx, True, bool, LiteralType(simple=SimpleType.BOOLEAN)), + } + ) + key = _calculate_cache_key("task_name_1", "31415", lm) + assert key == "task_name_1-31415-a291dc6fe0be387c1cfd67b4c6b78259" diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index a303230386..f6dc9c9ba5 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -372,3 +372,27 @@ def my_wf(a: str) -> str: wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.interruptible == interruptible + + +def test_void_promise_override(): + @task + def t1(a: str): + print(f"*~*~*~{a}*~*~*~") + + @workflow + def my_wf(a: str): + t1(a=a).with_overrides(requests=Resources(cpu="1", mem="100")) + + serialization_settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == [ + _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"), + _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "100"), + ] diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index 773d30b5ae..7793df430f 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -1,9 +1,13 @@ import tempfile import typing +import pandas as pd +import pyarrow as pa import pytest +from typing_extensions import Annotated import flytekit.configuration +from flytekit import kwtypes, task from flytekit.configuration import Image, ImageConfig from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider @@ -11,18 +15,7 @@ from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import SchemaType, SimpleType, StructuredDatasetType - -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - -import pandas as pd -import pyarrow as pa - -from flytekit import kwtypes, task from flytekit.types.structured.structured_dataset import ( - LOCAL, PARQUET, StructuredDataset, StructuredDatasetDecoder, @@ -49,7 +42,7 @@ def test_protocol(): assert protocol_prefix("s3://my-s3-bucket/file") == "s3" - assert protocol_prefix("/file") == "/" + assert protocol_prefix("/file") == "file" def generate_pandas() -> pd.DataFrame: @@ -121,10 +114,10 @@ def test_types_sd(): def test_retrieving(): - assert StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "/", PARQUET) is not None + assert StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET) is not None with pytest.raises(ValueError): # We don't have a default "" format encoder - StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "/", "") + StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", "") class TempEncoder(StructuredDatasetEncoder): def __init__(self, protocol): @@ -137,6 +130,11 @@ def encode(self): with pytest.raises(ValueError): StructuredDatasetTransformerEngine.register(TempEncoder("gs://"), default_for_type=False) + with pytest.raises(ValueError, match="Use None instead"): + e = TempEncoder("") + e._protocol = "" + StructuredDatasetTransformerEngine.register(e) + class TempEncoder: pass @@ -209,6 +207,24 @@ def encode( assert res is empty_format_temp_encoder +def test_slash_register(): + class TempEncoder(StructuredDatasetEncoder): + def __init__(self, fmt: str): + super().__init__(MyDF, None, supported_format=fmt) + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + return literals.StructuredDataset(uri="") + + # Check that registering with a / triggers the file protocol instead. + StructuredDatasetTransformerEngine.register(TempEncoder("/")) + assert StructuredDatasetTransformerEngine.ENCODERS[MyDF].get("file") is not None + + def test_sd(): sd = StructuredDataset(dataframe="hi") sd.uri = "my uri" @@ -273,6 +289,9 @@ def test_convert_schema_type_to_structured_dataset_type(): with pytest.raises(AssertionError, match="Unrecognized SchemaColumnType"): convert_schema_type_to_structured_dataset_type(int) + with pytest.raises(AssertionError, match="Unrecognized SchemaColumnType"): + convert_schema_type_to_structured_dataset_type(20) + def test_to_python_value_with_incoming_columns(): # make a literal with a type that has two columns @@ -338,7 +357,7 @@ def test_to_python_value_without_incoming_columns(): def test_format_correct(): class TempEncoder(StructuredDatasetEncoder): def __init__(self): - super().__init__(pd.DataFrame, LOCAL, "avro") + super().__init__(pd.DataFrame, "/", "avro") def encode( self, @@ -385,7 +404,7 @@ def test_protocol_detection(): e = StructuredDatasetTransformerEngine() ctx = FlyteContextManager.current_context() protocol = e._protocol_from_type_or_prefix(ctx, pd.DataFrame) - assert protocol == "/" + assert protocol == "file" with tempfile.TemporaryDirectory() as tmp_dir: fs = FileAccessProvider(local_sandbox_dir=tmp_dir, raw_output_prefix="s3://fdsa") @@ -395,3 +414,18 @@ def test_protocol_detection(): protocol = e._protocol_from_type_or_prefix(ctx2, pd.DataFrame, "bq://foo") assert protocol == "bq" + + +def test_register_renderers(): + class DummyRenderer: + def to_html(self, input: str) -> str: + return "hello " + input + + renderers = StructuredDatasetTransformerEngine.Renderers + StructuredDatasetTransformerEngine.register_renderer(str, DummyRenderer()) + assert renderers[str].to_html("flyte") == "hello flyte" + assert pd.DataFrame in renderers + assert pa.Table in renderers + + with pytest.raises(NotImplementedError, match="Could not find a renderer for in"): + StructuredDatasetTransformerEngine().to_html(FlyteContextManager.current_context(), 3, int) diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index ada7483a0f..c7aa5563f9 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -13,6 +13,7 @@ StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, + StructuredDatasetTransformerEngine, ) my_cols = kwtypes(w=typing.Dict[str, typing.Dict[str, int]], x=typing.List[typing.List[int]], y=int, z=str) @@ -23,8 +24,8 @@ def test_pandas(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - encoder = basic_dfs.PandasToParquetEncodingHandler("/") - decoder = basic_dfs.ParquetToPandasDecodingHandler("/") + encoder = basic_dfs.PandasToParquetEncodingHandler() + decoder = basic_dfs.ParquetToPandasDecodingHandler() ctx = context_manager.FlyteContextManager.current_context() sd = StructuredDataset(dataframe=df) @@ -41,3 +42,13 @@ def test_base_isnt_instantiable(): with pytest.raises(TypeError): StructuredDatasetDecoder(pd.DataFrame, "", "") + + +def test_arrow(): + encoder = basic_dfs.ArrowToParquetEncodingHandler() + decoder = basic_dfs.ParquetToArrowDecodingHandler() + assert encoder.protocol is None + assert decoder.protocol is None + assert encoder.python_type is decoder.python_type + d = StructuredDatasetTransformerEngine.DECODERS[encoder.python_type]["s3"]["parquet"] + assert d is not None diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index df8e14d7cb..34251f7d0a 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -35,6 +35,7 @@ TypeEngine, TypeTransformer, TypeTransformerFailedError, + UnionTransformer, convert_json_schema_to_python_class, dataclass_from_dict, ) @@ -147,6 +148,7 @@ def test_list_of_dataclass_getting_python_value(): @dataclass_json @dataclass() class Bar(object): + v: typing.Union[int, None] w: typing.Optional[str] x: float y: str @@ -155,12 +157,14 @@ class Bar(object): @dataclass_json @dataclass() class Foo(object): + u: typing.Optional[int] + v: typing.Optional[int] w: int x: typing.List[int] y: typing.Dict[str, str] z: Bar - foo = Foo(w=1, x=[1], y={"hello": "10"}, z=Bar(w=None, x=1.0, y="hello", z={"world": False})) + foo = Foo(u=5, v=None, w=1, x=[1], y={"hello": "10"}, z=Bar(v=3, w=None, x=1.0, y="hello", z={"world": False})) generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct()) lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) @@ -170,16 +174,23 @@ class Foo(object): schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema()) foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema") - pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) - assert isinstance(pv, list) - assert pv[0].w == foo.w - assert pv[0].x == foo.x - assert pv[0].y == foo.y - assert pv[0].z.x == foo.z.x - assert type(pv[0].z.x) == float - assert pv[0].z.y == foo.z.y - assert pv[0].z.z == foo.z.z - assert foo == dataclass_from_dict(Foo, asdict(pv[0])) + guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) + print("=====") + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) + assert isinstance(guessed_pv, list) + assert guessed_pv[0].u == pv[0].u + assert guessed_pv[0].v == pv[0].v + assert guessed_pv[0].w == pv[0].w + assert guessed_pv[0].x == pv[0].x + assert guessed_pv[0].y == pv[0].y + assert guessed_pv[0].z.x == pv[0].z.x + assert type(guessed_pv[0].u) == int + assert guessed_pv[0].v is None + assert type(guessed_pv[0].w) == int + assert type(guessed_pv[0].z.x) == float + assert guessed_pv[0].z.y == pv[0].z.y + assert guessed_pv[0].z.z == pv[0].z.z + assert pv[0] == dataclass_from_dict(Foo, asdict(guessed_pv[0])) def test_file_no_downloader_default(): @@ -782,6 +793,43 @@ def test_union_type(): assert v == "hello" +def test_assert_dataclass_type(): + @dataclass_json + @dataclass + class Args(object): + x: int + y: typing.Optional[str] + + @dataclass_json + @dataclass + class Schema(object): + x: typing.Optional[Args] = None + + pt = Schema + lt = TypeEngine.to_literal_type(pt) + gt = TypeEngine.guess_python_type(lt) + pv = Schema(x=Args(x=3, y="hello")) + DataclassTransformer().assert_type(gt, pv) + DataclassTransformer().assert_type(Schema, pv) + + @dataclass_json + @dataclass + class Bar(object): + x: int + + pv = Bar(x=3) + with pytest.raises( + TypeTransformerFailedError, match="Type of Val '' is not an instance of " + ): + DataclassTransformer().assert_type(gt, pv) + + +def test_union_transformer(): + assert UnionTransformer.is_optional_type(typing.Optional[int]) + assert not UnionTransformer.is_optional_type(str) + assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int + + def test_union_type_with_annotated(): pt = typing.Union[ Annotated[str, FlyteAnnotation({"hello": "world"})], Annotated[int, FlyteAnnotation({"test": 123})] @@ -1174,6 +1222,7 @@ def test_pass_annotated_to_downstream_tasks(): """ Test to confirm that the loaded dataframe is not affected and can be used in @dynamic. """ + # pandas dataframe hash function def hash_pandas_dataframe(df: pd.DataFrame) -> str: return str(pd.util.hash_pandas_object(df)) @@ -1197,11 +1246,11 @@ def t1(a: int) -> int: # We should have a cache miss in the first call to downstream_t v_1 = downstream_t(a=v, df=df) - v_2 = downstream_t(a=v, df=df) + downstream_t(a=v, df=df) - return v_1 + v_2 + return v_1 - assert t1(a=3) == (6 + 6 + 6) + assert t1(a=3) == 9 def test_literal_hash_int_not_set(): diff --git a/tests/flytekit/unit/deck/test_renderer.py b/tests/flytekit/unit/deck/test_renderer.py index f1ebbcd873..3f597af416 100644 --- a/tests/flytekit/unit/deck/test_renderer.py +++ b/tests/flytekit/unit/deck/test_renderer.py @@ -1,9 +1,12 @@ import pandas as pd +import pyarrow as pa -from flytekit.deck.renderer import TopFrameRenderer +from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer -def test_frame_profiling_renderer(): +def test_renderer(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]}) - renderer = TopFrameRenderer() - assert renderer.to_html(df) == df.to_html() + pa_df = pa.Table.from_pandas(df) + + assert TopFrameRenderer().to_html(df) == df.to_html() + assert ArrowRenderer().to_html(pa_df) == pa_df.to_string() diff --git a/tests/flytekit/unit/tools/test_fast_registration.py b/tests/flytekit/unit/tools/test_fast_registration.py index 0b50d6fdcf..aae3995bcb 100644 --- a/tests/flytekit/unit/tools/test_fast_registration.py +++ b/tests/flytekit/unit/tools/test_fast_registration.py @@ -23,7 +23,10 @@ def flyte_project(tmp_path): "workflows": { "__pycache__": {"some.pyc": ""}, "hello_world.py": "print('Hello World!')", - } + }, + }, + "utils": { + "util.py": "print('Hello from utils!')", }, ".venv": {"lots": "", "of": "", "packages": ""}, ".env": "supersecret", @@ -35,6 +38,7 @@ def flyte_project(tmp_path): } make_tree(tmp_path, tree) + os.symlink(str(tmp_path) + "/utils/util.py", str(tmp_path) + "/src/util") subprocess.run(["git", "init", str(tmp_path)]) return tmp_path @@ -48,9 +52,29 @@ def test_package(flyte_project, tmp_path): ".gitignore", "keep.foo", "src", + "src/util", "src/workflows", "src/workflows/hello_world.py", + "utils", + "utils/util.py", + ] + util = tar.getmember("src/util") + assert util.issym() + assert str(os.path.basename(archive_fname)).startswith(FAST_PREFIX) + assert str(archive_fname).endswith(FAST_FILEENDING) + + +def test_package_with_symlink(flyte_project, tmp_path): + archive_fname = fast_package(source=flyte_project / "src", output_dir=tmp_path, deref_symlinks=True) + with tarfile.open(archive_fname, dereference=True) as tar: + assert tar.getnames() == [ + "", # tar root, output removes leading '/' + "util", + "workflows", + "workflows/hello_world.py", ] + util = tar.getmember("util") + assert util.isfile() assert str(os.path.basename(archive_fname)).startswith(FAST_PREFIX) assert str(archive_fname).endswith(FAST_FILEENDING) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index f0d58eb36d..5d04a12e7b 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -1,28 +1,20 @@ import os import typing -import pytest - -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq +import pytest +from typing_extensions import Annotated from flytekit import FlyteContext, FlyteContextManager, kwtypes, task, workflow from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( - BIGQUERY, DF, - LOCAL, PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -41,7 +33,7 @@ class MockBQEncodingHandlers(StructuredDatasetEncoder): def __init__(self): - super().__init__(pd.DataFrame, BIGQUERY, "") + super().__init__(pd.DataFrame, "bq", "") def encode( self, @@ -56,7 +48,7 @@ def encode( class MockBQDecodingHandlers(StructuredDatasetDecoder): def __init__(self): - super().__init__(pd.DataFrame, BIGQUERY, "") + super().__init__(pd.DataFrame, "bq", "") def decode( self, @@ -71,6 +63,15 @@ def decode( StructuredDatasetTransformerEngine.register(MockBQDecodingHandlers(), False, True) +class NumpyRenderer: + """ + The Polars DataFrame summary statistics are rendered as an HTML table. + """ + + def to_html(self, array: np.ndarray) -> str: + return pd.DataFrame(array).describe().to_html() + + @pytest.fixture(autouse=True) def numpy_type(): class NumpyEncodingHandlers(StructuredDatasetEncoder): @@ -104,9 +105,9 @@ def decode( table = pq.read_table(local_dir) return table.to_pandas().to_numpy() - for protocol in [LOCAL, S3]: - StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray, protocol, PARQUET)) - StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray, protocol, PARQUET)) + StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray)) + StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray)) + StructuredDatasetTransformerEngine.register_renderer(np.ndarray, NumpyRenderer()) @task