From 6e89d22f9d33164edcf0af25af4615ebf0592855 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou <antoine@python.org> Date: Mon, 13 Nov 2023 22:53:48 +0100 Subject: [PATCH] GH-38676: [Python] Fix potential deadlock when CSV reading errors out --- python/pyarrow/_csv.pyx | 5 +- python/pyarrow/_dataset.pxd | 8 +-- python/pyarrow/_dataset.pyx | 4 +- python/pyarrow/_parquet.pyx | 6 +-- python/pyarrow/includes/libarrow_python.pxd | 7 +++ python/pyarrow/ipc.pxi | 2 +- python/pyarrow/lib.pxd | 4 +- python/pyarrow/src/arrow/python/common.h | 55 +++++++++++++++++++-- python/pyarrow/tests/test_csv.py | 21 ++++++++ 9 files changed, 92 insertions(+), 20 deletions(-) diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index e532d8d8ab22a..508488c0c3b3c 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -26,8 +26,7 @@ from collections.abc import Mapping from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * -from pyarrow.includes.libarrow_python cimport (MakeInvalidRowHandler, - PyInvalidRowCallback) +from pyarrow.includes.libarrow_python cimport * from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema, RecordBatchReader, ensure_type, maybe_unbox_memory_pool, get_input_stream, @@ -1251,7 +1250,7 @@ def read_csv(input_file, read_options=None, parse_options=None, CCSVParseOptions c_parse_options CCSVConvertOptions c_convert_options CIOContext io_context - shared_ptr[CCSVReader] reader + SharedPtrNoGIL[CCSVReader] reader shared_ptr[CTable] table _get_reader(input_file, read_options, &stream) diff --git a/python/pyarrow/_dataset.pxd b/python/pyarrow/_dataset.pxd index 210e5558009ec..bee9fc1f0987a 100644 --- a/python/pyarrow/_dataset.pxd +++ b/python/pyarrow/_dataset.pxd @@ -31,7 +31,7 @@ cdef CFileSource _make_file_source(object file, FileSystem filesystem=*) cdef class DatasetFactory(_Weakrefable): cdef: - shared_ptr[CDatasetFactory] wrapped + SharedPtrNoGIL[CDatasetFactory] wrapped CDatasetFactory* factory cdef init(self, const shared_ptr[CDatasetFactory]& sp) @@ -45,7 +45,7 @@ cdef class DatasetFactory(_Weakrefable): cdef class Dataset(_Weakrefable): cdef: - shared_ptr[CDataset] wrapped + SharedPtrNoGIL[CDataset] wrapped CDataset* dataset public dict _scan_options @@ -59,7 +59,7 @@ cdef class Dataset(_Weakrefable): cdef class Scanner(_Weakrefable): cdef: - shared_ptr[CScanner] wrapped + SharedPtrNoGIL[CScanner] wrapped CScanner* scanner cdef void init(self, const shared_ptr[CScanner]& sp) @@ -122,7 +122,7 @@ cdef class FileWriteOptions(_Weakrefable): cdef class Fragment(_Weakrefable): cdef: - shared_ptr[CFragment] wrapped + SharedPtrNoGIL[CFragment] wrapped CFragment* fragment cdef void init(self, const shared_ptr[CFragment]& sp) diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 48ee676915311..d7d69965d000a 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -3227,7 +3227,7 @@ cdef class RecordBatchIterator(_Weakrefable): object iterator_owner # Iterator is a non-POD type and Cython uses offsetof, leading # to a compiler warning unless wrapped like so - shared_ptr[CRecordBatchIterator] iterator + SharedPtrNoGIL[CRecordBatchIterator] iterator def __init__(self): _forbid_instantiation(self.__class__, subclasses_instead=False) @@ -3273,7 +3273,7 @@ cdef class TaggedRecordBatchIterator(_Weakrefable): """An iterator over a sequence of record batches with fragments.""" cdef: object iterator_owner - shared_ptr[CTaggedRecordBatchIterator] iterator + SharedPtrNoGIL[CTaggedRecordBatchIterator] iterator def __init__(self): _forbid_instantiation(self.__class__, subclasses_instead=False) diff --git a/python/pyarrow/_parquet.pyx b/python/pyarrow/_parquet.pyx index 48091367b2ff8..089ed7c75ce58 100644 --- a/python/pyarrow/_parquet.pyx +++ b/python/pyarrow/_parquet.pyx @@ -24,6 +24,7 @@ import warnings from cython.operator cimport dereference as deref from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_python cimport * from pyarrow.lib cimport (_Weakrefable, Buffer, Schema, check_status, MemoryPool, maybe_unbox_memory_pool, @@ -1165,7 +1166,7 @@ cdef class ParquetReader(_Weakrefable): cdef: object source CMemoryPool* pool - unique_ptr[FileReader] reader + UniquePtrNoGIL[FileReader] reader FileMetaData _metadata shared_ptr[CRandomAccessFile] rd_handle @@ -1334,7 +1335,7 @@ cdef class ParquetReader(_Weakrefable): vector[int] c_row_groups vector[int] c_column_indices shared_ptr[CRecordBatch] record_batch - unique_ptr[CRecordBatchReader] recordbatchreader + UniquePtrNoGIL[CRecordBatchReader] recordbatchreader self.set_batch_size(batch_size) @@ -1366,7 +1367,6 @@ cdef class ParquetReader(_Weakrefable): check_status( recordbatchreader.get().ReadNext(&record_batch) ) - if record_batch.get() == NULL: break diff --git a/python/pyarrow/includes/libarrow_python.pxd b/python/pyarrow/includes/libarrow_python.pxd index 4d109fc660e08..a1a52012344e9 100644 --- a/python/pyarrow/includes/libarrow_python.pxd +++ b/python/pyarrow/includes/libarrow_python.pxd @@ -261,6 +261,13 @@ cdef extern from "arrow/python/common.h" namespace "arrow::py": void RestorePyError(const CStatus& status) except * +cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil: + cdef cppclass SharedPtrNoGIL[T](shared_ptr[T]): + pass + cdef cppclass UniquePtrNoGIL[T, DELETER=*](unique_ptr[T, DELETER]): + pass + + cdef extern from "arrow/python/inference.h" namespace "arrow::py": c_bool IsPyBool(object o) c_bool IsPyInt(object o) diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index fcb9eb729ef04..5d20a4f8b72cb 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -977,7 +977,7 @@ cdef _wrap_record_batch_with_metadata(CRecordBatchWithMetadata c): cdef class _RecordBatchFileReader(_Weakrefable): cdef: - shared_ptr[CRecordBatchFileReader] reader + SharedPtrNoGIL[CRecordBatchFileReader] reader shared_ptr[CRandomAccessFile] file CIpcReadOptions options diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 63ebe6aea8233..ae197eca1ca6b 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -552,12 +552,12 @@ cdef class CompressedOutputStream(NativeFile): cdef class _CRecordBatchWriter(_Weakrefable): cdef: - shared_ptr[CRecordBatchWriter] writer + SharedPtrNoGIL[CRecordBatchWriter] writer cdef class RecordBatchReader(_Weakrefable): cdef: - shared_ptr[CRecordBatchReader] reader + SharedPtrNoGIL[CRecordBatchReader] reader cdef class Codec(_Weakrefable): diff --git a/python/pyarrow/src/arrow/python/common.h b/python/pyarrow/src/arrow/python/common.h index e36c0834fd424..1dd3ce8435811 100644 --- a/python/pyarrow/src/arrow/python/common.h +++ b/python/pyarrow/src/arrow/python/common.h @@ -19,6 +19,7 @@ #include <functional> #include <memory> +#include <optional> #include <utility> #include "arrow/buffer.h" @@ -134,13 +135,15 @@ class ARROW_PYTHON_EXPORT PyAcquireGIL { // A RAII-style helper that releases the GIL until the end of a lexical block class ARROW_PYTHON_EXPORT PyReleaseGIL { public: - PyReleaseGIL() { saved_state_ = PyEval_SaveThread(); } - - ~PyReleaseGIL() { PyEval_RestoreThread(saved_state_); } + PyReleaseGIL() : ptr_(PyEval_SaveThread(), &unique_ptr_deleter) {} private: - PyThreadState* saved_state_; - ARROW_DISALLOW_COPY_AND_ASSIGN(PyReleaseGIL); + static void unique_ptr_deleter(PyThreadState* state) { + if (state) { + PyEval_RestoreThread(state); + } + } + std::unique_ptr<PyThreadState, decltype(&unique_ptr_deleter)> ptr_; }; // A helper to call safely into the Python interpreter from arbitrary C++ code. @@ -235,6 +238,48 @@ class ARROW_PYTHON_EXPORT OwnedRefNoGIL : public OwnedRef { } }; +template <template <typename...> typename SmartPtr, typename... Ts> +class SmartPtrNoGIL : public SmartPtr<Ts...> { + using Base = SmartPtr<Ts...>; + + public: + template <typename... Args> + SmartPtrNoGIL(Args&&... args) : Base(std::forward<Args>(args)...) {} + + ~SmartPtrNoGIL() { reset(); } + + template <typename... Args> + void reset(Args&&... args) { + auto release_guard = optional_gil_release(); + Base::reset(std::forward<Args>(args)...); + } + + template <typename V> + SmartPtrNoGIL& operator=(V&& v) { + auto release_guard = optional_gil_release(); + Base::operator=(std::forward<V>(v)); + return *this; + } + + private: + // Only release the GIL if we own an object *and* the Python runtime is + // valid *and* the GIL is held. + std::optional<PyReleaseGIL> optional_gil_release() const { + if (this->get() != nullptr && Py_IsInitialized() && PyGILState_Check()) { + return PyReleaseGIL(); + } + return {}; + } +}; + +/// \brief A std::shared_ptr<T, ...> subclass that releases the GIL when destroying T +template <typename... Ts> +using SharedPtrNoGIL = SmartPtrNoGIL<std::shared_ptr, Ts...>; + +/// \brief A std::unique_ptr<T, ...> subclass that releases the GIL when destroying T +template <typename... Ts> +using UniquePtrNoGIL = SmartPtrNoGIL<std::unique_ptr, Ts...>; + template <typename Fn> struct BoundFunction; diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py index afc5380b75516..31f24187e3b37 100644 --- a/python/pyarrow/tests/test_csv.py +++ b/python/pyarrow/tests/test_csv.py @@ -1970,3 +1970,24 @@ def test_write_csv_decimal(tmpdir, type_factory): out = read_csv(tmpdir / "out.csv") assert out.column('col').cast(type) == table.column('col') + + +def test_read_csv_gil_deadlock(): + # GH-38676 + # This test depends on several preconditions: + # - the CSV input is a Python file object + # - reading the CSV file produces an error + data = b"a,b,c" + + class MyBytesIO(io.BytesIO): + def read(self, *args): + time.sleep(0.001) + return super().read(*args) + + def readinto(self, *args): + time.sleep(0.001) + return super().readinto(*args) + + for i in range(20): + with pytest.raises(pa.ArrowInvalid): + read_csv(MyBytesIO(data))