Skip to content

Commit

Permalink
Fast register for dynamic tasks (flyteorg#437)
Browse files Browse the repository at this point in the history
Signed-off-by: Max Hoffman <[email protected]>
  • Loading branch information
Katrina Rogan authored and max-hoffman committed Apr 29, 2021
1 parent 6a745c1 commit 5ee9bf0
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 18 deletions.
86 changes: 77 additions & 9 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,14 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str,
_logging.info(f"Engine folder written successfully to the output prefix {output_prefix}")


def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str):
def _handle_annotated_task(
task_def: PythonTask,
inputs: str,
output_prefix: str,
raw_output_data_prefix: str,
dynamic_addl_distro: str = None,
dynamic_dest_dir: str = None,
):
"""
Entrypoint for all PythonTask extensions
"""
Expand Down Expand Up @@ -224,7 +231,9 @@ def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str
with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx:
# Because execution states do not look up the context chain, it has to be made last
with ctx.new_execution_context(
mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters
mode=ExecutionState.Mode.TASK_EXECUTION,
execution_params=execution_parameters,
additional_context={"dynamic_addl_distro": dynamic_addl_distro, "dynamic_dest_dir": dynamic_dest_dir},
) as ctx:
_dispatch_execute(ctx, task_def, inputs, output_prefix)

Expand Down Expand Up @@ -281,7 +290,16 @@ def _load_resolver(resolver_location: str) -> TaskResolverMixin:


@_scopes.system_entry_point
def _execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver: str, resolver_args: List[str]):
def _execute_task(
inputs,
output_prefix,
raw_output_data_prefix,
test,
resolver: str,
resolver_args: List[str],
dynamic_addl_distro: str = None,
dynamic_dest_dir: str = None,
):
"""
This function should be called for new API tasks (those only available in 0.16 and later that leverage Python
native typing).
Expand All @@ -299,6 +317,10 @@ def _execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver:
:param resolver: The task resolver to use. This needs to be loadable directly from importlib (and thus cannot be
nested).
:param resolver_args: Args that will be passed to the aforementioned resolver's load_task function
:param dynamic_addl_distro: In the case of parent tasks executed using the 'fast' mode this captures where the
compressed code archive has been uploaded.
:param dynamic_dest_dir: In the case of parent tasks executed using the 'fast' mode this captures where compressed
code archives should be installed in the flyte task container.
:return:
"""
if len(resolver_args) < 1:
Expand All @@ -313,12 +335,22 @@ def _execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver:
f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}"
)
return
_handle_annotated_task(_task_def, inputs, output_prefix, raw_output_data_prefix)
_handle_annotated_task(
_task_def, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir
)


@_scopes.system_entry_point
def _execute_map_task(
inputs, output_prefix, raw_output_data_prefix, max_concurrency, test, resolver: str, resolver_args: List[str]
inputs,
output_prefix,
raw_output_data_prefix,
max_concurrency,
test,
dynamic_addl_distro: str,
dynamic_dest_dir: str,
resolver: str,
resolver_args: List[str],
):
if len(resolver_args) < 1:
raise Exception(f"Resolver args cannot be <1, got {resolver_args}")
Expand All @@ -342,7 +374,9 @@ def _execute_map_task(
)
return

_handle_annotated_task(map_task, inputs, output_prefix, raw_output_data_prefix)
_handle_annotated_task(
map_task, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir
)


@_click.group()
Expand All @@ -357,14 +391,25 @@ def _pass_through():
@_click.option("--output-prefix", required=True)
@_click.option("--raw-output-data-prefix", required=False)
@_click.option("--test", is_flag=True)
@_click.option("--dynamic-addl-distro", required=False)
@_click.option("--dynamic-dest-dir", required=False)
@_click.option("--resolver", required=False)
@_click.argument(
"resolver-args",
type=_click.UNPROCESSED,
nargs=-1,
)
def execute_task_cmd(
task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test, resolver, resolver_args
task_module,
task_name,
inputs,
output_prefix,
raw_output_data_prefix,
test,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
resolver_args,
):
_click.echo(_utils.get_version_message())
# Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original
Expand All @@ -382,7 +427,16 @@ def execute_task_cmd(
_legacy_execute_task(task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test)
else:
_click.echo(f"Attempting to run with {resolver}...")
_execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver, resolver_args)
_execute_task(
inputs,
output_prefix,
raw_output_data_prefix,
test,
resolver,
resolver_args,
dynamic_addl_distro,
dynamic_dest_dir,
)


@_pass_through.command("pyflyte-fast-execute")
Expand All @@ -405,7 +459,15 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd):

# Use the commandline to run the task execute command rather than calling it directly in python code
# since the current runtime bytecode references the older user code, rather than the downloaded distribution.
_os.system(" ".join(task_execute_cmd))

# Insert the call to fast before the unbounded resolver args
cmd = []
for arg in task_execute_cmd:
if arg == "--resolver":
cmd.extend(["--dynamic-addl-distro", additional_distribution, "--dynamic-dest-dir", dest_dir])
cmd.append(arg)

_os.system(" ".join(cmd))


@_pass_through.command("pyflyte-map-execute")
Expand All @@ -414,6 +476,8 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd):
@_click.option("--raw-output-data-prefix", required=False)
@_click.option("--max-concurrency", type=int, required=False)
@_click.option("--test", is_flag=True)
@_click.option("--dynamic-addl-distro", required=False)
@_click.option("--dynamic-dest-dir", required=False)
@_click.option("--resolver", required=True)
@_click.argument(
"resolver-args",
Expand All @@ -426,6 +490,8 @@ def map_execute_task_cmd(
raw_output_data_prefix,
max_concurrency,
test,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
resolver_args,
):
Expand All @@ -437,6 +503,8 @@ def map_execute_task_cmd(
raw_output_data_prefix,
max_concurrency,
test,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
resolver_args,
)
Expand Down
2 changes: 1 addition & 1 deletion flytekit/common/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def get_serializable_workflow(

# Translate nodes
upstream_sdk_nodes = [
get_serializable(entity_mapping, settings, n)
get_serializable(entity_mapping, settings, n, fast)
for n in entity.nodes
if n.id != _common_constants.GLOBAL_INPUT_NODE_ID
]
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,19 @@ def new_execution_context(
working_dir = working_dir or self.file_access.get_random_local_directory()
engine_dir = os.path.join(working_dir, "engine_dir")
pathlib.Path(engine_dir).mkdir(parents=True, exist_ok=True)
if additional_context is None:
additional_context = self.execution_state.additional_context if self.execution_state is not None else None
elif self.execution_state is not None and self.execution_state.additional_context is not None:
additional_context = {**self.execution_state.additional_context, **additional_context}
exec_state = ExecutionState(
mode=mode, working_dir=working_dir, engine_dir=engine_dir, additional_context=additional_context
)

# If a wf_params object was not given, use the default (defined at the bottom of this file)
new_ctx = FlyteContext(
parent=self, execution_state=exec_state, user_space_params=execution_params or default_user_space_params
parent=self,
execution_state=exec_state,
user_space_params=execution_params or default_user_space_params,
)
FlyteContext.OBJS.append(new_ctx)
try:
Expand Down
38 changes: 35 additions & 3 deletions flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
return container_args

def compile_into_workflow(
self, ctx: FlyteContext, task_function: Callable, **kwargs
self, ctx: FlyteContext, is_fast_execution: bool, task_function: Callable, **kwargs
) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
with ctx.new_compilation_context(prefix="dynamic"):
# TODO: Resolve circular import
Expand All @@ -178,7 +178,7 @@ def compile_into_workflow(
self._wf.compile(**kwargs)

wf = self._wf
sdk_workflow = get_serializable(OrderedDict(), ctx.serialization_settings, wf)
sdk_workflow = get_serializable(OrderedDict(), ctx.serialization_settings, wf, is_fast_execution)

# If no nodes were produced, let's just return the strict outputs
if len(sdk_workflow.nodes) == 0:
Expand All @@ -192,6 +192,33 @@ def compile_into_workflow(
for n in sdk_workflow.nodes:
self.aggregate(tasks, sub_workflows, n)

if is_fast_execution:
if (
not ctx.execution_state
or not ctx.execution_state.additional_context
or not ctx.execution_state.additional_context.get("dynamic_addl_distro")
):
raise AssertionError(
"Compilation for a dynamic workflow called in fast execution mode but no additional code "
"distribution could be retrieved"
)
logger.warn(f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}")
sanitized_tasks = set()
for task in tasks:
sanitized_args = []
for arg in task.container.args:
if arg == "{{ .remote_package_path }}":
sanitized_args.append(ctx.execution_state.additional_context.get("dynamic_addl_distro"))
elif arg == "{{ .dest_dir }}":
sanitized_args.append(ctx.execution_state.additional_context.get("dynamic_dest_dir", "."))
else:
sanitized_args.append(arg)
del task.container.args[:]
task.container.args.extend(sanitized_args)
sanitized_tasks.add(task)

tasks = sanitized_tasks

dj_spec = _dynamic_job.DynamicJobSpec(
min_successes=len(sdk_workflow.nodes),
tasks=list(tasks),
Expand Down Expand Up @@ -241,4 +268,9 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
return task_function(**kwargs)

if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
return self.compile_into_workflow(ctx, task_function, **kwargs)
is_fast_execution = bool(
ctx.execution_state
and ctx.execution_state.additional_context
and ctx.execution_state.additional_context.get("dynamic_addl_distro")
)
return self.compile_into_workflow(ctx, is_fast_execution, task_function, **kwargs)
2 changes: 1 addition & 1 deletion plugins/tests/pod/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def dynamic_pod_task(a: int) -> List[int]:
)
) as ctx:
with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
dynamic_job_spec = dynamic_pod_task.compile_into_workflow(ctx, dynamic_pod_task._task_function, a=5)
dynamic_job_spec = dynamic_pod_task.compile_into_workflow(ctx, False, dynamic_pod_task._task_function, a=5)
assert len(dynamic_job_spec._nodes) == 5


Expand Down
13 changes: 12 additions & 1 deletion tests/flytekit/unit/core/test_context_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from flytekit.core.context_manager import CompilationState, FlyteContext, look_up_image_info
from flytekit.core.context_manager import CompilationState, ExecutionState, FlyteContext, look_up_image_info


class SampleTestClass(object):
Expand Down Expand Up @@ -42,3 +42,14 @@ def test_look_up_image_info():
assert img.name == "x"
assert img.tag == "latest"
assert img.fqn == "localhost:5000/xyz"


def test_additional_context():
with FlyteContext.current_context() as ctx:
with ctx.new_execution_context(
mode=ExecutionState.Mode.TASK_EXECUTION, additional_context={1: "outer", 2: "foo"}
) as exec_ctx_outer:
with exec_ctx_outer.new_execution_context(
mode=ExecutionState.Mode.TASK_EXECUTION, additional_context={1: "inner", 3: "baz"}
) as exec_ctx_inner:
assert exec_ctx_inner.execution_state.additional_context == {1: "inner", 2: "foo", 3: "baz"}
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_dynamic_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,6 @@ def merge_sort(in1: typing.List[int], count: int) -> typing.List[int]:
) as ctx:
with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
dynamic_job_spec = merge_sort_remotely.compile_into_workflow(
ctx, merge_sort_remotely._task_function, in1=[2, 3, 4, 5]
ctx, False, merge_sort_remotely._task_function, in1=[2, 3, 4, 5]
)
assert len(dynamic_job_spec.tasks) == 5
45 changes: 44 additions & 1 deletion tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,53 @@ def my_wf(a: int, b: str) -> (str, typing.List[str]):
)
) as ctx:
with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
dynamic_job_spec = my_subwf.compile_into_workflow(ctx, my_subwf._task_function, a=5)
dynamic_job_spec = my_subwf.compile_into_workflow(ctx, False, my_subwf._task_function, a=5)
assert len(dynamic_job_spec._nodes) == 5


def test_wf1_with_fast_dynamic():
@task
def t1(a: int) -> str:
a = a + 2
return "fast-" + str(a)

@dynamic
def my_subwf(a: int) -> typing.List[str]:
s = []
for i in range(a):
s.append(t1(a=i))
return s

@workflow
def my_wf(a: int) -> typing.List[str]:
v = my_subwf(a=a)
return v

with context_manager.FlyteContext.current_context().new_serialization_settings(
serialization_settings=context_manager.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
) as ctx:
with ctx.new_execution_context(
mode=ExecutionState.Mode.TASK_EXECUTION,
additional_context={
"dynamic_addl_distro": "s3::/my-s3-bucket/fast/123",
"dynamic_dest_dir": "/User/flyte/workflows",
},
) as ctx:
dynamic_job_spec = my_subwf.compile_into_workflow(ctx, True, my_subwf._task_function, a=5)
assert len(dynamic_job_spec._nodes) == 5
assert len(dynamic_job_spec.tasks) == 1
args = " ".join(dynamic_job_spec.tasks[0].container.args)
assert args.startswith(
"pyflyte-fast-execute --additional-distribution s3::/my-s3-bucket/fast/123 --dest-dir /User/flyte/workflows"
)


def test_list_output():
@task
def t1(a: int) -> str:
Expand Down

0 comments on commit 5ee9bf0

Please sign in to comment.