Skip to content

Commit

Permalink
Fan out writing to multiple Parquet files (apache#444)
Browse files Browse the repository at this point in the history
* bin pack write

* add write target file size config

* test

* add test for multiple data files

* parquet writer write once

* parallelize write tasks

* refactor

* chunk correctly using to_batches

* change variable names

* get rid of assert

* configure PackingIterator

* add more tests

* rewrite set_properties

* set int property
  • Loading branch information
kevinjqliu authored Mar 28, 2024
1 parent 4c1cfdc commit 6aeb126
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 45 deletions.
91 changes: 52 additions & 39 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,54 +1761,67 @@ def data_file_statistics_from_parquet_metadata(


def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
task = next(tasks)

try:
_ = next(tasks)
# If there are more tasks, raise an exception
raise NotImplementedError("Only unpartitioned writes are supported: https://github.com/apache/iceberg-python/issues/208")
except StopIteration:
pass

parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)

file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
schema = table_metadata.schema()
arrow_file_schema = schema.as_arrow()
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)

fo = io.new_output(file_path)
row_group_size = PropertyUtil.property_as_int(
properties=table_metadata.properties,
property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT,
)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
writer.write_table(task.df, row_group_size=row_group_size)

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(schema),
)
data_file = DataFile(
content=DataFileContent.DATA,
file_path=file_path,
file_format=FileFormat.PARQUET,
partition=Record(),
file_size_in_bytes=len(fo),
# After this has been fixed:
# https://github.com/apache/iceberg-python/issues/271
# sort_order_id=task.sort_order_id,
sort_order_id=None,
# Just copy these from the table for now
spec_id=table_metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)

return iter([data_file])
def write_parquet(task: WriteTask) -> DataFile:
file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
writer.write(pa.Table.from_batches(task.record_batches), row_group_size=row_group_size)

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(schema),
)
data_file = DataFile(
content=DataFileContent.DATA,
file_path=file_path,
file_format=FileFormat.PARQUET,
partition=Record(),
file_size_in_bytes=len(fo),
# After this has been fixed:
# https://github.com/apache/iceberg-python/issues/271
# sort_order_id=task.sort_order_id,
sort_order_id=None,
# Just copy these from the table for now
spec_id=table_metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)

return data_file

executor = ExecutorFactory.get_or_create()
data_files = executor.map(write_parquet, tasks)

return iter(data_files)


def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[pa.RecordBatch]]:
from pyiceberg.utils.bin_packing import PackingIterator

avg_row_size_bytes = tbl.nbytes / tbl.num_rows
target_rows_per_file = target_file_size // avg_row_size_bytes
batches = tbl.to_batches(max_chunksize=target_rows_per_file)
bin_packed_record_batches = PackingIterator(
items=batches,
target_weight=target_file_size,
lookback=len(batches), # ignore lookback
weight_func=lambda x: x.nbytes,
largest_bin_first=False,
)
return bin_packed_record_batches


def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]:
Expand Down
19 changes: 16 additions & 3 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ class TableProperties:

PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX = "write.parquet.bloom-filter-enabled.column"

WRITE_TARGET_FILE_SIZE_BYTES = "write.target-file-size-bytes"
WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT = 512 * 1024 * 1024 # 512 MB

DEFAULT_WRITE_METRICS_MODE = "write.metadata.metrics.default"
DEFAULT_WRITE_METRICS_MODE_DEFAULT = "truncate(16)"

Expand Down Expand Up @@ -2486,7 +2489,7 @@ def _add_and_move_fields(
class WriteTask:
write_uuid: uuid.UUID
task_id: int
df: pa.Table
record_batches: List[pa.RecordBatch]
sort_order_id: Optional[int] = None

# Later to be extended with partition information
Expand Down Expand Up @@ -2521,17 +2524,27 @@ def _dataframe_to_data_files(
Returns:
An iterable that supplies datafiles that represent the table.
"""
from pyiceberg.io.pyarrow import write_file
from pyiceberg.io.pyarrow import bin_pack_arrow_table, write_file

if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 0]) > 0:
raise ValueError("Cannot write to partitioned tables")

counter = itertools.count(0)
write_uuid = write_uuid or uuid.uuid4()

target_file_size = PropertyUtil.property_as_int(
properties=table_metadata.properties,
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)

# This is an iter, so we don't have to materialize everything every time
# This will be more relevant when we start doing partitioned writes
yield from write_file(io=io, table_metadata=table_metadata, tasks=iter([WriteTask(write_uuid, next(counter), df)]))
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([WriteTask(write_uuid, next(counter), batches) for batches in bin_pack_arrow_table(df, target_file_size)]), # type: ignore
)


def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]:
Expand Down
59 changes: 58 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import socket
import string
import uuid
from datetime import datetime
from datetime import date, datetime
from pathlib import Path
from random import choice
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -1987,3 +1987,60 @@ def spark() -> SparkSession:
)

return spark


TEST_DATA_WITH_NULL = {
'bool': [False, None, True],
'string': ['a', None, 'z'],
# Go over the 16 bytes to kick in truncation
'string_long': ['a' * 22, None, 'z' * 22],
'int': [1, None, 9],
'long': [1, None, 9],
'float': [0.0, None, 0.9],
'double': [0.0, None, 0.9],
'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
'date': [date(2023, 1, 1), None, date(2023, 3, 1)],
# Not supported by Spark
# 'time': [time(1, 22, 0), None, time(19, 25, 0)],
# Not natively supported by Arrow
# 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, uuid.UUID('11111111-1111-1111-1111-111111111111').bytes],
'binary': [b'\01', None, b'\22'],
'fixed': [
uuid.UUID('00000000-0000-0000-0000-000000000000').bytes,
None,
uuid.UUID('11111111-1111-1111-1111-111111111111').bytes,
],
}


@pytest.fixture(scope="session")
def pa_schema() -> "pa.Schema":
import pyarrow as pa

return pa.schema([
("bool", pa.bool_()),
("string", pa.string()),
("string_long", pa.string()),
("int", pa.int32()),
("long", pa.int64()),
("float", pa.float32()),
("double", pa.float64()),
("timestamp", pa.timestamp(unit="us")),
("timestamptz", pa.timestamp(unit="us", tz="UTC")),
("date", pa.date32()),
# Not supported by Spark
# ("time", pa.time64("us")),
# Not natively supported by Arrow
# ("uuid", pa.fixed(16)),
("binary", pa.large_binary()),
("fixed", pa.binary(16)),
])


@pytest.fixture(scope="session")
def arrow_table_with_null(pa_schema: "pa.Schema") -> "pa.Table":
import pyarrow as pa

"""PyArrow table with all kinds of columns"""
return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema)
43 changes: 42 additions & 1 deletion tests/integration/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pyiceberg.catalog.sql import SqlCatalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.schema import Schema
from pyiceberg.table import Table, _dataframe_to_data_files
from pyiceberg.table import Table, TableProperties, _dataframe_to_data_files
from pyiceberg.typedef import Properties
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -383,6 +383,47 @@ def get_current_snapshot_id(identifier: str) -> int:
assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore


@pytest.mark.integration
def test_write_bin_pack_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.write_bin_pack_data_files"
tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [])

def get_data_files_count(identifier: str) -> int:
return spark.sql(
f"""
SELECT *
FROM {identifier}.files
"""
).count()

# writes 1 data file since the table is smaller than default target file size
assert arrow_table_with_null.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
tbl.overwrite(arrow_table_with_null)
assert get_data_files_count(identifier) == 1

# writes 1 data file as long as table is smaller than default target file size
bigger_arrow_tbl = pa.concat_tables([arrow_table_with_null] * 10)
assert bigger_arrow_tbl.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
tbl.overwrite(bigger_arrow_tbl)
assert get_data_files_count(identifier) == 1

# writes multiple data files once target file size is overridden
target_file_size = arrow_table_with_null.nbytes
tbl = tbl.transaction().set_properties({TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: target_file_size}).commit_transaction()
assert str(target_file_size) == tbl.properties.get(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES)
assert target_file_size < bigger_arrow_tbl.nbytes
tbl.overwrite(bigger_arrow_tbl)
assert get_data_files_count(identifier) == 10

# writes half the number of data files when target file size doubles
target_file_size = arrow_table_with_null.nbytes * 2
tbl = tbl.transaction().set_properties({TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: target_file_size}).commit_transaction()
assert str(target_file_size) == tbl.properties.get(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES)
assert target_file_size < bigger_arrow_tbl.nbytes
tbl.overwrite(bigger_arrow_tbl)
assert get_data_files_count(identifier) == 5


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
@pytest.mark.parametrize(
Expand Down
25 changes: 24 additions & 1 deletion tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,15 @@
_ConvertToArrowSchema,
_primitive_to_physical,
_read_deletes,
bin_pack_arrow_table,
expression_to_pyarrow,
project_table,
schema_to_pyarrow,
)
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema, make_compatible_name, visit
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table import FileScanTask, Table, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.typedef import UTF8
from pyiceberg.types import (
Expand Down Expand Up @@ -1710,3 +1711,25 @@ def test_stats_aggregator_update_max(vals: List[Any], primitive_type: PrimitiveT
stats.update_max(val)

assert stats.current_max == expected_result


def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None:
# default packs to 1 bin since the table is small
bin_packed = bin_pack_arrow_table(
arrow_table_with_null, target_file_size=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
)
assert len(list(bin_packed)) == 1

# as long as table is smaller than default target size, it should pack to 1 bin
bigger_arrow_tbl = pa.concat_tables([arrow_table_with_null] * 10)
assert bigger_arrow_tbl.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT)
assert len(list(bin_packed)) == 1

# unless we override the target size to be smaller
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes)
assert len(list(bin_packed)) == 10

# and will produce half the number of files if we double the target size
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes * 2)
assert len(list(bin_packed)) == 5

0 comments on commit 6aeb126

Please sign in to comment.