From 6e89d22f9d33164edcf0af25af4615ebf0592855 Mon Sep 17 00:00:00 2001
From: Antoine Pitrou <antoine@python.org>
Date: Mon, 13 Nov 2023 22:53:48 +0100
Subject: [PATCH] GH-38676: [Python] Fix potential deadlock when CSV reading
 errors out

---
 python/pyarrow/_csv.pyx                     |  5 +-
 python/pyarrow/_dataset.pxd                 |  8 +--
 python/pyarrow/_dataset.pyx                 |  4 +-
 python/pyarrow/_parquet.pyx                 |  6 +--
 python/pyarrow/includes/libarrow_python.pxd |  7 +++
 python/pyarrow/ipc.pxi                      |  2 +-
 python/pyarrow/lib.pxd                      |  4 +-
 python/pyarrow/src/arrow/python/common.h    | 55 +++++++++++++++++++--
 python/pyarrow/tests/test_csv.py            | 21 ++++++++
 9 files changed, 92 insertions(+), 20 deletions(-)

diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx
index e532d8d8ab22a..508488c0c3b3c 100644
--- a/python/pyarrow/_csv.pyx
+++ b/python/pyarrow/_csv.pyx
@@ -26,8 +26,7 @@ from collections.abc import Mapping
 
 from pyarrow.includes.common cimport *
 from pyarrow.includes.libarrow cimport *
-from pyarrow.includes.libarrow_python cimport (MakeInvalidRowHandler,
-                                               PyInvalidRowCallback)
+from pyarrow.includes.libarrow_python cimport *
 from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema,
                           RecordBatchReader, ensure_type,
                           maybe_unbox_memory_pool, get_input_stream,
@@ -1251,7 +1250,7 @@ def read_csv(input_file, read_options=None, parse_options=None,
         CCSVParseOptions c_parse_options
         CCSVConvertOptions c_convert_options
         CIOContext io_context
-        shared_ptr[CCSVReader] reader
+        SharedPtrNoGIL[CCSVReader] reader
         shared_ptr[CTable] table
 
     _get_reader(input_file, read_options, &stream)
diff --git a/python/pyarrow/_dataset.pxd b/python/pyarrow/_dataset.pxd
index 210e5558009ec..bee9fc1f0987a 100644
--- a/python/pyarrow/_dataset.pxd
+++ b/python/pyarrow/_dataset.pxd
@@ -31,7 +31,7 @@ cdef CFileSource _make_file_source(object file, FileSystem filesystem=*)
 cdef class DatasetFactory(_Weakrefable):
 
     cdef:
-        shared_ptr[CDatasetFactory] wrapped
+        SharedPtrNoGIL[CDatasetFactory] wrapped
         CDatasetFactory* factory
 
     cdef init(self, const shared_ptr[CDatasetFactory]& sp)
@@ -45,7 +45,7 @@ cdef class DatasetFactory(_Weakrefable):
 cdef class Dataset(_Weakrefable):
 
     cdef:
-        shared_ptr[CDataset] wrapped
+        SharedPtrNoGIL[CDataset] wrapped
         CDataset* dataset
         public dict _scan_options
 
@@ -59,7 +59,7 @@ cdef class Dataset(_Weakrefable):
 
 cdef class Scanner(_Weakrefable):
     cdef:
-        shared_ptr[CScanner] wrapped
+        SharedPtrNoGIL[CScanner] wrapped
         CScanner* scanner
 
     cdef void init(self, const shared_ptr[CScanner]& sp)
@@ -122,7 +122,7 @@ cdef class FileWriteOptions(_Weakrefable):
 cdef class Fragment(_Weakrefable):
 
     cdef:
-        shared_ptr[CFragment] wrapped
+        SharedPtrNoGIL[CFragment] wrapped
         CFragment* fragment
 
     cdef void init(self, const shared_ptr[CFragment]& sp)
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index 48ee676915311..d7d69965d000a 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -3227,7 +3227,7 @@ cdef class RecordBatchIterator(_Weakrefable):
         object iterator_owner
         # Iterator is a non-POD type and Cython uses offsetof, leading
         # to a compiler warning unless wrapped like so
-        shared_ptr[CRecordBatchIterator] iterator
+        SharedPtrNoGIL[CRecordBatchIterator] iterator
 
     def __init__(self):
         _forbid_instantiation(self.__class__, subclasses_instead=False)
@@ -3273,7 +3273,7 @@ cdef class TaggedRecordBatchIterator(_Weakrefable):
     """An iterator over a sequence of record batches with fragments."""
     cdef:
         object iterator_owner
-        shared_ptr[CTaggedRecordBatchIterator] iterator
+        SharedPtrNoGIL[CTaggedRecordBatchIterator] iterator
 
     def __init__(self):
         _forbid_instantiation(self.__class__, subclasses_instead=False)
diff --git a/python/pyarrow/_parquet.pyx b/python/pyarrow/_parquet.pyx
index 48091367b2ff8..089ed7c75ce58 100644
--- a/python/pyarrow/_parquet.pyx
+++ b/python/pyarrow/_parquet.pyx
@@ -24,6 +24,7 @@ import warnings
 from cython.operator cimport dereference as deref
 from pyarrow.includes.common cimport *
 from pyarrow.includes.libarrow cimport *
+from pyarrow.includes.libarrow_python cimport *
 from pyarrow.lib cimport (_Weakrefable, Buffer, Schema,
                           check_status,
                           MemoryPool, maybe_unbox_memory_pool,
@@ -1165,7 +1166,7 @@ cdef class ParquetReader(_Weakrefable):
     cdef:
         object source
         CMemoryPool* pool
-        unique_ptr[FileReader] reader
+        UniquePtrNoGIL[FileReader] reader
         FileMetaData _metadata
         shared_ptr[CRandomAccessFile] rd_handle
 
@@ -1334,7 +1335,7 @@ cdef class ParquetReader(_Weakrefable):
             vector[int] c_row_groups
             vector[int] c_column_indices
             shared_ptr[CRecordBatch] record_batch
-            unique_ptr[CRecordBatchReader] recordbatchreader
+            UniquePtrNoGIL[CRecordBatchReader] recordbatchreader
 
         self.set_batch_size(batch_size)
 
@@ -1366,7 +1367,6 @@ cdef class ParquetReader(_Weakrefable):
                 check_status(
                     recordbatchreader.get().ReadNext(&record_batch)
                 )
-
             if record_batch.get() == NULL:
                 break
 
diff --git a/python/pyarrow/includes/libarrow_python.pxd b/python/pyarrow/includes/libarrow_python.pxd
index 4d109fc660e08..a1a52012344e9 100644
--- a/python/pyarrow/includes/libarrow_python.pxd
+++ b/python/pyarrow/includes/libarrow_python.pxd
@@ -261,6 +261,13 @@ cdef extern from "arrow/python/common.h" namespace "arrow::py":
     void RestorePyError(const CStatus& status) except *
 
 
+cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil:
+    cdef cppclass SharedPtrNoGIL[T](shared_ptr[T]):
+        pass
+    cdef cppclass UniquePtrNoGIL[T, DELETER=*](unique_ptr[T, DELETER]):
+        pass
+
+
 cdef extern from "arrow/python/inference.h" namespace "arrow::py":
     c_bool IsPyBool(object o)
     c_bool IsPyInt(object o)
diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi
index fcb9eb729ef04..5d20a4f8b72cb 100644
--- a/python/pyarrow/ipc.pxi
+++ b/python/pyarrow/ipc.pxi
@@ -977,7 +977,7 @@ cdef _wrap_record_batch_with_metadata(CRecordBatchWithMetadata c):
 
 cdef class _RecordBatchFileReader(_Weakrefable):
     cdef:
-        shared_ptr[CRecordBatchFileReader] reader
+        SharedPtrNoGIL[CRecordBatchFileReader] reader
         shared_ptr[CRandomAccessFile] file
         CIpcReadOptions options
 
diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index 63ebe6aea8233..ae197eca1ca6b 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -552,12 +552,12 @@ cdef class CompressedOutputStream(NativeFile):
 
 cdef class _CRecordBatchWriter(_Weakrefable):
     cdef:
-        shared_ptr[CRecordBatchWriter] writer
+        SharedPtrNoGIL[CRecordBatchWriter] writer
 
 
 cdef class RecordBatchReader(_Weakrefable):
     cdef:
-        shared_ptr[CRecordBatchReader] reader
+        SharedPtrNoGIL[CRecordBatchReader] reader
 
 
 cdef class Codec(_Weakrefable):
diff --git a/python/pyarrow/src/arrow/python/common.h b/python/pyarrow/src/arrow/python/common.h
index e36c0834fd424..1dd3ce8435811 100644
--- a/python/pyarrow/src/arrow/python/common.h
+++ b/python/pyarrow/src/arrow/python/common.h
@@ -19,6 +19,7 @@
 
 #include <functional>
 #include <memory>
+#include <optional>
 #include <utility>
 
 #include "arrow/buffer.h"
@@ -134,13 +135,15 @@ class ARROW_PYTHON_EXPORT PyAcquireGIL {
 // A RAII-style helper that releases the GIL until the end of a lexical block
 class ARROW_PYTHON_EXPORT PyReleaseGIL {
  public:
-  PyReleaseGIL() { saved_state_ = PyEval_SaveThread(); }
-
-  ~PyReleaseGIL() { PyEval_RestoreThread(saved_state_); }
+  PyReleaseGIL() : ptr_(PyEval_SaveThread(), &unique_ptr_deleter) {}
 
  private:
-  PyThreadState* saved_state_;
-  ARROW_DISALLOW_COPY_AND_ASSIGN(PyReleaseGIL);
+  static void unique_ptr_deleter(PyThreadState* state) {
+    if (state) {
+      PyEval_RestoreThread(state);
+    }
+  }
+  std::unique_ptr<PyThreadState, decltype(&unique_ptr_deleter)> ptr_;
 };
 
 // A helper to call safely into the Python interpreter from arbitrary C++ code.
@@ -235,6 +238,48 @@ class ARROW_PYTHON_EXPORT OwnedRefNoGIL : public OwnedRef {
   }
 };
 
+template <template <typename...> typename SmartPtr, typename... Ts>
+class SmartPtrNoGIL : public SmartPtr<Ts...> {
+  using Base = SmartPtr<Ts...>;
+
+ public:
+  template <typename... Args>
+  SmartPtrNoGIL(Args&&... args) : Base(std::forward<Args>(args)...) {}
+
+  ~SmartPtrNoGIL() { reset(); }
+
+  template <typename... Args>
+  void reset(Args&&... args) {
+    auto release_guard = optional_gil_release();
+    Base::reset(std::forward<Args>(args)...);
+  }
+
+  template <typename V>
+  SmartPtrNoGIL& operator=(V&& v) {
+    auto release_guard = optional_gil_release();
+    Base::operator=(std::forward<V>(v));
+    return *this;
+  }
+
+ private:
+  // Only release the GIL if we own an object *and* the Python runtime is
+  // valid *and* the GIL is held.
+  std::optional<PyReleaseGIL> optional_gil_release() const {
+    if (this->get() != nullptr && Py_IsInitialized() && PyGILState_Check()) {
+      return PyReleaseGIL();
+    }
+    return {};
+  }
+};
+
+/// \brief A std::shared_ptr<T, ...> subclass that releases the GIL when destroying T
+template <typename... Ts>
+using SharedPtrNoGIL = SmartPtrNoGIL<std::shared_ptr, Ts...>;
+
+/// \brief A std::unique_ptr<T, ...> subclass that releases the GIL when destroying T
+template <typename... Ts>
+using UniquePtrNoGIL = SmartPtrNoGIL<std::unique_ptr, Ts...>;
+
 template <typename Fn>
 struct BoundFunction;
 
diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py
index afc5380b75516..31f24187e3b37 100644
--- a/python/pyarrow/tests/test_csv.py
+++ b/python/pyarrow/tests/test_csv.py
@@ -1970,3 +1970,24 @@ def test_write_csv_decimal(tmpdir, type_factory):
     out = read_csv(tmpdir / "out.csv")
 
     assert out.column('col').cast(type) == table.column('col')
+
+
+def test_read_csv_gil_deadlock():
+    # GH-38676
+    # This test depends on several preconditions:
+    # - the CSV input is a Python file object
+    # - reading the CSV file produces an error
+    data = b"a,b,c"
+
+    class MyBytesIO(io.BytesIO):
+        def read(self, *args):
+            time.sleep(0.001)
+            return super().read(*args)
+
+        def readinto(self, *args):
+            time.sleep(0.001)
+            return super().readinto(*args)
+
+    for i in range(20):
+        with pytest.raises(pa.ArrowInvalid):
+            read_csv(MyBytesIO(data))