From 4f864fced6457573b2a58643f5def85e7c2e1180 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:28:21 -0400 Subject: [PATCH 01/11] Fix docker warnings (#2683) * Remove warnings from dockerfiles Signed-off-by: Eduardo Apolinario * use 1.13.3 as default value in dev image Signed-off-by: Eduardo Apolinario --------- Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- Dockerfile | 10 +++++----- Dockerfile.agent | 8 ++++---- Dockerfile.dev | 10 +++++----- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2f7429c4ec..13277d7279 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,12 +1,12 @@ -ARG PYTHON_VERSION +ARG PYTHON_VERSION=3.12 FROM python:${PYTHON_VERSION}-slim-bookworm -MAINTAINER Flyte Team +LABEL org.opencontainers.image.authors="Flyte Team " LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit WORKDIR /root -ENV PYTHONPATH /root -ENV FLYTE_SDK_RICH_TRACEBACKS 0 +ENV PYTHONPATH=/root +ENV FLYTE_SDK_RICH_TRACEBACKS=0 ARG VERSION ARG DOCKER_IMAGE @@ -35,4 +35,4 @@ RUN apt-get update && apt-get install build-essential -y \ USER flytekit -ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE" +ENV FLYTE_INTERNAL_IMAGE="$DOCKER_IMAGE" diff --git a/Dockerfile.agent b/Dockerfile.agent index f9ff2ada76..e2d106f7c2 100644 --- a/Dockerfile.agent +++ b/Dockerfile.agent @@ -1,6 +1,6 @@ -FROM python:3.10-slim-bookworm as agent-slim +FROM python:3.10-slim-bookworm AS agent-slim -MAINTAINER Flyte Team +LABEL org.opencontainers.image.authors="Flyte Team " LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit ARG VERSION @@ -19,9 +19,9 @@ RUN pip install --no-cache-dir -U flytekit==$VERSION \ && rm -rf /var/lib/{apt,dpkg,cache,log}/ \ && : -CMD pyflyte serve agent --port 8000 +CMD ["pyflyte", "serve", "agent", "--port", "8000"] -FROM agent-slim as agent-all +FROM agent-slim AS agent-all ARG VERSION RUN pip install --no-cache-dir -U \ diff --git a/Dockerfile.dev b/Dockerfile.dev index 406740de27..7b32939d39 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -5,17 +5,17 @@ # From your test user code # $ pyflyte run --image localhost:30000/flytekittest:someversion -ARG PYTHON_VERSION +ARG PYTHON_VERSION=3.12 FROM python:${PYTHON_VERSION}-slim-bookworm -MAINTAINER Flyte Team +LABEL org.opencontainers.image.authors="Flyte Team " LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit WORKDIR /root -ENV FLYTE_SDK_RICH_TRACEBACKS 0 +ENV FLYTE_SDK_RICH_TRACEBACKS=0 # Flytekit version of flytekit to be installed in the image -ARG PSEUDO_VERSION +ARG PSEUDO_VERSION=1.13.3 # Note: Pod tasks should be exposed in the default image @@ -51,7 +51,7 @@ RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \ && : -ENV PYTHONPATH "/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:" +ENV PYTHONPATH="/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:" # Switch to the 'flytekit' user for better security. USER flytekit From 1cd8160a0552c308b18c210a4e11303fb645d5c0 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 14 Aug 2024 12:45:56 -0400 Subject: [PATCH 02/11] Move UV install to after the ENV is set (#2681) Signed-off-by: Thomas J. Fan --- flytekit/image_spec/default_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/image_spec/default_builder.py b/flytekit/image_spec/default_builder.py index 89bb8bd1b3..3b35214c22 100644 --- a/flytekit/image_spec/default_builder.py +++ b/flytekit/image_spec/default_builder.py @@ -59,8 +59,6 @@ -c conda-forge $CONDA_CHANNELS \ python=$PYTHON_VERSION $CONDA_PACKAGES -$UV_PYTHON_INSTALL_COMMAND - # Configure user space ENV PATH="/opt/micromamba/envs/runtime/bin:$$PATH" \ UV_LINK_MODE=copy \ @@ -69,6 +67,8 @@ SSL_CERT_DIR=/etc/ssl/certs \ $ENV +$UV_PYTHON_INSTALL_COMMAND + # Adds nvidia just in case it exists ENV PATH="$$PATH:/usr/local/nvidia/bin:/usr/local/cuda/bin" \ LD_LIBRARY_PATH="/usr/local/nvidia/lib64:$$LD_LIBRARY_PATH" From 03d23011fcf955838669bd5058c8ced17c6de3ee Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 15 Aug 2024 01:47:38 +0800 Subject: [PATCH 03/11] Remove false error inside dynamic task in local executions (#2675) Signed-off-by: Kevin Su --- flytekit/core/node_creation.py | 9 ++++++--- flytekit/core/promise.py | 3 +++ flytekit/core/python_function_task.py | 7 ++++++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index 58a72f357a..791480435f 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Union from flytekit.core.base_task import PythonTask -from flytekit.core.context_manager import BranchEvalMode, FlyteContext +from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext from flytekit.core.launch_plan import LaunchPlan from flytekit.core.node import Node from flytekit.core.promise import VoidPromise @@ -129,9 +129,12 @@ def create_node( return node # Handling local execution - # Note: execution state is set to TASK_EXECUTION when running dynamic task locally + # Note: execution state is set to DYNAMIC_TASK_EXECUTION when running a dynamic task locally # https://github.com/flyteorg/flytekit/blob/0815345faf0fae5dc26746a43d4bda4cc2cdf830/flytekit/core/python_function_task.py#L262 - elif ctx.execution_state and ctx.execution_state.is_local_execution(): + elif ctx.execution_state and ( + ctx.execution_state.is_local_execution() + or ctx.execution_state.mode == ExecutionState.Mode.DYNAMIC_TASK_EXECUTION + ): if isinstance(entity, RemoteEntity): raise AssertionError(f"Remote entities are not yet runnable locally {entity.name}") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 6bb07fee3e..847d727948 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1270,6 +1270,9 @@ def flyte_entity_call_handler( if inspect.iscoroutine(result): return result + if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.DYNAMIC_TASK_EXECUTION: + return result + if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or ( result is not None and expected_outputs == 1 ): diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 2c01723bdd..a1b863a092 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -308,7 +308,12 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: # local_execute directly though since that converts inputs into Promises. logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}") self._create_and_cache_dynamic_workflow() - function_outputs = cast(PythonFunctionWorkflow, self._wf).execute(**kwargs) + if self.execution_mode == self.ExecutionBehavior.DYNAMIC: + es = ctx.new_execution_state().with_params(mode=ExecutionState.Mode.DYNAMIC_TASK_EXECUTION) + else: + es = cast(ExecutionState, ctx.execution_state) + with FlyteContextManager.with_context(ctx.with_execution_state(es)): + function_outputs = cast(PythonFunctionWorkflow, self._wf).execute(**kwargs) if isinstance(function_outputs, VoidPromise) or function_outputs is None: return VoidPromise(self.name) From 556dad2550890fd6d9ba8570b864279096c773a8 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 15 Aug 2024 11:33:02 -0400 Subject: [PATCH 04/11] Create duckdb connection during execution (#2684) Signed-off-by: Thomas J. Fan --- .../flytekitplugins/duckdb/task.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py b/plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py index 71c15481f4..eda750fd33 100644 --- a/plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py +++ b/plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py @@ -34,9 +34,6 @@ def __init__( inputs: The query parameters to be used while executing the query """ self._query = query - # create an in-memory database that's non-persistent - self._con = duckdb.connect(":memory:") - outputs = {"result": StructuredDataset} super(DuckDBQuery, self).__init__( @@ -47,7 +44,9 @@ def __init__( **kwargs, ) - def _execute_query(self, params: list, query: str, counter: int, multiple_params: bool): + def _execute_query( + self, con: duckdb.DuckDBPyConnection, params: list, query: str, counter: int, multiple_params: bool + ): """ This method runs the DuckDBQuery. @@ -64,28 +63,32 @@ def _execute_query(self, params: list, query: str, counter: int, multiple_params raise ValueError("Parameter doesn't exist.") if "insert" in query.lower(): # run executemany disregarding the number of entries to store for an insert query - yield QueryOutput(output=self._con.executemany(query, params[counter]), counter=counter) + yield QueryOutput(output=con.executemany(query, params[counter]), counter=counter) else: - yield QueryOutput(output=self._con.execute(query, params[counter]), counter=counter) + yield QueryOutput(output=con.execute(query, params[counter]), counter=counter) else: if params: - yield QueryOutput(output=self._con.execute(query, params), counter=counter) + yield QueryOutput(output=con.execute(query, params), counter=counter) else: raise ValueError("Parameter not specified.") else: - yield QueryOutput(output=self._con.execute(query), counter=counter) + yield QueryOutput(output=con.execute(query), counter=counter) def execute(self, **kwargs) -> StructuredDataset: # TODO: Enable iterative download after adding the functionality to structured dataset code. + + # create an in-memory database that's non-persistent + con = duckdb.connect(":memory:") + params = None for key in self.python_interface.inputs.keys(): val = kwargs.get(key) if isinstance(val, StructuredDataset): # register structured dataset - self._con.register(key, val.open(pa.Table).all()) + con.register(key, val.open(pa.Table).all()) elif isinstance(val, (pd.DataFrame, pa.Table)): # register pandas dataframe/arrow table - self._con.register(key, val) + con.register(key, val) elif isinstance(val, list): # copy val into params params = val @@ -105,7 +108,11 @@ def execute(self, **kwargs) -> StructuredDataset: for query in self._query[:-1]: query_output = next( self._execute_query( - params=params, query=query, counter=query_output.counter, multiple_params=multiple_params + con=con, + params=params, + query=query, + counter=query_output.counter, + multiple_params=multiple_params, ) ) final_query = self._query[-1] @@ -114,7 +121,7 @@ def execute(self, **kwargs) -> StructuredDataset: # expecting a SELECT query dataframe = next( self._execute_query( - params=params, query=final_query, counter=query_output.counter, multiple_params=multiple_params + con=con, params=params, query=final_query, counter=query_output.counter, multiple_params=multiple_params ) ).output.arrow() From abb5219dc2a543efa0d6d6130f4f48f419604de9 Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Thu, 15 Aug 2024 14:39:21 -0400 Subject: [PATCH 05/11] Fix None deserialization bug in dataclass outputs (#2610) Signed-off-by: JackUrb --- flytekit/core/type_engine.py | 2 +- tests/flytekit/unit/core/test_type_engine.py | 61 ++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d66bc8a956..1ce6a05488 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1066,7 +1066,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type "actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then" "return v.x, instead of v, even if this has a single element" ) - if python_val is None and expected and expected.union_type is None: + if (python_val is None and python_type != type(None)) and expected and expected.union_type is None: raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}") transformer = cls.get_transformer(python_type) if transformer.type_assertions_enabled: diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 0cde27c619..a215b969b5 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3075,3 +3075,64 @@ def test_union_file_directory(): pv = union_trans.to_python_value(ctx, lv, typing.Union[FlyteFile, FlyteDirectory]) assert pv._remote_source == s3_dir + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") +def test_dataclass_none_output_input_deserialization(): + @dataclass + class OuterWorkflowInput(DataClassJSONMixin): + input: float + + @dataclass + class OuterWorkflowOutput(DataClassJSONMixin): + nullable_output: float | None = None + + + @dataclass + class InnerWorkflowInput(DataClassJSONMixin): + input: float + + @dataclass + class InnerWorkflowOutput(DataClassJSONMixin): + nullable_output: float | None = None + + + @task + def inner_task(input: float) -> float | None: + if input == 0: + return None + return input + + @task + def wrap_inner_inputs(input: float) -> InnerWorkflowInput: + return InnerWorkflowInput(input=input) + + @task + def wrap_inner_outputs(output: float | None) -> InnerWorkflowOutput: + return InnerWorkflowOutput(nullable_output=output) + + @task + def wrap_outer_outputs(output: float | None) -> OuterWorkflowOutput: + return OuterWorkflowOutput(nullable_output=output) + + @workflow + def inner_workflow(input: InnerWorkflowInput) -> InnerWorkflowOutput: + return wrap_inner_outputs( + output=inner_task( + input=input.input + ) + ) + + @workflow + def outer_workflow(input: OuterWorkflowInput) -> OuterWorkflowOutput: + inner_outputs = inner_workflow( + input=wrap_inner_inputs(input=input.input) + ) + return wrap_outer_outputs( + output=inner_outputs.nullable_output + ) + + float_value_output = outer_workflow(OuterWorkflowInput(input=1.0)).nullable_output + assert float_value_output == 1.0, f"Float value was {float_value_output}, not 1.0 as expected" + none_value_output = outer_workflow(OuterWorkflowInput(input=0)).nullable_output + assert none_value_output is None, f"None value was {none_value_output}, not None as expected" From 6ababc901801f49ef9d88289c10b61dfe61cffef Mon Sep 17 00:00:00 2001 From: rdeaton-freenome <134093844+rdeaton-freenome@users.noreply.github.com> Date: Thu, 15 Aug 2024 12:48:27 -0700 Subject: [PATCH 06/11] Fix race conditions in the Authentication client (#2635) * Fix race conditions in the Authentication cliente Signed-off-by: Robert Deaton * Update flytekit/clients/auth/auth_client.py Co-authored-by: Thomas J. Fan --------- Signed-off-by: Robert Deaton Co-authored-by: Thomas J. Fan --- flytekit/clients/auth/auth_client.py | 42 ++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/flytekit/clients/auth/auth_client.py b/flytekit/clients/auth/auth_client.py index f989736289..71cd8f0f37 100644 --- a/flytekit/clients/auth/auth_client.py +++ b/flytekit/clients/auth/auth_client.py @@ -6,6 +6,8 @@ import logging import os import re +import threading +import time import typing import urllib.parse as _urlparse import webbrowser @@ -236,6 +238,9 @@ def __init__( self._verify = verify self._headers = {"content-type": "application/x-www-form-urlencoded"} self._session = session or requests.Session() + self._lock = threading.Lock() + self._cached_credentials = None + self._cached_credentials_ts = None self._request_auth_code_params = { "client_id": client_id, # This must match the Client ID of the OAuth application. @@ -339,25 +344,38 @@ def _request_access_token(self, auth_code) -> Credentials: def get_creds_from_remote(self) -> Credentials: """ - This is the entrypoint method. It will kickoff the full authentication flow and trigger a web-browser to - retrieve credentials + This is the entrypoint method. It will kickoff the full authentication + flow and trigger a web-browser to retrieve credentials. Because this + needs to open a port on localhost and may be called from a + multithreaded context (e.g. pyflyte register), this call may block + multiple threads and return a cached result for up to 60 seconds. """ # In the absence of globally-set token values, initiate the token request flow - q = Queue() + with self._lock: + # Clear cache if it's been more than 60 seconds since the last check + cache_ttl_s = 60 + if self._cached_credentials_ts is not None and self._cached_credentials_ts + cache_ttl_s < time.monotonic(): + self._cached_credentials = None - # First prepare the callback server in the background - server = self._create_callback_server() + if self._cached_credentials is not None: + return self._cached_credentials + q = Queue() - self._request_authorization_code() + # First prepare the callback server in the background + server = self._create_callback_server() - server.handle_request(q) - server.server_close() + self._request_authorization_code() - # Send the call to request the authorization code in the background + server.handle_request(q) + server.server_close() - # Request the access token once the auth code has been received. - auth_code = q.get() - return self._request_access_token(auth_code) + # Send the call to request the authorization code in the background + + # Request the access token once the auth code has been received. + auth_code = q.get() + self._cached_credentials = self._request_access_token(auth_code) + self._cached_credentials_ts = time.monotonic() + return self._cached_credentials def refresh_access_token(self, credentials: Credentials) -> Credentials: if credentials.refresh_token is None: From 620a449b4a256b3bcb251fb1899950953a0906f0 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 16 Aug 2024 11:08:35 -0400 Subject: [PATCH 07/11] Update uv to 0.2.37 (#2687) Signed-off-by: Thomas J. Fan --- flytekit/image_spec/default_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/image_spec/default_builder.py b/flytekit/image_spec/default_builder.py index 3b35214c22..50fcc4ea8a 100644 --- a/flytekit/image_spec/default_builder.py +++ b/flytekit/image_spec/default_builder.py @@ -39,7 +39,7 @@ DOCKER_FILE_TEMPLATE = Template( """\ #syntax=docker/dockerfile:1.5 -FROM ghcr.io/astral-sh/uv:0.2.35 as uv +FROM ghcr.io/astral-sh/uv:0.2.37 as uv FROM mambaorg/micromamba:1.5.8-bookworm-slim as micromamba FROM $BASE_IMAGE From a8f68d724ff59585d45e4448025ffc2fd6864c1b Mon Sep 17 00:00:00 2001 From: Vincent Chen <62143443+mao3267@users.noreply.github.com> Date: Sun, 18 Aug 2024 03:46:43 +0800 Subject: [PATCH 08/11] Input through file and pipe (#2552) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: mao3267 --------- Signed-off-by: mao3267 Signed-off-by: Kevin Su Signed-off-by: pryce-turner Signed-off-by: ggydush Signed-off-by: Eduardo Apolinario Signed-off-by: ddl-rliu Signed-off-by: Thomas J. Fan Signed-off-by: Future-Outlier Signed-off-by: novahow Signed-off-by: Mecoli1219 Signed-off-by: Fabio Grätz Signed-off-by: bugra.gedik Signed-off-by: Thomas Newton Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Signed-off-by: dependabot[bot] Signed-off-by: Samhita Alla Signed-off-by: Peeter Piegaze <1153481+ppiegaze@users.noreply.github.com> Signed-off-by: Felix Ruess Signed-off-by: Ketan Umare Signed-off-by: Yee Hing Tong Signed-off-by: aditya7302 Signed-off-by: Jan Fiedler Signed-off-by: JackUrb Signed-off-by: Paul Dittamo Signed-off-by: Robert Deaton Co-authored-by: Kevin Su Co-authored-by: pryce-turner <31577879+pryce-turner@users.noreply.github.com> Co-authored-by: Greg Gydush <35151789+ggydush@users.noreply.github.com> Co-authored-by: Eduardo Apolinario Co-authored-by: ddl-rliu <140021987+ddl-rliu@users.noreply.github.com> Co-authored-by: Chi-Sheng Liu Co-authored-by: Thomas J. Fan Co-authored-by: Future-Outlier Co-authored-by: novahow <58504997+novahow@users.noreply.github.com> Co-authored-by: Chun-Mao Lai <72752478+Mecoli1219@users.noreply.github.com> Co-authored-by: Fabio M. Graetz, Ph.D Co-authored-by: Fabio Grätz Co-authored-by: Buğra Gedik Co-authored-by: bugra.gedik Co-authored-by: Thomas Newton Co-authored-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Samhita Alla Co-authored-by: Peeter Piegaze <1153481+ppiegaze@users.noreply.github.com> Co-authored-by: Felix Ruess Co-authored-by: Ketan Umare <16888709+kumare3@users.noreply.github.com> Co-authored-by: Ketan Umare Co-authored-by: Paul Dittamo <37558497+pvditt@users.noreply.github.com> Co-authored-by: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Co-authored-by: Yee Hing Tong Co-authored-by: Aditya Garg <110886184+aditya7302@users.noreply.github.com> Co-authored-by: Jan Fiedler <89976021+fiedlerNr9@users.noreply.github.com> Co-authored-by: Jack Urbanek Co-authored-by: rdeaton-freenome <134093844+rdeaton-freenome@users.noreply.github.com> --- flytekit/clis/sdk_in_container/run.py | 107 ++++++++++++++++-- flytekit/core/interface.py | 21 +++- flytekit/image_spec/default_builder.py | 18 ++- .../flytekitplugins/kfpytorch/task.py | 22 +++- .../tests/test_elastic_task.py | 25 ++-- .../integration/remote/test_remote.py | 21 ++-- .../unit/cli/pyflyte/my_wf_input.json | 47 ++++++++ .../unit/cli/pyflyte/my_wf_input.yaml | 34 ++++++ tests/flytekit/unit/cli/pyflyte/test_run.py | 98 +++++++++++++++- tests/flytekit/unit/cli/pyflyte/workflow.py | 8 ++ 10 files changed, 349 insertions(+), 52 deletions(-) create mode 100644 tests/flytekit/unit/cli/pyflyte/my_wf_input.json create mode 100644 tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index d8c215a598..ed46a29583 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -7,10 +7,13 @@ import sys import tempfile import typing +import typing as t from dataclasses import dataclass, field, fields from typing import Iterator, get_args import rich_click as click +import yaml +from click import Context from mashumaro.codecs.json import JSONEncoder from rich.progress import Progress from typing_extensions import get_origin @@ -25,7 +28,12 @@ pretty_print_exception, project_option, ) -from flytekit.configuration import DefaultImages, FastSerializationSettings, ImageConfig, SerializationSettings +from flytekit.configuration import ( + DefaultImages, + FastSerializationSettings, + ImageConfig, + SerializationSettings, +) from flytekit.configuration.plugin import get_plugin from flytekit.core import context_manager from flytekit.core.artifact import ArtifactQuery @@ -34,14 +42,24 @@ from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.exceptions.system import FlyteSystemException -from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback, labels_callback +from flytekit.interaction.click_types import ( + FlyteLiteralConverter, + key_value_callback, + labels_callback, +) from flytekit.interaction.string_literals import literal_string_repr from flytekit.loggers import logger from flytekit.models import security from flytekit.models.common import RawOutputDataConfig from flytekit.models.interface import Parameter, Variable from flytekit.models.types import SimpleType -from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask, FlyteWorkflow, remote_fs +from flytekit.remote import ( + FlyteLaunchPlan, + FlyteRemote, + FlyteTask, + FlyteWorkflow, + remote_fs, +) from flytekit.remote.executions import FlyteWorkflowExecution from flytekit.tools import module_loader from flytekit.tools.script_mode import _find_project_root, compress_scripts, get_all_modules @@ -489,7 +507,8 @@ def _update_flyte_context(params: RunLevelParams) -> FlyteContext.Builder: return ctx.current_context().new_builder() file_access = FileAccessProvider( - local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), raw_output_prefix=output_prefix + local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), + raw_output_prefix=output_prefix, ) # The task might run on a remote machine if raw_output_prefix is a remote path, @@ -539,7 +558,10 @@ def _run(*args, **kwargs): entity_type = "workflow" if isinstance(entity, PythonFunctionWorkflow) else "task" logger.debug(f"Running {entity_type} {entity.name} with input {kwargs}") - click.secho(f"Running Execution on {'Remote' if run_level_params.is_remote else 'local'}.", fg="cyan") + click.secho( + f"Running Execution on {'Remote' if run_level_params.is_remote else 'local'}.", + fg="cyan", + ) try: inputs = {} for input_name, v in entity.python_interface.inputs_with_defaults.items(): @@ -576,6 +598,8 @@ def _run(*args, **kwargs): ) if processed_click_value is not None or optional_v: inputs[input_name] = processed_click_value + if processed_click_value is None and v[0] == bool: + inputs[input_name] = False if not run_level_params.is_remote: with FlyteContextManager.with_context(_update_flyte_context(run_level_params)): @@ -755,7 +779,10 @@ def list_commands(self, ctx): run_level_params: RunLevelParams = ctx.obj r = run_level_params.remote_instance() progress = Progress(transient=True) - task = progress.add_task(f"[cyan]Gathering [{run_level_params.limit}] remote LaunchPlans...", total=None) + task = progress.add_task( + f"[cyan]Gathering [{run_level_params.limit}] remote LaunchPlans...", + total=None, + ) with progress: progress.start_task(task) try: @@ -783,6 +810,70 @@ def get_command(self, ctx, name): ) +class YamlFileReadingCommand(click.RichCommand): + def __init__( + self, + name: str, + params: typing.List[click.Option], + help: str, + callback: typing.Callable = None, + ): + params.append( + click.Option( + ["--inputs-file"], + required=False, + type=click.Path(exists=True, dir_okay=False, resolve_path=True), + help="Path to a YAML | JSON file containing inputs for the workflow.", + ) + ) + super().__init__(name=name, params=params, callback=callback, help=help) + + def parse_args(self, ctx: Context, args: t.List[str]) -> t.List[str]: + def load_inputs(f: str) -> t.Dict[str, str]: + try: + inputs = yaml.safe_load(f) + except yaml.YAMLError as e: + yaml_e = e + try: + inputs = json.loads(f) + except json.JSONDecodeError as e: + raise click.BadParameter( + message=f"Could not load the inputs file. Please make sure it is a valid JSON or YAML file." + f"\n json error: {e}," + f"\n yaml error: {yaml_e}", + param_hint="--inputs-file", + ) + + return inputs + + inputs = {} + if "--inputs-file" in args: + idx = args.index("--inputs-file") + args.pop(idx) + f = args.pop(idx) + with open(f, "r") as f: + inputs = load_inputs(f.read()) + elif not sys.stdin.isatty(): + f = sys.stdin.read() + if f != "": + inputs = load_inputs(f) + + new_args = [] + for k, v in inputs.items(): + if isinstance(v, str): + new_args.extend([f"--{k}", v]) + elif isinstance(v, bool): + if v: + new_args.append(f"--{k}") + else: + v = json.dumps(v) + new_args.extend([f"--{k}", v]) + new_args.extend(args) + args = new_args + + return super().parse_args(ctx, args) + + class WorkflowCommand(click.RichGroup): """ click multicommand at the python file layer, subcommands should be all the workflows in the file. @@ -837,11 +928,11 @@ def _create_command( h = f"{click.style(entity_type, bold=True)} ({run_level_params.computed_params.module}.{entity_name})" if loaded_entity.__doc__: h = h + click.style(f"{loaded_entity.__doc__}", dim=True) - cmd = click.RichCommand( + cmd = YamlFileReadingCommand( name=entity_name, params=params, - callback=run_command(ctx, loaded_entity), help=h, + callback=run_command(ctx, loaded_entity), ) return cmd diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 8124f617b3..cbfd08ae2f 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -6,7 +6,18 @@ import sys import typing from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast +from typing import ( + Any, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) from flyteidl.core import artifact_id_pb2 as art_id from typing_extensions import get_args, get_type_hints @@ -370,7 +381,9 @@ def transform_interface_to_list_interface( def transform_function_to_interface( - fn: typing.Callable, docstring: Optional[Docstring] = None, is_reference_entity: bool = False + fn: typing.Callable, + docstring: Optional[Docstring] = None, + is_reference_entity: bool = False, ) -> Interface: """ From the annotations on a task function that the user should have provided, and the output names they want to use @@ -463,7 +476,9 @@ def transform_type(x: type, description: Optional[str] = None) -> _interface_mod if artifact_id: logger.debug(f"Found artifact id spec: {artifact_id}") return _interface_models.Variable( - type=TypeEngine.to_literal_type(x), description=description, artifact_partial_id=artifact_id + type=TypeEngine.to_literal_type(x), + description=description, + artifact_partial_id=artifact_id, ) diff --git a/flytekit/image_spec/default_builder.py b/flytekit/image_spec/default_builder.py index 50fcc4ea8a..32f20d6373 100644 --- a/flytekit/image_spec/default_builder.py +++ b/flytekit/image_spec/default_builder.py @@ -19,25 +19,24 @@ ) from flytekit.tools.ignore import DockerIgnore, GitIgnore, IgnoreGroup, StandardIgnore -UV_PYTHON_INSTALL_COMMAND_TEMPLATE = Template("""\ +UV_PYTHON_INSTALL_COMMAND_TEMPLATE = Template( + """\ RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/uv,id=uv \ --mount=from=uv,source=/uv,target=/usr/bin/uv \ --mount=type=bind,target=requirements_uv.txt,src=requirements_uv.txt \ /usr/bin/uv \ pip install --python /opt/micromamba/envs/runtime/bin/python $PIP_EXTRA \ --requirement requirements_uv.txt -""") +""" +) -APT_INSTALL_COMMAND_TEMPLATE = Template( - """\ +APT_INSTALL_COMMAND_TEMPLATE = Template("""\ RUN --mount=type=cache,sharing=locked,mode=0777,target=/var/cache/apt,id=apt \ apt-get update && apt-get install -y --no-install-recommends \ $APT_PACKAGES -""" -) +""") -DOCKER_FILE_TEMPLATE = Template( - """\ +DOCKER_FILE_TEMPLATE = Template("""\ #syntax=docker/dockerfile:1.5 FROM ghcr.io/astral-sh/uv:0.2.37 as uv FROM mambaorg/micromamba:1.5.8-bookworm-slim as micromamba @@ -84,8 +83,7 @@ USER flytekit RUN mkdir -p $$HOME && \ echo "export PATH=$$PATH" >> $$HOME/.profile -""" -) +""") def get_flytekit_for_pypi(): diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 966425f901..c50d7f0984 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -206,7 +206,7 @@ def _convert_replica_spec( replicas=replicas, image=replica_config.image, resources=resources.to_flyte_idl() if resources else None, - restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + restart_policy=(replica_config.restart_policy.value if replica_config.restart_policy else None), ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: @@ -289,9 +289,11 @@ def spawn_helper( return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=om) -def _convert_run_policy_to_flyte_idl(run_policy: RunPolicy) -> kubeflow_common.RunPolicy: +def _convert_run_policy_to_flyte_idl( + run_policy: RunPolicy, +) -> kubeflow_common.RunPolicy: return kubeflow_common.RunPolicy( - clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None, + clean_pod_policy=(run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None), ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, active_deadline_seconds=run_policy.active_deadline_seconds, backoff_limit=run_policy.backoff_limit, @@ -416,7 +418,13 @@ def _execute(self, **kwargs) -> Any: checkpoint_dest = None checkpoint_src = None - launcher_args = (dumped_target_function, ctx.raw_output_prefix, checkpoint_dest, checkpoint_src, kwargs) + launcher_args = ( + dumped_target_function, + ctx.raw_output_prefix, + checkpoint_dest, + checkpoint_src, + kwargs, + ) elif self.task_config.start_method == "fork": """ The torch elastic launcher doesn't support passing kwargs to the target function, @@ -440,7 +448,11 @@ def fn_partial(): if isinstance(e, FlyteRecoverableException): create_recoverable_error_file() raise - return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=om) + return ElasticWorkerResult( + return_value=return_val, + decks=flytekit.current_context().decks, + om=om, + ) launcher_target_func = fn_partial launcher_args = () diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index 39f1e0bb80..faadc1019f 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -62,7 +62,7 @@ def test_end_to_end(start_method: str) -> None: """Test that the workflow with elastic task runs end to end.""" world_size = 2 - train_task = task(train, task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method)) + train_task = task(train,task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method)) @workflow def wf(config: Config = Config()) -> typing.Tuple[str, Config, torch.nn.Module, int]: @@ -89,9 +89,7 @@ def wf(config: Config = Config()) -> typing.Tuple[str, Config, torch.nn.Module, ("fork", "local", False), ], ) -def test_execution_params( - start_method: str, target_exec_id: str, monkeypatch_exec_id_env_var: bool, monkeypatch -) -> None: +def test_execution_params(start_method: str, target_exec_id: str, monkeypatch_exec_id_env_var: bool, monkeypatch) -> None: """Test that execution parameters are set in the worker processes.""" if monkeypatch_exec_id_env_var: monkeypatch.setenv("FLYTE_INTERNAL_EXECUTION_ID", target_exec_id) @@ -117,7 +115,7 @@ def test_rdzv_configs(start_method: str) -> None: rdzv_configs = {"join_timeout": 10} - @task(task_config=Elastic(nnodes=1, nproc_per_node=2, start_method=start_method, rdzv_configs=rdzv_configs)) + @task(task_config=Elastic(nnodes=1,nproc_per_node=2,start_method=start_method,rdzv_configs=rdzv_configs)) def test_task(): pass @@ -131,15 +129,12 @@ def test_deck(start_method: str) -> None: """Test that decks created in the main worker process are transferred to the parent process.""" world_size = 2 - @task( - task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method), - enable_deck=True, - ) + @task(task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method), enable_deck=True) def train(): import os ctx = flytekit.current_context() - deck = flytekit.Deck("test-deck", f"Hello Flyte Deck viewer from worker process {os.environ.get('RANK')}") + deck = flytekit.Deck("test-deck", f"Hello Flyte Deck viewer from worker process {os.environ.get('RANK')}",) ctx.decks.append(deck) default_deck = ctx.default_deck default_deck.append("Hello from default deck") @@ -189,9 +184,7 @@ def wf(): ctx = FlyteContext.current_context() omt = OutputMetadataTracker() - with FlyteContextManager.with_context( - ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_TASK_EXECUTION)).with_output_metadata_tracker(omt) - ) as child_ctx: + with FlyteContextManager.with_context(ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_TASK_EXECUTION)).with_output_metadata_tracker(omt)) as child_ctx: cast(ExecutionParameters, child_ctx.user_space_params)._decks = [] # call execute directly so as to be able to get at the same FlyteContext object. res = train2.execute() @@ -215,9 +208,7 @@ def test_recoverable_error(recoverable: bool, start_method: str) -> None: class CustomRecoverableException(FlyteRecoverableException): pass - @task( - task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method), - ) + @task(task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method)) def train(recoverable: bool): if recoverable: raise CustomRecoverableException("Recoverable error") @@ -244,7 +235,6 @@ def test_task(): assert test_task.task_config.rdzv_configs == {"join_timeout": 900, "timeout": 900} - def test_run_policy() -> None: """Test that run policy is propagated to custom spec.""" @@ -268,6 +258,7 @@ def test_task(): "activeDeadlineSeconds": 36000, } + @pytest.mark.parametrize("start_method", ["spawn", "fork"]) def test_omp_num_threads(start_method: str) -> None: """Test that the env var OMP_NUM_THREADS is set by default and not overwritten if set.""" diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 7e0661f808..ef47aa3529 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -100,7 +100,10 @@ def test_fetch_execute_launch_plan_with_args(register): flyte_launch_plan = remote.fetch_launch_plan(name="basic.basic_workflow.my_wf", version=VERSION) execution = remote.execute(flyte_launch_plan, inputs={"a": 10, "b": "foobar"}, wait=True) assert execution.node_executions["n0"].inputs == {"a": 10} - assert execution.node_executions["n0"].outputs == {"t1_int_output": 12, "c": "world"} + assert execution.node_executions["n0"].outputs == { + "t1_int_output": 12, + "c": "world", + } assert execution.node_executions["n1"].inputs == {"a": "world", "b": "foobar"} assert execution.node_executions["n1"].outputs == {"o0": "foobarworld"} assert execution.node_executions["n0"].task_executions[0].inputs == {"a": 10} @@ -130,7 +133,7 @@ def test_monitor_workflow_execution(register): break with pytest.raises( - FlyteAssertion, match="Please wait until the execution has completed before requesting the outputs." + FlyteAssertion, match="Please wait until the execution has completed before requesting the outputs.", ): execution.outputs @@ -241,7 +244,11 @@ def test_execute_python_workflow_and_launch_plan(register): launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name) execution = remote.execute( - launch_plan, name="basic.basic_workflow.my_wf", inputs={"a": 14, "b": "foobar"}, version=VERSION, wait=True + launch_plan, + name="basic.basic_workflow.my_wf", + inputs={"a": 14, "b": "foobar"}, + version=VERSION, + wait=True, ) assert execution.outputs["o0"] == 16 assert execution.outputs["o1"] == "foobarworld" @@ -269,7 +276,9 @@ def test_fetch_execute_task_list_of_floats(register): def test_fetch_execute_task_convert_dict(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) - flyte_task = remote.fetch_task(name="basic.dict_str_wf.convert_to_string", version=VERSION) + flyte_task = remote.fetch_task( + name="basic.dict_str_wf.convert_to_string", version=VERSION + ) d: typing.Dict[str, str] = {"key1": "value1", "key2": "value2"} execution = remote.execute(flyte_task, inputs={"d": d}, wait=True) remote.sync_execution(execution, sync_nodes=True) @@ -374,9 +383,7 @@ def test_execute_with_default_launch_plan(register): from .workflows.basic.subworkflows import parent_wf remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) - execution = remote.execute( - parent_wf, inputs={"a": 101}, version=VERSION, wait=True, image_config=ImageConfig.auto(img_name=IMAGE) - ) + execution = remote.execute(parent_wf, inputs={"a": 101}, version=VERSION, wait=True, image_config=ImageConfig.auto(img_name=IMAGE)) # check node execution inputs and outputs assert execution.node_executions["n0"].inputs == {"a": 101} assert execution.node_executions["n0"].outputs == {"t1_int_output": 103, "c": "world"} diff --git a/tests/flytekit/unit/cli/pyflyte/my_wf_input.json b/tests/flytekit/unit/cli/pyflyte/my_wf_input.json new file mode 100644 index 0000000000..c20081f3b2 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/my_wf_input.json @@ -0,0 +1,47 @@ +{ + "a": 1, + "b": "Hello", + "c": 1.1, + "d": { + "i": 1, + "a": [ + "h", + "e" + ] + }, + "e": [ + 1, + 2, + 3 + ], + "f": { + "x": 1.0, + "y": 2.0 + }, + "g": "tests/flytekit/unit/cli/pyflyte/testdata/df.parquet", + "h": true, + "i": "2020-05-01", + "j": "20H", + "k": "RED", + "l": { + "hello": "world" + }, + "m": { + "a": "b", + "c": "d" + }, + "n": [ + { + "x": "tests/flytekit/unit/cli/pyflyte/testdata/df.parquet" + } + ], + "o": { + "x": [ + "tests/flytekit/unit/cli/pyflyte/testdata/df.parquet" + ] + }, + "p": "None", + "q": "tests/flytekit/unit/cli/pyflyte/testdata", + "remote": "tests/flytekit/unit/cli/pyflyte/testdata", + "image": "tests/flytekit/unit/cli/pyflyte/testdata" +} diff --git a/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml b/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml new file mode 100644 index 0000000000..678f5331c8 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml @@ -0,0 +1,34 @@ +a: 1 +b: Hello +c: 1.1 +d: + i: 1 + a: + - h + - e +e: + - 1 + - 2 + - 3 +f: + x: 1.0 + y: 2.0 +g: tests/flytekit/unit/cli/pyflyte/testdata/df.parquet +h: true +i: '2020-05-01' +j: 20H +k: RED +l: + hello: world +m: + a: b + c: d +n: + - x: tests/flytekit/unit/cli/pyflyte/testdata/df.parquet +o: + x: + - tests/flytekit/unit/cli/pyflyte/testdata/df.parquet +p: 'None' +q: tests/flytekit/unit/cli/pyflyte/testdata +remote: tests/flytekit/unit/cli/pyflyte/testdata +image: tests/flytekit/unit/cli/pyflyte/testdata diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 3eb3062de9..475fb42ff1 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -4,6 +4,7 @@ import pathlib import shutil import sys +import io import mock import pytest @@ -39,6 +40,8 @@ ) DIR_NAME = os.path.dirname(os.path.realpath(__file__)) +monkeypatch = pytest.MonkeyPatch() + class WorkflowFileLocation(enum.Enum): NORMAL = enum.auto() @@ -230,6 +233,92 @@ def test_union_type1(input): assert result.exit_code == 0 +def test_all_types_with_json_input(): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(DIR_NAME, "workflow.py"), + "my_wf", + "--inputs-file", + os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"), + ], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.stdout + + +def test_all_types_with_yaml_input(): + runner = CliRunner() + + result = runner.invoke( + pyflyte.main, + ["run", os.path.join(DIR_NAME, "workflow.py"), "my_wf", "--inputs-file", os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.yaml")], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.stdout + + +def test_all_types_with_pipe_input(monkeypatch): + runner = CliRunner() + input= str(json.load(open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"),"r"))) + monkeypatch.setattr("sys.stdin", io.StringIO(input)) + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(DIR_NAME, "workflow.py"), + "my_wf", + ], + input=input, + catch_exceptions=False, + ) + assert result.exit_code == 0, result.stdout + + +@pytest.mark.parametrize( + "pipe_input, option_input", + [ + ( + str( + json.load( + open( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "my_wf_input.json", + ), + "r", + ) + ) + ), + "GREEN", + ) + ], +) +def test_replace_file_inputs(monkeypatch, pipe_input, option_input): + runner = CliRunner() + monkeypatch.setattr("sys.stdin", io.StringIO(pipe_input)) + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(DIR_NAME, "workflow.py"), + "my_wf", + "--inputs-file", + os.path.join( + os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json" + ), + "--k", + option_input, + ], + input=pipe_input, + ) + + assert result.exit_code == 0 + assert option_input in result.output + + @pytest.mark.parametrize( "input", [2.0, '{"i":1,"a":["h","e"]}', "[1, 2, 3]"], @@ -276,7 +365,9 @@ def test_union_type_with_invalid_input(): assert result.exit_code == 2 -@pytest.mark.skipif(sys.version_info < (3, 9), reason="listing entities requires python>=3.9") +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="listing entities requires python>=3.9" +) @pytest.mark.parametrize( "workflow_file", [ @@ -287,12 +378,13 @@ def test_union_type_with_invalid_input(): ) def test_get_entities_in_file(workflow_file): e = get_entities_in_file(pathlib.Path(workflow_file), False) - assert e.workflows == ["my_wf", "wf_with_env_vars", "wf_with_none"] + assert e.workflows == ["my_wf", "wf_with_env_vars", "wf_with_list", "wf_with_none"] assert e.tasks == [ "get_subset_df", "print_all", "show_sd", "task_with_env_vars", + "task_with_list", "task_with_optional", "test_union1", "test_union2", @@ -300,11 +392,13 @@ def test_get_entities_in_file(workflow_file): assert e.all() == [ "my_wf", "wf_with_env_vars", + "wf_with_list", "wf_with_none", "get_subset_df", "print_all", "show_sd", "task_with_env_vars", + "task_with_list", "task_with_optional", "test_union1", "test_union2", diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index 95535d2fc0..accebf82df 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -125,3 +125,11 @@ def task_with_env_vars(env_vars: typing.List[str]) -> str: @workflow def wf_with_env_vars(env_vars: typing.List[str]) -> str: return task_with_env_vars(env_vars=env_vars) + +@task +def task_with_list(a: typing.List[int]) -> typing.List[int]: + return a + +@workflow +def wf_with_list(a: typing.List[int]) -> typing.List[int]: + return task_with_list(a=a) From 172af7aede6eb94a79cdc3c0a446f69906e7832c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 20 Aug 2024 12:06:29 -0700 Subject: [PATCH 09/11] Improve error message for get signed url failure (#2679) Signed-off-by: Kevin Su --- flytekit/clients/friendly.py | 2 +- flytekit/clients/grpc_utils/wrap_exception_interceptor.py | 4 +++- flytekit/clis/sdk_in_container/utils.py | 8 ++++---- flytekit/exceptions/system.py | 7 +++++++ 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 58038d12ec..2110dc3d08 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -1021,7 +1021,7 @@ def get_upload_signed_url( ) ) except Exception as e: - raise RuntimeError(f"Failed to get signed url for {filename}, reason: {e}") + raise RuntimeError(f"Failed to get signed url for {filename}.") from e def get_download_signed_url( self, native_url: str, expires_in: datetime.timedelta = None diff --git a/flytekit/clients/grpc_utils/wrap_exception_interceptor.py b/flytekit/clients/grpc_utils/wrap_exception_interceptor.py index ea796f464a..bae147659e 100644 --- a/flytekit/clients/grpc_utils/wrap_exception_interceptor.py +++ b/flytekit/clients/grpc_utils/wrap_exception_interceptor.py @@ -4,7 +4,7 @@ import grpc from flytekit.exceptions.base import FlyteException -from flytekit.exceptions.system import FlyteSystemException +from flytekit.exceptions.system import FlyteSystemException, FlyteSystemUnavailableException from flytekit.exceptions.user import ( FlyteAuthenticationException, FlyteEntityAlreadyExistsException, @@ -28,6 +28,8 @@ def _raise_if_exc(request: typing.Any, e: Union[grpc.Call, grpc.Future]): raise FlyteEntityNotExistException() from e elif e.code() == grpc.StatusCode.INVALID_ARGUMENT: raise FlyteInvalidInputException(request) from e + elif e.code() == grpc.StatusCode.UNAVAILABLE: + raise FlyteSystemUnavailableException() from e raise FlyteSystemException() from e def intercept_unary_unary(self, continuation, client_call_details, request): diff --git a/flytekit/clis/sdk_in_container/utils.py b/flytekit/clis/sdk_in_container/utils.py index 5b89870d45..c31b1e6502 100644 --- a/flytekit/clis/sdk_in_container/utils.py +++ b/flytekit/clis/sdk_in_container/utils.py @@ -81,7 +81,7 @@ def pretty_print_grpc_error(e: grpc.RpcError): """ if isinstance(e, grpc._channel._InactiveRpcError): # noqa click.secho(f"RPC Failed, with Status: {e.code()}", fg="red", bold=True) - click.secho(f"\tdetails: {e.details()}", fg="magenta", bold=True) + click.secho(f"\tDetails: {e.details()}", fg="magenta", bold=True) return @@ -113,7 +113,6 @@ def pretty_print_traceback(e: Exception, verbosity: int = 1): Print the traceback in a nice formatted way if verbose is set to True. """ console = Console() - tb = e.__cause__.__traceback__ if e.__cause__ else e.__traceback__ if verbosity == 0: console.print(Traceback.from_exception(type(e), e, None)) @@ -124,10 +123,11 @@ def pretty_print_traceback(e: Exception, verbosity: int = 1): f" For more verbose output, use the flags -vv or -vvv.", fg="yellow", ) - new_tb = remove_unwanted_traceback_frames(tb, unwanted_module_names) + + new_tb = remove_unwanted_traceback_frames(e.__traceback__, unwanted_module_names) console.print(Traceback.from_exception(type(e), e, new_tb)) elif verbosity >= 2: - console.print(Traceback.from_exception(type(e), e, tb)) + console.print(Traceback.from_exception(type(e), e, e.__traceback__)) else: raise ValueError(f"Verbosity level must be between 0 and 2. Got {verbosity}") diff --git a/flytekit/exceptions/system.py b/flytekit/exceptions/system.py index 63fe55f0b9..d965d129d7 100644 --- a/flytekit/exceptions/system.py +++ b/flytekit/exceptions/system.py @@ -5,6 +5,13 @@ class FlyteSystemException(_base_exceptions.FlyteRecoverableException): _ERROR_CODE = "SYSTEM:Unknown" +class FlyteSystemUnavailableException(FlyteSystemException): + _ERROR_CODE = "SYSTEM:Unavailable" + + def __str__(self): + return "Flyte cluster is currently unavailable. Please make sure the cluster is up and running." + + class FlyteNotImplementedException(FlyteSystemException, NotImplementedError): _ERROR_CODE = "SYSTEM:NotImplemented" From 6bcedc366e1e03c61ffb13a6b650df40dfd9156f Mon Sep 17 00:00:00 2001 From: arbaobao Date: Wed, 21 Aug 2024 04:51:40 +0800 Subject: [PATCH 10/11] Add pythonpath "." before loading modules (#2673) Signed-off-by: Nelson Chen --- Dockerfile.dev | 3 --- flytekit/bin/entrypoint.py | 7 +++++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Dockerfile.dev b/Dockerfile.dev index 7b32939d39..760648d110 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -50,8 +50,5 @@ RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \ && chown flytekit: /home \ && : - -ENV PYTHONPATH="/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:" - # Switch to the 'flytekit' user for better security. USER flytekit diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index e13650ee63..edbd0c10ea 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -6,6 +6,7 @@ import pathlib import signal import subprocess +import sys import tempfile import traceback from sys import exit @@ -376,6 +377,9 @@ def _execute_task( dynamic_addl_distro, dynamic_dest_dir, ) as ctx: + working_dir = os.getcwd() + if all(os.path.realpath(path) != working_dir for path in sys.path): + sys.path.append(working_dir) resolver_obj = load_object_from_module(resolver) # Use the resolver to load the actual task object _task_def = resolver_obj.load_task(loader_args=resolver_args) @@ -424,6 +428,9 @@ def _execute_map_task( with setup_execution( raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir ) as ctx: + working_dir = os.getcwd() + if all(os.path.realpath(path) != working_dir for path in sys.path): + sys.path.append(working_dir) task_index = _compute_array_job_index() mtr = load_object_from_module(resolver)() map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency) From e3036f0d82ef9c73d1095d32cc65088066b784f8 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 20 Aug 2024 18:16:57 -0700 Subject: [PATCH 11/11] Better error message for FailureNodeInputMismatch error (#2693) Signed-off-by: Kevin Su --- flytekit/core/workflow.py | 21 ++++++++++- flytekit/exceptions/user.py | 22 ++++++++++++ tests/flytekit/unit/core/test_type_hints.py | 39 ++++++++++++++++++++- 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 5d2ef6f2a5..4abd07a007 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -8,6 +8,8 @@ from functools import update_wrapper from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload +from typing_inspect import is_optional_type + try: from typing import ParamSpec except ImportError: @@ -47,7 +49,11 @@ from flytekit.core.tracker import extract_task_module from flytekit.core.type_engine import TypeEngine from flytekit.exceptions import scopes as exception_scopes -from flytekit.exceptions.user import FlyteValidationException, FlyteValueException +from flytekit.exceptions.user import ( + FlyteFailureNodeInputMismatchException, + FlyteValidationException, + FlyteValueException, +) from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models @@ -689,6 +695,19 @@ def _validate_add_on_failure_handler(self, ctx: FlyteContext, prefix: str, wf_ar ) as inner_comp_ctx: # Now lets compile the failure-node if it exists if self.on_failure: + if self.on_failure.python_interface and self.python_interface: + workflow_inputs = self.python_interface.inputs + failure_node_inputs = self.on_failure.python_interface.inputs + + # Workflow inputs should be a subset of failure node inputs. + if (failure_node_inputs | workflow_inputs) != failure_node_inputs: + raise FlyteFailureNodeInputMismatchException(self.on_failure, self) + additional_keys = failure_node_inputs.keys() - workflow_inputs.keys() + # Raising an error if the additional inputs in the failure node are not optional. + for k in additional_keys: + if not is_optional_type(failure_node_inputs[k]): + raise FlyteFailureNodeInputMismatchException(self.on_failure, self) + c = wf_args.copy() exception_scopes.user_entry_point(self.on_failure)(**c) inner_nodes = None diff --git a/flytekit/exceptions/user.py b/flytekit/exceptions/user.py index 645754dc35..6637c8d573 100644 --- a/flytekit/exceptions/user.py +++ b/flytekit/exceptions/user.py @@ -3,6 +3,10 @@ from flytekit.exceptions.base import FlyteException as _FlyteException from flytekit.exceptions.base import FlyteRecoverableException as _Recoverable +if typing.TYPE_CHECKING: + from flytekit.core.base_task import Task + from flytekit.core.workflow import WorkflowBase + class FlyteUserException(_FlyteException): _ERROR_CODE = "USER:Unknown" @@ -68,6 +72,24 @@ class FlyteValidationException(FlyteAssertion): _ERROR_CODE = "USER:ValidationError" +class FlyteFailureNodeInputMismatchException(FlyteAssertion): + _ERROR_CODE = "USER:FailureNodeInputMismatch" + + def __init__(self, failure_node_node: typing.Union["WorkflowBase", "Task"], workflow: "WorkflowBase"): + self.failure_node_node = failure_node_node + self.workflow = workflow + + def __str__(self): + return ( + f"Mismatched Inputs Detected\n" + f"The failure node `{self.failure_node_node.name}` has inputs that do not align with those expected by the workflow `{self.workflow.name}`.\n" + f"Failure Node's Inputs: {self.failure_node_node.python_interface.inputs}\n" + f"Workflow's Inputs: {self.workflow.python_interface.inputs}\n" + "Action Required:\n" + "Please ensure that all input arguments in the failure node are provided and match the expected arguments specified in the workflow." + ) + + class FlyteDisapprovalException(FlyteAssertion): _ERROR_CODE = "USER:ResultNotApproved" diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 0a3501665c..9601ab6763 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -33,7 +33,7 @@ from flytekit.core.testing import patch, task_mock from flytekit.core.type_engine import RestrictedTypeError, SimpleTransformer, TypeEngine from flytekit.core.workflow import workflow -from flytekit.exceptions.user import FlyteValidationException +from flytekit.exceptions.user import FlyteValidationException, FlyteFailureNodeInputMismatchException from flytekit.models import literals as _literal_models from flytekit.models.core import types as _core_types from flytekit.models.interface import Parameter @@ -1635,6 +1635,7 @@ def foo4(input: DC1=DC1(1, 'a')) -> DC2: ): foo4() + def test_failure_node(): @task def run(a: int, b: str) -> typing.Tuple[int, str]: @@ -1686,6 +1687,42 @@ def wf2(a: int, b: str) -> typing.Tuple[int, str]: assert wf2.failure_node.flyte_entity == failure_handler +def test_failure_node_mismatch_inputs(): + @task() + def t1(a: int) -> int: + return a + 3 + + @workflow(on_failure=t1) + def wf1(a: int = 3, b: str = "hello"): + t1(a=a) + + # pytest-xdist uses `__channelexec__` as the top-level module + running_xdist = os.environ.get("PYTEST_XDIST_WORKER") is not None + prefix = "__channelexec__." if running_xdist else "" + + with pytest.raises( + FlyteFailureNodeInputMismatchException, + match="Mismatched Inputs Detected\n" + f"The failure node `{prefix}tests.flytekit.unit.core.test_type_hints.t1` has " + "inputs that do not align with those expected by the workflow `tests.flytekit.unit.core.test_type_hints.wf1`.\n" + "Failure Node's Inputs: {'a': }\n" + "Workflow's Inputs: {'a': , 'b': }\n" + "Action Required:\n" + "Please ensure that all input arguments in the failure node are provided and match the expected arguments specified in the workflow.", + ): + wf1() + + @task() + def t2(a: int, b: typing.Optional[int] = None) -> int: + return a + 3 + + @workflow(on_failure=t2) + def wf2(a: int = 3): + t2(a=a) + + wf2() + + @pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") def test_union_type(): import pandas as pd