Skip to content

Commit

Permalink
add gzip and bz2 compressed file format support (#320)
Browse files Browse the repository at this point in the history
* add gzip and bz2 compressed file format support
  • Loading branch information
grantleehoffman authored Jun 27, 2024
1 parent f6e493a commit 6f28e4d
Show file tree
Hide file tree
Showing 7 changed files with 349 additions and 22 deletions.
3 changes: 3 additions & 0 deletions docs/docs/reference/extractors.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ With the previous minimal configuration, it will use your currently active aws c
| assume_role_external_id | String | The external id that is required to assume role. Only used when `assume_role_arn` is set and only needed when the role is configured to require an external id. |
| **session_args | Any | Any other argument that you want sent to the [boto3.Session](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html) that will be used to interact with AWS. |

### Support For Compressed File Formats
The system seamlessly handles decompression of objects stored in `.gz`
and `.bz2`file formats. Files are automatically decompressed and processed based on their underlying content type, indicated by the file extension. For instance, a gzip-compressed JSON file should be named with the `.json.gz` extension to ensure it is correctly identified and read as JSON after decompression.

## `FileExtractor`

Expand Down
9 changes: 9 additions & 0 deletions docs/docs/reference/file-formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,12 @@ You will get a record in the following shape:

`.yaml` files are loaded using `yaml.load` and the one record is returned per file.
The record is the entire parsed contents of the `.yaml` file.


# Compressed File Formats

## `.gz`
`{File Format Extension}.gz` files are decompressed using `gzip.open` and stripped of the .gz extension and processed in the subsequent extension.

## `.bz2`
`{File Format Extension}.bz2` files are decompressed using `bz2.open` and stripped of the .bz2 extension and processed in the subsequent extension.
141 changes: 135 additions & 6 deletions nodestream/pipeline/extractors/files.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import bz2
import gzip
import io
import json
import os
import tempfile
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager, contextmanager
from csv import DictReader
from glob import glob
from io import BufferedReader, IOBase, StringIO, TextIOWrapper
from pathlib import Path
from typing import Any, AsyncGenerator, Iterable, Union
from typing import Any, AsyncGenerator, Callable, Generator, Iterable

import pandas as pd
from httpx import AsyncClient
Expand All @@ -18,13 +22,68 @@
from .extractor import Extractor

SUPPORTED_FILE_FORMAT_REGISTRY = SubclassRegistry()
SUPPORTED_COMPRESSED_FILE_FORMAT_REGISTRY = SubclassRegistry()


class IngestibleFile:
def __init__(
self,
path: Path,
fp: IOBase | None = None,
delete_on_ingestion: bool = False,
on_ingestion: Callable[[Any], Any] = lambda: (),
) -> None:
self.extension = path.suffix
self.suffixes = path.suffixes
self.fp = fp
self.path = path
self.delete_on_ingestion = delete_on_ingestion
self.on_ingestion = on_ingestion

@classmethod
def from_file_pointer_and_suffixes(
cls,
fp: IOBase,
suffixes: str | list[str],
on_ingestion: Callable[[Any], Any] = lambda: (),
) -> "IngestibleFile":
fd, temp_path = tempfile.mkstemp(suffix="".join(suffixes))
os.close(fd)

with open(temp_path, "wb") as temp_file:
for chunk in iter(lambda: fp.read(1024), b""):
temp_file.write(chunk)
temp_file.flush()

with open(temp_path, "rb+") as fp:
return IngestibleFile(Path(temp_path), fp, True, on_ingestion)

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
if self.fp:
self.fp.close()
if self.delete_on_ingestion:
self.tempfile_cleanup()

def ingested(self):
if self.delete_on_ingestion:
self.tempfile_cleanup()
self.on_ingestion()

def tempfile_cleanup(self):
if self.fp:
self.fp.close()
if os.path.isfile(self.path):
os.remove(self.path)


@SUPPORTED_FILE_FORMAT_REGISTRY.connect_baseclass
class SupportedFileFormat(Pluggable, ABC):
reader = None

def __init__(self, file: Union[Path, IOBase]) -> None:
def __init__(self, file: IngestibleFile) -> None:
self.file = file

@contextmanager
Expand All @@ -49,9 +108,15 @@ def read_file(self) -> Iterable[JsonLikeDocument]:

@classmethod
@contextmanager
def open(cls, file: Path) -> "SupportedFileFormat":
with open(file, "rb") as fp:
yield cls.from_file_pointer_and_format(fp, file.suffix)
def open(cls, file: IngestibleFile) -> Generator["SupportedFileFormat", None, None]:
extension = file.extension
# Decompress file if in Supported Compressed File Format Registry
while extension in SUPPORTED_COMPRESSED_FILE_FORMAT_REGISTRY:
compressed_file_format = SupportedCompressedFileFormat.open(file)
file = compressed_file_format.decompress_file()
extension = file.extension
with open(file.path, "rb") as fp:
yield cls.from_file_pointer_and_format(fp, extension)

@classmethod
def from_file_pointer_and_format(
Expand All @@ -67,6 +132,32 @@ def read_file_from_handle(self, fp: BufferedReader) -> Iterable[JsonLikeDocument
...


@SUPPORTED_COMPRESSED_FILE_FORMAT_REGISTRY.connect_baseclass
class SupportedCompressedFileFormat(Pluggable, ABC):
def __init__(self, file: IngestibleFile) -> None:
self.file = file

@classmethod
def open(cls, file: IngestibleFile) -> "SupportedCompressedFileFormat":
with open(file.path, "rb") as fp:
return cls.from_file_pointer_and_path(fp, file.path)

@classmethod
def from_file_pointer_and_path(
cls, fp: IOBase, path: Path
) -> "SupportedCompressedFileFormat":
# Import all compression file formats so that they can register themselves
cls.import_all()
file_format = SUPPORTED_COMPRESSED_FILE_FORMAT_REGISTRY.get(path.suffix)
file = IngestibleFile(path, fp)
file.on_ingestion = lambda: file.tempfile_cleanup()
return file_format(file)

@abstractmethod
def decompress_file(self) -> IngestibleFile:
...


class JsonFileFormat(SupportedFileFormat, alias=".json"):
reader = TextIOWrapper

Expand Down Expand Up @@ -118,6 +209,44 @@ def read_file_from_handle(
return [safe_load(reader)]


class GzipFileFormat(SupportedCompressedFileFormat, alias=".gz"):
def decompress_file(self) -> IngestibleFile:
decompressed_data = io.BytesIO()
with gzip.open(self.file.path, "rb") as f_in:
chunk_size = 1024 * 1024
while True:
chunk = f_in.read(chunk_size)
if len(chunk) == 0:
break
decompressed_data.write(chunk)
decompressed_data.seek(0)
new_path = self.file.path.with_suffix("")
temp_file = IngestibleFile.from_file_pointer_and_suffixes(
decompressed_data, new_path.suffixes
)
self.file.ingested()
return temp_file


class Bz2FileFormat(SupportedCompressedFileFormat, alias=".bz2"):
def decompress_file(self) -> IngestibleFile:
decompressed_data = io.BytesIO()
with bz2.open(self.file.path, "rb") as f_in:
chunk_size = 1024 * 1024
while True:
chunk = f_in.read(chunk_size)
if len(chunk) == 0:
break
decompressed_data.write(chunk)
decompressed_data.seek(0)
new_path = self.file.path.with_suffix("")
temp_file = IngestibleFile.from_file_pointer_and_suffixes(
decompressed_data, new_path.suffixes
)
self.file.ingested()
return temp_file


class FileExtractor(Extractor):
@classmethod
def from_file_data(cls, globs: Iterable[str]):
Expand All @@ -137,7 +266,7 @@ def _ordered_paths(self) -> Iterable[Path]:

async def extract_records(self) -> AsyncGenerator[Any, Any]:
for path in self._ordered_paths():
with SupportedFileFormat.open(path) as file:
with SupportedFileFormat.open(IngestibleFile(path)) as file:
for record in file.read_file():
yield record

Expand Down
31 changes: 20 additions & 11 deletions nodestream/pipeline/extractors/stores/aws/s3_extractor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from io import StringIO
from contextlib import contextmanager
from logging import getLogger
from pathlib import Path
from typing import Any, AsyncGenerator, Optional
from typing import Any, AsyncGenerator, Generator, Optional

from ...credential_utils import AwsClientFactory
from ...extractor import Extractor
from ...files import SupportedFileFormat
from ...files import IngestibleFile, SupportedFileFormat


class S3Extractor(Extractor):
Expand Down Expand Up @@ -41,8 +41,13 @@ def __init__(
self.s3_client = s3_client
self.logger = getLogger(__name__)

def get_object_as_io(self, key: str) -> StringIO:
return self.s3_client.get_object(Bucket=self.bucket, Key=key)["Body"]
@contextmanager
def get_object_as_tempfile(self, key: str):
streaming_body = self.s3_client.get_object(Bucket=self.bucket, Key=key)["Body"]
file = IngestibleFile.from_file_pointer_and_suffixes(
streaming_body, Path(key).suffixes, lambda: self.archive_s3_object(key)
)
yield file

def archive_s3_object(self, key: str):
if self.archive_dir:
Expand All @@ -63,10 +68,13 @@ def infer_object_format(self, key: str) -> str:
)
return object_format

def get_object_as_file(self, key: str) -> SupportedFileFormat:
io = self.get_object_as_io(key)
object_format = self.infer_object_format(key)
return SupportedFileFormat.from_file_pointer_and_format(io, object_format)
@contextmanager
def get_object_as_file(
self, key: str
) -> Generator[SupportedFileFormat, None, None]:
with self.get_object_as_tempfile(key) as temp_file:
with SupportedFileFormat.open(temp_file) as file_format:
yield file_format

def is_object_in_archive(self, key: str) -> bool:
if self.archive_dir:
Expand All @@ -84,7 +92,8 @@ def find_keys_in_bucket(self) -> list[str]:
async def extract_records(self) -> AsyncGenerator[Any, Any]:
for key in self.find_keys_in_bucket():
try:
for record in self.get_object_as_file(key).read_file():
yield record
with self.get_object_as_file(key) as records:
for record in records.read_file():
yield record
finally:
self.archive_s3_object(key)
3 changes: 3 additions & 0 deletions nodestream/subclass_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(self, ignore_overrides: bool = False) -> None:
self.linked_base = None
self.ignore_overrides = ignore_overrides

def __contains__(self, sub_class):
return sub_class in self.registry

def connect_baseclass(self, base_class):
"""Connect a base class to this registry."""

Expand Down
Loading

0 comments on commit 6f28e4d

Please sign in to comment.