Skip to content

Commit

Permalink
Add operator_def property to annotation (pytorch#13094)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#13094

Expose operator_def property

Reviewed By: duc0

Differential Revision: D10847125

fbshipit-source-id: 67a066555b690715e1f5f04125fd446ab197f45a
  • Loading branch information
bwasti authored and facebook-github-bot committed Oct 25, 2018
1 parent b883afc commit f1e4304
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
11 changes: 11 additions & 0 deletions caffe2/python/nomnigraph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,17 @@ def test_annotation_from_graph(self):
new_annot = node.getAnnotation()
assert new_annot.getDeviceType() == 7

def test_annotation_operator_def(self):
nn = ng.NNModule()
opdef = core.CreateOperator("Conv", [], [], engine="SENTINEL")
node = nn.dataFlow.createNode(opdef)
assert node.annotation.operator_def.engine == "SENTINEL"
opdef = core.CreateOperator("Conv", [], [], engine="NEW_SENTINEL")
node.annotation.operator_def = opdef
netdef = nn.convertToCaffe2Proto()
assert len(netdef.op) == 1
assert netdef.op[0].engine == "NEW_SENTINEL"

def test_annotation_device_option(self):
nn = ng.NNModule()
node = nn.dataFlow.createNode(ng.NeuralNetOperator("TestOp"))
Expand Down
25 changes: 23 additions & 2 deletions caffe2/python/pybind_state_nomni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,13 +490,34 @@ void addNomnigraphMethods(pybind11::module& m) {
[](Caffe2Annotation& annot, py::object& def) {
CAFFE_ENFORCE(
pybind11::hasattr(def, "SerializeToString"),
"convertToCaffe2Proto takes either no args",
"a NetDef");
"device_option can only be set to a DeviceOption");
auto str = def.attr("SerializeToString")();
caffe2::DeviceOption proto;
proto.ParseFromString(py::bytes(str));
annot.setDeviceOption(proto);
},
py::return_value_policy::reference)
.def_property(
"operator_def",
[](Caffe2Annotation& annot) {
auto opDef = py::module::import("caffe2.proto.caffe2_pb2")
.attr("OperatorDef");
auto proto = annot.getOperatorDef();
std::string serialized_proto;
proto.SerializeToString(&serialized_proto);
auto py_op_def= opDef();
py_op_def.attr("ParseFromString")(py::bytes(serialized_proto));
return py_op_def;
},
[](Caffe2Annotation& annot, py::object& def) {
CAFFE_ENFORCE(
pybind11::hasattr(def, "SerializeToString"),
"operator_def can only be set to an OperatorDef");
auto str = def.attr("SerializeToString")();
caffe2::OperatorDef proto;
proto.ParseFromString(py::bytes(str));
annot.setOperatorDef(proto);
},
py::return_value_policy::reference);
}

Expand Down

0 comments on commit f1e4304

Please sign in to comment.