Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pyflyte run imperative workflows #1131

Merged
merged 8 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flytekit/clis/sdk_in_container/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
CTX_PACKAGES = "pkgs"
CTX_NOTIFICATIONS = "notifications"
CTX_CONFIG_FILE = "config_file"
CTX_PROJECT_ROOT = "project_root"
CTX_MODULE = "module"


project_option = _click.option(
Expand Down
15 changes: 14 additions & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
from typing_extensions import get_args

from flytekit import BlobType, Literal, Scalar
from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_DOMAIN, CTX_PROJECT
from flytekit.clis.sdk_in_container.constants import (
CTX_CONFIG_FILE,
CTX_DOMAIN,
CTX_MODULE,
CTX_PROJECT,
CTX_PROJECT_ROOT,
)
from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context
from flytekit.configuration import ImageConfig
from flytekit.configuration.default_images import DefaultImages
Expand Down Expand Up @@ -542,6 +548,8 @@ def _run(*args, **kwargs):
domain=domain,
image_config=image_config,
destination_dir=run_level_params.get("destination_dir"),
source_path=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT),
module_name=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_MODULE),
)

options = None
Expand Down Expand Up @@ -602,11 +610,16 @@ def get_command(self, ctx, exe_entity):
)

project_root = _find_project_root(self._filename)

# Find the relative path for the filename relative to the root of the project.
# N.B.: by construction project_root will necessarily be an ancestor of the filename passed in as
# a parameter.
rel_path = self._filename.relative_to(project_root)
module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".")

ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_PROJECT_ROOT] = project_root
ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_MODULE] = module

entity = load_naive_entity(module, exe_entity, project_root)

# If this is a remote execution, which we should know at this point, then create the remote object
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect as _inspect
import os
import typing
from types import ModuleType
from typing import Callable, Tuple, Union

from flytekit.configuration.feature_flags import FeatureFlags
Expand Down Expand Up @@ -239,6 +240,11 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str,
if mod_name == "__main__":
return name, "", name, os.path.abspath(inspect.getfile(f))

mod_name = get_full_module_path(mod, mod_name)
return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod))


def get_full_module_path(mod: ModuleType, mod_name: str) -> str:
if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != ".":
package_root = (
FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != "auto" else None
Expand All @@ -247,4 +253,4 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str,
# We only replace the mod_name if it is more specific, else we already have a fully resolved path
if len(new_mod_name) > len(mod_name):
mod_name = new_mod_name
return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod))
return mod_name
7 changes: 6 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,8 @@ def register_script(
destination_dir: str = ".",
default_launch_plan: typing.Optional[bool] = True,
options: typing.Optional[Options] = None,
source_path: typing.Optional[str] = None,
module_name: typing.Optional[str] = None,
) -> typing.Union[FlyteWorkflow, FlyteTask]:
"""
Use this method to register a workflow via script mode.
Expand All @@ -588,13 +590,16 @@ def register_script(
:param entity: The workflow to be registered or the task to be registered
:param default_launch_plan: This should be true if a default launch plan should be created for the workflow
:param options: Additional execution options that can be configured for the default launchplan
:param source_path: The root of the project path
:param module_name: the name of the module
:return:
"""
if image_config is None:
image_config = ImageConfig.auto_default_image()

upload_location, md5_bytes = fast_register_single_script(
entity,
source_path,
module_name,
functools.partial(
self.client.get_upload_signed_url,
project=project or self.default_project,
Expand Down
12 changes: 5 additions & 7 deletions flytekit/tools/script_mode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gzip
import hashlib
import importlib
import os
import shutil
import tarfile
Expand All @@ -10,8 +11,7 @@
from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2

from flytekit.core import context_manager
from flytekit.core.tracker import extract_task_module
from flytekit.core.workflow import WorkflowBase
from flytekit.core.tracker import get_full_module_path


def compress_single_script(source_path: str, destination: str, full_module_name: str):
Expand Down Expand Up @@ -97,16 +97,14 @@ def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo:


def fast_register_single_script(
wf_entity: WorkflowBase, create_upload_location_fn: typing.Callable
source_path: str, module_name: str, create_upload_location_fn: typing.Callable
) -> (_data_proxy_pb2.CreateUploadLocationResponse, bytes):
_, mod_name, _, script_full_path = extract_task_module(wf_entity)
Copy link
Member Author

@pingsutw pingsutw Aug 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't get the module name of the imperative workflow by using inspect.getmodule because inspect.getmodule(<imperative_wf>).__name__ is always equal to flytekit.core.workflow

To address this issue, we can use the module name passed in the pyflyte run instead.
pyflyte run --remote imperative_wf.py (model_name) wf

# Find project root by moving up the folder hierarchy until you cannot find a __init__.py file.
source_path = _find_project_root(script_full_path)

# Open a temp directory and dump the contents of the digest.
with tempfile.TemporaryDirectory() as tmp_dir:
archive_fname = os.path.join(tmp_dir, "script_mode.tar.gz")
compress_single_script(source_path, archive_fname, mod_name)
mod = importlib.import_module(module_name)
compress_single_script(source_path, archive_fname, get_full_module_path(mod, mod.__name__))

flyte_ctx = context_manager.FlyteContextManager.current_context()
md5, _ = hash_file(archive_fname)
Expand Down
39 changes: 39 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/imperative_wf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import typing

from flytekit import Workflow, task


@task
def t1(a: str) -> str:
return a + " world"


@task
def t2():
print("side effect")


@task
def t3(a: typing.List[str]) -> str:
return ",".join(a)


wf = Workflow(name="my.imperative.workflow.example")
wf.add_workflow_input("in1", str)
node_t1 = wf.add_entity(t1, a=wf.inputs["in1"])
wf.add_workflow_output("output_from_t1", node_t1.outputs["o0"])
wf.add_entity(t2)

wf_in2 = wf.add_workflow_input("in2", str)
node_t3 = wf.add_entity(t3, a=[wf.inputs["in1"], wf_in2])

wf.add_workflow_output(
"output_list",
[node_t1.outputs["o0"], node_t3.outputs["o0"]],
python_type=typing.List[str],
)


if __name__ == "__main__":
print(wf)
print(wf(in1="hello", in2="foo"))
11 changes: 11 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from flytekit.core.task import task

WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py")
IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py")
DIR_NAME = os.path.dirname(os.path.realpath(__file__))


Expand All @@ -30,6 +31,16 @@ def test_pyflyte_run_wf():
assert result.exit_code == 0


def test_imperative_wf():
runner = CliRunner()
result = runner.invoke(
pyflyte.main,
["run", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"],
catch_exceptions=False,
)
assert result.exit_code == 0


def test_pyflyte_run_cli():
runner = CliRunner()
result = runner.invoke(
Expand Down