diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py index 73274e972e..d922f5e3c1 100644 --- a/flytekit/clis/helpers.py +++ b/flytekit/clis/helpers.py @@ -1,5 +1,7 @@ +import sys from typing import Tuple, Union +import click from flyteidl.admin.launch_plan_pb2 import LaunchPlan from flyteidl.admin.task_pb2 import TaskSpec from flyteidl.admin.workflow_pb2 import WorkflowSpec @@ -125,3 +127,9 @@ def hydrate_registration_parameters( del entity.sub_workflows[:] entity.sub_workflows.extend(refreshed_sub_workflows) return identifier, entity + + +def display_help_with_error(ctx: click.Context, message: str): + click.echo(f"{ctx.get_help()}\n") + click.secho(message, fg="red") + sys.exit(1) diff --git a/flytekit/clis/sdk_in_container/package.py b/flytekit/clis/sdk_in_container/package.py index 71efeab576..2a884e29da 100644 --- a/flytekit/clis/sdk_in_container/package.py +++ b/flytekit/clis/sdk_in_container/package.py @@ -1,8 +1,8 @@ import os -import sys import click +from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants from flytekit.configuration import ( DEFAULT_RUNTIME_PYTHON_INTERPRETER, @@ -100,8 +100,7 @@ def package(ctx, image_config, source, output, force, fast, in_container_source_ pkgs = ctx.obj[constants.CTX_PACKAGES] if not pkgs: - click.secho("No packages to scan for flyte entities. Aborting!", fg="red") - sys.exit(-1) + display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!") try: serialize_and_package(pkgs, serialization_settings, source, output, fast) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 8a3be2dfa6..03e00d7896 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -1,10 +1,10 @@ import os import pathlib -import sys import typing import click +from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings @@ -67,7 +67,7 @@ "--output", required=False, type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True), - default=".", + default=None, help="Directory to write the output zip file containing the protobuf definitions", ) @click.option( @@ -122,6 +122,12 @@ def register( if pkgs: raise ValueError("Unimplemented, just specify pkgs like folder/files as args at the end of the command") + if len(package_or_module) == 0: + display_help_with_error( + ctx, + "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed", + ) + cli_logger.debug( f"Running pyflyte register from {os.getcwd()} " f"with images {image_config} " @@ -162,8 +168,7 @@ def register( serialization_settings, detected_root, list(package_or_module), options ) if len(registerable_entities) == 0: - click.secho("No Flyte entities were detected. Aborting!", fg="red") - sys.exit(1) + display_help_with_error(ctx, "No Flyte entities were detected. Aborting!") cli_logger.info(f"Found and serialized {len(registerable_entities)} entities") if not version: diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index f38c3803a1..c4ac31a01a 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -9,6 +9,8 @@ import tempfile from typing import Optional +import click + from flytekit.core.context_manager import FlyteContextManager from flytekit.tools.ignore import DockerIgnore, GitIgnore, IgnoreGroup, StandardIgnore from flytekit.tools.script_mode import tar_strip_file_attributes @@ -31,8 +33,11 @@ def fast_package(source: os.PathLike, output_dir: os.PathLike) -> os.PathLike: digest = compute_digest(source, ignore.is_ignored) archive_fname = f"{FAST_PREFIX}{digest}{FAST_FILEENDING}" - if output_dir: - archive_fname = os.path.join(output_dir, archive_fname) + if output_dir is None: + output_dir = tempfile.mkdtemp() + click.secho(f"Output given as {None}, using a temporary directory at {output_dir} instead", fg="yellow") + + archive_fname = os.path.join(output_dir, archive_fname) with tempfile.TemporaryDirectory() as tmp_dir: tar_path = os.path.join(tmp_dir, "tmp.tar") diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index ad07ea2111..167c772184 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -91,6 +91,10 @@ def package( # If Fast serialization is enabled, then an archive is also created and packaged if fast: + # If output exists and is a path within source, delete it so as to not re-bundle it again. + if os.path.abspath(output).startswith(os.path.abspath(source)) and os.path.exists(output): + click.secho(f"{output} already exists within {source}, deleting and re-creating it", fg="yellow") + os.remove(output) archive_fname = fast_registration.fast_package(source, output_tmpdir) click.secho(f"Fast mode enabled: compressed archive {archive_fname}", dim=True) diff --git a/tests/flytekit/unit/cli/pyflyte/conftest.py b/tests/flytekit/unit/cli/pyflyte/conftest.py index 723fb4878b..6ce51bd4e1 100644 --- a/tests/flytekit/unit/cli/pyflyte/conftest.py +++ b/tests/flytekit/unit/cli/pyflyte/conftest.py @@ -18,7 +18,7 @@ def _fake_module_load(names): yield simple -@pytest.yield_fixture( +@pytest.fixture( scope="function", params=[ os.path.join( diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index d31b760608..364b6b14d9 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -1,3 +1,6 @@ +import os +import shutil + import pytest from click.testing import CliRunner from flyteidl.admin.launch_plan_pb2 import LaunchPlan @@ -11,6 +14,22 @@ from flytekit.core import context_manager from flytekit.exceptions.user import FlyteValidationException +sample_file_contents = """ +from flytekit import task, workflow + +@task(cache=True, cache_version="1", retries=3) +def sum(x: int, y: int) -> int: + return x + y + +@task(cache=True, cache_version="1", retries=3) +def square(z: int) -> int: + return z*z + +@workflow +def my_workflow(x: int, y: int) -> int: + return sum(x=square(z=x), y=square(z=y)) +""" + @flytekit.task def foo(): @@ -44,6 +63,29 @@ def test_get_registrable_entities(): assert False, f"found unknown entity {type(e)}" +def test_package_with_fast_registration(): + runner = CliRunner() + with runner.isolated_filesystem(): + os.makedirs("core", exist_ok=True) + with open(os.path.join("core", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + result = runner.invoke(pyflyte.main, ["--pkgs", "core", "package", "--image", "core:v1", "--fast"]) + assert result.exit_code == 0 + assert "Successfully serialized" in result.output + assert "Successfully packaged" in result.output + result = runner.invoke(pyflyte.main, ["--pkgs", "core", "package", "--image", "core:v1", "--fast"]) + assert result.exit_code == 2 + assert "flyte-package.tgz already exists, specify -f to override" in result.output + result = runner.invoke( + pyflyte.main, + ["--pkgs", "core", "package", "--image", "core:v1", "--fast", "--force"], + ) + assert result.exit_code == 0 + assert "deleting and re-creating it" in result.output + shutil.rmtree("core") + + def test_duplicate_registrable_entities(): @flytekit.task def t_1(): @@ -114,3 +156,11 @@ def test_package(): def test_pkgs(): pp = pyflyte.validate_package(None, None, ["a.b", "a.c,b.a", "cc.a"]) assert pp == ["a.b", "a.c", "b.a", "cc.a"] + + +def test_package_with_no_pkgs(): + runner = CliRunner() + with runner.isolated_filesystem(): + result = runner.invoke(pyflyte.main, ["package"]) + assert result.exit_code == 1 + assert "No packages to scan for flyte entities. Aborting!" in result.output diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index 52bf13d0a6..d078851e1b 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -1,6 +1,30 @@ +import os +import shutil +import subprocess + import mock +from click.testing import CliRunner +from flytekit.clients.friendly import SynchronousFlyteClient +from flytekit.clis.sdk_in_container import pyflyte from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.remote.remote import FlyteRemote + +sample_file_contents = """ +from flytekit import task, workflow + +@task(cache=True, cache_version="1", retries=3) +def sum(x: int, y: int) -> int: + return x + y + +@task(cache=True, cache_version="1", retries=3) +def square(z: int) -> int: + return z*z + +@workflow +def my_workflow(x: int, y: int) -> int: + return sum(x=square(z=x), y=square(z=y)) +""" @mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote") @@ -9,3 +33,33 @@ def test_saving_remote(mock_remote): mock_context.obj = {} get_and_save_remote_with_click_context(mock_context, "p", "d") assert mock_context.obj["flyte_remote"] is not None + + +def test_register_with_no_package_or_module_argument(): + runner = CliRunner() + with runner.isolated_filesystem(): + result = runner.invoke(pyflyte.main, ["register"]) + assert result.exit_code == 1 + assert ( + "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed" + in result.output + ) + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_with_no_output_dir_passed(mock_client, mock_remote): + mock_remote._client = mock_client + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value._upload_file.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core", exist_ok=True) + with open(os.path.join("core", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + result = runner.invoke(pyflyte.main, ["register", "core"]) + assert "Output given as None, using a temporary directory at" in result.output + shutil.rmtree("core")