diff --git a/caffe2/python/nomnigraph.py b/caffe2/python/nomnigraph.py index d69587e9a98ef..542dc37dcc978 100644 --- a/caffe2/python/nomnigraph.py +++ b/caffe2/python/nomnigraph.py @@ -57,6 +57,24 @@ def operators(self): def tensors(self): return self._NNModule.dataFlow().tensors + def createNode(self, val): + return self._NNModule.dataFlow().createNode(val) + + def deleteNode(self, node): + return self._NNModule.dataFlow().deleteNode(node) + + def createEdge(self, a, b): + return self._NNModule.dataFlow().createEdge(a, b) + + def deleteEdge(self, a, b=None): + if b: + self._NNModule.dataFlow().deleteEdge(a, b) + else: + self._NNModule.dataFlow().deleteEdge(a) + + def replaceNode(self, old_node, new_node): + return self._NNModule.dataFlow().replaceNode(old_node, new_node) + def convertToCaffe2Proto(self, old_proto=None): if not old_proto: old_proto = caffe2_pb2.NetDef() diff --git a/caffe2/python/nomnigraph_transformations_test.py b/caffe2/python/nomnigraph_transformations_test.py new file mode 100644 index 0000000000000..9c9c00ad5a754 --- /dev/null +++ b/caffe2/python/nomnigraph_transformations_test.py @@ -0,0 +1,78 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from caffe2.python import core, workspace, test_util +from caffe2.proto import caffe2_pb2 +import caffe2.python.nomnigraph as ng + +import numpy as np +from hypothesis import given +import hypothesis.strategies as st +import random + + +class TestNomnigraphTransformations(test_util.TestCase): + def test_simple_replace(self): + net = core.Net("name") + net.FC(["X", "W"], ["Y"]) + nn = ng.NNModule(net) + fc = nn.controlFlow[0] + add = nn.createNode(core.CreateOperator("Add", ["X"], ["Y"], engine="CUDNN")) + nn.replaceNode(fc, add) + nn.deleteNode(fc) + + # Test it out + new_netdef = nn.convertToCaffe2Proto() + workspace.FeedBlob("X", np.array([1, 2, 3])) + workspace.FeedBlob("W", np.array([1, 2, 3])) + workspace.RunNetOnce(new_netdef) + out = workspace.FetchBlob("Y") + expected_out = np.array([2, 4, 6]) + np.allclose(out, expected_out) + + def test_simple_rewire(self): + net = core.Net("name") + # Rewire this so that we get + # c = Add(a, d) + # e = Mul(c, b) + # + # if a = 1, b = 2, d = 3 + # we get 8: (1 + 3) * 2 + # as opposed to 7: 1 + (3 * 2) + net.Mul(["a", "b"], ["c"]) + net.Add(["c", "d"], ["e"]) + nn = ng.NNModule(net) + + mul = nn.controlFlow[0] + add = nn.controlFlow[1] + a = mul.inputs[0] + b = mul.inputs[1] + c = mul.outputs[0] + d = add.inputs[1] + e = add.outputs[0] + + nn.deleteEdge(a, mul) + nn.deleteEdge(b, mul) + nn.deleteEdge(mul, c) + nn.deleteEdge(c, add) + nn.deleteEdge(d, add) + nn.deleteEdge(add, e) + + nn.createEdge(a, add) + nn.createEdge(d, add) + nn.createEdge(add, c) + nn.createEdge(c, mul) + nn.createEdge(b, mul) + nn.createEdge(mul, e) + + # Test it out + new_netdef = nn.convertToCaffe2Proto() + workspace.FeedBlob("a", np.array([1, 1, 1])) + workspace.FeedBlob("b", np.array([2, 2, 2])) + workspace.FeedBlob("d", np.array([3, 3, 3])) + workspace.RunNetOnce(new_netdef) + out = workspace.FetchBlob("e") + expected_out = np.array([8, 8, 8]) + np.allclose(out, expected_out) diff --git a/caffe2/python/pybind_state_nomni.cc b/caffe2/python/pybind_state_nomni.cc index 802ffbfad7b2a..b6588fded8243 100644 --- a/caffe2/python/pybind_state_nomni.cc +++ b/caffe2/python/pybind_state_nomni.cc @@ -182,7 +182,15 @@ void addNomnigraphMethods(pybind11::module& m) { "Edges must exist between NeuralNetOperator and NeuralNetData"); g->createEdge(a, b); }) - + .def("deleteEdge", &NNGraph::deleteEdge) + .def( + "deleteEdge", + [](NNGraph* g, NNGraph::NodeRef a, NNGraph::NodeRef b) { + auto edge = g->getEdgeIfExists(a, b); + if (edge) { + g->deleteEdge(edge); + } + }) .def( "createNode", [](NNGraph* g, GenericOperator& op) { @@ -217,6 +225,11 @@ void addNomnigraphMethods(pybind11::module& m) { }, py::return_value_policy::reference_internal) .def("deleteNode", &NNGraph::deleteNode) + .def( + "replaceNode", + [](NNGraph* g, NNGraph::NodeRef old_node, NNGraph::NodeRef new_node) { + g->replaceNode(old_node, new_node); + }) .def( "getMutableNodes", &NNGraph::getMutableNodes,