From 63ad4fc9478d4201156549a4fe5d95d6d92bb419 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 29 Sep 2022 22:13:19 +0800 Subject: [PATCH] pyflyte non-fast register (#1205) * pyflyte run non-fast register Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- flytekit/clis/sdk_in_container/register.py | 36 ++++++++++++------- .../unit/cli/pyflyte/test_register.py | 19 ++++++++++ 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 024b70edde..c0bdcd2416 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -13,6 +13,7 @@ from flytekit.tools.fast_registration import fast_package from flytekit.tools.repo import find_common_root, load_packages_and_modules from flytekit.tools.repo import register as repo_register +from flytekit.tools.script_mode import hash_file from flytekit.tools.translator import Options _register_help = """ @@ -105,6 +106,12 @@ is_flag=True, help="Enables symlink dereferencing when packaging files in fast registration", ) +@click.option( + "--non-fast", + default=False, + is_flag=True, + help="Enables to skip zipping and uploading the package", +) @click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) @click.pass_context def register( @@ -118,6 +125,7 @@ def register( raw_data_prefix: str, version: typing.Optional[str], deref_symlinks: bool, + non_fast: bool, package_or_module: typing.Tuple[str], ): """ @@ -138,22 +146,30 @@ def register( cli_logger.debug( f"Running pyflyte register from {os.getcwd()} " f"with images {image_config} " - f"and image destinationfolder {destination_dir} " + f"and image destination folder {destination_dir} " f"on {len(package_or_module)} package(s) {package_or_module}" ) # Create and save FlyteRemote, remote = get_and_save_remote_with_click_context(ctx, project, domain) - # Todo: add switch for non-fast - skip the zipping and uploading and no fastserializationsettings - # Create a zip file containing all the entries. detected_root = find_common_root(package_or_module) cli_logger.debug(f"Using {detected_root} as root folder for project") - zip_file = fast_package(detected_root, output, deref_symlinks) + fast_serialization_settings = None - # Upload zip file to Admin using FlyteRemote. - md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file)) - cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}") + # Create a zip file containing all the entries. + zip_file = fast_package(detected_root, output, deref_symlinks) + md5_bytes, _ = hash_file(pathlib.Path(zip_file)) + + if non_fast is False: + # Upload zip file to Admin using FlyteRemote. + md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file)) + cli_logger.debug(f"Uploaded zip {zip_file} to {native_url}") + fast_serialization_settings = FastSerializationSettings( + enabled=not non_fast, + destination_dir=destination_dir, + distribution_location=native_url, + ) # Create serialization settings # Todo: Rely on default Python interpreter for now, this will break custom Spark containers @@ -161,11 +177,7 @@ def register( project=project, domain=domain, image_config=image_config, - fast_serialization_settings=FastSerializationSettings( - enabled=True, - destination_dir=destination_dir, - distribution_location=native_url, - ), + fast_serialization_settings=fast_serialization_settings, ) options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index d078851e1b..e9661dff6a 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -63,3 +63,22 @@ def test_register_with_no_output_dir_passed(mock_client, mock_remote): result = runner.invoke(pyflyte.main, ["register", "core"]) assert "Output given as None, using a temporary directory at" in result.output shutil.rmtree("core") + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_non_fast_register(mock_client, mock_remote): + mock_remote._client = mock_client + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value._upload_file.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core", exist_ok=True) + with open(os.path.join("core", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core"]) + assert "Output given as None, using a temporary directory at" in result.output + shutil.rmtree("core")