Skip to content

Commit

Permalink
pyflyte run remote file (#1670)
Browse files Browse the repository at this point in the history
Signed-off-by: ChungYujoyce <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
2 people authored and eapolinario committed Jun 29, 2023
1 parent 37e9f3a commit cbb968c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 6 deletions.
2 changes: 1 addition & 1 deletion flytekit/clis/sdk_in_container/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, filename: str, *args, **kwargs):
self._filename = pathlib.Path(filename).resolve()

def list_commands(self, ctx):
entities = get_entities_in_file(self._filename.__str__())
entities = get_entities_in_file(self._filename.__str__(), False)
return entities.all()

def get_command(self, ctx, exe_entity):
Expand Down
1 change: 1 addition & 0 deletions flytekit/clis/sdk_in_container/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
CTX_MODULE = "module"
CTX_VERBOSE = "verbose"
CTX_COPY_ALL = "copy_all"
CTX_FILE_NAME = "file_name"


project_option = _click.option(
Expand Down
26 changes: 22 additions & 4 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CTX_CONFIG_FILE,
CTX_COPY_ALL,
CTX_DOMAIN,
CTX_FILE_NAME,
CTX_MODULE,
CTX_PROJECT,
CTX_PROJECT_ROOT,
Expand Down Expand Up @@ -595,7 +596,7 @@ def all(self) -> typing.List[str]:
return e


def get_entities_in_file(filename: str) -> Entities:
def get_entities_in_file(filename: pathlib.Path, should_delete: bool) -> Entities:
"""
Returns a list of flyte workflow names and list of Flyte tasks in a file.
"""
Expand All @@ -615,6 +616,8 @@ def get_entities_in_file(filename: str) -> Entities:
elif isinstance(o, PythonTask):
tasks.append(name)

if should_delete and os.path.exists(filename):
os.remove(filename)
return Entities(workflows, tasks)


Expand All @@ -635,6 +638,8 @@ def _run(*args, **kwargs):
if not ctx.obj[REMOTE_FLAG_KEY]:
output = entity(**inputs)
click.echo(output)
if ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME):
os.remove(ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME))
return

remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY]
Expand Down Expand Up @@ -678,6 +683,9 @@ def _run(*args, **kwargs):
if run_level_params.get("dump_snippet"):
dump_flyte_remote_snippet(execution, project, domain)

if ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME):
os.remove(ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME))

return _run


Expand All @@ -688,10 +696,19 @@ class WorkflowCommand(click.RichGroup):

def __init__(self, filename: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self._filename = pathlib.Path(filename).resolve()

ctx = context_manager.FlyteContextManager.current_context()
if ctx.file_access.is_remote(filename):
local_path = os.path.join(os.path.curdir, filename.rsplit("/", 1)[1])
ctx.file_access.download(filename, local_path)
self._filename = pathlib.Path(local_path).resolve()
self._should_delete = True
else:
self._filename = pathlib.Path(filename).resolve()
self._should_delete = False

def list_commands(self, ctx):
entities = get_entities_in_file(self._filename)
entities = get_entities_in_file(self._filename, self._should_delete)
return entities.all()

def get_command(self, ctx, exe_entity):
Expand Down Expand Up @@ -721,7 +738,8 @@ def get_command(self, ctx, exe_entity):

ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_PROJECT_ROOT] = project_root
ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_MODULE] = module

if self._should_delete:
ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_FILE_NAME] = self._filename
entity = load_naive_entity(module, exe_entity, project_root)

# If this is a remote execution, which we should know at this point, then create the remote object
Expand Down
13 changes: 12 additions & 1 deletion tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from flytekit.remote import FlyteRemote

WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py")
REMOTE_WORKFLOW_FILE = "https://raw.githubusercontent.com/flyteorg/flytesnacks/8337b64b33df046b2f6e4cba03c74b7bdc0c4fb1/cookbook/core/flyte_basics/basic_workflow.py"
IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py")
DIR_NAME = os.path.dirname(os.path.realpath(__file__))

Expand Down Expand Up @@ -77,6 +78,16 @@ def test_copy_all_files():
assert result.exit_code == 0


def test_remote_files():
runner = CliRunner()
result = runner.invoke(
pyflyte.main,
["run", REMOTE_WORKFLOW_FILE, "my_wf", "--a", "1", "--b", "Hello"],
catch_exceptions=False,
)
assert result.exit_code == 0


def test_pyflyte_run_cli():
runner = CliRunner()
parquet_file = os.path.join(DIR_NAME, "testdata/df.parquet")
Expand Down Expand Up @@ -184,7 +195,7 @@ def test_union_type_with_invalid_input():


def test_get_entities_in_file():
e = get_entities_in_file(WORKFLOW_FILE)
e = get_entities_in_file(WORKFLOW_FILE, False)
assert e.workflows == ["my_wf"]
assert e.tasks == ["get_subset_df", "print_all", "show_sd", "test_union1", "test_union2"]
assert e.all() == ["my_wf", "get_subset_df", "print_all", "show_sd", "test_union1", "test_union2"]
Expand Down

0 comments on commit cbb968c

Please sign in to comment.