diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index 5ec4b9b262..6ed5072c36 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -7,6 +7,7 @@ from flytekit.configuration import ImageConfig from flytekit.configuration.plugin import get_plugin from flytekit.remote.remote import FlyteRemote +from flytekit.tools.fast_registration import CopyFileDetection FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote" @@ -61,3 +62,17 @@ def patch_image_config(config_file: Optional[str], image_config: ImageConfig) -> if addl.name not in additional_image_names: new_additional_images.append(addl) return replace(image_config, default_image=new_default, images=new_additional_images) + + +def parse_copy(ctx, param, value) -> Optional[CopyFileDetection]: + """Helper function to parse cmd line args into enum""" + if value == "auto": + copy_style = CopyFileDetection.LOADED_MODULES + elif value == "all": + copy_style = CopyFileDetection.ALL + elif value == "none": + copy_style = CopyFileDetection.NO_COPY + else: + copy_style = None + + return copy_style diff --git a/flytekit/clis/sdk_in_container/package.py b/flytekit/clis/sdk_in_container/package.py index c61b02a16d..6decbc32e1 100644 --- a/flytekit/clis/sdk_in_container/package.py +++ b/flytekit/clis/sdk_in_container/package.py @@ -1,9 +1,11 @@ import os +import typing import rich_click as click from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants +from flytekit.clis.sdk_in_container.helpers import parse_copy from flytekit.configuration import ( DEFAULT_RUNTIME_PYTHON_INTERPRETER, FastSerializationSettings, @@ -11,6 +13,7 @@ SerializationSettings, ) from flytekit.interaction.click_types import key_value_callback +from flytekit.tools.fast_registration import CopyFileDetection, FastPackageOptions from flytekit.tools.repo import NoSerializableEntitiesError, serialize_and_package @@ -50,8 +53,18 @@ is_flag=True, default=False, required=False, - help="This flag enables fast packaging, that allows `no container build` deploys of flyte workflows and tasks. " - "Note this needs additional configuration, refer to the docs.", + help="[Will be deprecated, see --copy] This flag enables fast packaging, that allows `no container build`" + " deploys of flyte workflows and tasks. You can specify --copy all/auto instead" + " Note this needs additional configuration, refer to the docs.", +) +@click.option( + "--copy", + required=False, + type=click.Choice(["all", "auto", "none"], case_sensitive=False), + default=None, # this will be changed to "none" after removing fast option + callback=parse_copy, + help="[Beta] Specify whether local files should be copied and uploaded so task containers have up-to-date code" + " 'all' will behave as the current 'fast' flag, copying all files, 'auto' copies only loaded Python modules", ) @click.option( "-f", @@ -100,6 +113,7 @@ def package( source, output, force, + copy: typing.Optional[CopyFileDetection], fast, in_container_source_path, python_interpreter, @@ -113,6 +127,12 @@ def package( object contains the WorkflowTemplate, along with the relevant tasks for that workflow. This serialization step will set the name of the tasks to the fully qualified name of the task function. """ + if copy is not None and fast: + raise ValueError("--fast and --copy cannot be used together. Please use --copy all instead.") + elif copy == CopyFileDetection.ALL or copy == CopyFileDetection.LOADED_MODULES: + # for those migrating, who only set --copy all/auto but don't have --fast set. + fast = True + if os.path.exists(output) and not force: raise click.BadParameter( click.style( @@ -136,6 +156,12 @@ def package( display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!") try: - serialize_and_package(pkgs, serialization_settings, source, output, fast, deref_symlinks) + # verbosity greater than 0 means to print the files + show_files = ctx.obj[constants.CTX_VERBOSE] > 0 + + fast_options = FastPackageOptions([], copy_style=copy, show_files=show_files) + serialize_and_package( + pkgs, serialization_settings, source, output, fast, deref_symlinks, fast_options=fast_options + ) except NoSerializableEntitiesError: click.secho(f"No flyte objects found in packages {pkgs}", fg="yellow") diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index e578f06a17..dfbbd23d00 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -5,13 +5,18 @@ from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants -from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context, patch_image_config +from flytekit.clis.sdk_in_container.helpers import ( + get_and_save_remote_with_click_context, + parse_copy, + patch_image_config, +) from flytekit.clis.sdk_in_container.utils import domain_option_dec, project_option_dec from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages from flytekit.interaction.click_types import key_value_callback from flytekit.loggers import logger from flytekit.tools import repo +from flytekit.tools.fast_registration import CopyFileDetection _register_help = """ This command is similar to ``package`` but instead of producing a zip file, all your Flyte entities are compiled, @@ -93,7 +98,17 @@ "--non-fast", default=False, is_flag=True, - help="Skip zipping and uploading the package", + help="[Will be deprecated, see --copy] Skip zipping and uploading the package. You can specify --copy none instead", +) +@click.option( + "--copy", + required=False, + type=click.Choice(["all", "auto", "none"], case_sensitive=False), + default=None, # this will be changed to "all" after removing non-fast option + callback=parse_copy, + help="[Beta] Specify how and whether to use fast register" + " 'all' is the current behavior copying all files from root, 'auto' copies only loaded Python modules" + " 'none' means no files are copied, i.e. don't use fast register", ) @click.option( "--dry-run", @@ -139,6 +154,7 @@ def register( version: typing.Optional[str], deref_symlinks: bool, non_fast: bool, + copy: typing.Optional[CopyFileDetection], package_or_module: typing.Tuple[str], dry_run: bool, activate_launchplans: bool, @@ -148,6 +164,16 @@ def register( """ see help """ + if copy is not None and non_fast: + raise ValueError("--non-fast and --copy cannot be used together. Use --copy none instead.") + + # Handle the new case where the copy flag is used instead of non-fast + if copy == CopyFileDetection.NO_COPY: + non_fast = True + # Set this to None because downstream logic currently detects None to mean old logic. + copy = None + show_files = ctx.obj[constants.CTX_VERBOSE] > 0 + pkgs = ctx.obj[constants.CTX_PACKAGES] if not pkgs: logger.debug("No pkgs") @@ -155,7 +181,7 @@ def register( raise ValueError("Unimplemented, just specify pkgs like folder/files as args at the end of the command") if non_fast and not version: - raise ValueError("Version is a required parameter in case --non-fast is specified.") + raise ValueError("Version is a required parameter in case --non-fast/--copy none is specified.") if len(package_or_module) == 0: display_help_with_error( @@ -190,10 +216,12 @@ def register( version, deref_symlinks, fast=not non_fast, + copy_style=copy, package_or_module=package_or_module, remote=remote, env=env, dry_run=dry_run, activate_launchplans=activate_launchplans, skip_errors=skip_errors, + show_files=show_files, ) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 5e99c8740b..1ab04452ee 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -19,7 +19,10 @@ from typing_extensions import get_origin from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal -from flytekit.clis.sdk_in_container.helpers import patch_image_config +from flytekit.clis.sdk_in_container.helpers import ( + parse_copy, + patch_image_config, +) from flytekit.clis.sdk_in_container.utils import ( PyFlyteParams, domain_option, @@ -63,6 +66,7 @@ ) from flytekit.remote.executions import FlyteWorkflowExecution from flytekit.tools import module_loader +from flytekit.tools.fast_registration import CopyFileDetection, FastPackageOptions from flytekit.tools.script_mode import _find_project_root, compress_scripts, get_all_modules from flytekit.tools.translator import Options @@ -104,7 +108,20 @@ class RunLevelParams(PyFlyteParams): is_flag=True, default=False, show_default=True, - help="Copy all files in the source root directory to the destination directory", + help="[Will be deprecated, see --copy] Copy all files in the source root directory to" + " the destination directory. You can specify --copy all instead", + ) + ) + copy: typing.Optional[CopyFileDetection] = make_click_option_field( + click.Option( + param_decls=["--copy"], + required=False, + default=None, # this will change to "auto" after removing copy_all option + type=click.Choice(["all", "auto"], case_sensitive=False), + show_default=True, + callback=parse_copy, + help="[Beta] Specifies how to detect which files to copy into image." + " 'all' will behave as the current copy-all flag, 'auto' copies only loaded Python modules", ) ) image_config: ImageConfig = make_click_option_field( @@ -626,6 +643,12 @@ def _run(*args, **kwargs): image_config = patch_image_config(config_file, image_config) with context_manager.FlyteContextManager.with_context(remote.context.new_builder()): + show_files = run_level_params.verbose > 0 + fast_package_options = FastPackageOptions( + [], + copy_style=run_level_params.copy, + show_files=show_files, + ) remote_entity = remote.register_script( entity, project=run_level_params.project, @@ -635,6 +658,7 @@ def _run(*args, **kwargs): source_path=run_level_params.computed_params.project_root, module_name=run_level_params.computed_params.module, copy_all=run_level_params.copy_all, + fast_package_options=fast_package_options, ) run_remote( diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index f28f3ca3e2..2cb8103647 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -1062,7 +1062,7 @@ def register_script( image_config = ImageConfig.auto_default_image() with tempfile.TemporaryDirectory() as tmp_dir: - if copy_all: + if copy_all or (fast_package_options and fast_package_options.copy_style): md5_bytes, upload_native_url = self.fast_package( pathlib.Path(source_path), False, tmp_dir, fast_package_options ) diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index d17bbe8994..a65d24a740 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -3,27 +3,43 @@ import gzip import hashlib import os +import pathlib import posixpath import subprocess +import sys import tarfile import tempfile import typing from dataclasses import dataclass +from enum import Enum from typing import Optional import click +from rich import print as rich_print +from rich.tree import Tree from flytekit.core.context_manager import FlyteContextManager from flytekit.core.utils import timeit from flytekit.exceptions.user import FlyteDataNotFoundException from flytekit.loggers import logger from flytekit.tools.ignore import DockerIgnore, FlyteIgnore, GitIgnore, Ignore, IgnoreGroup, StandardIgnore -from flytekit.tools.script_mode import tar_strip_file_attributes +from flytekit.tools.script_mode import _filehash_update, _pathhash_update, ls_files, tar_strip_file_attributes FAST_PREFIX = "fast" FAST_FILEENDING = ".tar.gz" +class CopyFileDetection(Enum): + LOADED_MODULES = 1 + ALL = 2 + # This option's meaning will change in the future. In the future this will mean that no files should be copied + # (i.e. no fast registration is used). For now, both this value and setting this Enum to Python None are both + # valid to distinguish between users explicitly setting --copy none and not setting the flag. + # Currently, this is only used for register, not for package or run because run doesn't have a no-fast-register + # option and package is by default non-fast. + NO_COPY = 3 + + @dataclass(frozen=True) class FastPackageOptions: """ @@ -32,6 +48,31 @@ class FastPackageOptions: ignores: list[Ignore] keep_default_ignores: bool = True + copy_style: Optional[CopyFileDetection] = None + show_files: bool = False + + +def print_ls_tree(source: os.PathLike, ls: typing.List[str]): + click.secho("Files to be copied for fast registration...", fg="bright_blue") + + tree_root = Tree( + f":open_file_folder: [link file://{source}]{source} (detected source root)", + guide_style="bold bright_blue", + ) + trees = {pathlib.Path(source): tree_root} + + for f in ls: + fpp = pathlib.Path(f) + if fpp.parent not in trees: + # add trees for all intermediate folders + current = tree_root + current_path = pathlib.Path(source) + for subdir in fpp.parent.relative_to(source).parts: + current = current.add(f"{subdir}", guide_style="bold bright_blue") + current_path = current_path / subdir + trees[current_path] = current + trees[fpp.parent].add(f"{fpp.name}", guide_style="bold bright_blue") + rich_print(tree_root) def fast_package( @@ -46,6 +87,7 @@ def fast_package( :param os.PathLike source: :param os.PathLike output_dir: :param bool deref_symlinks: Enables dereferencing symlinks when packaging directory + :param options: The CopyFileDetection option set to None :return os.PathLike: """ default_ignores = [GitIgnore, DockerIgnore, StandardIgnore, FlyteIgnore] @@ -58,28 +100,73 @@ def fast_package( ignores = default_ignores ignore = IgnoreGroup(source, ignores) + # Remove this after original tar command is removed. digest = compute_digest(source, ignore.is_ignored) - archive_fname = f"{FAST_PREFIX}{digest}{FAST_FILEENDING}" - - if output_dir is None: - output_dir = tempfile.mkdtemp() - click.secho(f"No output path provided, using a temporary directory at {output_dir} instead", fg="yellow") - - archive_fname = os.path.join(output_dir, archive_fname) - - with tempfile.TemporaryDirectory() as tmp_dir: - tar_path = os.path.join(tmp_dir, "tmp.tar") - with tarfile.open(tar_path, "w", dereference=deref_symlinks) as tar: - files: typing.List[str] = os.listdir(source) - for ws_file in files: - tar.add( - os.path.join(source, ws_file), - arcname=ws_file, - filter=lambda x: ignore.tar_filter(tar_strip_file_attributes(x)), - ) - with gzip.GzipFile(filename=archive_fname, mode="wb", mtime=0) as gzipped: - with open(tar_path, "rb") as tar_file: - gzipped.write(tar_file.read()) + + # This function is temporarily split into two, to support the creation of the tar file in both the old way, + # copying the underlying items in the source dir by doing a listdir, and the new way, relying on a list of files. + if options and ( + options.copy_style == CopyFileDetection.LOADED_MODULES or options.copy_style == CopyFileDetection.ALL + ): + if options.copy_style == CopyFileDetection.LOADED_MODULES: + # This is the 'auto' semantic by default used for pyflyte run, it only copies loaded .py files. + sys_modules = list(sys.modules.values()) + ls, ls_digest = ls_files(str(source), sys_modules, deref_symlinks, ignore) + else: + # This triggers listing of all files, mimicking the old way of creating the tar file. + ls, ls_digest = ls_files(str(source), [], deref_symlinks, ignore) + + logger.debug(f"Hash digest: {ls_digest}", fg="green") + + if options.show_files: + print_ls_tree(source, ls) + + # Compute where the archive should be written + archive_fname = f"{FAST_PREFIX}{ls_digest}{FAST_FILEENDING}" + if output_dir is None: + output_dir = tempfile.mkdtemp() + click.secho(f"No output path provided, using a temporary directory at {output_dir} instead", fg="yellow") + archive_fname = os.path.join(output_dir, archive_fname) + + with tempfile.TemporaryDirectory() as tmp_dir: + tar_path = os.path.join(tmp_dir, "tmp.tar") + with tarfile.open(tar_path, "w", dereference=True) as tar: + for ws_file in ls: + rel_path = os.path.relpath(ws_file, start=source) + tar.add( + os.path.join(source, ws_file), + arcname=rel_path, + filter=lambda x: tar_strip_file_attributes(x), + ) + + with gzip.GzipFile(filename=archive_fname, mode="wb", mtime=0) as gzipped: + with open(tar_path, "rb") as tar_file: + gzipped.write(tar_file.read()) + + # Original tar command - This condition to be removed in the future. + else: + # Compute where the archive should be written + archive_fname = f"{FAST_PREFIX}{digest}{FAST_FILEENDING}" + if output_dir is None: + output_dir = tempfile.mkdtemp() + click.secho(f"No output path provided, using a temporary directory at {output_dir} instead", fg="yellow") + archive_fname = os.path.join(output_dir, archive_fname) + + with tempfile.TemporaryDirectory() as tmp_dir: + tar_path = os.path.join(tmp_dir, "tmp.tar") + with tarfile.open(tar_path, "w", dereference=deref_symlinks) as tar: + files: typing.List[str] = os.listdir(source) + for ws_file in files: + tar.add( + os.path.join(source, ws_file), + arcname=ws_file, + filter=lambda x: ignore.tar_filter(tar_strip_file_attributes(x)), + ) + # tar.list(verbose=True) + + with gzip.GzipFile(filename=archive_fname, mode="wb", mtime=0) as gzipped: + with open(tar_path, "rb") as tar_file: + gzipped.write(tar_file.read()) return archive_fname @@ -112,20 +199,6 @@ def compute_digest(source: os.PathLike, filter: Optional[callable] = None) -> st return hasher.hexdigest() -def _filehash_update(path: os.PathLike, hasher: hashlib._Hash) -> None: - blocksize = 65536 - with open(path, "rb") as f: - bytes = f.read(blocksize) - while bytes: - hasher.update(bytes) - bytes = f.read(blocksize) - - -def _pathhash_update(path: os.PathLike, hasher: hashlib._Hash) -> None: - path_list = path.split(os.sep) - hasher.update("".join(path_list).encode("utf-8")) - - def get_additional_distribution_loc(remote_location: str, identifier: str) -> str: """ :param Text remote_location: diff --git a/flytekit/tools/module_loader.py b/flytekit/tools/module_loader.py index dc3a6bb9f4..977a194fbd 100644 --- a/flytekit/tools/module_loader.py +++ b/flytekit/tools/module_loader.py @@ -17,6 +17,12 @@ def add_sys_path(path: Union[str, os.PathLike]) -> Iterator[None]: sys.path.remove(path) +def module_load_error_handler(*args, **kwargs): + from flytekit import logger + + logger.info(f"Error walking package structure when loading: {args}, {kwargs}") + + def just_load_modules(pkgs: List[str]): """ This one differs from the above in that we don't yield anything, just load all the modules. @@ -29,7 +35,9 @@ def just_load_modules(pkgs: List[str]): continue # Note that walk_packages takes an onerror arg and swallows import errors silently otherwise - for _, name, _ in pkgutil.walk_packages(package.__path__, prefix=f"{package_name}."): + for _, name, _ in pkgutil.walk_packages( + package.__path__, prefix=f"{package_name}.", onerror=module_load_error_handler + ): importlib.import_module(name) diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 5dd68b4261..6160823920 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -25,28 +25,43 @@ class NoSerializableEntitiesError(Exception): pass -def serialize( +def serialize_load_only( pkgs: typing.List[str], settings: SerializationSettings, local_source_root: typing.Optional[str] = None, - options: typing.Optional[Options] = None, -) -> typing.List[FlyteControlPlaneEntity]: +): """ See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the entity type. - :param options: :param settings: SerializationSettings to be used :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization. :param local_source_root: Where to start looking for the code. """ settings.source_root = local_source_root - ctx = FlyteContextManager.current_context().with_serialization_settings(settings) - with FlyteContextManager.with_context(ctx) as ctx: + ctx_builder = FlyteContextManager.current_context().with_serialization_settings(settings) + with FlyteContextManager.with_context(ctx_builder): # Scan all modules. the act of loading populates the global singleton that contains all objects with module_loader.add_sys_path(local_source_root): click.secho(f"Loading packages {pkgs} under source root {local_source_root}", fg="yellow") module_loader.just_load_modules(pkgs=pkgs) + +def serialize_get_control_plane_entities( + settings: SerializationSettings, + local_source_root: typing.Optional[str] = None, + options: typing.Optional[Options] = None, +) -> typing.List[FlyteControlPlaneEntity]: + """ + See :py:class:`flytekit.models.core.identifier.ResourceType` to match the trailing index in the file name with the + entity type. + :param options: + :param settings: SerializationSettings to be used + :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization. + :param local_source_root: Where to start looking for the code. + """ + settings.source_root = local_source_root + ctx_builder = FlyteContextManager.current_context().with_serialization_settings(settings) + with FlyteContextManager.with_context(ctx_builder) as ctx: registrable_entities = get_registrable_entities(ctx, options=options) click.secho(f"Successfully serialized {len(registrable_entities)} flyte objects", fg="green") return registrable_entities @@ -64,7 +79,8 @@ def serialize_to_folder( """ if folder is None: folder = "." - loaded_entities = serialize(pkgs, settings, local_source_root, options=options) + serialize_load_only(pkgs, settings, local_source_root) + loaded_entities = serialize_get_control_plane_entities(settings, local_source_root, options=options) persist_registrable_entities(loaded_entities, folder) @@ -74,6 +90,7 @@ def package( output: str = "./flyte-package.tgz", fast: bool = False, deref_symlinks: bool = False, + fast_options: typing.Optional[fast_registration.FastPackageOptions] = None, ): """ Package the given entities and the source code (if fast is enabled) into a package with the given name in output @@ -82,6 +99,11 @@ def package( :param output: output package name with suffix :param fast: fast enabled implies source code is bundled :param deref_symlinks: if enabled then symlinks are dereferenced during packaging + :param fast_options: + + Temporarily, for fast register, specify both the fast arg as well as copy_style fast == True with + copy_style == None means use the old fast register tar'ring method. + In the future the fast bool will be removed, and copy_style == None will mean do not fast register. """ if not serializable_entities: raise NoSerializableEntitiesError("Nothing to package") @@ -95,7 +117,7 @@ def package( if os.path.abspath(output).startswith(os.path.abspath(source)) and os.path.exists(output): click.secho(f"{output} already exists within {source}, deleting and re-creating it", fg="yellow") os.remove(output) - archive_fname = fast_registration.fast_package(source, output_tmpdir, deref_symlinks) + archive_fname = fast_registration.fast_package(source, output_tmpdir, deref_symlinks, options=fast_options) click.secho(f"Fast mode enabled: compressed archive {archive_fname}", dim=True) with tarfile.open(output, "w:gz") as tar: @@ -114,12 +136,16 @@ def serialize_and_package( fast: bool = False, deref_symlinks: bool = False, options: typing.Optional[Options] = None, + fast_options: typing.Optional[fast_registration.FastPackageOptions] = None, ): """ Fist serialize and then package all entities + Temporarily for fast package, specify both the fast arg as well as copy_style. + fast == True with copy_style == None means use the old fast register tar'ring method. """ - serializable_entities = serialize(pkgs, settings, source, options=options) - package(serializable_entities, source, output, fast, deref_symlinks) + serialize_load_only(pkgs, settings, source) + serializable_entities = serialize_get_control_plane_entities(settings, source, options=options) + package(serializable_entities, source, output, fast, deref_symlinks, fast_options) def find_common_root( @@ -147,29 +173,19 @@ def find_common_root( return project_root -def load_packages_and_modules( - ss: SerializationSettings, +def list_packages_and_modules( project_root: Path, pkgs_or_mods: typing.List[str], - options: typing.Optional[Options] = None, -) -> typing.List[FlyteControlPlaneEntity]: +) -> typing.List[str]: """ - The project root is added as the first entry to sys.path, and then all the specified packages and modules - given are loaded with all submodules. The reason for prepending the entry is to ensure that the name that - the various modules are loaded under are the fully-resolved name. + This is a helper function that returns the input list of python packages/modules as a dot delinated list + relative to the given project_root. - For example, using flytesnacks cookbook, if you are in core/ and you call this function with - ``flyte_basics/hello_world.py control_flow/``, the ``hello_world`` module would be loaded - as ``core.flyte_basics.hello_world`` even though you're already in the core/ folder. - - :param ss: :param project_root: :param pkgs_or_mods: - :param options: - :return: The common detected root path, the output of _find_project_root + :return: List of packages/modules, dot delineated. """ - ss.git_repo = _get_git_repo_url(project_root) - pkgs_and_modules = [] + pkgs_and_modules: typing.List[str] = [] for pm in pkgs_or_mods: p = Path(pm).resolve() rel_path_from_root = p.relative_to(project_root) @@ -182,9 +198,7 @@ def load_packages_and_modules( ) pkgs_and_modules.append(dot_delineated) - registrable_entities = serialize(pkgs_and_modules, ss, str(project_root), options) - - return registrable_entities + return pkgs_and_modules def secho(i: Identifier, state: str = "success", reason: str = None, op: str = "Registration"): @@ -221,21 +235,19 @@ def register( fast: bool, package_or_module: typing.Tuple[str], remote: FlyteRemote, + copy_style: typing.Optional[fast_registration.CopyFileDetection], env: typing.Optional[typing.Dict[str, str]], dry_run: bool = False, activate_launchplans: bool = False, skip_errors: bool = False, + show_files: bool = False, ): + """ + Temporarily, for fast register, specify both the fast arg as well as copy_style. + fast == True with copy_style == None means use the old fast register tar'ring method. + """ detected_root = find_common_root(package_or_module) click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow") - fast_serialization_settings = None - if fast: - md5_bytes, native_url = remote.fast_package(detected_root, deref_symlinks, output) - fast_serialization_settings = FastSerializationSettings( - enabled=True, - destination_dir=destination_dir, - distribution_location=native_url, - ) # Create serialization settings # Todo: Rely on default Python interpreter for now, this will break custom Spark containers @@ -244,28 +256,50 @@ def register( domain=domain, version=version, image_config=image_config, - fast_serialization_settings=fast_serialization_settings, + fast_serialization_settings=None, # should probably add incomplete fast settings env=env, ) - if not version and fast: - version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa - click.secho(f"Computed version is {version}", fg="yellow") - elif not version: + if not version and not fast: click.secho("Version is required.", fg="red") return b = serialization_settings.new_builder() - b.version = version serialization_settings = b.build() options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) # Load all the entities FlyteContextManager.push_context(remote.context) - registrable_entities = load_packages_and_modules( - serialization_settings, detected_root, list(package_or_module), options - ) + serialization_settings.git_repo = _get_git_repo_url(str(detected_root)) + pkgs_and_modules = list_packages_and_modules(detected_root, list(package_or_module)) + + # NB: The change here is that the loading of user code _cannot_ depend on fast register information (the computed + # version, upload native url, hash digest, etc.). + serialize_load_only(pkgs_and_modules, serialization_settings, str(detected_root)) + + # Fast registration is handled after module loading + if fast: + md5_bytes, native_url = remote.fast_package( + detected_root, + deref_symlinks, + output, + options=fast_registration.FastPackageOptions([], copy_style=copy_style, show_files=show_files), + ) + # update serialization settings from fast register output + fast_serialization_settings = FastSerializationSettings( + enabled=True, + destination_dir=destination_dir, + distribution_location=native_url, + ) + serialization_settings.fast_serialization_settings = fast_serialization_settings + if not version: + version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa + serialization_settings.version = version + click.secho(f"Computed version is {version}", fg="yellow") + + registrable_entities = serialize_get_control_plane_entities(serialization_settings, str(detected_root), options) + FlyteContextManager.pop_context() if len(registrable_entities) == 0: click.secho("No Flyte entities were detected. Aborting!", fg="red") diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index 9d91731389..adbcd313f4 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import gzip import hashlib import os @@ -9,7 +11,10 @@ import typing from pathlib import Path from types import ModuleType -from typing import List, Optional +from typing import List, Optional, Tuple, Union + +from flytekit.loggers import logger +from flytekit.tools.ignore import IgnoreGroup def compress_scripts(source_path: str, destination: str, modules: List[ModuleType]): @@ -79,17 +84,114 @@ def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo: return tar_info -def add_imported_modules_from_source(source_path: str, destination: str, modules: List[ModuleType]): +def ls_files( + source_path: str, + modules: List[ModuleType], + deref_symlinks: bool = False, + ignore_group: Optional[IgnoreGroup] = None, +) -> Tuple[List[str], str]: + """ + user_modules_and_packages is a list of the Python modules and packages, expressed as absolute paths, that the + user has run this pyflyte command with. For pyflyte run for instance, this is just a list of one. + This is used for two reasons. + - Everything in this list needs to be returned. Files are returned and folders are walked. + - A common source path is derived from this is, which is just the common folder that contains everything in the + list. For ex. if you do + $ pyflyte --pkgs a.b,a.c package + Then the common root is just the folder a/. The modules list is filtered against this root. Only files + representing modules under this root are included + + + If the modules list should be a list of all the + + needs to compute digest as well. + """ + + # Unlike the below, the value error here is useful and should be returned to the user, like if absolute and + # relative paths are mixed. + + # This is --copy auto + if modules: + all_files = list_imported_modules_as_files(source_path, modules) + # this is --copy all + else: + all_files = list_all_files(source_path, deref_symlinks, ignore_group) + + hasher = hashlib.md5() + for abspath in all_files: + relpath = os.path.relpath(abspath, source_path) + _filehash_update(abspath, hasher) + _pathhash_update(relpath, hasher) + + digest = hasher.hexdigest() + + return all_files, digest + + +def _filehash_update(path: Union[os.PathLike, str], hasher: hashlib._Hash) -> None: + blocksize = 65536 + with open(path, "rb") as f: + bytes = f.read(blocksize) + while bytes: + hasher.update(bytes) + bytes = f.read(blocksize) + + +def _pathhash_update(path: Union[os.PathLike, str], hasher: hashlib._Hash) -> None: + path_list = path.split(os.sep) + hasher.update("".join(path_list).encode("utf-8")) + + +def list_all_files(source_path: str, deref_symlinks, ignore_group: Optional[IgnoreGroup] = None) -> List[str]: + all_files = [] + + # This is needed to prevent infinite recursion when walking with followlinks + visited_inodes = set() + + for root, dirnames, files in os.walk(source_path, topdown=True, followlinks=deref_symlinks): + if deref_symlinks: + inode = os.stat(root).st_ino + if inode in visited_inodes: + continue + visited_inodes.add(inode) + + ff = [] + files.sort() + for fname in files: + abspath = os.path.join(root, fname) + # Only consider files that exist (e.g. disregard symlinks that point to non-existent files) + if not os.path.exists(abspath): + logger.info(f"Skipping non-existent file {abspath}") + continue + if ignore_group: + if ignore_group.is_ignored(abspath): + continue + + ff.append(abspath) + all_files.extend(ff) + + # Remove directories that we've already visited from dirnames + if deref_symlinks: + dirnames[:] = [d for d in dirnames if os.stat(os.path.join(root, d)).st_ino not in visited_inodes] + + return all_files + + +def list_imported_modules_as_files(source_path: str, modules: List[ModuleType]) -> List[str]: """Copies modules into destination that are in modules. The module files are copied only if: 1. Not a site-packages. These are installed packages and not user files. 2. Not in the bin. These are also installed and not user files. 3. Does not share a common path with the source_path. """ + # source path is the folder holding the main script. + # but in register/package case, there are multiple folders. + # identify a common root amongst the packages listed? site_packages = site.getsitepackages() site_packages_set = set(site_packages) bin_directory = os.path.dirname(sys.executable) + files = [] for mod in modules: try: @@ -129,7 +231,25 @@ def add_imported_modules_from_source(source_path: str, destination: str, modules # so we do not upload the file. continue - relative_path = os.path.relpath(mod_file, start=source_path) + files.append(mod_file) + + return files + + +def add_imported_modules_from_source(source_path: str, destination: str, modules: List[ModuleType]): + """Copies modules into destination that are in modules. The module files are copied only if: + + 1. Not a site-packages. These are installed packages and not user files. + 2. Not in the bin. These are also installed and not user files. + 3. Does not share a common path with the source_path. + """ + # source path is the folder holding the main script. + # but in register/package case, there are multiple folders. + # identify a common root amongst the packages listed? + + files = list_imported_modules_as_files(source_path, modules) + for file in files: + relative_path = os.path.relpath(file, start=source_path) new_destination = os.path.join(destination, relative_path) if os.path.exists(new_destination): @@ -137,7 +257,7 @@ def add_imported_modules_from_source(source_path: str, destination: str, modules continue os.makedirs(os.path.dirname(new_destination), exist_ok=True) - shutil.copy(mod_file, new_destination) + shutil.copy(file, new_destination) def get_all_modules(source_path: str, module_name: Optional[str]) -> List[ModuleType]: @@ -154,12 +274,14 @@ def get_all_modules(source_path: str, module_name: Optional[str]) -> List[Module if not is_python_file: return sys_modules + # should move it here probably from flytekit.core.tracker import import_module_from_file try: new_module = import_module_from_file(module_name, full_module_path) return sys_modules + [new_module] - except Exception: + except Exception as exc: + logger.error(f"Using system modules, failed to import {module_name} from {full_module_path}: {str(exc)}") # Import failed so we fallback to `sys_modules` return sys_modules diff --git a/tests/flytekit/unit/cli/pyflyte/test_script_mode.py b/tests/flytekit/unit/cli/pyflyte/test_script_mode.py new file mode 100644 index 0000000000..dcccda0cd2 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_script_mode.py @@ -0,0 +1,51 @@ +import os +import pathlib +import pytest +import tempfile + +from flytekit.tools.script_mode import ls_files + + +# a pytest fixture that creates a tmp directory and creates +# a small file structure in it +@pytest.fixture +def dummy_dir_structure(): + # Create a temporary directory + with tempfile.TemporaryDirectory() as tmp_path: + + # Create directories + tmp_path = pathlib.Path(tmp_path) + subdir1 = tmp_path / "subdir1" + subdir2 = tmp_path / "subdir2" + subdir1.mkdir() + subdir2.mkdir() + + # Create files in the root of the temporary directory + (tmp_path / "file1.txt").write_text("This is file 1") + (tmp_path / "file2.txt").write_text("This is file 2") + + # Create files in subdir1 + (subdir1 / "file3.txt").write_text("This is file 3 in subdir1") + (subdir1 / "file4.txt").write_text("This is file 4 in subdir1") + + # Create files in subdir2 + (subdir2 / "file5.txt").write_text("This is file 5 in subdir2") + + # Return the path to the temporary directory + yield tmp_path + + +def test_list_dir(dummy_dir_structure): + files, d = ls_files(str(dummy_dir_structure), []) + assert len(files) == 5 + if os.name != "nt": + assert d == "c092f1b85f7c6b2a71881a946c00a855" + + +def test_list_filtered_on_modules(dummy_dir_structure): + import sys # any module will do + files, d = ls_files(str(dummy_dir_structure), [sys]) + # because none of the files are python modules, nothing should be returned + assert len(files) == 0 + if os.name != "nt": + assert d == "d41d8cd98f00b204e9800998ecf8427e" diff --git a/tests/flytekit/unit/cli/test_cli_helpers.py b/tests/flytekit/unit/cli/test_cli_helpers.py index 455979943c..af0c63a312 100644 --- a/tests/flytekit/unit/cli/test_cli_helpers.py +++ b/tests/flytekit/unit/cli/test_cli_helpers.py @@ -1,3 +1,4 @@ +import mock import flyteidl.admin.launch_plan_pb2 as _launch_plan_pb2 import flyteidl.admin.task_pb2 as _task_pb2 import flyteidl.admin.workflow_pb2 as _workflow_pb2 @@ -8,6 +9,8 @@ from flytekit.clis import helpers from flytekit.clis.helpers import _hydrate_identifier, _hydrate_workflow_template_nodes, hydrate_registration_parameters +from flytekit.clis.sdk_in_container.helpers import parse_copy +from flytekit.tools.fast_registration import CopyFileDetection def test_parse_args_into_dict(): @@ -426,3 +429,9 @@ def test_hydrate_registration_parameters__subworkflows(): name="subworkflow", version="12345", ) + + +def test_parse_copy(): + click_current_ctx = mock.MagicMock + assert parse_copy(click_current_ctx, None, "auto") == CopyFileDetection.LOADED_MODULES + assert parse_copy(click_current_ctx, None, "all") == CopyFileDetection.ALL diff --git a/tests/flytekit/unit/tools/test_repo.py b/tests/flytekit/unit/tools/test_repo.py index 8bb6bd773a..eefcaeb3be 100644 --- a/tests/flytekit/unit/tools/test_repo.py +++ b/tests/flytekit/unit/tools/test_repo.py @@ -7,7 +7,7 @@ import flytekit.configuration from flytekit.configuration import DefaultImages, ImageConfig -from flytekit.tools.repo import find_common_root, load_packages_and_modules +from flytekit.tools.repo import find_common_root, list_packages_and_modules task_text = """ from flytekit import task @@ -66,5 +66,5 @@ def test_module_loading(mock_entities, mock_entities_2): image_config=ImageConfig.auto(img_name=DefaultImages.default_image()), ) - x = load_packages_and_modules(serialization_settings, pathlib.Path(root), [bottom_level]) + x = list_packages_and_modules(pathlib.Path(root), [bottom_level]) assert len(x) == 1