Skip to content

Commit

Permalink
Introduce FieldPath abstraction, restrict predicates to Field, Op, (F…
Browse files Browse the repository at this point in the history
…ield|Scalar) (#324)

We need an abstraction to handle addressing nested fields, hence
FieldPath.

Also disallows scalar-scalar comparisons, which are useless.
  • Loading branch information
jdcasale authored May 16, 2024
1 parent 75245f9 commit e7e9952
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 60 deletions.
100 changes: 100 additions & 0 deletions vortex-dtype/src/field_paths.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use core::fmt;
use std::fmt::{Display, Formatter};

use vortex_error::{vortex_bail, VortexResult};

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FieldPath {
field_names: Vec<FieldIdentifier>,
}

impl FieldPath {
pub fn builder() -> FieldPathBuilder {
FieldPathBuilder::default()
}
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum FieldIdentifier {
Name(String),
ListIndex(u64),
}

pub struct FieldPathBuilder {
field_names: Vec<FieldIdentifier>,
}

impl FieldPathBuilder {
pub fn new() -> Self {
Self {
field_names: Vec::new(),
}
}

pub fn join<T: Into<FieldIdentifier>>(mut self, identifier: T) -> Self {
self.field_names.push(identifier.into());
self
}

pub fn build(self) -> VortexResult<FieldPath> {
if self.field_names.is_empty() {
vortex_bail!("Cannot build empty path");
}
Ok(FieldPath {
field_names: self.field_names,
})
}
}

impl Default for FieldPathBuilder {
fn default() -> Self {
Self::new()
}
}

pub fn field(x: impl Into<FieldIdentifier>) -> FieldPath {
x.into().into()
}

impl From<FieldIdentifier> for FieldPath {
fn from(value: FieldIdentifier) -> Self {
FieldPath {
field_names: vec![value],
}
}
}

impl From<&str> for FieldIdentifier {
fn from(value: &str) -> Self {
FieldIdentifier::Name(value.to_string())
}
}

impl From<u64> for FieldIdentifier {
fn from(value: u64) -> Self {
FieldIdentifier::ListIndex(value)
}
}

impl Display for FieldIdentifier {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
FieldIdentifier::Name(name) => write!(f, "${name}"),
FieldIdentifier::ListIndex(idx) => write!(f, "[{idx}]"),
}
}
}

impl Display for FieldPath {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let formatted = self
.field_names
.iter()
.map(|fid| format!("{fid}"))
.collect::<Vec<_>>()
.join(".");
write!(f, "{}", formatted)
}
}
3 changes: 3 additions & 0 deletions vortex-dtype/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ pub use extension::*;
pub use half;
pub use nullability::*;
pub use ptype::*;

mod dtype;
mod extension;
pub mod field_paths;
mod nullability;
mod ptype;
mod serde;
Expand All @@ -28,5 +30,6 @@ pub mod flatbuffers {
mod generated {
include!(concat!(env!("OUT_DIR"), "/flatbuffers/dtype.rs"));
}

pub use generated::vortex::dtype::*;
}
39 changes: 24 additions & 15 deletions vortex-expr/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ impl Display for Predicate {
impl Display for Value {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Value::Field(expr) => std::fmt::Display::fmt(expr, f),
Value::Literal(scalar) => scalar.fmt(f),
Value::Field(field_path) => Display::fmt(field_path, f),
Value::Literal(scalar) => Display::fmt(&scalar, f),
}
}
}
Expand All @@ -55,32 +55,40 @@ impl Display for Operator {

#[cfg(test)]
mod tests {
use vortex_dtype::field_paths::{field, FieldPath};

use crate::expressions::{lit, Conjunction, Disjunction};
use crate::field_paths::FieldPathOperations;

#[test]
fn test_predicate_formatting() {
// And
assert_eq!(format!("{}", lit(1u32).lt(lit(2u32))), "(1 < 2)");
// Or
assert_eq!(format!("{}", lit(1u32).gte(lit(2u32))), "(1 >= 2)");
// Not
assert_eq!(format!("{}", !lit(1u32).lte(lit(2u32))), "(1 > 2)");
let f1 = field("field");
assert_eq!(format!("{}", f1.clone().lt(lit(1u32))), "($field < 1)");
assert_eq!(format!("{}", f1.clone().gte(lit(1u32))), "($field >= 1)");
assert_eq!(format!("{}", !f1.clone().lte(lit(1u32))), "($field > 1)");
assert_eq!(format!("{}", !lit(1u32).lte(f1)), "($field <= 1)");

// nested field path
let f2 = FieldPath::builder().join("field").join(0).build().unwrap();
assert_eq!(format!("{}", !f2.lte(lit(1u32))), "($field.[0] > 1)");
}

#[test]
fn test_dnf_formatting() {
let path = FieldPath::builder().join(2).join("col1").build().unwrap();
let d1 = Conjunction {
predicates: vec![
lit(1u32).lt(lit(2u32)),
lit(1u32).gte(lit(2u32)),
!lit(1u32).lte(lit(2u32)),
lit(1u32).lt(path.clone()),
path.clone().gte(lit(1u32)),
!lit(1u32).lte(path),
],
};
let path2 = FieldPath::builder().join("col1").join(2).build().unwrap();
let d2 = Conjunction {
predicates: vec![
lit(2u32).lt(lit(3u32)),
lit(3u32).gte(lit(4u32)),
!lit(5u32).lte(lit(6u32)),
lit(2u32).lt(path2),
lit(3u32).gte(field(2)),
!lit(5u32).lte(field("col2")),
],
};

Expand All @@ -92,7 +100,8 @@ mod tests {
print!("{}", string);
assert_eq!(
string,
"(1 < 2) AND (1 >= 2) AND (1 > 2)\nOR \n(2 < 3) AND (3 >= 4) AND (5 > 6)"
"([2].$col1 >= 1) AND ([2].$col1 >= 1) AND ([2].$col1 <= 1)\nOR \
\n($col1.[2] >= 2) AND ([2] < 3) AND ($col2 <= 5)"
);
}
}
95 changes: 50 additions & 45 deletions vortex-expr/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use vortex_dtype::FieldName;
use vortex_dtype::field_paths::FieldPath;
use vortex_scalar::Scalar;

use crate::expressions::Value::Field;
use crate::operators::Operator;

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize),
Expand All @@ -27,86 +27,91 @@ pub struct Conjunction {
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum Value {
/// A named reference to a qualified field in a dtype.
Field(FieldName),
Field(FieldPath),
/// A constant scalar value.
Literal(Scalar),
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Predicate {
pub left: FieldPath,
pub op: Operator,
pub right: Value,
}

pub fn lit<T: Into<Scalar>>(n: T) -> Value {
Value::Literal(n.into())
}

impl Value {
pub fn field(field_name: impl Into<FieldName>) -> Value {
Field(field_name.into())
}
// comparisons
pub fn eq(self, other: Value) -> Predicate {
// NB: We rewrite predicates to be Field-op-predicate, so these methods all must
// use the inverse operator.
pub fn eq(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: self,
left: field.into(),
op: Operator::EqualTo,
right: other,
right: self,
}
}

pub fn not_eq(self, other: Value) -> Predicate {
pub fn not_eq(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: self,
op: Operator::NotEqualTo,
right: other,
left: field.into(),
op: Operator::NotEqualTo.inverse(),
right: self,
}
}

pub fn gt(self, other: Value) -> Predicate {
pub fn gt(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: self,
op: Operator::GreaterThan,
right: other,
left: field.into(),
op: Operator::GreaterThan.inverse(),
right: self,
}
}

pub fn gte(self, other: Value) -> Predicate {
pub fn gte(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: self,
op: Operator::GreaterThanOrEqualTo,
right: other,
left: field.into(),
op: Operator::GreaterThanOrEqualTo.inverse(),
right: self,
}
}

pub fn lt(self, other: Value) -> Predicate {
pub fn lt(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: self,
op: Operator::LessThan,
right: other,
left: field.into(),
op: Operator::LessThan.inverse(),
right: self,
}
}

pub fn lte(self, other: Value) -> Predicate {
pub fn lte(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: self,
op: Operator::LessThanOrEqualTo,
right: other,
left: field.into(),
op: Operator::LessThanOrEqualTo.inverse(),
right: self,
}
}
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Predicate {
pub left: Value,
pub op: Operator,
pub right: Value,
}

pub fn lit<T: Into<Scalar>>(n: T) -> Value {
Value::Literal(n.into())
}

#[cfg(test)]
mod test {
use vortex_dtype::field_paths::field;

use super::*;

#[test]
fn test_lit() {
let scalar: Scalar = 1.into();
let rhs: Value = lit(scalar);
let expr = Value::field("id").eq(rhs);
assert_eq!(format!("{}", expr), "(id = 1)");
let value: Value = lit(scalar);
let field = field("id");
let expr = Predicate {
left: field,
op: Operator::EqualTo,
right: value,
};
assert_eq!(format!("{}", expr), "($id = 1)");
}
}
Loading

0 comments on commit e7e9952

Please sign in to comment.