Skip to content

Commit

Permalink
Call PySys_SetArgv when initializing interpreter.
Browse files Browse the repository at this point in the history
  • Loading branch information
drmoose committed Jul 29, 2020
1 parent 1491c94 commit 1801811
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 5 deletions.
104 changes: 99 additions & 5 deletions include/pybind11/embed.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,95 @@ struct embedded_module {
}
};

/// Python 2.x/3.x-compatible version of `PySys_SetArgv`
inline void set_interpreter_argv(int argc, char** argv, bool add_current_dir_to_path) {
// Before it was special-cased in python 3.8, passing an empty or null argv
// caused a segfault, so we have to reimplement the special case ourselves.
char** safe_argv = argv;
if (nullptr == argv || argc <= 0) {
safe_argv = new char*[1];
if (nullptr == safe_argv) return;
safe_argv[0] = new char[1];
if (nullptr == safe_argv[0]) {
delete[] safe_argv;
return;
}
safe_argv[0][0] = '\0';
argc = 1;
}
#if PY_MAJOR_VERSION >= 3
// SetArgv* on python 3 takes wchar_t, so we have to convert.
wchar_t** widened_argv = new wchar_t*[static_cast<unsigned>(argc)];
for (int ii = 0; ii < argc; ++ii) {
# if PY_MINOR_VERSION >= 5
// From Python 3.5 onwards, we're supposed to use Py_DecodeLocale to
// generate the wchar_t version of argv.
widened_argv[ii] = Py_DecodeLocale(safe_argv[ii], nullptr);
# define FREE_WIDENED_ARG(X) PyMem_RawFree(X)
# else
// Before Python 3.5, we're stuck with mbstowcs, which may or may not
// actually work. Mercifully, pyconfig.h provides this define:
# ifdef HAVE_BROKEN_MBSTOWCS
size_t count = strlen(safe_argv[ii]);
# else
size_t count = mbstowcs(nullptr, safe_argv[ii], 0);
# endif
widened_argv[ii] = nullptr;
if (count != static_cast<size_t>(-1)) {
widened_argv[ii] = new wchar_t[count + 1];
mbstowcs(widened_argv[ii], safe_argv[ii], count + 1);
}
# define FREE_WIDENED_ARG(X) delete[] X
# endif
if (nullptr == widened_argv[ii]) {
// Either we ran out of memory or had a unicode encoding issue.
// Free what we've encoded so far and bail.
for (--ii; ii >= 0; --ii)
FREE_WIDENED_ARG(widened_argv[ii]);
return;
}
}

# if PY_MINOR_VERSION < 1 || (PY_MINOR_VERSION == 1 && PY_MICRO_VERSION < 3)
# define NEED_PYRUN_TO_SANITIZE_PATH 1
// don't have SetArgvEx yet
PySys_SetArgv(argc, widened_argv);
# else
PySys_SetArgvEx(argc, widened_argv, add_current_dir_to_path ? 1 : 0);
# endif

// PySys_SetArgv makes new PyUnicode objects so we can clean up this memory
if (nullptr != widened_argv) {
for (int ii = 0; ii < argc; ++ii)
if (nullptr != widened_argv[ii])
FREE_WIDENED_ARG(widened_argv[ii]);
delete[] widened_argv;
}
# undef FREE_WIDENED_ARG
#else
// python 2.x
# if PY_MINOR_VERSION < 6 || (PY_MINOR_VERSION == 6 && PY_MICRO_VERSION < 6)
# define NEED_PYRUN_TO_SANITIZE_PATH 1
// don't have SetArgvEx yet
PySys_SetArgv(argc, safe_argv);
# else
PySys_SetArgvEx(argc, safe_argv, add_current_dir_to_path ? 1 : 0);
# endif
#endif

#ifdef NEED_PYRUN_TO_SANITIZE_PATH
# undef NEED_PYRUN_TO_SANITIZE_PATH
if (!add_current_dir_to_path)
PyRun_SimpleString("import sys; sys.path.pop(0)\n");
#endif

// if we allocated new memory to make safe_argv, we need to free it
if (safe_argv != argv) {
delete[] safe_argv[0];
delete[] safe_argv;
}
}

PYBIND11_NAMESPACE_END(detail)

/** \rst
Expand All @@ -102,14 +191,16 @@ PYBIND11_NAMESPACE_END(detail)
.. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
\endrst */
inline void initialize_interpreter(bool init_signal_handlers = true) {
inline void initialize_interpreter(bool init_signal_handlers = true,
int argc = 0,
char** argv = nullptr,
bool add_current_dir_to_path = true) {
if (Py_IsInitialized())
pybind11_fail("The interpreter is already running");

Py_InitializeEx(init_signal_handlers ? 1 : 0);

// Make .py files in the working directory available by default
module::import("sys").attr("path").cast<list>().append(".");
detail::set_interpreter_argv(argc, argv, add_current_dir_to_path);
}

/** \rst
Expand Down Expand Up @@ -182,8 +273,11 @@ inline void finalize_interpreter() {
\endrst */
class scoped_interpreter {
public:
scoped_interpreter(bool init_signal_handlers = true) {
initialize_interpreter(init_signal_handlers);
scoped_interpreter(bool init_signal_handlers = true,
int argc = 0,
char** argv = nullptr,
bool add_current_dir_to_path = true) {
initialize_interpreter(init_signal_handlers, argc, argv, add_current_dir_to_path);
}

scoped_interpreter(const scoped_interpreter &) = delete;
Expand Down
24 changes: 24 additions & 0 deletions tests/test_embed/test_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Widget {

std::string the_message() const { return message; }
virtual int the_answer() const = 0;
virtual std::string argv0() const = 0;

private:
std::string message;
Expand All @@ -31,6 +32,7 @@ class PyWidget final : public Widget {
using Widget::Widget;

int the_answer() const override { PYBIND11_OVERLOAD_PURE(int, Widget, the_answer); }
std::string argv0() const override { PYBIND11_OVERLOAD_PURE(std::string, Widget, argv0); }
};

PYBIND11_EMBEDDED_MODULE(widget_module, m) {
Expand Down Expand Up @@ -282,3 +284,25 @@ TEST_CASE("Reload module from file") {
result = module.attr("test")().cast<int>();
REQUIRE(result == 2);
}

TEST_CASE("sys.argv gets initialized properly") {
py::finalize_interpreter();
{
py::scoped_interpreter default_scope;
auto module = py::module::import("test_interpreter");
auto py_widget = module.attr("DerivedWidget")("The question");
const auto &cpp_widget = py_widget.cast<const Widget &>();
REQUIRE(cpp_widget.argv0() == "");
}

{
char* argv[] = { strdup("a.out") };
py::scoped_interpreter argv_scope(true, 1, argv);
free(argv[0]);
auto module = py::module::import("test_interpreter");
auto py_widget = module.attr("DerivedWidget")("The question");
const auto &cpp_widget = py_widget.cast<const Widget &>();
REQUIRE(cpp_widget.argv0() == "a.out");
}
py::initialize_interpreter();
}
4 changes: 4 additions & 0 deletions tests/test_embed/test_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from widget_module import Widget
import sys


class DerivedWidget(Widget):
Expand All @@ -8,3 +9,6 @@ def __init__(self, message):

def the_answer(self):
return 42

def argv0(self):
return sys.argv[0]

0 comments on commit 1801811

Please sign in to comment.