Skip to content

Commit

Permalink
ARROW-15523: [Python] Support for Datasets as inputs of Joins
Browse files Browse the repository at this point in the history
This supports providing Datasets as inputs of the join operation.
Until https://issues.apache.org/jira/browse/ARROW-15526 is completed, the join will temporary return a Table in case of Datasets too.

Closes #12765 from amol-/ARROW-15523

Authored-by: Alessandro Molina <[email protected]>
Signed-off-by: Alessandro Molina <[email protected]>
  • Loading branch information
amol- committed Apr 7, 2022
1 parent 8ea2c93 commit dd42155
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 52 deletions.
1 change: 1 addition & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ endif()

if(PYARROW_BUILD_DATASET)
target_link_libraries(_dataset PRIVATE ${DATASET_LINK_LIBS})
target_link_libraries(_exec_plan PRIVATE ${DATASET_LINK_LIBS})
if(PYARROW_BUILD_ORC)
target_link_libraries(_dataset_orc PRIVATE ${DATASET_LINK_LIBS})
endif()
Expand Down
14 changes: 14 additions & 0 deletions python/pyarrow/_dataset.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ cdef class DatasetFactory(_Weakrefable):
cdef inline shared_ptr[CDatasetFactory] unwrap(self) nogil


cdef class Dataset(_Weakrefable):

cdef:
shared_ptr[CDataset] wrapped
CDataset* dataset

cdef void init(self, const shared_ptr[CDataset]& sp)

@staticmethod
cdef wrap(const shared_ptr[CDataset]& sp)

cdef shared_ptr[CDataset] unwrap(self) nogil


cdef class FragmentScanOptions(_Weakrefable):

cdef:
Expand Down
4 changes: 0 additions & 4 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,6 @@ cdef class Dataset(_Weakrefable):
can accelerate queries that only touch some partitions (files).
"""

cdef:
shared_ptr[CDataset] wrapped
CDataset* dataset

def __init__(self):
_forbid_instantiation(self.__class__)

Expand Down
77 changes: 54 additions & 23 deletions python/pyarrow/_exec_plan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ from cython.operator cimport dereference as deref, preincrement as inc

from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_dataset cimport *
from pyarrow.lib cimport (Table, pyarrow_unwrap_table, pyarrow_wrap_table)
from pyarrow.lib import tobytes, _pc
from pyarrow._compute cimport Expression, _true
from pyarrow._dataset cimport Dataset

Initialize() # Initialise support for Datasets in ExecPlan


cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=True):
Expand Down Expand Up @@ -61,10 +65,13 @@ cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads
CTable* c_table
shared_ptr[CTable] c_out_table
shared_ptr[CSourceNodeOptions] c_sourceopts
shared_ptr[CScanNodeOptions] c_scanopts
shared_ptr[CExecNodeOptions] c_input_node_opts
shared_ptr[CSinkNodeOptions] c_sinkopts
shared_ptr[CAsyncExecBatchGenerator] c_async_exec_batch_gen
shared_ptr[CRecordBatchReader] c_recordbatchreader
vector[CDeclaration].iterator plan_iter
vector[CDeclaration.Input] no_c_inputs

if use_threads:
c_executor = GetCpuThreadPool()
Expand All @@ -81,21 +88,34 @@ cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads
# Create source nodes for each input
for ipt in inputs:
if isinstance(ipt, Table):
node_factory = "source"
c_in_table = pyarrow_unwrap_table(ipt).get()
c_sourceopts = GetResultValue(
CSourceNodeOptions.FromTable(deref(c_in_table), deref(c_exec_context).executor()))
c_input_node_opts = static_pointer_cast[CExecNodeOptions, CSourceNodeOptions](
c_sourceopts)
elif isinstance(ipt, Dataset):
node_factory = "scan"
c_in_dataset = (<Dataset>ipt).unwrap()
c_scanopts = make_shared[CScanNodeOptions](
c_in_dataset, make_shared[CScanOptions]())
deref(deref(c_scanopts).scan_options).use_threads = use_threads
c_input_node_opts = static_pointer_cast[CExecNodeOptions, CScanNodeOptions](
c_scanopts)
else:
raise TypeError("Unsupported type")

if plan_iter != plan.end():
# Flag the source as the input of the first plan node.
deref(plan_iter).inputs.push_back(CDeclaration.Input(
CDeclaration(tobytes("source"), deref(c_sourceopts))
CDeclaration(tobytes(node_factory),
no_c_inputs, c_input_node_opts)
))
else:
# Empty plan, make the source the first plan node.
c_decls.push_back(
CDeclaration(tobytes("source"), deref(c_sourceopts))
CDeclaration(tobytes(node_factory),
no_c_inputs, c_input_node_opts)
)

# Add Here additional nodes
Expand Down Expand Up @@ -139,33 +159,33 @@ cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads
return output


def tables_join(join_type, left_table not None, left_keys,
right_table not None, right_keys,
left_suffix=None, right_suffix=None,
use_threads=True, coalesce_keys=False):
def _perform_join(join_type, left_operand not None, left_keys,
right_operand not None, right_keys,
left_suffix=None, right_suffix=None,
use_threads=True, coalesce_keys=False):
"""
Perform join of two tables.
Perform join of two tables or datasets.
The result will be an output table with the result of the join operation
Parameters
----------
join_type : str
One of supported join types.
left_table : Table
The left table for the join operation.
left_operand : Table or Dataset
The left operand for the join operation.
left_keys : str or list[str]
The left table key (or keys) on which the join operation should be performed.
right_table : Table
The right table for the join operation.
The left key (or keys) on which the join operation should be performed.
right_operand : Table or Dataset
The right operand for the join operation.
right_keys : str or list[str]
The right table key (or keys) on which the join operation should be performed.
The right key (or keys) on which the join operation should be performed.
left_suffix : str, default None
Which suffix to add to right column names. This prevents confusion
when the columns in left and right tables have colliding names.
when the columns in left and right operands have colliding names.
right_suffix : str, default None
Which suffic to add to the left column names. This prevents confusion
when the columns in left and right tables have colliding names.
when the columns in left and right operands have colliding names.
use_threads : bool, default True
Whenever to use multithreading or not.
coalesce_keys : bool, default False
Expand Down Expand Up @@ -202,8 +222,19 @@ def tables_join(join_type, left_table not None, left_keys,
c_right_keys.push_back(CFieldRef(<c_string>tobytes(key)))

# By default expose all columns on both left and right table
left_columns = left_table.column_names
right_columns = right_table.column_names
if isinstance(left_operand, Table):
left_columns = left_operand.column_names
elif isinstance(left_operand, Dataset):
left_columns = left_operand.schema.names
else:
raise TypeError("Unsupported left join member type")

if isinstance(right_operand, Table):
right_columns = right_operand.column_names
elif isinstance(right_operand, Dataset):
right_columns = right_operand.schema.names
else:
raise TypeError("Unsupported right join member type")

# Pick the join type
if join_type == "left semi":
Expand Down Expand Up @@ -262,7 +293,7 @@ def tables_join(join_type, left_table not None, left_keys,
left_columns_set = set(left_columns)
right_columns_set = set(right_columns)
# Where the right table columns start.
right_table_index = len(left_columns)
right_operand_index = len(left_columns)
for idx, col in enumerate(left_columns + right_columns):
if idx < len(left_columns) and col in left_column_keys_indices:
# Include keys only once and coalesce left+right table keys.
Expand All @@ -275,19 +306,19 @@ def tables_join(join_type, left_table not None, left_keys,
c_projections.push_back(Expression.unwrap(
Expression._call("coalesce", [
Expression._field(idx), Expression._field(
right_table_index+right_key_index)
right_operand_index+right_key_index)
])
))
elif idx >= right_table_index and col in right_column_keys_indices:
elif idx >= right_operand_index and col in right_column_keys_indices:
# Do not include right table keys. As they would lead to duplicated keys.
continue
else:
# For all the other columns incude them as they are.
# Just recompute the suffixes that the join produced as the projection
# would lose them otherwise.
if left_suffix and idx < right_table_index and col in right_columns_set:
if left_suffix and idx < right_operand_index and col in right_columns_set:
col += left_suffix
if right_suffix and idx >= right_table_index and col in left_columns_set:
if right_suffix and idx >= right_operand_index and col in left_columns_set:
col += right_suffix
c_projected_col_names.push_back(tobytes(col))
c_projections.push_back(
Expand All @@ -306,7 +337,7 @@ def tables_join(join_type, left_table not None, left_keys,
))
)

result_table = execplan([left_table, right_table],
result_table = execplan([left_operand, right_operand],
output_type=Table,
plan=c_decl_plan)

Expand Down
1 change: 0 additions & 1 deletion python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
# Expressions
Expression,
)
from pyarrow import _exec_plan # noqa

from collections import namedtuple
import inspect
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2503,6 +2503,7 @@ cdef extern from "arrow/compute/exec/exec_plan.h" namespace "arrow::compute" nog
vector[Input] inputs

CDeclaration(c_string factory_name, CExecNodeOptions options)
CDeclaration(c_string factory_name, vector[Input] inputs, shared_ptr[CExecNodeOptions] options)

@staticmethod
CDeclaration Sequence(vector[CDeclaration] decls)
Expand Down
14 changes: 11 additions & 3 deletions python/pyarrow/includes/libarrow_dataset.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
pass


cdef extern from "arrow/dataset/plan.h" namespace "arrow::dataset::internal" nogil:

cdef void Initialize()


ctypedef CStatus cb_writer_finish_internal(CFileWriter*)
ctypedef void cb_writer_finish(dict, CFileWriter*)

Expand All @@ -45,11 +50,14 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
arrow::dataset::ExistingDataBehavior::kError"

cdef cppclass CScanOptions "arrow::dataset::ScanOptions":
@staticmethod
shared_ptr[CScanOptions] Make(shared_ptr[CSchema] schema)

shared_ptr[CSchema] dataset_schema
shared_ptr[CSchema] projected_schema
c_bool use_threads

cdef cppclass CScanNodeOptions "arrow::dataset::ScanNodeOptions"(CExecNodeOptions):
CScanNodeOptions(shared_ptr[CDataset] dataset, shared_ptr[CScanOptions] scan_options)

shared_ptr[CScanOptions] scan_options

cdef cppclass CFragmentScanOptions "arrow::dataset::FragmentScanOptions":
c_string type_name() const
Expand Down
14 changes: 12 additions & 2 deletions python/pyarrow/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,20 @@ Type_DICTIONARY = _Type_DICTIONARY
UnionMode_SPARSE = _UnionMode_SPARSE
UnionMode_DENSE = _UnionMode_DENSE

__pc = None


def _pc():
import pyarrow.compute as pc
return pc
global __pc
if __pc is None:
import pyarrow.compute as pc
try:
from pyarrow import _exec_plan
pc._exec_plan = _exec_plan
except ImportError:
pass
__pc = pc
return __pc


def _gdb_test_session():
Expand Down
6 changes: 3 additions & 3 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -4252,9 +4252,9 @@ cdef class Table(_PandasConvertible):
"""
if right_keys is None:
right_keys = keys
return _pc()._exec_plan.tables_join(join_type, self, keys, right_table, right_keys,
left_suffix=left_suffix, right_suffix=right_suffix,
use_threads=use_threads, coalesce_keys=coalesce_keys)
return _pc()._exec_plan._perform_join(join_type, self, keys, right_table, right_keys,
left_suffix=left_suffix, right_suffix=right_suffix,
use_threads=use_threads, coalesce_keys=coalesce_keys)

def group_by(self, keys):
"""Declare a grouping over the columns of the table.
Expand Down
Loading

0 comments on commit dd42155

Please sign in to comment.