From 0bf22da36d19b1b367c78aa709b0c8c10890431a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 22 Sep 2021 21:57:48 +0800 Subject: [PATCH] Init python pickle Signed-off-by: Kevin Su --- flytekit/core/base_task.py | 7 ++-- flytekit/core/interface.py | 41 +++++++++++++++---- flytekit/core/python_function_task.py | 3 +- flytekit/core/task.py | 5 +++ flytekit/core/type_engine.py | 58 +++++++++++++++++++++++++-- flytekit/core/workflow.py | 18 +++++++-- flytekit/models/types.py | 4 ++ flytekit/types/file/__init__.py | 8 +++- flytekit/types/file/file.py | 26 ++++++++++++ 9 files changed, 152 insertions(+), 18 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 5018941882..d533299f8d 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -230,7 +230,6 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr native_types=self.get_input_types(), ) input_literal_map = _literal_models.LiteralMap(literals=kwargs) - # if metadata.cache is set, check memoized version if self.metadata.cache: # TODO: how to get a nice `native_inputs` here? @@ -370,7 +369,7 @@ def __init__( super().__init__( task_type=task_type, name=name, - interface=transform_interface_to_typed_interface(interface), + interface=transform_interface_to_typed_interface(interface, name), **kwargs, ) self._python_interface = interface if interface else Interface() @@ -453,7 +452,9 @@ def dispatch_execute( ) as exec_ctx: # TODO We could support default values here too - but not part of the plan right now # Translate the input literals to Python native - native_inputs = TypeEngine.literal_map_to_kwargs(exec_ctx, input_literal_map, self.python_interface.inputs) + native_inputs = TypeEngine.literal_map_to_kwargs( + exec_ctx, input_literal_map, self.python_interface.inputs, self.interface.inputs + ) # TODO: Logger should auto inject the current context information to indicate if the task is running within # a workflow or a subworkflow etc diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 842a4198b5..d422af0479 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -13,6 +13,8 @@ from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger from flytekit.models import interface as _interface_models +from flytekit.types.file import PythonPickle +from flytekit.types.file.file import FlyteFilePathTransformer class Interface(object): @@ -188,7 +190,7 @@ def transform_inputs_to_parameters( def transform_interface_to_typed_interface( - interface: typing.Optional[Interface], + interface: typing.Optional[Interface], name: str = "" ) -> typing.Optional[_interface_models.TypedInterface]: """ Transform the given simple python native interface to FlyteIDL's interface @@ -204,8 +206,8 @@ def transform_interface_to_typed_interface( interface.docstring.output_descriptions, interface.outputs ) - inputs_map = transform_variable_map(interface.inputs, input_descriptions) - outputs_map = transform_variable_map(interface.outputs, output_descriptions) + inputs_map = transform_variable_map(interface.inputs, input_descriptions, name) + outputs_map = transform_variable_map(interface.outputs, output_descriptions, name) return _interface_models.TypedInterface(inputs_map, outputs_map) @@ -244,7 +246,9 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface: return Interface(inputs=map_inputs, outputs=map_outputs) -def transform_signature_to_interface(signature: inspect.Signature, docstring: Optional[Docstring] = None) -> Interface: +def transform_signature_to_interface( + signature: inspect.Signature, docstring: Optional[Docstring] = None, contain_return: bool = False +) -> Interface: """ From the annotations on a task function that the user should have provided, and the output names they want to use for each output parameter, construct the TypedInterface object @@ -252,12 +256,29 @@ def transform_signature_to_interface(signature: inspect.Signature, docstring: Op For now the fancy object, maybe in the future a dumb object. """ + outputs = extract_return_annotation(signature.return_annotation) + # [WIP] Handle multiple outputs + if outputs == {} and contain_return: + outputs["o0"] = PythonPickle + for k, v in outputs.items(): + try: + TypeEngine.get_transformer(v) + except ValueError: + # We change the output type to the "PythonPickle" if we can't find the transformer for the original type + outputs[k] = PythonPickle inputs = OrderedDict() for k, v in signature.parameters.items(): + annotation = v.annotation + try: + TypeEngine.get_transformer(annotation) + except ValueError: + # We change the input type to the "PythonPickle" if we can't find the transformer for the original type + annotation = PythonPickle + default = v.default if v.default is not inspect.Parameter.empty else None # Inputs with default values are currently ignored, we may want to look into that in the future - inputs[k] = (v.annotation, v.default if v.default is not inspect.Parameter.empty else None) + inputs[k] = (annotation, default) # This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We # would like to preserve that name in our custom collections.namedtuple. @@ -273,7 +294,9 @@ def transform_signature_to_interface(signature: inspect.Signature, docstring: Op def transform_variable_map( - variable_map: Dict[str, type], descriptions: Dict[str, str] = {} + variable_map: Dict[str, type], + descriptions: Dict[str, str] = {}, + name: str = "", ) -> Dict[str, _interface_models.Variable]: """ Given a map of str (names of inputs for instance) to their Python native types, return a map of the name to a @@ -283,6 +306,9 @@ def transform_variable_map( if variable_map: for k, v in variable_map.items(): res[k] = transform_type(v, descriptions.get(k, k)) + # flytekit will lookup LiteralType metadata when task output type is PythonPickle, + # and write a pickle file to FlyteFilePathTransformer.PICKLE_PATH + res[k].type.metadata = {FlyteFilePathTransformer.PICKLE_PATH: name + "." + k} return res @@ -355,7 +381,8 @@ def t(a: int, b: str) -> Dict[str, int]: ... "Tuples should be used to indicate multiple return values, found only one return variable." ) return OrderedDict( - zip(list(output_name_generator(len(return_annotation.__args__))), return_annotation.__args__) # type: ignore + zip(list(output_name_generator(len(return_annotation.__args__))), return_annotation.__args__) + # type: ignore ) elif isinstance(return_annotation, tuple): if len(return_annotation) == 1: diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 56613e1716..4f77c6a0f8 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -101,6 +101,7 @@ def __init__( ignore_input_vars: Optional[List[str]] = None, execution_mode: Optional[ExecutionBehavior] = ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, + contain_return: bool = False, **kwargs, ): """ @@ -115,7 +116,7 @@ def __init__( if task_function is None: raise ValueError("TaskFunction is a required parameter for PythonFunctionTask") self._native_interface = transform_signature_to_interface( - inspect.signature(task_function), Docstring(callable_=task_function) + inspect.signature(task_function), Docstring(callable_=task_function), contain_return ) mutated_interface = self._native_interface.remove_inputs(ignore_input_vars) super().__init__( diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 24fe283399..30d4e4e511 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -1,3 +1,4 @@ +import ast import datetime as _datetime import inspect from typing import Any, Callable, Dict, List, Optional, Type, Union @@ -182,6 +183,9 @@ def wrapper(fn) -> PythonFunctionTask: timeout=timeout, ) + def _contains_explicit_return(f): + return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f)))) + task_instance = TaskPlugins.find_pythontask_plugin(type(task_config))( task_config, fn, @@ -192,6 +196,7 @@ def wrapper(fn) -> PythonFunctionTask: limits=limits, secret_requests=secret_requests, execution_mode=execution_mode, + contain_return=_contains_explicit_return(fn), ) return task_instance diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 26107f1e82..c6733f04b3 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -6,6 +6,7 @@ import inspect import json as _json import mimetypes +import pickle import typing from abc import ABC, abstractmethod from typing import Optional, Type, cast @@ -26,6 +27,7 @@ from flytekit.models import interface as _interface_models from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types +from flytekit.models.interface import Variable from flytekit.models.literals import Literal, LiteralCollection, LiteralMap, Primitive, Scalar from flytekit.models.types import LiteralType, SimpleType @@ -407,7 +409,11 @@ def named_tuple_to_variable_map(cls, t: typing.NamedTuple) -> _interface_models. @classmethod def literal_map_to_kwargs( - cls, ctx: FlyteContext, lm: LiteralMap, python_types: typing.Dict[str, type] + cls, + ctx: FlyteContext, + lm: LiteralMap, + python_types: typing.Dict[str, type], + literal_type: typing.Dict[str, Variable] = None, ) -> typing.Dict[str, typing.Any]: """ Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task @@ -416,8 +422,54 @@ def literal_map_to_kwargs( raise ValueError( f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" ) - - return {k: TypeEngine.to_python_value(ctx, lm.literals[k], v) for k, v in python_types.items()} + python_value_map = {} + for k, v in python_types.items(): + # We don't need to use TypeEngine to translate literal to python value if input type is PYTHON_PICKLE_FORMAT + # Just use the data in the pickle file + if ( + lm.literals[k].scalar + and lm.literals[k].scalar.blob + and lm.literals[k].scalar.blob.metadata.type.format == "python-pickle" + ): + uri = "" + # Download pickle file to local first if file is not in the local file systems. + if ctx.file_access.is_remote(lm.literals[k].scalar.blob.uri): + ctx.file_access.get_data(lm.literals[k].scalar.blob.uri, "./pickle-file", False) + uri = "./pickle-file" + else: + uri = lm.literals[k].scalar.blob.uri + infile = open(uri, "rb") + v = pickle.load(infile) + infile.close() + python_value_map[k] = v + # Handle special case here. + # We can't use the transformer to convert Literal to python value since we can't know literal's python type + # Task A output type is str, but Task B expect input type is Any which will fall back to PythonPickle + # In this scenario, We directly extract value from scalar, collection, or map. + # e.g. + # @task + # def A(name: str) -> str: + # return f"Welcome, {name}!" + # + # @task + # def B(greeting: typing.Any) -> str: + # return f"{greeting} How are you?" + # + # @workflow + # def welcome(name: str) -> str: + # greeting = A(name=name) + # return B(greeting=greeting) + # + elif literal_type[k].type.blob and literal_type[k].type.blob.format == "python-pickle": + if lm.literals[k].scalar: + python_value_map[k] = lm.literals[k].scalar.primitive.value + elif lm.literals[k].collection: + python_value_map[k] = [l.scalar.primitive.value for l in lm.literals[k].collection.literals] + elif lm.literals[k].map: + python_value_map[k] = {x: y for x, y in lm.literals[k].map.literals.items()} + else: + python_value_map[k] = TypeEngine.to_python_value(ctx, lm.literals[k], v) + return python_value_map @classmethod def dict_to_literal_map( diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index d7038ec908..8ca70aa80d 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ast import inspect from dataclasses import dataclass from enum import Enum @@ -172,7 +173,7 @@ def __init__( self._workflow_metadata = workflow_metadata self._workflow_metadata_defaults = workflow_metadata_defaults self._python_interface = python_interface - self._interface = transform_interface_to_typed_interface(python_interface) + self._interface = transform_interface_to_typed_interface(python_interface, name) self._inputs = {} self._unbound_inputs = set() self._nodes = [] @@ -249,7 +250,6 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # The output of this will always be a combination of Python native values and Promises containing Flyte # Literals. function_outputs = self.execute(**kwargs) - # First handle the empty return case. # A workflow function may return a task that doesn't return anything # def wf(): @@ -572,10 +572,13 @@ def __init__( metadata: Optional[WorkflowMetadata], default_metadata: Optional[WorkflowMetadataDefaults], docstring: Docstring = None, + contain_return: bool = False, ): name = f"{workflow_function.__module__}.{workflow_function.__name__}" self._workflow_function = workflow_function - native_interface = transform_signature_to_interface(inspect.signature(workflow_function), docstring=docstring) + native_interface = transform_signature_to_interface( + inspect.signature(workflow_function), docstring=docstring, contain_return=contain_return + ) # TODO do we need this - can this not be in launchplan only? # This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or @@ -721,6 +724,14 @@ def workflow( """ def wrapper(fn): + def _contains_explicit_return(f): + try: + return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f)))) + except IndentationError: + # Avoid the test errors when ast parse the workflow in test file. + # IndentationError: unexpected unindent + return False + workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) @@ -730,6 +741,7 @@ def wrapper(fn): metadata=workflow_metadata, default_metadata=workflow_metadata_defaults, docstring=Docstring(callable_=fn), + contain_return=_contains_explicit_return(fn), ) workflow_instance.compile() return workflow_instance diff --git a/flytekit/models/types.py b/flytekit/models/types.py index 03b71ef44e..5b6b98ae7c 100644 --- a/flytekit/models/types.py +++ b/flytekit/models/types.py @@ -165,6 +165,10 @@ def metadata(self): """ return self._metadata + @metadata.setter + def metadata(self, value): + self._metadata = value + def to_flyte_idl(self): """ :rtype: flyteidl.core.types_pb2.LiteralType diff --git a/flytekit/types/file/__init__.py b/flytekit/types/file/__init__.py index 2b65efbcd6..755f111365 100644 --- a/flytekit/types/file/__init__.py +++ b/flytekit/types/file/__init__.py @@ -63,11 +63,17 @@ decoration and useful for attaching content type information with the file and automatically documenting code. """ -PythonPickledFile = FlyteFile[typing.TypeVar("python-pickle")] +PythonPickledFile = FlyteFile[typing.TypeVar("python-pickled-file")] """ This type can be used when a serialized python pickled object is returned and shared between tasks. This only adds metadata to the file in Flyte, but does not really carry any object information """ +PYTHON_PICKLE_FORMAT = "python-pickle" +PythonPickle = FlyteFile[typing.TypeVar(PYTHON_PICKLE_FORMAT)] +""" + This type is only used by flytekit. User should not use this type. + Any type that flyte can't recognize will fall back on a PythonPickle +""" PythonNotebook = FlyteFile[typing.TypeVar("ipynb")] """ diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index ae08118bae..9699c33e13 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -2,6 +2,7 @@ import os import pathlib +import pickle import typing from flytekit.core.context_manager import FlyteContext @@ -224,6 +225,10 @@ def __str__(self): class FlyteFilePathTransformer(TypeTransformer[FlyteFile]): + BASE_DIR = ".flyte/" + PICKLE_PATH = "pickle-path" + PYTHON_PICKLE_FORMAT = "python-pickle" + def __init__(self): super().__init__(name="FlyteFilePath", t=FlyteFile) @@ -258,6 +263,17 @@ def to_literal( # information used by all cases meta = BlobMetadata(type=self._blob_type(format=FlyteFilePathTransformer.get_format(python_type))) + # Dump the task output into pickle + if expected.blob.format == self.PYTHON_PICKLE_FORMAT: + os.makedirs(self.BASE_DIR, exist_ok=True) + uri = self.BASE_DIR + expected.metadata.get(self.PICKLE_PATH) + outfile = open(uri, "wb") + pickle.dump(python_val, outfile) + outfile.close() + remote_path = ctx.file_access.get_random_remote_path(uri) + ctx.file_access.put_data(uri, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + if isinstance(python_val, FlyteFile): source_path = python_val.path @@ -318,6 +334,16 @@ def to_python_value( ) -> FlyteFile: uri = lv.scalar.blob.uri + # Deserialize the pickle, and return data in the pickle, + # and download pickle file to local first if file is not in the local file systems. + if lv.scalar.blob.metadata.type.format == self.PYTHON_PICKLE_FORMAT: + if ctx.file_access.is_remote(uri): + ctx.file_access.get_data(uri, "./pickle-file", False) + uri = "./pickle-file" + infile = open(uri, "rb") + data = pickle.load(infile) + infile.close() + return data # In this condition, we still return a FlyteFile instance, but it's a simple one that has no downloading tricks # Using is instead of issubclass because FlyteFile does actually subclass it