Skip to content

Commit

Permalink
apacheGH-38676: [Python] Fix potential deadlock when CSV reading erro…
Browse files Browse the repository at this point in the history
…rs out
  • Loading branch information
pitrou committed Nov 14, 2023
1 parent 160d45c commit 6e89d22
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 20 deletions.
5 changes: 2 additions & 3 deletions python/pyarrow/_csv.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ from collections.abc import Mapping

from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_python cimport (MakeInvalidRowHandler,
PyInvalidRowCallback)
from pyarrow.includes.libarrow_python cimport *
from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema,
RecordBatchReader, ensure_type,
maybe_unbox_memory_pool, get_input_stream,
Expand Down Expand Up @@ -1251,7 +1250,7 @@ def read_csv(input_file, read_options=None, parse_options=None,
CCSVParseOptions c_parse_options
CCSVConvertOptions c_convert_options
CIOContext io_context
shared_ptr[CCSVReader] reader
SharedPtrNoGIL[CCSVReader] reader
shared_ptr[CTable] table

_get_reader(input_file, read_options, &stream)
Expand Down
8 changes: 4 additions & 4 deletions python/pyarrow/_dataset.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ cdef CFileSource _make_file_source(object file, FileSystem filesystem=*)
cdef class DatasetFactory(_Weakrefable):

cdef:
shared_ptr[CDatasetFactory] wrapped
SharedPtrNoGIL[CDatasetFactory] wrapped
CDatasetFactory* factory

cdef init(self, const shared_ptr[CDatasetFactory]& sp)
Expand All @@ -45,7 +45,7 @@ cdef class DatasetFactory(_Weakrefable):
cdef class Dataset(_Weakrefable):

cdef:
shared_ptr[CDataset] wrapped
SharedPtrNoGIL[CDataset] wrapped
CDataset* dataset
public dict _scan_options

Expand All @@ -59,7 +59,7 @@ cdef class Dataset(_Weakrefable):

cdef class Scanner(_Weakrefable):
cdef:
shared_ptr[CScanner] wrapped
SharedPtrNoGIL[CScanner] wrapped
CScanner* scanner

cdef void init(self, const shared_ptr[CScanner]& sp)
Expand Down Expand Up @@ -122,7 +122,7 @@ cdef class FileWriteOptions(_Weakrefable):
cdef class Fragment(_Weakrefable):

cdef:
shared_ptr[CFragment] wrapped
SharedPtrNoGIL[CFragment] wrapped
CFragment* fragment

cdef void init(self, const shared_ptr[CFragment]& sp)
Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3227,7 +3227,7 @@ cdef class RecordBatchIterator(_Weakrefable):
object iterator_owner
# Iterator is a non-POD type and Cython uses offsetof, leading
# to a compiler warning unless wrapped like so
shared_ptr[CRecordBatchIterator] iterator
SharedPtrNoGIL[CRecordBatchIterator] iterator

def __init__(self):
_forbid_instantiation(self.__class__, subclasses_instead=False)
Expand Down Expand Up @@ -3273,7 +3273,7 @@ cdef class TaggedRecordBatchIterator(_Weakrefable):
"""An iterator over a sequence of record batches with fragments."""
cdef:
object iterator_owner
shared_ptr[CTaggedRecordBatchIterator] iterator
SharedPtrNoGIL[CTaggedRecordBatchIterator] iterator

def __init__(self):
_forbid_instantiation(self.__class__, subclasses_instead=False)
Expand Down
6 changes: 3 additions & 3 deletions python/pyarrow/_parquet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import warnings
from cython.operator cimport dereference as deref
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_python cimport *
from pyarrow.lib cimport (_Weakrefable, Buffer, Schema,
check_status,
MemoryPool, maybe_unbox_memory_pool,
Expand Down Expand Up @@ -1165,7 +1166,7 @@ cdef class ParquetReader(_Weakrefable):
cdef:
object source
CMemoryPool* pool
unique_ptr[FileReader] reader
UniquePtrNoGIL[FileReader] reader
FileMetaData _metadata
shared_ptr[CRandomAccessFile] rd_handle

Expand Down Expand Up @@ -1334,7 +1335,7 @@ cdef class ParquetReader(_Weakrefable):
vector[int] c_row_groups
vector[int] c_column_indices
shared_ptr[CRecordBatch] record_batch
unique_ptr[CRecordBatchReader] recordbatchreader
UniquePtrNoGIL[CRecordBatchReader] recordbatchreader

self.set_batch_size(batch_size)

Expand Down Expand Up @@ -1366,7 +1367,6 @@ cdef class ParquetReader(_Weakrefable):
check_status(
recordbatchreader.get().ReadNext(&record_batch)
)

if record_batch.get() == NULL:
break

Expand Down
7 changes: 7 additions & 0 deletions python/pyarrow/includes/libarrow_python.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ cdef extern from "arrow/python/common.h" namespace "arrow::py":
void RestorePyError(const CStatus& status) except *


cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil:
cdef cppclass SharedPtrNoGIL[T](shared_ptr[T]):
pass
cdef cppclass UniquePtrNoGIL[T, DELETER=*](unique_ptr[T, DELETER]):
pass


cdef extern from "arrow/python/inference.h" namespace "arrow::py":
c_bool IsPyBool(object o)
c_bool IsPyInt(object o)
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ cdef _wrap_record_batch_with_metadata(CRecordBatchWithMetadata c):

cdef class _RecordBatchFileReader(_Weakrefable):
cdef:
shared_ptr[CRecordBatchFileReader] reader
SharedPtrNoGIL[CRecordBatchFileReader] reader
shared_ptr[CRandomAccessFile] file
CIpcReadOptions options

Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -552,12 +552,12 @@ cdef class CompressedOutputStream(NativeFile):

cdef class _CRecordBatchWriter(_Weakrefable):
cdef:
shared_ptr[CRecordBatchWriter] writer
SharedPtrNoGIL[CRecordBatchWriter] writer


cdef class RecordBatchReader(_Weakrefable):
cdef:
shared_ptr[CRecordBatchReader] reader
SharedPtrNoGIL[CRecordBatchReader] reader


cdef class Codec(_Weakrefable):
Expand Down
55 changes: 50 additions & 5 deletions python/pyarrow/src/arrow/python/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <functional>
#include <memory>
#include <optional>
#include <utility>

#include "arrow/buffer.h"
Expand Down Expand Up @@ -134,13 +135,15 @@ class ARROW_PYTHON_EXPORT PyAcquireGIL {
// A RAII-style helper that releases the GIL until the end of a lexical block
class ARROW_PYTHON_EXPORT PyReleaseGIL {
public:
PyReleaseGIL() { saved_state_ = PyEval_SaveThread(); }

~PyReleaseGIL() { PyEval_RestoreThread(saved_state_); }
PyReleaseGIL() : ptr_(PyEval_SaveThread(), &unique_ptr_deleter) {}

private:
PyThreadState* saved_state_;
ARROW_DISALLOW_COPY_AND_ASSIGN(PyReleaseGIL);
static void unique_ptr_deleter(PyThreadState* state) {
if (state) {
PyEval_RestoreThread(state);
}
}
std::unique_ptr<PyThreadState, decltype(&unique_ptr_deleter)> ptr_;
};

// A helper to call safely into the Python interpreter from arbitrary C++ code.
Expand Down Expand Up @@ -235,6 +238,48 @@ class ARROW_PYTHON_EXPORT OwnedRefNoGIL : public OwnedRef {
}
};

template <template <typename...> typename SmartPtr, typename... Ts>
class SmartPtrNoGIL : public SmartPtr<Ts...> {
using Base = SmartPtr<Ts...>;

public:
template <typename... Args>
SmartPtrNoGIL(Args&&... args) : Base(std::forward<Args>(args)...) {}

~SmartPtrNoGIL() { reset(); }

template <typename... Args>
void reset(Args&&... args) {
auto release_guard = optional_gil_release();
Base::reset(std::forward<Args>(args)...);
}

template <typename V>
SmartPtrNoGIL& operator=(V&& v) {
auto release_guard = optional_gil_release();
Base::operator=(std::forward<V>(v));
return *this;
}

private:
// Only release the GIL if we own an object *and* the Python runtime is
// valid *and* the GIL is held.
std::optional<PyReleaseGIL> optional_gil_release() const {
if (this->get() != nullptr && Py_IsInitialized() && PyGILState_Check()) {
return PyReleaseGIL();
}
return {};
}
};

/// \brief A std::shared_ptr<T, ...> subclass that releases the GIL when destroying T
template <typename... Ts>
using SharedPtrNoGIL = SmartPtrNoGIL<std::shared_ptr, Ts...>;

/// \brief A std::unique_ptr<T, ...> subclass that releases the GIL when destroying T
template <typename... Ts>
using UniquePtrNoGIL = SmartPtrNoGIL<std::unique_ptr, Ts...>;

template <typename Fn>
struct BoundFunction;

Expand Down
21 changes: 21 additions & 0 deletions python/pyarrow/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,3 +1970,24 @@ def test_write_csv_decimal(tmpdir, type_factory):
out = read_csv(tmpdir / "out.csv")

assert out.column('col').cast(type) == table.column('col')


def test_read_csv_gil_deadlock():
# GH-38676
# This test depends on several preconditions:
# - the CSV input is a Python file object
# - reading the CSV file produces an error
data = b"a,b,c"

class MyBytesIO(io.BytesIO):
def read(self, *args):
time.sleep(0.001)
return super().read(*args)

def readinto(self, *args):
time.sleep(0.001)
return super().readinto(*args)

for i in range(20):
with pytest.raises(pa.ArrowInvalid):
read_csv(MyBytesIO(data))

0 comments on commit 6e89d22

Please sign in to comment.