Skip to content

Commit

Permalink
Fix pod task fast registration template substitution. (#678)
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrina Rogan authored Sep 28, 2021
1 parent c889122 commit c2473c0
Showing 1 changed file with 30 additions and 12 deletions.
42 changes: 30 additions & 12 deletions flytekit/clis/flyte_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,17 @@ def register_files(
_extract_and_register(host, insecure, project, domain, version, files, patches)


def _substitute_fast_register_task_args(args: List[str], full_remote_path: str, dest_dir: str) -> List[str]:
complete_args = []
for arg in args:
if arg == "{{ .remote_package_path }}":
arg = full_remote_path
elif arg == "{{ .dest_dir }}":
arg = dest_dir if dest_dir else "."
complete_args.append(arg)
return complete_args


@_flyte_cli.command("fast-register-files", cls=_FlyteSubCommand)
@_click.option(*_PROJECT_FLAGS, required=True, help="The project namespace to register with.")
@_click.option(*_DOMAIN_FLAGS, required=True, help="The domain namespace to register with.")
Expand Down Expand Up @@ -2032,18 +2043,25 @@ def fast_register_task(entity: _GeneratedProtocolMessageType) -> _GeneratedProto
task execution.
"""
# entity is of type flyteidl.admin.task_pb2.TaskSpec
if not entity.template.HasField("container") or len(entity.template.container.args) == 0:
# Containerless tasks are always fast registerable without modification
return entity
complete_args = []
for arg in entity.template.container.args:
if arg == "{{ .remote_package_path }}":
arg = full_remote_path
elif arg == "{{ .dest_dir }}":
arg = dest_dir if dest_dir else "."
complete_args.append(arg)
del entity.template.container.args[:]
entity.template.container.args.extend(complete_args)

if entity.template.HasField("container") and len(entity.template.container.args) > 0:
complete_args = _substitute_fast_register_task_args(
entity.template.container.args, full_remote_path, dest_dir
)
# Because we're dealing with a proto list, we have to delete the existing args before we can extend the list
# with the substituted ones.
del entity.template.container.args[:]
entity.template.container.args.extend(complete_args)

if entity.template.HasField("k8s_pod"):
pod_spec_struct = entity.template.k8s_pod.pod_spec
if "containers" in pod_spec_struct:
for idx in range(len(pod_spec_struct["containers"])):
if "args" in pod_spec_struct["containers"][idx]:
# We can directly overwrite the args in the pod spec struct definition.
pod_spec_struct["containers"][idx]["args"] = _substitute_fast_register_task_args(
pod_spec_struct["containers"][idx]["args"], full_remote_path, dest_dir
)
return entity

patches = {
Expand Down

0 comments on commit c2473c0

Please sign in to comment.