Skip to content

Commit

Permalink
framework_py: Add test of using AbstractValue's in input ports (follo…
Browse files Browse the repository at this point in the history
…wing up from pybind unique_ptr snafu)
  • Loading branch information
EricCousineau-TRI committed Feb 20, 2018
1 parent f46b536 commit ea3a403
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 3 deletions.
26 changes: 23 additions & 3 deletions bindings/pydrake/systems/framework_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,16 @@ PYBIND11_MODULE(framework, m) {
"EvalVectorInput",
[](const System<T>* self, const Context<T>& arg1, int arg2) {
return self->EvalVectorInput(arg1, arg2);
}, py_reference_internal)
}, py_reference,
// Keep alive, ownership: `return` keeps `Context` alive.
py::keep_alive<0, 2>())
.def(
"EvalAbstractInput",
[](const System<T>* self, const Context<T>& arg1, int arg2) {
return self->EvalAbstractInput(arg1, arg2);
}, py_reference,
// Keep alive, ownership: `return` keeps `Context` alive.
py::keep_alive<0, 2>())
.def("CalcOutput", &System<T>::CalcOutput)
// Sugar.
.def(
Expand Down Expand Up @@ -352,9 +361,16 @@ PYBIND11_MODULE(framework, m) {
.def("get_num_input_ports", &Context<T>::get_num_input_ports)
.def("FixInputPort",
py::overload_cast<int, unique_ptr<BasicVector<T>>>(
&Context<T>::FixInputPort), py_reference_internal,
&Context<T>::FixInputPort),
py_reference_internal,
// Keep alive, ownership: `BasicVector` keeps `self` alive.
py::keep_alive<3, 1>())
.def("FixInputPort",
py::overload_cast<int, unique_ptr<AbstractValue>>(
&Context<T>::FixInputPort),
py_reference_internal,
// Keep alive, ownership: `AbstractValue` keeps `self` alive.
py::keep_alive<3, 1>())
.def("get_time", &Context<T>::get_time)
.def("set_time", &Context<T>::set_time)
.def("Clone", &Context<T>::Clone)
Expand Down Expand Up @@ -437,8 +453,12 @@ PYBIND11_MODULE(framework, m) {
py::class_<OutputPort<T>>(m, "OutputPort")
.def("size", &OutputPort<T>::size);

py::class_<SystemOutput<T>>(m, "SystemOutput")
py::class_<SystemOutput<T>> system_output(m, "SystemOutput");
DefClone(&system_output);
system_output
.def("get_num_ports", &SystemOutput<T>::get_num_ports)
.def("get_data", &SystemOutput<T>::get_data,
py_reference_internal)
.def("get_vector_data", &SystemOutput<T>::get_vector_data,
py_reference_internal);

Expand Down
5 changes: 5 additions & 0 deletions bindings/pydrake/systems/primitives_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "drake/systems/primitives/constant_vector_source.h"
#include "drake/systems/primitives/integrator.h"
#include "drake/systems/primitives/linear_system.h"
#include "drake/systems/primitives/pass_through.h"
#include "drake/systems/primitives/signal_logger.h"
#include "drake/systems/primitives/zero_order_hold.h"

Expand Down Expand Up @@ -72,6 +73,10 @@ PYBIND11_MODULE(primitives, m) {
py::arg("A"), py::arg("B"), py::arg("C"), py::arg("D"),
py::arg("time_period") = 0.0);

py::class_<PassThrough<T>, LeafSystem<T>>(m, "PassThrough")
.def(py::init<int>())
.def(py::init<const AbstractValue&>());

py::class_<SignalLogger<T>, LeafSystem<T>>(m, "SignalLogger")
.def(py::init<int>())
.def(py::init<int, int>())
Expand Down
35 changes: 35 additions & 0 deletions bindings/pydrake/systems/test/general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,31 @@
Simulator,
)
from pydrake.systems.framework import (
AbstractValue,
BasicVector,
Diagram,
DiagramBuilder,
VectorBase,
)
from pydrake.systems.primitives import (
Adder,
AffineSystem,
ConstantVectorSource,
Integrator,
LinearSystem,
PassThrough,
SignalLogger,
)


def compare_value(test, a, b):
if isinstance(a, VectorBase):
test.assertTrue(np.allclose(a.get_value(), b.get_value()))
else:
test.assertEquals(type(a.get_value()), type(b.get_value()))
test.assertEquals(a.get_value(), b.get_value())


class TestGeneral(unittest.TestCase):
def test_simulator_ctor(self):
# Create simple system.
Expand Down Expand Up @@ -203,6 +214,30 @@ def test_linear_affine_system(self):
self.assertEqual(system.y0(), y0)
self.assertEqual(system.time_period(), .1)

def test_vector_pass_through(self):
model_value = BasicVector([1., 2, 3])
system = PassThrough(model_value.size())
context = system.CreateDefaultContext()
context.FixInputPort(0, model_value)
output = system.AllocateOutput(context)
input_eval = system.EvalVectorInput(context, 0)
compare_value(self, input_eval, model_value)
system.CalcOutput(context, output)
output_value = output.get_vector_data(0)
compare_value(self, output_value, model_value)

def test_abstract_pass_through(self):
model_value = AbstractValue.Make("Hello world")
system = PassThrough(model_value)
context = system.CreateDefaultContext()
context.FixInputPort(0, model_value)
output = system.AllocateOutput(context)
input_eval = system.EvalAbstractInput(context, 0)
compare_value(self, input_eval, model_value)
system.CalcOutput(context, output)
output_value = output.get_data(0)
compare_value(self, output_value, model_value)


if __name__ == '__main__':
unittest.main()

0 comments on commit ea3a403

Please sign in to comment.