-
Notifications
You must be signed in to change notification settings - Fork 193
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor!: isolate arrow code (#266)
# Rationale for this change Currently, arrow is one of the elements that must be disabled to enable no_std support. There are many instances of #[cfg(feature = "arrow")] throughout the codebase, which leaves the code cluttered. To make code readable and manageable , arrow module is introduced. # What changes are included in this PR? - Moved all [cfg(feature = "arrow")] related code to separate modules and cracked cyclic dependencies along the way without any massive disruption. # Are these changes tested? Yes
- Loading branch information
Showing
19 changed files
with
337 additions
and
548 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
79 changes: 79 additions & 0 deletions
79
crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
use crate::base::{ | ||
database::{ColumnField, ColumnType}, | ||
math::decimal::Precision, | ||
}; | ||
use alloc::sync::Arc; | ||
use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit}; | ||
use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; | ||
|
||
/// Convert [`ColumnType`] values to some arrow [`DataType`] | ||
impl From<&ColumnType> for DataType { | ||
fn from(column_type: &ColumnType) -> Self { | ||
match column_type { | ||
ColumnType::Boolean => DataType::Boolean, | ||
ColumnType::TinyInt => DataType::Int8, | ||
ColumnType::SmallInt => DataType::Int16, | ||
ColumnType::Int => DataType::Int32, | ||
ColumnType::BigInt => DataType::Int64, | ||
ColumnType::Int128 => DataType::Decimal128(38, 0), | ||
ColumnType::Decimal75(precision, scale) => { | ||
DataType::Decimal256(precision.value(), *scale) | ||
} | ||
ColumnType::VarChar => DataType::Utf8, | ||
ColumnType::Scalar => unimplemented!("Cannot convert Scalar type to arrow type"), | ||
ColumnType::TimestampTZ(timeunit, timezone) => { | ||
let arrow_timezone = Some(Arc::from(timezone.to_string())); | ||
let arrow_timeunit = match timeunit { | ||
PoSQLTimeUnit::Second => ArrowTimeUnit::Second, | ||
PoSQLTimeUnit::Millisecond => ArrowTimeUnit::Millisecond, | ||
PoSQLTimeUnit::Microsecond => ArrowTimeUnit::Microsecond, | ||
PoSQLTimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond, | ||
}; | ||
DataType::Timestamp(arrow_timeunit, arrow_timezone) | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// Convert arrow [`DataType`] values to some [`ColumnType`] | ||
impl TryFrom<DataType> for ColumnType { | ||
type Error = String; | ||
|
||
fn try_from(data_type: DataType) -> Result<Self, Self::Error> { | ||
match data_type { | ||
DataType::Boolean => Ok(ColumnType::Boolean), | ||
DataType::Int8 => Ok(ColumnType::TinyInt), | ||
DataType::Int16 => Ok(ColumnType::SmallInt), | ||
DataType::Int32 => Ok(ColumnType::Int), | ||
DataType::Int64 => Ok(ColumnType::BigInt), | ||
DataType::Decimal128(38, 0) => Ok(ColumnType::Int128), | ||
DataType::Decimal256(precision, scale) if precision <= 75 => { | ||
Ok(ColumnType::Decimal75(Precision::new(precision)?, scale)) | ||
} | ||
DataType::Timestamp(time_unit, timezone_option) => { | ||
let posql_time_unit = match time_unit { | ||
ArrowTimeUnit::Second => PoSQLTimeUnit::Second, | ||
ArrowTimeUnit::Millisecond => PoSQLTimeUnit::Millisecond, | ||
ArrowTimeUnit::Microsecond => PoSQLTimeUnit::Microsecond, | ||
ArrowTimeUnit::Nanosecond => PoSQLTimeUnit::Nanosecond, | ||
}; | ||
Ok(ColumnType::TimestampTZ( | ||
posql_time_unit, | ||
PoSQLTimeZone::try_from(&timezone_option)?, | ||
)) | ||
} | ||
DataType::Utf8 => Ok(ColumnType::VarChar), | ||
_ => Err(format!("Unsupported arrow data type {data_type:?}")), | ||
} | ||
} | ||
} | ||
/// Convert [`ColumnField`] values to arrow Field | ||
impl From<&ColumnField> for Field { | ||
fn from(column_field: &ColumnField) -> Self { | ||
Field::new( | ||
column_field.name().name(), | ||
(&column_field.data_type()).into(), | ||
false, | ||
) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
//! This module provides conversions and utilities for working with Arrow data structures. | ||
/// Module for handling conversion from Arrow arrays to columns. | ||
pub mod arrow_array_to_column_conversion; | ||
|
||
/// Module for converting between owned and Arrow data structures. | ||
pub mod owned_and_arrow_conversions; | ||
|
||
#[cfg(test)] | ||
/// Tests for owned and Arrow conversions. | ||
mod owned_and_arrow_conversions_test; | ||
|
||
/// Module for converting record batches. | ||
pub mod record_batch_conversion; | ||
|
||
/// Module for record batch error definitions. | ||
pub mod record_batch_errors; | ||
|
||
/// Utility functions for record batches. | ||
pub mod record_batch_utility; | ||
|
||
/// Module for scalar and i256 conversions. | ||
pub mod scalar_and_i256_conversions; | ||
|
||
/// Module for handling conversions between columns and Arrow arrays. | ||
pub mod column_arrow_conversions; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 2 additions & 2 deletions
4
...abase/owned_and_arrow_conversions_test.rs → ...arrow/owned_and_arrow_conversions_test.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
160 changes: 160 additions & 0 deletions
160
crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
use super::{ | ||
arrow_array_to_column_conversion::ArrayRefExt, | ||
record_batch_errors::{AppendRecordBatchTableCommitmentError, RecordBatchToColumnsError}, | ||
}; | ||
use crate::base::{ | ||
commitment::{ | ||
AppendColumnCommitmentsError, AppendTableCommitmentError, Commitment, TableCommitment, | ||
TableCommitmentFromColumnsError, | ||
}, | ||
database::Column, | ||
scalar::Scalar, | ||
}; | ||
use arrow::record_batch::RecordBatch; | ||
use bumpalo::Bump; | ||
use proof_of_sql_parser::Identifier; | ||
|
||
/// This function will return an error if: | ||
/// - The field name cannot be parsed into an [`Identifier`]. | ||
/// - The conversion of an Arrow array to a [`Column`] fails. | ||
pub fn batch_to_columns<'a, S: Scalar + 'a>( | ||
batch: &'a RecordBatch, | ||
alloc: &'a Bump, | ||
) -> Result<Vec<(Identifier, Column<'a, S>)>, RecordBatchToColumnsError> { | ||
batch | ||
.schema() | ||
.fields() | ||
.into_iter() | ||
.zip(batch.columns()) | ||
.map(|(field, array)| { | ||
let identifier: Identifier = field.name().parse()?; | ||
let column: Column<S> = array.to_column(alloc, &(0..array.len()), None)?; | ||
Ok((identifier, column)) | ||
}) | ||
.collect() | ||
} | ||
|
||
impl<C: Commitment> TableCommitment<C> { | ||
/// Append an arrow [`RecordBatch`] to the existing [`TableCommitment`]. | ||
/// | ||
/// The row offset is assumed to be the end of the [`TableCommitment`]'s current range. | ||
/// | ||
/// Will error on a variety of mismatches, or if the provided columns have mixed length. | ||
#[allow(clippy::missing_panics_doc)] | ||
pub fn try_append_record_batch( | ||
&mut self, | ||
batch: &RecordBatch, | ||
setup: &C::PublicSetup<'_>, | ||
) -> Result<(), AppendRecordBatchTableCommitmentError> { | ||
match self.try_append_rows( | ||
batch_to_columns::<C::Scalar>(batch, &Bump::new())? | ||
.iter() | ||
.map(|(a, b)| (a, b)), | ||
setup, | ||
) { | ||
Ok(()) => Ok(()), | ||
Err(AppendTableCommitmentError::MixedLengthColumns { .. }) => { | ||
panic!("RecordBatches cannot have columns of mixed length") | ||
} | ||
Err(AppendTableCommitmentError::AppendColumnCommitments { | ||
source: AppendColumnCommitmentsError::DuplicateIdentifiers { .. }, | ||
}) => { | ||
panic!("RecordBatches cannot have duplicate identifiers") | ||
} | ||
Err(AppendTableCommitmentError::AppendColumnCommitments { | ||
source: AppendColumnCommitmentsError::Mismatch { source: e }, | ||
}) => Err(e)?, | ||
} | ||
} | ||
/// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`]. | ||
pub fn try_from_record_batch( | ||
batch: &RecordBatch, | ||
setup: &C::PublicSetup<'_>, | ||
) -> Result<TableCommitment<C>, RecordBatchToColumnsError> { | ||
Self::try_from_record_batch_with_offset(batch, 0, setup) | ||
} | ||
|
||
/// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`] with the given row offset. | ||
#[allow(clippy::missing_panics_doc)] | ||
pub fn try_from_record_batch_with_offset( | ||
batch: &RecordBatch, | ||
offset: usize, | ||
setup: &C::PublicSetup<'_>, | ||
) -> Result<TableCommitment<C>, RecordBatchToColumnsError> { | ||
match Self::try_from_columns_with_offset( | ||
batch_to_columns::<C::Scalar>(batch, &Bump::new())? | ||
.iter() | ||
.map(|(a, b)| (a, b)), | ||
offset, | ||
setup, | ||
) { | ||
Ok(commitment) => Ok(commitment), | ||
Err(TableCommitmentFromColumnsError::MixedLengthColumns { .. }) => { | ||
panic!("RecordBatches cannot have columns of mixed length") | ||
} | ||
Err(TableCommitmentFromColumnsError::DuplicateIdentifiers { .. }) => { | ||
panic!("RecordBatches cannot have duplicate identifiers") | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[cfg(all(test, feature = "blitzar"))] | ||
mod tests { | ||
use super::*; | ||
use crate::{base::scalar::Curve25519Scalar, record_batch}; | ||
use curve25519_dalek::RistrettoPoint; | ||
|
||
#[test] | ||
fn we_can_create_and_append_table_commitments_with_record_batchs() { | ||
let batch = record_batch!( | ||
"a" => [1i64, 2, 3], | ||
"b" => ["1", "2", "3"], | ||
); | ||
|
||
let b_scals = ["1".into(), "2".into(), "3".into()]; | ||
|
||
let columns = [ | ||
( | ||
&"a".parse().unwrap(), | ||
&Column::<Curve25519Scalar>::BigInt(&[1, 2, 3]), | ||
), | ||
( | ||
&"b".parse().unwrap(), | ||
&Column::<Curve25519Scalar>::VarChar((&["1", "2", "3"], &b_scals)), | ||
), | ||
]; | ||
|
||
let mut expected_commitment = | ||
TableCommitment::<RistrettoPoint>::try_from_columns_with_offset(columns, 0, &()) | ||
.unwrap(); | ||
|
||
let mut commitment = | ||
TableCommitment::<RistrettoPoint>::try_from_record_batch(&batch, &()).unwrap(); | ||
|
||
assert_eq!(commitment, expected_commitment); | ||
|
||
let batch2 = record_batch!( | ||
"a" => [4i64, 5, 6], | ||
"b" => ["4", "5", "6"], | ||
); | ||
|
||
let b_scals2 = ["4".into(), "5".into(), "6".into()]; | ||
|
||
let columns2 = [ | ||
( | ||
&"a".parse().unwrap(), | ||
&Column::<Curve25519Scalar>::BigInt(&[4, 5, 6]), | ||
), | ||
( | ||
&"b".parse().unwrap(), | ||
&Column::<Curve25519Scalar>::VarChar((&["4", "5", "6"], &b_scals2)), | ||
), | ||
]; | ||
|
||
expected_commitment.try_append_rows(columns2, &()).unwrap(); | ||
commitment.try_append_record_batch(&batch2, &()).unwrap(); | ||
|
||
assert_eq!(commitment, expected_commitment); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
use super::arrow_array_to_column_conversion::ArrowArrayToColumnConversionError; | ||
use crate::base::commitment::ColumnCommitmentsMismatch; | ||
use proof_of_sql_parser::ParseError; | ||
use snafu::Snafu; | ||
|
||
/// Errors that can occur when trying to create or extend a [`TableCommitment`] from a record batch. | ||
#[derive(Debug, Snafu)] | ||
pub enum RecordBatchToColumnsError { | ||
/// Error converting from arrow array | ||
#[snafu(transparent)] | ||
ArrowArrayToColumnConversionError { | ||
/// The underlying source error | ||
source: ArrowArrayToColumnConversionError, | ||
}, | ||
#[snafu(transparent)] | ||
/// This error occurs when convering from a record batch name to an identifier fails. (Which may be impossible.) | ||
FieldParseFail { | ||
/// The underlying source error | ||
source: ParseError, | ||
}, | ||
} | ||
|
||
/// Errors that can occur when attempting to append a record batch to a [`TableCommitment`]. | ||
#[derive(Debug, Snafu)] | ||
pub enum AppendRecordBatchTableCommitmentError { | ||
/// During commitment operation, metadata indicates that operand tables cannot be the same. | ||
#[snafu(transparent)] | ||
ColumnCommitmentsMismatch { | ||
/// The underlying source error | ||
source: ColumnCommitmentsMismatch, | ||
}, | ||
/// Error converting from arrow array | ||
#[snafu(transparent)] | ||
ArrowBatchToColumnError { | ||
/// The underlying source error | ||
source: RecordBatchToColumnsError, | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.