Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-38676: [Python] Fix potential deadlock when CSV reading errors out #38713

Merged
merged 2 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/pyarrow/_csv.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ from collections.abc import Mapping

from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_python cimport (MakeInvalidRowHandler,
PyInvalidRowCallback)
from pyarrow.includes.libarrow_python cimport *
from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema,
RecordBatchReader, ensure_type,
maybe_unbox_memory_pool, get_input_stream,
Expand Down Expand Up @@ -1251,7 +1250,7 @@ def read_csv(input_file, read_options=None, parse_options=None,
CCSVParseOptions c_parse_options
CCSVConvertOptions c_convert_options
CIOContext io_context
shared_ptr[CCSVReader] reader
SharedPtrNoGIL[CCSVReader] reader
shared_ptr[CTable] table

_get_reader(input_file, read_options, &stream)
Expand Down
8 changes: 4 additions & 4 deletions python/pyarrow/_dataset.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ cdef CFileSource _make_file_source(object file, FileSystem filesystem=*)
cdef class DatasetFactory(_Weakrefable):

cdef:
shared_ptr[CDatasetFactory] wrapped
SharedPtrNoGIL[CDatasetFactory] wrapped
CDatasetFactory* factory

cdef init(self, const shared_ptr[CDatasetFactory]& sp)
Expand All @@ -45,7 +45,7 @@ cdef class DatasetFactory(_Weakrefable):
cdef class Dataset(_Weakrefable):

cdef:
shared_ptr[CDataset] wrapped
SharedPtrNoGIL[CDataset] wrapped
CDataset* dataset
public dict _scan_options

Expand All @@ -59,7 +59,7 @@ cdef class Dataset(_Weakrefable):

cdef class Scanner(_Weakrefable):
cdef:
shared_ptr[CScanner] wrapped
SharedPtrNoGIL[CScanner] wrapped
CScanner* scanner

cdef void init(self, const shared_ptr[CScanner]& sp)
Expand Down Expand Up @@ -122,7 +122,7 @@ cdef class FileWriteOptions(_Weakrefable):
cdef class Fragment(_Weakrefable):

cdef:
shared_ptr[CFragment] wrapped
SharedPtrNoGIL[CFragment] wrapped
CFragment* fragment

cdef void init(self, const shared_ptr[CFragment]& sp)
Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3227,7 +3227,7 @@ cdef class RecordBatchIterator(_Weakrefable):
object iterator_owner
# Iterator is a non-POD type and Cython uses offsetof, leading
# to a compiler warning unless wrapped like so
shared_ptr[CRecordBatchIterator] iterator
SharedPtrNoGIL[CRecordBatchIterator] iterator

def __init__(self):
_forbid_instantiation(self.__class__, subclasses_instead=False)
Expand Down Expand Up @@ -3273,7 +3273,7 @@ cdef class TaggedRecordBatchIterator(_Weakrefable):
"""An iterator over a sequence of record batches with fragments."""
cdef:
object iterator_owner
shared_ptr[CTaggedRecordBatchIterator] iterator
SharedPtrNoGIL[CTaggedRecordBatchIterator] iterator

def __init__(self):
_forbid_instantiation(self.__class__, subclasses_instead=False)
Expand Down
6 changes: 3 additions & 3 deletions python/pyarrow/_parquet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import warnings
from cython.operator cimport dereference as deref
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_python cimport *
from pyarrow.lib cimport (_Weakrefable, Buffer, Schema,
check_status,
MemoryPool, maybe_unbox_memory_pool,
Expand Down Expand Up @@ -1165,7 +1166,7 @@ cdef class ParquetReader(_Weakrefable):
cdef:
object source
CMemoryPool* pool
unique_ptr[FileReader] reader
UniquePtrNoGIL[FileReader] reader
FileMetaData _metadata
shared_ptr[CRandomAccessFile] rd_handle

Expand Down Expand Up @@ -1334,7 +1335,7 @@ cdef class ParquetReader(_Weakrefable):
vector[int] c_row_groups
vector[int] c_column_indices
shared_ptr[CRecordBatch] record_batch
unique_ptr[CRecordBatchReader] recordbatchreader
UniquePtrNoGIL[CRecordBatchReader] recordbatchreader

self.set_batch_size(batch_size)

Expand Down Expand Up @@ -1366,7 +1367,6 @@ cdef class ParquetReader(_Weakrefable):
check_status(
recordbatchreader.get().ReadNext(&record_batch)
)

if record_batch.get() == NULL:
break

Expand Down
8 changes: 8 additions & 0 deletions python/pyarrow/includes/libarrow_python.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ cdef extern from "arrow/python/common.h" namespace "arrow::py":
void RestorePyError(const CStatus& status) except *


cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil:
cdef cppclass SharedPtrNoGIL[T](shared_ptr[T]):
# This looks like the only way to satsify both Cython 2 and Cython 3
SharedPtrNoGIL& operator=(...)
cdef cppclass UniquePtrNoGIL[T, DELETER=*](unique_ptr[T, DELETER]):
UniquePtrNoGIL& operator=(...)


cdef extern from "arrow/python/inference.h" namespace "arrow::py":
c_bool IsPyBool(object o)
c_bool IsPyInt(object o)
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ cdef _wrap_record_batch_with_metadata(CRecordBatchWithMetadata c):

cdef class _RecordBatchFileReader(_Weakrefable):
cdef:
shared_ptr[CRecordBatchFileReader] reader
SharedPtrNoGIL[CRecordBatchFileReader] reader
shared_ptr[CRandomAccessFile] file
CIpcReadOptions options

Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -552,12 +552,12 @@ cdef class CompressedOutputStream(NativeFile):

cdef class _CRecordBatchWriter(_Weakrefable):
cdef:
shared_ptr[CRecordBatchWriter] writer
SharedPtrNoGIL[CRecordBatchWriter] writer


cdef class RecordBatchReader(_Weakrefable):
cdef:
shared_ptr[CRecordBatchReader] reader
SharedPtrNoGIL[CRecordBatchReader] reader


cdef class Codec(_Weakrefable):
Expand Down
55 changes: 50 additions & 5 deletions python/pyarrow/src/arrow/python/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <functional>
#include <memory>
#include <optional>
#include <utility>

#include "arrow/buffer.h"
Expand Down Expand Up @@ -134,13 +135,15 @@ class ARROW_PYTHON_EXPORT PyAcquireGIL {
// A RAII-style helper that releases the GIL until the end of a lexical block
class ARROW_PYTHON_EXPORT PyReleaseGIL {
public:
PyReleaseGIL() { saved_state_ = PyEval_SaveThread(); }

~PyReleaseGIL() { PyEval_RestoreThread(saved_state_); }
PyReleaseGIL() : ptr_(PyEval_SaveThread(), &unique_ptr_deleter) {}

private:
PyThreadState* saved_state_;
ARROW_DISALLOW_COPY_AND_ASSIGN(PyReleaseGIL);
static void unique_ptr_deleter(PyThreadState* state) {
if (state) {
PyEval_RestoreThread(state);
}
}
std::unique_ptr<PyThreadState, decltype(&unique_ptr_deleter)> ptr_;
};

// A helper to call safely into the Python interpreter from arbitrary C++ code.
Expand Down Expand Up @@ -235,6 +238,48 @@ class ARROW_PYTHON_EXPORT OwnedRefNoGIL : public OwnedRef {
}
};

template <template <typename...> typename SmartPtr, typename... Ts>
class SmartPtrNoGIL : public SmartPtr<Ts...> {
using Base = SmartPtr<Ts...>;

public:
template <typename... Args>
SmartPtrNoGIL(Args&&... args) : Base(std::forward<Args>(args)...) {}

~SmartPtrNoGIL() { reset(); }

template <typename... Args>
void reset(Args&&... args) {
auto release_guard = optional_gil_release();
Base::reset(std::forward<Args>(args)...);
}

template <typename V>
SmartPtrNoGIL& operator=(V&& v) {
auto release_guard = optional_gil_release();
Base::operator=(std::forward<V>(v));
return *this;
}

private:
// Only release the GIL if we own an object *and* the Python runtime is
// valid *and* the GIL is held.
std::optional<PyReleaseGIL> optional_gil_release() const {
if (this->get() != nullptr && Py_IsInitialized() && PyGILState_Check()) {
return PyReleaseGIL();
}
return {};
}
};

/// \brief A std::shared_ptr<T, ...> subclass that releases the GIL when destroying T
template <typename... Ts>
using SharedPtrNoGIL = SmartPtrNoGIL<std::shared_ptr, Ts...>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the no GIL smart pointer wraps an object which must call into python as part of its destruction? (For example a dataset which wraps a python file system.) I think that this will be fine since such a call would acquire the GIL at the call site

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this will be fine since such a call would acquire the GIL at the call site

Yes, it would.


/// \brief A std::unique_ptr<T, ...> subclass that releases the GIL when destroying T
template <typename... Ts>
using UniquePtrNoGIL = SmartPtrNoGIL<std::unique_ptr, Ts...>;

template <typename Fn>
struct BoundFunction;

Expand Down
21 changes: 21 additions & 0 deletions python/pyarrow/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,3 +1970,24 @@ def test_write_csv_decimal(tmpdir, type_factory):
out = read_csv(tmpdir / "out.csv")

assert out.column('col').cast(type) == table.column('col')


def test_read_csv_gil_deadlock():
# GH-38676
# This test depends on several preconditions:
# - the CSV input is a Python file object
# - reading the CSV file produces an error
data = b"a,b,c"

class MyBytesIO(io.BytesIO):
def read(self, *args):
time.sleep(0.001)
return super().read(*args)

def readinto(self, *args):
time.sleep(0.001)
return super().readinto(*args)

for i in range(20):
with pytest.raises(pa.ArrowInvalid):
read_csv(MyBytesIO(data))
Loading