diff --git a/python/ray/data/_internal/datasource/parquet_datasink.py b/python/ray/data/_internal/datasource/parquet_datasink.py index 4dffa939d772..2b8edc11d531 100644 --- a/python/ray/data/_internal/datasource/parquet_datasink.py +++ b/python/ray/data/_internal/datasource/parquet_datasink.py @@ -72,13 +72,16 @@ def write( write_kwargs = _resolve_kwargs( self.arrow_parquet_args_fn, **self.arrow_parquet_args ) + schema = write_kwargs.pop("schema", None) + if schema is None: + schema = BlockAccessor.for_block(blocks[0]).to_arrow().schema def write_blocks_to_path(): with self.open_output_stream(write_path) as file: - schema = BlockAccessor.for_block(blocks[0]).to_arrow().schema + tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks] with pq.ParquetWriter(file, schema, **write_kwargs) as writer: - for block in blocks: - table = BlockAccessor.for_block(block).to_arrow() + for table in tables: + table = table.cast(schema) writer.write_table(table) logger.debug(f"Writing {write_path} file.") diff --git a/python/ray/data/tests/test_parquet.py b/python/ray/data/tests/test_parquet.py index 6d3868a58946..3a185fd59644 100644 --- a/python/ray/data/tests/test_parquet.py +++ b/python/ray/data/tests/test_parquet.py @@ -1316,6 +1316,15 @@ def test_count_with_filter(ray_start_regular_shared): assert isinstance(ds.count(), int) +def test_write_with_schema(ray_start_regular_shared, tmp_path): + ds = ray.data.range(1) + schema = pa.schema({"id": pa.float32()}) + + ds.write_parquet(tmp_path, schema=schema) + + assert pq.read_table(tmp_path).schema == schema + + if __name__ == "__main__": import sys