diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py index 9c81857339c6..7d150c7c3d34 100644 --- a/python/tvm/contrib/debugger/debug_runtime.py +++ b/python/tvm/contrib/debugger/debug_runtime.py @@ -85,7 +85,7 @@ class GraphModuleDebug(graph_runtime.GraphModule): Parameters ---------- module : Module - The interal tvm module that holds the actual graph functions. + The internal tvm module that holds the actual graph functions. ctx : TVMContext The context this module is under. @@ -188,7 +188,7 @@ def _run_debug(self): out_tensor = array(out_tensor) self.debug_datum._output_tensor_list.append(out_tensor) - def debug_get_output(self, node, out): + def debug_get_output(self, node, out=None): """Run graph up to node and get the output to out Parameters @@ -199,12 +199,11 @@ def debug_get_output(self, node, out): out : NDArray The output array container """ - ret = None if isinstance(node, str): output_tensors = self.debug_datum.get_output_tensors() try: - ret = output_tensors[node] - except: + out = output_tensors[node] + except KeyError: node_list = output_tensors.keys() raise RuntimeError( "Node " @@ -215,10 +214,10 @@ def debug_get_output(self, node, out): ) elif isinstance(node, int): output_tensors = self.debug_datum._output_tensor_list - ret = output_tensors[node] + out = output_tensors[node] else: raise RuntimeError("Require node index or name only.") - return ret + return out def run(self, **input_dict): """Run forward execution of the graph with debug @@ -244,7 +243,6 @@ def run_individual(self, number, repeat=1, min_repeat_ms=0): ret = self._run_individual(number, repeat, min_repeat_ms) return ret.strip(",").split(",") if ret else [] - def exit(self): """Exits the dump folder and all its contents""" self._remove_dump_root() diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 99e0bba7af83..2c945d2fca95 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -22,6 +22,7 @@ from .._ffi.runtime_ctypes import TVMContext from ..rpc import base as rpc_base + def create(graph_json_str, libmod, ctx): """Create a runtime executor module given a graph and module. Parameters @@ -57,6 +58,7 @@ def create(graph_json_str, libmod, ctx): return GraphModule(fcreate(graph_json_str, libmod, *device_type_id)) + def get_device_ctx(libmod, ctx): """Parse and validate all the device context(s). Parameters @@ -112,12 +114,12 @@ class GraphModule(object): Parameters ---------- module : Module - The interal tvm module that holds the actual graph functions. + The internal tvm module that holds the actual graph functions. Attributes ---------- module : Module - The interal tvm module that holds the actual graph functions. + The internal tvm module that holds the actual graph functions. """ def __init__(self, module): @@ -142,7 +144,7 @@ def set_input(self, key=None, value=None, **params): The input key params : dict of str to NDArray - Additonal arguments + Additional arguments """ if key is not None: self._get_input(key).copyfrom(value) @@ -211,7 +213,7 @@ def get_output(self, index, out=None): return self._get_output(index) def debug_get_output(self, node, out): - """Run graph upto node and get the output to out + """Run graph up to node and get the output to out Parameters ----------