From 90324e07f06da2fc82bc52f5db2442411fcdf42f Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 29 Feb 2024 11:16:15 -0800 Subject: [PATCH] Add launch condition column (#2204) Signed-off-by: Yee Hing Tong --- flytekit/core/artifact.py | 12 ++++++----- flytekit/core/launch_plan.py | 17 +++++++++++---- flytekit/core/schedule.py | 20 +++++++++++++++++- flytekit/models/schedule.py | 22 ++++++++++---------- flytekit/tools/translator.py | 11 +++++++++- pyproject.toml | 2 +- tests/flytekit/unit/core/test_artifacts.py | 24 +++++++++++++++++++++- 7 files changed, 84 insertions(+), 24 deletions(-) diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index 6c709f59a1..27d16b4822 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -147,7 +147,13 @@ def __init__( self.time_partition = time_partition self.partitions = partitions self.tag = tag - self.bindings = bindings + if len(bindings) > 0: + b = set(bindings) + if len(b) > 1: + raise ValueError(f"Multiple bindings found in query {self}") + self.binding: Optional[Artifact] = bindings[0] + else: + self.binding = None def to_flyte_idl( self, @@ -391,23 +397,19 @@ def concrete_artifact_id(self) -> art_id.ArtifactID: def embed_as_query( self, - bindings: typing.List[Artifact], partition: Optional[str] = None, bind_to_time_partition: Optional[bool] = None, expr: Optional[str] = None, ) -> art_id.ArtifactQuery: """ This should only be called in the context of a Trigger - :param bindings: The list of artifacts in trigger_on :param partition: Can embed a time partition :param bind_to_time_partition: Set to true if you want to bind to a time partition :param expr: Only valid if there's a time partition. """ # Find self in the list, raises ValueError if not there. - idx = bindings.index(self) aq = art_id.ArtifactQuery( binding=art_id.ArtifactBindingData( - index=idx, partition_key=partition, bind_to_time_partition=bind_to_time_partition, transform=str(expr) if expr and (partition or bind_to_time_partition) else None, diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 7f45287428..0b097ad847 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -8,6 +8,7 @@ from flytekit.core.interface import Interface, transform_function_to_interface, transform_inputs_to_parameters from flytekit.core.promise import create_and_link_node, translate_inputs_to_literals from flytekit.core.reference_entity import LaunchPlanReference, ReferenceEntity +from flytekit.core.schedule import LaunchPlanTriggerBase from flytekit.models import common as _common_models from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models @@ -123,6 +124,7 @@ def create( max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, auth_role: Optional[_common_models.AuthRole] = None, + trigger: Optional[LaunchPlanTriggerBase] = None, overwrite_cache: Optional[bool] = None, ) -> LaunchPlan: ctx = FlyteContextManager.current_context() @@ -174,6 +176,7 @@ def create( raw_output_data_config=raw_output_data_config, max_parallelism=max_parallelism, security_context=security_context, + trigger=trigger, overwrite_cache=overwrite_cache, ) @@ -203,6 +206,7 @@ def get_or_create( max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, auth_role: Optional[_common_models.AuthRole] = None, + trigger: Optional[LaunchPlanTriggerBase] = None, overwrite_cache: Optional[bool] = None, ) -> LaunchPlan: """ @@ -229,6 +233,7 @@ def get_or_create( :param max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and parallelism/concurrency of MapTasks is independent from this. + :param trigger: [alpha] This is a new syntax for specifying schedules. """ if name is None and ( default_inputs is not None @@ -241,6 +246,7 @@ def get_or_create( or auth_role is not None or max_parallelism is not None or security_context is not None + or trigger is not None or overwrite_cache is not None ): raise ValueError( @@ -299,6 +305,7 @@ def get_or_create( max_parallelism, auth_role=auth_role, security_context=security_context, + trigger=trigger, overwrite_cache=overwrite_cache, ) LaunchPlan.CACHE[name or workflow.name] = lp @@ -317,8 +324,8 @@ def __init__( raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, + trigger: Optional[LaunchPlanTriggerBase] = None, overwrite_cache: Optional[bool] = None, - additional_metadata: Optional[Any] = None, ): self._name = name self._workflow = workflow @@ -336,8 +343,8 @@ def __init__( self._raw_output_data_config = raw_output_data_config self._max_parallelism = max_parallelism self._security_context = security_context + self._trigger = trigger self._overwrite_cache = overwrite_cache - self._additional_metadata = additional_metadata FlyteEntities.entities.append(self) @@ -353,6 +360,7 @@ def clone_with( raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, + trigger: Optional[LaunchPlanTriggerBase] = None, overwrite_cache: Optional[bool] = None, ) -> LaunchPlan: return LaunchPlan( @@ -367,6 +375,7 @@ def clone_with( raw_output_data_config=raw_output_data_config or self.raw_output_data_config, max_parallelism=max_parallelism or self.max_parallelism, security_context=security_context or self.security_context, + trigger=trigger, overwrite_cache=overwrite_cache or self.overwrite_cache, ) @@ -435,8 +444,8 @@ def security_context(self) -> Optional[security.SecurityContext]: return self._security_context @property - def additional_metadata(self) -> Optional[Any]: - return self._additional_metadata + def trigger(self) -> Optional[LaunchPlanTriggerBase]: + return self._trigger def construct_node_metadata(self) -> _workflow_model.NodeMetadata: return self.workflow.construct_node_metadata() diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index 93116d0720..5ce0948cfd 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -6,13 +6,20 @@ import datetime import re as _re -from typing import Optional +from typing import Optional, Protocol, Union import croniter as _croniter +from flyteidl.admin import schedule_pb2 +from google.protobuf import message as google_message from flytekit.models import schedule as _schedule_models +class LaunchPlanTriggerBase(Protocol): + def to_flyte_idl(self, *args, **kwargs) -> google_message.Message: + ... + + # Duplicates flytekit.common.schedules.Schedule to avoid using the ExtendedSdkType metaclass. class CronSchedule(_schedule_models.Schedule): """ @@ -202,3 +209,14 @@ def _translate_duration(duration: datetime.timedelta): int(duration.total_seconds() / _SECONDS_TO_MINUTES), _schedule_models.Schedule.FixedRateUnit.MINUTE, ) + + +class OnSchedule(LaunchPlanTriggerBase): + def __init__(self, schedule: Union[CronSchedule, FixedRate]): + """ + :param Union[CronSchedule, FixedRate] schedule: Either a cron or a fixed rate + """ + self._schedule = schedule + + def to_flyte_idl(self) -> schedule_pb2.Schedule: + return self._schedule.to_flyte_idl() diff --git a/flytekit/models/schedule.py b/flytekit/models/schedule.py index a6be2a58ee..65d3f477ac 100644 --- a/flytekit/models/schedule.py +++ b/flytekit/models/schedule.py @@ -1,13 +1,13 @@ -from flyteidl.admin import schedule_pb2 as _schedule_pb2 +from flyteidl.admin import schedule_pb2 -from flytekit.models import common as _common +from flytekit.models import common -class Schedule(_common.FlyteIdlEntity): +class Schedule(common.FlyteIdlEntity): class FixedRateUnit(object): - MINUTE = _schedule_pb2.MINUTE - HOUR = _schedule_pb2.HOUR - DAY = _schedule_pb2.DAY + MINUTE = schedule_pb2.MINUTE + HOUR = schedule_pb2.HOUR + DAY = schedule_pb2.DAY @classmethod def enum_to_string(cls, int_value): @@ -24,7 +24,7 @@ def enum_to_string(cls, int_value): else: return "{}".format(int_value) - class FixedRate(_common.FlyteIdlEntity): + class FixedRate(common.FlyteIdlEntity): def __init__(self, value, unit): """ :param int value: @@ -51,7 +51,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.schedule_pb2.FixedRate """ - return _schedule_pb2.FixedRate(value=self.value, unit=self.unit) + return schedule_pb2.FixedRate(value=self.value, unit=self.unit) @classmethod def from_flyte_idl(cls, pb2_object): @@ -61,7 +61,7 @@ def from_flyte_idl(cls, pb2_object): """ return cls(pb2_object.value, pb2_object.unit) - class CronSchedule(_common.FlyteIdlEntity): + class CronSchedule(common.FlyteIdlEntity): def __init__(self, schedule, offset): """ :param Text schedule: cron expression or aliases @@ -88,7 +88,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.schedule_pb2.FixedRate """ - return _schedule_pb2.CronSchedule(schedule=self.schedule, offset=self.offset) + return schedule_pb2.CronSchedule(schedule=self.schedule, offset=self.offset) @classmethod def from_flyte_idl(cls, pb2_object): @@ -145,7 +145,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.schedule_pb2.Schedule """ - return _schedule_pb2.Schedule( + return schedule_pb2.Schedule( kickoff_time_input_arg=self.kickoff_time_input_arg, cron_expression=self.cron_expression, rate=self.rate.to_flyte_idl() if self.rate is not None else None, diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 6d696bc4d6..7bc719cef8 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union +from flyteidl.admin import schedule_pb2 + from flytekit import PythonFunctionTask, SourceCode from flytekit.configuration import SerializationSettings from flytekit.core import constants as _common_constants @@ -368,12 +370,19 @@ def get_serializable_launch_plan( else: raw_prefix_config = entity.raw_output_data_config or _common_models.RawOutputDataConfig("") + if entity.trigger: + lc = entity.trigger.to_flyte_idl(entity) + if isinstance(lc, schedule_pb2.Schedule): + raise ValueError("Please continue to use the schedule arg, the trigger arg is not implemented yet") + else: + lc = None + lps = _launch_plan_models.LaunchPlanSpec( workflow_id=wf_id, entity_metadata=_launch_plan_models.LaunchPlanMetadata( schedule=entity.schedule, notifications=options.notifications or entity.notifications, - launch_conditions=entity.additional_metadata, + launch_conditions=lc, ), default_inputs=entity.parameters, fixed_inputs=entity.fixed_inputs, diff --git a/pyproject.toml b/pyproject.toml index 1f83e2b4e3..c4b6c03e97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0,<7.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.10.7", + "flyteidl>=1.11.0b0", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/core/test_artifacts.py b/tests/flytekit/unit/core/test_artifacts.py index abd9e456c2..b3fc5b5f64 100644 --- a/tests/flytekit/unit/core/test_artifacts.py +++ b/tests/flytekit/unit/core/test_artifacts.py @@ -194,7 +194,7 @@ def test_query_basic(): partition_keys=["region"], ) data_query = aa.query(time_partition=Inputs.dt, region=Inputs.blah) - assert data_query.bindings == [] + assert data_query.binding is None assert data_query.artifact is aa dq_idl = data_query.to_flyte_idl() assert dq_idl.HasField("artifact_id") @@ -271,6 +271,28 @@ def wf2(a: CustomReturn = wf_artifact): assert aq.artifact_id.partitions.value["region"].static_value == "LAX" +def test_query_basic_query_bindings(): + # Note these bindings don't really work yet. + aa = Artifact( + name="ride_count_data", + time_partitioned=True, + partition_keys=["region"], + ) + bb = Artifact( + name="driver_data", + time_partitioned=True, + partition_keys=["region"], + ) + cc = Artifact( + name="passenger_data", + time_partitioned=True, + partition_keys=["region"], + ) + aa.query(time_partition=Inputs.dt, region=bb.partitions.region) + with pytest.raises(ValueError): + aa.query(time_partition=cc.time_partition, region=bb.partitions.region) + + def test_partition_none(): # confirm that we can distinguish between partitions being set to empty, and not being set # though this is not currently used.