Skip to content

Commit

Permalink
Init python pickle
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Sep 22, 2021
1 parent d04e9f6 commit 0bf22da
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 18 deletions.
7 changes: 4 additions & 3 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
native_types=self.get_input_types(),
)
input_literal_map = _literal_models.LiteralMap(literals=kwargs)

# if metadata.cache is set, check memoized version
if self.metadata.cache:
# TODO: how to get a nice `native_inputs` here?
Expand Down Expand Up @@ -370,7 +369,7 @@ def __init__(
super().__init__(
task_type=task_type,
name=name,
interface=transform_interface_to_typed_interface(interface),
interface=transform_interface_to_typed_interface(interface, name),
**kwargs,
)
self._python_interface = interface if interface else Interface()
Expand Down Expand Up @@ -453,7 +452,9 @@ def dispatch_execute(
) as exec_ctx:
# TODO We could support default values here too - but not part of the plan right now
# Translate the input literals to Python native
native_inputs = TypeEngine.literal_map_to_kwargs(exec_ctx, input_literal_map, self.python_interface.inputs)
native_inputs = TypeEngine.literal_map_to_kwargs(
exec_ctx, input_literal_map, self.python_interface.inputs, self.interface.inputs
)

# TODO: Logger should auto inject the current context information to indicate if the task is running within
# a workflow or a subworkflow etc
Expand Down
41 changes: 34 additions & 7 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.types.file import PythonPickle
from flytekit.types.file.file import FlyteFilePathTransformer


class Interface(object):
Expand Down Expand Up @@ -188,7 +190,7 @@ def transform_inputs_to_parameters(


def transform_interface_to_typed_interface(
interface: typing.Optional[Interface],
interface: typing.Optional[Interface], name: str = ""
) -> typing.Optional[_interface_models.TypedInterface]:
"""
Transform the given simple python native interface to FlyteIDL's interface
Expand All @@ -204,8 +206,8 @@ def transform_interface_to_typed_interface(
interface.docstring.output_descriptions, interface.outputs
)

inputs_map = transform_variable_map(interface.inputs, input_descriptions)
outputs_map = transform_variable_map(interface.outputs, output_descriptions)
inputs_map = transform_variable_map(interface.inputs, input_descriptions, name)
outputs_map = transform_variable_map(interface.outputs, output_descriptions, name)
return _interface_models.TypedInterface(inputs_map, outputs_map)


Expand Down Expand Up @@ -244,20 +246,39 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface:
return Interface(inputs=map_inputs, outputs=map_outputs)


def transform_signature_to_interface(signature: inspect.Signature, docstring: Optional[Docstring] = None) -> Interface:
def transform_signature_to_interface(
signature: inspect.Signature, docstring: Optional[Docstring] = None, contain_return: bool = False
) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
for each output parameter, construct the TypedInterface object
For now the fancy object, maybe in the future a dumb object.
"""

outputs = extract_return_annotation(signature.return_annotation)
# [WIP] Handle multiple outputs
if outputs == {} and contain_return:
outputs["o0"] = PythonPickle
for k, v in outputs.items():
try:
TypeEngine.get_transformer(v)
except ValueError:
# We change the output type to the "PythonPickle" if we can't find the transformer for the original type
outputs[k] = PythonPickle

inputs = OrderedDict()
for k, v in signature.parameters.items():
annotation = v.annotation
try:
TypeEngine.get_transformer(annotation)
except ValueError:
# We change the input type to the "PythonPickle" if we can't find the transformer for the original type
annotation = PythonPickle
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (v.annotation, v.default if v.default is not inspect.Parameter.empty else None)
inputs[k] = (annotation, default)

# This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We
# would like to preserve that name in our custom collections.namedtuple.
Expand All @@ -273,7 +294,9 @@ def transform_signature_to_interface(signature: inspect.Signature, docstring: Op


def transform_variable_map(
variable_map: Dict[str, type], descriptions: Dict[str, str] = {}
variable_map: Dict[str, type],
descriptions: Dict[str, str] = {},
name: str = "",
) -> Dict[str, _interface_models.Variable]:
"""
Given a map of str (names of inputs for instance) to their Python native types, return a map of the name to a
Expand All @@ -283,6 +306,9 @@ def transform_variable_map(
if variable_map:
for k, v in variable_map.items():
res[k] = transform_type(v, descriptions.get(k, k))
# flytekit will lookup LiteralType metadata when task output type is PythonPickle,
# and write a pickle file to FlyteFilePathTransformer.PICKLE_PATH
res[k].type.metadata = {FlyteFilePathTransformer.PICKLE_PATH: name + "." + k}

return res

Expand Down Expand Up @@ -355,7 +381,8 @@ def t(a: int, b: str) -> Dict[str, int]: ...
"Tuples should be used to indicate multiple return values, found only one return variable."
)
return OrderedDict(
zip(list(output_name_generator(len(return_annotation.__args__))), return_annotation.__args__) # type: ignore
zip(list(output_name_generator(len(return_annotation.__args__))), return_annotation.__args__)
# type: ignore
)
elif isinstance(return_annotation, tuple):
if len(return_annotation) == 1:
Expand Down
3 changes: 2 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
ignore_input_vars: Optional[List[str]] = None,
execution_mode: Optional[ExecutionBehavior] = ExecutionBehavior.DEFAULT,
task_resolver: Optional[TaskResolverMixin] = None,
contain_return: bool = False,
**kwargs,
):
"""
Expand All @@ -115,7 +116,7 @@ def __init__(
if task_function is None:
raise ValueError("TaskFunction is a required parameter for PythonFunctionTask")
self._native_interface = transform_signature_to_interface(
inspect.signature(task_function), Docstring(callable_=task_function)
inspect.signature(task_function), Docstring(callable_=task_function), contain_return
)
mutated_interface = self._native_interface.remove_inputs(ignore_input_vars)
super().__init__(
Expand Down
5 changes: 5 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import datetime as _datetime
import inspect
from typing import Any, Callable, Dict, List, Optional, Type, Union
Expand Down Expand Up @@ -182,6 +183,9 @@ def wrapper(fn) -> PythonFunctionTask:
timeout=timeout,
)

def _contains_explicit_return(f):
return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f))))

task_instance = TaskPlugins.find_pythontask_plugin(type(task_config))(
task_config,
fn,
Expand All @@ -192,6 +196,7 @@ def wrapper(fn) -> PythonFunctionTask:
limits=limits,
secret_requests=secret_requests,
execution_mode=execution_mode,
contain_return=_contains_explicit_return(fn),
)

return task_instance
Expand Down
58 changes: 55 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
import json as _json
import mimetypes
import pickle
import typing
from abc import ABC, abstractmethod
from typing import Optional, Type, cast
Expand All @@ -26,6 +27,7 @@
from flytekit.models import interface as _interface_models
from flytekit.models import types as _type_models
from flytekit.models.core import types as _core_types
from flytekit.models.interface import Variable
from flytekit.models.literals import Literal, LiteralCollection, LiteralMap, Primitive, Scalar
from flytekit.models.types import LiteralType, SimpleType

Expand Down Expand Up @@ -407,7 +409,11 @@ def named_tuple_to_variable_map(cls, t: typing.NamedTuple) -> _interface_models.

@classmethod
def literal_map_to_kwargs(
cls, ctx: FlyteContext, lm: LiteralMap, python_types: typing.Dict[str, type]
cls,
ctx: FlyteContext,
lm: LiteralMap,
python_types: typing.Dict[str, type],
literal_type: typing.Dict[str, Variable] = None,
) -> typing.Dict[str, typing.Any]:
"""
Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task
Expand All @@ -416,8 +422,54 @@ def literal_map_to_kwargs(
raise ValueError(
f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}"
)

return {k: TypeEngine.to_python_value(ctx, lm.literals[k], v) for k, v in python_types.items()}
python_value_map = {}
for k, v in python_types.items():
# We don't need to use TypeEngine to translate literal to python value if input type is PYTHON_PICKLE_FORMAT
# Just use the data in the pickle file
if (
lm.literals[k].scalar
and lm.literals[k].scalar.blob
and lm.literals[k].scalar.blob.metadata.type.format == "python-pickle"
):
uri = ""
# Download pickle file to local first if file is not in the local file systems.
if ctx.file_access.is_remote(lm.literals[k].scalar.blob.uri):
ctx.file_access.get_data(lm.literals[k].scalar.blob.uri, "./pickle-file", False)
uri = "./pickle-file"
else:
uri = lm.literals[k].scalar.blob.uri
infile = open(uri, "rb")
v = pickle.load(infile)
infile.close()
python_value_map[k] = v
# Handle special case here.
# We can't use the transformer to convert Literal to python value since we can't know literal's python type
# Task A output type is str, but Task B expect input type is Any which will fall back to PythonPickle
# In this scenario, We directly extract value from scalar, collection, or map.
# e.g.
# @task
# def A(name: str) -> str:
# return f"Welcome, {name}!"
#
# @task
# def B(greeting: typing.Any) -> str:
# return f"{greeting} How are you?"
#
# @workflow
# def welcome(name: str) -> str:
# greeting = A(name=name)
# return B(greeting=greeting)
#
elif literal_type[k].type.blob and literal_type[k].type.blob.format == "python-pickle":
if lm.literals[k].scalar:
python_value_map[k] = lm.literals[k].scalar.primitive.value
elif lm.literals[k].collection:
python_value_map[k] = [l.scalar.primitive.value for l in lm.literals[k].collection.literals]
elif lm.literals[k].map:
python_value_map[k] = {x: y for x, y in lm.literals[k].map.literals.items()}
else:
python_value_map[k] = TypeEngine.to_python_value(ctx, lm.literals[k], v)
return python_value_map

@classmethod
def dict_to_literal_map(
Expand Down
18 changes: 15 additions & 3 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import ast
import inspect
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -172,7 +173,7 @@ def __init__(
self._workflow_metadata = workflow_metadata
self._workflow_metadata_defaults = workflow_metadata_defaults
self._python_interface = python_interface
self._interface = transform_interface_to_typed_interface(python_interface)
self._interface = transform_interface_to_typed_interface(python_interface, name)
self._inputs = {}
self._unbound_inputs = set()
self._nodes = []
Expand Down Expand Up @@ -249,7 +250,6 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
function_outputs = self.execute(**kwargs)

# First handle the empty return case.
# A workflow function may return a task that doesn't return anything
# def wf():
Expand Down Expand Up @@ -572,10 +572,13 @@ def __init__(
metadata: Optional[WorkflowMetadata],
default_metadata: Optional[WorkflowMetadataDefaults],
docstring: Docstring = None,
contain_return: bool = False,
):
name = f"{workflow_function.__module__}.{workflow_function.__name__}"
self._workflow_function = workflow_function
native_interface = transform_signature_to_interface(inspect.signature(workflow_function), docstring=docstring)
native_interface = transform_signature_to_interface(
inspect.signature(workflow_function), docstring=docstring, contain_return=contain_return
)

# TODO do we need this - can this not be in launchplan only?
# This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or
Expand Down Expand Up @@ -721,6 +724,14 @@ def workflow(
"""

def wrapper(fn):
def _contains_explicit_return(f):
try:
return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f))))
except IndentationError:
# Avoid the test errors when ast parse the workflow in test file.
# IndentationError: unexpected unindent
return False

workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY)

workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)
Expand All @@ -730,6 +741,7 @@ def wrapper(fn):
metadata=workflow_metadata,
default_metadata=workflow_metadata_defaults,
docstring=Docstring(callable_=fn),
contain_return=_contains_explicit_return(fn),
)
workflow_instance.compile()
return workflow_instance
Expand Down
4 changes: 4 additions & 0 deletions flytekit/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ def metadata(self):
"""
return self._metadata

@metadata.setter
def metadata(self, value):
self._metadata = value

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.types_pb2.LiteralType
Expand Down
8 changes: 7 additions & 1 deletion flytekit/types/file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,17 @@
decoration and useful for attaching content type information with the file and automatically documenting code.
"""

PythonPickledFile = FlyteFile[typing.TypeVar("python-pickle")]
PythonPickledFile = FlyteFile[typing.TypeVar("python-pickled-file")]
"""
This type can be used when a serialized python pickled object is returned and shared between tasks. This only
adds metadata to the file in Flyte, but does not really carry any object information
"""
PYTHON_PICKLE_FORMAT = "python-pickle"
PythonPickle = FlyteFile[typing.TypeVar(PYTHON_PICKLE_FORMAT)]
"""
This type is only used by flytekit. User should not use this type.
Any type that flyte can't recognize will fall back on a PythonPickle
"""

PythonNotebook = FlyteFile[typing.TypeVar("ipynb")]
"""
Expand Down
Loading

0 comments on commit 0bf22da

Please sign in to comment.