diff --git a/src/prefect/client/schemas/actions.py b/src/prefect/client/schemas/actions.py index 1188e292662f..01ca752b6153 100644 --- a/src/prefect/client/schemas/actions.py +++ b/src/prefect/client/schemas/actions.py @@ -281,7 +281,7 @@ class TaskRunCreate(ActionBaseModel): ) name: str = FieldFrom(objects.TaskRun) - flow_run_id: UUID = FieldFrom(objects.TaskRun) + flow_run_id: Optional[UUID] = FieldFrom(objects.TaskRun) task_key: str = FieldFrom(objects.TaskRun) dynamic_key: str = FieldFrom(objects.TaskRun) cache_key: Optional[str] = FieldFrom(objects.TaskRun) @@ -461,7 +461,7 @@ class LogCreate(ActionBaseModel): level: int = FieldFrom(objects.Log) message: str = FieldFrom(objects.Log) timestamp: objects.DateTimeTZ = FieldFrom(objects.Log) - flow_run_id: UUID = FieldFrom(objects.Log) + flow_run_id: Optional[UUID] = FieldFrom(objects.Log) task_run_id: Optional[UUID] = FieldFrom(objects.Log) diff --git a/src/prefect/client/schemas/objects.py b/src/prefect/client/schemas/objects.py index 150d6ed8039f..fa4e90871ce6 100644 --- a/src/prefect/client/schemas/objects.py +++ b/src/prefect/client/schemas/objects.py @@ -620,8 +620,8 @@ class Constant(TaskRunInput): class TaskRun(ObjectBaseModel): name: str = Field(default_factory=lambda: generate_slug(2), example="my-task-run") - flow_run_id: UUID = Field( - default=..., description="The flow run id of the task run." + flow_run_id: Optional[UUID] = Field( + default=None, description="The flow run id of the task run." ) task_key: str = Field( default=..., description="A unique identifier for the task being run." @@ -1132,8 +1132,8 @@ class Log(ObjectBaseModel): level: int = Field(default=..., description="The log level.") message: str = Field(default=..., description="The log message.") timestamp: DateTimeTZ = Field(default=..., description="The log timestamp.") - flow_run_id: UUID = Field( - default=..., description="The flow run ID associated with the log." + flow_run_id: Optional[UUID] = Field( + default=None, description="The flow run ID associated with the log." ) task_run_id: Optional[UUID] = Field( default=None, description="The task run ID associated with the log." diff --git a/src/prefect/server/database/migrations/versions/postgresql/2023_09_25_121806_8167af8df781_remove_flow_run_id_requirement_from_task_run.py b/src/prefect/server/database/migrations/versions/postgresql/2023_09_25_121806_8167af8df781_remove_flow_run_id_requirement_from_task_run.py new file mode 100644 index 000000000000..243b41e5fc85 --- /dev/null +++ b/src/prefect/server/database/migrations/versions/postgresql/2023_09_25_121806_8167af8df781_remove_flow_run_id_requirement_from_task_run.py @@ -0,0 +1,30 @@ +"""Make flow_run_id nullable on task_run and log tables + +Revision ID: 05ea6f882b1d +Revises: 4e9a6f93eb6c +Create Date: 2023-09-25 12:18:06.722322 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "05ea6f882b1d" +down_revision = "4e9a6f93eb6c" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("task_run") as batch_op: + batch_op.alter_column("flow_run_id", nullable=True) + + with op.batch_alter_table("log") as batch_op: + batch_op.alter_column("flow_run_id", nullable=True) + + +def downgrade(): + with op.batch_alter_table("task_run") as batch_op: + batch_op.alter_column("flow_run_id", nullable=False) + + with op.batch_alter_table("log") as batch_op: + batch_op.alter_column("flow_run_id", nullable=False) diff --git a/src/prefect/server/database/migrations/versions/sqlite/2023_09_25_121806_8167af8df781_remove_flow_run_id_requirement_from_task_run.py b/src/prefect/server/database/migrations/versions/sqlite/2023_09_25_121806_8167af8df781_remove_flow_run_id_requirement_from_task_run.py new file mode 100644 index 000000000000..510f2222dcb6 --- /dev/null +++ b/src/prefect/server/database/migrations/versions/sqlite/2023_09_25_121806_8167af8df781_remove_flow_run_id_requirement_from_task_run.py @@ -0,0 +1,30 @@ +"""Make flow_run_id nullable on task_run and log tables + +Revision ID: 05ea6f882b1d +Revises: 8167af8df781 +Create Date: 2023-09-25 12:18:06.722322 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "05ea6f882b1d" +down_revision = "8167af8df781" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("task_run") as batch_op: + batch_op.alter_column("flow_run_id", nullable=True) + + with op.batch_alter_table("log") as batch_op: + batch_op.alter_column("flow_run_id", nullable=True) + + +def downgrade(): + with op.batch_alter_table("task_run") as batch_op: + batch_op.alter_column("flow_run_id", nullable=False) + + with op.batch_alter_table("log") as batch_op: + batch_op.alter_column("flow_run_id", nullable=False) diff --git a/src/prefect/server/database/orm_models.py b/src/prefect/server/database/orm_models.py index a4d48f8488c2..143e2a87ed9f 100644 --- a/src/prefect/server/database/orm_models.py +++ b/src/prefect/server/database/orm_models.py @@ -684,7 +684,7 @@ def flow_run_id(cls): return sa.Column( UUID(), sa.ForeignKey("flow_run.id", ondelete="cascade"), - nullable=False, + nullable=True, index=True, ) diff --git a/src/prefect/server/models/task_runs.py b/src/prefect/server/models/task_runs.py index e459b47392d9..1c9128e1b41e 100644 --- a/src/prefect/server/models/task_runs.py +++ b/src/prefect/server/models/task_runs.py @@ -51,34 +51,63 @@ async def create_task_run( now = pendulum.now("UTC") # if a dynamic key exists, we need to guard against conflicts - insert_stmt = ( - (await db.insert(db.TaskRun)) - .values( - created=now, - **task_run.dict( - shallow=True, exclude={"state", "created"}, exclude_unset=True - ), + if task_run.flow_run_id: + insert_stmt = ( + (await db.insert(db.TaskRun)) + .values( + created=now, + **task_run.dict( + shallow=True, exclude={"state", "created"}, exclude_unset=True + ), + ) + .on_conflict_do_nothing( + index_elements=db.task_run_unique_upsert_columns, + ) ) - .on_conflict_do_nothing( - index_elements=db.task_run_unique_upsert_columns, + await session.execute(insert_stmt) + + query = ( + sa.select(db.TaskRun) + .where( + sa.and_( + db.TaskRun.flow_run_id == task_run.flow_run_id, + db.TaskRun.task_key == task_run.task_key, + db.TaskRun.dynamic_key == task_run.dynamic_key, + ) + ) + .limit(1) + .execution_options(populate_existing=True) ) - ) - await session.execute(insert_stmt) - - query = ( - sa.select(db.TaskRun) - .where( - sa.and_( - db.TaskRun.flow_run_id == task_run.flow_run_id, - db.TaskRun.task_key == task_run.task_key, - db.TaskRun.dynamic_key == task_run.dynamic_key, + result = await session.execute(query) + model = result.scalar() + else: + # Upsert on (task_key, dynamic_key) application logic. + query = ( + sa.select(db.TaskRun) + .where( + sa.and_( + db.TaskRun.flow_run_id.is_(None), + db.TaskRun.task_key == task_run.task_key, + db.TaskRun.dynamic_key == task_run.dynamic_key, + ) ) + .limit(1) + .execution_options(populate_existing=True) ) - .limit(1) - .execution_options(populate_existing=True) - ) - result = await session.execute(query) - model = result.scalar() + + result = await session.execute(query) + model = result.scalar() + + if model is None: + model = db.TaskRun( + created=now, + **task_run.dict( + shallow=True, exclude={"state", "created"}, exclude_unset=True + ), + state=None, + ) + session.add(model) + await session.flush() if model.created == now and task_run.state: await models.task_runs.set_task_run_state( diff --git a/src/prefect/server/orchestration/core_policy.py b/src/prefect/server/orchestration/core_policy.py index 139570f5cc4e..299c615f519f 100644 --- a/src/prefect/server/orchestration/core_policy.py +++ b/src/prefect/server/orchestration/core_policy.py @@ -596,16 +596,17 @@ async def after_transition( validated_state: Optional[states.State], context: TaskOrchestrationContext, ) -> None: - self.flow_run = await context.flow_run() - if self.flow_run: - context.run.flow_run_run_count = self.flow_run.run_count - else: - raise ObjectNotFoundError( - ( - "Unable to read flow run associated with task run:" - f" {context.run.id}, this flow run might have been deleted" - ), - ) + if context.run.flow_run_id is not None: + self.flow_run = await context.flow_run() + if self.flow_run: + context.run.flow_run_run_count = self.flow_run.run_count + else: + raise ObjectNotFoundError( + ( + "Unable to read flow run associated with task run:" + f" {context.run.id}, this flow run might have been deleted" + ), + ) class HandleTaskTerminalStateTransitions(BaseOrchestrationRule): @@ -652,7 +653,7 @@ async def before_transition( context.run.run_count = 0 # Change the name of the state to retrying if its a flow run retry - if proposed_state.is_running(): + if proposed_state.is_running() and context.run.flow_run_id is not None: self.flow_run = await context.flow_run() flow_retrying = context.run.flow_run_run_count < self.flow_run.run_count if flow_retrying: @@ -799,23 +800,26 @@ async def before_transition( context: TaskOrchestrationContext, ) -> None: flow_run = await context.flow_run() - if flow_run.state is None: - await self.abort_transition( - reason="The enclosing flow must be running to begin task execution." - ) - elif flow_run.state.type == StateType.PAUSED: - await self.reject_transition( - state=states.Paused(name="NotReady"), - reason=( - "The flow is paused, new tasks can execute after resuming flow" - f" run: {flow_run.id}." - ), - ) - elif not flow_run.state.type == StateType.RUNNING: - # task runners should abort task run execution - await self.abort_transition( - reason="The enclosing flow must be running to begin task execution.", - ) + if flow_run is not None: + if flow_run.state is None: + await self.abort_transition( + reason="The enclosing flow must be running to begin task execution." + ) + elif flow_run.state.type == StateType.PAUSED: + await self.reject_transition( + state=states.Paused(name="NotReady"), + reason=( + "The flow is paused, new tasks can execute after resuming flow" + f" run: {flow_run.id}." + ), + ) + elif not flow_run.state.type == StateType.RUNNING: + # task runners should abort task run execution + await self.abort_transition( + reason=( + "The enclosing flow must be running to begin task execution." + ), + ) class EnforceCancellingToCancelledTransition(BaseOrchestrationRule): diff --git a/src/prefect/server/orchestration/rules.py b/src/prefect/server/orchestration/rules.py index c6df4f8abdf4..541f2db629d2 100644 --- a/src/prefect/server/orchestration/rules.py +++ b/src/prefect/server/orchestration/rules.py @@ -429,8 +429,9 @@ async def _validate_proposed_state( state_result_artifact = core.Artifact.from_result(state_data) state_result_artifact.task_run_id = self.run.id - flow_run = await self.flow_run() - state_result_artifact.flow_run_id = flow_run.id + if self.run.flow_run_id is not None: + flow_run = await self.flow_run() + state_result_artifact.flow_run_id = flow_run.id await artifacts.create_artifact(self.session, state_result_artifact) state_payload["result_artifact_id"] = state_result_artifact.id diff --git a/src/prefect/server/schemas/actions.py b/src/prefect/server/schemas/actions.py index b7bcceee4dd3..4689f0b4329c 100644 --- a/src/prefect/server/schemas/actions.py +++ b/src/prefect/server/schemas/actions.py @@ -331,7 +331,7 @@ class TaskRunCreate(ActionBaseModel): ) name: str = FieldFrom(schemas.core.TaskRun) - flow_run_id: UUID = FieldFrom(schemas.core.TaskRun) + flow_run_id: Optional[UUID] = FieldFrom(schemas.core.TaskRun) task_key: str = FieldFrom(schemas.core.TaskRun) dynamic_key: str = FieldFrom(schemas.core.TaskRun) cache_key: Optional[str] = FieldFrom(schemas.core.TaskRun) @@ -551,7 +551,7 @@ class LogCreate(ActionBaseModel): level: int = FieldFrom(schemas.core.Log) message: str = FieldFrom(schemas.core.Log) timestamp: schemas.core.DateTimeTZ = FieldFrom(schemas.core.Log) - flow_run_id: UUID = FieldFrom(schemas.core.Log) + flow_run_id: Optional[UUID] = FieldFrom(schemas.core.Log) task_run_id: Optional[UUID] = FieldFrom(schemas.core.Log) diff --git a/src/prefect/server/schemas/core.py b/src/prefect/server/schemas/core.py index 61cfbbc41350..f3786ae3651a 100644 --- a/src/prefect/server/schemas/core.py +++ b/src/prefect/server/schemas/core.py @@ -373,8 +373,8 @@ class TaskRun(ORMBaseModel): """An ORM representation of task run data.""" name: str = Field(default_factory=lambda: generate_slug(2), example="my-task-run") - flow_run_id: UUID = Field( - default=..., description="The flow run id of the task run." + flow_run_id: Optional[UUID] = Field( + default=None, description="The flow run id of the task run." ) task_key: str = Field( default=..., description="A unique identifier for the task being run." @@ -859,8 +859,8 @@ class Log(ORMBaseModel): level: int = Field(default=..., description="The log level.") message: str = Field(default=..., description="The log message.") timestamp: DateTimeTZ = Field(default=..., description="The log timestamp.") - flow_run_id: UUID = Field( - default=..., description="The flow run ID associated with the log." + flow_run_id: Optional[UUID] = Field( + default=None, description="The flow run ID associated with the log." ) task_run_id: Optional[UUID] = Field( default=None, description="The task run ID associated with the log." diff --git a/tests/server/api/test_task_runs.py b/tests/server/api/test_task_runs.py index 5e2b007ad412..4159690d3bea 100644 --- a/tests/server/api/test_task_runs.py +++ b/tests/server/api/test_task_runs.py @@ -42,6 +42,29 @@ async def test_create_task_run_gracefully_upserts(self, flow_run, client): assert response.status_code == status.HTTP_200_OK assert response.json()["id"] == task_run_response.json()["id"] + async def test_create_task_run_without_flow_run_id(self, flow_run, client, session): + task_run_data = { + "flow_run_id": None, + "task_key": "my-task-key", + "name": "my-cool-task-run-name", + "dynamic_key": "0", + } + response = await client.post("/task_runs/", json=task_run_data) + assert response.status_code == status.HTTP_201_CREATED + assert response.json()["flow_run_id"] is None + assert response.json()["id"] + assert response.json()["name"] == "my-cool-task-run-name" + + task_run = await models.task_runs.read_task_run( + session=session, task_run_id=response.json()["id"] + ) + assert task_run.flow_run_id is None + + # Posting the same data twice should result in an upsert + response_2 = await client.post("/task_runs/", json=task_run_data) + assert response_2.status_code == status.HTTP_200_OK + assert response.json()["id"] == response_2.json()["id"] + async def test_create_task_run_without_state(self, flow_run, client, session): task_run_data = dict( flow_run_id=str(flow_run.id), task_key="task-key", dynamic_key="0" @@ -345,6 +368,43 @@ async def test_set_task_run_state(self, task_run, client, session): assert run.state.name == "Test State" assert run.run_count == 1 + async def test_set_task_run_state_without_flow_run_id(self, client, session): + task_run_data = { + "flow_run_id": None, + "task_key": "my-task-key", + "name": "my-cool-task-run-name", + "dynamic_key": "0", + } + response = await client.post("/task_runs/", json=task_run_data) + assert response.status_code == status.HTTP_201_CREATED + assert response.json()["flow_run_id"] is None + assert response.json()["id"] + assert response.json()["name"] == "my-cool-task-run-name" + + task_run = await models.task_runs.read_task_run( + session=session, task_run_id=response.json()["id"] + ) + assert task_run.flow_run_id is None + + orchestration_response = await client.post( + f"/task_runs/{task_run.id}/set_state", + json=dict(state=dict(type="RUNNING", name="Test State")), + ) + assert orchestration_response.status_code == status.HTTP_201_CREATED + + api_response = OrchestrationResult.parse_obj(orchestration_response.json()) + assert api_response.status == responses.SetStateStatus.ACCEPT + + task_run_id = task_run.id + session.expire_all() + run = await models.task_runs.read_task_run( + session=session, task_run_id=task_run_id + ) + assert run.state.type == states.StateType.RUNNING + assert run.state.name == "Test State" + assert run.run_count == 1 + assert run.flow_run_id is None + @pytest.mark.parametrize("proposed_state", ["PENDING", "RUNNING"]) async def test_setting_task_run_state_twice_works( self, task_run, client, session, proposed_state