Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change parquet writers to use standard std:io::Write rather custom ParquetWriter trait (#1717) (#1163) #1719

Merged
merged 9 commits into from
May 25, 2022
8 changes: 3 additions & 5 deletions parquet/benches/arrow_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ use std::sync::Arc;

use arrow::datatypes::*;
use arrow::{record_batch::RecordBatch, util::data_gen::*};
use parquet::{
arrow::ArrowWriter, errors::Result, file::writer::InMemoryWriteableCursor,
};
use parquet::{arrow::ArrowWriter, errors::Result};

fn create_primitive_bench_batch(
size: usize,
Expand Down Expand Up @@ -278,8 +276,8 @@ fn _create_nested_bench_batch(
#[inline]
fn write_batch(batch: &RecordBatch) -> Result<()> {
// Write batch to an in-memory writer
let cursor = InMemoryWriteableCursor::default();
let mut writer = ArrowWriter::try_new(cursor, batch.schema(), None)?;
let buffer = vec![];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is a nice example of the new API in action: use something that does std::io::Write

let mut writer = ArrowWriter::try_new(buffer, batch.schema(), None)?;

writer.write(batch)?;
writer.close()?;
Expand Down
2 changes: 1 addition & 1 deletion parquet/src/arrow/array_reader/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ mod tests {
.set_max_row_group_size(200)
.build();

let mut writer = ArrowWriter::try_new(
let writer = ArrowWriter::try_new(
file.try_clone().unwrap(),
Arc::new(arrow_schema),
Some(props),
Expand Down
39 changes: 22 additions & 17 deletions parquet/src/arrow/arrow_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,14 @@ mod tests {
use crate::arrow::schema::add_encoded_arrow_schema_to_metadata;
use crate::arrow::{ArrowWriter, ProjectionMask};
use crate::basic::{ConvertedType, Encoding, Repetition, Type as PhysicalType};
use crate::column::writer::get_typed_column_writer_mut;
use crate::data_type::{
BoolType, ByteArray, ByteArrayType, DataType, FixedLenByteArray,
FixedLenByteArrayType, Int32Type, Int64Type,
};
use crate::errors::Result;
use crate::file::properties::{WriterProperties, WriterVersion};
use crate::file::reader::{FileReader, SerializedFileReader};
use crate::file::writer::{FileWriter, SerializedFileWriter};
use crate::file::writer::SerializedFileWriter;
use crate::schema::parser::parse_message_type;
use crate::schema::types::{Type, TypePtr};
use crate::util::cursor::SliceableCursor;
Expand Down Expand Up @@ -936,21 +935,24 @@ mod tests {
for (idx, v) in values.iter().enumerate() {
let def_levels = def_levels.map(|d| d[idx].as_slice());
let mut row_group_writer = writer.next_row_group()?;
let mut column_writer = row_group_writer
.next_column()?
.expect("Column writer is none!");
{
let mut column_writer = row_group_writer
.next_column()?
.expect("Column writer is none!");

get_typed_column_writer_mut::<T>(&mut column_writer)
.write_batch(v, def_levels, None)?;
column_writer
.typed::<T>()
.write_batch(v, def_levels, None)?;

row_group_writer.close_column(column_writer)?;
writer.close_row_group(row_group_writer)?
column_writer.close()?;
}
row_group_writer.close()?;
}

writer.close()
}

fn get_test_reader(file_name: &str) -> Arc<dyn FileReader> {
fn get_test_reader(file_name: &str) -> Arc<SerializedFileReader<File>> {
let file = get_test_file(file_name);

let reader =
Expand Down Expand Up @@ -1094,15 +1096,18 @@ mod tests {
)
.unwrap();

let mut row_group_writer = writer.next_row_group().unwrap();
let mut column_writer = row_group_writer.next_column().unwrap().unwrap();
{
let mut row_group_writer = writer.next_row_group().unwrap();
let mut column_writer = row_group_writer.next_column().unwrap().unwrap();

get_typed_column_writer_mut::<Int32Type>(&mut column_writer)
.write_batch(&[34, 76], Some(&[0, 1, 0, 1]), None)
.unwrap();
column_writer
.typed::<Int32Type>()
.write_batch(&[34, 76], Some(&[0, 1, 0, 1]), None)
.unwrap();

row_group_writer.close_column(column_writer).unwrap();
writer.close_row_group(row_group_writer).unwrap();
column_writer.close().unwrap();
row_group_writer.close().unwrap();
}

writer.close().unwrap();
}
Expand Down
44 changes: 21 additions & 23 deletions parquet/src/arrow/arrow_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Contains writer which writes arrow data into parquet data.

use std::collections::VecDeque;
use std::io::Write;
use std::sync::Arc;

use arrow::array as arrow_array;
Expand All @@ -35,18 +36,16 @@ use super::schema::{
use crate::column::writer::ColumnWriter;
use crate::errors::{ParquetError, Result};
use crate::file::properties::WriterProperties;
use crate::{
data_type::*,
file::writer::{FileWriter, ParquetWriter, RowGroupWriter, SerializedFileWriter},
};
use crate::file::writer::{SerializedColumnWriter, SerializedRowGroupWriter};
use crate::{data_type::*, file::writer::SerializedFileWriter};

/// Arrow writer
///
/// Writes Arrow `RecordBatch`es to a Parquet writer, buffering up `RecordBatch` in order
/// to produce row groups with `max_row_group_size` rows. Any remaining rows will be
/// flushed on close, leading the final row group in the output file to potentially
/// contain fewer than `max_row_group_size` rows
pub struct ArrowWriter<W: ParquetWriter> {
pub struct ArrowWriter<W: Write> {
/// Underlying Parquet writer
writer: SerializedFileWriter<W>,

Expand All @@ -65,7 +64,7 @@ pub struct ArrowWriter<W: ParquetWriter> {
max_row_group_size: usize,
}

impl<W: 'static + ParquetWriter> ArrowWriter<W> {
impl<W: Write> ArrowWriter<W> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemingly small change means you can pass in &mut std::io::Cursor<Vec<_>> or any other construction which makes this much easier to use

/// Try to create a new Arrow writer
///
/// The writer will fail if:
Expand Down Expand Up @@ -185,33 +184,35 @@ impl<W: 'static + ParquetWriter> ArrowWriter<W> {
})
.collect();

write_leaves(row_group_writer.as_mut(), &arrays, &mut levels)?;
write_leaves(&mut row_group_writer, &arrays, &mut levels)?;
}

self.writer.close_row_group(row_group_writer)?;
row_group_writer.close().unwrap();
tustvold marked this conversation as resolved.
Show resolved Hide resolved
self.buffered_rows -= num_rows;

Ok(())
}

/// Close and finalize the underlying Parquet writer
pub fn close(&mut self) -> Result<parquet_format::FileMetaData> {
pub fn close(mut self) -> Result<parquet_format::FileMetaData> {
self.flush()?;
self.writer.close()
}
}

/// Convenience method to get the next ColumnWriter from the RowGroupWriter
#[inline]
fn get_col_writer(row_group_writer: &mut dyn RowGroupWriter) -> Result<ColumnWriter> {
fn get_col_writer<'a, W: Write>(
row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>,
) -> Result<SerializedColumnWriter<'a>> {
let col_writer = row_group_writer
.next_column()?
.expect("Unable to get column writer");
Ok(col_writer)
}

fn write_leaves(
row_group_writer: &mut dyn RowGroupWriter,
fn write_leaves<W: Write>(
row_group_writer: &mut SerializedRowGroupWriter<'_, W>,
arrays: &[ArrayRef],
levels: &mut [Vec<LevelInfo>],
) -> Result<()> {
Expand Down Expand Up @@ -250,12 +251,12 @@ fn write_leaves(
let mut col_writer = get_col_writer(row_group_writer)?;
for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
write_leaf(
&mut col_writer,
col_writer.untyped(),
array,
levels.pop().expect("Levels exhausted"),
)?;
}
row_group_writer.close_column(col_writer)?;
col_writer.close()?;
Ok(())
}
ArrowDataType::List(_) | ArrowDataType::LargeList(_) => {
Expand Down Expand Up @@ -313,12 +314,12 @@ fn write_leaves(
// cast dictionary to a primitive
let array = arrow::compute::cast(array, value_type)?;
write_leaf(
&mut col_writer,
col_writer.untyped(),
&array,
levels.pop().expect("Levels exhausted"),
)?;
}
row_group_writer.close_column(col_writer)?;
col_writer.close()?;
Ok(())
}
ArrowDataType::Float16 => Err(ParquetError::ArrowError(
Expand All @@ -336,8 +337,8 @@ fn write_leaves(
}

fn write_leaf(
writer: &mut ColumnWriter,
column: &arrow_array::ArrayRef,
writer: &mut ColumnWriter<'_>,
column: &ArrayRef,
levels: LevelInfo,
) -> Result<i64> {
let indices = levels.filter_array_indices();
Expand Down Expand Up @@ -705,7 +706,6 @@ mod tests {
use crate::file::{
reader::{FileReader, SerializedFileReader},
statistics::Statistics,
writer::InMemoryWriteableCursor,
};

#[test]
Expand Down Expand Up @@ -744,16 +744,14 @@ mod tests {
let expected_batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(a), Arc::new(b)]).unwrap();

let cursor = InMemoryWriteableCursor::default();
let mut buffer = vec![];

{
let mut writer = ArrowWriter::try_new(cursor.clone(), schema, None).unwrap();
let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap();
writer.write(&expected_batch).unwrap();
writer.close().unwrap();
}

let buffer = cursor.into_inner().unwrap();

let cursor = crate::file::serialized_reader::SliceableCursor::new(buffer);
let reader = SerializedFileReader::new(cursor).unwrap();
let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(reader));
Expand Down
4 changes: 2 additions & 2 deletions parquet/src/arrow/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,7 @@ mod tests {

// write to an empty parquet file so that schema is serialized
let file = tempfile::tempfile().unwrap();
let mut writer = ArrowWriter::try_new(
let writer = ArrowWriter::try_new(
file.try_clone().unwrap(),
Arc::new(schema.clone()),
None,
Expand Down Expand Up @@ -1660,7 +1660,7 @@ mod tests {

// write to an empty parquet file so that schema is serialized
let file = tempfile::tempfile().unwrap();
let mut writer = ArrowWriter::try_new(
let writer = ArrowWriter::try_new(
file.try_clone().unwrap(),
Arc::new(schema.clone()),
None,
Expand Down
22 changes: 10 additions & 12 deletions parquet/src/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@
//!
//! use parquet::{
//! column::{reader::ColumnReader, writer::ColumnWriter},
//! data_type::Int32Type,
//! file::{
//! properties::WriterProperties,
//! reader::{FileReader, SerializedFileReader},
//! writer::{FileWriter, SerializedFileWriter},
//! writer::SerializedFileWriter,
//! },
//! schema::parser::parse_message_type,
//! };
Expand All @@ -65,20 +66,17 @@
//! let props = Arc::new(WriterProperties::builder().build());
//! let file = fs::File::create(path).unwrap();
//! let mut writer = SerializedFileWriter::new(file, schema, props).unwrap();
//!
//! let mut row_group_writer = writer.next_row_group().unwrap();
//! while let Some(mut col_writer) = row_group_writer.next_column().unwrap() {
//! match col_writer {
//! // You can also use `get_typed_column_writer` method to extract typed writer.
//! ColumnWriter::Int32ColumnWriter(ref mut typed_writer) => {
//! typed_writer
//! .write_batch(&[1, 2, 3], Some(&[3, 3, 3, 2, 2]), Some(&[0, 1, 0, 1, 1]))
//! .unwrap();
//! }
//! _ => {}
//! }
//! row_group_writer.close_column(col_writer).unwrap();
//! col_writer
//! .typed::<Int32Type>()
//! .write_batch(&[1, 2, 3], Some(&[3, 3, 3, 2, 2]), Some(&[0, 1, 0, 1, 1]))
//! .unwrap();
//! col_writer.close().unwrap();
//! }
//! writer.close_row_group(row_group_writer).unwrap();
//! row_group_writer.close().unwrap();
//!
//! writer.close().unwrap();
//!
//! // Reading data using column reader API.
Expand Down
Loading