Skip to content

Commit

Permalink
pyflyte run imperative workflows (#1131)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Aug 29, 2022
1 parent 3cf0639 commit 9792893
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 18 deletions.
2 changes: 2 additions & 0 deletions flytekit/clis/sdk_in_container/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
CTX_PACKAGES = "pkgs"
CTX_NOTIFICATIONS = "notifications"
CTX_CONFIG_FILE = "config_file"
CTX_PROJECT_ROOT = "project_root"
CTX_MODULE = "module"


project_option = _click.option(
Expand Down
29 changes: 20 additions & 9 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
from typing_extensions import get_args

from flytekit import BlobType, Literal, Scalar
from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_DOMAIN, CTX_PROJECT
from flytekit.clis.sdk_in_container.constants import (
CTX_CONFIG_FILE,
CTX_DOMAIN,
CTX_MODULE,
CTX_PROJECT,
CTX_PROJECT_ROOT,
)
from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context
from flytekit.configuration import ImageConfig
from flytekit.configuration.default_images import DefaultImages
from flytekit.core import context_manager, tracker
from flytekit.core import context_manager
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContext
from flytekit.core.data_persistence import FileAccessProvider
Expand Down Expand Up @@ -480,14 +486,12 @@ def get_entities_in_file(filename: str) -> Entities:
workflows = []
tasks = []
module = importlib.import_module(module_name)
for k in dir(module):
o = module.__dict__[k]
if isinstance(o, PythonFunctionWorkflow):
_, _, fn, _ = tracker.extract_task_module(o)
workflows.append(fn)
for name in dir(module):
o = module.__dict__[name]
if isinstance(o, WorkflowBase):
workflows.append(name)
elif isinstance(o, PythonTask):
_, _, fn, _ = tracker.extract_task_module(o)
tasks.append(fn)
tasks.append(name)

return Entities(workflows, tasks)

Expand Down Expand Up @@ -542,6 +546,8 @@ def _run(*args, **kwargs):
domain=domain,
image_config=image_config,
destination_dir=run_level_params.get("destination_dir"),
source_path=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT),
module_name=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_MODULE),
)

options = None
Expand Down Expand Up @@ -602,11 +608,16 @@ def get_command(self, ctx, exe_entity):
)

project_root = _find_project_root(self._filename)

# Find the relative path for the filename relative to the root of the project.
# N.B.: by construction project_root will necessarily be an ancestor of the filename passed in as
# a parameter.
rel_path = self._filename.relative_to(project_root)
module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".")

ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_PROJECT_ROOT] = project_root
ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_MODULE] = module

entity = load_naive_entity(module, exe_entity, project_root)

# If this is a remote execution, which we should know at this point, then create the remote object
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect as _inspect
import os
import typing
from types import ModuleType
from typing import Callable, Tuple, Union

from flytekit.configuration.feature_flags import FeatureFlags
Expand Down Expand Up @@ -239,6 +240,11 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str,
if mod_name == "__main__":
return name, "", name, os.path.abspath(inspect.getfile(f))

mod_name = get_full_module_path(mod, mod_name)
return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod))


def get_full_module_path(mod: ModuleType, mod_name: str) -> str:
if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != ".":
package_root = (
FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != "auto" else None
Expand All @@ -247,4 +253,4 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str,
# We only replace the mod_name if it is more specific, else we already have a fully resolved path
if len(new_mod_name) > len(mod_name):
mod_name = new_mod_name
return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod))
return mod_name
7 changes: 6 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,8 @@ def register_script(
destination_dir: str = ".",
default_launch_plan: typing.Optional[bool] = True,
options: typing.Optional[Options] = None,
source_path: typing.Optional[str] = None,
module_name: typing.Optional[str] = None,
) -> typing.Union[FlyteWorkflow, FlyteTask]:
"""
Use this method to register a workflow via script mode.
Expand All @@ -588,13 +590,16 @@ def register_script(
:param entity: The workflow to be registered or the task to be registered
:param default_launch_plan: This should be true if a default launch plan should be created for the workflow
:param options: Additional execution options that can be configured for the default launchplan
:param source_path: The root of the project path
:param module_name: the name of the module
:return:
"""
if image_config is None:
image_config = ImageConfig.auto_default_image()

upload_location, md5_bytes = fast_register_single_script(
entity,
source_path,
module_name,
functools.partial(
self.client.get_upload_signed_url,
project=project or self.default_project,
Expand Down
12 changes: 5 additions & 7 deletions flytekit/tools/script_mode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gzip
import hashlib
import importlib
import os
import shutil
import tarfile
Expand All @@ -10,8 +11,7 @@
from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2

from flytekit.core import context_manager
from flytekit.core.tracker import extract_task_module
from flytekit.core.workflow import WorkflowBase
from flytekit.core.tracker import get_full_module_path


def compress_single_script(source_path: str, destination: str, full_module_name: str):
Expand Down Expand Up @@ -97,16 +97,14 @@ def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo:


def fast_register_single_script(
wf_entity: WorkflowBase, create_upload_location_fn: typing.Callable
source_path: str, module_name: str, create_upload_location_fn: typing.Callable
) -> (_data_proxy_pb2.CreateUploadLocationResponse, bytes):
_, mod_name, _, script_full_path = extract_task_module(wf_entity)
# Find project root by moving up the folder hierarchy until you cannot find a __init__.py file.
source_path = _find_project_root(script_full_path)

# Open a temp directory and dump the contents of the digest.
with tempfile.TemporaryDirectory() as tmp_dir:
archive_fname = os.path.join(tmp_dir, "script_mode.tar.gz")
compress_single_script(source_path, archive_fname, mod_name)
mod = importlib.import_module(module_name)
compress_single_script(source_path, archive_fname, get_full_module_path(mod, mod.__name__))

flyte_ctx = context_manager.FlyteContextManager.current_context()
md5, _ = hash_file(archive_fname)
Expand Down
39 changes: 39 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/imperative_wf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import typing

from flytekit import Workflow, task


@task
def t1(a: str) -> str:
return a + " world"


@task
def t2():
print("side effect")


@task
def t3(a: typing.List[str]) -> str:
return ",".join(a)


wf = Workflow(name="my.imperative.workflow.example")
wf.add_workflow_input("in1", str)
node_t1 = wf.add_entity(t1, a=wf.inputs["in1"])
wf.add_workflow_output("output_from_t1", node_t1.outputs["o0"])
wf.add_entity(t2)

wf_in2 = wf.add_workflow_input("in2", str)
node_t3 = wf.add_entity(t3, a=[wf.inputs["in1"], wf_in2])

wf.add_workflow_output(
"output_list",
[node_t1.outputs["o0"], node_t3.outputs["o0"]],
python_type=typing.List[str],
)


if __name__ == "__main__":
print(wf)
print(wf(in1="hello", in2="foo"))
11 changes: 11 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from flytekit.core.task import task

WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py")
IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py")
DIR_NAME = os.path.dirname(os.path.realpath(__file__))


Expand All @@ -30,6 +31,16 @@ def test_pyflyte_run_wf():
assert result.exit_code == 0


def test_imperative_wf():
runner = CliRunner()
result = runner.invoke(
pyflyte.main,
["run", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"],
catch_exceptions=False,
)
assert result.exit_code == 0


def test_pyflyte_run_cli():
runner = CliRunner()
result = runner.invoke(
Expand Down

0 comments on commit 9792893

Please sign in to comment.