Skip to content

Commit

Permalink
Merge branch 'flyteorg:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
yubofredwang authored May 13, 2023
2 parents 7f3b6ae + dab1eed commit a6119a8
Show file tree
Hide file tree
Showing 27 changed files with 685 additions and 77 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ jobs:
run: |
sudo apt-get install python3-sphinx
pip install -r doc-requirements.txt
SPHINXOPTS="-W" cd docs && make html
cd docs && SPHINXOPTS="-W" make html
3 changes: 1 addition & 2 deletions Dockerfile.external-plugin-service
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ MAINTAINER Flyte Team <[email protected]>
LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit

ARG VERSION
RUN pip install -U flytekit==$VERSION \
flytekitplugins-bigquery==$VERSION \
RUN pip install -U flytekit==$VERSION flytekitplugins-bigquery==$VERSION

CMD pyflyte serve --port 8000
4 changes: 0 additions & 4 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
:template: custom.rst
:nosignatures:
DataPersistence
DataPersistencePlugins
DiskPersistence
FileAccessProvider
UnsupportedPersistenceOp
"""
import os
Expand Down
18 changes: 15 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flytekit.core.node import Node
from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.exceptions import user as _user_exceptions
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models import literals as _literals_models
Expand Down Expand Up @@ -618,10 +619,21 @@ def binding_data_from_python_std(
f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task"
)

elif isinstance(t_value, list):
if expected_literal_type.collection_type is None:
raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}")
elif expected_literal_type.union_type is not None:
for i in range(len(expected_literal_type.union_type.variants)):
try:
lt_type = expected_literal_type.union_type.variants[i]
python_type = get_args(t_value_type)[i] if t_value_type else None
return binding_data_from_python_std(ctx, lt_type, t_value, python_type)
except Exception:
logger.debug(
f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}."
)
raise AssertionError(
f"Failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants}."
)

elif isinstance(t_value, list):
sub_type: Optional[type] = ListTransformer.get_sub_type(t_value_type) if t_value_type else None
collection = _literals_models.BindingDataCollection(
bindings=[
Expand Down
65 changes: 60 additions & 5 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime as _datetime
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload

from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
Expand Down Expand Up @@ -75,9 +75,64 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction
return PythonFunctionTask


T = TypeVar("T")


@overload
def task(
_task_function: None = ...,
task_config: Optional[T] = ...,
cache: bool = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
retries: int = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
timeout: Union[_datetime.timedelta, int] = ...,
container_image: Optional[Union[str, ImageSpec]] = ...,
environment: Optional[Dict[str, str]] = ...,
requests: Optional[Resources] = ...,
limits: Optional[Resources] = ...,
secret_requests: Optional[List[Secret]] = ...,
execution_mode: PythonFunctionTask.ExecutionBehavior = ...,
task_resolver: Optional[TaskResolverMixin] = ...,
docs: Optional[Documentation] = ...,
disable_deck: bool = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
) -> Callable[[Callable[..., Any]], PythonFunctionTask[T]]:
...


@overload
def task(
_task_function: Callable[..., Any],
task_config: Optional[T] = ...,
cache: bool = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
retries: int = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
timeout: Union[_datetime.timedelta, int] = ...,
container_image: Optional[Union[str, ImageSpec]] = ...,
environment: Optional[Dict[str, str]] = ...,
requests: Optional[Resources] = ...,
limits: Optional[Resources] = ...,
secret_requests: Optional[List[Secret]] = ...,
execution_mode: PythonFunctionTask.ExecutionBehavior = ...,
task_resolver: Optional[TaskResolverMixin] = ...,
docs: Optional[Documentation] = ...,
disable_deck: bool = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
) -> PythonFunctionTask[T]:
...


def task(
_task_function: Optional[Callable] = None,
task_config: Optional[Any] = None,
_task_function: Optional[Callable[..., Any]] = None,
task_config: Optional[T] = None,
cache: bool = False,
cache_serialize: bool = False,
cache_version: str = "",
Expand All @@ -96,7 +151,7 @@ def task(
disable_deck: bool = True,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
) -> Union[Callable, PythonFunctionTask]:
) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T]]:
"""
This is the core decorator to use for any task type in flytekit.
Expand Down Expand Up @@ -190,7 +245,7 @@ def foo2():
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""

def wrapper(fn) -> PythonFunctionTask:
def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
_metadata = TaskMetadata(
cache=cache,
cache_serialize=cache_serialize,
Expand Down
16 changes: 12 additions & 4 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,24 +718,32 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
d = dictionary of registered transformers, where is a python `type`
v = lookup type
Step 1:
find a transformer that matches v exactly
If the type is annotated with a TypeTransformer instance, use that.
Step 2:
find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc
find a transformer that matches v exactly
Step 3:
find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc
Step 4:
Walk the inheritance hierarchy of v and find a transformer that matches the first base class.
This is potentially non-deterministic - will depend on the registration pattern.
TODO lets make this deterministic by using an ordered dict
Step 4:
Step 5:
if v is of type data class, use the dataclass transformer
"""
cls.lazy_import_transformers()
# Step 1
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
args = get_args(python_type)
for annotation in args:
if isinstance(annotation, TypeTransformer):
return annotation

python_type = args[0]

if python_type in cls._REGISTRY:
return cls._REGISTRY[python_type]
Expand Down
104 changes: 85 additions & 19 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import typing
from dataclasses import dataclass
from enum import Enum
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, overload

from typing_extensions import get_args

from flytekit.core import constants as _common_constants
from flytekit.core.base_task import PythonTask
Expand Down Expand Up @@ -32,14 +35,16 @@
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError, UnionTransformer
from flytekit.exceptions import scopes as exception_scopes
from flytekit.exceptions.user import FlyteValidationException, FlyteValueException
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.documentation import Description, Documentation
from flytekit.models.types import TypeStructure

GLOBAL_START_NODE = Node(
id=_common_constants.GLOBAL_INPUT_NODE_ID,
Expand All @@ -49,6 +54,8 @@
flyte_entity=None,
)

T = typing.TypeVar("T")


class WorkflowFailurePolicy(Enum):
"""
Expand Down Expand Up @@ -272,24 +279,63 @@ def execute(self, **kwargs):
def compile(self, **kwargs):
pass

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
# This is done to support the invariant that Workflow local executions always work with Promise objects
# holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.
for k, v in kwargs.items():
if not isinstance(v, Promise):
t = self.python_interface.inputs[k]
def ensure_literal(
self, ctx, py_type: Type[T], input_type: type_models.LiteralType, python_value: Any
) -> _literal_models.Literal:
"""
This function will attempt to convert a python value to a literal. If the python value is a promise, it will
return the promise's value.
"""
if input_type.union_type is not None:
if python_value is None and UnionTransformer.is_optional_type(py_type):
return _literal_models.Literal(scalar=_literal_models.Scalar(none_type=_literal_models.Void()))
for i in range(len(input_type.union_type.variants)):
lt_type = input_type.union_type.variants[i]
python_type = get_args(py_type)[i]
try:
final_lt = self.ensure_literal(ctx, python_type, lt_type, python_value)
lt_type._structure = TypeStructure(tag=TypeEngine.get_transformer(python_type).name)
return _literal_models.Literal(
scalar=_literal_models.Scalar(union=_literal_models.Union(value=final_lt, stored_type=lt_type))
)
except Exception as e:
logger.debug(f"Failed to convert {python_value} to {lt_type} with error {e}")
raise TypeError(f"Failed to convert {python_value} to {input_type}")
if isinstance(python_value, list) and input_type.collection_type:
collection_lit_type = input_type.collection_type
collection_py_type = get_args(py_type)[0]
xx = [self.ensure_literal(ctx, collection_py_type, collection_lit_type, pv) for pv in python_value]
return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=xx))
elif isinstance(python_value, dict) and input_type.map_value_type:
mapped_lit_type = input_type.map_value_type
mapped_py_type = get_args(py_type)[1]
xx = {k: self.ensure_literal(ctx, mapped_py_type, mapped_lit_type, v) for k, v in python_value.items()} # type: ignore
return _literal_models.Literal(map=_literal_models.LiteralMap(literals=xx))
# It is a scalar, convert to Promise if necessary.
else:
if isinstance(python_value, Promise):
return python_value.val
if not isinstance(python_value, Promise):
try:
kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type))
res = TypeEngine.to_literal(ctx, python_value, py_type, input_type)
return res
except TypeTransformerFailedError as exc:
raise TypeError(
f"Failed to convert input argument '{k}' of workflow '{self.name}':\n {exc}"
f"Failed to convert input '{python_value}' of workflow '{self.name}':\n {exc}"
) from exc

# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
# This is done to support the invariant that Workflow local executions always work with Promise objects
# holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.
for k, v in kwargs.items():
py_type = self.python_interface.inputs[k]
lit_type = self.interface.inputs[k].type
kwargs[k] = Promise(var=k, val=self.ensure_literal(ctx, py_type, lit_type, v))

# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
self.compile()
function_outputs = self.execute(**kwargs)

# First handle the empty return case.
# A workflow function may return a task that doesn't return anything
# def wf():
Expand Down Expand Up @@ -607,7 +653,7 @@ class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver):

def __init__(
self,
workflow_function: Callable,
workflow_function: Callable[..., Any],
metadata: WorkflowMetadata,
default_metadata: WorkflowMetadataDefaults,
docstring: Optional[Docstring] = None,
Expand Down Expand Up @@ -731,12 +777,32 @@ def execute(self, **kwargs):
return exception_scopes.user_entry_point(self._workflow_function)(**kwargs)


@overload
def workflow(
_workflow_function: None = ...,
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
docs: Optional[Documentation] = ...,
) -> Callable[[Callable[..., Any]], PythonFunctionWorkflow]:
...


@overload
def workflow(
_workflow_function: Callable[..., Any],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
docs: Optional[Documentation] = ...,
) -> PythonFunctionWorkflow:
...


def workflow(
_workflow_function=None,
_workflow_function: Optional[Callable[..., Any]] = None,
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
docs: Optional[Documentation] = None,
) -> WorkflowBase:
) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
of tasks using the data flow between tasks.
Expand Down Expand Up @@ -767,7 +833,7 @@ def workflow(
:param docs: Description entity for the workflow
"""

def wrapper(fn):
def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow:
workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY)

workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)
Expand All @@ -782,10 +848,10 @@ def wrapper(fn):
update_wrapper(workflow_instance, fn)
return workflow_instance

if _workflow_function:
if _workflow_function is not None:
return wrapper(_workflow_function)
else:
return wrapper # type: ignore
return wrapper


class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): # type: ignore
Expand Down
4 changes: 3 additions & 1 deletion flytekit/image_spec/image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ImageSpec:
packages: list of python packages to install.
apt_packages: list of apt packages to install.
base_image: base image of the image.
platform: Specify the target platforms for the build output (for example, windows/amd64 or linux/amd64,darwin/arm64
"""

name: str = "flytekit"
Expand All @@ -43,6 +44,7 @@ class ImageSpec:
packages: Optional[List[str]] = None
apt_packages: Optional[List[str]] = None
base_image: Optional[str] = None
platform: str = "linux/amd64"

def image_name(self) -> str:
"""
Expand Down Expand Up @@ -147,7 +149,7 @@ def calculate_hash_from_image_spec(image_spec: ImageSpec):
# copy the image spec to avoid modifying the original image spec. otherwise, the hash will be different.
spec = copy(image_spec)
spec.source_root = hash_directory(image_spec.source_root) if image_spec.source_root else b""
image_spec_bytes = bytes(image_spec.to_json(), "utf-8")
image_spec_bytes = bytes(spec.to_json(), "utf-8")
tag = base64.urlsafe_b64encode(hashlib.md5(image_spec_bytes).digest()).decode("ascii")
# replace "=" with "." to make it a valid tag
return tag.replace("=", ".")
Expand Down
Loading

0 comments on commit a6119a8

Please sign in to comment.