Skip to content

Commit

Permalink
feat: allow empty projection in table scan (#677)
Browse files Browse the repository at this point in the history
* fix: allow empty projection in scan

* fix: allow empty projection in scan

* fix: pub get_manifest_list

* Update crates/iceberg/src/scan.rs

Co-authored-by: Renjie Liu <[email protected]>

* chore: remove pub

---------

Co-authored-by: Renjie Liu <[email protected]>
  • Loading branch information
sundy-li and liurenjie1024 authored Oct 25, 2024
1 parent 0c44e50 commit 11e36c0
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 35 deletions.
29 changes: 11 additions & 18 deletions crates/iceberg/src/arrow/record_batch_transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;

use arrow_array::{
Array as ArrowArray, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array,
Int32Array, Int64Array, NullArray, RecordBatch, StringArray,
Int32Array, Int64Array, NullArray, RecordBatch, RecordBatchOptions, StringArray,
};
use arrow_cast::cast;
use arrow_schema::{
Expand Down Expand Up @@ -124,19 +124,7 @@ impl RecordBatchTransformer {
snapshot_schema: Arc<IcebergSchema>,
projected_iceberg_field_ids: &[i32],
) -> Self {
let projected_iceberg_field_ids = if projected_iceberg_field_ids.is_empty() {
// If the list of field ids is empty, this indicates that we
// need to select all fields.
// Project all fields in table schema order
snapshot_schema
.as_struct()
.fields()
.iter()
.map(|field| field.id)
.collect()
} else {
projected_iceberg_field_ids.to_vec()
};
let projected_iceberg_field_ids = projected_iceberg_field_ids.to_vec();

Self {
snapshot_schema,
Expand All @@ -154,10 +142,15 @@ impl RecordBatchTransformer {
Some(BatchTransform::Modify {
ref target_schema,
ref operations,
}) => RecordBatch::try_new(
target_schema.clone(),
self.transform_columns(record_batch.columns(), operations)?,
)?,
}) => {
let options =
RecordBatchOptions::default().with_row_count(Some(record_batch.num_rows()));
RecordBatch::try_new_with_options(
target_schema.clone(),
self.transform_columns(record_batch.columns(), operations)?,
&options,
)?
}
Some(BatchTransform::ModifySchema { target_schema }) => {
record_batch.with_schema(target_schema.clone())?
}
Expand Down
63 changes: 46 additions & 17 deletions crates/iceberg/src/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ pub type ArrowRecordBatchStream = BoxStream<'static, Result<RecordBatch>>;
/// Builder to create table scan.
pub struct TableScanBuilder<'a> {
table: &'a Table,
// Empty column names means to select all columns
column_names: Vec<String>,
// Defaults to none which means select all columns
column_names: Option<Vec<String>>,
snapshot_id: Option<i64>,
batch_size: Option<usize>,
case_sensitive: bool,
Expand All @@ -70,7 +70,7 @@ impl<'a> TableScanBuilder<'a> {

Self {
table,
column_names: vec![],
column_names: None,
snapshot_id: None,
batch_size: None,
case_sensitive: true,
Expand Down Expand Up @@ -106,16 +106,24 @@ impl<'a> TableScanBuilder<'a> {

/// Select all columns.
pub fn select_all(mut self) -> Self {
self.column_names.clear();
self.column_names = None;
self
}

/// Select empty columns.
pub fn select_empty(mut self) -> Self {
self.column_names = Some(vec![]);
self
}

/// Select some columns of the table.
pub fn select(mut self, column_names: impl IntoIterator<Item = impl ToString>) -> Self {
self.column_names = column_names
.into_iter()
.map(|item| item.to_string())
.collect();
self.column_names = Some(
column_names
.into_iter()
.map(|item| item.to_string())
.collect(),
);
self
}

Expand Down Expand Up @@ -205,8 +213,8 @@ impl<'a> TableScanBuilder<'a> {
let schema = snapshot.schema(self.table.metadata())?;

// Check that all column names exist in the schema.
if !self.column_names.is_empty() {
for column_name in &self.column_names {
if let Some(column_names) = self.column_names.as_ref() {
for column_name in column_names {
if schema.field_by_name(column_name).is_none() {
return Err(Error::new(
ErrorKind::DataInvalid,
Expand All @@ -220,7 +228,16 @@ impl<'a> TableScanBuilder<'a> {
}

let mut field_ids = vec![];
for column_name in &self.column_names {
let column_names = self.column_names.clone().unwrap_or_else(|| {
schema
.as_struct()
.fields()
.iter()
.map(|f| f.name.clone())
.collect()
});

for column_name in column_names.iter() {
let field_id = schema.field_id_by_name(column_name).ok_or_else(|| {
Error::new(
ErrorKind::DataInvalid,
Expand Down Expand Up @@ -297,7 +314,7 @@ pub struct TableScan {
plan_context: PlanContext,
batch_size: Option<usize>,
file_io: FileIO,
column_names: Vec<String>,
column_names: Option<Vec<String>>,
/// The maximum number of manifest files that will be
/// retrieved from [`FileIO`] concurrently
concurrency_limit_manifest_files: usize,
Expand Down Expand Up @@ -409,9 +426,10 @@ impl TableScan {
}

/// Returns a reference to the column names of the table scan.
pub fn column_names(&self) -> &[String] {
&self.column_names
pub fn column_names(&self) -> Option<&[String]> {
self.column_names.as_deref()
}

/// Returns a reference to the snapshot of the table scan.
pub fn snapshot(&self) -> &SnapshotRef {
&self.plan_context.snapshot
Expand Down Expand Up @@ -1236,23 +1254,26 @@ mod tests {
let table = TableTestFixture::new().table;

let table_scan = table.scan().select(["x", "y"]).build().unwrap();
assert_eq!(vec!["x", "y"], table_scan.column_names);
assert_eq!(
Some(vec!["x".to_string(), "y".to_string()]),
table_scan.column_names
);

let table_scan = table
.scan()
.select(["x", "y"])
.select(["z"])
.build()
.unwrap();
assert_eq!(vec!["z"], table_scan.column_names);
assert_eq!(Some(vec!["z".to_string()]), table_scan.column_names);
}

#[test]
fn test_select_all() {
let table = TableTestFixture::new().table;

let table_scan = table.scan().select_all().build().unwrap();
assert!(table_scan.column_names.is_empty());
assert!(table_scan.column_names.is_none());
}

#[test]
Expand Down Expand Up @@ -1424,6 +1445,14 @@ mod tests {
let col2 = batches[0].column_by_name("z").unwrap();
let int64_arr = col2.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 3);

// test empty scan
let table_scan = fixture.table.scan().select_empty().build().unwrap();
let batch_stream = table_scan.to_arrow().await.unwrap();
let batches: Vec<_> = batch_stream.try_collect().await.unwrap();

assert_eq!(batches[0].num_columns(), 0);
assert_eq!(batches[0].num_rows(), 1024);
}

#[tokio::test]
Expand Down

0 comments on commit 11e36c0

Please sign in to comment.