Skip to content

Commit

Permalink
Add support for copying all the files in source root (#1622)
Browse files Browse the repository at this point in the history
* Add support for copying all the files in source root

Signed-off-by: Kevin Su <[email protected]>

* Add tests

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored May 8, 2023
1 parent 7efde40 commit ca46761
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 10 deletions.
1 change: 1 addition & 0 deletions flytekit/clis/sdk_in_container/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CTX_PROJECT_ROOT = "project_root"
CTX_MODULE = "module"
CTX_VERBOSE = "verbose"
CTX_COPY_ALL = "copy_all"


project_option = _click.option(
Expand Down
9 changes: 9 additions & 0 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flytekit import BlobType, Literal, Scalar
from flytekit.clis.sdk_in_container.constants import (
CTX_CONFIG_FILE,
CTX_COPY_ALL,
CTX_DOMAIN,
CTX_MODULE,
CTX_PROJECT,
Expand Down Expand Up @@ -512,6 +513,13 @@ def get_workflow_command_base_params() -> typing.List[click.Option]:
default="/root",
help="Directory inside the image where the tar file containing the code will be copied to",
),
click.Option(
param_decls=["--copy-all", "copy_all"],
required=False,
is_flag=True,
default=False,
help="Copy all files in the source root directory to the destination directory",
),
click.Option(
param_decls=["-i", "--image", "image_config"],
required=False,
Expand Down Expand Up @@ -643,6 +651,7 @@ def _run(*args, **kwargs):
destination_dir=run_level_params.get("destination_dir"),
source_path=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT),
module_name=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_MODULE),
copy_all=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_COPY_ALL),
)

options = None
Expand Down
25 changes: 15 additions & 10 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,17 +799,19 @@ def register_script(
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
destination_dir: str = ".",
default_launch_plan: typing.Optional[bool] = True,
copy_all: bool = False,
default_launch_plan: bool = True,
options: typing.Optional[Options] = None,
source_path: typing.Optional[str] = None,
module_name: typing.Optional[str] = None,
) -> typing.Union[FlyteWorkflow, FlyteTask]:
"""
Use this method to register a workflow via script mode.
:param destination_dir:
:param domain:
:param project:
:param image_config:
:param destination_dir: The destination directory where the workflow will be copied to.
:param copy_all: If true, the entire source directory will be copied over to the destination directory.
:param domain: The domain to register the workflow in.
:param project: The project to register the workflow in.
:param image_config: The image config to use for the workflow.
:param version: version for the entity to be registered as
:param entity: The workflow to be registered or the task to be registered
:param default_launch_plan: This should be true if a default launch plan should be created for the workflow
Expand All @@ -822,11 +824,14 @@ def register_script(
image_config = ImageConfig.auto_default_image()

with tempfile.TemporaryDirectory() as tmp_dir:
archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz"))
compress_scripts(source_path, str(archive_fname), module_name)
md5_bytes, upload_native_url = self.upload_file(
archive_fname, project or self.default_project, domain or self.default_domain
)
if copy_all:
md5_bytes, upload_native_url = self.fast_package(pathlib.Path(source_path), False, tmp_dir)
else:
archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz"))
compress_scripts(source_path, str(archive_fname), module_name)
md5_bytes, upload_native_url = self.upload_file(
archive_fname, project or self.default_project, domain or self.default_domain
)

serialization_settings = SerializationSettings(
project=project,
Expand Down
10 changes: 10 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def test_imperative_wf():
assert result.exit_code == 0


def test_copy_all_files():
runner = CliRunner()
result = runner.invoke(
pyflyte.main,
["run", "--copy-all", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"],
catch_exceptions=False,
)
assert result.exit_code == 0


def test_pyflyte_run_cli():
runner = CliRunner()
parquet_file = os.path.join(DIR_NAME, "testdata/df.parquet")
Expand Down

0 comments on commit ca46761

Please sign in to comment.