diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index 8f131aace0..67391abb4d 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -10,6 +10,7 @@ CTX_PROJECT_ROOT = "project_root" CTX_MODULE = "module" CTX_VERBOSE = "verbose" +CTX_COPY_ALL = "copy_all" project_option = _click.option( diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 6689b44c43..9c7228ec46 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -18,6 +18,7 @@ from flytekit import BlobType, Literal, Scalar from flytekit.clis.sdk_in_container.constants import ( CTX_CONFIG_FILE, + CTX_COPY_ALL, CTX_DOMAIN, CTX_MODULE, CTX_PROJECT, @@ -512,6 +513,13 @@ def get_workflow_command_base_params() -> typing.List[click.Option]: default="/root", help="Directory inside the image where the tar file containing the code will be copied to", ), + click.Option( + param_decls=["--copy-all", "copy_all"], + required=False, + is_flag=True, + default=False, + help="Copy all files in the source root directory to the destination directory", + ), click.Option( param_decls=["-i", "--image", "image_config"], required=False, @@ -643,6 +651,7 @@ def _run(*args, **kwargs): destination_dir=run_level_params.get("destination_dir"), source_path=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT), module_name=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_MODULE), + copy_all=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_COPY_ALL), ) options = None diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 8716504dc1..91189ede74 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -799,17 +799,19 @@ def register_script( project: typing.Optional[str] = None, domain: typing.Optional[str] = None, destination_dir: str = ".", - default_launch_plan: typing.Optional[bool] = True, + copy_all: bool = False, + default_launch_plan: bool = True, options: typing.Optional[Options] = None, source_path: typing.Optional[str] = None, module_name: typing.Optional[str] = None, ) -> typing.Union[FlyteWorkflow, FlyteTask]: """ Use this method to register a workflow via script mode. - :param destination_dir: - :param domain: - :param project: - :param image_config: + :param destination_dir: The destination directory where the workflow will be copied to. + :param copy_all: If true, the entire source directory will be copied over to the destination directory. + :param domain: The domain to register the workflow in. + :param project: The project to register the workflow in. + :param image_config: The image config to use for the workflow. :param version: version for the entity to be registered as :param entity: The workflow to be registered or the task to be registered :param default_launch_plan: This should be true if a default launch plan should be created for the workflow @@ -822,11 +824,14 @@ def register_script( image_config = ImageConfig.auto_default_image() with tempfile.TemporaryDirectory() as tmp_dir: - archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz")) - compress_scripts(source_path, str(archive_fname), module_name) - md5_bytes, upload_native_url = self.upload_file( - archive_fname, project or self.default_project, domain or self.default_domain - ) + if copy_all: + md5_bytes, upload_native_url = self.fast_package(pathlib.Path(source_path), False, tmp_dir) + else: + archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz")) + compress_scripts(source_path, str(archive_fname), module_name) + md5_bytes, upload_native_url = self.upload_file( + archive_fname, project or self.default_project, domain or self.default_domain + ) serialization_settings = SerializationSettings( project=project, diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index d839ab474b..735df4af2c 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -67,6 +67,16 @@ def test_imperative_wf(): assert result.exit_code == 0 +def test_copy_all_files(): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + ["run", "--copy-all", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + def test_pyflyte_run_cli(): runner = CliRunner() parquet_file = os.path.join(DIR_NAME, "testdata/df.parquet")