Skip to content

Commit

Permalink
fix message deletion (#328)
Browse files Browse the repository at this point in the history
* fix sqs message deletion in extractor
  • Loading branch information
grantleehoffman authored Jul 2, 2024
1 parent 8cd5d5a commit 022e246
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
7 changes: 3 additions & 4 deletions nodestream/pipeline/extractors/files.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import bz2
import gzip
import io
import json
import os
import tempfile
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager, contextmanager
from csv import DictReader
from glob import glob
from io import BufferedReader, IOBase, StringIO, TextIOWrapper
from io import BufferedReader, BytesIO, IOBase, StringIO, TextIOWrapper
from pathlib import Path
from typing import Any, AsyncGenerator, Callable, Generator, Iterable

Expand Down Expand Up @@ -211,7 +210,7 @@ def read_file_from_handle(

class GzipFileFormat(SupportedCompressedFileFormat, alias=".gz"):
def decompress_file(self) -> IngestibleFile:
decompressed_data = io.BytesIO()
decompressed_data = BytesIO()
with gzip.open(self.file.path, "rb") as f_in:
chunk_size = 1024 * 1024
while True:
Expand All @@ -230,7 +229,7 @@ def decompress_file(self) -> IngestibleFile:

class Bz2FileFormat(SupportedCompressedFileFormat, alias=".bz2"):
def decompress_file(self) -> IngestibleFile:
decompressed_data = io.BytesIO()
decompressed_data = BytesIO()
with bz2.open(self.file.path, "rb") as f_in:
chunk_size = 1024 * 1024
while True:
Expand Down
10 changes: 8 additions & 2 deletions nodestream/pipeline/extractors/queues/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ async def get_next_messsage_batch(self):
def process_messages(self, messages):
for message in messages:
yield message["Body"]
if self.delete_after_read:
messages.delete()
if self.delete_after_read:
self.delete_message(message)

def get_message_batch(self) -> AsyncGenerator[Any, Any]:
return self.sqs_client.receive_message(
Expand All @@ -84,3 +84,9 @@ def get_message_batch(self) -> AsyncGenerator[Any, Any]:
MessageAttributeNames=self.message_attribute_names,
MaxNumberOfMessages=self.max_batch_size,
)

def delete_message(self, msg):
receipt_handle = msg.get("ReceiptHandle")
return self.sqs_client.delete_message(
QueueUrl=self.queue_url, ReceiptHandle=receipt_handle
)
1 change: 1 addition & 0 deletions tests/unit/pipeline/extractors/queues/test_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ async def test_poll(subject):
@pytest.mark.asyncio
async def test_poll_batch_1(subject):
subject.max_batch_size = 1
subject.max_batches = 1
results = await subject.poll()
assert results == ["test-message-1"]
results = await subject.poll()
Expand Down

0 comments on commit 022e246

Please sign in to comment.