forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nomnigraph.py
123 lines (102 loc) · 3.58 KB
/
nomnigraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import caffe2.python._import_c_extension as C
from caffe2.python import core
from caffe2.proto import caffe2_pb2
import os
from subprocess import Popen, PIPE
import errno
class NNModule(object):
def __init__(self, net=None, device_map=None):
if net is not None:
serialized_proto = None
if isinstance(net, core.Net):
serialized_proto = net.Proto().SerializeToString()
elif isinstance(net, caffe2_pb2.NetDef):
serialized_proto = net.SerializeToString()
# Distributed
if device_map is not None:
serialized_device_map = {}
for k in device_map:
serialized_device_map[k] = device_map[k].SerializeToString()
self._NNModule = C.NNModuleFromProtobufDistributed(serialized_proto,
serialized_device_map)
# Default
elif serialized_proto:
self._NNModule, self._OpList = C.NNModuleFromProtobuf(serialized_proto)
else:
raise Exception(
"NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
)
else:
self._NNModule = C.NNModule()
@property
def dataFlow(self):
return self._NNModule.dataFlow()
@property
def controlFlow(self):
return self._NNModule.getExecutionOrder()
@property
def nodes(self):
return self._NNModule.dataFlow().nodes
@property
def operators(self):
return self._NNModule.dataFlow().operators
@property
def tensors(self):
return self._NNModule.dataFlow().tensors
def createNode(self, val):
return self._NNModule.dataFlow().createNode(val)
def deleteNode(self, node):
return self._NNModule.dataFlow().deleteNode(node)
def createEdge(self, a, b):
return self._NNModule.dataFlow().createEdge(a, b)
def deleteEdge(self, a, b=None):
if b:
self._NNModule.dataFlow().deleteEdge(a, b)
else:
self._NNModule.dataFlow().deleteEdge(a)
def replaceNode(self, old_node, new_node):
return self._NNModule.dataFlow().replaceNode(old_node, new_node)
def convertToCaffe2Proto(self, old_proto=None):
if not old_proto:
old_proto = caffe2_pb2.NetDef()
output = self._NNModule.convertToCaffe2Proto(old_proto)
new_proto = caffe2_pb2.NetDef()
new_proto.ParseFromString(output)
return new_proto
def match(self, pattern):
for n in self.dataFlow.getMutableNodes():
m = C.matchSubgraph(n, pattern)
if m:
yield m
def render(s):
s = str(s)
cmd_exists = lambda x: any(
os.access(os.path.join(path, x), os.X_OK)
for path in os.environ["PATH"].split(os.pathsep)
)
if cmd_exists("graph-easy"):
p = Popen("graph-easy", stdin=PIPE)
try:
p.stdin.write(s.encode("utf-8"))
except IOError as e:
if e.errno == errno.EPIPE or e.errno == errno.EINVAL:
pass
else:
# Raise any other error.
raise
p.stdin.close()
p.wait()
else:
print(s)
NeuralNetOperator = C.NeuralNetOperator
Operator = C.NeuralNetOperator
NeuralNetData = C.NeuralNetData
Data = C.NeuralNetData
NNSubgraph = C.NNSubgraph
NNMatchGraph = C.NNMatchGraph
Graph = C.Graph
Annotation = C.Annotation