Skip to content

Commit

Permalink
framework_py: Add test of using AbstractValue's in input ports (follo…
Browse files Browse the repository at this point in the history
…wing up from pybind unique_ptr snafu)
  • Loading branch information
EricCousineau-TRI committed Feb 17, 2018
1 parent 0b81c9f commit 9e808c3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
22 changes: 20 additions & 2 deletions bindings/pydrake/systems/framework_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,13 @@ PYBIND11_MODULE(framework, m) {
[](const System<T>* self, const Context<T>& arg1, int arg2) {
return self->EvalVectorInput(arg1, arg2);
}, py_reference_internal)
.def(
"EvalAbstractInput",
[](const System<T>* self, const Context<T>& arg1, int arg2) {
return self->EvalAbstractInput(arg1, arg2);
}, py_reference,
// Keep alive, ownership: `return` keeps `Context` alive.
py::keep_alive<0, 2>())
.def("CalcOutput", &System<T>::CalcOutput)
// Sugar.
.def(
Expand Down Expand Up @@ -349,9 +356,16 @@ PYBIND11_MODULE(framework, m) {
.def("get_num_input_ports", &Context<T>::get_num_input_ports)
.def("FixInputPort",
py::overload_cast<int, unique_ptr<BasicVector<T>>>(
&Context<T>::FixInputPort), py_reference_internal,
&Context<T>::FixInputPort),
py_reference_internal,
// Keep alive, ownership: `BasicVector` keeps `self` alive.
py::keep_alive<3, 1>())
.def("FixInputPort",
py::overload_cast<int, unique_ptr<AbstractValue>>(
&Context<T>::FixInputPort),
py_reference_internal,
// Keep alive, ownership: `AbstractValue` keeps `self` alive.
py::keep_alive<3, 1>())
.def("get_time", &Context<T>::get_time)
.def("Clone", &Context<T>::Clone)
.def("__copy__", &Context<T>::Clone)
Expand Down Expand Up @@ -432,8 +446,12 @@ PYBIND11_MODULE(framework, m) {

py::class_<OutputPort<T>>(m, "OutputPort");

py::class_<SystemOutput<T>>(m, "SystemOutput")
py::class_<SystemOutput<T>> system_output(m, "SystemOutput");
DefClone(&system_output);
system_output
.def("get_num_ports", &SystemOutput<T>::get_num_ports)
.def("get_data", &SystemOutput<T>::get_data,
py_reference_internal)
.def("get_vector_data", &SystemOutput<T>::get_vector_data,
py_reference_internal);

Expand Down
5 changes: 5 additions & 0 deletions bindings/pydrake/systems/primitives_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "drake/systems/primitives/constant_value_source.h"
#include "drake/systems/primitives/constant_vector_source.h"
#include "drake/systems/primitives/integrator.h"
#include "drake/systems/primitives/pass_through.h"
#include "drake/systems/primitives/signal_logger.h"
#include "drake/systems/primitives/zero_order_hold.h"

Expand All @@ -31,6 +32,10 @@ PYBIND11_MODULE(primitives, m) {
py::class_<Integrator<T>, LeafSystem<T>>(m, "Integrator")
.def(py::init<int>());

py::class_<PassThrough<T>, LeafSystem<T>>(m, "PassThrough")
.def(py::init<int>())
.def(py::init<const AbstractValue&>());

py::class_<ZeroOrderHold<T>, LeafSystem<T>>(m, "ZeroOrderHold")
.def(py::init<double, int>());

Expand Down
15 changes: 15 additions & 0 deletions bindings/pydrake/systems/test/general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Simulator,
)
from pydrake.systems.framework import (
AbstractValue,
BasicVector,
Diagram,
DiagramBuilder,
Expand All @@ -20,6 +21,7 @@
Adder,
ConstantVectorSource,
Integrator,
PassThrough,
SignalLogger,
)

Expand Down Expand Up @@ -165,6 +167,19 @@ def test_signal_logger(self):
self.assertTrue(t.shape[0] == x.shape[1])
self.assertAlmostEqual(x[0, -1], t[-1]*kValue, places=2)

def test_abstract_pass_through(self):
model_value = AbstractValue.Make("Hello world")
input_value = model_value.Clone()
system = PassThrough(model_value)
context = system.CreateDefaultContext()
context.FixInputPort(0, input_value)
output = system.AllocateOutput(context)
input_eval = system.EvalAbstractInput(context, 0)
self.assertEquals(input_eval.get_value(), input_value.get_value())
system.CalcOutput(context, output)
output_value = output.get_data(0)
self.assertEquals(output_value.get_value(), input_value.get_value())


if __name__ == '__main__':
unittest.main()

0 comments on commit 9e808c3

Please sign in to comment.