Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor]: Convert Vec<PhysicalExpr> to HashSet<PhysicalExpr> #13612

Merged
merged 5 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions datafusion/physical-expr-common/src/physical_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,24 @@ pub fn with_new_children_if_necessary(
/// Returns [`Display`] able a list of [`PhysicalExpr`]
///
/// Example output: `[a + 1, b]`
pub fn format_physical_expr_list(exprs: &[Arc<dyn PhysicalExpr>]) -> impl Display + '_ {
struct DisplayWrapper<'a>(&'a [Arc<dyn PhysicalExpr>]);
impl Display for DisplayWrapper<'_> {
pub fn format_physical_expr_list<T>(exprs: T) -> impl Display
where
T: IntoIterator,
T::Item: Display,
T::IntoIter: Clone,
{
struct DisplayWrapper<I>(I)
where
I: Iterator + Clone,
I::Item: Display;

impl<I> Display for DisplayWrapper<I>
where
I: Iterator + Clone,
I::Item: Display,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut iter = self.0.iter();
let mut iter = self.0.clone();
write!(f, "[")?;
if let Some(expr) = iter.next() {
write!(f, "{}", expr)?;
Expand All @@ -233,5 +246,6 @@ pub fn format_physical_expr_list(exprs: &[Arc<dyn PhysicalExpr>]) -> impl Displa
Ok(())
}
}
DisplayWrapper(exprs)

DisplayWrapper(exprs.into_iter())
}
35 changes: 17 additions & 18 deletions datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use std::fmt::Display;
use std::sync::Arc;

use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping};
use crate::{
expressions::Column, physical_expr::deduplicate_physical_exprs,
physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexRequirement,
expressions::Column, physical_exprs_contains, LexOrdering, LexRequirement,
PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement,
};
use indexmap::IndexSet;
use std::fmt::Display;
use std::sync::Arc;

use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::JoinType;
Expand Down Expand Up @@ -190,47 +189,47 @@ pub struct EquivalenceClass {
/// The expressions in this equivalence class. The order doesn't
/// matter for equivalence purposes
///
/// TODO: use a HashSet for this instead of a Vec
exprs: Vec<Arc<dyn PhysicalExpr>>,
exprs: IndexSet<Arc<dyn PhysicalExpr>>,
}

impl PartialEq for EquivalenceClass {
/// Returns true if other is equal in the sense
/// of bags (multi-sets), disregarding their orderings.
fn eq(&self, other: &Self) -> bool {
physical_exprs_bag_equal(&self.exprs, &other.exprs)
self.exprs.eq(&other.exprs)
}
}

impl EquivalenceClass {
/// Create a new empty equivalence class
pub fn new_empty() -> Self {
Self { exprs: vec![] }
Self {
exprs: IndexSet::new(),
}
}

// Create a new equivalence class from a pre-existing `Vec`
pub fn new(mut exprs: Vec<Arc<dyn PhysicalExpr>>) -> Self {
deduplicate_physical_exprs(&mut exprs);
Self { exprs }
pub fn new(exprs: Vec<Arc<dyn PhysicalExpr>>) -> Self {
Self {
exprs: exprs.into_iter().collect(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

}
}

/// Return the inner vector of expressions
pub fn into_vec(self) -> Vec<Arc<dyn PhysicalExpr>> {
self.exprs
self.exprs.into_iter().collect()
}

/// Return the "canonical" expression for this class (the first element)
/// if any
fn canonical_expr(&self) -> Option<Arc<dyn PhysicalExpr>> {
self.exprs.first().cloned()
self.exprs.iter().next().cloned()
}

/// Insert the expression into this class, meaning it is known to be equal to
/// all other expressions in this class
pub fn push(&mut self, expr: Arc<dyn PhysicalExpr>) {
if !self.contains(&expr) {
self.exprs.push(expr);
}
self.exprs.insert(expr);
}

/// Inserts all the expressions from other into this class
Expand All @@ -243,7 +242,7 @@ impl EquivalenceClass {

/// Returns true if this equivalence class contains t expression
pub fn contains(&self, expr: &Arc<dyn PhysicalExpr>) -> bool {
physical_exprs_contains(&self.exprs, expr)
self.exprs.contains(expr)
}

/// Returns true if this equivalence class has any entries in common with `other`
Expand Down
61 changes: 2 additions & 59 deletions datafusion/physical-expr/src/physical_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,34 +65,14 @@ pub fn physical_exprs_bag_equal(
}
}

/// This utility function removes duplicates from the given `exprs` vector.
/// Note that this function does not necessarily preserve its input ordering.
pub fn deduplicate_physical_exprs(exprs: &mut Vec<Arc<dyn PhysicalExpr>>) {
// TODO: Once we can use `HashSet`s with `Arc<dyn PhysicalExpr>`, this
// function should use a `HashSet` to reduce computational complexity.
// See issue: https://github.com/apache/datafusion/issues/8027
let mut idx = 0;
while idx < exprs.len() {
let mut rest_idx = idx + 1;
while rest_idx < exprs.len() {
if exprs[idx].eq(&exprs[rest_idx]) {
exprs.swap_remove(rest_idx);
} else {
rest_idx += 1;
}
}
idx += 1;
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::expressions::{Column, Literal};
use crate::physical_expr::{
deduplicate_physical_exprs, physical_exprs_bag_equal, physical_exprs_contains,
physical_exprs_equal, PhysicalExpr,
physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal,
PhysicalExpr,
};

use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -208,41 +188,4 @@ mod tests {
assert!(physical_exprs_bag_equal(list3.as_slice(), list3.as_slice()));
assert!(physical_exprs_bag_equal(list4.as_slice(), list4.as_slice()));
}

#[test]
fn test_deduplicate_physical_exprs() {
let lit_true = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))
as Arc<dyn PhysicalExpr>);
let lit_false = &(Arc::new(Literal::new(ScalarValue::Boolean(Some(false))))
as Arc<dyn PhysicalExpr>);
let lit4 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(4))))
as Arc<dyn PhysicalExpr>);
let lit2 = &(Arc::new(Literal::new(ScalarValue::Int32(Some(2))))
as Arc<dyn PhysicalExpr>);
let col_a_expr = &(Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>);
let col_b_expr = &(Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>);

// First vector in the tuple is arguments, second one is the expected value.
let test_cases = vec![
// ---------- TEST CASE 1----------//
(
vec![
lit_true, lit_false, lit4, lit2, col_a_expr, col_a_expr, col_b_expr,
lit_true, lit2,
],
vec![lit_true, lit_false, lit4, lit2, col_a_expr, col_b_expr],
),
// ---------- TEST CASE 2----------//
(
vec![lit_true, lit_true, lit_false, lit4],
vec![lit_true, lit4, lit_false],
),
];
for (exprs, expected) in test_cases {
let mut exprs = exprs.into_iter().cloned().collect::<Vec<_>>();
let expected = expected.into_iter().cloned().collect::<Vec<_>>();
deduplicate_physical_exprs(&mut exprs);
assert!(physical_exprs_equal(&exprs, &expected));
}
}
}