From c2473c092a3ea9d0f1deeb9ddb4f2b2ace09f2e3 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Tue, 28 Sep 2021 16:03:26 -0700 Subject: [PATCH] Fix pod task fast registration template substitution. (#678) --- flytekit/clis/flyte_cli/main.py | 42 +++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 06edd4bda5..9ed44d776e 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -1950,6 +1950,17 @@ def register_files( _extract_and_register(host, insecure, project, domain, version, files, patches) +def _substitute_fast_register_task_args(args: List[str], full_remote_path: str, dest_dir: str) -> List[str]: + complete_args = [] + for arg in args: + if arg == "{{ .remote_package_path }}": + arg = full_remote_path + elif arg == "{{ .dest_dir }}": + arg = dest_dir if dest_dir else "." + complete_args.append(arg) + return complete_args + + @_flyte_cli.command("fast-register-files", cls=_FlyteSubCommand) @_click.option(*_PROJECT_FLAGS, required=True, help="The project namespace to register with.") @_click.option(*_DOMAIN_FLAGS, required=True, help="The domain namespace to register with.") @@ -2032,18 +2043,25 @@ def fast_register_task(entity: _GeneratedProtocolMessageType) -> _GeneratedProto task execution. """ # entity is of type flyteidl.admin.task_pb2.TaskSpec - if not entity.template.HasField("container") or len(entity.template.container.args) == 0: - # Containerless tasks are always fast registerable without modification - return entity - complete_args = [] - for arg in entity.template.container.args: - if arg == "{{ .remote_package_path }}": - arg = full_remote_path - elif arg == "{{ .dest_dir }}": - arg = dest_dir if dest_dir else "." - complete_args.append(arg) - del entity.template.container.args[:] - entity.template.container.args.extend(complete_args) + + if entity.template.HasField("container") and len(entity.template.container.args) > 0: + complete_args = _substitute_fast_register_task_args( + entity.template.container.args, full_remote_path, dest_dir + ) + # Because we're dealing with a proto list, we have to delete the existing args before we can extend the list + # with the substituted ones. + del entity.template.container.args[:] + entity.template.container.args.extend(complete_args) + + if entity.template.HasField("k8s_pod"): + pod_spec_struct = entity.template.k8s_pod.pod_spec + if "containers" in pod_spec_struct: + for idx in range(len(pod_spec_struct["containers"])): + if "args" in pod_spec_struct["containers"][idx]: + # We can directly overwrite the args in the pod spec struct definition. + pod_spec_struct["containers"][idx]["args"] = _substitute_fast_register_task_args( + pod_spec_struct["containers"][idx]["args"], full_remote_path, dest_dir + ) return entity patches = {