diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 3affeecc4c..0a7d2de10a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -582,6 +582,12 @@ jobs: run: | uv pip install -r requirements-dev.txt dist/${{ env.package-name }}-*x86_64*.whl --force-reinstall rm -rf daft + - name: Install ODBC Driver 18 for SQL Server + run: | + curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add - + sudo add-apt-repository https://packages.microsoft.com/ubuntu/$(lsb_release -rs)/prod + sudo apt-get update + sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18 - name: Spin up services run: | pushd ./tests/integration/sql/docker-compose/ diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index cd96a9a480..8dbb33111e 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -1659,8 +1659,8 @@ def repartition(self, num: Optional[int], *partition_by: ColumnInputType) -> "Da def into_partitions(self, num: int) -> "DataFrame": """Splits or coalesces DataFrame to ``num`` partitions. Order is preserved. - No rebalancing is done; the minimum number of splits or merges are applied. - (i.e. if there are 2 partitions, and change it into 3, this function will just split the bigger one) + This will naively greedily split partitions in a round-robin fashion to hit the targeted number of partitions. + The number of rows/size in a given partition is not taken into account during the splitting. Example: >>> import daft diff --git a/daft/delta_lake/delta_lake_scan.py b/daft/delta_lake/delta_lake_scan.py index a5c7b94b29..f1357c2fa7 100644 --- a/daft/delta_lake/delta_lake_scan.py +++ b/daft/delta_lake/delta_lake_scan.py @@ -17,6 +17,7 @@ ScanTask, StorageConfig, ) +from daft.io.aws_config import boto3_client_from_s3_config from daft.io.object_store_options import io_config_to_storage_options from daft.io.scan import PartitionField, ScanOperator from daft.logical.schema import Schema @@ -43,6 +44,24 @@ def __init__( deltalake_sdk_io_config = storage_config.config.io_config scheme = urlparse(table_uri).scheme if scheme == "s3" or scheme == "s3a": + # Try to get region from boto3 + if deltalake_sdk_io_config.s3.region_name is None: + from botocore.exceptions import BotoCoreError + + try: + client = boto3_client_from_s3_config("s3", deltalake_sdk_io_config.s3) + response = client.get_bucket_location(Bucket=urlparse(table_uri).netloc) + except BotoCoreError as e: + logger.warning( + "Failed to get the S3 bucket region using existing storage config, will attempt to get it from the environment instead. Error from boto3: %s", + e, + ) + else: + deltalake_sdk_io_config = deltalake_sdk_io_config.replace( + s3=deltalake_sdk_io_config.s3.replace(region_name=response["LocationConstraint"]) + ) + + # Try to get config from the environment if any([deltalake_sdk_io_config.s3.key_id is None, deltalake_sdk_io_config.s3.region_name is None]): try: s3_config_from_env = S3Config.from_env() diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index daa9afa289..ea490073ea 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -1044,3 +1044,33 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) ) return results + + +@dataclass(frozen=True) +class FanoutEvenSlices(FanoutInstruction): + def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + [input] = inputs + results = [] + + input_length = len(input) + num_outputs = self.num_outputs() + + chunk_size, remainder = divmod(input_length, num_outputs) + ptr = 0 + for output_idx in range(self.num_outputs()): + end = ptr + chunk_size + 1 if output_idx < remainder else ptr + chunk_size + results.append(input.slice(ptr, end)) + ptr = end + assert ptr == input_length + + return results + + def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: + # TODO: Derive this based on the ratios of num rows + return [ + PartialPartitionMetadata( + num_rows=None, + size_bytes=None, + ) + for _ in range(self._num_outputs) + ] diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 2b65a35f21..e3a8501ce5 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -250,10 +250,6 @@ def actor_pool_project( actor_pool_name = f"{stateful_udf_names}-stage={stage_id}" # Keep track of materializations of the children tasks - # - # Our goal here is to saturate the actors, and so we need a sufficient number of completed child tasks to do so. However - # we do not want too many child tasks to be running (potentially starving our actors) and hence place an upper bound of `num_actors * 2` - child_materializations_buffer_len = num_actors * 2 child_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() # Keep track of materializations of the actor_pool tasks @@ -313,8 +309,8 @@ def actor_pool_project( if len(child_materializations) > 0 or len(actor_pool_materializations) > 0: yield None - # If there is capacity in the pipeline, attempt to schedule child work - elif len(child_materializations) < child_materializations_buffer_len: + # Attempt to schedule child work + else: try: child_step = next(child_plan) except StopIteration: @@ -326,10 +322,6 @@ def actor_pool_project( child_materializations.append(child_step) yield child_step - # Otherwise, indicate that we need to wait for work to complete - else: - yield None - def monotonically_increasing_id( child_plan: InProgressPhysicalPlan[PartitionT], column_name: str @@ -1351,61 +1343,30 @@ def split( num_input_partitions: int, num_output_partitions: int, ) -> InProgressPhysicalPlan[PartitionT]: - """Repartition the child_plan into more partitions by splitting partitions only. Preserves order.""" + """Repartition the child_plan into more partitions by splitting partitions only. Preserves order. + This performs a naive split, which might lead to data skews but does not require a full materialization of + input partitions when performing the split. + """ assert ( num_output_partitions >= num_input_partitions ), f"Cannot split from {num_input_partitions} to {num_output_partitions}." - # Materialize the input partitions so we can see the number of rows and try to split evenly. - # Splitting evenly is fairly important if this operation is to be used for parallelism. - # (optimization TODO: don't materialize if num_rows is already available in physical plan metadata.) - materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() - stage_id = next(stage_id_counter) + base_splits_per_partition, num_partitions_with_extra_output = divmod(num_output_partitions, num_input_partitions) + + input_partition_idx = 0 for step in child_plan: if isinstance(step, PartitionTaskBuilder): - step = step.finalize_partition_task_single_output(stage_id=stage_id) - materializations.append(step) - yield step - - while any(not _.done() for _ in materializations): - logger.debug("split_to blocked on completion of all sources: %s", materializations) - yield None - - splits_per_partition = deque([1 for _ in materializations]) - num_splits_to_apply = num_output_partitions - num_input_partitions - - # Split by rows for now. - # In the future, maybe parameterize to allow alternatively splitting by size. - rows_by_partitions = [task.partition_metadata().num_rows for task in materializations] - - # Calculate how to spread the required splits across all the partitions. - # Iteratively apply a split and update how many rows would be in the resulting partitions. - # After this loop, splits_per_partition has the final number of splits to apply to each partition. - rows_after_splitting = [float(_) for _ in rows_by_partitions] - for _ in range(num_splits_to_apply): - _, split_at = max((rows, index) for (index, rows) in enumerate(rows_after_splitting)) - splits_per_partition[split_at] += 1 - rows_after_splitting[split_at] = float(rows_by_partitions[split_at] / splits_per_partition[split_at]) - - # Emit the split partitions. - for task, num_out, num_rows in zip(consume_deque(materializations), splits_per_partition, rows_by_partitions): - if num_out == 1: - yield PartitionTaskBuilder[PartitionT]( - inputs=[task.partition()], - partial_metadatas=[task.partition_metadata()], - resource_request=ResourceRequest(memory_bytes=task.partition_metadata().size_bytes), + num_out = ( + base_splits_per_partition + 1 + if input_partition_idx < num_partitions_with_extra_output + else base_splits_per_partition ) + step = step.add_instruction(instruction=execution_step.FanoutEvenSlices(_num_outputs=num_out)) + input_partition_idx += 1 + yield step else: - boundaries = [math.ceil(num_rows * i / num_out) for i in range(num_out + 1)] - starts, ends = boundaries[:-1], boundaries[1:] - yield PartitionTaskBuilder[PartitionT]( - inputs=[task.partition()], - partial_metadatas=[task.partition_metadata()], - resource_request=ResourceRequest(memory_bytes=task.partition_metadata().size_bytes), - ).add_instruction( - instruction=execution_step.FanoutSlices(_num_outputs=num_out, slices=list(zip(starts, ends))) - ) + yield step def coalesce( diff --git a/daft/io/__init__.py b/daft/io/__init__.py index 7d36701886..09647f21f2 100644 --- a/daft/io/__init__.py +++ b/daft/io/__init__.py @@ -1,7 +1,5 @@ from __future__ import annotations -import sys - from daft.daft import ( AzureConfig, GCSConfig, @@ -21,21 +19,6 @@ from daft.io.catalog import DataCatalogTable, DataCatalogType from daft.io.file_path import from_glob_path - -def _set_linux_cert_paths(): - import os - import ssl - - paths = ssl.get_default_verify_paths() - if paths.cafile: - os.environ[paths.openssl_cafile_env] = paths.openssl_cafile - if paths.capath: - os.environ[paths.openssl_capath_env] = paths.openssl_capath - - -if sys.platform == "linux": - _set_linux_cert_paths() - __all__ = [ "read_csv", "read_json", diff --git a/daft/io/aws_config.py b/daft/io/aws_config.py new file mode 100644 index 0000000000..7f0e9e3dff --- /dev/null +++ b/daft/io/aws_config.py @@ -0,0 +1,21 @@ +from typing import TYPE_CHECKING + +from daft.daft import S3Config + +if TYPE_CHECKING: + import boto3 + + +def boto3_client_from_s3_config(service: str, s3_config: S3Config) -> "boto3.client": + import boto3 + + return boto3.client( + service, + region_name=s3_config.region_name, + use_ssl=s3_config.use_ssl, + verify=s3_config.verify_ssl, + endpoint_url=s3_config.endpoint_url, + aws_access_key_id=s3_config.key_id, + aws_secret_access_key=s3_config.access_key, + aws_session_token=s3_config.session_token, + ) diff --git a/daft/io/catalog.py b/daft/io/catalog.py index 1183caa8ab..62cb16e672 100644 --- a/daft/io/catalog.py +++ b/daft/io/catalog.py @@ -5,6 +5,7 @@ from typing import Optional from daft.daft import IOConfig +from daft.io.aws_config import boto3_client_from_s3_config class DataCatalogType(Enum): @@ -42,20 +43,8 @@ def table_uri(self, io_config: IOConfig) -> str: """ if self.catalog == DataCatalogType.GLUE: # Use boto3 to get the table from AWS Glue Data Catalog. - import boto3 + glue = boto3_client_from_s3_config("glue", io_config.s3) - s3_config = io_config.s3 - - glue = boto3.client( - "glue", - region_name=s3_config.region_name, - use_ssl=s3_config.use_ssl, - verify=s3_config.verify_ssl, - endpoint_url=s3_config.endpoint_url, - aws_access_key_id=s3_config.key_id, - aws_secret_access_key=s3_config.access_key, - aws_session_token=s3_config.session_token, - ) if self.catalog_id is not None: # Allow cross account access, table.catalog_id should be the target account id glue_table = glue.get_table( diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 4f0f9a35c7..4d3156ae80 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -161,14 +161,18 @@ def _get_num_rows(self) -> int: def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]: try: - # Try to get percentiles using percentile_cont + # Try to get percentiles using percentile_disc. + # Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons. percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)] + # Use the OVER clause for SQL Server + over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else "" percentile_sql = self.conn.construct_sql_query( self.sql, projection=[ - f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}" + f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}" for i, percentile in enumerate(percentiles) ], + limit=1, ) pa_table = self.conn.execute_sql_query(percentile_sql) return pa_table, PartitionBoundStrategy.PERCENTILE diff --git a/requirements-dev.txt b/requirements-dev.txt index 9c7809ac80..3ab91623eb 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -66,6 +66,7 @@ trino[sqlalchemy]==0.328.0; python_version >= '3.8' PyMySQL==1.1.0; python_version >= '3.8' psycopg2-binary==2.9.9; python_version >= '3.8' sqlglot==23.3.0; python_version >= '3.8' +pyodbc==5.1.0; python_version >= '3.8' # AWS s3fs==2023.12.0; python_version >= '3.8' diff --git a/src/arrow2/src/array/growable/primitive.rs b/src/arrow2/src/array/growable/primitive.rs index e443756cb9..4083cb49db 100644 --- a/src/arrow2/src/array/growable/primitive.rs +++ b/src/arrow2/src/array/growable/primitive.rs @@ -1,10 +1,7 @@ use std::sync::Arc; use crate::{ - array::{Array, PrimitiveArray}, - bitmap::MutableBitmap, - datatypes::DataType, - types::NativeType, + array::{Array, PrimitiveArray}, bitmap::MutableBitmap, datatypes::DataType, types::NativeType }; use super::{ diff --git a/src/daft-core/src/prelude.rs b/src/daft-core/src/prelude.rs index 3b71045ddd..6f6ecaf5a5 100644 --- a/src/daft-core/src/prelude.rs +++ b/src/daft-core/src/prelude.rs @@ -2,6 +2,8 @@ //! //! This module re-exports commonly used items from the Daft core library. +// Re-export arrow2 bitmap +pub use arrow2::bitmap; // Re-export core series structures pub use daft_schema::schema::{Schema, SchemaRef}; diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 873f9013bd..567a2d35d8 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -990,21 +990,9 @@ impl Expr { to_sql_inner(inner, buffer)?; write!(buffer, ") IS NOT NULL") } - Expr::IfElse { - if_true, - if_false, - predicate, - } => { - write!(buffer, "CASE WHEN ")?; - to_sql_inner(predicate, buffer)?; - write!(buffer, " THEN ")?; - to_sql_inner(if_true, buffer)?; - write!(buffer, " ELSE ")?; - to_sql_inner(if_false, buffer)?; - write!(buffer, " END") - } // TODO: Implement SQL translations for these expressions if possible - Expr::Agg(..) + Expr::IfElse { .. } + | Expr::Agg(..) | Expr::Cast(..) | Expr::IsIn(..) | Expr::Between(..) diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 525e308ebe..bdcebecab6 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use daft_plan::JoinType; @@ -43,14 +44,18 @@ impl IntermediateOperatorState for AntiSemiProbeState { pub struct AntiSemiProbeOperator { probe_on: Vec, - join_type: JoinType, + is_semi: bool, + output_schema: SchemaRef, } impl AntiSemiProbeOperator { - pub fn new(probe_on: Vec, join_type: JoinType) -> Self { + const DEFAULT_GROWABLE_SIZE: usize = 20; + + pub fn new(probe_on: Vec, join_type: &JoinType, output_schema: &SchemaRef) -> Self { Self { probe_on, - join_type, + is_semi: *join_type == JoinType::Semi, + output_schema: output_schema.clone(), } } @@ -65,8 +70,11 @@ impl AntiSemiProbeOperator { let input_tables = input.get_tables()?; - let mut probe_side_growable = - GrowableTable::new(&input_tables.iter().collect::>(), false, 20)?; + let mut probe_side_growable = GrowableTable::new( + &input_tables.iter().collect::>(), + false, + Self::DEFAULT_GROWABLE_SIZE, + )?; drop(_growables); { @@ -76,7 +84,7 @@ impl AntiSemiProbeOperator { let iter = probe_set.probe_exists(&join_keys)?; for (probe_row_idx, matched) in iter.enumerate() { - match (self.join_type == JoinType::Semi, matched) { + match (self.is_semi, matched) { (true, true) | (false, false) => { probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); } @@ -109,15 +117,16 @@ impl IntermediateOperator for AntiSemiProbeOperator { .expect("AntiSemiProbeOperator state should be AntiSemiProbeState"); if idx == 0 { - let (probe_table, _) = input.as_probe_table(); - state.set_table(probe_table); + let probe_state = input.as_probe_state(); + state.set_table(probe_state.get_probeable()); Ok(IntermediateOperatorResult::NeedMoreInput(None)) } else { let input = input.as_data(); - let out = match self.join_type { - JoinType::Semi | JoinType::Anti => self.probe_anti_semi(input, state), - _ => unreachable!("Only Semi and Anti joins are supported"), - }?; + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); + } + let out = self.probe_anti_semi(input, state)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } } diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs deleted file mode 100644 index dd53b9eac4..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ /dev/null @@ -1,268 +0,0 @@ -use std::sync::Arc; - -use common_error::DaftResult; -use daft_core::prelude::SchemaRef; -use daft_dsl::ExprRef; -use daft_micropartition::MicroPartition; -use daft_plan::JoinType; -use daft_table::{GrowableTable, Probeable, Table}; -use indexmap::IndexSet; -use tracing::{info_span, instrument}; - -use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, -}; -use crate::pipeline::PipelineResultType; - -enum HashJoinProbeState { - Building, - ReadyToProbe(Arc, Arc>), -} - -impl HashJoinProbeState { - fn set_table(&mut self, table: &Arc, tables: &Arc>) { - if matches!(self, Self::Building) { - *self = Self::ReadyToProbe(table.clone(), tables.clone()); - } else { - panic!("HashJoinProbeState should only be in Building state when setting table") - } - } - - fn get_probeable_and_table(&self) -> (&Arc, &Arc>) { - if let Self::ReadyToProbe(probe_table, tables) = self { - (probe_table, tables) - } else { - panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") - } - } -} - -impl IntermediateOperatorState for HashJoinProbeState { - fn as_any_mut(&mut self) -> &mut dyn std::any::Any { - self - } -} - -pub struct HashJoinProbeOperator { - probe_on: Vec, - common_join_keys: Vec, - left_non_join_columns: Vec, - right_non_join_columns: Vec, - join_type: JoinType, - build_on_left: bool, -} - -impl HashJoinProbeOperator { - pub fn new( - probe_on: Vec, - left_schema: &SchemaRef, - right_schema: &SchemaRef, - join_type: JoinType, - build_on_left: bool, - common_join_keys: IndexSet, - ) -> Self { - let (common_join_keys, left_non_join_columns, right_non_join_columns) = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right => { - let left_non_join_columns = left_schema - .fields - .keys() - .filter(|c| !common_join_keys.contains(*c)) - .cloned() - .collect(); - let right_non_join_columns = right_schema - .fields - .keys() - .filter(|c| !common_join_keys.contains(*c)) - .cloned() - .collect(); - ( - common_join_keys.into_iter().collect(), - left_non_join_columns, - right_non_join_columns, - ) - } - _ => { - panic!("Semi, Anti, and join are not supported in HashJoinProbeOperator") - } - }; - Self { - probe_on, - common_join_keys, - left_non_join_columns, - right_non_join_columns, - join_type, - build_on_left, - } - } - - fn probe_inner( - &self, - input: &Arc, - state: &HashJoinProbeState, - ) -> DaftResult> { - let (probe_table, tables) = state.get_probeable_and_table(); - - let _growables = info_span!("HashJoinOperator::build_growables").entered(); - - let mut build_side_growable = - GrowableTable::new(&tables.iter().collect::>(), false, 20)?; - - let input_tables = input.get_tables()?; - - let mut probe_side_growable = - GrowableTable::new(&input_tables.iter().collect::>(), false, 20)?; - - drop(_growables); - { - let _loop = info_span!("HashJoinOperator::eval_and_probe").entered(); - for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - // we should emit one table at a time when this is streaming - let join_keys = table.eval_expression_list(&self.probe_on)?; - let idx_mapper = probe_table.probe_indices(&join_keys)?; - - for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { - if let Some(inner_iter) = inner_iter { - for (build_side_table_idx, build_row_idx) in inner_iter { - build_side_growable.extend( - build_side_table_idx as usize, - build_row_idx as usize, - 1, - ); - // we can perform run length compression for this to make this more efficient - probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); - } - } - } - } - } - let build_side_table = build_side_growable.build()?; - let probe_side_table = probe_side_growable.build()?; - - let (left_table, right_table) = if self.build_on_left { - (build_side_table, probe_side_table) - } else { - (probe_side_table, build_side_table) - }; - - let join_keys_table = left_table.get_columns(&self.common_join_keys)?; - let left_non_join_columns = left_table.get_columns(&self.left_non_join_columns)?; - let right_non_join_columns = right_table.get_columns(&self.right_non_join_columns)?; - let final_table = join_keys_table - .union(&left_non_join_columns)? - .union(&right_non_join_columns)?; - - Ok(Arc::new(MicroPartition::new_loaded( - final_table.schema.clone(), - Arc::new(vec![final_table]), - None, - ))) - } - - fn probe_left_right( - &self, - input: &Arc, - state: &HashJoinProbeState, - ) -> DaftResult> { - let (probe_table, tables) = state.get_probeable_and_table(); - - let _growables = info_span!("HashJoinOperator::build_growables").entered(); - - let mut build_side_growable = GrowableTable::new( - &tables.iter().collect::>(), - true, - tables.iter().map(daft_table::Table::len).sum(), - )?; - - let input_tables = input.get_tables()?; - - let mut probe_side_growable = - GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; - - drop(_growables); - { - let _loop = info_span!("HashJoinOperator::eval_and_probe").entered(); - for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - let join_keys = table.eval_expression_list(&self.probe_on)?; - let idx_mapper = probe_table.probe_indices(&join_keys)?; - - for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { - if let Some(inner_iter) = inner_iter { - for (build_side_table_idx, build_row_idx) in inner_iter { - build_side_growable.extend( - build_side_table_idx as usize, - build_row_idx as usize, - 1, - ); - probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); - } - } else { - // if there's no match, we should still emit the probe side and fill the build side with nulls - build_side_growable.add_nulls(1); - probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); - } - } - } - } - let build_side_table = build_side_growable.build()?; - let probe_side_table = probe_side_growable.build()?; - - let final_table = if self.join_type == JoinType::Left { - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = probe_side_table.get_columns(&self.left_non_join_columns)?; - let right = build_side_table.get_columns(&self.right_non_join_columns)?; - join_table.union(&left)?.union(&right)? - } else { - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns(&self.left_non_join_columns)?; - let right = probe_side_table.get_columns(&self.right_non_join_columns)?; - join_table.union(&left)?.union(&right)? - }; - Ok(Arc::new(MicroPartition::new_loaded( - final_table.schema.clone(), - Arc::new(vec![final_table]), - None, - ))) - } -} - -impl IntermediateOperator for HashJoinProbeOperator { - #[instrument(skip_all, name = "HashJoinOperator::execute")] - fn execute( - &self, - idx: usize, - input: &PipelineResultType, - state: Option<&mut Box>, - ) -> DaftResult { - let state = state - .expect("HashJoinProbeOperator should have state") - .as_any_mut() - .downcast_mut::() - .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - - if idx == 0 { - let (probe_table, tables) = input.as_probe_table(); - state.set_table(probe_table, tables); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } else { - let input = input.as_data(); - let out = match self.join_type { - JoinType::Inner => self.probe_inner(input, state), - JoinType::Left | JoinType::Right => self.probe_left_right(input, state), - _ => { - unimplemented!( - "Only Inner, Left, and Right joins are supported in HashJoinProbeOperator" - ) - } - }?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) - } - } - - fn name(&self) -> &'static str { - "HashJoinProbeOperator" - } - - fn make_state(&self) -> Option> { - Some(Box::new(HashJoinProbeState::Building)) - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs new file mode 100644 index 0000000000..a208efea6c --- /dev/null +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -0,0 +1,199 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::prelude::SchemaRef; +use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; +use daft_table::{GrowableTable, ProbeState}; +use indexmap::IndexSet; +use tracing::{info_span, instrument}; + +use super::intermediate_op::{ + IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, +}; +use crate::pipeline::PipelineResultType; + +enum InnerHashJoinProbeState { + Building, + ReadyToProbe(Arc), +} + +impl InnerHashJoinProbeState { + fn set_probe_state(&mut self, probe_state: Arc) { + if matches!(self, Self::Building) { + *self = Self::ReadyToProbe(probe_state); + } else { + panic!("InnerHashJoinProbeState should only be in Building state when setting table") + } + } + + fn get_probe_state(&self) -> &Arc { + if let Self::ReadyToProbe(probe_state) = self { + probe_state + } else { + panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") + } + } +} + +impl IntermediateOperatorState for InnerHashJoinProbeState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub struct InnerHashJoinProbeOperator { + probe_on: Vec, + common_join_keys: Vec, + left_non_join_columns: Vec, + right_non_join_columns: Vec, + build_on_left: bool, + output_schema: SchemaRef, +} + +impl InnerHashJoinProbeOperator { + const DEFAULT_GROWABLE_SIZE: usize = 20; + + pub fn new( + probe_on: Vec, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + build_on_left: bool, + common_join_keys: IndexSet, + output_schema: &SchemaRef, + ) -> Self { + let left_non_join_columns = left_schema + .fields + .keys() + .filter(|c| !common_join_keys.contains(*c)) + .cloned() + .collect(); + let right_non_join_columns = right_schema + .fields + .keys() + .filter(|c| !common_join_keys.contains(*c)) + .cloned() + .collect(); + let common_join_keys = common_join_keys.into_iter().collect(); + Self { + probe_on, + common_join_keys, + left_non_join_columns, + right_non_join_columns, + build_on_left, + output_schema: output_schema.clone(), + } + } + + fn probe_inner( + &self, + input: &Arc, + state: &InnerHashJoinProbeState, + ) -> DaftResult> { + let (probe_table, tables) = { + let probe_state = state.get_probe_state(); + (probe_state.get_probeable(), probe_state.get_tables()) + }; + + let _growables = info_span!("InnerHashJoinOperator::build_growables").entered(); + + let mut build_side_growable = GrowableTable::new( + &tables.iter().collect::>(), + false, + Self::DEFAULT_GROWABLE_SIZE, + )?; + + let input_tables = input.get_tables()?; + + let mut probe_side_growable = GrowableTable::new( + &input_tables.iter().collect::>(), + false, + Self::DEFAULT_GROWABLE_SIZE, + )?; + + drop(_growables); + { + let _loop = info_span!("InnerHashJoinOperator::eval_and_probe").entered(); + for (probe_side_table_idx, table) in input_tables.iter().enumerate() { + // we should emit one table at a time when this is streaming + let join_keys = table.eval_expression_list(&self.probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + build_side_growable.extend( + build_side_table_idx as usize, + build_row_idx as usize, + 1, + ); + // we can perform run length compression for this to make this more efficient + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } + } + } + } + let build_side_table = build_side_growable.build()?; + let probe_side_table = probe_side_growable.build()?; + + let (left_table, right_table) = if self.build_on_left { + (build_side_table, probe_side_table) + } else { + (probe_side_table, build_side_table) + }; + + let join_keys_table = left_table.get_columns(&self.common_join_keys)?; + let left_non_join_columns = left_table.get_columns(&self.left_non_join_columns)?; + let right_non_join_columns = right_table.get_columns(&self.right_non_join_columns)?; + let final_table = join_keys_table + .union(&left_non_join_columns)? + .union(&right_non_join_columns)?; + + Ok(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + ))) + } +} + +impl IntermediateOperator for InnerHashJoinProbeOperator { + #[instrument(skip_all, name = "InnerHashJoinOperator::execute")] + fn execute( + &self, + idx: usize, + input: &PipelineResultType, + state: Option<&mut Box>, + ) -> DaftResult { + let state = state + .expect("InnerHashJoinProbeOperator should have state") + .as_any_mut() + .downcast_mut::() + .expect("InnerHashJoinProbeOperator state should be InnerHashJoinProbeState"); + match idx { + 0 => { + let probe_state = input.as_probe_state(); + state.set_probe_state(probe_state.clone()); + Ok(IntermediateOperatorResult::NeedMoreInput(None)) + } + _ => { + let input = input.as_data(); + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); + } + let out = self.probe_inner(input, state)?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) + } + } + } + + fn name(&self) -> &'static str { + "InnerHashJoinProbeOperator" + } + + fn make_state(&self) -> Option> { + Some(Box::new(InnerHashJoinProbeState::Building)) + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index e9523abea3..7d97464e24 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -3,7 +3,9 @@ pub mod anti_semi_hash_join_probe; pub mod buffer; pub mod explode; pub mod filter; -pub mod hash_join_probe; +pub mod inner_hash_join_probe; pub mod intermediate_op; +pub mod pivot; pub mod project; pub mod sample; +pub mod unpivot; diff --git a/src/daft-local-execution/src/intermediate_ops/pivot.rs b/src/daft-local-execution/src/intermediate_ops/pivot.rs new file mode 100644 index 0000000000..d942053dd9 --- /dev/null +++ b/src/daft-local-execution/src/intermediate_ops/pivot.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_dsl::ExprRef; +use tracing::instrument; + +use super::intermediate_op::{ + IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, +}; +use crate::pipeline::PipelineResultType; + +pub struct PivotOperator { + group_by: Vec, + pivot_col: ExprRef, + values_col: ExprRef, + names: Vec, +} + +impl PivotOperator { + pub fn new( + group_by: Vec, + pivot_col: ExprRef, + values_col: ExprRef, + names: Vec, + ) -> Self { + Self { + group_by, + pivot_col, + values_col, + names, + } + } +} + +impl IntermediateOperator for PivotOperator { + #[instrument(skip_all, name = "PivotOperator::execute")] + fn execute( + &self, + _idx: usize, + input: &PipelineResultType, + _state: Option<&mut Box>, + ) -> DaftResult { + let out = input.as_data().pivot( + &self.group_by, + self.pivot_col.clone(), + self.values_col.clone(), + self.names.clone(), + )?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( + out, + )))) + } + + fn name(&self) -> &'static str { + "PivotOperator" + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/unpivot.rs b/src/daft-local-execution/src/intermediate_ops/unpivot.rs new file mode 100644 index 0000000000..746d0563c8 --- /dev/null +++ b/src/daft-local-execution/src/intermediate_ops/unpivot.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_dsl::ExprRef; +use tracing::instrument; + +use super::intermediate_op::{ + IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, +}; +use crate::pipeline::PipelineResultType; + +pub struct UnpivotOperator { + ids: Vec, + values: Vec, + variable_name: String, + value_name: String, +} + +impl UnpivotOperator { + pub fn new( + ids: Vec, + values: Vec, + variable_name: String, + value_name: String, + ) -> Self { + Self { + ids, + values, + variable_name, + value_name, + } + } +} + +impl IntermediateOperator for UnpivotOperator { + #[instrument(skip_all, name = "UnpivotOperator::execute")] + fn execute( + &self, + _idx: usize, + input: &PipelineResultType, + _state: Option<&mut Box>, + ) -> DaftResult { + let out = input.as_data().unpivot( + &self.ids, + &self.values, + &self.variable_name, + &self.value_name, + )?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( + out, + )))) + } + + fn name(&self) -> &'static str { + "UnpivotOperator" + } +} diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 591a7859ac..eccece1a56 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -10,11 +10,11 @@ use daft_core::{ use daft_dsl::{col, join::get_common_join_keys, Expr}; use daft_micropartition::MicroPartition; use daft_physical_plan::{ - EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, - Project, Sample, Sort, UnGroupedAggregate, + Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, + LocalPhysicalPlan, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot, }; use daft_plan::{populate_aggregation_stages, JoinType}; -use daft_table::{Probeable, Table}; +use daft_table::ProbeState; use indexmap::IndexSet; use snafu::ResultExt; @@ -22,12 +22,15 @@ use crate::{ channel::PipelineChannel, intermediate_ops::{ aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator, - explode::ExplodeOperator, filter::FilterOperator, hash_join_probe::HashJoinProbeOperator, - intermediate_op::IntermediateNode, project::ProjectOperator, sample::SampleOperator, + explode::ExplodeOperator, filter::FilterOperator, + inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode, + pivot::PivotOperator, project::ProjectOperator, sample::SampleOperator, + unpivot::UnpivotOperator, }, sinks::{ - aggregate::AggregateSink, blocking_sink::BlockingSinkNode, - hash_join_build::HashJoinBuildSink, limit::LimitSink, sort::SortSink, + aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink, + hash_join_build::HashJoinBuildSink, limit::LimitSink, + outer_hash_join_probe::OuterHashJoinProbeSink, sort::SortSink, streaming_sink::StreamingSinkNode, }, sources::{empty_scan::EmptyScanSource, in_memory::InMemorySource}, @@ -37,7 +40,7 @@ use crate::{ #[derive(Clone)] pub enum PipelineResultType { Data(Arc), - ProbeTable(Arc, Arc>), + ProbeState(Arc), } impl From> for PipelineResultType { @@ -46,9 +49,9 @@ impl From> for PipelineResultType { } } -impl From<(Arc, Arc>)> for PipelineResultType { - fn from((probe_table, tables): (Arc, Arc>)) -> Self { - Self::ProbeTable(probe_table, tables) +impl From> for PipelineResultType { + fn from(probe_state: Arc) -> Self { + Self::ProbeState(probe_state) } } @@ -60,15 +63,15 @@ impl PipelineResultType { } } - pub fn as_probe_table(&self) -> (&Arc, &Arc>) { + pub fn as_probe_state(&self) -> &Arc { match self { - Self::ProbeTable(probe_table, tables) => (probe_table, tables), + Self::ProbeState(probe_state) => probe_state, _ => panic!("Expected probe table"), } } pub fn should_broadcast(&self) -> bool { - matches!(self, Self::ProbeTable(_, _)) + matches!(self, Self::ProbeState(_)) } } @@ -155,14 +158,13 @@ pub fn physical_plan_to_pipeline( }) => { let sink = LimitSink::new(*num_rows as usize); let child_node = physical_plan_to_pipeline(input, psets)?; - StreamingSinkNode::new(sink.boxed(), vec![child_node]).boxed() + StreamingSinkNode::new(Arc::new(sink), vec![child_node]).boxed() } - LocalPhysicalPlan::Concat(_) => { - todo!("concat") - // let sink = ConcatSink::new(); - // let left_child = physical_plan_to_pipeline(input, psets)?; - // let right_child = physical_plan_to_pipeline(other, psets)?; - // PipelineNode::double_sink(sink, left_child, right_child) + LocalPhysicalPlan::Concat(Concat { input, other, .. }) => { + let left_child = physical_plan_to_pipeline(input, psets)?; + let right_child = physical_plan_to_pipeline(other, psets)?; + let sink = ConcatSink {}; + StreamingSinkNode::new(Arc::new(sink), vec![left_child, right_child]).boxed() } LocalPhysicalPlan::UnGroupedAggregate(UnGroupedAggregate { input, @@ -241,6 +243,40 @@ pub fn physical_plan_to_pipeline( IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed() } + LocalPhysicalPlan::Unpivot(Unpivot { + input, + ids, + values, + variable_name, + value_name, + .. + }) => { + let child_node = physical_plan_to_pipeline(input, psets)?; + let unpivot_op = UnpivotOperator::new( + ids.clone(), + values.clone(), + variable_name.clone(), + value_name.clone(), + ); + IntermediateNode::new(Arc::new(unpivot_op), vec![child_node]).boxed() + } + LocalPhysicalPlan::Pivot(Pivot { + input, + group_by, + pivot_column, + value_column, + names, + .. + }) => { + let pivot_op = PivotOperator::new( + group_by.clone(), + pivot_column.clone(), + value_column.clone(), + names.clone(), + ); + let child_node = physical_plan_to_pipeline(input, psets)?; + IntermediateNode::new(Arc::new(pivot_op), vec![child_node]).boxed() + } LocalPhysicalPlan::Sort(Sort { input, sort_by, @@ -258,7 +294,7 @@ pub fn physical_plan_to_pipeline( left_on, right_on, join_type, - .. + schema, }) => { let left_schema = left.schema(); let right_schema = right.schema(); @@ -269,11 +305,9 @@ pub fn physical_plan_to_pipeline( let build_on_left = match join_type { JoinType::Inner => true, JoinType::Right => true, + JoinType::Outer => true, JoinType::Left => false, JoinType::Anti | JoinType::Semi => false, - JoinType::Outer => { - unimplemented!("Outer join not supported yet"); - } }; let (build_on, probe_on, build_child, probe_child) = match build_on_left { true => (left_on, right_on, left, right), @@ -282,7 +316,7 @@ pub fn physical_plan_to_pipeline( let build_schema = build_child.schema(); let probe_schema = probe_child.schema(); - let probe_node = || -> DaftResult<_> { + || -> DaftResult<_> { let common_join_keys: IndexSet<_> = get_common_join_keys(left_on, right_on) .map(std::string::ToString::to_string) .collect(); @@ -327,32 +361,46 @@ pub fn physical_plan_to_pipeline( let probe_child_node = physical_plan_to_pipeline(probe_child, psets)?; match join_type { - JoinType::Anti | JoinType::Semi => DaftResult::Ok(IntermediateNode::new( - Arc::new(AntiSemiProbeOperator::new(casted_probe_on, *join_type)), + JoinType::Anti | JoinType::Semi => Ok(IntermediateNode::new( + Arc::new(AntiSemiProbeOperator::new( + casted_probe_on, + join_type, + schema, + )), vec![build_node, probe_child_node], - )), - JoinType::Inner | JoinType::Left | JoinType::Right => { - DaftResult::Ok(IntermediateNode::new( - Arc::new(HashJoinProbeOperator::new( + ) + .boxed()), + JoinType::Inner => Ok(IntermediateNode::new( + Arc::new(InnerHashJoinProbeOperator::new( + casted_probe_on, + left_schema, + right_schema, + build_on_left, + common_join_keys, + schema, + )), + vec![build_node, probe_child_node], + ) + .boxed()), + JoinType::Left | JoinType::Right | JoinType::Outer => { + Ok(StreamingSinkNode::new( + Arc::new(OuterHashJoinProbeSink::new( casted_probe_on, left_schema, right_schema, *join_type, - build_on_left, common_join_keys, + schema, )), vec![build_node, probe_child_node], - )) - } - JoinType::Outer => { - unimplemented!("Outer join not supported yet"); + ) + .boxed()) } } }() .with_context(|_| PipelineCreationSnafu { plan_name: physical_plan.name(), - })?; - probe_node.boxed() + })? } _ => { unimplemented!("Physical plan not supported: {}", physical_plan.name()); diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index de1f657273..566d253e9c 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -124,8 +124,8 @@ impl CountingSender { ) -> Result<(), SendError> { let len = match v { PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeTable(_, ref tables) => { - tables.iter().map(daft_table::Table::len).sum() + PipelineResultType::ProbeState(ref state) => { + state.get_tables().iter().map(|t| t.len()).sum() } }; self.sender.send(v).await?; @@ -149,8 +149,8 @@ impl CountingReceiver { if let Some(ref v) = v { let len = match v { PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeTable(_, ref tables) => { - tables.iter().map(daft_table::Table::len).sum() + PipelineResultType::ProbeState(state) => { + state.get_tables().iter().map(|t| t.len()).sum() } }; self.rt.mark_rows_received(len as u64); diff --git a/src/daft-local-execution/src/sinks/concat.rs b/src/daft-local-execution/src/sinks/concat.rs index 010bed0aaf..5b98cb84c6 100644 --- a/src/daft-local-execution/src/sinks/concat.rs +++ b/src/daft-local-execution/src/sinks/concat.rs @@ -1,61 +1,68 @@ -// use std::sync::Arc; - -// use common_error::DaftResult; -// use daft_micropartition::MicroPartition; -// use tracing::instrument; - -// use super::sink::{Sink, SinkResultType}; - -// #[derive(Clone)] -// pub struct ConcatSink { -// result_left: Vec>, -// result_right: Vec>, -// } - -// impl ConcatSink { -// pub fn new() -> Self { -// Self { -// result_left: Vec::new(), -// result_right: Vec::new(), -// } -// } - -// #[instrument(skip_all, name = "ConcatSink::sink")] -// fn sink_left(&mut self, input: &Arc) -> DaftResult { -// self.result_left.push(input.clone()); -// Ok(SinkResultType::NeedMoreInput) -// } - -// #[instrument(skip_all, name = "ConcatSink::sink")] -// fn sink_right(&mut self, input: &Arc) -> DaftResult { -// self.result_right.push(input.clone()); -// Ok(SinkResultType::NeedMoreInput) -// } -// } - -// impl Sink for ConcatSink { -// fn sink(&mut self, index: usize, input: &Arc) -> DaftResult { -// match index { -// 0 => self.sink_left(input), -// 1 => self.sink_right(input), -// _ => panic!("concat only supports 2 inputs, got {index}"), -// } -// } - -// fn in_order(&self) -> bool { -// true -// } - -// fn num_inputs(&self) -> usize { -// 2 -// } - -// #[instrument(skip_all, name = "ConcatSink::finalize")] -// fn finalize(self: Box) -> DaftResult>> { -// Ok(self -// .result_left -// .into_iter() -// .chain(self.result_right.into_iter()) -// .collect()) -// } -// } +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; +use daft_micropartition::MicroPartition; +use tracing::instrument; + +use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; +use crate::pipeline::PipelineResultType; + +struct ConcatSinkState { + // The index of the last morsel of data that was received, which should be strictly non-decreasing. + pub curr_idx: usize, +} +impl StreamingSinkState for ConcatSinkState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub struct ConcatSink {} + +impl StreamingSink for ConcatSink { + /// Execute for the ConcatSink operator does not do any computation and simply returns the input data. + /// It only expects that the indices of the input data are strictly non-decreasing. + /// TODO(Colin): If maintain_order is false, technically we could accept any index. Make this optimization later. + #[instrument(skip_all, name = "ConcatSink::sink")] + fn execute( + &self, + index: usize, + input: &PipelineResultType, + state: &mut dyn StreamingSinkState, + ) -> DaftResult { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("ConcatSink should have ConcatSinkState"); + + // If the index is the same as the current index or one more than the current index, then we can accept the morsel. + if state.curr_idx == index || state.curr_idx + 1 == index { + state.curr_idx = index; + Ok(StreamingSinkOutput::NeedMoreInput(Some( + input.as_data().clone(), + ))) + } else { + Err(DaftError::ComputeError(format!("Concat sink received out-of-order data. Expected index to be {} or {}, but got {}.", state.curr_idx, state.curr_idx + 1, index))) + } + } + + fn name(&self) -> &'static str { + "Concat" + } + + fn finalize( + &self, + _states: Vec>, + ) -> DaftResult>> { + Ok(None) + } + + fn make_state(&self) -> Box { + Box::new(ConcatSinkState { curr_idx: 0 }) + } + + /// Since the ConcatSink does not do any computation, it does not need to spawn multiple workers. + fn max_concurrency(&self) -> usize { + 1 + } +} diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 3af65702cd..c8258e281a 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -5,7 +5,7 @@ use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use daft_plan::JoinType; -use daft_table::{make_probeable_builder, Probeable, ProbeableBuilder, Table}; +use daft_table::{make_probeable_builder, ProbeState, ProbeableBuilder, Table}; use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; use crate::pipeline::PipelineResultType; @@ -17,8 +17,7 @@ enum ProbeTableState { tables: Vec, }, Done { - probe_table: Arc, - tables: Arc>, + probe_state: Arc, }, } @@ -66,8 +65,7 @@ impl ProbeTableState { let pt = ptb.build(); *self = Self::Done { - probe_table: pt, - tables: Arc::new(tables.clone()), + probe_state: Arc::new(ProbeState::new(pt, Arc::new(tables.clone()))), }; Ok(()) } else { @@ -108,12 +106,8 @@ impl BlockingSink for HashJoinBuildSink { fn finalize(&mut self) -> DaftResult> { self.probe_table_state.finalize()?; - if let ProbeTableState::Done { - probe_table, - tables, - } = &self.probe_table_state - { - Ok(Some((probe_table.clone(), tables.clone()).into())) + if let ProbeTableState::Done { probe_state } = &self.probe_table_state { + Ok(Some(probe_state.clone().into())) } else { panic!("finalize should only be called after the probe table is built") } diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index 40b4d1538f..633c3511c1 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -4,51 +4,69 @@ use common_error::DaftResult; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::streaming_sink::{StreamSinkOutput, StreamingSink}; +use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; +use crate::pipeline::PipelineResultType; + +struct LimitSinkState { + remaining: usize, +} + +impl LimitSinkState { + fn new(remaining: usize) -> Self { + Self { remaining } + } + + fn get_remaining_mut(&mut self) -> &mut usize { + &mut self.remaining + } +} + +impl StreamingSinkState for LimitSinkState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} pub struct LimitSink { - #[allow(dead_code)] limit: usize, - remaining: usize, } impl LimitSink { pub fn new(limit: usize) -> Self { - Self { - limit, - remaining: limit, - } - } - pub fn boxed(self) -> Box { - Box::new(self) + Self { limit } } } impl StreamingSink for LimitSink { #[instrument(skip_all, name = "LimitSink::sink")] fn execute( - &mut self, + &self, index: usize, - input: &Arc, - ) -> DaftResult { + input: &PipelineResultType, + state: &mut dyn StreamingSinkState, + ) -> DaftResult { assert_eq!(index, 0); - + let state = state + .as_any_mut() + .downcast_mut::() + .expect("Limit Sink should have LimitSinkState"); + let input = input.as_data(); let input_num_rows = input.len(); - + let remaining = state.get_remaining_mut(); use std::cmp::Ordering::{Equal, Greater, Less}; - match input_num_rows.cmp(&self.remaining) { + match input_num_rows.cmp(remaining) { Less => { - self.remaining -= input_num_rows; - Ok(StreamSinkOutput::NeedMoreInput(Some(input.clone()))) + *remaining -= input_num_rows; + Ok(StreamingSinkOutput::NeedMoreInput(Some(input.clone()))) } Equal => { - self.remaining = 0; - Ok(StreamSinkOutput::Finished(Some(input.clone()))) + *remaining = 0; + Ok(StreamingSinkOutput::Finished(Some(input.clone()))) } Greater => { - let taken = input.head(self.remaining)?; - self.remaining -= taken.len(); - Ok(StreamSinkOutput::Finished(Some(Arc::new(taken)))) + let taken = input.head(*remaining)?; + *remaining = 0; + Ok(StreamingSinkOutput::Finished(Some(Arc::new(taken)))) } } } @@ -56,4 +74,19 @@ impl StreamingSink for LimitSink { fn name(&self) -> &'static str { "Limit" } + + fn finalize( + &self, + _states: Vec>, + ) -> DaftResult>> { + Ok(None) + } + + fn make_state(&self) -> Box { + Box::new(LimitSinkState::new(self.limit)) + } + + fn max_concurrency(&self) -> usize { + 1 + } } diff --git a/src/daft-local-execution/src/sinks/mod.rs b/src/daft-local-execution/src/sinks/mod.rs index 39910e7995..7960e55a7c 100644 --- a/src/daft-local-execution/src/sinks/mod.rs +++ b/src/daft-local-execution/src/sinks/mod.rs @@ -3,5 +3,6 @@ pub mod blocking_sink; pub mod concat; pub mod hash_join_build; pub mod limit; +pub mod outer_hash_join_probe; pub mod sort; pub mod streaming_sink; diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs new file mode 100644 index 0000000000..ab5ffa8cb0 --- /dev/null +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -0,0 +1,419 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::{ + prelude::{ + bitmap::{and, Bitmap, MutableBitmap}, + BooleanArray, Schema, SchemaRef, + }, + series::{IntoSeries, Series}, +}; +use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; +use daft_plan::JoinType; +use daft_table::{GrowableTable, ProbeState, Table}; +use indexmap::IndexSet; +use tracing::{info_span, instrument}; + +use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; +use crate::pipeline::PipelineResultType; + +struct IndexBitmapBuilder { + mutable_bitmaps: Vec, +} + +impl IndexBitmapBuilder { + fn new(tables: &[Table]) -> Self { + Self { + mutable_bitmaps: tables + .iter() + .map(|t| MutableBitmap::from_len_set(t.len())) + .collect(), + } + } + + #[inline] + fn mark_used(&mut self, table_idx: usize, row_idx: usize) { + self.mutable_bitmaps[table_idx].set(row_idx, false); + } + + fn build(self) -> IndexBitmap { + IndexBitmap { + bitmaps: self.mutable_bitmaps.into_iter().map(|b| b.into()).collect(), + } + } +} + +struct IndexBitmap { + bitmaps: Vec, +} + +impl IndexBitmap { + fn merge(&self, other: &Self) -> Self { + Self { + bitmaps: self + .bitmaps + .iter() + .zip(other.bitmaps.iter()) + .map(|(a, b)| and(a, b)) + .collect(), + } + } + + fn convert_to_boolean_arrays(self) -> impl Iterator { + self.bitmaps + .into_iter() + .map(|b| BooleanArray::from(("bitmap", b))) + } +} + +enum OuterHashJoinProbeState { + Building, + ReadyToProbe(Arc, Option), +} + +impl OuterHashJoinProbeState { + fn initialize_probe_state(&mut self, probe_state: Arc, needs_bitmap: bool) { + let tables = probe_state.get_tables(); + if matches!(self, Self::Building) { + *self = Self::ReadyToProbe( + probe_state.clone(), + if needs_bitmap { + Some(IndexBitmapBuilder::new(tables)) + } else { + None + }, + ); + } else { + panic!("OuterHashJoinProbeState should only be in Building state when setting table") + } + } + + fn get_probe_state(&self) -> &ProbeState { + if let Self::ReadyToProbe(probe_state, _) = self { + probe_state + } else { + panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") + } + } + + fn get_bitmap_builder(&mut self) -> &mut Option { + if let Self::ReadyToProbe(_, bitmap_builder) = self { + bitmap_builder + } else { + panic!("get_bitmap can only be used during the ReadyToProbe Phase") + } + } +} + +impl StreamingSinkState for OuterHashJoinProbeState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub(crate) struct OuterHashJoinProbeSink { + probe_on: Vec, + common_join_keys: Vec, + left_non_join_columns: Vec, + right_non_join_columns: Vec, + right_non_join_schema: SchemaRef, + join_type: JoinType, + output_schema: SchemaRef, +} + +impl OuterHashJoinProbeSink { + pub(crate) fn new( + probe_on: Vec, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + join_type: JoinType, + common_join_keys: IndexSet, + output_schema: &SchemaRef, + ) -> Self { + let left_non_join_columns = left_schema + .fields + .keys() + .filter(|c| !common_join_keys.contains(*c)) + .cloned() + .collect(); + let right_non_join_fields = right_schema + .fields + .values() + .filter(|f| !common_join_keys.contains(&f.name)) + .cloned() + .collect(); + let right_non_join_schema = + Arc::new(Schema::new(right_non_join_fields).expect("right schema should be valid")); + let right_non_join_columns = right_non_join_schema.fields.keys().cloned().collect(); + let common_join_keys = common_join_keys.into_iter().collect(); + Self { + probe_on, + common_join_keys, + left_non_join_columns, + right_non_join_columns, + right_non_join_schema, + join_type, + output_schema: output_schema.clone(), + } + } + + fn probe_left_right( + &self, + input: &Arc, + state: &OuterHashJoinProbeState, + ) -> DaftResult> { + let (probe_table, tables) = { + let probe_state = state.get_probe_state(); + (probe_state.get_probeable(), probe_state.get_tables()) + }; + + let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); + let mut build_side_growable = GrowableTable::new( + &tables.iter().collect::>(), + true, + tables.iter().map(|t| t.len()).sum(), + )?; + + let input_tables = input.get_tables()?; + let mut probe_side_growable = + GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; + + drop(_growables); + { + let _loop = info_span!("OuterHashJoinProbeSink::eval_and_probe").entered(); + for (probe_side_table_idx, table) in input_tables.iter().enumerate() { + let join_keys = table.eval_expression_list(&self.probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + build_side_growable.extend( + build_side_table_idx as usize, + build_row_idx as usize, + 1, + ); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } else { + // if there's no match, we should still emit the probe side and fill the build side with nulls + build_side_growable.add_nulls(1); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } + } + } + let build_side_table = build_side_growable.build()?; + let probe_side_table = probe_side_growable.build()?; + + let final_table = if self.join_type == JoinType::Left { + let join_table = probe_side_table.get_columns(&self.common_join_keys)?; + let left = probe_side_table.get_columns(&self.left_non_join_columns)?; + let right = build_side_table.get_columns(&self.right_non_join_columns)?; + join_table.union(&left)?.union(&right)? + } else { + let join_table = probe_side_table.get_columns(&self.common_join_keys)?; + let left = build_side_table.get_columns(&self.left_non_join_columns)?; + let right = probe_side_table.get_columns(&self.right_non_join_columns)?; + join_table.union(&left)?.union(&right)? + }; + Ok(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + ))) + } + + fn probe_outer( + &self, + input: &Arc, + state: &mut OuterHashJoinProbeState, + ) -> DaftResult> { + let (probe_table, tables) = { + let probe_state = state.get_probe_state(); + ( + probe_state.get_probeable().clone(), + probe_state.get_tables().clone(), + ) + }; + let bitmap_builder = state.get_bitmap_builder(); + let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); + // Need to set use_validity to true here because we add nulls to the build side + let mut build_side_growable = GrowableTable::new( + &tables.iter().collect::>(), + true, + tables.iter().map(|t| t.len()).sum(), + )?; + + let input_tables = input.get_tables()?; + let mut probe_side_growable = + GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; + + let left_idx_used = bitmap_builder + .as_mut() + .expect("bitmap should be set in outer join"); + + drop(_growables); + { + let _loop = info_span!("OuterHashJoinProbeSink::eval_and_probe").entered(); + for (probe_side_table_idx, table) in input_tables.iter().enumerate() { + let join_keys = table.eval_expression_list(&self.probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + let build_side_table_idx = build_side_table_idx as usize; + let build_row_idx = build_row_idx as usize; + left_idx_used.mark_used(build_side_table_idx, build_row_idx); + build_side_growable.extend(build_side_table_idx, build_row_idx, 1); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } else { + // if there's no match, we should still emit the probe side and fill the build side with nulls + build_side_growable.add_nulls(1); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } + } + } + let build_side_table = build_side_growable.build()?; + let probe_side_table = probe_side_growable.build()?; + + let join_table = probe_side_table.get_columns(&self.common_join_keys)?; + let left = build_side_table.get_columns(&self.left_non_join_columns)?; + let right = probe_side_table.get_columns(&self.right_non_join_columns)?; + let final_table = join_table.union(&left)?.union(&right)?; + Ok(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + ))) + } + + fn finalize_outer( + &self, + mut states: Vec>, + ) -> DaftResult>> { + let states = states + .iter_mut() + .map(|s| { + s.as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState") + }) + .collect::>(); + let tables = states + .first() + .expect("at least one state should be present") + .get_probe_state() + .get_tables() + .clone(); + + let merged_bitmap = { + let bitmaps = states.into_iter().map(|s| { + if let OuterHashJoinProbeState::ReadyToProbe(_, bitmap) = s { + bitmap + .take() + .expect("bitmap should be present in outer join") + .build() + } else { + panic!("OuterHashJoinProbeState should be in ReadyToProbe state") + } + }); + bitmaps.fold(None, |acc, x| match acc { + None => Some(x), + Some(acc) => Some(acc.merge(&x)), + }) + } + .expect("at least one bitmap should be present"); + + let leftovers = merged_bitmap + .convert_to_boolean_arrays() + .zip(tables.iter()) + .map(|(bitmap, table)| table.mask_filter(&bitmap.into_series())) + .collect::>>()?; + + let build_side_table = Table::concat(&leftovers)?; + + let join_table = build_side_table.get_columns(&self.common_join_keys)?; + let left = build_side_table.get_columns(&self.left_non_join_columns)?; + let right = { + let columns = self + .right_non_join_schema + .fields + .values() + .map(|field| Series::full_null(&field.name, &field.dtype, left.len())) + .collect::>(); + Table::new_unchecked(self.right_non_join_schema.clone(), columns, left.len()) + }; + let final_table = join_table.union(&left)?.union(&right)?; + Ok(Some(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + )))) + } +} + +impl StreamingSink for OuterHashJoinProbeSink { + #[instrument(skip_all, name = "OuterHashJoinProbeSink::execute")] + fn execute( + &self, + idx: usize, + input: &PipelineResultType, + state: &mut dyn StreamingSinkState, + ) -> DaftResult { + match idx { + 0 => { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + let probe_state = input.as_probe_state(); + state + .initialize_probe_state(probe_state.clone(), self.join_type == JoinType::Outer); + Ok(StreamingSinkOutput::NeedMoreInput(None)) + } + _ => { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + let input = input.as_data(); + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(StreamingSinkOutput::NeedMoreInput(Some(empty))); + } + let out = match self.join_type { + JoinType::Left | JoinType::Right => self.probe_left_right(input, state), + JoinType::Outer => self.probe_outer(input, state), + _ => unreachable!( + "Only Left, Right, and Outer joins are supported in OuterHashJoinProbeSink" + ), + }?; + Ok(StreamingSinkOutput::NeedMoreInput(Some(out))) + } + } + } + + fn name(&self) -> &'static str { + "OuterHashJoinProbeSink" + } + + fn make_state(&self) -> Box { + Box::new(OuterHashJoinProbeState::Building) + } + + fn finalize( + &self, + states: Vec>, + ) -> DaftResult>> { + if self.join_type == JoinType::Outer { + self.finalize_outer(states) + } else { + Ok(None) + } + } +} diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index f18a7efca0..6e8a022cdb 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -3,14 +3,22 @@ use std::sync::Arc; use common_display::tree::TreeDisplay; use common_error::DaftResult; use daft_micropartition::MicroPartition; -use tracing::info_span; +use snafu::ResultExt; +use tracing::{info_span, instrument}; use crate::{ - channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, NUM_CPUS, + channel::{create_channel, PipelineChannel, Receiver, Sender}, + create_task_set, + pipeline::{PipelineNode, PipelineResultType}, + runtime_stats::{CountingReceiver, RuntimeStatsContext}, + ExecutionRuntimeHandle, JoinSnafu, TaskSet, NUM_CPUS, }; -pub enum StreamSinkOutput { +pub trait StreamingSinkState: Send + Sync { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any; +} + +pub enum StreamingSinkOutput { NeedMoreInput(Option>), #[allow(dead_code)] HasMoreOutput(Arc), @@ -18,36 +26,136 @@ pub enum StreamSinkOutput { } pub trait StreamingSink: Send + Sync { + /// Execute the StreamingSink operator on the morsel of input data, + /// received from the child with the given index, + /// with the given state. fn execute( - &mut self, + &self, index: usize, - input: &Arc, - ) -> DaftResult; - #[allow(dead_code)] + input: &PipelineResultType, + state: &mut dyn StreamingSinkState, + ) -> DaftResult; + + /// Finalize the StreamingSink operator, with the given states from each worker. + fn finalize( + &self, + states: Vec>, + ) -> DaftResult>>; + + /// The name of the StreamingSink operator. fn name(&self) -> &'static str; + + /// Create a new worker-local state for this StreamingSink. + fn make_state(&self) -> Box; + + /// The maximum number of concurrent workers that can be spawned for this sink. + /// Each worker will has its own StreamingSinkState. + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } } pub struct StreamingSinkNode { - // use a RW lock - op: Arc>>, + op: Arc, name: &'static str, children: Vec>, runtime_stats: Arc, } impl StreamingSinkNode { - pub(crate) fn new(op: Box, children: Vec>) -> Self { + pub(crate) fn new(op: Arc, children: Vec>) -> Self { let name = op.name(); Self { - op: Arc::new(tokio::sync::Mutex::new(op)), + op, name, children, runtime_stats: RuntimeStatsContext::new(), } } + pub(crate) fn boxed(self) -> Box { Box::new(self) } + + #[instrument(level = "info", skip_all, name = "StreamingSink::run_worker")] + async fn run_worker( + op: Arc, + mut input_receiver: Receiver<(usize, PipelineResultType)>, + output_sender: Sender>, + rt_context: Arc, + ) -> DaftResult> { + let span = info_span!("StreamingSink::Execute"); + let mut state = op.make_state(); + while let Some((idx, morsel)) = input_receiver.recv().await { + loop { + let result = + rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; + match result { + StreamingSinkOutput::NeedMoreInput(mp) => { + if let Some(mp) = mp { + let _ = output_sender.send(mp).await; + } + break; + } + StreamingSinkOutput::HasMoreOutput(mp) => { + let _ = output_sender.send(mp).await; + } + StreamingSinkOutput::Finished(mp) => { + if let Some(mp) = mp { + let _ = output_sender.send(mp).await; + } + return Ok(state); + } + } + } + } + Ok(state) + } + + fn spawn_workers( + op: Arc, + input_receivers: Vec>, + task_set: &mut TaskSet>>, + stats: Arc, + ) -> Receiver> { + let (output_sender, output_receiver) = create_channel(input_receivers.len()); + for input_receiver in input_receivers { + task_set.spawn(Self::run_worker( + op.clone(), + input_receiver, + output_sender.clone(), + stats.clone(), + )); + } + output_receiver + } + + // Forwards input from the children to the workers in a round-robin fashion. + // Always exhausts the input from one child before moving to the next. + async fn forward_input_to_workers( + receivers: Vec, + worker_senders: Vec>, + ) -> DaftResult<()> { + let mut next_worker_idx = 0; + let mut send_to_next_worker = |idx, data: PipelineResultType| { + let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); + next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); + next_worker_sender.send((idx, data)) + }; + + for (idx, mut receiver) in receivers.into_iter().enumerate() { + while let Some(morsel) = receiver.recv().await { + if morsel.should_broadcast() { + for worker_sender in &worker_senders { + let _ = worker_sender.send((idx, morsel.clone())).await; + } + } else { + let _ = send_to_next_worker(idx, morsel.clone()).await; + } + } + } + Ok(()) + } } impl TreeDisplay for StreamingSinkNode { @@ -88,50 +196,49 @@ impl PipelineNode for StreamingSinkNode { maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result { - let child = self - .children - .get_mut(0) - .expect("we should only have 1 child"); - let child_results_channel = child.start(true, runtime_handle)?; - let mut child_results_receiver = - child_results_channel.get_receiver_with_stats(&self.runtime_stats); - - let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order); - let sender = destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let mut child_result_receivers = Vec::with_capacity(self.children.len()); + for child in &mut self.children { + let child_result_channel = child.start(maintain_order, runtime_handle)?; + child_result_receivers + .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats.clone())); + } + + let mut destination_channel = PipelineChannel::new(1, maintain_order); + let destination_sender = + destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); + let num_workers = op.max_concurrency(); + let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); + runtime_handle.spawn( + Self::forward_input_to_workers(child_result_receivers, input_senders), + self.name(), + ); runtime_handle.spawn( async move { - // this should be a RWLock and run in concurrent workers - let span = info_span!("StreamingSink::execute"); - - let mut sink = op.lock().await; - let mut is_active = true; - while is_active && let Some(val) = child_results_receiver.recv().await { - let val = val.as_data(); - loop { - let result = runtime_stats.in_span(&span, || sink.execute(0, val))?; - match result { - StreamSinkOutput::HasMoreOutput(mp) => { - sender.send(mp.into()).await.unwrap(); - } - StreamSinkOutput::NeedMoreInput(mp) => { - if let Some(mp) = mp { - sender.send(mp.into()).await.unwrap(); - } - break; - } - StreamSinkOutput::Finished(mp) => { - if let Some(mp) = mp { - sender.send(mp.into()).await.unwrap(); - } - is_active = false; - break; - } - } - } + let mut task_set = create_task_set(); + let mut output_receiver = Self::spawn_workers( + op.clone(), + input_receivers, + &mut task_set, + runtime_stats.clone(), + ); + + while let Some(morsel) = output_receiver.recv().await { + let _ = destination_sender.send(morsel.into()).await; + } + + let mut finished_states = Vec::with_capacity(num_workers); + while let Some(result) = task_set.join_next().await { + let state = result.context(JoinSnafu)??; + finished_states.push(state); + } + + if let Some(finalized_result) = op.finalize(finished_states)? { + let _ = destination_sender.send(finalized_result.into()).await; } - DaftResult::Ok(()) + Ok(()) }, self.name(), ); diff --git a/src/daft-physical-plan/src/lib.rs b/src/daft-physical-plan/src/lib.rs index 824fc6c099..ba20720855 100644 --- a/src/daft-physical-plan/src/lib.rs +++ b/src/daft-physical-plan/src/lib.rs @@ -4,7 +4,7 @@ mod translate; pub use local_plan::{ Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, - LocalPhysicalPlan, LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Project, Sample, Sort, - UnGroupedAggregate, + LocalPhysicalPlan, LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Pivot, Project, Sample, + Sort, UnGroupedAggregate, Unpivot, }; pub use translate::translate; diff --git a/src/daft-physical-plan/src/local_plan.rs b/src/daft-physical-plan/src/local_plan.rs index 05fa44eda5..94672c2463 100644 --- a/src/daft-physical-plan/src/local_plan.rs +++ b/src/daft-physical-plan/src/local_plan.rs @@ -16,7 +16,7 @@ pub enum LocalPhysicalPlan { Filter(Filter), Limit(Limit), Explode(Explode), - // Unpivot(Unpivot), + Unpivot(Unpivot), Sort(Sort), // Split(Split), Sample(Sample), @@ -29,7 +29,7 @@ pub enum LocalPhysicalPlan { // ReduceMerge(ReduceMerge), UnGroupedAggregate(UnGroupedAggregate), HashAggregate(HashAggregate), - // Pivot(Pivot), + Pivot(Pivot), Concat(Concat), HashJoin(HashJoin), // SortMergeJoin(SortMergeJoin), @@ -165,6 +165,46 @@ impl LocalPhysicalPlan { .arced() } + pub(crate) fn unpivot( + input: LocalPhysicalPlanRef, + ids: Vec, + values: Vec, + variable_name: String, + value_name: String, + schema: SchemaRef, + ) -> LocalPhysicalPlanRef { + Self::Unpivot(Unpivot { + input, + ids, + values, + variable_name, + value_name, + schema, + plan_stats: PlanStats {}, + }) + .arced() + } + + pub(crate) fn pivot( + input: LocalPhysicalPlanRef, + group_by: Vec, + pivot_column: ExprRef, + value_column: ExprRef, + names: Vec, + schema: SchemaRef, + ) -> LocalPhysicalPlanRef { + Self::Pivot(Pivot { + input, + group_by, + pivot_column, + value_column, + names, + schema, + plan_stats: PlanStats {}, + }) + .arced() + } + pub(crate) fn sort( input: LocalPhysicalPlanRef, sort_by: Vec, @@ -242,10 +282,12 @@ impl LocalPhysicalPlan { | Self::Project(Project { schema, .. }) | Self::UnGroupedAggregate(UnGroupedAggregate { schema, .. }) | Self::HashAggregate(HashAggregate { schema, .. }) + | Self::Pivot(Pivot { schema, .. }) | Self::Sort(Sort { schema, .. }) | Self::Sample(Sample { schema, .. }) | Self::HashJoin(HashJoin { schema, .. }) | Self::Explode(Explode { schema, .. }) + | Self::Unpivot(Unpivot { schema, .. }) | Self::Concat(Concat { schema, .. }) => schema, Self::InMemoryScan(InMemoryScan { info, .. }) => &info.source_schema, _ => todo!("{:?}", self), @@ -340,6 +382,28 @@ pub struct HashAggregate { pub plan_stats: PlanStats, } +#[derive(Debug)] +pub struct Unpivot { + pub input: LocalPhysicalPlanRef, + pub ids: Vec, + pub values: Vec, + pub variable_name: String, + pub value_name: String, + pub schema: SchemaRef, + pub plan_stats: PlanStats, +} + +#[derive(Debug)] +pub struct Pivot { + pub input: LocalPhysicalPlanRef, + pub group_by: Vec, + pub pivot_column: ExprRef, + pub value_column: ExprRef, + pub names: Vec, + pub schema: SchemaRef, + pub plan_stats: PlanStats, +} + #[derive(Debug)] pub struct HashJoin { pub left: LocalPhysicalPlanRef, diff --git a/src/daft-physical-plan/src/translate.rs b/src/daft-physical-plan/src/translate.rs index 69744dd5ce..7dcb0f552b 100644 --- a/src/daft-physical-plan/src/translate.rs +++ b/src/daft-physical-plan/src/translate.rs @@ -1,5 +1,5 @@ use common_error::DaftResult; -use daft_core::join::JoinStrategy; +use daft_core::{join::JoinStrategy, prelude::Schema}; use daft_dsl::ExprRef; use daft_plan::{LogicalPlan, LogicalPlanRef, SourceInfo}; @@ -70,6 +70,46 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { )) } } + LogicalPlan::Unpivot(unpivot) => { + let input = translate(&unpivot.input)?; + Ok(LocalPhysicalPlan::unpivot( + input, + unpivot.ids.clone(), + unpivot.values.clone(), + unpivot.variable_name.clone(), + unpivot.value_name.clone(), + unpivot.output_schema.clone(), + )) + } + LogicalPlan::Pivot(pivot) => { + let input = translate(&pivot.input)?; + let groupby_with_pivot = pivot + .group_by + .iter() + .chain(std::iter::once(&pivot.pivot_column)) + .cloned() + .collect::>(); + let aggregate_fields = groupby_with_pivot + .iter() + .map(|expr| expr.to_field(input.schema())) + .chain(std::iter::once(pivot.aggregation.to_field(input.schema()))) + .collect::>>()?; + let aggregate_schema = Schema::new(aggregate_fields)?; + let aggregate = LocalPhysicalPlan::hash_aggregate( + input, + vec![pivot.aggregation.clone(); 1], + groupby_with_pivot, + aggregate_schema.into(), + ); + Ok(LocalPhysicalPlan::pivot( + aggregate, + pivot.group_by.clone(), + pivot.pivot_column.clone(), + pivot.value_column.clone(), + pivot.names.clone(), + pivot.output_schema.clone(), + )) + } LogicalPlan::Sort(sort) => { let input = translate(&sort.input)?; Ok(LocalPhysicalPlan::sort( diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index e651b6528f..55823e5843 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -11,7 +11,7 @@ use daft_functions::numeric::{ceil::ceil, floor::floor}; use daft_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ - ArrayElemTypeDef, BinaryOperator, CastKind, ExactNumberInfo, ExcludeSelectItem, + ArrayElemTypeDef, BinaryOperator, CastKind, Distinct, ExactNumberInfo, ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, Statement, StructField, Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value, WildcardAdditionalOptions, }, @@ -202,6 +202,15 @@ impl SQLPlanner { } } + match &selection.distinct { + Some(Distinct::Distinct) => { + let rel = self.relation_mut(); + rel.inner = rel.inner.distinct()?; + } + Some(Distinct::On(_)) => unsupported_sql_err!("DISTINCT ON"), + None => {} + } + if let Some(order_by) = &query.order_by { if order_by.interpolate.is_some() { unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]"); @@ -1186,9 +1195,7 @@ fn check_select_features(selection: &sqlparser::ast::Select) -> SQLPlannerResult if selection.top.is_some() { unsupported_sql_err!("TOP"); } - if selection.distinct.is_some() { - unsupported_sql_err!("DISTINCT"); - } + if selection.into.is_some() { unsupported_sql_err!("INTO"); } diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 1d84e3d7b1..e93f4d7a77 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -29,7 +29,7 @@ mod probeable; mod repr_html; pub use growable::GrowableTable; -pub use probeable::{make_probeable_builder, Probeable, ProbeableBuilder}; +pub use probeable::{make_probeable_builder, ProbeState, Probeable, ProbeableBuilder}; #[cfg(feature = "python")] pub mod python; diff --git a/src/daft-table/src/probeable/mod.rs b/src/daft-table/src/probeable/mod.rs index a3d935246e..3346bd9869 100644 --- a/src/daft-table/src/probeable/mod.rs +++ b/src/daft-table/src/probeable/mod.rs @@ -77,3 +77,23 @@ pub trait Probeable: Send + Sync { table: &'a Table, ) -> DaftResult + 'a>>; } + +#[derive(Clone)] +pub struct ProbeState { + probeable: Arc, + tables: Arc>, +} + +impl ProbeState { + pub fn new(probeable: Arc, tables: Arc>) -> Self { + Self { probeable, tables } + } + + pub fn get_probeable(&self) -> &Arc { + &self.probeable + } + + pub fn get_tables(&self) -> &Arc> { + &self.tables + } +} diff --git a/tests/cookbook/test_joins.py b/tests/cookbook/test_joins.py index b51c863100..d80dce72a2 100644 --- a/tests/cookbook/test_joins.py +++ b/tests/cookbook/test_joins.py @@ -6,16 +6,20 @@ from daft.expressions import col from tests.conftest import assert_df_equals -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) + +def skip_invalid_join_strategies(join_strategy): + if context.get_context().daft_execution_config.enable_native_executor is True: + if join_strategy not in [None, "hash"]: + pytest.skip("Native executor fails for these tests") @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_join(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df = daft_df.repartition(repartition_nparts) daft_df_left = daft_df.select(col("Unique Key"), col("Borough")) daft_df_right = daft_df.select(col("Unique Key"), col("Created Date")) @@ -33,9 +37,12 @@ def test_simple_join(join_strategy, daft_df, service_requests_csv_pd_df, reparti @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_self_join(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df = daft_df.repartition(repartition_nparts) daft_df = daft_df.select(col("Unique Key"), col("Borough")) @@ -44,7 +51,11 @@ def test_simple_self_join(join_strategy, daft_df, service_requests_csv_pd_df, re service_requests_csv_pd_df = service_requests_csv_pd_df[["Unique Key", "Borough"]] service_requests_csv_pd_df = ( service_requests_csv_pd_df.set_index("Unique Key") - .join(service_requests_csv_pd_df.set_index("Unique Key"), how="inner", rsuffix="_right") + .join( + service_requests_csv_pd_df.set_index("Unique Key"), + how="inner", + rsuffix="_right", + ) .reset_index() ) service_requests_csv_pd_df = service_requests_csv_pd_df.rename({"Borough_right": "right.Borough"}, axis=1) @@ -53,9 +64,12 @@ def test_simple_self_join(join_strategy, daft_df, service_requests_csv_pd_df, re @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_join_missing_rvalues(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df_right = daft_df.sort("Unique Key").limit(25).repartition(repartition_nparts) daft_df_left = daft_df.repartition(repartition_nparts) daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough")) @@ -76,9 +90,12 @@ def test_simple_join_missing_rvalues(join_strategy, daft_df, service_requests_cs @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_join_missing_lvalues(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df_right = daft_df.repartition(repartition_nparts) daft_df_left = daft_df.sort(col("Unique Key")).limit(25).repartition(repartition_nparts) daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough")) diff --git a/tests/dataframe/test_approx_count_distinct.py b/tests/dataframe/test_approx_count_distinct.py index 78d2a7b181..68d7057ca0 100644 --- a/tests/dataframe/test_approx_count_distinct.py +++ b/tests/dataframe/test_approx_count_distinct.py @@ -2,12 +2,7 @@ import pytest import daft -from daft import col, context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) +from daft import col TESTS = [ [[], 0], diff --git a/tests/dataframe/test_concat.py b/tests/dataframe/test_concat.py index 07e06df59c..f3caf56bb1 100644 --- a/tests/dataframe/test_concat.py +++ b/tests/dataframe/test_concat.py @@ -2,13 +2,6 @@ import pytest -from daft import context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - def test_simple_concat(make_df): df1 = make_df({"foo": [1, 2, 3]}) diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 980d19aa3f..4b08abea61 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -12,7 +12,7 @@ def skip_invalid_join_strategies(join_strategy, join_type): if context.get_context().daft_execution_config.enable_native_executor is True: - if join_type == "outer" or join_strategy not in [None, "hash"]: + if join_strategy not in [None, "hash"]: pytest.skip("Native executor fails for these tests") else: if (join_strategy == "sort_merge" or join_strategy == "sort_merge_aligned_boundaries") and join_type != "inner": diff --git a/tests/dataframe/test_pivot.py b/tests/dataframe/test_pivot.py index 232d8b0b45..fcd88c9c51 100644 --- a/tests/dataframe/test_pivot.py +++ b/tests/dataframe/test_pivot.py @@ -1,12 +1,5 @@ import pytest -from daft import context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - @pytest.mark.parametrize("repartition_nparts", [1, 2, 5]) def test_pivot(make_df, repartition_nparts): diff --git a/tests/dataframe/test_sample.py b/tests/dataframe/test_sample.py index 791e2a2211..109b9f332b 100644 --- a/tests/dataframe/test_sample.py +++ b/tests/dataframe/test_sample.py @@ -2,8 +2,6 @@ import pytest -from daft import context - def test_sample_fraction(make_df, valid_data: list[dict[str, float]]) -> None: df = make_df(valid_data) @@ -100,10 +98,6 @@ def test_sample_without_replacement(make_df, valid_data: list[dict[str, float]]) assert pylist[0] != pylist[1] -@pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for concat", -) def test_sample_with_concat(make_df, valid_data: list[dict[str, float]]) -> None: df1 = make_df(valid_data) df2 = make_df(valid_data) diff --git a/tests/dataframe/test_transform.py b/tests/dataframe/test_transform.py index 277c378bad..a698b6e7fd 100644 --- a/tests/dataframe/test_transform.py +++ b/tests/dataframe/test_transform.py @@ -3,12 +3,6 @@ import pytest import daft -from daft import context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) def add_1(df): diff --git a/tests/dataframe/test_unpivot.py b/tests/dataframe/test_unpivot.py index e40edb0008..b4c7a84cc5 100644 --- a/tests/dataframe/test_unpivot.py +++ b/tests/dataframe/test_unpivot.py @@ -1,13 +1,8 @@ import pytest -from daft import col, context +from daft import col from daft.datatype import DataType -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - @pytest.mark.parametrize("n_partitions", [1, 2, 4]) def test_unpivot(make_df, n_partitions): diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index f5c01dccc6..e202eed471 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -26,6 +26,7 @@ "trino://user@localhost:8080/memory/default", "postgresql://username:password@localhost:5432/postgres", "mysql+pymysql://username:password@localhost:3306/mysql", + "mssql+pyodbc://SA:StrongPassword!@127.0.0.1:1433/master?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes", ] TEST_TABLE_NAME = "example" EMPTY_TEST_TABLE_NAME = "empty_table" diff --git a/tests/integration/sql/docker-compose/docker-compose.yml b/tests/integration/sql/docker-compose/docker-compose.yml index 11c391b0d3..b8eb8c3eba 100644 --- a/tests/integration/sql/docker-compose/docker-compose.yml +++ b/tests/integration/sql/docker-compose/docker-compose.yml @@ -31,6 +31,18 @@ services: volumes: - mysql_data:/var/lib/mysql + azuresqledge: + image: mcr.microsoft.com/azure-sql-edge + container_name: azuresqledge + environment: + ACCEPT_EULA: "Y" + MSSQL_SA_PASSWORD: "StrongPassword!" + ports: + - 1433:1433 + volumes: + - azuresqledge_data:/var/opt/mssql + volumes: postgres_data: mysql_data: + azuresqledge_data: diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index ff02ebaac4..7983be00c7 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -141,6 +141,10 @@ def test_sql_read_with_partition_num_without_partition_col(test_db) -> None: ) @pytest.mark.parametrize("num_partitions", [1, 2]) def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value, num_partitions, pdf) -> None: + # Skip invalid comparisons for bool_col + if column == "bool_col" and operator not in ("=", "!="): + pytest.skip(f"Operator {operator} not valid for bool_col") + df = daft.read_sql( f"SELECT * FROM {TEST_TABLE_NAME}", test_db, @@ -204,13 +208,15 @@ def test_sql_read_with_not_null_filter_pushdowns(test_db, num_partitions, pdf) - @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_if_else_filter_pushdown(test_db, num_partitions, pdf) -> None: +def test_sql_read_with_non_pushdowned_predicate(test_db, num_partitions, pdf) -> None: df = daft.read_sql( f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions, ) + + # If_else is not supported as a pushdown to read_sql, but it should still work df = df.where((df["id"] > 100).if_else(df["float_col"] > 150, df["float_col"] < 50)) pdf = pdf[(pdf["id"] > 100) & (pdf["float_col"] > 150) | (pdf["float_col"] < 50)] diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index 8b8cce43b5..6bcd716854 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -214,3 +214,10 @@ def test_sql_tbl_alias(): catalog = SQLCatalog({"df": daft.from_pydict({"n": [1, 2, 3]})}) df = daft.sql("SELECT df_alias.n FROM df AS df_alias where df_alias.n = 2", catalog) assert df.collect().to_pydict() == {"n": [2]} + + +def test_sql_distinct(): + df = daft.from_pydict({"n": [1, 1, 2, 2]}) + actual = daft.sql("SELECT DISTINCT n FROM df").collect().to_pydict() + expected = df.distinct().collect().to_pydict() + assert actual == expected