Skip to content

Commit

Permalink
Expose basic transformation API to Python (pytorch#13033)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#13033

Basic graph manipulation exposed to python

Reviewed By: ZolotukhinM

Differential Revision: D10519720

fbshipit-source-id: 0f9a494d122289a3a9e23d4cff99ac0a21382ec6
  • Loading branch information
bwasti authored and facebook-github-bot committed Oct 24, 2018
1 parent 4e0b6c8 commit 53ac4de
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 1 deletion.
18 changes: 18 additions & 0 deletions caffe2/python/nomnigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ def operators(self):
def tensors(self):
return self._NNModule.dataFlow().tensors

def createNode(self, val):
return self._NNModule.dataFlow().createNode(val)

def deleteNode(self, node):
return self._NNModule.dataFlow().deleteNode(node)

def createEdge(self, a, b):
return self._NNModule.dataFlow().createEdge(a, b)

def deleteEdge(self, a, b=None):
if b:
self._NNModule.dataFlow().deleteEdge(a, b)
else:
self._NNModule.dataFlow().deleteEdge(a)

def replaceNode(self, old_node, new_node):
return self._NNModule.dataFlow().replaceNode(old_node, new_node)

def convertToCaffe2Proto(self, old_proto=None):
if not old_proto:
old_proto = caffe2_pb2.NetDef()
Expand Down
78 changes: 78 additions & 0 deletions caffe2/python/nomnigraph_transformations_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from caffe2.python import core, workspace, test_util
from caffe2.proto import caffe2_pb2
import caffe2.python.nomnigraph as ng

import numpy as np
from hypothesis import given
import hypothesis.strategies as st
import random


class TestNomnigraphTransformations(test_util.TestCase):
def test_simple_replace(self):
net = core.Net("name")
net.FC(["X", "W"], ["Y"])
nn = ng.NNModule(net)
fc = nn.controlFlow[0]
add = nn.createNode(core.CreateOperator("Add", ["X"], ["Y"], engine="CUDNN"))
nn.replaceNode(fc, add)
nn.deleteNode(fc)

# Test it out
new_netdef = nn.convertToCaffe2Proto()
workspace.FeedBlob("X", np.array([1, 2, 3]))
workspace.FeedBlob("W", np.array([1, 2, 3]))
workspace.RunNetOnce(new_netdef)
out = workspace.FetchBlob("Y")
expected_out = np.array([2, 4, 6])
np.allclose(out, expected_out)

def test_simple_rewire(self):
net = core.Net("name")
# Rewire this so that we get
# c = Add(a, d)
# e = Mul(c, b)
#
# if a = 1, b = 2, d = 3
# we get 8: (1 + 3) * 2
# as opposed to 7: 1 + (3 * 2)
net.Mul(["a", "b"], ["c"])
net.Add(["c", "d"], ["e"])
nn = ng.NNModule(net)

mul = nn.controlFlow[0]
add = nn.controlFlow[1]
a = mul.inputs[0]
b = mul.inputs[1]
c = mul.outputs[0]
d = add.inputs[1]
e = add.outputs[0]

nn.deleteEdge(a, mul)
nn.deleteEdge(b, mul)
nn.deleteEdge(mul, c)
nn.deleteEdge(c, add)
nn.deleteEdge(d, add)
nn.deleteEdge(add, e)

nn.createEdge(a, add)
nn.createEdge(d, add)
nn.createEdge(add, c)
nn.createEdge(c, mul)
nn.createEdge(b, mul)
nn.createEdge(mul, e)

# Test it out
new_netdef = nn.convertToCaffe2Proto()
workspace.FeedBlob("a", np.array([1, 1, 1]))
workspace.FeedBlob("b", np.array([2, 2, 2]))
workspace.FeedBlob("d", np.array([3, 3, 3]))
workspace.RunNetOnce(new_netdef)
out = workspace.FetchBlob("e")
expected_out = np.array([8, 8, 8])
np.allclose(out, expected_out)
15 changes: 14 additions & 1 deletion caffe2/python/pybind_state_nomni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,15 @@ void addNomnigraphMethods(pybind11::module& m) {
"Edges must exist between NeuralNetOperator and NeuralNetData");
g->createEdge(a, b);
})

.def("deleteEdge", &NNGraph::deleteEdge)
.def(
"deleteEdge",
[](NNGraph* g, NNGraph::NodeRef a, NNGraph::NodeRef b) {
auto edge = g->getEdgeIfExists(a, b);
if (edge) {
g->deleteEdge(edge);
}
})
.def(
"createNode",
[](NNGraph* g, GenericOperator& op) {
Expand Down Expand Up @@ -217,6 +225,11 @@ void addNomnigraphMethods(pybind11::module& m) {
},
py::return_value_policy::reference_internal)
.def("deleteNode", &NNGraph::deleteNode)
.def(
"replaceNode",
[](NNGraph* g, NNGraph::NodeRef old_node, NNGraph::NodeRef new_node) {
g->replaceNode(old_node, new_node);
})
.def(
"getMutableNodes",
&NNGraph::getMutableNodes,
Expand Down

0 comments on commit 53ac4de

Please sign in to comment.