Skip to content

Commit

Permalink
Pure pybind example of overload failure
Browse files Browse the repository at this point in the history
  • Loading branch information
EricCousineau-TRI committed Nov 10, 2023
1 parent 5303029 commit 77c2e56
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tmp/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
load("//tools/skylark:pybind.bzl", "pybind_py_library")

pybind_py_library(
name = "cc_py",
cc_so_name = "cc",
cc_srcs = ["cc_py.cc"],
py_imports = ["."],
)

py_test(
name = "cc_py_test",
srcs = ["cc_py_test.py"],
deps = [":cc_py"],
)
54 changes: 54 additions & 0 deletions tmp/cc_py.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "pybind11/eigen.h"
#include "pybind11/pybind11.h"

namespace py = pybind11;

namespace drake {

struct MyScalar {
int value{};
};

} // namespace drake

PYBIND11_NUMPY_OBJECT_DTYPE(drake::MyScalar);

namespace drake {
namespace {

std::string AcceptMatrixDense(const Eigen::MatrixXd&) {
return "dense";
}

std::string AcceptMatrixSparse(const Eigen::SparseMatrix<double>&) {
return "sparse";
}

using VectorXMyScalar = Eigen::Matrix<MyScalar, Eigen::Dynamic, 1>;

std::string AcceptMatrixAndObjectDense(
const Eigen::MatrixXd&, const VectorXMyScalar&) {
return "dense";
}

std::string AcceptMatrixAndObjectSparse(
const Eigen::SparseMatrix<double>&, const VectorXMyScalar&) {
return "sparse";
}

// Try to replicate signature as in #20516.

PYBIND11_MODULE(cc, m) {
m.def("AcceptMatrix", &AcceptMatrixDense);
m.def("AcceptMatrix", &AcceptMatrixSparse);

py::class_<MyScalar>(m, "MyScalar")
.def(py::init<int>(), py::arg("value"))
.def_readwrite("value", &MyScalar::value);

m.def("AcceptMatrixAndObject", &AcceptMatrixAndObjectSparse);
m.def("AcceptMatrixAndObject", &AcceptMatrixAndObjectDense);
}

} // namespace
} // namespace drake
23 changes: 23 additions & 0 deletions tmp/cc_py_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest

import numpy as np
import scipy.sparse

from drake.tmp.cc import AcceptMatrix, AcceptMatrixAndObject, MyScalar


class Test(unittest.TestCase):
def test_overload(self):
A_dense = np.eye(2)
A_sparse = scipy.sparse.csc_matrix(np.eye(2))
self.assertEqual(AcceptMatrix(A_dense), "dense")
self.assertEqual(AcceptMatrix(A_sparse), "sparse")

x = np.array([MyScalar(1), MyScalar(2)])
self.assertEqual(x.dtype, object)
self.assertEqual(AcceptMatrixAndObject(A_dense, x), "dense")
self.assertEqual(AcceptMatrixAndObject(A_sparse, x), "sparse")


if __name__ == "__main__":
unittest.main()

0 comments on commit 77c2e56

Please sign in to comment.