Skip to content

Commit

Permalink
Allow task runs without a flow run (#10816)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Streed <[email protected]>
  • Loading branch information
zangell44 and desertaxle authored Sep 28, 2023
1 parent 6fe0e35 commit 6b6482b
Show file tree
Hide file tree
Showing 11 changed files with 221 additions and 67 deletions.
4 changes: 2 additions & 2 deletions src/prefect/client/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class TaskRunCreate(ActionBaseModel):
)

name: str = FieldFrom(objects.TaskRun)
flow_run_id: UUID = FieldFrom(objects.TaskRun)
flow_run_id: Optional[UUID] = FieldFrom(objects.TaskRun)
task_key: str = FieldFrom(objects.TaskRun)
dynamic_key: str = FieldFrom(objects.TaskRun)
cache_key: Optional[str] = FieldFrom(objects.TaskRun)
Expand Down Expand Up @@ -461,7 +461,7 @@ class LogCreate(ActionBaseModel):
level: int = FieldFrom(objects.Log)
message: str = FieldFrom(objects.Log)
timestamp: objects.DateTimeTZ = FieldFrom(objects.Log)
flow_run_id: UUID = FieldFrom(objects.Log)
flow_run_id: Optional[UUID] = FieldFrom(objects.Log)
task_run_id: Optional[UUID] = FieldFrom(objects.Log)


Expand Down
8 changes: 4 additions & 4 deletions src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,8 @@ class Constant(TaskRunInput):

class TaskRun(ObjectBaseModel):
name: str = Field(default_factory=lambda: generate_slug(2), example="my-task-run")
flow_run_id: UUID = Field(
default=..., description="The flow run id of the task run."
flow_run_id: Optional[UUID] = Field(
default=None, description="The flow run id of the task run."
)
task_key: str = Field(
default=..., description="A unique identifier for the task being run."
Expand Down Expand Up @@ -1132,8 +1132,8 @@ class Log(ObjectBaseModel):
level: int = Field(default=..., description="The log level.")
message: str = Field(default=..., description="The log message.")
timestamp: DateTimeTZ = Field(default=..., description="The log timestamp.")
flow_run_id: UUID = Field(
default=..., description="The flow run ID associated with the log."
flow_run_id: Optional[UUID] = Field(
default=None, description="The flow run ID associated with the log."
)
task_run_id: Optional[UUID] = Field(
default=None, description="The task run ID associated with the log."
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Make flow_run_id nullable on task_run and log tables
Revision ID: 05ea6f882b1d
Revises: 4e9a6f93eb6c
Create Date: 2023-09-25 12:18:06.722322
"""
from alembic import op

# revision identifiers, used by Alembic.
revision = "05ea6f882b1d"
down_revision = "4e9a6f93eb6c"
branch_labels = None
depends_on = None


def upgrade():
with op.batch_alter_table("task_run") as batch_op:
batch_op.alter_column("flow_run_id", nullable=True)

with op.batch_alter_table("log") as batch_op:
batch_op.alter_column("flow_run_id", nullable=True)


def downgrade():
with op.batch_alter_table("task_run") as batch_op:
batch_op.alter_column("flow_run_id", nullable=False)

with op.batch_alter_table("log") as batch_op:
batch_op.alter_column("flow_run_id", nullable=False)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Make flow_run_id nullable on task_run and log tables
Revision ID: 05ea6f882b1d
Revises: 8167af8df781
Create Date: 2023-09-25 12:18:06.722322
"""
from alembic import op

# revision identifiers, used by Alembic.
revision = "05ea6f882b1d"
down_revision = "8167af8df781"
branch_labels = None
depends_on = None


def upgrade():
with op.batch_alter_table("task_run") as batch_op:
batch_op.alter_column("flow_run_id", nullable=True)

with op.batch_alter_table("log") as batch_op:
batch_op.alter_column("flow_run_id", nullable=True)


def downgrade():
with op.batch_alter_table("task_run") as batch_op:
batch_op.alter_column("flow_run_id", nullable=False)

with op.batch_alter_table("log") as batch_op:
batch_op.alter_column("flow_run_id", nullable=False)
2 changes: 1 addition & 1 deletion src/prefect/server/database/orm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def flow_run_id(cls):
return sa.Column(
UUID(),
sa.ForeignKey("flow_run.id", ondelete="cascade"),
nullable=False,
nullable=True,
index=True,
)

Expand Down
77 changes: 53 additions & 24 deletions src/prefect/server/models/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,63 @@ async def create_task_run(
now = pendulum.now("UTC")

# if a dynamic key exists, we need to guard against conflicts
insert_stmt = (
(await db.insert(db.TaskRun))
.values(
created=now,
**task_run.dict(
shallow=True, exclude={"state", "created"}, exclude_unset=True
),
if task_run.flow_run_id:
insert_stmt = (
(await db.insert(db.TaskRun))
.values(
created=now,
**task_run.dict(
shallow=True, exclude={"state", "created"}, exclude_unset=True
),
)
.on_conflict_do_nothing(
index_elements=db.task_run_unique_upsert_columns,
)
)
.on_conflict_do_nothing(
index_elements=db.task_run_unique_upsert_columns,
await session.execute(insert_stmt)

query = (
sa.select(db.TaskRun)
.where(
sa.and_(
db.TaskRun.flow_run_id == task_run.flow_run_id,
db.TaskRun.task_key == task_run.task_key,
db.TaskRun.dynamic_key == task_run.dynamic_key,
)
)
.limit(1)
.execution_options(populate_existing=True)
)
)
await session.execute(insert_stmt)

query = (
sa.select(db.TaskRun)
.where(
sa.and_(
db.TaskRun.flow_run_id == task_run.flow_run_id,
db.TaskRun.task_key == task_run.task_key,
db.TaskRun.dynamic_key == task_run.dynamic_key,
result = await session.execute(query)
model = result.scalar()
else:
# Upsert on (task_key, dynamic_key) application logic.
query = (
sa.select(db.TaskRun)
.where(
sa.and_(
db.TaskRun.flow_run_id.is_(None),
db.TaskRun.task_key == task_run.task_key,
db.TaskRun.dynamic_key == task_run.dynamic_key,
)
)
.limit(1)
.execution_options(populate_existing=True)
)
.limit(1)
.execution_options(populate_existing=True)
)
result = await session.execute(query)
model = result.scalar()

result = await session.execute(query)
model = result.scalar()

if model is None:
model = db.TaskRun(
created=now,
**task_run.dict(
shallow=True, exclude={"state", "created"}, exclude_unset=True
),
state=None,
)
session.add(model)
await session.flush()

if model.created == now and task_run.state:
await models.task_runs.set_task_run_state(
Expand Down
60 changes: 32 additions & 28 deletions src/prefect/server/orchestration/core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,16 +596,17 @@ async def after_transition(
validated_state: Optional[states.State],
context: TaskOrchestrationContext,
) -> None:
self.flow_run = await context.flow_run()
if self.flow_run:
context.run.flow_run_run_count = self.flow_run.run_count
else:
raise ObjectNotFoundError(
(
"Unable to read flow run associated with task run:"
f" {context.run.id}, this flow run might have been deleted"
),
)
if context.run.flow_run_id is not None:
self.flow_run = await context.flow_run()
if self.flow_run:
context.run.flow_run_run_count = self.flow_run.run_count
else:
raise ObjectNotFoundError(
(
"Unable to read flow run associated with task run:"
f" {context.run.id}, this flow run might have been deleted"
),
)


class HandleTaskTerminalStateTransitions(BaseOrchestrationRule):
Expand Down Expand Up @@ -652,7 +653,7 @@ async def before_transition(
context.run.run_count = 0

# Change the name of the state to retrying if its a flow run retry
if proposed_state.is_running():
if proposed_state.is_running() and context.run.flow_run_id is not None:
self.flow_run = await context.flow_run()
flow_retrying = context.run.flow_run_run_count < self.flow_run.run_count
if flow_retrying:
Expand Down Expand Up @@ -799,23 +800,26 @@ async def before_transition(
context: TaskOrchestrationContext,
) -> None:
flow_run = await context.flow_run()
if flow_run.state is None:
await self.abort_transition(
reason="The enclosing flow must be running to begin task execution."
)
elif flow_run.state.type == StateType.PAUSED:
await self.reject_transition(
state=states.Paused(name="NotReady"),
reason=(
"The flow is paused, new tasks can execute after resuming flow"
f" run: {flow_run.id}."
),
)
elif not flow_run.state.type == StateType.RUNNING:
# task runners should abort task run execution
await self.abort_transition(
reason="The enclosing flow must be running to begin task execution.",
)
if flow_run is not None:
if flow_run.state is None:
await self.abort_transition(
reason="The enclosing flow must be running to begin task execution."
)
elif flow_run.state.type == StateType.PAUSED:
await self.reject_transition(
state=states.Paused(name="NotReady"),
reason=(
"The flow is paused, new tasks can execute after resuming flow"
f" run: {flow_run.id}."
),
)
elif not flow_run.state.type == StateType.RUNNING:
# task runners should abort task run execution
await self.abort_transition(
reason=(
"The enclosing flow must be running to begin task execution."
),
)


class EnforceCancellingToCancelledTransition(BaseOrchestrationRule):
Expand Down
5 changes: 3 additions & 2 deletions src/prefect/server/orchestration/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,9 @@ async def _validate_proposed_state(
state_result_artifact = core.Artifact.from_result(state_data)
state_result_artifact.task_run_id = self.run.id

flow_run = await self.flow_run()
state_result_artifact.flow_run_id = flow_run.id
if self.run.flow_run_id is not None:
flow_run = await self.flow_run()
state_result_artifact.flow_run_id = flow_run.id

await artifacts.create_artifact(self.session, state_result_artifact)
state_payload["result_artifact_id"] = state_result_artifact.id
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/server/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class TaskRunCreate(ActionBaseModel):
)

name: str = FieldFrom(schemas.core.TaskRun)
flow_run_id: UUID = FieldFrom(schemas.core.TaskRun)
flow_run_id: Optional[UUID] = FieldFrom(schemas.core.TaskRun)
task_key: str = FieldFrom(schemas.core.TaskRun)
dynamic_key: str = FieldFrom(schemas.core.TaskRun)
cache_key: Optional[str] = FieldFrom(schemas.core.TaskRun)
Expand Down Expand Up @@ -551,7 +551,7 @@ class LogCreate(ActionBaseModel):
level: int = FieldFrom(schemas.core.Log)
message: str = FieldFrom(schemas.core.Log)
timestamp: schemas.core.DateTimeTZ = FieldFrom(schemas.core.Log)
flow_run_id: UUID = FieldFrom(schemas.core.Log)
flow_run_id: Optional[UUID] = FieldFrom(schemas.core.Log)
task_run_id: Optional[UUID] = FieldFrom(schemas.core.Log)


Expand Down
8 changes: 4 additions & 4 deletions src/prefect/server/schemas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ class TaskRun(ORMBaseModel):
"""An ORM representation of task run data."""

name: str = Field(default_factory=lambda: generate_slug(2), example="my-task-run")
flow_run_id: UUID = Field(
default=..., description="The flow run id of the task run."
flow_run_id: Optional[UUID] = Field(
default=None, description="The flow run id of the task run."
)
task_key: str = Field(
default=..., description="A unique identifier for the task being run."
Expand Down Expand Up @@ -859,8 +859,8 @@ class Log(ORMBaseModel):
level: int = Field(default=..., description="The log level.")
message: str = Field(default=..., description="The log message.")
timestamp: DateTimeTZ = Field(default=..., description="The log timestamp.")
flow_run_id: UUID = Field(
default=..., description="The flow run ID associated with the log."
flow_run_id: Optional[UUID] = Field(
default=None, description="The flow run ID associated with the log."
)
task_run_id: Optional[UUID] = Field(
default=None, description="The task run ID associated with the log."
Expand Down
Loading

0 comments on commit 6b6482b

Please sign in to comment.