Skip to content

Commit

Permalink
Acquire GIL in pywrapper.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ShikharJ committed Jun 14, 2017
1 parent bf388b5 commit ca9d021
Showing 1 changed file with 78 additions and 1 deletion.
79 changes: 78 additions & 1 deletion symengine/lib/pywrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

namespace SymEngine {

int PyGILState_Check2(void) {
PyThreadState * tstate = _PyThreadState_Current;
return tstate && (tstate == PyGILState_GetThisThreadState());
}

// PyModule
PyModule::PyModule(PyObject* (*to_py)(const RCP<const Basic>), RCP<const Basic> (*from_py)(PyObject*),
RCP<const Number> (*eval)(PyObject*, long), RCP<const Basic> (*diff)(PyObject*, RCP<const Basic>)) :
Expand All @@ -19,7 +24,7 @@ PyModule::PyModule(PyObject* (*to_py)(const RCP<const Basic>), RCP<const Basic>
PyModule::~PyModule(){
Py_DECREF(zero);
Py_DECREF(one);
Py_DECREF(minus_one);
Py_DECREF(minus_one);
}

// PyNumber
Expand All @@ -37,8 +42,14 @@ bool PyNumber::__eq__(const Basic &o) const {
}

int PyNumber::compare(const Basic &o) const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
SYMENGINE_ASSERT(is_a<PyNumber>(o))
PyObject* o1 = static_cast<const PyNumber &>(o).get_py_object();
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
if (PyObject_RichCompareBool(pyobject_, o1, Py_EQ) == 1)
return 0;
return PyObject_RichCompareBool(pyobject_, o1, Py_LT) == 1 ? -1 : 1;
Expand Down Expand Up @@ -70,6 +81,9 @@ bool PyNumber::is_complex() const {

//! Addition
RCP<const Number> PyNumber::add(const Number &other) const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
PyObject *other_p, *result;
if (is_a<PyNumber>(other)) {
other_p = static_cast<const PyNumber &>(other).pyobject_;
Expand All @@ -79,10 +93,16 @@ RCP<const Number> PyNumber::add(const Number &other) const {
result = PyNumber_Add(pyobject_, other_p);
Py_XDECREF(other_p);
}
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
return make_rcp<PyNumber>(result, pymodule_);
}
//! Subtraction
RCP<const Number> PyNumber::sub(const Number &other) const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
PyObject *other_p, *result;
if (is_a<PyNumber>(other)) {
other_p = static_cast<const PyNumber &>(other).pyobject_;
Expand All @@ -92,9 +112,15 @@ RCP<const Number> PyNumber::sub(const Number &other) const {
result = PyNumber_Subtract(pyobject_, other_p);
Py_XDECREF(other_p);
}
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
return make_rcp<PyNumber>(result, pymodule_);
}
RCP<const Number> PyNumber::rsub(const Number &other) const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
PyObject *other_p, *result;
if (is_a<PyNumber>(other)) {
other_p = static_cast<const PyNumber &>(other).pyobject_;
Expand All @@ -104,10 +130,16 @@ RCP<const Number> PyNumber::rsub(const Number &other) const {
result = PyNumber_Subtract(other_p, pyobject_);
Py_XDECREF(other_p);
}
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
return make_rcp<PyNumber>(result, pymodule_);
}
//! Multiplication
RCP<const Number> PyNumber::mul(const Number &other) const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
PyObject *other_p, *result;
if (is_a<PyNumber>(other)) {
other_p = static_cast<const PyNumber &>(other).pyobject_;
Expand All @@ -117,10 +149,16 @@ RCP<const Number> PyNumber::mul(const Number &other) const {
result = PyNumber_Multiply(pyobject_, other_p);
Py_XDECREF(other_p);
}
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
return make_rcp<PyNumber>(result, pymodule_);
}
//! Division
RCP<const Number> PyNumber::div(const Number &other) const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
PyObject *other_p, *result;
if (is_a<PyNumber>(other)) {
other_p = static_cast<const PyNumber &>(other).pyobject_;
Expand All @@ -130,9 +168,15 @@ RCP<const Number> PyNumber::div(const Number &other) const {
result = PyNumber_Divide(pyobject_, other_p);
Py_XDECREF(other_p);
}
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
return make_rcp<PyNumber>(result, pymodule_);
}
RCP<const Number> PyNumber::rdiv(const Number &other) const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
PyObject *other_p, *result;
if (is_a<PyNumber>(other)) {
other_p = static_cast<const PyNumber &>(other).pyobject_;
Expand All @@ -142,10 +186,16 @@ RCP<const Number> PyNumber::rdiv(const Number &other) const {
result = PyNumber_Divide(pyobject_, other_p);
Py_XDECREF(other_p);
}
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
return make_rcp<PyNumber>(result, pymodule_);
}
//! Power
RCP<const Number> PyNumber::pow(const Number &other) const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
PyObject *other_p, *result;
if (is_a<PyNumber>(other)) {
other_p = static_cast<const PyNumber &>(other).pyobject_;
Expand All @@ -155,9 +205,15 @@ RCP<const Number> PyNumber::pow(const Number &other) const {
result = PyNumber_Power(pyobject_, other_p, Py_None);
Py_XDECREF(other_p);
}
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
return make_rcp<PyNumber>(result, pymodule_);
}
RCP<const Number> PyNumber::rpow(const Number &other) const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
PyObject *other_p, *result;
if (is_a<PyNumber>(other)) {
other_p = static_cast<const PyNumber &>(other).pyobject_;
Expand All @@ -167,6 +223,9 @@ RCP<const Number> PyNumber::rpow(const Number &other) const {
result = PyNumber_Power(other_p, pyobject_, Py_None);
Py_XDECREF(other_p);
}
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
return make_rcp<PyNumber>(result, pymodule_);
}

Expand All @@ -175,6 +234,9 @@ RCP<const Number> PyNumber::eval(long bits) const {
}

std::string PyNumber::__str__() const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
PyObject* temp;
std::string str;
#if PY_MAJOR_VERSION > 2
Expand All @@ -185,6 +247,9 @@ std::string PyNumber::__str__() const {
str = std::string(PyString_AsString(temp));
#endif
Py_XDECREF(temp);
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
return str;
}

Expand All @@ -196,12 +261,18 @@ PyFunctionClass::PyFunctionClass(PyObject *pyobject, std::string name, const RCP
}

PyObject* PyFunctionClass::call(const vec_basic &vec) const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
PyObject *tuple = PyTuple_New(vec.size());
for (unsigned i = 0; i < vec.size(); i++) {
PyTuple_SetItem(tuple, i, pymodule_->to_py_(vec[i]));
}
PyObject* result = PyObject_CallObject(pyobject_, tuple);
Py_DECREF(tuple);
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
return result;
}

Expand Down Expand Up @@ -240,9 +311,15 @@ RCP<const PyFunctionClass> PyFunction::get_pyfunction_class() const {
}

RCP<const Basic> PyFunction::create(const vec_basic &x) const {
if (not PyGILState_Check2()){
PyEval_AcquireLock();
}
PyObject* pyobj = pyfunction_class_->call(x);
RCP<const Basic> result = pyfunction_class_->get_py_module()->from_py_(pyobj);
Py_XDECREF(pyobj);
if (PyGILState_Check2()){
PyEval_ReleaseLock();
}
return result;
}

Expand Down

0 comments on commit ca9d021

Please sign in to comment.