Skip to content

Commit

Permalink
[Data] Allow user to specify schema for write_parquet (ray-project#…
Browse files Browse the repository at this point in the history
…48631)

Fixes ray-project#48630

Signed-off-by: Balaji Veeramani <[email protected]>
Signed-off-by: mohitjain2504 <[email protected]>
  • Loading branch information
bveeramani authored and mohitjain2504 committed Nov 15, 2024
1 parent 3be1cf1 commit f77ec0b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/ray/data/_internal/datasource/parquet_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,16 @@ def write(
write_kwargs = _resolve_kwargs(
self.arrow_parquet_args_fn, **self.arrow_parquet_args
)
schema = write_kwargs.pop("schema", None)
if schema is None:
schema = BlockAccessor.for_block(blocks[0]).to_arrow().schema

def write_blocks_to_path():
with self.open_output_stream(write_path) as file:
schema = BlockAccessor.for_block(blocks[0]).to_arrow().schema
tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
with pq.ParquetWriter(file, schema, **write_kwargs) as writer:
for block in blocks:
table = BlockAccessor.for_block(block).to_arrow()
for table in tables:
table = table.cast(schema)
writer.write_table(table)

logger.debug(f"Writing {write_path} file.")
Expand Down
9 changes: 9 additions & 0 deletions python/ray/data/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,15 @@ def test_count_with_filter(ray_start_regular_shared):
assert isinstance(ds.count(), int)


def test_write_with_schema(ray_start_regular_shared, tmp_path):
ds = ray.data.range(1)
schema = pa.schema({"id": pa.float32()})

ds.write_parquet(tmp_path, schema=schema)

assert pq.read_table(tmp_path).schema == schema


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit f77ec0b

Please sign in to comment.