diff --git a/deltacat/aws/s3u.py b/deltacat/aws/s3u.py index 5ce4ecaa..49ae3c4b 100644 --- a/deltacat/aws/s3u.py +++ b/deltacat/aws/s3u.py @@ -3,6 +3,8 @@ from functools import partial from typing import Any, Callable, Dict, Generator, List, Optional, Union from uuid import uuid4 +from botocore.config import Config +from deltacat.aws.constants import BOTO_MAX_RETRIES import pyarrow as pa import ray @@ -385,14 +387,16 @@ def download_manifest_entry( content_encoding: Optional[ContentEncoding] = None, ) -> LocalTable: + conf = Config(retries={"max_attempts": BOTO_MAX_RETRIES, "mode": "adaptive"}) s3_client_kwargs = ( { "aws_access_key_id": token_holder["accessKeyId"], "aws_secret_access_key": token_holder["secretAccessKey"], "aws_session_token": token_holder["sessionToken"], + "config": conf, } if token_holder - else {} + else {"config": conf} ) if not content_type: content_type = manifest_entry.meta.content_type diff --git a/deltacat/exceptions.py b/deltacat/exceptions.py index 0f3db5ad..44a40da4 100644 --- a/deltacat/exceptions.py +++ b/deltacat/exceptions.py @@ -8,3 +8,7 @@ class NonRetryableError(Exception): class ConcurrentModificationError(Exception): pass + + +class ValidationError(NonRetryableError): + pass diff --git a/deltacat/storage/model/types.py b/deltacat/storage/model/types.py index 9e0db24a..fe1acd7a 100644 --- a/deltacat/storage/model/types.py +++ b/deltacat/storage/model/types.py @@ -1,6 +1,7 @@ from enum import Enum from typing import List, Union, Any +from pyarrow.parquet import ParquetFile import numpy as np import pandas as pd import pyarrow as pa @@ -8,7 +9,7 @@ from ray.data._internal.arrow_block import ArrowRow from ray.data.dataset import Dataset -LocalTable = Union[pa.Table, pd.DataFrame, np.ndarray] +LocalTable = Union[pa.Table, pd.DataFrame, np.ndarray, ParquetFile] LocalDataset = List[LocalTable] # Starting Ray 2.5.0, Dataset follows a strict mode (https://docs.ray.io/en/latest/data/faq.html#migrating-to-strict-mode), # and generic annotation is removed. So add a version checker to determine whether to use the old or new definition. diff --git a/deltacat/types/media.py b/deltacat/types/media.py index 155d73dd..1e502b98 100644 --- a/deltacat/types/media.py +++ b/deltacat/types/media.py @@ -41,6 +41,7 @@ class TableType(str, Enum): PYARROW = "pyarrow" PANDAS = "pandas" NUMPY = "numpy" + PYARROW_PARQUET = "pyarrow_parquet" class SchemaType(str, Enum): diff --git a/deltacat/types/tables.py b/deltacat/types/tables.py index f4b259da..dbb824fc 100644 --- a/deltacat/types/tables.py +++ b/deltacat/types/tables.py @@ -21,6 +21,7 @@ from deltacat.utils.ray_utils import dataset as ds_utils TABLE_TYPE_TO_READER_FUNC: Dict[int, Callable] = { + TableType.PYARROW_PARQUET.value: pa_utils.s3_file_to_parquet, TableType.PYARROW.value: pa_utils.s3_file_to_table, TableType.PANDAS.value: pd_utils.s3_file_to_dataframe, TableType.NUMPY.value: np_utils.s3_file_to_ndarray, diff --git a/deltacat/utils/pyarrow.py b/deltacat/utils/pyarrow.py index ec95ebc3..31a413cc 100644 --- a/deltacat/utils/pyarrow.py +++ b/deltacat/utils/pyarrow.py @@ -7,6 +7,9 @@ import logging from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional +from pyarrow.parquet import ParquetFile +import s3fs +from deltacat.exceptions import ValidationError import pyarrow as pa from fsspec import AbstractFileSystem @@ -272,6 +275,55 @@ def s3_file_to_table( return table +def s3_file_to_parquet( + s3_url: str, + content_type: str, + content_encoding: str, + column_names: Optional[List[str]] = None, + include_columns: Optional[List[str]] = None, + pa_read_func_kwargs_provider: Optional[ReadKwargsProvider] = None, + **s3_client_kwargs, +) -> ParquetFile: + logger.debug( + f"Reading {s3_url} to PyArrow ParquetFile. " + f"Content type: {content_type}. Encoding: {content_encoding}" + ) + + if ( + content_type != ContentType.PARQUET.value + or content_encoding != ContentEncoding.IDENTITY + ): + raise ValidationError( + f"S3 file with content type: {content_type} and " + f"content encoding: {content_encoding} cannot be read" + "into pyarrow.parquet.ParquetFile" + ) + + if s3_client_kwargs is None: + s3_client_kwargs = {} + + s3_file_system = s3fs.S3FileSystem( + key=s3_client_kwargs.get("aws_access_key_id"), + secret=s3_client_kwargs.get("aws_secret_access_key"), + token=s3_client_kwargs.get("aws_session_token"), + client_kwargs=s3_client_kwargs, + ) + + kwargs = {} + if pa_read_func_kwargs_provider: + kwargs = pa_read_func_kwargs_provider(content_type, kwargs) + + logger.debug( + f"Reading the file from {s3_url} into ParquetFile with kwargs: {kwargs}" + ) + pqFile, latency = timed_invocation( + lambda: ParquetFile(s3_url, filesystem=s3_file_system, **kwargs) + ) + logger.debug(f"Time to get {s3_url} into parquet file: {latency}s") + + return pqFile + + def table_size(table: pa.Table) -> int: return table.nbytes diff --git a/requirements.txt b/requirements.txt index 802e87bc..2bff2c42 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ boto3 ~= 1.20 numpy == 1.21.5 pandas == 1.3.5 -pyarrow == 10.0.1 +pyarrow == 12.0.1 pydantic == 1.10.4 pymemcache == 4.0.0 ray[default] ~= 2.0 diff --git a/setup.py b/setup.py index 4ed7da30..5c30baeb 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ def find_version(*paths): "boto3 ~= 1.20", "numpy == 1.21.5", "pandas == 1.3.5", - "pyarrow == 10.0.1", + "pyarrow == 12.0.1", "pydantic == 1.10.4", "ray[default] ~= 2.0", "s3fs == 2022.2.0",