Skip to content

Commit

Permalink
[GraphRuntime] Support parameter out in the graph runtime debug (apac…
Browse files Browse the repository at this point in the history
…he#4598)

* [GraphRuntime] Support parameter out in the graph runtime debug

* Dummy commit to trigger build
  • Loading branch information
cchung100m authored and zhiics committed Dec 31, 2019
1 parent 67a06a4 commit 5fc1712
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
14 changes: 6 additions & 8 deletions python/tvm/contrib/debugger/debug_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
Parameters
----------
module : Module
The interal tvm module that holds the actual graph functions.
The internal tvm module that holds the actual graph functions.
ctx : TVMContext
The context this module is under.
Expand Down Expand Up @@ -188,7 +188,7 @@ def _run_debug(self):
out_tensor = array(out_tensor)
self.debug_datum._output_tensor_list.append(out_tensor)

def debug_get_output(self, node, out):
def debug_get_output(self, node, out=None):
"""Run graph up to node and get the output to out
Parameters
Expand All @@ -199,12 +199,11 @@ def debug_get_output(self, node, out):
out : NDArray
The output array container
"""
ret = None
if isinstance(node, str):
output_tensors = self.debug_datum.get_output_tensors()
try:
ret = output_tensors[node]
except:
out = output_tensors[node]
except KeyError:
node_list = output_tensors.keys()
raise RuntimeError(
"Node "
Expand All @@ -215,10 +214,10 @@ def debug_get_output(self, node, out):
)
elif isinstance(node, int):
output_tensors = self.debug_datum._output_tensor_list
ret = output_tensors[node]
out = output_tensors[node]
else:
raise RuntimeError("Require node index or name only.")
return ret
return out

def run(self, **input_dict):
"""Run forward execution of the graph with debug
Expand All @@ -244,7 +243,6 @@ def run_individual(self, number, repeat=1, min_repeat_ms=0):
ret = self._run_individual(number, repeat, min_repeat_ms)
return ret.strip(",").split(",") if ret else []


def exit(self):
"""Exits the dump folder and all its contents"""
self._remove_dump_root()
10 changes: 6 additions & 4 deletions python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .._ffi.runtime_ctypes import TVMContext
from ..rpc import base as rpc_base


def create(graph_json_str, libmod, ctx):
"""Create a runtime executor module given a graph and module.
Parameters
Expand Down Expand Up @@ -57,6 +58,7 @@ def create(graph_json_str, libmod, ctx):

return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))


def get_device_ctx(libmod, ctx):
"""Parse and validate all the device context(s).
Parameters
Expand Down Expand Up @@ -112,12 +114,12 @@ class GraphModule(object):
Parameters
----------
module : Module
The interal tvm module that holds the actual graph functions.
The internal tvm module that holds the actual graph functions.
Attributes
----------
module : Module
The interal tvm module that holds the actual graph functions.
The internal tvm module that holds the actual graph functions.
"""

def __init__(self, module):
Expand All @@ -142,7 +144,7 @@ def set_input(self, key=None, value=None, **params):
The input key
params : dict of str to NDArray
Additonal arguments
Additional arguments
"""
if key is not None:
self._get_input(key).copyfrom(value)
Expand Down Expand Up @@ -211,7 +213,7 @@ def get_output(self, index, out=None):
return self._get_output(index)

def debug_get_output(self, node, out):
"""Run graph upto node and get the output to out
"""Run graph up to node and get the output to out
Parameters
----------
Expand Down

0 comments on commit 5fc1712

Please sign in to comment.