Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gzip and bz2 compressed file format support #320

Merged
merged 14 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/docs/reference/extractors.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ With the previous minimal configuration, it will use your currently active aws c
| assume_role_external_id | String | The external id that is required to assume role. Only used when `assume_role_arn` is set and only needed when the role is configured to require an external id. |
| **session_args | Any | Any other argument that you want sent to the [boto3.Session](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html) that will be used to interact with AWS. |

### Support For Compressed File Formats
The system seamlessly handles decompression of objects stored in `.gz`
and `.bz2`file formats. Files are automatically decompressed and processed based on their underlying content type, indicated by the file extension. For instance, a gzip-compressed JSON file should be named with the `.json.gz` extension to ensure it is correctly identified and read as JSON after decompression.

## `FileExtractor`

Expand Down
9 changes: 9 additions & 0 deletions docs/docs/reference/file-formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,12 @@ You will get a record in the following shape:

`.yaml` files are loaded using `yaml.load` and the one record is returned per file.
The record is the entire parsed contents of the `.yaml` file.


# Compressed File Formats

## `.gz`
`{File Format Extension}.gz` files are decompressed using `gzip.open` and stripped of the .gz extension and processed in the subsequent extension.

## `.bz2`
`{File Format Extension}.bz2` files are decompressed using `bz2.open` and stripped of the .bz2 extension and processed in the subsequent extension.
141 changes: 135 additions & 6 deletions nodestream/pipeline/extractors/files.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import bz2
import gzip
import io
import json
import os
import tempfile
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager, contextmanager
from csv import DictReader
from glob import glob
from io import BufferedReader, IOBase, StringIO, TextIOWrapper
from pathlib import Path
from typing import Any, AsyncGenerator, Iterable, Union
from typing import Any, AsyncGenerator, Callable, Generator, Iterable

import pandas as pd
from httpx import AsyncClient
Expand All @@ -18,13 +22,68 @@
from .extractor import Extractor

SUPPORTED_FILE_FORMAT_REGISTRY = SubclassRegistry()
SUPPORTED_COMPRESSED_FILE_FORMAT_REGISTRY = SubclassRegistry()


class IngestibleFile:
def __init__(
self,
path: Path,
fp: IOBase | None = None,
delete_on_ingestion: bool = False,
on_ingestion: Callable[[Any], Any] = lambda: (),
) -> None:
self.extension = path.suffix
self.suffixes = path.suffixes
self.fp = fp
self.path = path
self.delete_on_ingestion = delete_on_ingestion
self.on_ingestion = on_ingestion

@classmethod
def from_file_pointer_and_suffixes(
cls,
fp: IOBase,
suffixes: str | list[str],
on_ingestion: Callable[[Any], Any] = lambda: (),
) -> "IngestibleFile":
fd, temp_path = tempfile.mkstemp(suffix="".join(suffixes))
os.close(fd)

with open(temp_path, "wb") as temp_file:
for chunk in iter(lambda: fp.read(1024), b""):
temp_file.write(chunk)
temp_file.flush()

with open(temp_path, "rb+") as fp:
return IngestibleFile(Path(temp_path), fp, True, on_ingestion)

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
if self.fp:
self.fp.close()
if self.delete_on_ingestion:
self.tempfile_cleanup()

def ingested(self):
if self.delete_on_ingestion:
self.tempfile_cleanup()
self.on_ingestion()

def tempfile_cleanup(self):
if self.fp:
self.fp.close()
if os.path.isfile(self.path):
os.remove(self.path)


@SUPPORTED_FILE_FORMAT_REGISTRY.connect_baseclass
class SupportedFileFormat(Pluggable, ABC):
reader = None

def __init__(self, file: Union[Path, IOBase]) -> None:
def __init__(self, file: IngestibleFile) -> None:
self.file = file

@contextmanager
Expand All @@ -49,9 +108,15 @@

@classmethod
@contextmanager
def open(cls, file: Path) -> "SupportedFileFormat":
with open(file, "rb") as fp:
yield cls.from_file_pointer_and_format(fp, file.suffix)
def open(cls, file: IngestibleFile) -> Generator["SupportedFileFormat", None, None]:
extension = file.extension
# Decompress file if in Supported Compressed File Format Registry
while extension in SUPPORTED_COMPRESSED_FILE_FORMAT_REGISTRY:
compressed_file_format = SupportedCompressedFileFormat.open(file)
file = compressed_file_format.decompress_file()
extension = file.extension
with open(file.path, "rb") as fp:
yield cls.from_file_pointer_and_format(fp, extension)

@classmethod
def from_file_pointer_and_format(
Expand All @@ -67,6 +132,32 @@
...


@SUPPORTED_COMPRESSED_FILE_FORMAT_REGISTRY.connect_baseclass
class SupportedCompressedFileFormat(Pluggable, ABC):
def __init__(self, file: IngestibleFile) -> None:
self.file = file

@classmethod
def open(cls, file: IngestibleFile) -> "SupportedCompressedFileFormat":
with open(file.path, "rb") as fp:
return cls.from_file_pointer_and_path(fp, file.path)

@classmethod
def from_file_pointer_and_path(
cls, fp: IOBase, path: Path
) -> "SupportedCompressedFileFormat":
# Import all compression file formats so that they can register themselves
cls.import_all()
file_format = SUPPORTED_COMPRESSED_FILE_FORMAT_REGISTRY.get(path.suffix)
file = IngestibleFile(path, fp)
file.on_ingestion = lambda: file.tempfile_cleanup()
return file_format(file)

@abstractmethod
def decompress_file(self) -> IngestibleFile:
...

Check warning on line 158 in nodestream/pipeline/extractors/files.py

View check run for this annotation

Codecov / codecov/patch

nodestream/pipeline/extractors/files.py#L158

Added line #L158 was not covered by tests


class JsonFileFormat(SupportedFileFormat, alias=".json"):
reader = TextIOWrapper

Expand Down Expand Up @@ -118,6 +209,44 @@
return [safe_load(reader)]


class GzipFileFormat(SupportedCompressedFileFormat, alias=".gz"):
def decompress_file(self) -> IngestibleFile:
decompressed_data = io.BytesIO()
with gzip.open(self.file.path, "rb") as f_in:
chunk_size = 1024 * 1024
while True:
chunk = f_in.read(chunk_size)
if len(chunk) == 0:
break
decompressed_data.write(chunk)
decompressed_data.seek(0)
new_path = self.file.path.with_suffix("")
temp_file = IngestibleFile.from_file_pointer_and_suffixes(
decompressed_data, new_path.suffixes
)
self.file.ingested()
return temp_file


class Bz2FileFormat(SupportedCompressedFileFormat, alias=".bz2"):
def decompress_file(self) -> IngestibleFile:
decompressed_data = io.BytesIO()
with bz2.open(self.file.path, "rb") as f_in:
chunk_size = 1024 * 1024
while True:
chunk = f_in.read(chunk_size)
if len(chunk) == 0:
break
decompressed_data.write(chunk)
decompressed_data.seek(0)
new_path = self.file.path.with_suffix("")
temp_file = IngestibleFile.from_file_pointer_and_suffixes(
decompressed_data, new_path.suffixes
)
self.file.ingested()
return temp_file


class FileExtractor(Extractor):
@classmethod
def from_file_data(cls, globs: Iterable[str]):
Expand All @@ -137,7 +266,7 @@

async def extract_records(self) -> AsyncGenerator[Any, Any]:
for path in self._ordered_paths():
with SupportedFileFormat.open(path) as file:
with SupportedFileFormat.open(IngestibleFile(path)) as file:
for record in file.read_file():
yield record

Expand Down
31 changes: 20 additions & 11 deletions nodestream/pipeline/extractors/stores/aws/s3_extractor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from io import StringIO
from contextlib import contextmanager
from logging import getLogger
from pathlib import Path
from typing import Any, AsyncGenerator, Optional
from typing import Any, AsyncGenerator, Generator, Optional

from ...credential_utils import AwsClientFactory
from ...extractor import Extractor
from ...files import SupportedFileFormat
from ...files import IngestibleFile, SupportedFileFormat


class S3Extractor(Extractor):
Expand Down Expand Up @@ -41,8 +41,13 @@ def __init__(
self.s3_client = s3_client
self.logger = getLogger(__name__)

def get_object_as_io(self, key: str) -> StringIO:
return self.s3_client.get_object(Bucket=self.bucket, Key=key)["Body"]
@contextmanager
def get_object_as_tempfile(self, key: str):
streaming_body = self.s3_client.get_object(Bucket=self.bucket, Key=key)["Body"]
file = IngestibleFile.from_file_pointer_and_suffixes(
streaming_body, Path(key).suffixes, lambda: self.archive_s3_object(key)
)
yield file

def archive_s3_object(self, key: str):
if self.archive_dir:
Expand All @@ -63,10 +68,13 @@ def infer_object_format(self, key: str) -> str:
)
return object_format

def get_object_as_file(self, key: str) -> SupportedFileFormat:
io = self.get_object_as_io(key)
object_format = self.infer_object_format(key)
return SupportedFileFormat.from_file_pointer_and_format(io, object_format)
@contextmanager
def get_object_as_file(
self, key: str
) -> Generator[SupportedFileFormat, None, None]:
with self.get_object_as_tempfile(key) as temp_file:
with SupportedFileFormat.open(temp_file) as file_format:
yield file_format

def is_object_in_archive(self, key: str) -> bool:
if self.archive_dir:
Expand All @@ -84,7 +92,8 @@ def find_keys_in_bucket(self) -> list[str]:
async def extract_records(self) -> AsyncGenerator[Any, Any]:
for key in self.find_keys_in_bucket():
try:
for record in self.get_object_as_file(key).read_file():
yield record
with self.get_object_as_file(key) as records:
for record in records.read_file():
yield record
finally:
self.archive_s3_object(key)
3 changes: 3 additions & 0 deletions nodestream/subclass_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(self, ignore_overrides: bool = False) -> None:
self.linked_base = None
self.ignore_overrides = ignore_overrides

def __contains__(self, sub_class):
return sub_class in self.registry

def connect_baseclass(self, base_class):
"""Connect a base class to this registry."""

Expand Down
Loading