Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

submitted state lock #291

Merged
merged 6 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changes/pr291.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
enhancement:
- "Add a submitted state lock to prevent the same flow from being run by multiple agents - [#291](https://github.com/PrefectHQ/server/pull/291)"
11 changes: 11 additions & 0 deletions src/prefect_server/api/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
from prefect.engine.state import Cancelled, Cancelling, State
from prefect.utilities.plugins import register_api
from prefect_server.utilities import events
from prefect_server.utilities.collections import ExpiringSet
from prefect_server.utilities.logging import get_logger

logger = get_logger("api")

state_schema = prefect.serialization.state.StateSchema()

submitted_state_lock = ExpiringSet(duration_seconds=30)


@register_api("states.set_flow_run_state")
async def set_flow_run_state(
Expand Down Expand Up @@ -52,6 +55,14 @@ async def set_flow_run_state(
if not flow_run:
raise ValueError(f"State update failed for flow run ID {flow_run_id}")

if state.is_submitted():
lock_name = f"{flow_run.id}-submitted-lock"
submitted_lock = submitted_state_lock.add(lock_name)
if submitted_lock is False:
raise ValueError(
"State update failed: this run has already been submitted."
)

# --------------------------------------------------------
# apply downstream updates
# --------------------------------------------------------
Expand Down
65 changes: 64 additions & 1 deletion src/prefect_server/utilities/collections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
from typing import Iterable
from typing import Hashable, Iterable, Optional, Union
import threading


def chunked_iterable(iterable: Iterable, size: int):
Expand All @@ -19,3 +20,65 @@ def chunked_iterable(iterable: Iterable, size: int):
if not chunk:
break
yield chunk


class ExpiringSet:
"""
Stores values in a set for a specific amount of time.

Args:
- duration_seconds (float, int): the number of seconds to put a value in the set before removing it
"""

def __init__(self, duration_seconds: Union[float, int] = 10):
self.values = set()
self.lock = threading.Lock()
self.duration_seconds = duration_seconds

def add(
self, value: Hashable, duration_seconds: Optional[Union[float, int]] = None
) -> bool:
"""
Add a value to set and kick off a timer to remove it

Args:
- value (Hashable): value to add, must be hashable and
able to be stored in a set
- duration_seconds (float, int, optional): the number of seconds to add a value for
before removing it. Defaults to self.duration if not provided
Returns:
- bool: `True` if the value did not already exist in the set
and is successfully added, otherwise `False`
"""
if duration_seconds is None:
duration_seconds = self.duration_seconds

with self.lock:
jakekaplan marked this conversation as resolved.
Show resolved Hide resolved
if value in self.values:
return False
else:
self.values.add(value)
threading.Timer(
duration_seconds, self.remove, kwargs={"value": value}
).start()
return True

def remove(self, value: Hashable):
"""
Remove a value from the set

Args:
- value (Hashable): value to remove from set
"""
with self.lock:
self.values.discard(value)

def exists(self, value: Hashable):
jakekaplan marked this conversation as resolved.
Show resolved Hide resolved
"""
Check if a value exists in the set

Args:
- value (Hashable): value to remove from set
"""
with self.lock:
return value in self.values
10 changes: 10 additions & 0 deletions tests/api/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ async def test_state_does_not_set_heartbeat_unless_running_or_submitted(
flow_run = await models.FlowRun.where(id=flow_run_id).first({"heartbeat"})
assert flow_run.heartbeat is None

async def test_state_submitted_lock_on_same_flow_run_id(self, flow_run_id):
await api.states.set_flow_run_state(flow_run_id=flow_run_id, state=Submitted())
with pytest.raises(
ValueError,
match="State update failed: this run has already been submitted.",
):
await api.states.set_flow_run_state(
flow_run_id=flow_run_id, state=Submitted()
)

@pytest.mark.parametrize("state", [Running(), Submitted()])
async def test_running_and_submitted_state_sets_heartbeat(self, state, flow_run_id):
"""
Expand Down
31 changes: 31 additions & 0 deletions tests/graphql/test_states.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
import uuid

import pytest
Expand Down Expand Up @@ -176,6 +177,36 @@ async def test_set_flow_run_states_rejects_states_with_large_payloads(
)
assert "State payload is too large" in result.errors[0].message

async def test_set_flow_run_states_submitted_lock_on_same_flow_run_id(
self, run_query, flow_run_id
):
result = await run_query(
query=self.mutation,
variables=dict(
input=dict(
states=[
dict(flow_run_id=flow_run_id, state=Submitted().serialize())
]
)
),
)
bad_result = await run_query(
query=self.mutation,
variables=dict(
input=dict(
states=[
dict(flow_run_id=flow_run_id, state=Submitted().serialize())
]
)
),
)

assert result.data.set_flow_run_states.states[0].status == "SUCCESS"
assert (
"State update failed: this run has already been submitted."
in bad_result.errors[0].message
)


# ---------------------------------------------------------------
# Task runs
Expand Down
22 changes: 21 additions & 1 deletion tests/utilities/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from prefect_server.utilities.collections import chunked_iterable
import time

from prefect_server.utilities.collections import chunked_iterable, ExpiringSet


def test_chunked_iterable_of_list():
Expand All @@ -10,3 +12,21 @@ def test_chunked_iterable_of_list():
def test_chunked_iterable_of_empty_iterable():
chunks = [chunk for chunk in chunked_iterable([], 4)]
assert len(chunks) == 0


class TestExpiringSet:
def test_add_and_expire(self):
store = ExpiringSet()
added = store.add("value1", duration_seconds=2)
assert added
assert store.exists("value1")
time.sleep(3)
assert not store.exists("value1")

def test_add_and_expire_unique(self):
store = ExpiringSet()
added = store.add("value1", duration_seconds=2)
assert added
assert store.exists("value1")
added = store.add("value1", duration_seconds=3)
assert not added