Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
samhita-alla committed Jul 16, 2021
2 parents 52f89af + 6875a5f commit 11f8720
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 78 deletions.
6 changes: 2 additions & 4 deletions flytekit/clis/sdk_in_container/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,8 @@ def _should_register_with_admin(entity) -> bool:
This is used in the code below. The translator.py module produces lots of objects (namely nodes and BranchNodes)
that do not/should not be written to .pb file to send to admin. This function filters them out.
"""
return entity is not None and (
isinstance(entity, task_models.TaskSpec)
or isinstance(entity, _launch_plan_models.LaunchPlan)
or isinstance(entity, admin_workflow_models.WorkflowSpec)
return isinstance(
entity, (task_models.TaskSpec, _launch_plan_models.LaunchPlan, admin_workflow_models.WorkflowSpec)
)


Expand Down
49 changes: 27 additions & 22 deletions flytekit/common/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
from flytekit.core.node import Node
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceEntity
from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate
from flytekit.core.task import ReferenceTask
from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase
from flytekit.models import common as _common_models
Expand Down Expand Up @@ -138,15 +138,17 @@ def get_serializable_workflow(
sub_wfs = []
for n in entity.nodes:
if isinstance(n.flyte_entity, WorkflowBase):
if isinstance(n.flyte_entity, ReferenceEntity):
# We are currently not supporting reference workflows since these will
# require a network call to flyteadmin to populate the WorkflowTemplate
# object
if isinstance(n.flyte_entity, ReferenceWorkflow):
raise Exception(
f"Sorry, reference subworkflows do not work right now, please use the launch plan instead for the "
f"subworkflow you're trying to invoke. Node: {n}"
"Reference sub-workflows are currently unsupported. Use reference launch plans instead."
)
sub_wf_spec = get_serializable(entity_mapping, settings, n.flyte_entity)
if not isinstance(sub_wf_spec, admin_workflow_models.WorkflowSpec):
raise Exception(
f"Serialized form of a workflow should be an admin.WorkflowSpec but {type(sub_wf_spec)} found instead"
raise TypeError(
f"Unexpected type for serialized form of workflow. Expected {admin_workflow_models.WorkflowSpec}, but got {type(sub_wf_spec)}"
)
sub_wfs.append(sub_wf_spec.template)
sub_wfs.extend(sub_wf_spec.sub_workflows)
Expand Down Expand Up @@ -245,26 +247,25 @@ def get_serializable_node(

# Reference entities also inherit from the classes in the second if statement so address them first.
if isinstance(entity.flyte_entity, ReferenceEntity):
# This is a throw away call.
# See the comment in compile_into_workflow in python_function_task. This is just used to place a None value
# in the entity_mapping.
get_serializable(entity_mapping, settings, entity.flyte_entity)
ref = entity.flyte_entity
ref_spec = get_serializable(entity_mapping, settings, entity.flyte_entity)
ref_template = ref_spec.template
node_model = workflow_model.Node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
upstream_node_ids=[n.id for n in upstream_sdk_nodes],
output_aliases=[],
)
if ref.reference.resource_type == _identifier_model.ResourceType.TASK:
node_model._task_node = workflow_model.TaskNode(reference_id=ref.id)
elif ref.reference.resource_type == _identifier_model.ResourceType.WORKFLOW:
node_model._workflow_node = workflow_model.WorkflowNode(sub_workflow_ref=ref.id)
elif ref.reference.resource_type == _identifier_model.ResourceType.LAUNCH_PLAN:
node_model._workflow_node = workflow_model.WorkflowNode(launchplan_ref=ref.id)
if ref_template.resource_type == _identifier_model.ResourceType.TASK:
node_model._task_node = workflow_model.TaskNode(reference_id=ref_template.id)
elif ref_template.resource_type == _identifier_model.ResourceType.WORKFLOW:
node_model._workflow_node = workflow_model.WorkflowNode(sub_workflow_ref=ref_template.id)
elif ref_template.resource_type == _identifier_model.ResourceType.LAUNCH_PLAN:
node_model._workflow_node = workflow_model.WorkflowNode(launchplan_ref=ref_template.id)
else:
raise Exception(f"Unexpected reference type {ref}")
raise Exception(
f"Unexpected resource type for reference entity {entity.flyte_entity}: {ref_template.resource_type}"
)
return node_model

if isinstance(entity.flyte_entity, PythonTask):
Expand Down Expand Up @@ -342,6 +343,13 @@ def get_serializable_branch_node(
)


def get_reference_spec(
entity_mapping: OrderedDict, settings: SerializationSettings, entity: ReferenceEntity
) -> ReferenceSpec:
template = ReferenceTemplate(entity.id, entity.reference.resource_type)
return ReferenceSpec(template)


def get_serializable(
entity_mapping: OrderedDict,
settings: SerializationSettings,
Expand All @@ -367,10 +375,7 @@ def get_serializable(
return entity_mapping[entity]

if isinstance(entity, ReferenceEntity):
# TODO: Create a non-registerable model class comparable to TaskSpec or WorkflowSpec to replace None as a
# keystone value. The purpose is only to store something so that we can check for it when compiling
# dynamic tasks. See comment in compile_into_workflow.
cp_entity = None
cp_entity = get_reference_spec(entity_mapping, settings, entity)

elif isinstance(entity, PythonTask):
cp_entity = get_serializable_task(entity_mapping, settings, entity)
Expand Down
47 changes: 26 additions & 21 deletions flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,8 @@
from typing import Any, Callable, List, Optional, TypeVar, Union

from flytekit.common.exceptions import scopes as exception_scopes
from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.context_manager import (
ExecutionState,
FastSerializationSettings,
FlyteContext,
FlyteContextManager,
SerializationSettings,
)
from flytekit.core.base_task import Task, TaskResolverMixin
from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_signature_to_interface
from flytekit.core.python_auto_container import PythonAutoContainerTask, default_task_resolver
from flytekit.core.tracker import isnested, istestfunction
Expand Down Expand Up @@ -168,6 +162,9 @@ def compile_into_workflow(
In the case of dynamic workflows, this function will produce a workflow definition at execution time which will
then proceed to be executed.
"""
# TODO: circular import
from flytekit.core.task import ReferenceTask

if not ctx.compilation_state:
cs = ctx.new_compilation_state("dynamic")
else:
Expand Down Expand Up @@ -202,20 +199,28 @@ def compile_into_workflow(
}
)

# This is not great. The translator.py module is relied on here (see comment above) to get the tasks and
# subworkflow definitions. However we want to ensure that reference tasks and reference sub workflows are
# not used.
# TODO: Replace None with a class.
for value in model_entities.values():
if value is None:
raise Exception(
"Reference tasks are not allowed in the dynamic - a network call is necessary "
"in order to retrieve the structure of the reference task."
# Gather underlying TaskTemplates that get referenced.
tts = []
for entity, model in model_entities.items():
# We only care about gathering tasks here. Launch plans are handled by
# propeller. Subworkflows should already be in the workflow spec.
if not isinstance(entity, Task):
continue

# We are currently not supporting reference tasks since these will
# require a network call to flyteadmin to populate the TaskTemplate
# model
if isinstance(entity, ReferenceTask):
raise Exception("Reference tasks are currently unsupported within dynamic tasks")

if not isinstance(model, task_models.TaskSpec):
raise TypeError(
f"Unexpected type for serialized form of task. Expected {task_models.TaskSpec}, but got {type(model)}"
)

# Gather underlying TaskTemplates that get referenced. Launch plans are handled by propeller. Subworkflows
# should already be in the workflow spec.
tts = [v.template for v in model_entities.values() if isinstance(v, task_models.TaskSpec)]
# Store the valid task template so that we can pass it to the
# DynamicJobSpec later
tts.append(model.template)

if ctx.serialization_settings.should_fast_serialize():
if (
Expand Down Expand Up @@ -278,7 +283,7 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
)
if is_fast_execution:
ctx = ctx.with_serialization_settings(
SerializationSettings.new_builder()
ctx.serialization_settings.new_builder()
.with_fast_serialization_settings(FastSerializationSettings(enabled=True))
.build()
)
Expand Down
47 changes: 47 additions & 0 deletions flytekit/core/reference_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,50 @@ def __call__(self, *args, **kwargs):
else:
logger.debug("Reference entity - running raw execute")
return self.execute(**kwargs)


# ReferenceEntity is not a registerable entity and therefore the below classes do not need to inherit from
# flytekit.models.common.FlyteIdlEntity.
class ReferenceTemplate(object):
def __init__(self, id: _identifier_model.Identifier, resource_type: int) -> None:
"""
A reference template encapsulates all the information necessary to use reference entities within other
workflows or dynamic tasks.
:param flytekit.models.core.identifier.Identifier id: User-specified information that uniquely
identifies this reference.
:param int resource_type: The type of reference. See: flytekit.models.core.identifier.ResourceType
"""
self._id = id
self._resource_type = resource_type

@property
def id(self) -> _identifier_model.Identifier:
"""
User-specified information that uniquely identifies this reference.
:rtype: flytekit.models.core.identifier.Identifier
"""
return self._id

@property
def resource_type(self) -> int:
"""
The type of reference.
:rtype: flytekit.models.core.identifier.ResourceType
"""
return self._resource_type


class ReferenceSpec(object):
def __init__(self, template: ReferenceTemplate) -> None:
"""
:param ReferenceTemplate template:
"""
self._template = template

@property
def template(self) -> ReferenceTemplate:
"""
:rtype: ReferenceTemplate
"""
return self._template
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
flytekit>=0.20.1
flytekitplugins-sqlalchemy
sqlalchemy
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ docker-image-py==0.1.10
flyteidl==0.19.5
# via flytekit
flytekit==0.20.1
# via
# -r requirements.in
# flytekitplugins-sqlalchemy
flytekitplugins-sqlalchemy==0.20.1
# via -r requirements.in
greenlet==1.1.0
# via sqlalchemy
Expand Down Expand Up @@ -107,7 +103,7 @@ six==1.16.0
sortedcontainers==2.4.0
# via flytekit
sqlalchemy==1.4.19
# via flytekitplugins-sqlalchemy
# via -r requirements.in
statsd==3.3.0
# via flytekit
stringcase==1.2.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
super().__init__(
name=name,
task_config=task_config,
container_image="ghcr.io/flyteorg/flytekit:sqlalchemy-b7ccc96d46a239f12e0b65ad749d1e11d5d20f46",
container_image="ghcr.io/flyteorg/flytekit:sqlalchemy-6deb81af74ce8f3768553c188ab35660c717420a",
executor_type=SQLAlchemyTaskExecutor,
task_type=self._SQLALCHEMY_TASK_TYPE,
query_template=query_template,
Expand Down
16 changes: 13 additions & 3 deletions tests/flytekit/unit/common_tests/test_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flytekit.core.base_task import kwtypes
from flytekit.core.context_manager import FastSerializationSettings, Image, ImageConfig
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
from flytekit.core.reference_entity import ReferenceSpec, ReferenceTemplate
from flytekit.core.task import ReferenceTask, task
from flytekit.core.workflow import ReferenceWorkflow, workflow
from flytekit.models.core import identifier as identifier_models
Expand All @@ -24,15 +25,24 @@
def test_references():
rlp = ReferenceLaunchPlan("media", "stg", "some.name", "cafe", inputs=kwtypes(in1=str), outputs=kwtypes())
lp_model = get_serializable(OrderedDict(), serialization_settings, rlp)
assert lp_model is None
assert isinstance(lp_model, ReferenceSpec)
assert isinstance(lp_model.template, ReferenceTemplate)
assert lp_model.template.id == rlp.reference.id
assert lp_model.template.resource_type == identifier_models.ResourceType.LAUNCH_PLAN

rt = ReferenceTask("media", "stg", "some.name", "cafe", inputs=kwtypes(in1=str), outputs=kwtypes())
task_spec = get_serializable(OrderedDict(), serialization_settings, rt)
assert task_spec is None
assert isinstance(task_spec, ReferenceSpec)
assert isinstance(task_spec.template, ReferenceTemplate)
assert task_spec.template.id == rt.reference.id
assert task_spec.template.resource_type == identifier_models.ResourceType.TASK

rw = ReferenceWorkflow("media", "stg", "some.name", "cafe", inputs=kwtypes(in1=str), outputs=kwtypes())
wf_spec = get_serializable(OrderedDict(), serialization_settings, rw)
assert wf_spec is None
assert isinstance(wf_spec, ReferenceSpec)
assert isinstance(wf_spec.template, ReferenceTemplate)
assert wf_spec.template.id == rw.reference.id
assert wf_spec.template.resource_type == identifier_models.ResourceType.WORKFLOW


def test_basics():
Expand Down
62 changes: 62 additions & 0 deletions tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import typing

from flytekit import dynamic
from flytekit.core import context_manager
from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, Image, ImageConfig
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow


def test_wf1_with_fast_dynamic():
@task
def t1(a: int) -> str:
a = a + 2
return "fast-" + str(a)

@dynamic
def my_subwf(a: int) -> typing.List[str]:
s = []
for i in range(a):
s.append(t1(a=i))
return s

@workflow
def my_wf(a: int) -> typing.List[str]:
v = my_subwf(a=a)
return v

with context_manager.FlyteContextManager.with_context(
context_manager.FlyteContextManager.current_context().with_serialization_settings(
context_manager.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
fast_serialization_settings=FastSerializationSettings(enabled=True),
)
)
) as ctx:
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
mode=ExecutionState.Mode.TASK_EXECUTION,
additional_context={
"dynamic_addl_distro": "s3://my-s3-bucket/fast/123",
"dynamic_dest_dir": "/User/flyte/workflows",
},
)
)
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5})
dynamic_job_spec = my_subwf.dispatch_execute(ctx, input_literal_map)
assert len(dynamic_job_spec._nodes) == 5
assert len(dynamic_job_spec.tasks) == 1
args = " ".join(dynamic_job_spec.tasks[0].container.args)
assert args.startswith(
"pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 "
"--dest-dir /User/flyte/workflows"
)

assert context_manager.FlyteContextManager.size() == 1
Loading

0 comments on commit 11f8720

Please sign in to comment.