From 0e40d3fd7ef816b04b5744c338de6c82c7ff7b2c Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Wed, 28 Aug 2024 20:21:45 +0800 Subject: [PATCH] improve shuffle performance --- native-engine/blaze-jni-bridge/src/conf.rs | 4 +- .../src/streams/coalesce_stream.rs | 46 ++- .../src/common/ipc_compression.rs | 344 ++++++++---------- .../src/ipc_writer_exec.rs | 4 +- .../src/shuffle/buffered_data.rs | 102 ++++-- .../src/shuffle/rss_single_repartitioner.rs | 10 +- .../src/shuffle/single_repartitioner.rs | 9 +- .../src/shuffle/sort_repartitioner.rs | 7 +- 8 files changed, 256 insertions(+), 270 deletions(-) diff --git a/native-engine/blaze-jni-bridge/src/conf.rs b/native-engine/blaze-jni-bridge/src/conf.rs index e2db103d4..41908ec0e 100644 --- a/native-engine/blaze-jni-bridge/src/conf.rs +++ b/native-engine/blaze-jni-bridge/src/conf.rs @@ -79,13 +79,13 @@ pub trait DoubleConf { pub trait StringConf { fn key(&self) -> &'static str; - fn value(&self) -> Result<&'static str> { + fn value(&self) -> Result { let key = jni_new_string!(self.key())?; let value = jni_get_string!( jni_call_static!(BlazeConf.stringConf(key.as_obj()) -> JObject)? .as_obj() .into() )?; - Ok(Box::leak(value.into_boxed_str())) + Ok(value) } } diff --git a/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs b/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs index 4b9dadd26..fc0374548 100644 --- a/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs +++ b/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs @@ -101,26 +101,8 @@ impl CoalesceStream { fn coalesce(&mut self) -> Result { // better concat_batches() implementation that releases old batch columns asap. let schema = self.input.schema(); - - // collect all columns - let mut all_cols = schema.fields().iter().map(|_| vec![]).collect::>(); - for batch in std::mem::take(&mut self.staging_batches) { - for i in 0..all_cols.len() { - all_cols[i].push(batch.column(i).clone()); - } - } - - // coalesce each column - let mut coalesced_cols = vec![]; - for (cols, field) in all_cols.into_iter().zip(schema.fields()) { - let dt = field.data_type(); - coalesced_cols.push(coalesce_arrays_unchecked(dt, &cols)); - } - let coalesced_batch = RecordBatch::try_new_with_options( - schema, - coalesced_cols, - &RecordBatchOptions::new().with_row_count(Some(self.staging_rows)), - )?; + let coalesced_batch = coalesce_batches_unchecked(schema, &self.staging_batches); + self.staging_batches.clear(); self.staging_rows = 0; self.staging_batches_mem_size = 0; Ok(coalesced_batch) @@ -177,6 +159,30 @@ impl Stream for CoalesceStream { } } +/// coalesce batches without checking there schemas, invokers must make +/// sure all arrays have the same schema +pub fn coalesce_batches_unchecked(schema: SchemaRef, batches: &[RecordBatch]) -> RecordBatch { + let num_rows = batches.iter().map(|b| b.num_rows()).sum::(); + let num_fields = schema.fields().len(); + let mut coalesced_cols = vec![]; + + for i in 0..num_fields { + let data_type = schema.field(i).data_type(); + let mut cols = Vec::with_capacity(batches.len()); + for j in 0..batches.len() { + cols.push(batches[j].column(i).clone()); + } + coalesced_cols.push(coalesce_arrays_unchecked(data_type, &cols)); + } + + RecordBatch::try_new_with_options( + schema, + coalesced_cols, + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + ) + .expect("error coalescing record batch") +} + /// coalesce arrays without checking there data types, invokers must make /// sure all arrays have the same data type pub fn coalesce_arrays_unchecked(data_type: &DataType, arrays: &[ArrayRef]) -> ArrayRef { diff --git a/native-engine/datafusion-ext-plans/src/common/ipc_compression.rs b/native-engine/datafusion-ext-plans/src/common/ipc_compression.rs index af60e19f3..37dce42db 100644 --- a/native-engine/datafusion-ext-plans/src/common/ipc_compression.rs +++ b/native-engine/datafusion-ext-plans/src/common/ipc_compression.rs @@ -18,60 +18,69 @@ use std::io::{BufReader, Read, Take, Write}; use arrow::array::ArrayRef; -use blaze_jni_bridge::{conf, conf::StringConf}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use blaze_jni_bridge::{conf, conf::StringConf, is_jni_bridge_inited}; use datafusion::common::Result; use datafusion_ext_commons::{ df_execution_err, - io::{read_one_batch, write_one_batch}, + io::{read_len, read_one_batch, write_len, write_one_batch}, }; +use once_cell::sync::OnceCell; pub const DEFAULT_SHUFFLE_COMPRESSION_TARGET_BUF_SIZE: usize = 4194304; const ZSTD_LEVEL: i32 = 1; pub struct IpcCompressionWriter { output: W, - compressed: bool, - buf: Box, - buf_empty: bool, + shared_buf: VecBuffer, + block_writer: IoCompressionWriter, + block_empty: bool, } unsafe impl Send for IpcCompressionWriter {} impl IpcCompressionWriter { - pub fn new(output: W, compressed: bool) -> Self { + pub fn new(output: W) -> Self { + let mut shared_buf = VecBuffer::default(); + let block_writer = IoCompressionWriter::new_with_configured_codec(shared_buf.writer()); Self { output, - compressed, - buf: create_block_writer(compressed), - buf_empty: true, + shared_buf, + block_writer, + block_empty: true, } } /// Write a batch, returning uncompressed bytes size pub fn write_batch(&mut self, num_rows: usize, cols: &[ArrayRef]) -> Result<()> { - write_one_batch(num_rows, cols, &mut self.buf)?; - self.buf_empty = false; - if self.buf.buf_len() as f64 >= DEFAULT_SHUFFLE_COMPRESSION_TARGET_BUF_SIZE as f64 * 0.9 { - self.flush()?; + write_one_batch(num_rows, cols, &mut self.block_writer)?; + self.block_empty = false; + + let buf_len = self.shared_buf.inner().len(); + if buf_len as f64 >= DEFAULT_SHUFFLE_COMPRESSION_TARGET_BUF_SIZE as f64 * 0.9 { + self.finish_current_buf()?; } Ok(()) } - pub fn flush(&mut self) -> Result<()> { - if !self.buf_empty { - // finish current buf and open next - let next_buf = create_block_writer(self.compressed); - let block_data = std::mem::replace(&mut self.buf, next_buf).finish()?; - self.output.write_all(&block_data)?; - self.output.flush()?; - self.buf_empty = true; + pub fn finish_current_buf(&mut self) -> Result<()> { + if !self.block_empty { + // finish current buf + self.block_writer.finish()?; + + // write + write_len(self.shared_buf.inner().len(), &mut self.output)?; + self.output.write_all(self.shared_buf.inner())?; + self.shared_buf.inner_mut().clear(); + + // open next buf + self.block_writer = + IoCompressionWriter::new_with_configured_codec(self.shared_buf.writer()); + self.block_empty = true; } Ok(()) } - pub fn finish_into_inner(mut self) -> Result { - self.flush()?; - Ok(self.output) + pub fn inner(&self) -> &W { + &self.output } } @@ -85,7 +94,7 @@ enum InputState { #[default] Unreachable, BlockStart(R), - BlockContent(Box>), + BlockContent(IoCompressionReader>), } impl IpcCompressionReader { @@ -100,13 +109,22 @@ impl IpcCompressionReader { impl<'a, R: Read> Read for Reader<'a, R> { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { match std::mem::take(&mut self.0.input) { - InputState::Unreachable => unreachable!(), - InputState::BlockStart(input) => { - let block_reader = match create_block_reader(input)? { - Some(reader) => reader, - None => return Ok(0), + InputState::BlockStart(mut input) => { + let block_len = match read_len(&mut input) { + Ok(block_len) => block_len, + Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(0); + } + Err(err) => { + return Err(err); + } }; - self.0.input = InputState::BlockContent(block_reader); + let taken = input.take(block_len as u64); + + self.0.input = InputState::BlockContent(IoCompressionReader::try_new( + io_compression_codec(), + taken, + )?); self.read(buf) } InputState::BlockContent(mut block_reader) => match block_reader.read(buf) { @@ -116,11 +134,12 @@ impl IpcCompressionReader { } Ok(_zero) => { let input = block_reader.finish_into_inner()?; - self.0.input = InputState::BlockStart(input); + self.0.input = InputState::BlockStart(input.into_inner()); self.read(buf) } Err(err) => Err(err), }, + _ => unreachable!(), } } } @@ -128,160 +147,16 @@ impl IpcCompressionReader { } } -#[derive(Clone, Copy)] -struct Header { - compressed: bool, - block_len: usize, -} - -impl Header { - fn new(compressed: bool, block_len: usize) -> Self { - Self { - compressed, - block_len, - } - } - - fn from_u32(value: u32) -> Self { - let compressed = (value & 0x8000_0000) > 0; - let block_len = (value & 0x7fff_ffff) as usize; - Self::new(compressed, block_len) - } - - fn to_u32(&self) -> u32 { - (self.compressed as u32) << 31 | (self.block_len as u32) - } -} - -trait CompressibleBlockWriter: Write { - fn buf_len(&self) -> usize; - fn finish(self: Box) -> Result>; -} - -struct ZWriter(IoCompressionWriter>); - -impl ZWriter { - fn new() -> Self { - Self( - IoCompressionWriter::try_new(io_compression_codec(), vec![0u8; 4]) - .expect("error creating compression encoder"), - ) - } -} - -impl Write for ZWriter { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.0.write(buf) - } - - fn flush(&mut self) -> std::io::Result<()> { - self.0.flush() - } -} - -impl CompressibleBlockWriter for ZWriter { - fn buf_len(&self) -> usize { - self.0.get_ref().len() - } - - fn finish(self: Box) -> Result> { - let mut block_data = self.0.finish()?; - let header = Header::new(true, block_data.len() - 4); - block_data[0..4] - .as_mut() - .write_u32::(header.to_u32())?; - Ok(block_data) - } -} - -struct UncompressedWriter(Vec); - -impl UncompressedWriter { - fn new() -> Self { - Self(vec![0u8; 4]) - } -} - -impl Write for UncompressedWriter { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.0.write(buf) - } - - fn flush(&mut self) -> std::io::Result<()> { - self.0.flush() - } -} - -impl CompressibleBlockWriter for UncompressedWriter { - fn buf_len(&self) -> usize { - self.0.len() - } - - fn finish(self: Box) -> Result> { - let mut block_data = self.0; - let header = Header::new(false, block_data.len() - 4); - block_data[0..4] - .as_mut() - .write_u32::(header.to_u32())?; - Ok(block_data) - } -} - -trait CompressibleBlockReader: Read { - fn finish_into_inner(self: Box) -> Result; -} - -impl CompressibleBlockReader for IoCompressionReader<'_, Take> { - fn finish_into_inner(self: Box) -> Result { - let mut r = (*self).finish_into_inner()?; - std::io::copy(&mut r, &mut std::io::sink())?; // skip to end - Ok(r.into_inner()) - } -} - -impl CompressibleBlockReader for Take { - fn finish_into_inner(self: Box) -> Result { - Ok(self.into_inner()) - } -} - -fn create_block_writer(compressed: bool) -> Box { - if compressed { - Box::new(ZWriter::new()) - } else { - Box::new(UncompressedWriter::new()) - } -} - -fn create_block_reader( - mut input: R, -) -> Result>>> { - let header = match input.read_u32::() { - Ok(value) => Header::from_u32(value), - Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => { - return Ok(None); - } - Err(err) => { - return df_execution_err!("{err}"); - } - }; - - let taken = input.take(header.block_len as u64); - if !header.compressed { - return Ok(Some(Box::new(taken))); - } - Ok(Some(Box::new( - IoCompressionReader::try_new(io_compression_codec(), taken) - .expect("error creating compression decoder"), - ))) -} - enum IoCompressionWriter { LZ4(lz4_flex::frame::FrameEncoder), ZSTD(zstd::Encoder<'static, W>), } impl IoCompressionWriter { + fn new_with_configured_codec(inner: W) -> Self { + Self::try_new(io_compression_codec(), inner).expect("error creating compression encoder") + } + fn try_new(codec: &str, inner: W) -> Result { match codec { "lz4" => Ok(Self::LZ4(lz4_flex::frame::FrameEncoder::new(inner))), @@ -290,18 +165,17 @@ impl IoCompressionWriter { } } - fn get_ref(&self) -> &W { - match self { - IoCompressionWriter::LZ4(w) => w.get_ref(), - IoCompressionWriter::ZSTD(w) => w.get_ref(), - } - } - - fn finish(self) -> Result { + fn finish(&mut self) -> Result<()> { match self { - IoCompressionWriter::LZ4(w) => Ok(w.finish().or_else(|e| df_execution_err!("{e}"))?), - IoCompressionWriter::ZSTD(w) => Ok(w.finish().or_else(|e| df_execution_err!("{e}"))?), + IoCompressionWriter::LZ4(w) => { + w.try_finish() + .or_else(|_| df_execution_err!("ipc compresion error"))?; + } + IoCompressionWriter::ZSTD(w) => { + w.do_finish()?; + } } + Ok(()) } } @@ -321,12 +195,12 @@ impl Write for IoCompressionWriter { } } -enum IoCompressionReader<'a, R: Read> { +enum IoCompressionReader { LZ4(lz4_flex::frame::FrameDecoder), - ZSTD(zstd::Decoder<'a, BufReader>), + ZSTD(zstd::Decoder<'static, BufReader>), } -impl IoCompressionReader<'_, R> { +impl IoCompressionReader { fn try_new(codec: &str, inner: R) -> Result { match codec { "lz4" => Ok(Self::LZ4(lz4_flex::frame::FrameDecoder::new(inner))), @@ -343,7 +217,7 @@ impl IoCompressionReader<'_, R> { } } -impl Read for IoCompressionReader<'_, R> { +impl Read for IoCompressionReader { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { match self { Self::LZ4(r) => r.read(buf), @@ -353,5 +227,83 @@ impl Read for IoCompressionReader<'_, R> { } fn io_compression_codec() -> &'static str { - conf::SPARK_IO_COMPRESSION_CODEC.value().unwrap() + static CODEC: OnceCell = OnceCell::new(); + CODEC + .get_or_try_init(|| { + if is_jni_bridge_inited() { + conf::SPARK_IO_COMPRESSION_CODEC.value() + } else { + Ok(format!("lz4")) // for testing + } + }) + .expect("error reading spark.io.compression.codec") + .as_str() +} + +#[derive(Default)] +struct VecBuffer { + vec: Box>, +} + +struct VecBufferWrite { + unsafe_inner: *mut Vec, +} + +impl Write for VecBufferWrite { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let inner = unsafe { &mut *self.unsafe_inner }; + inner.extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +impl VecBuffer { + fn inner(&self) -> &Vec { + &self.vec + } + + fn inner_mut(&mut self) -> &mut Vec { + &mut self.vec + } + + fn writer(&mut self) -> VecBufferWrite { + VecBufferWrite { + unsafe_inner: &mut *self.vec as *mut Vec, + } + } +} + +#[cfg(test)] +mod tests { + use std::{error::Error, io::Cursor, sync::Arc}; + + use arrow::array::StringArray; + + use super::*; + + #[test] + fn test_ipc_compression() -> Result<(), Box> { + let mut buf = vec![]; + let mut writer = IpcCompressionWriter::new(&mut buf); + + let test_array1: ArrayRef = Arc::new(StringArray::from(vec![Some("hello"), Some("world")])); + writer.write_batch(2, &[test_array1.clone()])?; + let test_array2: ArrayRef = Arc::new(StringArray::from(vec![Some("foo"), Some("bar")])); + writer.write_batch(2, &[test_array2.clone()])?; + writer.finish_current_buf()?; + + let mut reader = IpcCompressionReader::new(Cursor::new(buf)); + let (num_rows1, arrays1) = reader.read_batch()?.unwrap(); + assert_eq!(num_rows1, 2); + assert_eq!(arrays1, &[test_array1]); + let (num_rows2, arrays2) = reader.read_batch()?.unwrap(); + assert_eq!(num_rows2, 2); + assert_eq!(arrays2, &[test_array2]); + assert!(reader.read_batch()?.is_none()); + Ok(()) + } } diff --git a/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs b/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs index 20fe5a292..9a7eea97d 100644 --- a/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs +++ b/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs @@ -149,7 +149,7 @@ pub async fn write_ipc( } } - let mut writer = IpcCompressionWriter::new(IpcConsumerWrite(ipc_consumer), true); + let mut writer = IpcCompressionWriter::new(IpcConsumerWrite(ipc_consumer)); while let Some(batch) = input.next().await.transpose()? { let _timer = metrics.elapsed_compute().timer(); writer.write_batch(batch.num_rows(), batch.columns())?; @@ -157,7 +157,7 @@ pub async fn write_ipc( } let _timer = metrics.elapsed_compute().timer(); - writer.finish_into_inner()?; + writer.finish_current_buf()?; Ok(()) }) } diff --git a/native-engine/datafusion-ext-plans/src/shuffle/buffered_data.rs b/native-engine/datafusion-ext-plans/src/shuffle/buffered_data.rs index 28ba8700b..593051aec 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/buffered_data.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/buffered_data.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{io::Write, mem::size_of}; +use std::io::Write; use arrow::record_batch::RecordBatch; use blaze_jni_bridge::jni_call; @@ -25,7 +25,9 @@ use datafusion_ext_commons::{ ds::rdx_tournament_tree::{KeyForRadixTournamentTree, RadixTournamentTree}, rdxsort::radix_sort_unstable_by_key, staging_mem_size_for_partial_sort, + streams::coalesce_stream::coalesce_batches_unchecked, }; +use itertools::Itertools; use jni::objects::GlobalRef; use crate::{ @@ -37,7 +39,7 @@ pub struct BufferedData { partition_id: usize, staging_batches: Vec, sorted_batches: Vec, - sorted_partition_indices: Vec>, + sorted_parts: Vec>, num_rows: usize, staging_mem_used: usize, sorted_mem_used: usize, @@ -49,7 +51,7 @@ impl BufferedData { partition_id, staging_batches: vec![], sorted_batches: vec![], - sorted_partition_indices: vec![], + sorted_parts: vec![], num_rows: 0, staging_mem_used: 0, sorted_mem_used: 0, @@ -81,13 +83,12 @@ impl BufferedData { let staging_batches = std::mem::take(&mut self.staging_batches); self.staging_mem_used = 0; - let (partition_indices, sorted_batch) = - sort_batches_by_partition_id(staging_batches, partitioning)?; + let (parts, sorted_batch) = sort_batches_by_partition_id(staging_batches, partitioning)?; self.sorted_mem_used += - sorted_batch.get_array_mem_size() + partition_indices.len() * size_of::(); + sorted_batch.get_array_mem_size() + parts.len() * size_of::(); self.sorted_batches.push(sorted_batch); - self.sorted_partition_indices.push(partition_indices); + self.sorted_parts.push(parts); Ok(()) } @@ -103,6 +104,7 @@ impl BufferedData { if self.num_rows == 0 { return Ok(vec![0; partitioning.partition_count() + 1]); } + let mut writer = IpcCompressionWriter::new(CountWrite::from(&mut w)); let mut offsets = vec![]; let mut offset = 0; let mut iter = self.into_sorted_batches(partitioning)?; @@ -114,12 +116,12 @@ impl BufferedData { } // write all batches with this part id - let mut writer = IpcCompressionWriter::new(CountWrite::from(&mut w), true); while iter.cur_part_id() == cur_part_id { let batch = iter.next_batch(); writer.write_batch(batch.num_rows(), batch.columns())?; } - offset += writer.finish_into_inner()?.count(); + writer.finish_current_buf()?; + offset = writer.inner().count(); offsets.push(offset); } while offsets.len() <= partitioning.partition_count() { @@ -150,17 +152,17 @@ impl BufferedData { while (iter.cur_part_id() as usize) < partitioning.partition_count() { let cur_part_id = iter.cur_part_id(); - let mut writer = IpcCompressionWriter::new( - RssWriter::new(rss_partition_writer.clone(), cur_part_id as usize), - true, - ); + let mut writer = IpcCompressionWriter::new(RssWriter::new( + rss_partition_writer.clone(), + cur_part_id as usize, + )); // write all batches with this part id while iter.cur_part_id() == cur_part_id { let batch = iter.next_batch(); writer.write_batch(batch.num_rows(), batch.columns())?; } - writer.finish_into_inner()?; + writer.finish_current_buf()?; } jni_call!(BlazeRssPartitionWriterBase(rss_partition_writer.as_obj()).flush() -> ())?; @@ -182,14 +184,13 @@ impl BufferedData { Ok(PartitionedBatchesIterator { batches: self.sorted_batches.clone(), cursors: RadixTournamentTree::new( - self.sorted_partition_indices + self.sorted_parts .into_iter() .enumerate() .map(|(idx, partition_indices)| PartCursor { idx, - part_id: partition_indices[0], - row_idx: 0, - partition_indices, + parts: partition_indices, + parts_idx: 0, }) .collect(), partitioning.partition_count(), @@ -215,31 +216,30 @@ struct PartitionedBatchesIterator { impl PartitionedBatchesIterator { pub fn cur_part_id(&self) -> u32 { - self.cursors.peek().part_id + self.cursors.peek().rdx() as u32 } fn next_batch(&mut self) -> RecordBatch { let cur_batch_size = self.batch_size.min(self.num_rows - self.num_output_rows); let cur_part_id = self.cur_part_id(); - let mut indices = Vec::with_capacity(cur_batch_size); + let mut slices = vec![]; + let mut slices_len = 0; // add rows with same parition id under this cursor - while indices.len() < cur_batch_size { + while slices_len < cur_batch_size { let mut min_cursor = self.cursors.peek_mut(); - if min_cursor.part_id != cur_part_id { + if min_cursor.rdx() as u32 != cur_part_id { break; } - while indices.len() < cur_batch_size && min_cursor.part_id == cur_part_id { - indices.push((min_cursor.idx, min_cursor.row_idx)); - min_cursor.row_idx += 1; - min_cursor.part_id = *min_cursor - .partition_indices - .get(min_cursor.row_idx) - .unwrap_or(&u32::MAX); - } + + let cur_part = min_cursor.parts[min_cursor.parts_idx]; + let cur_slice = + self.batches[min_cursor.idx].slice(cur_part.start as usize, cur_part.len as usize); + slices_len += cur_slice.num_rows(); + slices.push(cur_slice); + min_cursor.parts_idx += 1; } - let output_batch = interleave_batches(self.batches[0].schema(), &self.batches, &indices) - .expect("error merging sorted batches: interleaving error"); + let output_batch = coalesce_batches_unchecked(self.batches[0].schema(), &slices); self.num_output_rows += output_batch.num_rows(); output_batch } @@ -247,21 +247,30 @@ impl PartitionedBatchesIterator { struct PartCursor { idx: usize, - partition_indices: Vec, - row_idx: usize, - part_id: u32, + parts: Vec, + parts_idx: usize, } impl KeyForRadixTournamentTree for PartCursor { fn rdx(&self) -> usize { - self.part_id as usize + if self.parts_idx < self.parts.len() { + return self.parts[self.parts_idx].part_id as usize; + } + u32::MAX as usize } } +#[derive(Clone, Copy)] +struct PartitionInBatch { + part_id: u32, + start: u32, + len: u32, +} + fn sort_batches_by_partition_id( batches: Vec, partitioning: &Partitioning, -) -> Result<(Vec, RecordBatch)> { +) -> Result<(Vec, RecordBatch)> { let num_partitions = partitioning.partition_count(); let schema = batches[0].schema(); @@ -291,5 +300,22 @@ fn sort_batches_by_partition_id( .map(|(part_id, batch_idx, row_idx)| (part_id, (batch_idx as usize, row_idx as usize))) .unzip(); let sorted_batch = interleave_batches(schema, &batches, &sorted_row_indices)?; - return Ok((sorted_partition_indices, sorted_batch)); + + let mut start = 0; + let partitions = sorted_partition_indices + .into_iter() + .chunk_by(|part_id| *part_id) + .into_iter() + .map(|(part_id, chunk)| { + let partition = PartitionInBatch { + part_id, + start, + len: chunk.count() as u32, + }; + start += partition.len; + partition + }) + .collect(); + + return Ok((partitions, sorted_batch)); } diff --git a/native-engine/datafusion-ext-plans/src/shuffle/rss_single_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/rss_single_repartitioner.rs index b350c2635..d43cc4004 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/rss_single_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/rss_single_repartitioner.rs @@ -32,10 +32,10 @@ pub struct RssSingleShuffleRepartitioner { impl RssSingleShuffleRepartitioner { pub fn new(rss_partition_writer: GlobalRef) -> Self { Self { - rss_partition_writer: Arc::new(Mutex::new(IpcCompressionWriter::new( - RssWriter::new(rss_partition_writer, 0), - true, - ))), + rss_partition_writer: Arc::new(Mutex::new(IpcCompressionWriter::new(RssWriter::new( + rss_partition_writer, + 0, + )))), } } } @@ -55,7 +55,7 @@ impl ShuffleRepartitioner for RssSingleShuffleRepartitioner { } async fn shuffle_write(&self) -> Result<()> { - self.rss_partition_writer.lock().flush()?; + self.rss_partition_writer.lock().finish_current_buf()?; Ok(()) } } diff --git a/native-engine/datafusion-ext-plans/src/shuffle/single_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/single_repartitioner.rs index 810b94dd6..4c159c28a 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/single_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/single_repartitioner.rs @@ -57,7 +57,6 @@ impl SingleShuffleRepartitioner { .create(true) .truncate(true) .open(&self.output_data_file)?, - true, )); } Ok(output_data.as_mut().unwrap()) @@ -76,12 +75,12 @@ impl ShuffleRepartitioner for SingleShuffleRepartitioner { async fn shuffle_write(&self) -> Result<()> { let _timer = self.metrics.elapsed_compute().timer(); - let output_data = std::mem::take(&mut *self.output_data.lock().await); + let mut output_data = std::mem::take(&mut *self.output_data.lock().await); // write index file - if let Some(output_writer) = output_data { - let mut output_file = output_writer.finish_into_inner()?; - let offset = output_file.stream_position()?; + if let Some(output_writer) = output_data.as_mut() { + output_writer.finish_current_buf()?; + let offset = output_writer.inner().stream_position()?; let mut output_index = File::create(&self.output_index_file)?; output_index.write_all(&[0u8; 8])?; output_index.write_all(&(offset as i64).to_le_bytes()[..])?; diff --git a/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs index bf6b16b26..2e8e19c26 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs @@ -164,10 +164,13 @@ impl ShuffleRepartitioner for SortShuffleRepartitioner { .create(true) .truncate(true) .open(&data_file)?; + let mut output_index = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&index_file)?; let offsets = data.write(&mut output_data, &partitioning)?; - - let mut output_index = File::create(&index_file)?; for offset in offsets { output_index.write_all(&(offset as i64).to_le_bytes()[..])?; }