Skip to content

Commit

Permalink
refactor!: isolate arrow code (#266)
Browse files Browse the repository at this point in the history
# Rationale for this change
Currently, arrow is one of the elements that must be disabled to enable
no_std support. There are many instances of #[cfg(feature = "arrow")]
throughout the codebase, which leaves the code cluttered. To make code
readable and manageable , arrow module is introduced.


# What changes are included in this PR?

- Moved all [cfg(feature = "arrow")] related code to separate modules
and cracked cyclic dependencies along the way without any massive
disruption.

# Are these changes tested?
Yes
  • Loading branch information
JayWhite2357 authored Oct 27, 2024
2 parents e803222 + d6a1eda commit d8f1377
Show file tree
Hide file tree
Showing 19 changed files with 337 additions and 548 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ ark-poly = { version = "0.4.0" }
ark-serialize = { version = "0.4.0" }
ark-std = { version = "0.4.0", default-features = false }
arrayvec = { version = "0.7", default-features = false }
arrow = { version = "51.0" }
arrow-csv = { version = "51.0" }
arrow = { version = "51.0.0" }
arrow-csv = { version = "51.0.0" }
bit-iter = { version = "1.1.1" }
bigdecimal = { version = "0.4.5", default-features = false, features = ["serde"] }
blake3 = { version = "1.3.3", default-features = false }
Expand Down
3 changes: 2 additions & 1 deletion crates/proof-of-sql/examples/posql_db/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod commit_accessor;
mod csv_accessor;
/// TODO: add docs
mod record_batch_accessor;

use arrow::{
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
Expand Down Expand Up @@ -273,7 +274,7 @@ fn main() {
end_timer(timer);
println!(
"Verified Result: {:?}",
RecordBatch::try_from(query_result).unwrap()
RecordBatch::try_from(query_result.table).unwrap()
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use arrow::record_batch::RecordBatch;
use bumpalo::Bump;
use indexmap::IndexMap;
use proof_of_sql::base::{
arrow::arrow_array_to_column_conversion::ArrayRefExt,
database::{
ArrayRefExt, Column, ColumnRef, ColumnType, DataAccessor, MetadataAccessor, SchemaAccessor,
TableRef,
Column, ColumnRef, ColumnType, DataAccessor, MetadataAccessor, SchemaAccessor, TableRef,
},
scalar::Scalar,
};
Expand Down
79 changes: 79 additions & 0 deletions crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use crate::base::{
database::{ColumnField, ColumnType},
math::decimal::Precision,
};
use alloc::sync::Arc;
use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit};
use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};

/// Convert [`ColumnType`] values to some arrow [`DataType`]
impl From<&ColumnType> for DataType {
fn from(column_type: &ColumnType) -> Self {
match column_type {
ColumnType::Boolean => DataType::Boolean,
ColumnType::TinyInt => DataType::Int8,
ColumnType::SmallInt => DataType::Int16,
ColumnType::Int => DataType::Int32,
ColumnType::BigInt => DataType::Int64,
ColumnType::Int128 => DataType::Decimal128(38, 0),
ColumnType::Decimal75(precision, scale) => {
DataType::Decimal256(precision.value(), *scale)
}
ColumnType::VarChar => DataType::Utf8,
ColumnType::Scalar => unimplemented!("Cannot convert Scalar type to arrow type"),
ColumnType::TimestampTZ(timeunit, timezone) => {
let arrow_timezone = Some(Arc::from(timezone.to_string()));
let arrow_timeunit = match timeunit {
PoSQLTimeUnit::Second => ArrowTimeUnit::Second,
PoSQLTimeUnit::Millisecond => ArrowTimeUnit::Millisecond,
PoSQLTimeUnit::Microsecond => ArrowTimeUnit::Microsecond,
PoSQLTimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond,
};
DataType::Timestamp(arrow_timeunit, arrow_timezone)
}
}
}
}

/// Convert arrow [`DataType`] values to some [`ColumnType`]
impl TryFrom<DataType> for ColumnType {
type Error = String;

fn try_from(data_type: DataType) -> Result<Self, Self::Error> {
match data_type {
DataType::Boolean => Ok(ColumnType::Boolean),
DataType::Int8 => Ok(ColumnType::TinyInt),
DataType::Int16 => Ok(ColumnType::SmallInt),
DataType::Int32 => Ok(ColumnType::Int),
DataType::Int64 => Ok(ColumnType::BigInt),
DataType::Decimal128(38, 0) => Ok(ColumnType::Int128),
DataType::Decimal256(precision, scale) if precision <= 75 => {
Ok(ColumnType::Decimal75(Precision::new(precision)?, scale))
}
DataType::Timestamp(time_unit, timezone_option) => {
let posql_time_unit = match time_unit {
ArrowTimeUnit::Second => PoSQLTimeUnit::Second,
ArrowTimeUnit::Millisecond => PoSQLTimeUnit::Millisecond,
ArrowTimeUnit::Microsecond => PoSQLTimeUnit::Microsecond,
ArrowTimeUnit::Nanosecond => PoSQLTimeUnit::Nanosecond,
};
Ok(ColumnType::TimestampTZ(
posql_time_unit,
PoSQLTimeZone::try_from(&timezone_option)?,
))
}
DataType::Utf8 => Ok(ColumnType::VarChar),
_ => Err(format!("Unsupported arrow data type {data_type:?}")),
}
}
}
/// Convert [`ColumnField`] values to arrow Field
impl From<&ColumnField> for Field {
fn from(column_field: &ColumnField) -> Self {
Field::new(
column_field.name().name(),
(&column_field.data_type()).into(),
false,
)
}
}
26 changes: 26 additions & 0 deletions crates/proof-of-sql/src/base/arrow/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//! This module provides conversions and utilities for working with Arrow data structures.
/// Module for handling conversion from Arrow arrays to columns.
pub mod arrow_array_to_column_conversion;

/// Module for converting between owned and Arrow data structures.
pub mod owned_and_arrow_conversions;

#[cfg(test)]
/// Tests for owned and Arrow conversions.
mod owned_and_arrow_conversions_test;

/// Module for converting record batches.
pub mod record_batch_conversion;

/// Module for record batch error definitions.
pub mod record_batch_errors;

/// Utility functions for record batches.
pub mod record_batch_utility;

/// Module for scalar and i256 conversions.
pub mod scalar_and_i256_conversions;

/// Module for handling conversions between columns and Arrow arrays.
pub mod column_arrow_conversions;
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
//! This is because there is no `Int128` type in Arrow.
//! This does not check that the values are less than 39 digits.
//! However, the actual arrow backing `i128` is the correct value.
use super::scalar_and_i256_conversions::convert_scalar_to_i256;
use super::scalar_and_i256_conversions::{convert_i256_to_scalar, convert_scalar_to_i256};
use crate::base::{
database::{
scalar_and_i256_conversions::convert_i256_to_scalar, OwnedColumn, OwnedTable,
OwnedTableError,
},
database::{OwnedColumn, OwnedTable, OwnedTableError},
map::IndexMap,
math::decimal::Precision,
scalar::Scalar,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{OwnedColumn, OwnedTable};
use super::owned_and_arrow_conversions::OwnedArrowConversionError;
use crate::{
base::{
database::{owned_table_utility::*, OwnedArrowConversionError},
database::{owned_table_utility::*, OwnedColumn, OwnedTable},
map::IndexMap,
scalar::Curve25519Scalar,
},
Expand Down
160 changes: 160 additions & 0 deletions crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
use super::{
arrow_array_to_column_conversion::ArrayRefExt,
record_batch_errors::{AppendRecordBatchTableCommitmentError, RecordBatchToColumnsError},
};
use crate::base::{
commitment::{
AppendColumnCommitmentsError, AppendTableCommitmentError, Commitment, TableCommitment,
TableCommitmentFromColumnsError,
},
database::Column,
scalar::Scalar,
};
use arrow::record_batch::RecordBatch;
use bumpalo::Bump;
use proof_of_sql_parser::Identifier;

/// This function will return an error if:
/// - The field name cannot be parsed into an [`Identifier`].
/// - The conversion of an Arrow array to a [`Column`] fails.
pub fn batch_to_columns<'a, S: Scalar + 'a>(
batch: &'a RecordBatch,
alloc: &'a Bump,
) -> Result<Vec<(Identifier, Column<'a, S>)>, RecordBatchToColumnsError> {
batch
.schema()
.fields()
.into_iter()
.zip(batch.columns())
.map(|(field, array)| {
let identifier: Identifier = field.name().parse()?;
let column: Column<S> = array.to_column(alloc, &(0..array.len()), None)?;
Ok((identifier, column))
})
.collect()
}

impl<C: Commitment> TableCommitment<C> {
/// Append an arrow [`RecordBatch`] to the existing [`TableCommitment`].
///
/// The row offset is assumed to be the end of the [`TableCommitment`]'s current range.
///
/// Will error on a variety of mismatches, or if the provided columns have mixed length.
#[allow(clippy::missing_panics_doc)]
pub fn try_append_record_batch(
&mut self,
batch: &RecordBatch,
setup: &C::PublicSetup<'_>,
) -> Result<(), AppendRecordBatchTableCommitmentError> {
match self.try_append_rows(
batch_to_columns::<C::Scalar>(batch, &Bump::new())?
.iter()
.map(|(a, b)| (a, b)),
setup,
) {
Ok(()) => Ok(()),
Err(AppendTableCommitmentError::MixedLengthColumns { .. }) => {
panic!("RecordBatches cannot have columns of mixed length")
}
Err(AppendTableCommitmentError::AppendColumnCommitments {
source: AppendColumnCommitmentsError::DuplicateIdentifiers { .. },
}) => {
panic!("RecordBatches cannot have duplicate identifiers")
}
Err(AppendTableCommitmentError::AppendColumnCommitments {
source: AppendColumnCommitmentsError::Mismatch { source: e },
}) => Err(e)?,
}
}
/// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`].
pub fn try_from_record_batch(
batch: &RecordBatch,
setup: &C::PublicSetup<'_>,
) -> Result<TableCommitment<C>, RecordBatchToColumnsError> {
Self::try_from_record_batch_with_offset(batch, 0, setup)
}

/// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`] with the given row offset.
#[allow(clippy::missing_panics_doc)]
pub fn try_from_record_batch_with_offset(
batch: &RecordBatch,
offset: usize,
setup: &C::PublicSetup<'_>,
) -> Result<TableCommitment<C>, RecordBatchToColumnsError> {
match Self::try_from_columns_with_offset(
batch_to_columns::<C::Scalar>(batch, &Bump::new())?
.iter()
.map(|(a, b)| (a, b)),
offset,
setup,
) {
Ok(commitment) => Ok(commitment),
Err(TableCommitmentFromColumnsError::MixedLengthColumns { .. }) => {
panic!("RecordBatches cannot have columns of mixed length")
}
Err(TableCommitmentFromColumnsError::DuplicateIdentifiers { .. }) => {
panic!("RecordBatches cannot have duplicate identifiers")
}
}
}
}

#[cfg(all(test, feature = "blitzar"))]
mod tests {
use super::*;
use crate::{base::scalar::Curve25519Scalar, record_batch};
use curve25519_dalek::RistrettoPoint;

#[test]
fn we_can_create_and_append_table_commitments_with_record_batchs() {
let batch = record_batch!(
"a" => [1i64, 2, 3],
"b" => ["1", "2", "3"],
);

let b_scals = ["1".into(), "2".into(), "3".into()];

let columns = [
(
&"a".parse().unwrap(),
&Column::<Curve25519Scalar>::BigInt(&[1, 2, 3]),
),
(
&"b".parse().unwrap(),
&Column::<Curve25519Scalar>::VarChar((&["1", "2", "3"], &b_scals)),
),
];

let mut expected_commitment =
TableCommitment::<RistrettoPoint>::try_from_columns_with_offset(columns, 0, &())
.unwrap();

let mut commitment =
TableCommitment::<RistrettoPoint>::try_from_record_batch(&batch, &()).unwrap();

assert_eq!(commitment, expected_commitment);

let batch2 = record_batch!(
"a" => [4i64, 5, 6],
"b" => ["4", "5", "6"],
);

let b_scals2 = ["4".into(), "5".into(), "6".into()];

let columns2 = [
(
&"a".parse().unwrap(),
&Column::<Curve25519Scalar>::BigInt(&[4, 5, 6]),
),
(
&"b".parse().unwrap(),
&Column::<Curve25519Scalar>::VarChar((&["4", "5", "6"], &b_scals2)),
),
];

expected_commitment.try_append_rows(columns2, &()).unwrap();
commitment.try_append_record_batch(&batch2, &()).unwrap();

assert_eq!(commitment, expected_commitment);
}
}
38 changes: 38 additions & 0 deletions crates/proof-of-sql/src/base/arrow/record_batch_errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use super::arrow_array_to_column_conversion::ArrowArrayToColumnConversionError;
use crate::base::commitment::ColumnCommitmentsMismatch;
use proof_of_sql_parser::ParseError;
use snafu::Snafu;

/// Errors that can occur when trying to create or extend a [`TableCommitment`] from a record batch.
#[derive(Debug, Snafu)]
pub enum RecordBatchToColumnsError {
/// Error converting from arrow array
#[snafu(transparent)]
ArrowArrayToColumnConversionError {
/// The underlying source error
source: ArrowArrayToColumnConversionError,
},
#[snafu(transparent)]
/// This error occurs when convering from a record batch name to an identifier fails. (Which may be impossible.)
FieldParseFail {
/// The underlying source error
source: ParseError,
},
}

/// Errors that can occur when attempting to append a record batch to a [`TableCommitment`].
#[derive(Debug, Snafu)]
pub enum AppendRecordBatchTableCommitmentError {
/// During commitment operation, metadata indicates that operand tables cannot be the same.
#[snafu(transparent)]
ColumnCommitmentsMismatch {
/// The underlying source error
source: ColumnCommitmentsMismatch,
},
/// Error converting from arrow array
#[snafu(transparent)]
ArrowBatchToColumnError {
/// The underlying source error
source: RecordBatchToColumnsError,
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ macro_rules! record_batch {
use arrow::datatypes::Field;
use arrow::datatypes::Schema;
use arrow::record_batch::RecordBatch;
use $crate::base::database::ToArrow;
use $crate::base::arrow::record_batch_utility::ToArrow;

let schema = Arc::new(Schema::new(
vec![$(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,10 @@ pub fn convert_i256_to_scalar<S: Scalar>(value: &i256) -> Option<S> {

#[cfg(test)]
mod tests {

use super::{convert_i256_to_scalar, convert_scalar_to_i256};
use crate::base::{
database::scalar_and_i256_conversions::{MAX_SUPPORTED_I256, MIN_SUPPORTED_I256},
scalar::{Curve25519Scalar, Scalar},
use super::{
convert_i256_to_scalar, convert_scalar_to_i256, MAX_SUPPORTED_I256, MIN_SUPPORTED_I256,
};
use crate::base::scalar::{Curve25519Scalar, Scalar};
use arrow::datatypes::i256;
use num_traits::Zero;
use rand::RngCore;
Expand Down
Loading

0 comments on commit d8f1377

Please sign in to comment.