-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cpp_param: Add ability to map C++ template paramters to Python to ena…
…ble simple templating.
- Loading branch information
1 parent
94afd2f
commit 3c379c8
Showing
10 changed files
with
502 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from __future__ import absolute_import, print_function | ||
|
||
import ctypes | ||
import numpy as np | ||
|
||
""" | ||
@file | ||
Defines a mapping between Python and alias types, and provides canonical Python | ||
types as they relate to C++. | ||
""" | ||
|
||
|
||
def _get_type_name(t, verbose): | ||
# Gets type name as a string. | ||
# Defaults to just returning the name to shorten template names. | ||
if verbose and t.__module__ != "__builtin__": | ||
return t.__module__ + "." + t.__name__ | ||
else: | ||
return t.__name__ | ||
|
||
|
||
class _StrictMap(object): | ||
# Provides a map which may only add a key once. | ||
def __init__(self): | ||
self._values = dict() | ||
|
||
def add(self, key, value): | ||
assert key not in self._values, "Already added: {}".format(key) | ||
self._values[key] = value | ||
|
||
def get(self, key, default): | ||
return self._values.get(key, default) | ||
|
||
|
||
class _ParamAliases(object): | ||
# Registers aliases for a set of objects. This will be used for template | ||
# parameters. | ||
def __init__(self): | ||
self._to_canonical = _StrictMap() | ||
self._register_common() | ||
|
||
def _register_common(self): | ||
# Register common Python aliases relevant for C++. | ||
self.register(float, [np.double, ctypes.c_double]) | ||
self.register(np.float32, [ctypes.c_float]) | ||
self.register(int, [np.int32, ctypes.c_int32]) | ||
self.register(np.uint32, [ctypes.c_uint32]) | ||
self.register(np.int64, [ctypes.c_int64]) | ||
|
||
def register(self, canonical, aliases): | ||
# Registers a set of aliases to a canonical value. | ||
for alias in aliases: | ||
self._to_canonical.add(alias, canonical) | ||
|
||
def is_aliased(self, alias): | ||
# Determines if a parameter is aliased / registered. | ||
return self._to_canonical.get(alias, None) is not None | ||
|
||
def get_canonical(self, alias): | ||
# Gets registered canonical parameter if it is aliased; otherwise | ||
# return the same parameter. | ||
return self._to_canonical.get(alias, alias) | ||
|
||
def get_name(self, alias): | ||
# Gets string for an alias. | ||
canonical = self.get_canonical(alias) | ||
if isinstance(canonical, type): | ||
return _get_type_name(canonical, False) | ||
else: | ||
# For literals. | ||
return str(canonical) | ||
|
||
|
||
# Create singleton instance. | ||
_param_aliases = _ParamAliases() | ||
|
||
|
||
def get_param_canonical(param): | ||
"""Gets the canonical types for a set of Python types (canonical as in | ||
how they relate to C++ types. """ | ||
return tuple(map(_param_aliases.get_canonical, param)) | ||
|
||
|
||
def get_param_names(param): | ||
"""Gets the canonical type names for a set of Python types (canonical as in | ||
how they relate to C++ types. """ | ||
return tuple(map(_param_aliases.get_name, param)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#include "drake/bindings/pydrake/util/cpp_param_pybind.h" | ||
|
||
#include <pybind11/eval.h> | ||
|
||
namespace drake { | ||
namespace pydrake { | ||
namespace internal { | ||
namespace { | ||
|
||
// Creates a Python object that should uniquely hash for a primitive C++ | ||
// type. | ||
py::object GetPyHash(const std::type_info& tinfo) { | ||
return py::make_tuple("cpp_type", tinfo.hash_code()); | ||
} | ||
|
||
// Registers C++ type. | ||
template <typename T> | ||
void RegisterType( | ||
py::module m, py::object param_aliases, const std::string& canonical_str) { | ||
// Create an object that is a unique hash. | ||
py::object canonical = py::eval(canonical_str, m.attr("__dict__")); | ||
py::list aliases(1); | ||
aliases[0] = GetPyHash(typeid(T)); | ||
param_aliases.attr("register")(canonical, aliases); | ||
} | ||
|
||
// Registers common C++ types. | ||
void RegisterCommon(py::module m, py::object param_aliases) { | ||
// Make mappings for C++ RTTI to Python types. | ||
// Unfortunately, this is hard to obtain from `pybind11`. | ||
RegisterType<bool>(m, param_aliases, "bool"); | ||
RegisterType<std::string>(m, param_aliases, "str"); | ||
RegisterType<double>(m, param_aliases, "float"); | ||
RegisterType<float>(m, param_aliases, "np.float32"); | ||
RegisterType<int>(m, param_aliases, "int"); | ||
RegisterType<uint32_t>(m, param_aliases, "np.uint32"); | ||
RegisterType<int64_t>(m, param_aliases, "np.int64"); | ||
// For supporting generic Python types. | ||
RegisterType<py::object>(m, param_aliases, "object"); | ||
} | ||
|
||
} // namespace | ||
|
||
py::object GetParamAliases() { | ||
py::module m = py::module::import("pydrake.util.cpp_param"); | ||
py::object param_aliases = m.attr("_param_aliases"); | ||
const char registered_check[] = "_register_common_cpp"; | ||
if (!py::hasattr(m, registered_check)) { | ||
RegisterCommon(m, param_aliases); | ||
m.attr(registered_check) = true; | ||
} | ||
return param_aliases; | ||
} | ||
|
||
py::object GetPyParamScalarImpl(const std::type_info& tinfo) { | ||
py::object param_aliases = GetParamAliases(); | ||
py::object py_hash = GetPyHash(tinfo); | ||
if (param_aliases.attr("is_aliased")(py_hash).cast<bool>()) { | ||
// If it's an alias, return the canonical type. | ||
return param_aliases.attr("get_canonical")(py_hash); | ||
} else { | ||
// This type is not aliased. Get the pybind-registered type, | ||
// erroring out if it's not registered. | ||
// WARNING: Internal API :( | ||
auto* info = py::detail::get_type_info(tinfo); | ||
if (!info) { | ||
// TODO(eric.cousineau): Use NiceTypeName::Canonicalize(...Demangle(...)) | ||
// once simpler dependencies are used (or something else is used to | ||
// justify linking in `libdrake.so`). | ||
const std::string name = tinfo.name(); | ||
throw std::runtime_error( | ||
"C++ type is not registered in pybind: " + name); | ||
} | ||
py::handle h(reinterpret_cast<PyObject*>(info->type)); | ||
return py::reinterpret_borrow<py::object>(h); | ||
} | ||
} | ||
|
||
} // namespace internal | ||
} // namespace pydrake | ||
} // namespace drake |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#pragma once | ||
|
||
/// @file | ||
/// Provides a mechanism to map C++ types to canonical Python types. | ||
|
||
#include <string> | ||
#include <typeinfo> | ||
#include <vector> | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
#include "drake/bindings/pydrake/util/type_pack.h" | ||
|
||
namespace drake { | ||
namespace pydrake { | ||
|
||
// This alias is intended to be part of the public API, as it follows | ||
// `pybind11` conventions. | ||
namespace py = pybind11; | ||
|
||
namespace internal { | ||
|
||
// Gets singleton for type aliases from `cpp_param`. | ||
py::object GetParamAliases(); | ||
|
||
// Gets Python type object given `std::type_info`. | ||
// @throws std::runtime_error if type is neither aliased nor registered in | ||
// `pybind11`. | ||
py::object GetPyParamScalarImpl(const std::type_info& tinfo); | ||
|
||
// Gets Python type for a C++ type (base case). | ||
template <typename T> | ||
inline py::object GetPyParamScalarImpl(type_pack<T> = {}) { | ||
return GetPyParamScalarImpl(typeid(T)); | ||
} | ||
|
||
// Gets Python literal for a C++ literal (specialization). | ||
template <typename T, T Value> | ||
inline py::object GetPyParamScalarImpl( | ||
type_pack<std::integral_constant<T, Value>> = {}) { | ||
return py::cast(Value); | ||
} | ||
|
||
} // namespace internal | ||
|
||
/// Gets the canonical Python parameters for each C++ type. | ||
/// @returns Python tuple of canonical parameters. | ||
/// @throws std::runtime_error on the first type it encounters that is neither | ||
/// aliased nor registered in `pybind11`. | ||
template <typename ... Ts> | ||
inline py::tuple GetPyParam(type_pack<Ts...> = {}) { | ||
return py::make_tuple(internal::GetPyParamScalarImpl(type_pack<Ts>{})...); | ||
} | ||
|
||
} // namespace pydrake | ||
} // namespace drake |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
#include "drake/bindings/pydrake/util/cpp_param_pybind.h" | ||
|
||
// @file | ||
// Tests the public interfaces in `cpp_param.py` and `cpp_param_pybind.h`. | ||
|
||
#include <stdexcept> | ||
#include <string> | ||
|
||
#include <gtest/gtest.h> | ||
#include <pybind11/embed.h> | ||
#include <pybind11/eval.h> | ||
#include <pybind11/pybind11.h> | ||
|
||
using std::string; | ||
|
||
namespace drake { | ||
namespace pydrake { | ||
|
||
// Compare two Python objects directly. | ||
bool PyEquals(py::object lhs, py::object rhs) { | ||
return lhs.attr("__eq__")(rhs).cast<bool>(); | ||
} | ||
|
||
// Ensures that the type `T` maps to the expression in `py_expr_expected`. | ||
template <typename ... Ts> | ||
bool CheckPyParam(const string& py_expr_expected, type_pack<Ts...> param = {}) { | ||
py::object actual = GetPyParam(param); | ||
py::object expected = py::eval(py_expr_expected.c_str()); | ||
return PyEquals(actual, expected); | ||
} | ||
|
||
GTEST_TEST(CppParamTest, PrimitiveTypes) { | ||
// Tests primitive types that are not expose directly via `pybind11`, thus | ||
// needing custom registration. | ||
// This follows the ordering in `cpp_param_pybind.cc`, | ||
// `RegisterCommon`. | ||
ASSERT_TRUE(CheckPyParam<bool>("bool,")); | ||
ASSERT_TRUE(CheckPyParam<std::string>("str,")); | ||
ASSERT_TRUE(CheckPyParam<double>("float,")); | ||
ASSERT_TRUE(CheckPyParam<float>("np.float32,")); | ||
ASSERT_TRUE(CheckPyParam<int>("int,")); | ||
ASSERT_TRUE(CheckPyParam<uint32_t>("np.uint32,")); | ||
ASSERT_TRUE(CheckPyParam<int64_t>("np.int64,")); | ||
ASSERT_TRUE(CheckPyParam<py::object>("object,")); | ||
} | ||
|
||
// Dummy type. | ||
// - Registered. | ||
struct CustomCppType {}; | ||
// - Unregistered. | ||
struct CustomCppTypeUnregistered {}; | ||
|
||
GTEST_TEST(CppParamTest, CustomTypes) { | ||
// Tests types that are C++ types registered with `pybind11`. | ||
ASSERT_TRUE(CheckPyParam<CustomCppType>("CustomCppType,")); | ||
EXPECT_THROW( | ||
CheckPyParam<CustomCppTypeUnregistered>("CustomCppTypeUnregistered"), | ||
std::runtime_error); | ||
} | ||
|
||
template <typename T, T Value> | ||
using constant = std::integral_constant<T, Value>; | ||
|
||
GTEST_TEST(CppParamTest, LiteralTypes) { | ||
// Tests that literal types are mapped to literals in Python. | ||
ASSERT_TRUE(CheckPyParam<std::true_type>("True,")); | ||
ASSERT_TRUE((CheckPyParam<constant<int, -1>>("-1,"))); | ||
ASSERT_TRUE((CheckPyParam<constant<uint, 1>>("1,"))); | ||
} | ||
|
||
GTEST_TEST(CppParamTest, Packs) { | ||
// Tests that type packs are properly interpreted. | ||
ASSERT_TRUE((CheckPyParam<int, bool>("int, bool"))); | ||
ASSERT_TRUE((CheckPyParam<bool, constant<bool, false>>("bool, False"))); | ||
} | ||
|
||
int main(int argc, char** argv) { | ||
// Reconstructing `scoped_interpreter` multiple times (e.g. via `SetUp()`) | ||
// while *also* importing `numpy` wreaks havoc. | ||
py::scoped_interpreter guard; | ||
|
||
// Define common scope, import numpy for use in `eval`. | ||
py::module m("__main__"); | ||
py::globals()["np"] = py::module::import("numpy"); | ||
|
||
// Define custom class only once here. | ||
py::class_<CustomCppType>(m, "CustomCppType"); | ||
|
||
::testing::InitGoogleTest(&argc, argv); | ||
return RUN_ALL_TESTS(); | ||
} | ||
|
||
} // namespace pydrake | ||
} // namespace drake | ||
|
||
int main(int argc, char** argv) { | ||
return drake::pydrake::main(argc, argv); | ||
} |
Oops, something went wrong.