Skip to content

Commit

Permalink
Support storing data in parquet and utsv bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
raghumdani committed Aug 7, 2023
1 parent 707506a commit 544fd9e
Showing 1 changed file with 66 additions and 14 deletions.
80 changes: 66 additions & 14 deletions deltacat/tests/local_deltacat_storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pyarrow as pa
import json
import sqlite3
from sqlite3 import Cursor, Connection
import uuid
from ray import cloudpickle
import io
from deltacat.utils.common import current_time_ms

from deltacat.storage import (
Expand Down Expand Up @@ -38,6 +39,10 @@
from deltacat.types.media import ContentType, StorageType, TableType, ContentEncoding
from deltacat.utils.common import ReadKwargsProvider

SQLITE_CUR_ARG = "sqlite3_cur"
SQLITE_CON_ARG = "sqlite3_con"
DB_FILE_PATH_ARG = "db_file_path"

STORAGE_TYPE = "SQLITE3"
STREAM_ID_PROPERTY = "stream_id"
CREATE_NAMESPACES_TABLE = (
Expand Down Expand Up @@ -67,7 +72,14 @@


def _get_sqlite3_cursor_con(kwargs) -> Tuple[Cursor, Connection]:
return kwargs["sqlite3_cur"], kwargs["sqlite3_con"]
if SQLITE_CUR_ARG in kwargs and SQLITE_CON_ARG in kwargs:
return kwargs[SQLITE_CUR_ARG], kwargs[SQLITE_CON_ARG]
elif DB_FILE_PATH_ARG in kwargs:
con = sqlite3.connect(kwargs[DB_FILE_PATH_ARG])
cur = con.cursor()
return cur, con

raise ValueError(f"Invalid local db connection kwargs: {kwargs}")


def _get_manifest_entry_uri(manifest_entry_id: str) -> str:
Expand Down Expand Up @@ -332,19 +344,46 @@ def download_delta_manifest_entry(
entry = manifest.entries[entry_index]

res = cur.execute("SELECT value FROM data WHERE uri = ?", (entry.uri,))
pickled = res.fetchone()
serialized_data = res.fetchone()

if pickled is None:
if serialized_data is None:
raise ValueError(
f"Invalid value of delta locator: {delta_like.canonical_string()}"
)

table = cloudpickle.loads(pickled[0])

if columns:
return table.select(columns)
serialized_data = serialized_data[0]

if entry.meta.content_type == ContentType.PARQUET:
if table_type == TableType.PYARROW_PARQUET:
table = pa.parquet.ParquetFile(io.BytesIO(serialized_data))
else:
table = pa.parquet.read_table(io.BytesIO(serialized_data), columns=columns)
elif entry.meta.content_type == ContentType.UNESCAPED_TSV:
assert (
table_type != TableType.PYARROW_PARQUET
), f"uTSV table cannot be read as {table_type}"
parse_options = pa.csv.ParseOptions(delimiter="\t")
convert_options = pa.csv.ConvertOptions(
null_values=[""], strings_can_be_null=True, include_columns=columns
)
table = pa.csv.read_csv(
io.BytesIO(serialized_data),
parse_options=parse_options,
convert_options=convert_options,
)
else:
raise ValueError(f"Content type: {entry.meta.content_type} not supported.")

if table_type == TableType.PYARROW:
return table
elif table_type == TableType.PYARROW_PARQUET:
return table
elif table_type == TableType.NUMPY:
raise NotImplementedError(f"Table type={table_type} not supported")
elif table_type == TableType.PANDAS:
return table.to_pandas()

return table


def get_delta_manifest(
Expand Down Expand Up @@ -427,7 +466,7 @@ def create_table_version(
if (
table_version is not None
and latest_version
and int(latest_version.table_version) != 1 + int(table_version)
and int(latest_version.table_version) + 1 != int(table_version)
):
raise AssertionError(
f"Table version can only be incremented. Last version={latest_version.table_version}"
Expand Down Expand Up @@ -799,17 +838,30 @@ def stage_delta(
manifest_entry_id = uuid.uuid4().__str__()
uri = _get_manifest_entry_uri(manifest_entry_id)

pickled = cloudpickle.dumps(data)
serialized_data = None
if content_type == ContentType.PARQUET:
buffer = io.BytesIO()
pa.parquet.write_table(data, buffer)
serialized_data = buffer.getvalue()
elif content_type == ContentType.UNESCAPED_TSV:
buffer = io.BytesIO()
write_options = pa.csv.WriteOptions(
include_header=True, delimiter="\t", quoting_style="none"
)
pa.csv.write_csv(data, buffer, write_options=write_options)
serialized_data = buffer.getvalue()
else:
raise ValueError(f"Unsupported content type: {content_type}")

stream_position = current_time_ms()
delta_locator = DeltaLocator.of(partition.locator, stream_position=stream_position)

meta = ManifestMeta.of(
len(data),
len(pickled),
content_type=ContentType.PARQUET,
len(serialized_data),
content_type=content_type,
content_encoding=ContentEncoding.IDENTITY,
source_content_length=len(pickled),
source_content_length=data.nbytes,
)

manifest = Manifest.of(
Expand All @@ -833,7 +885,7 @@ def stage_delta(
previous_stream_position=partition.stream_position,
)

params = (uri, pickled)
params = (uri, serialized_data)
cur.execute("INSERT OR IGNORE INTO data VALUES (?, ?)", params)

params = (delta_locator.canonical_string(), "staged_delta", json.dumps(delta))
Expand Down

0 comments on commit 544fd9e

Please sign in to comment.