Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into thomasjpfan/neptu…
Browse files Browse the repository at this point in the history
…ne_pr

Signed-off-by: Thomas J. Fan <[email protected]>
  • Loading branch information
thomasjpfan committed Aug 26, 2024
2 parents 5727be2 + 83b90fa commit fa6b350
Show file tree
Hide file tree
Showing 21 changed files with 421 additions and 26 deletions.
3 changes: 2 additions & 1 deletion Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \
uv pip install --system --no-cache-dir -U \
"git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl" \
-e /flytekit \
-e /flytekit/plugins/flytekit-k8s-pod \
-e /flytekit/plugins/flytekit-deck-standard \
-e /flytekit/plugins/flytekit-flyteinteractive \
scikit-learn \
Expand All @@ -50,5 +49,7 @@ RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \
&& chown flytekit: /home \
&& :

ENV PYTHONPATH="/flytekit:"

# Switch to the 'flytekit' user for better security.
USER flytekit
12 changes: 10 additions & 2 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,16 @@ def __init__(
actual_task = python_function_task

# TODO: add support for other Flyte entities
if not (isinstance(actual_task, PythonFunctionTask) or isinstance(actual_task, PythonInstanceTask)):
raise ValueError("Only PythonFunctionTask and PythonInstanceTask are supported in map tasks.")
if not (
(
isinstance(actual_task, PythonFunctionTask)
and actual_task.execution_mode == PythonFunctionTask.ExecutionBehavior.DEFAULT
)
or isinstance(actual_task, PythonInstanceTask)
):
raise ValueError(
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager) and PythonInstanceTask are supported in map tasks."
)

for k, v in actual_task.python_interface.inputs.items():
if bound_inputs and k in bound_inputs:
Expand Down
9 changes: 8 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):

expected_type = get_underlying_type(expected_type)
expected_fields_dict = {}

for f in dataclasses.fields(expected_type):
expected_fields_dict[f.name] = f.type

Expand Down Expand Up @@ -539,11 +540,13 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
field.type = self._get_origin_type_in_annotation(field.type)
return python_type

def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T:
def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T | None:
# In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict,
# so here we convert it back to the Structured Dataset.
from flytekit.types.structured import StructuredDataset

if python_val is None:
return python_val
if python_type == StructuredDataset and type(python_val) == dict:
return StructuredDataset(**python_val)
elif get_origin(python_type) is list:
Expand Down Expand Up @@ -575,9 +578,13 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t
return self._make_dataclass_serializable(python_val, get_args(python_type)[0])

if hasattr(python_type, "__origin__") and get_origin(python_type) is list:
if python_val is None:
return None
return [self._make_dataclass_serializable(v, get_args(python_type)[0]) for v in cast(list, python_val)]

if hasattr(python_type, "__origin__") and get_origin(python_type) is dict:
if python_val is None:
return None
return {
k: self._make_dataclass_serializable(v, get_args(python_type)[1])
for k, v in cast(dict, python_val).items()
Expand Down
2 changes: 1 addition & 1 deletion flytekit/image_spec/default_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
# Configure user space
ENV PATH="/opt/micromamba/envs/runtime/bin:$$PATH" \
UV_LINK_MODE=copy \
UV_PRERELEASE=allow \
FLYTE_SDK_RICH_TRACEBACKS=0 \
SSL_CERT_DIR=/etc/ssl/certs \
$ENV
Expand Down Expand Up @@ -245,6 +244,7 @@ class DefaultImageBuilder(ImageSpecBuilder):
"cudnn",
"base_image",
"pip_index",
"pip_extra_index_url",
# "registry_config",
"commands",
}
Expand Down
5 changes: 5 additions & 0 deletions flytekit/image_spec/image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import os
import pathlib
import re
import typing
from abc import abstractmethod
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -143,6 +144,10 @@ def exist(self) -> Optional[bool]:
if e.response.status_code == 404:
return False

if re.match(f"unknown: repository .*{self.name} not found", e.explanation):
click.secho(f"Received 500 error with explanation: {e.explanation}", fg="yellow")
return False

click.secho(f"Failed to check if the image exists with error:\n {e}", fg="red")
return None
except ImageNotFound:
Expand Down
43 changes: 42 additions & 1 deletion flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dataclasses
import datetime
import enum
import json
import logging
import os
import pathlib
import typing
from typing import cast
from typing import cast, get_args

import rich_click as click
import yaml
Expand All @@ -22,6 +23,7 @@
from flytekit.types.file import FlyteFile
from flytekit.types.iterator.json_iterator import JSONIteratorTransformer
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.types.schema.types import FlyteSchema


def is_pydantic_basemodel(python_type: typing.Type) -> bool:
Expand Down Expand Up @@ -305,11 +307,50 @@ def convert(
if value is None:
raise click.BadParameter("None value cannot be converted to a Json type.")

FLYTE_TYPES = [FlyteFile, FlyteDirectory, StructuredDataset, FlyteSchema]

def has_nested_dataclass(t: typing.Type) -> bool:
"""
Recursively checks whether the given type or its nested types contain any dataclass.
This function is typically called with a dictionary or list type and will return True if
any of the nested types within the dictionary or list is a dataclass.
Note:
- A single dataclass will return True.
- The function specifically excludes certain Flyte types like FlyteFile, FlyteDirectory,
StructuredDataset, and FlyteSchema from being considered as dataclasses. This is because
these types are handled separately by Flyte and do not need to be converted to dataclasses.
Args:
t (typing.Type): The type to check for nested dataclasses.
Returns:
bool: True if the type or its nested types contain a dataclass, False otherwise.
"""

if dataclasses.is_dataclass(t):
# FlyteTypes is not supported now, we can support it in the future.
return t not in FLYTE_TYPES

return any(has_nested_dataclass(arg) for arg in get_args(t))

parsed_value = self._parse(value, param)

# We compare the origin type because the json parsed value for list or dict is always a list or dict without
# the covariant type information.
if type(parsed_value) == typing.get_origin(self._python_type) or type(parsed_value) == self._python_type:
# Indexing the return value of get_args will raise an error for native dict and list types.
# We don't support native list/dict types with nested dataclasses.
if get_args(self._python_type) == ():
return parsed_value
elif isinstance(parsed_value, list) and has_nested_dataclass(get_args(self._python_type)[0]):
j = JsonParamType(get_args(self._python_type)[0])
return [j.convert(v, param, ctx) for v in parsed_value]
elif isinstance(parsed_value, dict) and has_nested_dataclass(get_args(self._python_type)[1]):
j = JsonParamType(get_args(self._python_type)[1])
return {k: j.convert(v, param, ctx) for k, v in parsed_value.items()}

return parsed_value

if is_pydantic_basemodel(self._python_type):
Expand Down
5 changes: 3 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,8 +1144,9 @@ def _execute(
"""
if execution_name is not None and execution_name_prefix is not None:
raise ValueError("Only one of execution_name and execution_name_prefix can be set, but got both set")
execution_name_prefix = execution_name_prefix + "-" if execution_name_prefix is not None else None
execution_name = execution_name or (execution_name_prefix or "f") + uuid.uuid4().hex[:19]
# todo: The prefix should be passed to the backend
if execution_name_prefix is not None:
execution_name = execution_name_prefix + "-" + uuid.uuid4().hex[:19]
if not options:
options = Options()
if options.disable_notifications is not None:
Expand Down
6 changes: 6 additions & 0 deletions plugins/flytekit-k8s-pod/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Flytekit Kubernetes Pod Plugin

> [!IMPORTANT]
> This plugin is no longer needed and is here only for backwards compatibility. No new versions will be published after v1.13.x
> Please use the `pod_template` and `pod_template_name` args to `@task` as described in https://docs.flyte.org/en/latest/deployment/configuration/general.html#configuring-task-pods-with-k8s-podtemplates
> instead.

By default, Flyte tasks decorated with `@task` are essentially single functions that are loaded in one container. But often, there is a need to run a job with more than one container.

In this case, a regular task is not enough. Hence, Flyte provides a Kubernetes pod abstraction to execute multiple containers, which can be accomplished using Pod's `task_config`. The `task_config` can be leveraged to fully customize the pod spec used to run the task.
Expand Down
10 changes: 9 additions & 1 deletion plugins/flytekit-k8s-pod/flytekitplugins/pod/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import warnings

from .task import Pod

"""
.. currentmodule:: flytekitplugins.pod
Expand All @@ -10,4 +14,8 @@
Pod
"""

from .task import Pod
warnings.warn(
"This pod plugin is no longer necessary, please use the pod_template and pod_template_name options to @task as described "
"in https://docs.flyte.org/en/latest/deployment/configuration/general.html#configuring-task-pods-with-k8s-podtemplates",
FutureWarning,
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ class PolarsDataFrameRenderer:

def to_html(self, df: pl.DataFrame) -> str:
assert isinstance(df, pl.DataFrame)
describe_df = df.describe()
return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False)
return df.describe().to_pandas().to_html(index=False)


class PolarsDataFrameToParquetEncodingHandler(StructuredDatasetEncoder):
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-polars/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "polars>=0.8.27,<0.17.0", "pandas"]
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "polars>=0.8.27", "pandas"]

__version__ = "0.0.0+develop"

Expand Down
16 changes: 10 additions & 6 deletions plugins/flytekit-polars/tests/test_polars_plugin_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
import polars as pl
from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer
from typing_extensions import Annotated
from packaging import version
from polars.testing import assert_frame_equal

from flytekit import kwtypes, task, workflow
from flytekit.types.structured.structured_dataset import PARQUET, StructuredDataset

subset_schema = Annotated[StructuredDataset, kwtypes(col2=str), PARQUET]
full_schema = Annotated[StructuredDataset, PARQUET]

polars_version = pl.__version__


def test_polars_workflow_subset():
@task
Expand Down Expand Up @@ -65,9 +69,9 @@ def wf() -> full_schema:

def test_polars_renderer():
df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")})
assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame(
df.describe().transpose(), columns=df.describe().columns
).to_html(index=False)
assert PolarsDataFrameRenderer().to_html(df) == df.describe().to_pandas().to_html(
index=False
)


def test_parquet_to_polars():
Expand All @@ -80,7 +84,7 @@ def create_sd() -> StructuredDataset:

sd = create_sd()
polars_df = sd.open(pl.DataFrame).all()
assert pl.DataFrame(data).frame_equal(polars_df)
assert_frame_equal(polars_df, pl.DataFrame(data))

tmp = tempfile.mktemp()
pl.DataFrame(data).write_parquet(tmp)
Expand All @@ -90,11 +94,11 @@ def t1(sd: StructuredDataset) -> pl.DataFrame:
return sd.open(pl.DataFrame).all()

sd = StructuredDataset(uri=tmp)
assert t1(sd=sd).frame_equal(polars_df)
assert_frame_equal(t1(sd=sd), polars_df)

@task
def t2(sd: StructuredDataset) -> StructuredDataset:
return StructuredDataset(dataframe=sd.open(pl.DataFrame).all())

sd = StructuredDataset(uri=tmp)
assert t2(sd=sd).open(pl.DataFrame).all().frame_equal(polars_df)
assert_frame_equal(t2(sd=sd).open(pl.DataFrame).all(), polars_df)
3 changes: 3 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/my_wf_input.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
},
"p": "None",
"q": "tests/flytekit/unit/cli/pyflyte/testdata",
"r": [{"i": 1, "a": ["h", "e"]}],
"s": {"x": {"i": 1, "a": ["h", "e"]}},
"t": {"i": [{"i":1,"a":["h","e"]}]},
"remote": "tests/flytekit/unit/cli/pyflyte/testdata",
"image": "tests/flytekit/unit/cli/pyflyte/testdata"
}
17 changes: 17 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,22 @@ o:
- tests/flytekit/unit/cli/pyflyte/testdata/df.parquet
p: 'None'
q: tests/flytekit/unit/cli/pyflyte/testdata
r:
- i: 1
a:
- h
- e
s:
x:
i: 1
a:
- h
- e
t:
i:
- i: 1
a:
- h
- e
remote: tests/flytekit/unit/cli/pyflyte/testdata
image: tests/flytekit/unit/cli/pyflyte/testdata
8 changes: 6 additions & 2 deletions tests/flytekit/unit/cli/pyflyte/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,21 @@ def reset_flytectl_config_env_var() -> pytest.fixture():
return os.environ[FLYTECTL_CONFIG_ENV_VAR]


@mock.patch("flytekit.configuration.plugin.get_config_file")
@mock.patch("flytekit.configuration.plugin.FlyteRemote")
def test_get_remote(mock_remote, reset_flytectl_config_env_var):
def test_get_remote(mock_remote, mock_config_file, reset_flytectl_config_env_var):
mock_config_file.return_value = None
r = FlytekitPlugin.get_remote(None, "p", "d")
assert r is not None
mock_remote.assert_called_once_with(
Config.for_sandbox(), default_project="p", default_domain="d", data_upload_location=None
)


@mock.patch("flytekit.configuration.plugin.get_config_file")
@mock.patch("flytekit.configuration.plugin.FlyteRemote")
def test_saving_remote(mock_remote):
def test_saving_remote(mock_remote, mock_config_file):
mock_config_file.return_value = None
mock_context = mock.MagicMock
mock_context.obj = {}
get_and_save_remote_with_click_context(mock_context, "p", "d")
Expand Down
6 changes: 6 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ def test_pyflyte_run_cli(workflow_file):
"Any",
"--q",
DIR_NAME,
"--r",
json.dumps([{"i": 1, "a": ["h", "e"]}]),
"--s",
json.dumps({"x": {"i": 1, "a": ["h", "e"]}}),
"--t",
json.dumps({"i": [{"i":1,"a":["h","e"]}]}),
],
catch_exceptions=False,
)
Expand Down
Loading

0 comments on commit fa6b350

Please sign in to comment.