From e1027dc70f8344257baa30fe385c5a8154f366a9 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 12 Jan 2023 16:55:45 -0800 Subject: [PATCH] GH-33212: [C++][Python] Add use_threads to pyarrow.substrait.run_query (#33623) Also adds memory_pool and & function_registry to the various DeclarationToXyz methods. Converts `ExecuteSerializedPlan` to `DeclarationToReader` instead of the bespoke thing it was doing before. * Closes: #33212 Lead-authored-by: Weston Pace Co-authored-by: Vibhatha Lakmal Abeykoon Signed-off-by: Weston Pace --- cpp/src/arrow/compute/exec/exec_plan.cc | 87 ++++++++------ cpp/src/arrow/compute/exec/exec_plan.h | 68 ++++++++--- cpp/src/arrow/engine/substrait/serde.cc | 22 ++++ cpp/src/arrow/engine/substrait/serde.h | 17 +++ cpp/src/arrow/engine/substrait/util.cc | 112 ++---------------- cpp/src/arrow/engine/substrait/util.h | 16 ++- python/pyarrow/_substrait.pyx | 10 +- .../pyarrow/includes/libarrow_substrait.pxd | 3 +- python/pyarrow/tests/test_substrait.py | 22 ++-- 9 files changed, 185 insertions(+), 172 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index b8886619d7dae..88cd298d2cb4e 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -563,23 +563,28 @@ Future> DeclarationToTableAsync(Declaration declaration, return exec_plan->finished().Then([exec_plan, output_table] { return *output_table; }); } -Future> DeclarationToTableAsync(Declaration declaration, - bool use_threads) { +Future> DeclarationToTableAsync( + Declaration declaration, bool use_threads, MemoryPool* memory_pool, + FunctionRegistry* function_registry) { if (use_threads) { - return DeclarationToTableAsync(std::move(declaration), *threaded_exec_context()); + ExecContext ctx(memory_pool, ::arrow::internal::GetCpuThreadPool(), + function_registry); + return DeclarationToTableAsync(std::move(declaration), ctx); } else { ARROW_ASSIGN_OR_RAISE(std::shared_ptr tpool, ThreadPool::Make(1)); - ExecContext ctx(default_memory_pool(), tpool.get()); + ExecContext ctx(memory_pool, tpool.get(), function_registry); return DeclarationToTableAsync(std::move(declaration), ctx) .Then([tpool](const std::shared_ptr& table) { return table; }); } } Result> DeclarationToTable(Declaration declaration, - bool use_threads) { + bool use_threads, + MemoryPool* memory_pool, + FunctionRegistry* function_registry) { return ::arrow::internal::RunSynchronously>>( - [declaration = std::move(declaration)](::arrow::internal::Executor* executor) { - ExecContext ctx(default_memory_pool(), executor); + [=, declaration = std::move(declaration)](::arrow::internal::Executor* executor) { + ExecContext ctx(memory_pool, executor, function_registry); return DeclarationToTableAsync(std::move(declaration), ctx); }, use_threads); @@ -594,12 +599,15 @@ Future>> DeclarationToBatchesAsync( } Future>> DeclarationToBatchesAsync( - Declaration declaration, bool use_threads) { + Declaration declaration, bool use_threads, MemoryPool* memory_pool, + FunctionRegistry* function_registry) { if (use_threads) { - return DeclarationToBatchesAsync(std::move(declaration), *threaded_exec_context()); + ExecContext ctx(memory_pool, ::arrow::internal::GetCpuThreadPool(), + function_registry); + return DeclarationToBatchesAsync(std::move(declaration), ctx); } else { ARROW_ASSIGN_OR_RAISE(std::shared_ptr tpool, ThreadPool::Make(1)); - ExecContext ctx(default_memory_pool(), tpool.get()); + ExecContext ctx(memory_pool, tpool.get(), function_registry); return DeclarationToBatchesAsync(std::move(declaration), ctx) .Then([tpool](const std::vector>& batches) { return batches; @@ -608,11 +616,12 @@ Future>> DeclarationToBatchesAsync( } Result>> DeclarationToBatches( - Declaration declaration, bool use_threads) { + Declaration declaration, bool use_threads, MemoryPool* memory_pool, + FunctionRegistry* function_registry) { return ::arrow::internal::RunSynchronously< Future>>>( - [declaration = std::move(declaration)](::arrow::internal::Executor* executor) { - ExecContext ctx(default_memory_pool(), executor); + [=, declaration = std::move(declaration)](::arrow::internal::Executor* executor) { + ExecContext ctx(memory_pool, executor, function_registry); return DeclarationToBatchesAsync(std::move(declaration), ctx); }, use_threads); @@ -641,24 +650,27 @@ Future DeclarationToExecBatchesAsync(Declaration declar }); } -Future DeclarationToExecBatchesAsync(Declaration declaration, - bool use_threads) { +Future DeclarationToExecBatchesAsync( + Declaration declaration, bool use_threads, MemoryPool* memory_pool, + FunctionRegistry* function_registry) { if (use_threads) { - return DeclarationToExecBatchesAsync(std::move(declaration), - *threaded_exec_context()); + ExecContext ctx(memory_pool, ::arrow::internal::GetCpuThreadPool(), + function_registry); + return DeclarationToExecBatchesAsync(std::move(declaration), ctx); } else { ARROW_ASSIGN_OR_RAISE(std::shared_ptr tpool, ThreadPool::Make(1)); - ExecContext ctx(default_memory_pool(), tpool.get()); + ExecContext ctx(memory_pool, tpool.get(), function_registry); return DeclarationToExecBatchesAsync(std::move(declaration), ctx) .Then([tpool](const BatchesWithCommonSchema& batches) { return batches; }); } } -Result DeclarationToExecBatches(Declaration declaration, - bool use_threads) { +Result DeclarationToExecBatches( + Declaration declaration, bool use_threads, MemoryPool* memory_pool, + FunctionRegistry* function_registry) { return ::arrow::internal::RunSynchronously>( - [declaration = std::move(declaration)](::arrow::internal::Executor* executor) { - ExecContext ctx(default_memory_pool(), executor); + [=, declaration = std::move(declaration)](::arrow::internal::Executor* executor) { + ExecContext ctx(memory_pool, executor, function_registry); return DeclarationToExecBatchesAsync(std::move(declaration), ctx); }, use_threads); @@ -680,20 +692,25 @@ Future<> DeclarationToStatusAsync(Declaration declaration, ExecContext exec_cont return exec_plan->finished().Then([exec_plan]() {}); } -Future<> DeclarationToStatusAsync(Declaration declaration, bool use_threads) { +Future<> DeclarationToStatusAsync(Declaration declaration, bool use_threads, + MemoryPool* memory_pool, + FunctionRegistry* function_registry) { if (use_threads) { - return DeclarationToStatusAsync(std::move(declaration), *threaded_exec_context()); + ExecContext ctx(memory_pool, ::arrow::internal::GetCpuThreadPool(), + function_registry); + return DeclarationToStatusAsync(std::move(declaration), ctx); } else { ARROW_ASSIGN_OR_RAISE(std::shared_ptr tpool, ThreadPool::Make(1)); - ExecContext ctx(default_memory_pool(), tpool.get()); + ExecContext ctx(memory_pool, tpool.get(), function_registry); return DeclarationToStatusAsync(std::move(declaration), ctx).Then([tpool]() {}); } } -Status DeclarationToStatus(Declaration declaration, bool use_threads) { +Status DeclarationToStatus(Declaration declaration, bool use_threads, + MemoryPool* memory_pool, FunctionRegistry* function_registry) { return ::arrow::internal::RunSynchronously>( - [declaration = std::move(declaration)](::arrow::internal::Executor* executor) { - ExecContext ctx(default_memory_pool(), executor); + [=, declaration = std::move(declaration)](::arrow::internal::Executor* executor) { + ExecContext ctx(memory_pool, executor, function_registry); return DeclarationToStatusAsync(std::move(declaration), ctx); }, use_threads); @@ -738,11 +755,9 @@ struct BatchConverter { }; Result>> DeclarationToRecordBatchGenerator( - Declaration declaration, ::arrow::internal::Executor* executor, - std::shared_ptr* out_schema) { + Declaration declaration, ExecContext exec_ctx, std::shared_ptr* out_schema) { auto converter = std::make_shared(); - ExecContext exec_context(default_memory_pool(), executor); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, ExecPlan::Make(exec_context)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, ExecPlan::Make(exec_ctx)); Declaration with_sink = Declaration::Sequence( {declaration, {"sink", SinkNodeOptions(&converter->exec_batch_gen, &converter->schema)}}); @@ -754,14 +769,16 @@ Result>> DeclarationToRecordBatchGen } } // namespace -Result> DeclarationToReader(Declaration declaration, - bool use_threads) { +Result> DeclarationToReader( + Declaration declaration, bool use_threads, MemoryPool* memory_pool, + FunctionRegistry* function_registry) { std::shared_ptr schema; auto batch_iterator = std::make_unique>>( ::arrow::internal::IterateSynchronously>( [&](::arrow::internal::Executor* executor) -> Result>> { - return DeclarationToRecordBatchGenerator(declaration, executor, &schema); + ExecContext exec_ctx(memory_pool, executor, function_registry); + return DeclarationToRecordBatchGenerator(declaration, exec_ctx, &schema); }, use_threads)); diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 09fab00727824..f7519bbd8812b 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -426,24 +426,36 @@ struct ARROW_EXPORT Declaration { /// \brief Utility method to run a declaration and collect the results into a table /// +/// \param use_threads If `use_threads` is false then all CPU work will be done on the +/// calling thread. I/O tasks will still happen on the I/O executor +/// and may be multi-threaded (but should not use significant CPU +/// resources). +/// \param memory_pool The memory pool to use for allocations made while running the plan. +/// \param function_registry The function registry to use for function execution. If null +/// then the default function registry will be used. +/// /// This method will add a sink node to the declaration to collect results into a /// table. It will then create an ExecPlan from the declaration, start the exec plan, /// block until the plan has finished, and return the created table. -/// -/// If `use_threads` is false then all CPU work will be done on the calling thread. I/O -/// tasks will still happen on the I/O executor and may be multi-threaded (but should -/// not use significant CPU resources) -ARROW_EXPORT Result> DeclarationToTable(Declaration declaration, - bool use_threads = true); +ARROW_EXPORT Result> DeclarationToTable( + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); /// \brief Asynchronous version of \see DeclarationToTable /// -/// The behavior of use_threads is slightly different than the synchronous version since -/// we cannot run synchronously on the calling thread. Instead, if use_threads=false then -/// a new thread pool will be created with a single thread and this will be used for all -/// compute work. +/// \param use_threads The behavior of use_threads is slightly different than the +/// synchronous version since we cannot run synchronously on the +/// calling thread. Instead, if use_threads=false then a new thread +/// pool will be created with a single thread and this will be used for +/// all compute work. +/// \param memory_pool The memory pool to use for allocations made while running the plan. +/// \param function_registry The function registry to use for function execution. If null +/// then the default function registry will be used. ARROW_EXPORT Future> DeclarationToTableAsync( - Declaration declaration, bool use_threads = true); + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); /// \brief Overload of \see DeclarationToTableAsync accepting a custom exec context /// @@ -463,13 +475,17 @@ struct BatchesWithCommonSchema { /// /// \see DeclarationToTable for details on threading & execution ARROW_EXPORT Result DeclarationToExecBatches( - Declaration declaration, bool use_threads = true); + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); /// \brief Asynchronous version of \see DeclarationToExecBatches /// /// \see DeclarationToTableAsync for details on threading & execution ARROW_EXPORT Future DeclarationToExecBatchesAsync( - Declaration declaration, bool use_threads = true); + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); /// \brief Overload of \see DeclarationToExecBatchesAsync accepting a custom exec context /// @@ -481,13 +497,17 @@ ARROW_EXPORT Future DeclarationToExecBatchesAsync( /// /// \see DeclarationToTable for details on threading & execution ARROW_EXPORT Result>> DeclarationToBatches( - Declaration declaration, bool use_threads = true); + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); /// \brief Asynchronous version of \see DeclarationToBatches /// /// \see DeclarationToTableAsync for details on threading & execution ARROW_EXPORT Future>> DeclarationToBatchesAsync( - Declaration declaration, bool use_threads = true); + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); /// \brief Overload of \see DeclarationToBatchesAsync accepting a custom exec context /// @@ -511,7 +531,13 @@ ARROW_EXPORT Future>> DeclarationToBatc /// /// If a custom exec context is provided then the value of `use_threads` will be ignored. ARROW_EXPORT Result> DeclarationToReader( - Declaration declaration, bool use_threads = true); + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); + +/// \brief Overload of \see DeclarationToReader accepting a custom exec context +ARROW_EXPORT Result> DeclarationToReader( + Declaration declaration, ExecContext exec_context); /// \brief Utility method to run a declaration and ignore results /// @@ -519,7 +545,9 @@ ARROW_EXPORT Result> DeclarationToReader( /// example, when the plan ends with a write node. /// /// \see DeclarationToTable for details on threading & execution -ARROW_EXPORT Status DeclarationToStatus(Declaration declaration, bool use_threads = true); +ARROW_EXPORT Status DeclarationToStatus(Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); /// \brief Asynchronous version of \see DeclarationToStatus /// @@ -527,8 +555,10 @@ ARROW_EXPORT Status DeclarationToStatus(Declaration declaration, bool use_thread /// example, when the plan ends with a write node. /// /// \see DeclarationToTableAsync for details on threading & execution -ARROW_EXPORT Future<> DeclarationToStatusAsync(Declaration declaration, - bool use_threads = true); +ARROW_EXPORT Future<> DeclarationToStatusAsync( + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); /// \brief Overload of \see DeclarationToStatusAsync accepting a custom exec context /// diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index ac5de90326ebb..f588aff44481d 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -127,6 +127,13 @@ DeclarationFactory MakeWriteDeclarationFactory( }; } +DeclarationFactory MakeNoSinkDeclarationFactory() { + return [](compute::Declaration input, + std::vector names) -> Result { + return input; + }; +} + // FIXME - Replace with actual version that includes the change constexpr uint32_t kMinimumMajorVersion = 0; constexpr uint32_t kMinimumMinorVersion = 19; @@ -188,6 +195,21 @@ Result> DeserializePlans( registry, ext_set_out, conversion_options); } +ARROW_ENGINE_EXPORT Result DeserializePlan( + const Buffer& buf, const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out, + const ConversionOptions& conversion_options) { + ARROW_ASSIGN_OR_RAISE(std::vector top_level_decls, + DeserializePlans(buf, MakeNoSinkDeclarationFactory(), registry, + ext_set_out, conversion_options)); + if (top_level_decls.empty()) { + return Status::Invalid("No RelRoot in plan"); + } + if (top_level_decls.size() != 1) { + return Status::Invalid("Multiple top level declarations found in Substrait plan"); + } + return top_level_decls[0]; +} + namespace { Result> MakeSingleDeclarationPlan( diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 0f61ba209f4a1..a4e3b3df14513 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -139,6 +139,23 @@ ARROW_ENGINE_EXPORT Result> DeserializePlan( const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR, const ConversionOptions& conversion_options = {}); +/// \brief Deserializes a Substrait Plan message to a Declaration +/// +/// The plan will not contain any sink nodes and will be suitable for use in any +/// of the arrow::compute::DeclarationToXyz methods. +/// +/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan +/// message +/// \param[in] registry an extension-id-registry to use, or null for the default one. +/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait +/// Plan is returned here. +/// \param[in] conversion_options options to control how the conversion is to be done. +/// \return A declaration representing the Substrait plan +ARROW_ENGINE_EXPORT Result DeserializePlan( + const Buffer& buf, const ExtensionIdRegistry* registry = NULLPTR, + ExtensionSet* ext_set_out = NULLPTR, + const ConversionOptions& conversion_options = {}); + /// \brief Deserializes a Substrait Type message to the corresponding Arrow type /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index 22091d969980d..e0c876d21d2c9 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -40,113 +40,15 @@ namespace arrow { namespace engine { -namespace { - -/// \brief A SinkNodeConsumer specialized to output ExecBatches via PushGenerator -class SubstraitSinkConsumer : public compute::SinkNodeConsumer { - public: - explicit SubstraitSinkConsumer( - arrow::PushGenerator>::Producer producer) - : producer_(std::move(producer)) {} - - Status Consume(compute::ExecBatch batch) override { - // Consume a batch of data - bool did_push = producer_.Push(batch); - if (!did_push) return Status::Invalid("Producer closed already"); - return Status::OK(); - } - - Status Init(const std::shared_ptr& schema, - compute::BackpressureControl* backpressure_control, - compute::ExecPlan* plan) override { - schema_ = schema; - return Status::OK(); - } - - Future<> Finish() override { - ARROW_UNUSED(producer_.Close()); - return Future<>::MakeFinished(); - } - - std::shared_ptr schema() { return schema_; } - - private: - arrow::PushGenerator>::Producer producer_; - std::shared_ptr schema_; -}; - -/// \brief An executor to run a Substrait Query -/// This interface is provided as a utility when creating language -/// bindings for consuming a Substrait plan. -class SubstraitExecutor { - public: - explicit SubstraitExecutor(std::shared_ptr plan, - compute::ExecContext exec_context, - const ConversionOptions& conversion_options = {}) - : plan_(std::move(plan)), - plan_started_(false), - exec_context_(exec_context), - conversion_options_(conversion_options) {} - - ~SubstraitExecutor() { ARROW_UNUSED(this->Close()); } - - Result> Execute() { - for (const compute::Declaration& decl : declarations_) { - RETURN_NOT_OK(decl.AddToPlan(plan_.get()).status()); - } - RETURN_NOT_OK(plan_->Validate()); - plan_started_ = true; - RETURN_NOT_OK(plan_->StartProducing()); - auto schema = sink_consumer_->schema(); - std::shared_ptr sink_reader = compute::MakeGeneratorReader( - std::move(schema), std::move(generator_), exec_context_.memory_pool()); - return sink_reader; - } - - Status Close() { - if (plan_started_) return plan_->finished().status(); - return Status::OK(); - } - - Status Init(const Buffer& substrait_buffer, const ExtensionIdRegistry* registry) { - if (substrait_buffer.size() == 0) { - return Status::Invalid("Empty substrait plan is passed."); - } - sink_consumer_ = std::make_shared(generator_.producer()); - std::function()> consumer_factory = [&] { - return sink_consumer_; - }; - ARROW_ASSIGN_OR_RAISE( - declarations_, engine::DeserializePlans(substrait_buffer, consumer_factory, - registry, nullptr, conversion_options_)); - return Status::OK(); - } - - private: - arrow::PushGenerator> generator_; - std::vector declarations_; - std::shared_ptr plan_; - bool plan_started_; - compute::ExecContext exec_context_; - std::shared_ptr sink_consumer_; - const ConversionOptions& conversion_options_; -}; - -} // namespace - Result> ExecuteSerializedPlan( const Buffer& substrait_buffer, const ExtensionIdRegistry* registry, - compute::FunctionRegistry* func_registry, - const ConversionOptions& conversion_options) { - compute::ExecContext exec_context(arrow::default_memory_pool(), - ::arrow::internal::GetCpuThreadPool(), func_registry); - ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(exec_context)); - SubstraitExecutor executor(std::move(plan), exec_context, conversion_options); - RETURN_NOT_OK(executor.Init(substrait_buffer, registry)); - ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute()); - // check closing here, not in destructor, to expose error to caller - RETURN_NOT_OK(executor.Close()); - return sink_reader; + compute::FunctionRegistry* func_registry, const ConversionOptions& conversion_options, + bool use_threads, MemoryPool* memory_pool) { + ARROW_ASSIGN_OR_RAISE(compute::Declaration plan, + DeserializePlan(substrait_buffer, registry, + /*ext_set_out=*/nullptr, conversion_options)); + return compute::DeclarationToReader(std::move(plan), use_threads, memory_pool, + func_registry); } Result> SerializeJsonPlan(const std::string& substrait_json) { diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 38b6447b00a45..9f8bd8048899a 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -38,10 +38,24 @@ namespace engine { using PythonTableProvider = std::function>(const std::vector&)>; +/// \brief Utility method to run a Substrait plan +/// \param substrait_buffer The plan to run, must be in binary protobuf format +/// \param registry A registry of extension functions to make available to the plan +/// If null then the default registry will be used. +/// \param memory_pool The memory pool the plan should use to make allocations. +/// \param func_registry A registry of functions used for execution expressions. +/// `registry` maps from Substrait function IDs to "names". These +/// names will be provided to `func_registry` to get the actual +/// kernel. +/// \param conversion_options Options to control plan deserialization +/// \param use_threads If True then the CPU thread pool will be used for CPU work. If +/// False then all work will be done on the calling thread. +/// \return A record batch reader that will read out the results ARROW_ENGINE_EXPORT Result> ExecuteSerializedPlan( const Buffer& substrait_buffer, const ExtensionIdRegistry* registry = NULLPTR, compute::FunctionRegistry* func_registry = NULLPTR, - const ConversionOptions& conversion_options = {}); + const ConversionOptions& conversion_options = {}, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool()); /// \brief Get a Serialized Plan from a Substrait JSON plan. /// This is a helper method for Python tests. diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 7974aea152810..da061d8cd3245 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -48,7 +48,7 @@ cdef CDeclaration _create_named_table_provider(dict named_args, const std_vector no_c_inputs, c_input_node_opts) -def run_query(plan, table_provider=None): +def run_query(plan, *, table_provider=None, use_threads=True): """ Execute a Substrait plan and read the results as a RecordBatchReader. @@ -60,6 +60,9 @@ def run_query(plan, table_provider=None): A function to resolve any NamedTable relation to a table. The function will receive a single argument which will be a list of strings representing the table name and should return a pyarrow.Table. + use_threads : bool, default True + If True then multiple threads will be used to run the query. If False then + all CPU intensive work will be done on the calling thread. Returns ------- @@ -123,7 +126,9 @@ def run_query(plan, table_provider=None): shared_ptr[CBuffer] c_buf_plan function[CNamedTableProvider] c_named_table_provider CConversionOptions c_conversion_options + c_bool c_use_threads + c_use_threads = use_threads if isinstance(plan, bytes): c_buf_plan = pyarrow_unwrap_buffer(py_buffer(plan)) elif isinstance(plan, Buffer): @@ -141,7 +146,8 @@ def run_query(plan, table_provider=None): with nogil: c_res_reader = ExecuteSerializedPlan( - deref(c_buf_plan), default_extension_id_registry(), GetFunctionRegistry(), c_conversion_options) + deref(c_buf_plan), default_extension_id_registry(), + GetFunctionRegistry(), c_conversion_options, c_use_threads) c_reader = GetResultValue(c_res_reader) diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index 04990380d97a6..b3ad00516d898 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -51,6 +51,7 @@ cdef extern from "arrow/engine/substrait/extension_set.h" \ cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil: CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan( const CBuffer& substrait_buffer, const ExtensionIdRegistry* registry, - CFunctionRegistry* func_registry, const CConversionOptions& conversion_options) + CFunctionRegistry* func_registry, const CConversionOptions& conversion_options, + c_bool use_threads) CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json) diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 50b16cd6b6f3c..bd32178feaa9a 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -41,7 +41,8 @@ def _write_dummy_data_to_disk(tmpdir, file_name, table): return path -def test_run_serialized_query(tmpdir): +@pytest.mark.parametrize("use_threads", [True, False]) +def test_run_serialized_query(tmpdir, use_threads): substrait_query = """ { "version": { "major": 9999 }, @@ -80,7 +81,7 @@ def test_run_serialized_query(tmpdir): buf = pa._substrait._parse_json_plan(query) - reader = substrait.run_query(buf) + reader = substrait.run_query(buf, use_threads=use_threads) res_tb = reader.read_all() assert table.select(["foo"]) == res_tb.select(["foo"]) @@ -110,12 +111,13 @@ def test_invalid_plan(): } """ buf = pa._substrait._parse_json_plan(tobytes(query)) - exec_message = "Empty substrait plan is passed." + exec_message = "No RelRoot in plan" with pytest.raises(ArrowInvalid, match=exec_message): substrait.run_query(buf) -def test_binary_conversion_with_json_options(tmpdir): +@pytest.mark.parametrize("use_threads", [True, False]) +def test_binary_conversion_with_json_options(tmpdir, use_threads): substrait_query = """ { "version": { "major": 9999 }, @@ -156,7 +158,7 @@ def test_binary_conversion_with_json_options(tmpdir): "FILENAME_PLACEHOLDER", pathlib.Path(path).as_uri())) buf = pa._substrait._parse_json_plan(tobytes(query)) - reader = substrait.run_query(buf) + reader = substrait.run_query(buf, use_threads=use_threads) res_tb = reader.read_all() assert table.select(["bar"]) == res_tb.select(["bar"]) @@ -182,7 +184,8 @@ def test_get_supported_functions(): 'functions_arithmetic.yaml', 'sum') -def test_named_table(): +@pytest.mark.parametrize("use_threads", [True, False]) +def test_named_table(use_threads): test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]}) test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]}) @@ -222,7 +225,8 @@ def table_provider(names): """ buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) - reader = pa.substrait.run_query(buf, table_provider) + reader = pa.substrait.run_query( + buf, table_provider=table_provider, use_threads=use_threads) res_tb = reader.read_all() assert res_tb == test_table_1 @@ -266,7 +270,7 @@ def table_provider(names): buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) exec_message = "Invalid NamedTable Source" with pytest.raises(ArrowInvalid, match=exec_message): - substrait.run_query(buf, table_provider) + substrait.run_query(buf, table_provider=table_provider) def test_named_table_empty_names(): @@ -308,4 +312,4 @@ def table_provider(names): buf = pa._substrait._parse_json_plan(tobytes(query)) exec_message = "names for NamedTable not provided" with pytest.raises(ArrowInvalid, match=exec_message): - substrait.run_query(buf, table_provider) + substrait.run_query(buf, table_provider=table_provider)