Skip to content

Commit

Permalink
Add launch condition column (#2204)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
wild-endeavor authored and fiedlerNr9 committed Jul 25, 2024
1 parent cdb82d1 commit d03886a
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 24 deletions.
12 changes: 7 additions & 5 deletions flytekit/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,13 @@ def __init__(
self.time_partition = time_partition
self.partitions = partitions
self.tag = tag
self.bindings = bindings
if len(bindings) > 0:
b = set(bindings)
if len(b) > 1:
raise ValueError(f"Multiple bindings found in query {self}")
self.binding: Optional[Artifact] = bindings[0]
else:
self.binding = None

def to_flyte_idl(
self,
Expand Down Expand Up @@ -391,23 +397,19 @@ def concrete_artifact_id(self) -> art_id.ArtifactID:

def embed_as_query(
self,
bindings: typing.List[Artifact],
partition: Optional[str] = None,
bind_to_time_partition: Optional[bool] = None,
expr: Optional[str] = None,
) -> art_id.ArtifactQuery:
"""
This should only be called in the context of a Trigger
:param bindings: The list of artifacts in trigger_on
:param partition: Can embed a time partition
:param bind_to_time_partition: Set to true if you want to bind to a time partition
:param expr: Only valid if there's a time partition.
"""
# Find self in the list, raises ValueError if not there.
idx = bindings.index(self)
aq = art_id.ArtifactQuery(
binding=art_id.ArtifactBindingData(
index=idx,
partition_key=partition,
bind_to_time_partition=bind_to_time_partition,
transform=str(expr) if expr and (partition or bind_to_time_partition) else None,
Expand Down
17 changes: 13 additions & 4 deletions flytekit/core/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flytekit.core.interface import Interface, transform_function_to_interface, transform_inputs_to_parameters
from flytekit.core.promise import create_and_link_node, translate_inputs_to_literals
from flytekit.core.reference_entity import LaunchPlanReference, ReferenceEntity
from flytekit.core.schedule import LaunchPlanTriggerBase
from flytekit.models import common as _common_models
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
Expand Down Expand Up @@ -123,6 +124,7 @@ def create(
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
ctx = FlyteContextManager.current_context()
Expand Down Expand Up @@ -174,6 +176,7 @@ def create(
raw_output_data_config=raw_output_data_config,
max_parallelism=max_parallelism,
security_context=security_context,
trigger=trigger,
overwrite_cache=overwrite_cache,
)

Expand Down Expand Up @@ -203,6 +206,7 @@ def get_or_create(
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
"""
Expand All @@ -229,6 +233,7 @@ def get_or_create(
:param max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire
workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and
parallelism/concurrency of MapTasks is independent from this.
:param trigger: [alpha] This is a new syntax for specifying schedules.
"""
if name is None and (
default_inputs is not None
Expand All @@ -241,6 +246,7 @@ def get_or_create(
or auth_role is not None
or max_parallelism is not None
or security_context is not None
or trigger is not None
or overwrite_cache is not None
):
raise ValueError(
Expand Down Expand Up @@ -299,6 +305,7 @@ def get_or_create(
max_parallelism,
auth_role=auth_role,
security_context=security_context,
trigger=trigger,
overwrite_cache=overwrite_cache,
)
LaunchPlan.CACHE[name or workflow.name] = lp
Expand All @@ -317,8 +324,8 @@ def __init__(
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
additional_metadata: Optional[Any] = None,
):
self._name = name
self._workflow = workflow
Expand All @@ -336,8 +343,8 @@ def __init__(
self._raw_output_data_config = raw_output_data_config
self._max_parallelism = max_parallelism
self._security_context = security_context
self._trigger = trigger
self._overwrite_cache = overwrite_cache
self._additional_metadata = additional_metadata

FlyteEntities.entities.append(self)

Expand All @@ -353,6 +360,7 @@ def clone_with(
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
return LaunchPlan(
Expand All @@ -367,6 +375,7 @@ def clone_with(
raw_output_data_config=raw_output_data_config or self.raw_output_data_config,
max_parallelism=max_parallelism or self.max_parallelism,
security_context=security_context or self.security_context,
trigger=trigger,
overwrite_cache=overwrite_cache or self.overwrite_cache,
)

Expand Down Expand Up @@ -435,8 +444,8 @@ def security_context(self) -> Optional[security.SecurityContext]:
return self._security_context

@property
def additional_metadata(self) -> Optional[Any]:
return self._additional_metadata
def trigger(self) -> Optional[LaunchPlanTriggerBase]:
return self._trigger

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
return self.workflow.construct_node_metadata()
Expand Down
20 changes: 19 additions & 1 deletion flytekit/core/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,20 @@

import datetime
import re as _re
from typing import Optional
from typing import Optional, Protocol, Union

import croniter as _croniter
from flyteidl.admin import schedule_pb2
from google.protobuf import message as google_message

from flytekit.models import schedule as _schedule_models


class LaunchPlanTriggerBase(Protocol):
def to_flyte_idl(self, *args, **kwargs) -> google_message.Message:
...


# Duplicates flytekit.common.schedules.Schedule to avoid using the ExtendedSdkType metaclass.
class CronSchedule(_schedule_models.Schedule):
"""
Expand Down Expand Up @@ -202,3 +209,14 @@ def _translate_duration(duration: datetime.timedelta):
int(duration.total_seconds() / _SECONDS_TO_MINUTES),
_schedule_models.Schedule.FixedRateUnit.MINUTE,
)


class OnSchedule(LaunchPlanTriggerBase):
def __init__(self, schedule: Union[CronSchedule, FixedRate]):
"""
:param Union[CronSchedule, FixedRate] schedule: Either a cron or a fixed rate
"""
self._schedule = schedule

def to_flyte_idl(self) -> schedule_pb2.Schedule:
return self._schedule.to_flyte_idl()
22 changes: 11 additions & 11 deletions flytekit/models/schedule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from flyteidl.admin import schedule_pb2 as _schedule_pb2
from flyteidl.admin import schedule_pb2

from flytekit.models import common as _common
from flytekit.models import common


class Schedule(_common.FlyteIdlEntity):
class Schedule(common.FlyteIdlEntity):
class FixedRateUnit(object):
MINUTE = _schedule_pb2.MINUTE
HOUR = _schedule_pb2.HOUR
DAY = _schedule_pb2.DAY
MINUTE = schedule_pb2.MINUTE
HOUR = schedule_pb2.HOUR
DAY = schedule_pb2.DAY

@classmethod
def enum_to_string(cls, int_value):
Expand All @@ -24,7 +24,7 @@ def enum_to_string(cls, int_value):
else:
return "{}".format(int_value)

class FixedRate(_common.FlyteIdlEntity):
class FixedRate(common.FlyteIdlEntity):
def __init__(self, value, unit):
"""
:param int value:
Expand All @@ -51,7 +51,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.schedule_pb2.FixedRate
"""
return _schedule_pb2.FixedRate(value=self.value, unit=self.unit)
return schedule_pb2.FixedRate(value=self.value, unit=self.unit)

@classmethod
def from_flyte_idl(cls, pb2_object):
Expand All @@ -61,7 +61,7 @@ def from_flyte_idl(cls, pb2_object):
"""
return cls(pb2_object.value, pb2_object.unit)

class CronSchedule(_common.FlyteIdlEntity):
class CronSchedule(common.FlyteIdlEntity):
def __init__(self, schedule, offset):
"""
:param Text schedule: cron expression or aliases
Expand All @@ -88,7 +88,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.schedule_pb2.FixedRate
"""
return _schedule_pb2.CronSchedule(schedule=self.schedule, offset=self.offset)
return schedule_pb2.CronSchedule(schedule=self.schedule, offset=self.offset)

@classmethod
def from_flyte_idl(cls, pb2_object):
Expand Down Expand Up @@ -145,7 +145,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.schedule_pb2.Schedule
"""
return _schedule_pb2.Schedule(
return schedule_pb2.Schedule(
kickoff_time_input_arg=self.kickoff_time_input_arg,
cron_expression=self.cron_expression,
rate=self.rate.to_flyte_idl() if self.rate is not None else None,
Expand Down
11 changes: 10 additions & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union

from flyteidl.admin import schedule_pb2

from flytekit import PythonFunctionTask, SourceCode
from flytekit.configuration import SerializationSettings
from flytekit.core import constants as _common_constants
Expand Down Expand Up @@ -368,12 +370,19 @@ def get_serializable_launch_plan(
else:
raw_prefix_config = entity.raw_output_data_config or _common_models.RawOutputDataConfig("")

if entity.trigger:
lc = entity.trigger.to_flyte_idl(entity)
if isinstance(lc, schedule_pb2.Schedule):
raise ValueError("Please continue to use the schedule arg, the trigger arg is not implemented yet")
else:
lc = None

lps = _launch_plan_models.LaunchPlanSpec(
workflow_id=wf_id,
entity_metadata=_launch_plan_models.LaunchPlanMetadata(
schedule=entity.schedule,
notifications=options.notifications or entity.notifications,
launch_conditions=entity.additional_metadata,
launch_conditions=lc,
),
default_inputs=entity.parameters,
fixed_inputs=entity.fixed_inputs,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0,<7.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.10.7",
"flyteidl>=1.11.0b0",
"fsspec>=2023.3.0",
"gcsfs>=2023.3.0",
"googleapis-common-protos>=1.57",
Expand Down
24 changes: 23 additions & 1 deletion tests/flytekit/unit/core/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_query_basic():
partition_keys=["region"],
)
data_query = aa.query(time_partition=Inputs.dt, region=Inputs.blah)
assert data_query.bindings == []
assert data_query.binding is None
assert data_query.artifact is aa
dq_idl = data_query.to_flyte_idl()
assert dq_idl.HasField("artifact_id")
Expand Down Expand Up @@ -271,6 +271,28 @@ def wf2(a: CustomReturn = wf_artifact):
assert aq.artifact_id.partitions.value["region"].static_value == "LAX"


def test_query_basic_query_bindings():
# Note these bindings don't really work yet.
aa = Artifact(
name="ride_count_data",
time_partitioned=True,
partition_keys=["region"],
)
bb = Artifact(
name="driver_data",
time_partitioned=True,
partition_keys=["region"],
)
cc = Artifact(
name="passenger_data",
time_partitioned=True,
partition_keys=["region"],
)
aa.query(time_partition=Inputs.dt, region=bb.partitions.region)
with pytest.raises(ValueError):
aa.query(time_partition=cc.time_partition, region=bb.partitions.region)


def test_partition_none():
# confirm that we can distinguish between partitions being set to empty, and not being set
# though this is not currently used.
Expand Down

0 comments on commit d03886a

Please sign in to comment.