Skip to content

Commit

Permalink
refactor(rust): Refactor ArrowSchema to use `polars_schema::Schema<…
Browse files Browse the repository at this point in the history
…D>` (#18564)
  • Loading branch information
nameexhaustion authored Sep 5, 2024
1 parent e4746a5 commit 10fab78
Show file tree
Hide file tree
Showing 66 changed files with 332 additions and 317 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/polars-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ hashbrown = { workspace = true }
num-traits = { workspace = true }
parking_lot = { workspace = true }
polars-error = { workspace = true }
polars-schema = { workspace = true }
polars-utils = { workspace = true }
serde = { workspace = true, optional = true }
simdutf8 = { workspace = true }
Expand Down Expand Up @@ -153,7 +154,7 @@ compute = [
"compute_take",
"compute_temporal",
]
serde = ["dep:serde"]
serde = ["dep:serde", "polars-schema/serde"]
simd = []

# polars-arrow
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-arrow/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ pub struct Field {
pub metadata: Metadata,
}

/// Support for `ArrowSchema::from_iter([field, ..])`
impl From<Field> for (PlSmallStr, Field) {
fn from(value: Field) -> Self {
(value.name.clone(), value)
}
}

impl Field {
/// Creates a new [`Field`].
pub fn new(name: PlSmallStr, data_type: ArrowDataType, is_nullable: bool) -> Self {
Expand Down
64 changes: 1 addition & 63 deletions crates/polars-arrow/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,11 @@
use std::sync::Arc;

use polars_error::{polars_bail, PolarsResult};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use super::Field;

/// An ordered sequence of [`Field`]s
///
/// [`ArrowSchema`] is an abstraction used to read from, and write to, Arrow IPC format,
/// Apache Parquet, and Apache Avro. All these formats have a concept of a schema
/// with fields and metadata.
#[derive(Debug, Clone, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ArrowSchema {
/// The fields composing this schema.
pub fields: Vec<Field>,
}

pub type ArrowSchema = polars_schema::Schema<Field>;
pub type ArrowSchemaRef = Arc<ArrowSchema>;

impl ArrowSchema {
#[inline]
pub fn len(&self) -> usize {
self.fields.len()
}

#[inline]
pub fn is_empty(&self) -> bool {
self.fields.is_empty()
}

/// Returns a new [`ArrowSchema`] with a subset of all fields whose `predicate`
/// evaluates to true.
pub fn filter<F: Fn(usize, &Field) -> bool>(self, predicate: F) -> Self {
let fields = self
.fields
.into_iter()
.enumerate()
.filter_map(|(index, f)| {
if (predicate)(index, &f) {
Some(f)
} else {
None
}
})
.collect();

ArrowSchema { fields }
}

pub fn try_project(&self, indices: &[usize]) -> PolarsResult<Self> {
let fields = indices.iter().map(|&i| {
let Some(out) = self.fields.get(i) else {
polars_bail!(
SchemaFieldNotFound: "projection index {} is out of bounds for schema of length {}",
i, self.fields.len()
);
};

Ok(out.clone())
}).collect::<PolarsResult<Vec<_>>>()?;

Ok(ArrowSchema { fields })
}
}

impl From<Vec<Field>> for ArrowSchema {
fn from(fields: Vec<Field>) -> Self {
Self { fields }
}
}
6 changes: 3 additions & 3 deletions crates/polars-arrow/src/io/avro/read/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ fn skip_item<'a>(
/// `fields`, `avro_fields` and `projection` must have the same length.
pub fn deserialize(
block: &Block,
fields: &[Field],
fields: &ArrowSchema,
avro_fields: &[AvroField],
projection: &[bool],
) -> PolarsResult<RecordBatchT<Box<dyn Array>>> {
Expand All @@ -479,7 +479,7 @@ pub fn deserialize(

// create mutables, one per field
let mut arrays: Vec<Box<dyn MutableArray>> = fields
.iter()
.iter_values()
.zip(avro_fields.iter())
.zip(projection.iter())
.map(|((field, avro_field), projection)| {
Expand All @@ -496,7 +496,7 @@ pub fn deserialize(
for _ in 0..rows {
let iter = arrays
.iter_mut()
.zip(fields.iter())
.zip(fields.iter_values())
.zip(avro_fields.iter())
.zip(projection.iter());

Expand Down
8 changes: 4 additions & 4 deletions crates/polars-arrow/src/io/avro/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ mod util;
pub use schema::infer_schema;

use crate::array::Array;
use crate::datatypes::Field;
use crate::datatypes::ArrowSchema;
use crate::record_batch::RecordBatchT;

/// Single threaded, blocking reader of Avro; [`Iterator`] of [`RecordBatchT`].
pub struct Reader<R: Read> {
iter: BlockStreamingIterator<R>,
avro_fields: Vec<AvroField>,
fields: Vec<Field>,
fields: ArrowSchema,
projection: Vec<bool>,
}

Expand All @@ -33,7 +33,7 @@ impl<R: Read> Reader<R> {
pub fn new(
reader: R,
metadata: FileMetadata,
fields: Vec<Field>,
fields: ArrowSchema,
projection: Option<Vec<bool>>,
) -> Self {
let projection = projection.unwrap_or_else(|| fields.iter().map(|_| true).collect());
Expand All @@ -56,7 +56,7 @@ impl<R: Read> Iterator for Reader<R> {
type Item = PolarsResult<RecordBatchT<Box<dyn Array>>>;

fn next(&mut self) -> Option<Self::Item> {
let fields = &self.fields[..];
let fields = &self.fields;
let avro_fields = &self.avro_fields;
let projection = &self.projection;

Expand Down
11 changes: 6 additions & 5 deletions crates/polars-arrow/src/io/avro/read/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@ fn external_props(schema: &AvroSchema) -> Metadata {
/// Infers an [`ArrowSchema`] from the root [`Record`].
/// This
pub fn infer_schema(record: &Record) -> PolarsResult<ArrowSchema> {
Ok(record
record
.fields
.iter()
.map(|field| {
schema_to_field(
let field = schema_to_field(
&field.schema,
Some(&field.name),
external_props(&field.schema),
)
)?;

Ok((field.name.clone(), field))
})
.collect::<PolarsResult<Vec<_>>>()?
.into())
.collect::<PolarsResult<ArrowSchema>>()
}

fn schema_to_field(
Expand Down
3 changes: 1 addition & 2 deletions crates/polars-arrow/src/io/avro/write/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ use crate::datatypes::*;
pub fn to_record(schema: &ArrowSchema, name: String) -> PolarsResult<Record> {
let mut name_counter: i32 = 0;
let fields = schema
.fields
.iter()
.iter_values()
.map(|f| field_to_field(f, &mut name_counter))
.collect::<PolarsResult<_>>()?;
Ok(Record {
Expand Down
10 changes: 5 additions & 5 deletions crates/polars-arrow/src/io/flight/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pub fn serialize_schema_to_info(
let encoded_data = if let Some(ipc_fields) = ipc_fields {
schema_as_encoded_data(schema, ipc_fields)
} else {
let ipc_fields = default_ipc_fields(&schema.fields);
let ipc_fields = default_ipc_fields(schema.iter_values());
schema_as_encoded_data(schema, &ipc_fields)
};

Expand All @@ -92,7 +92,7 @@ fn _serialize_schema(schema: &ArrowSchema, ipc_fields: Option<&[IpcField]>) -> V
if let Some(ipc_fields) = ipc_fields {
write::schema_to_bytes(schema, ipc_fields)
} else {
let ipc_fields = default_ipc_fields(&schema.fields);
let ipc_fields = default_ipc_fields(schema.iter_values());
write::schema_to_bytes(schema, &ipc_fields)
}
}
Expand All @@ -113,7 +113,7 @@ pub fn deserialize_schemas(bytes: &[u8]) -> PolarsResult<(ArrowSchema, IpcSchema
/// Deserializes [`FlightData`] representing a record batch message to [`RecordBatchT`].
pub fn deserialize_batch(
data: &FlightData,
fields: &[Field],
fields: &ArrowSchema,
ipc_schema: &IpcSchema,
dictionaries: &read::Dictionaries,
) -> PolarsResult<RecordBatchT<Box<dyn Array>>> {
Expand Down Expand Up @@ -147,7 +147,7 @@ pub fn deserialize_batch(
/// Deserializes [`FlightData`], assuming it to be a dictionary message, into `dictionaries`.
pub fn deserialize_dictionary(
data: &FlightData,
fields: &[Field],
fields: &ArrowSchema,
ipc_schema: &IpcSchema,
dictionaries: &mut read::Dictionaries,
) -> PolarsResult<()> {
Expand Down Expand Up @@ -182,7 +182,7 @@ pub fn deserialize_dictionary(
/// or by upserting into `dictionaries` (when the message is a dictionary)
pub fn deserialize_message(
data: &FlightData,
fields: &[Field],
fields: &ArrowSchema,
ipc_schema: &IpcSchema,
dictionaries: &mut Dictionaries,
) -> PolarsResult<Option<RecordBatchT<Box<dyn Array>>>> {
Expand Down
38 changes: 25 additions & 13 deletions crates/polars-arrow/src/io/ipc/read/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use polars_utils::pl_str::PlSmallStr;
use super::deserialize::{read, skip};
use super::Dictionaries;
use crate::array::*;
use crate::datatypes::{ArrowDataType, Field};
use crate::datatypes::{ArrowDataType, ArrowSchema, Field};
use crate::io::ipc::read::OutOfSpecKind;
use crate::io::ipc::{IpcField, IpcSchema};
use crate::record_batch::RecordBatchT;
Expand Down Expand Up @@ -77,7 +77,7 @@ impl<'a, A, I: Iterator<Item = A>> Iterator for ProjectionIter<'a, A, I> {
#[allow(clippy::too_many_arguments)]
pub fn read_record_batch<R: Read + Seek>(
batch: arrow_format::ipc::RecordBatchRef,
fields: &[Field],
fields: &ArrowSchema,
ipc_schema: &IpcSchema,
projection: Option<&[usize]>,
limit: Option<usize>,
Expand Down Expand Up @@ -127,8 +127,10 @@ pub fn read_record_batch<R: Read + Seek>(
let mut field_nodes = field_nodes.iter().collect::<VecDeque<_>>();

let columns = if let Some(projection) = projection {
let projection =
ProjectionIter::new(projection, fields.iter().zip(ipc_schema.fields.iter()));
let projection = ProjectionIter::new(
projection,
fields.iter_values().zip(ipc_schema.fields.iter()),
);

projection
.map(|maybe_field| match maybe_field {
Expand Down Expand Up @@ -163,7 +165,7 @@ pub fn read_record_batch<R: Read + Seek>(
.collect::<PolarsResult<Vec<_>>>()?
} else {
fields
.iter()
.iter_values()
.zip(ipc_schema.fields.iter())
.map(|(field, ipc_field)| {
read(
Expand Down Expand Up @@ -227,11 +229,11 @@ fn find_first_dict_field<'a>(

pub(crate) fn first_dict_field<'a>(
id: i64,
fields: &'a [Field],
fields: &'a ArrowSchema,
ipc_fields: &'a [IpcField],
) -> PolarsResult<(&'a Field, &'a IpcField)> {
assert_eq!(fields.len(), ipc_fields.len());
for (field, ipc_field) in fields.iter().zip(ipc_fields.iter()) {
for (field, ipc_field) in fields.iter_values().zip(ipc_fields.iter()) {
if let Some(field) = find_first_dict_field(id, field, ipc_field) {
return Ok(field);
}
Expand All @@ -246,7 +248,7 @@ pub(crate) fn first_dict_field<'a>(
#[allow(clippy::too_many_arguments)]
pub fn read_dictionary<R: Read + Seek>(
batch: arrow_format::ipc::DictionaryBatchRef,
fields: &[Field],
fields: &ArrowSchema,
ipc_schema: &IpcSchema,
dictionaries: &mut Dictionaries,
reader: &mut R,
Expand Down Expand Up @@ -280,7 +282,11 @@ pub fn read_dictionary<R: Read + Seek>(
};

// Make a fake schema for the dictionary batch.
let fields = vec![Field::new(PlSmallStr::EMPTY, value_type.clone(), false)];
let fields = std::iter::once((
PlSmallStr::EMPTY,
Field::new(PlSmallStr::EMPTY, value_type.clone(), false),
))
.collect();
let ipc_schema = IpcSchema {
fields: vec![first_ipc_field.clone()],
is_little_endian: ipc_schema.is_little_endian,
Expand All @@ -305,10 +311,16 @@ pub fn read_dictionary<R: Read + Seek>(
}

pub fn prepare_projection(
fields: &[Field],
schema: &ArrowSchema,
mut projection: Vec<usize>,
) -> (Vec<usize>, PlHashMap<usize, usize>, Vec<Field>) {
let fields = projection.iter().map(|x| fields[*x].clone()).collect();
) -> (Vec<usize>, PlHashMap<usize, usize>, ArrowSchema) {
let schema = projection
.iter()
.map(|x| {
let (k, v) = schema.get_at_index(*x).unwrap();
(k.clone(), v.clone())
})
.collect();

// todo: find way to do this more efficiently
let mut indices = (0..projection.len()).collect::<Vec<_>>();
Expand All @@ -335,7 +347,7 @@ pub fn prepare_projection(
}
}

(projection, map, fields)
(projection, map, schema)
}

pub fn apply_projection(
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-arrow/src/io/ipc/read/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ fn read_dictionary_block<R: Read + Seek>(

read_dictionary(
batch,
&metadata.schema.fields,
&metadata.schema,
&metadata.ipc_schema,
dictionaries,
reader,
Expand Down Expand Up @@ -317,7 +317,7 @@ pub fn read_batch<R: Read + Seek>(

read_record_batch(
batch,
&metadata.schema.fields,
&metadata.schema,
&metadata.ipc_schema,
projection,
limit,
Expand Down
Loading

0 comments on commit 10fab78

Please sign in to comment.