From 731cabb06f8c1da75a46c3ad22c992d3333941e8 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Thu, 19 Dec 2024 17:13:42 +0800 Subject: [PATCH] Implement a naive join order --- Cargo.lock | 1 + src/daft-logical-plan/Cargo.toml | 1 + .../rules/reorder_joins/greedy_join_order.rs | 153 ------------- .../rules/reorder_joins/join_graph.rs | 205 ++++++++++-------- .../optimization/rules/reorder_joins/mod.rs | 4 +- .../rules/reorder_joins/naive_join_order.rs | 181 ++++++++++++++++ 6 files changed, 298 insertions(+), 247 deletions(-) delete mode 100644 src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs create mode 100644 src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs diff --git a/Cargo.lock b/Cargo.lock index 34b37bc81f..fa690309ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2337,6 +2337,7 @@ dependencies = [ "log", "pretty_assertions", "pyo3", + "rand 0.8.5", "rstest", "serde", "snafu", diff --git a/src/daft-logical-plan/Cargo.toml b/src/daft-logical-plan/Cargo.toml index 1b4dab023f..dadde8b7f8 100644 --- a/src/daft-logical-plan/Cargo.toml +++ b/src/daft-logical-plan/Cargo.toml @@ -24,6 +24,7 @@ uuid = {version = "1", features = ["v4"]} [dev-dependencies] daft-dsl = {path = "../daft-dsl", features = ["test-utils"]} pretty_assertions = {workspace = true} +rand = "0.8" rstest = {workspace = true} test-log = {workspace = true} diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs deleted file mode 100644 index 0f5a12592d..0000000000 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs +++ /dev/null @@ -1,153 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use common_error::DaftResult; -use daft_dsl::{col, ExprRef}; - -use super::join_graph::{JoinCondition, JoinGraph}; -use crate::{LogicalPlanBuilder, LogicalPlanRef}; - -// This is an implementation of the Greedy Operator Ordering algorithm (GOO) [1] for join selection. This algorithm -// selects join edges greedily by picking the edge with the smallest cost at each step. This is similar to Kruskal's -// minimum spanning tree algorithm, with the caveat that edge costs update at each step, due to changing cardinalities -// and selectivities between join nodes. -// -// Compared to DP-based algorithms, GOO is not always optimal. However, GOO has a complexity of O(n^3) and is more viable -// than DP-based algorithms when performing join ordering on many relations. DP Connected subgraph Complement Pairs (DPccp) [2] -// is the DP-based algorithm widely used in database systems today and has a O(3^n) complexity, although the latest -// literature does offer a super-polynomially faster DP-algorithm but that still has a O(2^n) to O(2^n * n^3) complexity [3]. -// -// For this reason, we maintain a greedy-based join ordering algorithm to use when the number of relations is large, and default -// to DP-based algorithms otherwise. -// -// [1]: Fegaras, L. (1998). A New Heuristic for Optimizing Large Queries. International Conference on Database and Expert Systems Applications. -// [2]: Moerkotte, G., & Neumann, T. (2006). Analysis of two existing and one new dynamic programming algorithm for the generation of optimal bushy join trees without cross products. Very Large Data Bases Conference. -// [3]: Stoian, M., & Kipf, A. (2024). DPconv: Super-Polynomially Faster Join Ordering. ArXiv, abs/2409.08013. -pub(crate) struct GreedyJoinOrderer {} - -impl GreedyJoinOrderer { - /// Consumes the join graph and transforms it into a logical plan with joins reordered. - pub(crate) fn compute_join_order(join_graph: &mut JoinGraph) -> DaftResult { - // While the join graph consists of more than one join node, select the edge that has the smallest cost, - // then join the left and right nodes connected by this edge. - while join_graph.adj_list.0.len() > 1 { - let selected_pair = GreedyJoinOrderer::find_minimum_cost_join(&join_graph.adj_list.0); - if let Some((left, right, join_conds)) = selected_pair { - // Join the left and right relations using the given join conditions. - let (left_on, right_on) = join_conds - .iter() - .map(|join_cond| { - ( - col(join_cond.left_on.clone()), - col(join_cond.right_on.clone()), - ) - }) - .collect::<(Vec, Vec)>(); - let left_builder = LogicalPlanBuilder::from(left.clone()); - let join = left_builder - .inner_join(right.clone(), left_on, right_on)? - .build(); - let join = Arc::new(Arc::unwrap_or_clone(join).with_materialized_stats()); - - // Add the new node into the adjacency list. - let left_neighbors = join_graph.adj_list.0.remove(&left).unwrap(); - let right_neighbors = join_graph.adj_list.0.remove(&right).unwrap(); - let mut new_join_edges = HashMap::new(); - - // Helper function that takes in neighbors to the left and right nodes, then combines edges that point - // back to the left and/or right nodes into edges that point to the new join node. - let mut update_neighbors = - |neighbors: HashMap>| { - for (neighbor, _) in neighbors { - if neighbor == right || neighbor == left { - // Skip the nodes that we just joined. - continue; - } - let mut join_conditions = Vec::new(); - // If this neighbor was connected to left or right nodes, collect the join conditions. - let neighbor_edges = join_graph - .adj_list - .0 - .get_mut(&neighbor) - .expect("The neighbor should still be in the join graph"); - if let Some(left_conds) = neighbor_edges.remove(&left) { - join_conditions.extend(left_conds); - } - if let Some(right_conds) = neighbor_edges.remove(&right) { - join_conditions.extend(right_conds); - } - // If this neighbor had any connections to left or right, create a new edge to the new join node. - if !join_conditions.is_empty() { - neighbor_edges.insert(join.clone(), join_conditions.clone()); - new_join_edges.insert( - neighbor.clone(), - join_conditions.iter().map(|cond| cond.flip()).collect(), - ); - } - } - }; - - // Process all neighbors from both the left and right sides. - update_neighbors(left_neighbors); - update_neighbors(right_neighbors); - - // Add the new join node and its edges to the graph. - join_graph.adj_list.0.insert(join, new_join_edges); - } else { - panic!( - "No valid join edge selected despite join graph containing more than one relation" - ); - } - } - // Apply projections and filters on top of the fully joined plan. - if let Some(joined_plan) = join_graph.adj_list.0.drain().map(|(plan, _)| plan).last() { - join_graph.apply_projections_and_filters_to_plan(joined_plan) - } else { - panic!("No valid logical plan after join reordering") - } - } - - /// Helper functions that finds the next join edge in the adjacency list that has the smallest cost. - /// Currently cost is determined based on the max size in bytes of the candidate left and right relations. - fn find_minimum_cost_join( - adj_list: &HashMap>>, - ) -> Option<(LogicalPlanRef, LogicalPlanRef, Vec)> { - let mut min_cost = None; - let mut selected_pair = None; - - for (candidate_left, neighbors) in adj_list { - for (candidate_right, join_conds) in neighbors { - let left_stats = candidate_left.materialized_stats(); - let right_stats = candidate_right.materialized_stats(); - - // Assume primary key foreign key join which would have a size bounded by the foreign key relation, - // which is typically larger. - let cur_cost = left_stats - .approx_stats - .upper_bound_bytes - .max(right_stats.approx_stats.upper_bound_bytes); - - if let Some(existing_min) = min_cost { - if let Some(current) = cur_cost { - if current < existing_min { - min_cost = Some(current); - selected_pair = Some(( - candidate_left.clone(), - candidate_right.clone(), - join_conds.clone(), - )); - } - } - } else { - min_cost = cur_cost; - selected_pair = Some(( - candidate_left.clone(), - candidate_right.clone(), - join_conds.clone(), - )); - } - } - } - - selected_pair - } -} diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index e7c09f3174..2aa5234ff0 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -13,8 +13,42 @@ use crate::{ LogicalPlan, LogicalPlanBuilder, LogicalPlanRef, }; -#[derive(Debug)] -struct JoinNode { +// TODO(desmond): In the future these trees should keep track of current cost estimates. +#[derive(Clone, Debug)] +pub(super) enum JoinOrderTree { + Relation(usize), // (id). + Join(Box, Box, Vec), // (subtree, subtree, nodes involved). +} + +impl JoinOrderTree { + pub(super) fn join(self: Box, right: Box) -> Box { + let mut nodes = self.nodes(); + nodes.append(&mut right.nodes()); + Box::new(JoinOrderTree::Join(self, right, nodes)) + } + + pub(super) fn nodes(&self) -> Vec { + match self { + Self::Relation(id) => vec![*id], + Self::Join(_, _, nodes) => nodes.clone(), + } + } + + // Helper function that checks if the join order tree contains a given id. + pub(super) fn contains(&self, target_id: usize) -> bool { + match self { + Self::Relation(id) => *id == target_id, + Self::Join(left, right, _) => left.contains(target_id) || right.contains(target_id), + } + } +} + +pub(super) trait JoinOrderer { + fn order(&self, graph: &JoinGraph) -> Box; +} + +#[derive(Clone, Debug)] +pub(super) struct JoinNode { relation_name: String, plan: LogicalPlanRef, final_name: String, @@ -27,7 +61,7 @@ struct JoinNode { /// JoinNodes represent a relation (i.e. a non-reorderable logical plan node), the column /// that's being accessed from the relation, and the final name of the column in the output. impl JoinNode { - fn new(relation_name: String, plan: LogicalPlanRef, final_name: String) -> Self { + pub(super) fn new(relation_name: String, plan: LogicalPlanRef, final_name: String) -> Self { Self { relation_name, plan, @@ -52,31 +86,31 @@ impl Display for JoinNode { } #[derive(Clone, Debug)] -pub(crate) struct JoinCondition { +pub(super) struct JoinCondition { pub left_on: String, pub right_on: String, } -impl JoinCondition { - pub(crate) fn flip(&self) -> Self { - JoinCondition { - left_on: self.right_on.clone(), - right_on: self.left_on.clone(), - } - } +pub(super) struct JoinAdjList { + pub max_id: usize, + plan_to_id: HashMap<*const LogicalPlan, usize>, + id_to_plan: HashMap, + pub edges: HashMap>>, } -pub(crate) struct JoinAdjList( - pub HashMap>>, -); - impl std::fmt::Display for JoinAdjList { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "Join Graph Adjacency List:")?; - for (node, neighbors) in &self.0 { - writeln!(f, "Node {}:", node.name())?; - for (neighbor, join_conds) in neighbors { - writeln!(f, " -> {} with conditions:", neighbor.name())?; + for (node_id, neighbors) in &self.edges { + let node = self.id_to_plan.get(node_id).unwrap(); + writeln!(f, "Node {} (id = {node_id}):", node.name())?; + for (neighbor_id, join_conds) in neighbors { + let neighbor = self.id_to_plan.get(neighbor_id).unwrap(); + writeln!( + f, + " -> {} (id = {neighbor_id}) with conditions:", + neighbor.name() + )?; for (i, cond) in join_conds.iter().enumerate() { writeln!(f, " {}: {} = {}", i, cond.left_on, cond.right_on)?; } @@ -87,28 +121,66 @@ impl std::fmt::Display for JoinAdjList { } impl JoinAdjList { + pub(super) fn empty() -> Self { + Self { + max_id: 0, + plan_to_id: HashMap::new(), + id_to_plan: HashMap::new(), + edges: HashMap::new(), + } + } + + pub(super) fn get_plan_id(&mut self, plan: &LogicalPlanRef) -> usize { + let ptr = Arc::as_ptr(plan); + if let Some(id) = self.plan_to_id.get(&ptr) { + *id + } else { + let id = self.max_id; + self.max_id += 1; + self.plan_to_id.insert(ptr, id); + self.id_to_plan.insert(id, plan.clone()); + id + } + } + fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) { // TODO(desmond): We should also keep track of projections that we need to do. let join_condition = JoinCondition { left_on: left.final_name.clone(), right_on: right.final_name.clone(), }; - if let Some(neighbors) = self.0.get_mut(&left.plan) { - if let Some(join_conditions) = neighbors.get_mut(&right.plan) { + let left_id = self.get_plan_id(&left.plan); + let right_id = self.get_plan_id(&right.plan); + if let Some(neighbors) = self.edges.get_mut(&left_id) { + if let Some(join_conditions) = neighbors.get_mut(&right_id) { join_conditions.push(join_condition); } else { - neighbors.insert(right.plan.clone(), vec![join_condition]); + neighbors.insert(right_id, vec![join_condition]); } } else { let mut neighbors = HashMap::new(); - neighbors.insert(right.plan.clone(), vec![join_condition]); - self.0.insert(left.plan.clone(), neighbors); + neighbors.insert(right_id, vec![join_condition]); + self.edges.insert(left_id, neighbors); } } - fn add_bidirectional_edge(&mut self, node1: JoinNode, node2: JoinNode) { + + pub(super) fn add_bidirectional_edge(&mut self, node1: JoinNode, node2: JoinNode) { self.add_unidirectional_edge(&node1, &node2); self.add_unidirectional_edge(&node2, &node1); } + + pub(super) fn connected(&self, left_nodes: &Vec, right_nodes: &Vec) -> bool { + for left_node in left_nodes { + if let Some(neighbors) = self.edges.get(left_node) { + for right_node in right_nodes { + if let Some(_) = neighbors.get(right_node) { + return true; + } + } + } + } + return false; + } } #[derive(Debug)] @@ -138,42 +210,10 @@ impl JoinGraph { } } - pub(crate) fn apply_projections_and_filters_to_plan( - &mut self, - plan: LogicalPlanRef, - ) -> DaftResult { - let mut plan = LogicalPlanBuilder::from(plan); - // Apply projections and filters in post-traversal order. - let mut reversed_items = self - .final_projections_and_filters - .drain(..) - .rev() - .peekable(); - while let Some(projection_or_filter) = reversed_items.next() { - let is_last = reversed_items.peek().is_none(); - - match projection_or_filter { - ProjectionOrFilter::Projection(projections) => { - if is_last { - // The final projection is the output projection, so here we select the final projection. - plan = plan.select(projections)?; - } else { - // Intermediate projections might only transform a subset of columns, so we use `with_columns()` instead of `select()`. - plan = plan.with_columns(projections)?; - } - } - ProjectionOrFilter::Filter(predicate) => { - plan = plan.filter(predicate)?; - } - } - } - Ok(plan.build()) - } - /// Test helper function to get the number of edges that the current graph contains. pub(crate) fn num_edges(&self) -> usize { let mut num_edges = 0; - for (_, edges) in &self.adj_list.0 { + for (_, edges) in &self.adj_list.edges { num_edges += edges.len(); } // Each edge is bidirectional, so we divide by 2 to get the correct number of edges. @@ -182,7 +222,7 @@ impl JoinGraph { /// Test helper function to check that all relations in this graph are connected. pub(crate) fn fully_connected(&self) -> bool { - let start = if let Some((node, _)) = self.adj_list.0.iter().next() { + let start = if let Some((node, _)) = self.adj_list.edges.iter().next() { node } else { // There are no nodes. The empty graph is fully connected. @@ -195,7 +235,7 @@ impl JoinGraph { while let Some(current) = stack.pop() { if seen.insert(current) { // If this is a new node, add all its neighbors to the stack. - if let Some(neighbors) = self.adj_list.0.get(current) { + if let Some(neighbors) = self.adj_list.edges.get(current) { stack.extend(neighbors.iter().filter_map(|(neighbor, _)| { if !seen.contains(neighbor) { Some(neighbor) @@ -206,7 +246,7 @@ impl JoinGraph { } } } - seen.len() == self.adj_list.0.len() + seen.len() == self.adj_list.max_id } /// Test helper function that checks if the graph contains the given projection/filter expressions @@ -236,8 +276,10 @@ impl JoinGraph { /// exists in the current graph. pub(crate) fn contains_edges(&self, to_check: Vec<&str>) -> bool { let mut edge_strings = HashSet::new(); - for (left, neighbors) in &self.adj_list.0 { - for (right, join_conds) in neighbors { + for (left_id, neighbors) in &self.adj_list.edges { + for (right_id, join_conds) in neighbors { + let left = self.adj_list.id_to_plan.get(left_id).unwrap(); + let right = self.adj_list.id_to_plan.get(right_id).unwrap(); for join_cond in join_conds { edge_strings.insert(format!( "{}({}) <-> {}({})", @@ -288,7 +330,7 @@ impl JoinGraphBuilder { plan, join_conds_to_resolve: vec![], final_name_map: HashMap::new(), - adj_list: JoinAdjList(HashMap::new()), + adj_list: JoinAdjList::empty(), final_projections_and_filters: vec![ProjectionOrFilter::Projection(output_projection)], } } @@ -487,10 +529,7 @@ mod tests { use super::JoinGraphBuilder; use crate::{ - optimization::rules::{ - reorder_joins::greedy_join_order::GreedyJoinOrderer, EnrichWithStats, MaterializeScans, - OptimizerRule, - }, + optimization::rules::{EnrichWithStats, MaterializeScans, OptimizerRule}, test::{ dummy_scan_node_with_pushdowns, dummy_scan_operator, dummy_scan_operator_with_size, }, @@ -553,7 +592,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b @@ -565,9 +604,6 @@ mod tests { "Project(c) <-> Source(d)", "Source(a) <-> Source(d)" ])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -627,7 +663,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b @@ -639,9 +675,6 @@ mod tests { "Project(c) <-> Source(d)", "Source(b) <-> Source(d)", ])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -695,7 +728,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a_beta <-> b @@ -705,9 +738,6 @@ mod tests { "Project(a_beta) <-> Source(b)", "Project(a_beta) <-> Source(c)", ])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -768,7 +798,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b @@ -784,9 +814,6 @@ mod tests { // `c_prime` gets renamed to `c` in the final projection let double_proj = col("c").add(col("c")).alias("double"); assert!(join_graph.contains_projections_and_filters(vec![&double_proj])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -865,7 +892,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b @@ -887,9 +914,6 @@ mod tests { &double_proj, &filter_c_prime, ])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -963,7 +987,7 @@ mod tests { .unwrap(); let stats_enricher = EnrichWithStats::new(); let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); - let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); + let join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b @@ -977,8 +1001,5 @@ mod tests { ])); // Projections below the aggregation should not be part of the final projections. assert!(!join_graph.contains_projections_and_filters(vec![&a_proj])); - // Test greedy join reordering. - let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); - assert!(reordered_plan.schema() == original_plan.schema()); } } diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs index 58987555ab..c8644b620e 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs @@ -1,4 +1,4 @@ #[cfg(test)] -mod greedy_join_order; -#[cfg(test)] mod join_graph; +#[cfg(test)] +mod naive_join_order; diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs new file mode 100644 index 0000000000..0fba6e4fa3 --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/naive_join_order.rs @@ -0,0 +1,181 @@ +use super::join_graph::{JoinGraph, JoinOrderTree, JoinOrderer}; + +pub(crate) struct NaiveJoinOrderer {} + +impl NaiveJoinOrderer { + fn extend_order( + graph: &JoinGraph, + current_order: Box, + mut available: Vec, + ) -> Box { + if available.is_empty() { + return current_order; + } + for (index, candidate_node_id) in available.iter().enumerate() { + let right = Box::new(JoinOrderTree::Relation(*candidate_node_id)); + if graph + .adj_list + .connected(¤t_order.nodes(), &right.nodes()) + { + let new_order = current_order.join(right); + available.remove(index); + return Self::extend_order(graph, new_order, available); + } + } + panic!("There should be at least one naive join order."); + } +} + +impl JoinOrderer for NaiveJoinOrderer { + fn order(&self, graph: &JoinGraph) -> Box { + let available: Vec = (1..graph.adj_list.max_id).collect(); + // Take a starting order of the node with id 0. + let starting_order = Box::new(JoinOrderTree::Relation(0)); + Self::extend_order(graph, starting_order, available) + } +} + +#[cfg(test)] +mod tests { + use common_scan_info::Pushdowns; + use daft_schema::{dtype::DataType, field::Field}; + use rand::{seq::SliceRandom, Rng}; + + use super::{JoinGraph, JoinOrderTree, JoinOrderer, NaiveJoinOrderer}; + use crate::{ + optimization::rules::reorder_joins::join_graph::{JoinAdjList, JoinNode}, + test::{dummy_scan_node_with_pushdowns, dummy_scan_operator_with_size}, + LogicalPlanRef, + }; + + fn assert_order_contains_all_nodes(order: &Box, graph: &JoinGraph) { + for id in 0..graph.adj_list.max_id { + assert!( + order.contains(id), + "Graph id {} not found in order {:?}.\n{}", + id, + order, + graph.adj_list + ); + } + } + + fn create_scan_node(name: &str, size: Option) -> LogicalPlanRef { + dummy_scan_node_with_pushdowns( + dummy_scan_operator_with_size(vec![Field::new(name, DataType::Int64)], size), + Pushdowns::default(), + ) + .build() + } + + fn create_join_graph_with_edges(nodes: Vec, edges: Vec<(usize, usize)>) -> JoinGraph { + let mut adj_list = JoinAdjList::empty(); + for (from, to) in edges { + adj_list.add_bidirectional_edge(nodes[from].clone(), nodes[to].clone()); + } + JoinGraph::new(adj_list, vec![]) + } + + macro_rules! create_and_test_join_graph { + ($nodes:expr, $edges:expr, $orderer:expr) => { + let nodes: Vec = $nodes + .iter() + .map(|name| { + let scan_node = create_scan_node(name, Some(100)); + JoinNode::new(name.to_string(), scan_node, name.to_string()) + }) + .collect(); + let graph = create_join_graph_with_edges(nodes.clone(), $edges); + let order = $orderer.order(&graph); + assert_order_contains_all_nodes(&order, &graph); + }; + } + + #[test] + fn test_order_basic_join_graph() { + let nodes = vec!["a", "b", "c", "d"]; + let edges = vec![ + (0, 2), // node_a <-> node_c + (1, 2), // node_b <-> node_c + (2, 3), // node_c <-> node_d + ]; + create_and_test_join_graph!(nodes, edges, NaiveJoinOrderer {}); + } + + pub struct UnionFind { + parent: Vec, + size: Vec, + } + + impl UnionFind { + pub fn create(num_nodes: usize) -> Self { + UnionFind { + parent: (0..num_nodes).collect(), + size: vec![1; num_nodes], + } + } + + pub fn find(&mut self, node: usize) -> usize { + if self.parent[node] != node { + self.parent[node] = self.find(self.parent[node]); + } + self.parent[node] + } + + pub fn union(&mut self, node1: usize, node2: usize) { + let root1 = self.find(node1); + let root2 = self.find(node2); + + if root1 != root2 { + let (small, big) = if self.size[root1] < self.size[root2] { + (root1, root2) + } else { + (root2, root1) + }; + self.parent[small] = big; + self.size[big] += self.size[small]; + } + } + } + + fn create_random_connected_graph(num_nodes: usize) -> Vec<(usize, usize)> { + let mut rng = rand::thread_rng(); + let mut edges = Vec::new(); + let mut uf = UnionFind::create(num_nodes); + + // Get a random order of all possible edges. + let mut all_edges: Vec<(usize, usize)> = (0..num_nodes) + .flat_map(|i| (0..i).chain(i + 1..num_nodes).map(move |j| (i, j))) + .collect(); + all_edges.shuffle(&mut rng); + + // Select edges to form a minimum spanning tree + a random number of extra edges. + for (a, b) in all_edges { + if uf.find(a) != uf.find(b) { + uf.union(a, b); + edges.push((a, b)); + } + // Check if we have a minimum spanning tree. + if edges.len() >= num_nodes - 1 { + // Once we have a minimum spanning tree, we let a random number of extra edges be added to the graph. + if rng.gen_bool(0.3) { + break; + } + edges.push((a, b)); + } + } + + edges + } + + const NUM_RANDOM_NODES: usize = 100; + + #[test] + fn test_order_random_join_graph() { + let nodes: Vec = (0..NUM_RANDOM_NODES) + .map(|i| format!("node_{}", i)) + .collect(); + let edges = create_random_connected_graph(NUM_RANDOM_NODES); + create_and_test_join_graph!(nodes, edges, NaiveJoinOrderer {}); + } +}