Skip to content

Commit

Permalink
perf: intern cached str objects (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Dec 1, 2023
1 parent abb03ab commit 442fecc
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
18 changes: 12 additions & 6 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,22 @@ inline void HashCombine(py::ssize_t& seed, const T& v) { // NOLINT[runtime/refe
constexpr bool NONE_IS_LEAF = true;
constexpr bool NONE_IS_NODE = false;

// NOTE: Use raw pointers to leak the memory intentionally to avoid py::object deallocation and
// garbage collection.
#define Py_Declare_ID(name) \
inline PyObject* Py_ID_##name() { \
static PyObject* ptr = (new py::str{#name})->ptr(); \
return ptr; \
#define Py_Declare_ID(name) \
inline PyObject* Py_ID_##name() { \
static PyObject* obj = []() { \
PyObject* ptr = PyUnicode_InternFromString(#name); \
if (ptr == nullptr) [[unlikely]] { \
throw py::error_already_set(); \
} \
Py_INCREF(ptr); /* leak a reference on purpose */ \
return ptr; \
}(); \
return obj; \
}

#define Py_Get_ID(name) (Py_ID_##name())

Py_Declare_ID(optree);
Py_Declare_ID(__module__); // type.__module__
Py_Declare_ID(__qualname__); // type.__qualname__
Py_Declare_ID(__name__); // type.__name__
Expand Down
2 changes: 1 addition & 1 deletion optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def __init_subclass__(cls) -> NoReturn:
raise TypeError("type 'structseq' is not an acceptable base type")

# pylint: disable-next=unused-argument,redefined-builtin
def __new__(cls: type[Self], sequence: Iterable[_T_co], dict: dict[str, Any] = ...) -> Self:
def __new__(cls, sequence: Iterable[_T_co], dict: dict[str, Any] = ...) -> Self:
raise NotImplementedError


Expand Down
4 changes: 2 additions & 2 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
.value("DEQUE", PyTreeKind::Deque, "A collections.deque.")
.value("STRUCTSEQUENCE", PyTreeKind::StructSequence, "A PyStructSequence.");
reinterpret_cast<PyTypeObject*>(PyTreeKindTypeObject.ptr())->tp_name = "optree.PyTreeKind";
py::setattr(PyTreeKindTypeObject.ptr(), "__module__", py::str("optree"));
py::setattr(PyTreeKindTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree));

auto PyTreeSpecTypeObject =
py::class_<PyTreeSpec>(mod, "PyTreeSpec", "Representing the structure of the pytree.");
reinterpret_cast<PyTypeObject*>(PyTreeSpecTypeObject.ptr())->tp_name = "optree.PyTreeSpec";
py::setattr(PyTreeSpecTypeObject.ptr(), "__module__", py::str("optree"));
py::setattr(PyTreeSpecTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree));

PyTreeSpecTypeObject
.def_property_readonly(
Expand Down

0 comments on commit 442fecc

Please sign in to comment.