From 022e24653290d976e046125ca8787f6892f42057 Mon Sep 17 00:00:00 2001 From: Grant Hoffman Date: Tue, 2 Jul 2024 09:23:43 -0700 Subject: [PATCH] fix message deletion (#328) * fix sqs message deletion in extractor --- nodestream/pipeline/extractors/files.py | 7 +++---- nodestream/pipeline/extractors/queues/sqs.py | 10 ++++++++-- tests/unit/pipeline/extractors/queues/test_sqs.py | 1 + 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/nodestream/pipeline/extractors/files.py b/nodestream/pipeline/extractors/files.py index 1aebb767f..19fa1af5d 100644 --- a/nodestream/pipeline/extractors/files.py +++ b/nodestream/pipeline/extractors/files.py @@ -1,6 +1,5 @@ import bz2 import gzip -import io import json import os import tempfile @@ -8,7 +7,7 @@ from contextlib import asynccontextmanager, contextmanager from csv import DictReader from glob import glob -from io import BufferedReader, IOBase, StringIO, TextIOWrapper +from io import BufferedReader, BytesIO, IOBase, StringIO, TextIOWrapper from pathlib import Path from typing import Any, AsyncGenerator, Callable, Generator, Iterable @@ -211,7 +210,7 @@ def read_file_from_handle( class GzipFileFormat(SupportedCompressedFileFormat, alias=".gz"): def decompress_file(self) -> IngestibleFile: - decompressed_data = io.BytesIO() + decompressed_data = BytesIO() with gzip.open(self.file.path, "rb") as f_in: chunk_size = 1024 * 1024 while True: @@ -230,7 +229,7 @@ def decompress_file(self) -> IngestibleFile: class Bz2FileFormat(SupportedCompressedFileFormat, alias=".bz2"): def decompress_file(self) -> IngestibleFile: - decompressed_data = io.BytesIO() + decompressed_data = BytesIO() with bz2.open(self.file.path, "rb") as f_in: chunk_size = 1024 * 1024 while True: diff --git a/nodestream/pipeline/extractors/queues/sqs.py b/nodestream/pipeline/extractors/queues/sqs.py index ba0b96e42..a90e05f00 100644 --- a/nodestream/pipeline/extractors/queues/sqs.py +++ b/nodestream/pipeline/extractors/queues/sqs.py @@ -74,8 +74,8 @@ async def get_next_messsage_batch(self): def process_messages(self, messages): for message in messages: yield message["Body"] - if self.delete_after_read: - messages.delete() + if self.delete_after_read: + self.delete_message(message) def get_message_batch(self) -> AsyncGenerator[Any, Any]: return self.sqs_client.receive_message( @@ -84,3 +84,9 @@ def get_message_batch(self) -> AsyncGenerator[Any, Any]: MessageAttributeNames=self.message_attribute_names, MaxNumberOfMessages=self.max_batch_size, ) + + def delete_message(self, msg): + receipt_handle = msg.get("ReceiptHandle") + return self.sqs_client.delete_message( + QueueUrl=self.queue_url, ReceiptHandle=receipt_handle + ) diff --git a/tests/unit/pipeline/extractors/queues/test_sqs.py b/tests/unit/pipeline/extractors/queues/test_sqs.py index 75d26e8d7..9ab8c8a96 100644 --- a/tests/unit/pipeline/extractors/queues/test_sqs.py +++ b/tests/unit/pipeline/extractors/queues/test_sqs.py @@ -107,6 +107,7 @@ async def test_poll(subject): @pytest.mark.asyncio async def test_poll_batch_1(subject): subject.max_batch_size = 1 + subject.max_batches = 1 results = await subject.poll() assert results == ["test-message-1"] results = await subject.poll()