Skip to content

Commit

Permalink
Add a task to Redis if it is no longer tracked after NACKing
Browse files Browse the repository at this point in the history
  • Loading branch information
catileptic authored and stchris committed Jun 19, 2024
1 parent 9eab35a commit c8fb3a6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
26 changes: 25 additions & 1 deletion servicelayer/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ def mark_for_retry(self, task):
f" for retry after NACK"
)

pipe.sadd(make_key(stage_key, "pending"), task.task_id)
pipe.srem(make_key(stage_key, "running"), task.task_id)
pipe.delete(task.retry_key)
pipe.srem(stage_key, task.task_id)
Expand All @@ -336,6 +335,28 @@ def __str__(self):
def get_stage_key(self, stage):
return make_key(PREFIX, "qds", self.name, stage)

def is_task_tracked(self, task: Task):
tracked = True

pipe = self.conn.pipeline()
stage_key = self.get_stage_key(task.operation)
dataset = dataset = dataset_from_collection_id(task.collection_id)
task_id = task.task_id
stage = task.operation

# A task is considered tracked if
# the dataset is in the list of active datasets
if dataset not in self.conn.smembers(self.key):
tracked = False
# and the stage is in the list of active stages
elif stage not in self.conn.smembers(self.active_stages_key):
tracked = False
# and the task_id is in the list of task_ids per stage
elif task_id not in self.conn.smembers(stage_key):
tracked = False

return tracked


def get_task(body, delivery_tag) -> Task:
body = json.loads(body)
Expand Down Expand Up @@ -624,10 +645,13 @@ def ack_message(self, task, channel):

def nack_message(self, task, channel, requeue=True):
"""NACK task and update status."""

apply_task_context(task, v=self.version)
log.info(f"NACKing message {task.delivery_tag} for task_id {task.task_id}")
dataset = task.get_dataset(conn=self.conn)
# Sync state to redis
if not dataset.is_task_tracked(task):
dataset.add_task()
dataset.mark_for_retry(task)
if channel.is_open:
channel.basic_nack(delivery_tag=task.delivery_tag, requeue=requeue)
Expand Down
7 changes: 5 additions & 2 deletions tests/test_taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,16 @@ def test_task_that_shouldnt_execute(self, mock_should_execute):
task_id = "test-task"
priority = randrange(1, settings.RABBITMQ_MAX_PRIORITY + 1)
body = {
"collection_id": 2,
"job_id": "test-job",
"task_id": "test-task",
"job_id": "test-job",
"delivery_tag": 0,
"operation": "test-op",
"context": {},
"payload": {},
"priority": priority,
"collection_id": 2,
}

connection = get_rabbitmq_connection()
channel = connection.channel()
declare_rabbitmq_queue(channel, test_queue_name)
Expand Down Expand Up @@ -180,3 +182,4 @@ def did_nack():
stage = status["datasets"]["2"]["stages"][0]
assert stage["pending"] == 1
assert stage["running"] == 0
assert dataset.is_task_tracked(Task(**body))

0 comments on commit c8fb3a6

Please sign in to comment.