Skip to content

Commit

Permalink
Housekeeping for pyflyte package and register commands (#1084)
Browse files Browse the repository at this point in the history
Signed-off-by: Madhur Tandon <[email protected]>
  • Loading branch information
madhur-tandon authored Jul 5, 2022
1 parent a968a5a commit 3137609
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 10 deletions.
8 changes: 8 additions & 0 deletions flytekit/clis/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys
from typing import Tuple, Union

import click
from flyteidl.admin.launch_plan_pb2 import LaunchPlan
from flyteidl.admin.task_pb2 import TaskSpec
from flyteidl.admin.workflow_pb2 import WorkflowSpec
Expand Down Expand Up @@ -125,3 +127,9 @@ def hydrate_registration_parameters(
del entity.sub_workflows[:]
entity.sub_workflows.extend(refreshed_sub_workflows)
return identifier, entity


def display_help_with_error(ctx: click.Context, message: str):
click.echo(f"{ctx.get_help()}\n")
click.secho(message, fg="red")
sys.exit(1)
5 changes: 2 additions & 3 deletions flytekit/clis/sdk_in_container/package.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import sys

import click

from flytekit.clis.helpers import display_help_with_error
from flytekit.clis.sdk_in_container import constants
from flytekit.configuration import (
DEFAULT_RUNTIME_PYTHON_INTERPRETER,
Expand Down Expand Up @@ -100,8 +100,7 @@ def package(ctx, image_config, source, output, force, fast, in_container_source_

pkgs = ctx.obj[constants.CTX_PACKAGES]
if not pkgs:
click.secho("No packages to scan for flyte entities. Aborting!", fg="red")
sys.exit(-1)
display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!")

try:
serialize_and_package(pkgs, serialization_settings, source, output, fast)
Expand Down
13 changes: 9 additions & 4 deletions flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import pathlib
import sys
import typing

import click

from flytekit.clis.helpers import display_help_with_error
from flytekit.clis.sdk_in_container import constants
from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context
from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings
Expand Down Expand Up @@ -67,7 +67,7 @@
"--output",
required=False,
type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True),
default=".",
default=None,
help="Directory to write the output zip file containing the protobuf definitions",
)
@click.option(
Expand Down Expand Up @@ -122,6 +122,12 @@ def register(
if pkgs:
raise ValueError("Unimplemented, just specify pkgs like folder/files as args at the end of the command")

if len(package_or_module) == 0:
display_help_with_error(
ctx,
"Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed",
)

cli_logger.debug(
f"Running pyflyte register from {os.getcwd()} "
f"with images {image_config} "
Expand Down Expand Up @@ -162,8 +168,7 @@ def register(
serialization_settings, detected_root, list(package_or_module), options
)
if len(registerable_entities) == 0:
click.secho("No Flyte entities were detected. Aborting!", fg="red")
sys.exit(1)
display_help_with_error(ctx, "No Flyte entities were detected. Aborting!")
cli_logger.info(f"Found and serialized {len(registerable_entities)} entities")

if not version:
Expand Down
9 changes: 7 additions & 2 deletions flytekit/tools/fast_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import tempfile
from typing import Optional

import click

from flytekit.core.context_manager import FlyteContextManager
from flytekit.tools.ignore import DockerIgnore, GitIgnore, IgnoreGroup, StandardIgnore
from flytekit.tools.script_mode import tar_strip_file_attributes
Expand All @@ -31,8 +33,11 @@ def fast_package(source: os.PathLike, output_dir: os.PathLike) -> os.PathLike:
digest = compute_digest(source, ignore.is_ignored)
archive_fname = f"{FAST_PREFIX}{digest}{FAST_FILEENDING}"

if output_dir:
archive_fname = os.path.join(output_dir, archive_fname)
if output_dir is None:
output_dir = tempfile.mkdtemp()
click.secho(f"Output given as {None}, using a temporary directory at {output_dir} instead", fg="yellow")

archive_fname = os.path.join(output_dir, archive_fname)

with tempfile.TemporaryDirectory() as tmp_dir:
tar_path = os.path.join(tmp_dir, "tmp.tar")
Expand Down
4 changes: 4 additions & 0 deletions flytekit/tools/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def package(

# If Fast serialization is enabled, then an archive is also created and packaged
if fast:
# If output exists and is a path within source, delete it so as to not re-bundle it again.
if os.path.abspath(output).startswith(os.path.abspath(source)) and os.path.exists(output):
click.secho(f"{output} already exists within {source}, deleting and re-creating it", fg="yellow")
os.remove(output)
archive_fname = fast_registration.fast_package(source, output_tmpdir)
click.secho(f"Fast mode enabled: compressed archive {archive_fname}", dim=True)

Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/cli/pyflyte/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _fake_module_load(names):
yield simple


@pytest.yield_fixture(
@pytest.fixture(
scope="function",
params=[
os.path.join(
Expand Down
50 changes: 50 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_package.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import shutil

import pytest
from click.testing import CliRunner
from flyteidl.admin.launch_plan_pb2 import LaunchPlan
Expand All @@ -11,6 +14,22 @@
from flytekit.core import context_manager
from flytekit.exceptions.user import FlyteValidationException

sample_file_contents = """
from flytekit import task, workflow
@task(cache=True, cache_version="1", retries=3)
def sum(x: int, y: int) -> int:
return x + y
@task(cache=True, cache_version="1", retries=3)
def square(z: int) -> int:
return z*z
@workflow
def my_workflow(x: int, y: int) -> int:
return sum(x=square(z=x), y=square(z=y))
"""


@flytekit.task
def foo():
Expand Down Expand Up @@ -44,6 +63,29 @@ def test_get_registrable_entities():
assert False, f"found unknown entity {type(e)}"


def test_package_with_fast_registration():
runner = CliRunner()
with runner.isolated_filesystem():
os.makedirs("core", exist_ok=True)
with open(os.path.join("core", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()
result = runner.invoke(pyflyte.main, ["--pkgs", "core", "package", "--image", "core:v1", "--fast"])
assert result.exit_code == 0
assert "Successfully serialized" in result.output
assert "Successfully packaged" in result.output
result = runner.invoke(pyflyte.main, ["--pkgs", "core", "package", "--image", "core:v1", "--fast"])
assert result.exit_code == 2
assert "flyte-package.tgz already exists, specify -f to override" in result.output
result = runner.invoke(
pyflyte.main,
["--pkgs", "core", "package", "--image", "core:v1", "--fast", "--force"],
)
assert result.exit_code == 0
assert "deleting and re-creating it" in result.output
shutil.rmtree("core")


def test_duplicate_registrable_entities():
@flytekit.task
def t_1():
Expand Down Expand Up @@ -114,3 +156,11 @@ def test_package():
def test_pkgs():
pp = pyflyte.validate_package(None, None, ["a.b", "a.c,b.a", "cc.a"])
assert pp == ["a.b", "a.c", "b.a", "cc.a"]


def test_package_with_no_pkgs():
runner = CliRunner()
with runner.isolated_filesystem():
result = runner.invoke(pyflyte.main, ["package"])
assert result.exit_code == 1
assert "No packages to scan for flyte entities. Aborting!" in result.output
54 changes: 54 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_register.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
import os
import shutil
import subprocess

import mock
from click.testing import CliRunner

from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.clis.sdk_in_container import pyflyte
from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context
from flytekit.remote.remote import FlyteRemote

sample_file_contents = """
from flytekit import task, workflow
@task(cache=True, cache_version="1", retries=3)
def sum(x: int, y: int) -> int:
return x + y
@task(cache=True, cache_version="1", retries=3)
def square(z: int) -> int:
return z*z
@workflow
def my_workflow(x: int, y: int) -> int:
return sum(x=square(z=x), y=square(z=y))
"""


@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote")
Expand All @@ -9,3 +33,33 @@ def test_saving_remote(mock_remote):
mock_context.obj = {}
get_and_save_remote_with_click_context(mock_context, "p", "d")
assert mock_context.obj["flyte_remote"] is not None


def test_register_with_no_package_or_module_argument():
runner = CliRunner()
with runner.isolated_filesystem():
result = runner.invoke(pyflyte.main, ["register"])
assert result.exit_code == 1
assert (
"Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed"
in result.output
)


@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_register_with_no_output_dir_passed(mock_client, mock_remote):
mock_remote._client = mock_client
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value._upload_file.return_value = "dummy_md5_bytes", "dummy_native_url"
runner = CliRunner()
with runner.isolated_filesystem():
out = subprocess.run(["git", "init"], capture_output=True)
assert out.returncode == 0
os.makedirs("core", exist_ok=True)
with open(os.path.join("core", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()
result = runner.invoke(pyflyte.main, ["register", "core"])
assert "Output given as None, using a temporary directory at" in result.output
shutil.rmtree("core")

0 comments on commit 3137609

Please sign in to comment.