Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast register for dynamic tasks #437

Merged
merged 13 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 77 additions & 9 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,14 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str,
_logging.info(f"Engine folder written successfully to the output prefix {output_prefix}")


def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str, raw_output_data_prefix: str):
def _handle_annotated_task(
task_def: PythonTask,
inputs: str,
output_prefix: str,
raw_output_data_prefix: str,
dynamic_addl_distro: str = None,
dynamic_dest_dir: str = None,
):
"""
Entrypoint for all PythonTask extensions
"""
Expand Down Expand Up @@ -224,7 +231,9 @@ def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str
with ctx.new_serialization_settings(serialization_settings=serialization_settings) as ctx:
# Because execution states do not look up the context chain, it has to be made last
with ctx.new_execution_context(
mode=ExecutionState.Mode.TASK_EXECUTION, execution_params=execution_parameters
mode=ExecutionState.Mode.TASK_EXECUTION,
execution_params=execution_parameters,
additional_context={"dynamic_addl_distro": dynamic_addl_distro, "dynamic_dest_dir": dynamic_dest_dir},
) as ctx:
_dispatch_execute(ctx, task_def, inputs, output_prefix)

Expand Down Expand Up @@ -281,7 +290,16 @@ def _load_resolver(resolver_location: str) -> TaskResolverMixin:


@_scopes.system_entry_point
def _execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver: str, resolver_args: List[str]):
def _execute_task(
inputs,
output_prefix,
raw_output_data_prefix,
test,
resolver: str,
resolver_args: List[str],
dynamic_addl_distro: str = None,
dynamic_dest_dir: str = None,
):
"""
This function should be called for new API tasks (those only available in 0.16 and later that leverage Python
native typing).
Expand All @@ -299,6 +317,10 @@ def _execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver:
:param resolver: The task resolver to use. This needs to be loadable directly from importlib (and thus cannot be
nested).
:param resolver_args: Args that will be passed to the aforementioned resolver's load_task function
:param dynamic_addl_distro: In the case of parent tasks executed using the 'fast' mode this captures where the
compressed code archive has been uploaded.
:param dynamic_dest_dir: In the case of parent tasks executed using the 'fast' mode this captures where compressed
code archives should be installed in the flyte task container.
:return:
"""
if len(resolver_args) < 1:
Expand All @@ -313,12 +335,22 @@ def _execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver:
f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}"
)
return
_handle_annotated_task(_task_def, inputs, output_prefix, raw_output_data_prefix)
_handle_annotated_task(
_task_def, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir
)


@_scopes.system_entry_point
def _execute_map_task(
inputs, output_prefix, raw_output_data_prefix, max_concurrency, test, resolver: str, resolver_args: List[str]
inputs,
output_prefix,
raw_output_data_prefix,
max_concurrency,
test,
dynamic_addl_distro: str,
dynamic_dest_dir: str,
resolver: str,
resolver_args: List[str],
):
if len(resolver_args) < 1:
raise Exception(f"Resolver args cannot be <1, got {resolver_args}")
Expand All @@ -342,7 +374,9 @@ def _execute_map_task(
)
return

_handle_annotated_task(map_task, inputs, output_prefix, raw_output_data_prefix)
_handle_annotated_task(
map_task, inputs, output_prefix, raw_output_data_prefix, dynamic_addl_distro, dynamic_dest_dir
)


@_click.group()
Expand All @@ -357,14 +391,25 @@ def _pass_through():
@_click.option("--output-prefix", required=True)
@_click.option("--raw-output-data-prefix", required=False)
@_click.option("--test", is_flag=True)
@_click.option("--dynamic-addl-distro", required=False)
@_click.option("--dynamic-dest-dir", required=False)
@_click.option("--resolver", required=False)
@_click.argument(
"resolver-args",
type=_click.UNPROCESSED,
nargs=-1,
)
def execute_task_cmd(
task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test, resolver, resolver_args
task_module,
task_name,
inputs,
output_prefix,
raw_output_data_prefix,
test,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
resolver_args,
):
_click.echo(_utils.get_version_message())
# Backwards compatibility - if Propeller hasn't filled this in, then it'll come through here as the original
Expand All @@ -382,7 +427,16 @@ def execute_task_cmd(
_legacy_execute_task(task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test)
else:
_click.echo(f"Attempting to run with {resolver}...")
_execute_task(inputs, output_prefix, raw_output_data_prefix, test, resolver, resolver_args)
_execute_task(
inputs,
output_prefix,
raw_output_data_prefix,
test,
resolver,
resolver_args,
dynamic_addl_distro,
dynamic_dest_dir,
)


@_pass_through.command("pyflyte-fast-execute")
Expand All @@ -405,7 +459,15 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd):

# Use the commandline to run the task execute command rather than calling it directly in python code
# since the current runtime bytecode references the older user code, rather than the downloaded distribution.
_os.system(" ".join(task_execute_cmd))

# Insert the call to fast before the unbounded resolver args
cmd = []
for arg in task_execute_cmd:
if arg == "--resolver":
cmd.extend(["--dynamic-addl-distro", additional_distribution, "--dynamic-dest-dir", dest_dir])
cmd.append(arg)

_os.system(" ".join(cmd))


@_pass_through.command("pyflyte-map-execute")
Expand All @@ -414,6 +476,8 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd):
@_click.option("--raw-output-data-prefix", required=False)
@_click.option("--max-concurrency", type=int, required=False)
@_click.option("--test", is_flag=True)
@_click.option("--dynamic-addl-distro", required=False)
@_click.option("--dynamic-dest-dir", required=False)
@_click.option("--resolver", required=True)
@_click.argument(
"resolver-args",
Expand All @@ -426,6 +490,8 @@ def map_execute_task_cmd(
raw_output_data_prefix,
max_concurrency,
test,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
resolver_args,
):
Expand All @@ -437,6 +503,8 @@ def map_execute_task_cmd(
raw_output_data_prefix,
max_concurrency,
test,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
resolver_args,
)
Expand Down
2 changes: 1 addition & 1 deletion flytekit/common/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def get_serializable_workflow(

# Translate nodes
upstream_sdk_nodes = [
get_serializable(entity_mapping, settings, n)
get_serializable(entity_mapping, settings, n, fast)
for n in entity.nodes
if n.id != _common_constants.GLOBAL_INPUT_NODE_ID
]
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,19 @@ def new_execution_context(
working_dir = working_dir or self.file_access.get_random_local_directory()
engine_dir = os.path.join(working_dir, "engine_dir")
pathlib.Path(engine_dir).mkdir(parents=True, exist_ok=True)
if additional_context is None:
additional_context = self.execution_state.additional_context if self.execution_state is not None else None
elif self.execution_state is not None and self.execution_state.additional_context is not None:
additional_context = {**self.execution_state.additional_context, **additional_context}
exec_state = ExecutionState(
mode=mode, working_dir=working_dir, engine_dir=engine_dir, additional_context=additional_context
)

# If a wf_params object was not given, use the default (defined at the bottom of this file)
new_ctx = FlyteContext(
parent=self, execution_state=exec_state, user_space_params=execution_params or default_user_space_params
parent=self,
execution_state=exec_state,
user_space_params=execution_params or default_user_space_params,
)
FlyteContext.OBJS.append(new_ctx)
try:
Expand Down
38 changes: 35 additions & 3 deletions flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
return container_args

def compile_into_workflow(
self, ctx: FlyteContext, task_function: Callable, **kwargs
self, ctx: FlyteContext, is_fast_execution: bool, task_function: Callable, **kwargs
) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
with ctx.new_compilation_context(prefix="dynamic"):
# TODO: Resolve circular import
Expand All @@ -178,7 +178,7 @@ def compile_into_workflow(
self._wf.compile(**kwargs)

wf = self._wf
sdk_workflow = get_serializable(OrderedDict(), ctx.serialization_settings, wf)
sdk_workflow = get_serializable(OrderedDict(), ctx.serialization_settings, wf, is_fast_execution)

# If no nodes were produced, let's just return the strict outputs
if len(sdk_workflow.nodes) == 0:
Expand All @@ -192,6 +192,33 @@ def compile_into_workflow(
for n in sdk_workflow.nodes:
self.aggregate(tasks, sub_workflows, n)

if is_fast_execution:
if (
not ctx.execution_state
or not ctx.execution_state.additional_context
or not ctx.execution_state.additional_context.get("dynamic_addl_distro")
):
raise AssertionError(
"Compilation for a dynamic workflow called in fast execution mode but no additional code "
"distribution could be retrieved"
)
logger.warn(f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}")
sanitized_tasks = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't tasks already a set? why do we need another set?

for task in tasks:
sanitized_args = []
for arg in task.container.args:
if arg == "{{ .remote_package_path }}":
sanitized_args.append(ctx.execution_state.additional_context.get("dynamic_addl_distro"))
elif arg == "{{ .dest_dir }}":
sanitized_args.append(ctx.execution_state.additional_context.get("dynamic_dest_dir", "."))
else:
sanitized_args.append(arg)
del task.container.args[:]
task.container.args.extend(sanitized_args)
sanitized_tasks.add(task)

tasks = sanitized_tasks

dj_spec = _dynamic_job.DynamicJobSpec(
min_successes=len(sdk_workflow.nodes),
tasks=list(tasks),
Expand Down Expand Up @@ -241,4 +268,9 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
return task_function(**kwargs)

if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
return self.compile_into_workflow(ctx, task_function, **kwargs)
is_fast_execution = bool(
ctx.execution_state
and ctx.execution_state.additional_context
and ctx.execution_state.additional_context.get("dynamic_addl_distro")
)
return self.compile_into_workflow(ctx, is_fast_execution, task_function, **kwargs)
2 changes: 1 addition & 1 deletion plugins/tests/pod/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def dynamic_pod_task(a: int) -> List[int]:
)
) as ctx:
with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
dynamic_job_spec = dynamic_pod_task.compile_into_workflow(ctx, dynamic_pod_task._task_function, a=5)
dynamic_job_spec = dynamic_pod_task.compile_into_workflow(ctx, False, dynamic_pod_task._task_function, a=5)
assert len(dynamic_job_spec._nodes) == 5


Expand Down
13 changes: 12 additions & 1 deletion tests/flytekit/unit/core/test_context_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from flytekit.core.context_manager import CompilationState, FlyteContext, look_up_image_info
from flytekit.core.context_manager import CompilationState, ExecutionState, FlyteContext, look_up_image_info


class SampleTestClass(object):
Expand Down Expand Up @@ -42,3 +42,14 @@ def test_look_up_image_info():
assert img.name == "x"
assert img.tag == "latest"
assert img.fqn == "localhost:5000/xyz"


def test_additional_context():
with FlyteContext.current_context() as ctx:
with ctx.new_execution_context(
mode=ExecutionState.Mode.TASK_EXECUTION, additional_context={1: "outer", 2: "foo"}
) as exec_ctx_outer:
with exec_ctx_outer.new_execution_context(
mode=ExecutionState.Mode.TASK_EXECUTION, additional_context={1: "inner", 3: "baz"}
) as exec_ctx_inner:
assert exec_ctx_inner.execution_state.additional_context == {1: "inner", 2: "foo", 3: "baz"}
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_dynamic_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,6 @@ def merge_sort(in1: typing.List[int], count: int) -> typing.List[int]:
) as ctx:
with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
dynamic_job_spec = merge_sort_remotely.compile_into_workflow(
ctx, merge_sort_remotely._task_function, in1=[2, 3, 4, 5]
ctx, False, merge_sort_remotely._task_function, in1=[2, 3, 4, 5]
)
assert len(dynamic_job_spec.tasks) == 5
45 changes: 44 additions & 1 deletion tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,53 @@ def my_wf(a: int, b: str) -> (str, typing.List[str]):
)
) as ctx:
with ctx.new_execution_context(mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
dynamic_job_spec = my_subwf.compile_into_workflow(ctx, my_subwf._task_function, a=5)
dynamic_job_spec = my_subwf.compile_into_workflow(ctx, False, my_subwf._task_function, a=5)
assert len(dynamic_job_spec._nodes) == 5


def test_wf1_with_fast_dynamic():
@task
def t1(a: int) -> str:
a = a + 2
return "fast-" + str(a)

@dynamic
def my_subwf(a: int) -> typing.List[str]:
s = []
for i in range(a):
s.append(t1(a=i))
return s

@workflow
def my_wf(a: int) -> typing.List[str]:
v = my_subwf(a=a)
return v

with context_manager.FlyteContext.current_context().new_serialization_settings(
serialization_settings=context_manager.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
) as ctx:
with ctx.new_execution_context(
mode=ExecutionState.Mode.TASK_EXECUTION,
additional_context={
"dynamic_addl_distro": "s3::/my-s3-bucket/fast/123",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"dynamic_addl_distro": "s3::/my-s3-bucket/fast/123",
"dynamic_addl_distro": "s3://my-s3-bucket/fast/123",

"dynamic_dest_dir": "/User/flyte/workflows",
},
) as ctx:
dynamic_job_spec = my_subwf.compile_into_workflow(ctx, True, my_subwf._task_function, a=5)
assert len(dynamic_job_spec._nodes) == 5
assert len(dynamic_job_spec.tasks) == 1
args = " ".join(dynamic_job_spec.tasks[0].container.args)
assert args.startswith(
"pyflyte-fast-execute --additional-distribution s3::/my-s3-bucket/fast/123 --dest-dir /User/flyte/workflows"
)


def test_list_output():
@task
def t1(a: int) -> str:
Expand Down