diff --git a/caffe2/python/nomnigraph_test.py b/caffe2/python/nomnigraph_test.py index ebeec7cf020262..5900a020dd9cf4 100644 --- a/caffe2/python/nomnigraph_test.py +++ b/caffe2/python/nomnigraph_test.py @@ -274,6 +274,17 @@ def test_annotation_from_graph(self): new_annot = node.getAnnotation() assert new_annot.getDeviceType() == 7 + def test_annotation_operator_def(self): + nn = ng.NNModule() + opdef = core.CreateOperator("Conv", [], [], engine="SENTINEL") + node = nn.dataFlow.createNode(opdef) + assert node.annotation.operator_def.engine == "SENTINEL" + opdef = core.CreateOperator("Conv", [], [], engine="NEW_SENTINEL") + node.annotation.operator_def = opdef + netdef = nn.convertToCaffe2Proto() + assert len(netdef.op) == 1 + assert netdef.op[0].engine == "NEW_SENTINEL" + def test_annotation_device_option(self): nn = ng.NNModule() node = nn.dataFlow.createNode(ng.NeuralNetOperator("TestOp")) diff --git a/caffe2/python/pybind_state_nomni.cc b/caffe2/python/pybind_state_nomni.cc index b6588fded82435..5222e315aabe05 100644 --- a/caffe2/python/pybind_state_nomni.cc +++ b/caffe2/python/pybind_state_nomni.cc @@ -490,13 +490,34 @@ void addNomnigraphMethods(pybind11::module& m) { [](Caffe2Annotation& annot, py::object& def) { CAFFE_ENFORCE( pybind11::hasattr(def, "SerializeToString"), - "convertToCaffe2Proto takes either no args", - "a NetDef"); + "device_option can only be set to a DeviceOption"); auto str = def.attr("SerializeToString")(); caffe2::DeviceOption proto; proto.ParseFromString(py::bytes(str)); annot.setDeviceOption(proto); }, + py::return_value_policy::reference) + .def_property( + "operator_def", + [](Caffe2Annotation& annot) { + auto opDef = py::module::import("caffe2.proto.caffe2_pb2") + .attr("OperatorDef"); + auto proto = annot.getOperatorDef(); + std::string serialized_proto; + proto.SerializeToString(&serialized_proto); + auto py_op_def= opDef(); + py_op_def.attr("ParseFromString")(py::bytes(serialized_proto)); + return py_op_def; + }, + [](Caffe2Annotation& annot, py::object& def) { + CAFFE_ENFORCE( + pybind11::hasattr(def, "SerializeToString"), + "operator_def can only be set to an OperatorDef"); + auto str = def.attr("SerializeToString")(); + caffe2::OperatorDef proto; + proto.ParseFromString(py::bytes(str)); + annot.setOperatorDef(proto); + }, py::return_value_policy::reference); }