Skip to content

Commit

Permalink
Update pyflyte defaults to use --copy behavior (#2755)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored and kumare3 committed Nov 8, 2024
1 parent cf69651 commit 5bc72c5
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 47 deletions.
32 changes: 19 additions & 13 deletions flytekit/clis/sdk_in_container/package.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import typing

import rich_click as click
Expand Down Expand Up @@ -54,17 +55,18 @@
is_flag=True,
default=False,
required=False,
help="[Will be deprecated, see --copy] This flag enables fast packaging, that allows `no container build`"
" deploys of flyte workflows and tasks. You can specify --copy all/auto instead"
help="[Deprecated, see --copy] This flag enables fast packaging, that allows `no container build`"
" deploys of flyte workflows and tasks. You should specify --copy all/auto instead"
" Note this needs additional configuration, refer to the docs.",
)
@click.option(
"--copy",
required=False,
type=click.Choice(["all", "auto", "none"], case_sensitive=False),
default=None, # this will be changed to "none" after removing fast option
default="none",
show_default=True,
callback=parse_copy,
help="[Beta] Specify whether local files should be copied and uploaded so task containers have up-to-date code"
help="Specify whether local files should be copied and uploaded so task containers have up-to-date code"
" 'all' will behave as the current 'fast' flag, copying all files, 'auto' copies only loaded Python modules",
)
@click.option(
Expand Down Expand Up @@ -128,11 +130,17 @@ def package(
object contains the WorkflowTemplate, along with the relevant tasks for that workflow.
This serialization step will set the name of the tasks to the fully qualified name of the task function.
"""
if copy is not None and fast:
raise ValueError("--fast and --copy cannot be used together. Please use --copy all instead.")
elif copy == CopyFileDetection.ALL or copy == CopyFileDetection.LOADED_MODULES:
# for those migrating, who only set --copy all/auto but don't have --fast set.
fast = True
# Ensure that the two flags are consistent
if fast:
if "--copy" in sys.argv:
raise click.BadParameter(
click.style(
"Cannot use both --fast and --copy flags together. Please move to --copy",
fg="red",
)
)
click.secho("The --fast flag is deprecated, please use --copy all instead", fg="yellow")
copy = CopyFileDetection.ALL

if os.path.exists(output) and not force:
raise click.BadParameter(
Expand All @@ -145,7 +153,7 @@ def package(
serialization_settings = SerializationSettings(
image_config=image_config,
fast_serialization_settings=FastSerializationSettings(
enabled=fast,
enabled=copy != CopyFileDetection.NO_COPY,
destination_dir=in_container_source_path,
),
python_interpreter=python_interpreter,
Expand All @@ -161,8 +169,6 @@ def package(
show_files = ctx.obj[constants.CTX_VERBOSE] > 0

fast_options = FastPackageOptions([], copy_style=copy, show_files=show_files)
serialize_and_package(
pkgs, serialization_settings, source, output, fast, deref_symlinks, fast_options=fast_options
)
serialize_and_package(pkgs, serialization_settings, source, output, deref_symlinks, fast_options=fast_options)
except NoSerializableEntitiesError:
click.secho(f"No flyte objects found in packages {pkgs}", fg="yellow")
33 changes: 19 additions & 14 deletions flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import typing

import rich_click as click
Expand Down Expand Up @@ -98,15 +99,16 @@
"--non-fast",
default=False,
is_flag=True,
help="[Will be deprecated, see --copy] Skip zipping and uploading the package. You can specify --copy none instead",
help="[Deprecated, see --copy] Skip zipping and uploading the package. You should specify --copy none instead",
)
@click.option(
"--copy",
required=False,
type=click.Choice(["all", "auto", "none"], case_sensitive=False),
default=None, # this will be changed to "all" after removing non-fast option
default="all",
show_default=True,
callback=parse_copy,
help="[Beta] Specify how and whether to use fast register"
help="Specify how and whether to use fast register"
" 'all' is the current behavior copying all files from root, 'auto' copies only loaded Python modules"
" 'none' means no files are copied, i.e. don't use fast register",
)
Expand Down Expand Up @@ -164,14 +166,21 @@ def register(
"""
see help
"""
if copy is not None and non_fast:
raise ValueError("--non-fast and --copy cannot be used together. Use --copy none instead.")
# Set the relevant copy option if non_fast is set, this enables the individual file listing behavior
# that the copy flag uses.
if non_fast:
click.secho("The --non-fast flag is deprecated, please use --copy none instead", fg="yellow")
if "--copy" in sys.argv:
raise click.BadParameter(
click.style(
"Cannot use both --non-fast and --copy flags together. Please move to --copy.",
fg="red",
)
)
copy = CopyFileDetection.NO_COPY
if copy == CopyFileDetection.NO_COPY and not version:
raise ValueError("Version is a required parameter in case --copy none is specified.")

# Handle the new case where the copy flag is used instead of non-fast
if copy == CopyFileDetection.NO_COPY:
non_fast = True
# Set this to None because downstream logic currently detects None to mean old logic.
copy = None
show_files = ctx.obj[constants.CTX_VERBOSE] > 0

pkgs = ctx.obj[constants.CTX_PACKAGES]
Expand All @@ -180,9 +189,6 @@ def register(
if pkgs:
raise ValueError("Unimplemented, just specify pkgs like folder/files as args at the end of the command")

if non_fast and not version:
raise ValueError("Version is a required parameter in case --non-fast/--copy none is specified.")

if len(package_or_module) == 0:
display_help_with_error(
ctx,
Expand Down Expand Up @@ -215,7 +221,6 @@ def register(
raw_data_prefix,
version,
deref_symlinks,
fast=not non_fast,
copy_style=copy,
package_or_module=package_or_module,
remote=remote,
Expand Down
24 changes: 18 additions & 6 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,20 @@ class RunLevelParams(PyFlyteParams):
is_flag=True,
default=False,
show_default=True,
help="[Will be deprecated, see --copy] Copy all files in the source root directory to"
help="[Deprecated, see --copy] Copy all files in the source root directory to"
" the destination directory. You can specify --copy all instead",
)
)
copy: typing.Optional[CopyFileDetection] = make_click_option_field(
click.Option(
param_decls=["--copy"],
required=False,
default=None, # this will change to "auto" after removing copy_all option
default="auto",
type=click.Choice(["all", "auto"], case_sensitive=False),
show_default=True,
callback=parse_copy,
help="[Beta] Specifies how to detect which files to copy into image."
" 'all' will behave as the current copy-all flag, 'auto' copies only loaded Python modules",
help="Specifies how to detect which files to copy into image."
" 'all' will behave as the deprecated copy-all flag, 'auto' copies only loaded Python modules",
)
)
image_config: ImageConfig = make_click_option_field(
Expand Down Expand Up @@ -650,14 +650,27 @@ def _run(*args, **kwargs):

image_config = run_level_params.image_config
image_config = patch_image_config(config_file, image_config)
if run_level_params.copy_all:
click.secho(
"The --copy_all flag is now deprecated. Please use --copy all instead.",
fg="yellow",
)
if "--copy" in sys.argv:
raise click.BadParameter(
click.style(
"Cannot use both --copy-all and --copy flags together. Please move to --copy.",
fg="red",
)
)

with context_manager.FlyteContextManager.with_context(remote.context.new_builder()):
show_files = run_level_params.verbose > 0
fast_package_options = FastPackageOptions(
[],
copy_style=run_level_params.copy,
copy_style=CopyFileDetection.ALL if run_level_params.copy_all else run_level_params.copy,
show_files=show_files,
)

remote_entity = remote.register_script(
entity,
project=run_level_params.project,
Expand All @@ -666,7 +679,6 @@ def _run(*args, **kwargs):
destination_dir=run_level_params.destination_dir,
source_path=run_level_params.computed_params.project_root,
module_name=run_level_params.computed_params.module,
copy_all=run_level_params.copy_all,
fast_package_options=fast_package_options,
)

Expand Down
16 changes: 14 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from base64 import b64encode
from collections import OrderedDict
from dataclasses import asdict, dataclass
from dataclasses import replace as dc_replace
from datetime import datetime, timedelta
from typing import Dict

Expand All @@ -34,6 +35,7 @@
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions
from flytekit.configuration import Config, FastSerializationSettings, ImageConfig, SerializationSettings
from flytekit.constants import CopyFileDetection
from flytekit.core import constants, utils
from flytekit.core.artifact import Artifact
from flytekit.core.base_task import PythonTask
Expand Down Expand Up @@ -1048,7 +1050,7 @@ def register_script(
"""
Use this method to register a workflow via script mode.
:param destination_dir: The destination directory where the workflow will be copied to.
:param copy_all: If true, the entire source directory will be copied over to the destination directory.
:param copy_all: [deprecated] Please use the copy_style field in fast_package_options instead.
:param domain: The domain to register the workflow in.
:param project: The project to register the workflow in.
:param image_config: The image config to use for the workflow.
Expand All @@ -1062,11 +1064,21 @@ def register_script(
:param fast_package_options: Options to customize copy_all behavior, ignored when copy_all is False.
:return:
"""
if copy_all:
logger.info(
"The copy_all flag to FlyteRemote.register_script is deprecated. Please use"
" the copy_style field in fast_package_options instead."
)
if not fast_package_options:
fast_package_options = FastPackageOptions([], copy_style=CopyFileDetection.ALL)
else:
fast_package_options = dc_replace(fast_package_options, copy_style=CopyFileDetection.ALL)

if image_config is None:
image_config = ImageConfig.auto_default_image()

with tempfile.TemporaryDirectory() as tmp_dir:
if copy_all or (fast_package_options and fast_package_options.copy_style):
if fast_package_options and fast_package_options.copy_style != CopyFileDetection.NO_COPY:
md5_bytes, upload_native_url = self.fast_package(
pathlib.Path(source_path), False, tmp_dir, fast_package_options
)
Expand Down
2 changes: 1 addition & 1 deletion flytekit/tools/fast_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def fast_package(

compress_tarball(tar_path, archive_fname)

# Original tar command - This condition to be removed in the future.
# Original tar command - This condition to be removed in the future after serialize is removed.
else:
# Compute where the archive should be written
archive_fname = f"{FAST_PREFIX}{digest}{FAST_FILEENDING}"
Expand Down
17 changes: 6 additions & 11 deletions flytekit/tools/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

import click

import flytekit.configuration
import flytekit.constants
from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings
from flytekit.constants import CopyFileDetection
from flytekit.core.context_manager import FlyteContextManager
from flytekit.loggers import logger
from flytekit.models import launch_plan, task
Expand Down Expand Up @@ -90,7 +89,6 @@ def package(
serializable_entities: typing.List[FlyteControlPlaneEntity],
source: str = ".",
output: str = "./flyte-package.tgz",
fast: bool = False,
deref_symlinks: bool = False,
fast_options: typing.Optional[fast_registration.FastPackageOptions] = None,
):
Expand All @@ -99,7 +97,6 @@ def package(
:param serializable_entities: Entities that can be serialized
:param source: source folder
:param output: output package name with suffix
:param fast: fast enabled implies source code is bundled
:param deref_symlinks: if enabled then symlinks are dereferenced during packaging
:param fast_options:
Expand All @@ -114,7 +111,7 @@ def package(
persist_registrable_entities(serializable_entities, output_tmpdir)

# If Fast serialization is enabled, then an archive is also created and packaged
if fast:
if fast_options and fast_options.copy_style != CopyFileDetection.NO_COPY:
# If output exists and is a path within source, delete it so as to not re-bundle it again.
if os.path.abspath(output).startswith(os.path.abspath(source)) and os.path.exists(output):
click.secho(f"{output} already exists within {source}, deleting and re-creating it", fg="yellow")
Expand All @@ -135,7 +132,6 @@ def serialize_and_package(
settings: SerializationSettings,
source: str = ".",
output: str = "./flyte-package.tgz",
fast: bool = False,
deref_symlinks: bool = False,
options: typing.Optional[Options] = None,
fast_options: typing.Optional[fast_registration.FastPackageOptions] = None,
Expand All @@ -147,7 +143,7 @@ def serialize_and_package(
"""
serialize_load_only(pkgs, settings, source)
serializable_entities = serialize_get_control_plane_entities(settings, source, options=options)
package(serializable_entities, source, output, fast, deref_symlinks, fast_options)
package(serializable_entities, source, output, deref_symlinks, fast_options)


def find_common_root(
Expand Down Expand Up @@ -234,10 +230,9 @@ def register(
raw_data_prefix: str,
version: typing.Optional[str],
deref_symlinks: bool,
fast: bool,
package_or_module: typing.Tuple[str],
remote: FlyteRemote,
copy_style: typing.Optional[flytekit.constants.CopyFileDetection],
copy_style: CopyFileDetection,
env: typing.Optional[typing.Dict[str, str]],
dry_run: bool = False,
activate_launchplans: bool = False,
Expand All @@ -262,7 +257,7 @@ def register(
env=env,
)

if not version and not fast:
if not version and copy_style == CopyFileDetection.NO_COPY:
click.secho("Version is required.", fg="red")
return

Expand All @@ -281,7 +276,7 @@ def register(
serialize_load_only(pkgs_and_modules, serialization_settings, str(detected_root))

# Fast registration is handled after module loading
if fast:
if copy_style != CopyFileDetection.NO_COPY:
md5_bytes, native_url = remote.fast_package(
detected_root,
deref_symlinks,
Expand Down

0 comments on commit 5bc72c5

Please sign in to comment.