From ea3a4033520484d01caec30ebef3142b96f2efbc Mon Sep 17 00:00:00 2001 From: Eric Cousineau Date: Tue, 20 Feb 2018 10:56:18 -0500 Subject: [PATCH] framework_py: Add test of using AbstractValue's in input ports (following up from pybind unique_ptr snafu) --- bindings/pydrake/systems/framework_py.cc | 26 ++++++++++++-- bindings/pydrake/systems/primitives_py.cc | 5 +++ bindings/pydrake/systems/test/general_test.py | 35 +++++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/bindings/pydrake/systems/framework_py.cc b/bindings/pydrake/systems/framework_py.cc index 34a1028b4af3..e3ff65cd823b 100644 --- a/bindings/pydrake/systems/framework_py.cc +++ b/bindings/pydrake/systems/framework_py.cc @@ -283,7 +283,16 @@ PYBIND11_MODULE(framework, m) { "EvalVectorInput", [](const System* self, const Context& arg1, int arg2) { return self->EvalVectorInput(arg1, arg2); - }, py_reference_internal) + }, py_reference, + // Keep alive, ownership: `return` keeps `Context` alive. + py::keep_alive<0, 2>()) + .def( + "EvalAbstractInput", + [](const System* self, const Context& arg1, int arg2) { + return self->EvalAbstractInput(arg1, arg2); + }, py_reference, + // Keep alive, ownership: `return` keeps `Context` alive. + py::keep_alive<0, 2>()) .def("CalcOutput", &System::CalcOutput) // Sugar. .def( @@ -352,9 +361,16 @@ PYBIND11_MODULE(framework, m) { .def("get_num_input_ports", &Context::get_num_input_ports) .def("FixInputPort", py::overload_cast>>( - &Context::FixInputPort), py_reference_internal, + &Context::FixInputPort), + py_reference_internal, // Keep alive, ownership: `BasicVector` keeps `self` alive. py::keep_alive<3, 1>()) + .def("FixInputPort", + py::overload_cast>( + &Context::FixInputPort), + py_reference_internal, + // Keep alive, ownership: `AbstractValue` keeps `self` alive. + py::keep_alive<3, 1>()) .def("get_time", &Context::get_time) .def("set_time", &Context::set_time) .def("Clone", &Context::Clone) @@ -437,8 +453,12 @@ PYBIND11_MODULE(framework, m) { py::class_>(m, "OutputPort") .def("size", &OutputPort::size); - py::class_>(m, "SystemOutput") + py::class_> system_output(m, "SystemOutput"); + DefClone(&system_output); + system_output .def("get_num_ports", &SystemOutput::get_num_ports) + .def("get_data", &SystemOutput::get_data, + py_reference_internal) .def("get_vector_data", &SystemOutput::get_vector_data, py_reference_internal); diff --git a/bindings/pydrake/systems/primitives_py.cc b/bindings/pydrake/systems/primitives_py.cc index f39e73d81073..e21611166781 100644 --- a/bindings/pydrake/systems/primitives_py.cc +++ b/bindings/pydrake/systems/primitives_py.cc @@ -9,6 +9,7 @@ #include "drake/systems/primitives/constant_vector_source.h" #include "drake/systems/primitives/integrator.h" #include "drake/systems/primitives/linear_system.h" +#include "drake/systems/primitives/pass_through.h" #include "drake/systems/primitives/signal_logger.h" #include "drake/systems/primitives/zero_order_hold.h" @@ -72,6 +73,10 @@ PYBIND11_MODULE(primitives, m) { py::arg("A"), py::arg("B"), py::arg("C"), py::arg("D"), py::arg("time_period") = 0.0); + py::class_, LeafSystem>(m, "PassThrough") + .def(py::init()) + .def(py::init()); + py::class_, LeafSystem>(m, "SignalLogger") .def(py::init()) .def(py::init()) diff --git a/bindings/pydrake/systems/test/general_test.py b/bindings/pydrake/systems/test/general_test.py index 5835d606d855..e57ca1a3e625 100644 --- a/bindings/pydrake/systems/test/general_test.py +++ b/bindings/pydrake/systems/test/general_test.py @@ -12,9 +12,11 @@ Simulator, ) from pydrake.systems.framework import ( + AbstractValue, BasicVector, Diagram, DiagramBuilder, + VectorBase, ) from pydrake.systems.primitives import ( Adder, @@ -22,10 +24,19 @@ ConstantVectorSource, Integrator, LinearSystem, + PassThrough, SignalLogger, ) +def compare_value(test, a, b): + if isinstance(a, VectorBase): + test.assertTrue(np.allclose(a.get_value(), b.get_value())) + else: + test.assertEquals(type(a.get_value()), type(b.get_value())) + test.assertEquals(a.get_value(), b.get_value()) + + class TestGeneral(unittest.TestCase): def test_simulator_ctor(self): # Create simple system. @@ -203,6 +214,30 @@ def test_linear_affine_system(self): self.assertEqual(system.y0(), y0) self.assertEqual(system.time_period(), .1) + def test_vector_pass_through(self): + model_value = BasicVector([1., 2, 3]) + system = PassThrough(model_value.size()) + context = system.CreateDefaultContext() + context.FixInputPort(0, model_value) + output = system.AllocateOutput(context) + input_eval = system.EvalVectorInput(context, 0) + compare_value(self, input_eval, model_value) + system.CalcOutput(context, output) + output_value = output.get_vector_data(0) + compare_value(self, output_value, model_value) + + def test_abstract_pass_through(self): + model_value = AbstractValue.Make("Hello world") + system = PassThrough(model_value) + context = system.CreateDefaultContext() + context.FixInputPort(0, model_value) + output = system.AllocateOutput(context) + input_eval = system.EvalAbstractInput(context, 0) + compare_value(self, input_eval, model_value) + system.CalcOutput(context, output) + output_value = output.get_data(0) + compare_value(self, output_value, model_value) + if __name__ == '__main__': unittest.main()