diff --git a/flytekit/clis/sdk_in_container/build.py b/flytekit/clis/sdk_in_container/build.py index 3e18535268..33b8346f10 100644 --- a/flytekit/clis/sdk_in_container/build.py +++ b/flytekit/clis/sdk_in_container/build.py @@ -63,7 +63,7 @@ def __init__(self, filename: str, *args, **kwargs): self._filename = pathlib.Path(filename).resolve() def list_commands(self, ctx): - entities = get_entities_in_file(self._filename.__str__()) + entities = get_entities_in_file(self._filename.__str__(), False) return entities.all() def get_command(self, ctx, exe_entity): diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index 67391abb4d..8059d4d14d 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -11,6 +11,7 @@ CTX_MODULE = "module" CTX_VERBOSE = "verbose" CTX_COPY_ALL = "copy_all" +CTX_FILE_NAME = "file_name" project_option = _click.option( diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 336ffbdad6..3753b237f9 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -21,6 +21,7 @@ CTX_CONFIG_FILE, CTX_COPY_ALL, CTX_DOMAIN, + CTX_FILE_NAME, CTX_MODULE, CTX_PROJECT, CTX_PROJECT_ROOT, @@ -626,7 +627,7 @@ def all(self) -> typing.List[str]: return e -def get_entities_in_file(filename: str) -> Entities: +def get_entities_in_file(filename: pathlib.Path, should_delete: bool) -> Entities: """ Returns a list of flyte workflow names and list of Flyte tasks in a file. """ @@ -646,6 +647,8 @@ def get_entities_in_file(filename: str) -> Entities: elif isinstance(o, PythonTask): tasks.append(name) + if should_delete and os.path.exists(filename): + os.remove(filename) return Entities(workflows, tasks) @@ -666,6 +669,8 @@ def _run(*args, **kwargs): if not ctx.obj[REMOTE_FLAG_KEY]: output = entity(**inputs) click.echo(output) + if ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME): + os.remove(ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME)) return remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] @@ -711,6 +716,9 @@ def _run(*args, **kwargs): if run_level_params.get("dump_snippet"): dump_flyte_remote_snippet(execution, project, domain) + if ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME): + os.remove(ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME)) + return _run @@ -721,10 +729,19 @@ class WorkflowCommand(click.RichGroup): def __init__(self, filename: str, *args, **kwargs): super().__init__(*args, **kwargs) - self._filename = pathlib.Path(filename).resolve() + + ctx = context_manager.FlyteContextManager.current_context() + if ctx.file_access.is_remote(filename): + local_path = os.path.join(os.path.curdir, filename.rsplit("/", 1)[1]) + ctx.file_access.download(filename, local_path) + self._filename = pathlib.Path(local_path).resolve() + self._should_delete = True + else: + self._filename = pathlib.Path(filename).resolve() + self._should_delete = False def list_commands(self, ctx): - entities = get_entities_in_file(self._filename) + entities = get_entities_in_file(self._filename, self._should_delete) return entities.all() def get_command(self, ctx, exe_entity): @@ -754,7 +771,8 @@ def get_command(self, ctx, exe_entity): ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_PROJECT_ROOT] = project_root ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_MODULE] = module - + if self._should_delete: + ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_FILE_NAME] = self._filename entity = load_naive_entity(module, exe_entity, project_root) # If this is a remote execution, which we should know at this point, then create the remote object diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index e33aeebe56..6629eb245b 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -36,6 +36,7 @@ from flytekit.remote import FlyteRemote WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") +REMOTE_WORKFLOW_FILE = "https://raw.githubusercontent.com/flyteorg/flytesnacks/8337b64b33df046b2f6e4cba03c74b7bdc0c4fb1/cookbook/core/flyte_basics/basic_workflow.py" IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py") DIR_NAME = os.path.dirname(os.path.realpath(__file__)) @@ -77,6 +78,16 @@ def test_copy_all_files(): assert result.exit_code == 0 +def test_remote_files(): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + ["run", REMOTE_WORKFLOW_FILE, "my_wf", "--a", "1", "--b", "Hello"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + def test_pyflyte_run_cli(): runner = CliRunner() parquet_file = os.path.join(DIR_NAME, "testdata/df.parquet") @@ -181,7 +192,7 @@ def test_union_type_with_invalid_input(): def test_get_entities_in_file(): - e = get_entities_in_file(WORKFLOW_FILE) + e = get_entities_in_file(WORKFLOW_FILE, False) assert e.workflows == ["my_wf"] assert e.tasks == ["get_subset_df", "print_all", "show_sd", "test_union1", "test_union2"] assert e.all() == ["my_wf", "get_subset_df", "print_all", "show_sd", "test_union1", "test_union2"]