Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add launch condition column #2204

Merged
merged 11 commits into from
Feb 29, 2024
12 changes: 7 additions & 5 deletions flytekit/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,13 @@ def __init__(
self.time_partition = time_partition
self.partitions = partitions
self.tag = tag
self.bindings = bindings
if len(bindings) > 0:
b = set(bindings)
if len(b) > 1:
raise ValueError(f"Multiple bindings found in query {self}")
self.binding: Optional[Artifact] = bindings[0]
else:
self.binding = None

def to_flyte_idl(
self,
Expand Down Expand Up @@ -391,23 +397,19 @@ def concrete_artifact_id(self) -> art_id.ArtifactID:

def embed_as_query(
self,
bindings: typing.List[Artifact],
partition: Optional[str] = None,
bind_to_time_partition: Optional[bool] = None,
expr: Optional[str] = None,
) -> art_id.ArtifactQuery:
"""
This should only be called in the context of a Trigger
:param bindings: The list of artifacts in trigger_on
:param partition: Can embed a time partition
:param bind_to_time_partition: Set to true if you want to bind to a time partition
:param expr: Only valid if there's a time partition.
"""
# Find self in the list, raises ValueError if not there.
idx = bindings.index(self)
aq = art_id.ArtifactQuery(
binding=art_id.ArtifactBindingData(
index=idx,
partition_key=partition,
bind_to_time_partition=bind_to_time_partition,
transform=str(expr) if expr and (partition or bind_to_time_partition) else None,
Expand Down
17 changes: 13 additions & 4 deletions flytekit/core/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flytekit.core.interface import Interface, transform_function_to_interface, transform_inputs_to_parameters
from flytekit.core.promise import create_and_link_node, translate_inputs_to_literals
from flytekit.core.reference_entity import LaunchPlanReference, ReferenceEntity
from flytekit.core.schedule import LaunchPlanTriggerBase
from flytekit.models import common as _common_models
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
Expand Down Expand Up @@ -123,6 +124,7 @@ def create(
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
ctx = FlyteContextManager.current_context()
Expand Down Expand Up @@ -174,6 +176,7 @@ def create(
raw_output_data_config=raw_output_data_config,
max_parallelism=max_parallelism,
security_context=security_context,
trigger=trigger,
overwrite_cache=overwrite_cache,
)

Expand Down Expand Up @@ -203,6 +206,7 @@ def get_or_create(
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
"""
Expand All @@ -229,6 +233,7 @@ def get_or_create(
:param max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire
workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and
parallelism/concurrency of MapTasks is independent from this.
:param trigger: [alpha] This is a new syntax for specifying schedules.
"""
if name is None and (
default_inputs is not None
Expand All @@ -241,6 +246,7 @@ def get_or_create(
or auth_role is not None
or max_parallelism is not None
or security_context is not None
or trigger is not None
or overwrite_cache is not None
):
raise ValueError(
Expand Down Expand Up @@ -299,6 +305,7 @@ def get_or_create(
max_parallelism,
auth_role=auth_role,
security_context=security_context,
trigger=trigger,
overwrite_cache=overwrite_cache,
)
LaunchPlan.CACHE[name or workflow.name] = lp
Expand All @@ -317,8 +324,8 @@ def __init__(
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
additional_metadata: Optional[Any] = None,
):
self._name = name
self._workflow = workflow
Expand All @@ -336,8 +343,8 @@ def __init__(
self._raw_output_data_config = raw_output_data_config
self._max_parallelism = max_parallelism
self._security_context = security_context
self._trigger = trigger
self._overwrite_cache = overwrite_cache
self._additional_metadata = additional_metadata

FlyteEntities.entities.append(self)

Expand All @@ -353,6 +360,7 @@ def clone_with(
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
return LaunchPlan(
Expand All @@ -367,6 +375,7 @@ def clone_with(
raw_output_data_config=raw_output_data_config or self.raw_output_data_config,
max_parallelism=max_parallelism or self.max_parallelism,
security_context=security_context or self.security_context,
trigger=trigger,
overwrite_cache=overwrite_cache or self.overwrite_cache,
)

Expand Down Expand Up @@ -435,8 +444,8 @@ def security_context(self) -> Optional[security.SecurityContext]:
return self._security_context

@property
def additional_metadata(self) -> Optional[Any]:
return self._additional_metadata
def trigger(self) -> Optional[LaunchPlanTriggerBase]:
return self._trigger

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
return self.workflow.construct_node_metadata()
Expand Down
20 changes: 19 additions & 1 deletion flytekit/core/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,20 @@

import datetime
import re as _re
from typing import Optional
from typing import Optional, Protocol, Union

import croniter as _croniter
from flyteidl.admin import schedule_pb2
from google.protobuf import message as google_message

from flytekit.models import schedule as _schedule_models


class LaunchPlanTriggerBase(Protocol):
def to_flyte_idl(self, *args, **kwargs) -> google_message.Message:
...

Check warning on line 20 in flytekit/core/schedule.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/schedule.py#L20

Added line #L20 was not covered by tests


# Duplicates flytekit.common.schedules.Schedule to avoid using the ExtendedSdkType metaclass.
class CronSchedule(_schedule_models.Schedule):
"""
Expand Down Expand Up @@ -202,3 +209,14 @@
int(duration.total_seconds() / _SECONDS_TO_MINUTES),
_schedule_models.Schedule.FixedRateUnit.MINUTE,
)


class OnSchedule(LaunchPlanTriggerBase):
def __init__(self, schedule: Union[CronSchedule, FixedRate]):
"""
:param Union[CronSchedule, FixedRate] schedule: Either a cron or a fixed rate
"""
self._schedule = schedule

Check warning on line 219 in flytekit/core/schedule.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/schedule.py#L219

Added line #L219 was not covered by tests

def to_flyte_idl(self) -> schedule_pb2.Schedule:
return self._schedule.to_flyte_idl()

Check warning on line 222 in flytekit/core/schedule.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/schedule.py#L222

Added line #L222 was not covered by tests
22 changes: 11 additions & 11 deletions flytekit/models/schedule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from flyteidl.admin import schedule_pb2 as _schedule_pb2
from flyteidl.admin import schedule_pb2

from flytekit.models import common as _common
from flytekit.models import common


class Schedule(_common.FlyteIdlEntity):
class Schedule(common.FlyteIdlEntity):
class FixedRateUnit(object):
MINUTE = _schedule_pb2.MINUTE
HOUR = _schedule_pb2.HOUR
DAY = _schedule_pb2.DAY
MINUTE = schedule_pb2.MINUTE
HOUR = schedule_pb2.HOUR
DAY = schedule_pb2.DAY

@classmethod
def enum_to_string(cls, int_value):
Expand All @@ -24,7 +24,7 @@ def enum_to_string(cls, int_value):
else:
return "{}".format(int_value)

class FixedRate(_common.FlyteIdlEntity):
class FixedRate(common.FlyteIdlEntity):
def __init__(self, value, unit):
"""
:param int value:
Expand All @@ -51,7 +51,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.schedule_pb2.FixedRate
"""
return _schedule_pb2.FixedRate(value=self.value, unit=self.unit)
return schedule_pb2.FixedRate(value=self.value, unit=self.unit)

@classmethod
def from_flyte_idl(cls, pb2_object):
Expand All @@ -61,7 +61,7 @@ def from_flyte_idl(cls, pb2_object):
"""
return cls(pb2_object.value, pb2_object.unit)

class CronSchedule(_common.FlyteIdlEntity):
class CronSchedule(common.FlyteIdlEntity):
def __init__(self, schedule, offset):
"""
:param Text schedule: cron expression or aliases
Expand All @@ -88,7 +88,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.schedule_pb2.FixedRate
"""
return _schedule_pb2.CronSchedule(schedule=self.schedule, offset=self.offset)
return schedule_pb2.CronSchedule(schedule=self.schedule, offset=self.offset)

@classmethod
def from_flyte_idl(cls, pb2_object):
Expand Down Expand Up @@ -145,7 +145,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.schedule_pb2.Schedule
"""
return _schedule_pb2.Schedule(
return schedule_pb2.Schedule(
kickoff_time_input_arg=self.kickoff_time_input_arg,
cron_expression=self.cron_expression,
rate=self.rate.to_flyte_idl() if self.rate is not None else None,
Expand Down
11 changes: 10 additions & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union

from flyteidl.admin import schedule_pb2

from flytekit import PythonFunctionTask, SourceCode
from flytekit.configuration import SerializationSettings
from flytekit.core import constants as _common_constants
Expand Down Expand Up @@ -368,12 +370,19 @@
else:
raw_prefix_config = entity.raw_output_data_config or _common_models.RawOutputDataConfig("")

if entity.trigger:
lc = entity.trigger.to_flyte_idl(entity)

Check warning on line 374 in flytekit/tools/translator.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/translator.py#L374

Added line #L374 was not covered by tests
if isinstance(lc, schedule_pb2.Schedule):
raise ValueError("Please continue to use the schedule arg, the trigger arg is not implemented yet")

Check warning on line 376 in flytekit/tools/translator.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/translator.py#L376

Added line #L376 was not covered by tests
else:
lc = None

lps = _launch_plan_models.LaunchPlanSpec(
workflow_id=wf_id,
entity_metadata=_launch_plan_models.LaunchPlanMetadata(
schedule=entity.schedule,
notifications=options.notifications or entity.notifications,
launch_conditions=entity.additional_metadata,
launch_conditions=lc,
),
default_inputs=entity.parameters,
fixed_inputs=entity.fixed_inputs,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0,<7.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.10.7",
"flyteidl>=1.11.0b0",
"fsspec>=2023.3.0",
"gcsfs>=2023.3.0",
"googleapis-common-protos>=1.57",
Expand Down
24 changes: 23 additions & 1 deletion tests/flytekit/unit/core/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_query_basic():
partition_keys=["region"],
)
data_query = aa.query(time_partition=Inputs.dt, region=Inputs.blah)
assert data_query.bindings == []
assert data_query.binding is None
assert data_query.artifact is aa
dq_idl = data_query.to_flyte_idl()
assert dq_idl.HasField("artifact_id")
Expand Down Expand Up @@ -271,6 +271,28 @@ def wf2(a: CustomReturn = wf_artifact):
assert aq.artifact_id.partitions.value["region"].static_value == "LAX"


def test_query_basic_query_bindings():
# Note these bindings don't really work yet.
aa = Artifact(
name="ride_count_data",
time_partitioned=True,
partition_keys=["region"],
)
bb = Artifact(
name="driver_data",
time_partitioned=True,
partition_keys=["region"],
)
cc = Artifact(
name="passenger_data",
time_partitioned=True,
partition_keys=["region"],
)
aa.query(time_partition=Inputs.dt, region=bb.partitions.region)
with pytest.raises(ValueError):
aa.query(time_partition=cc.time_partition, region=bb.partitions.region)


def test_partition_none():
# confirm that we can distinguish between partitions being set to empty, and not being set
# though this is not currently used.
Expand Down
Loading