diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 6c3ef29e1cbd..fd66e3c1f367 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -733,54 +733,49 @@ def _convert_elemwise_input(data, input_type): } -def run_jit_passes(graph): +def _run_jit_passes(graph): """ The inline pass is necessary to unwrap prim::CallMethod """ import torch if version.parse(torch.__version__) >= version.parse("1.4.0"): torch._C._jit_pass_inline(graph) -def is_int_seq(seq): +def _is_int_seq(seq): return len(seq) > 0 and all([isinstance(i, int) for i in seq]) -def get_tensor_and_var(torch_tensor, name): +def _get_tensor_and_var(torch_tensor, name): tensor = tvm.nd.array(torch_tensor.cpu().numpy()) var = _expr.var(name, shape=tensor.shape) return tensor, var -def get_output_name(node): +def _get_output_name(node): assert node.outputsSize() == 1 return node.output().debugName() -def get_output_names(node): +def _get_output_names(node): return [output.debugName() for output in node.outputs()] -def get_input_names(node_or_graph): +def _get_input_names(node_or_graph): return [inp.debugName() for inp in node_or_graph.inputs()] -def get_op_inputs(op_node, outputs, output_index_map): +def _get_op_inputs(op_node, outputs, output_index_map): input_names = [output_index_map[name] - for name in get_input_names(op_node)] + for name in _get_input_names(op_node)] return [outputs[name] for name in input_names] -def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map): +def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map): for output_name, output in name_output_pairs: output_index_map[output_name] = len(outputs) outputs.append(output) -def get_all_op_names(graph): - nodes = list(graph.nodes()) - return set(node.kind() for node in nodes) - - -def report_missing_conversion(op_names): +def _report_missing_conversion(op_names): """ Check if all ops in an input graph are supported by TVM """ known_ops = ["prim::Constant", "prim::GetAttr", "prim::ListConstruct", "prim::ListUnpack", @@ -795,58 +790,18 @@ def report_missing_conversion(op_names): raise NotImplementedError(msg) -def getattr_attr_name(node): +def _getattr_attr_name(node): attribute_names = node.attributeNames() assert len(attribute_names) == 1 attr_name = node.s(attribute_names[0]) return attr_name -def get_full_attr_name(getattrs): - return ".".join([getattr_attr_name(node) for node in getattrs]) - - -def get_use_chains(root_node, terminate=lambda _: False): - """ - Track a chain of users of this node forward, returning a list of chains - See get_attr_chains below for its usage - """ - def concat_lists(lists): - return itertools.chain.from_iterable(lists) - - def inner(current, accum): - users = [] - for output in current.outputs(): - users += [use.user for use in output.uses()] - - if not users or terminate(users): - return [accum] - - return concat_lists([inner(nxt, accum + [nxt]) for nxt in users]) - - return inner(root_node, [root_node]) - - -def get_attr_chains(root_getattr_node): - """ Returns chains of attribute access starting from root_getattr_node - - For example, given attribute "block", as in "self.block" when "self" points - to the top level torch.nn.Module, it returns lists of attribute "chains", - e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params'] - - These sets of attributes form full attribute accessors. For example, - "self.block.1", "self.block.2" will return the second and third submodule, - and "self.block.0._packed_params" will return the parameters of the first - submodule. - """ - def terminate(users): - next_attrs = [user for user in users if user.kind() == "prim::GetAttr"] - return len(next_attrs) == 0 - - return get_use_chains(root_getattr_node, terminate) +def _getattr_full_name(getattrs): + return ".".join([_getattr_attr_name(node) for node in getattrs]) -def get_input_types(op_node): +def _get_input_types(op_node): """ Returns a torch type for each input nodes """ input_list_types = [] for input_node in op_node.inputs(): @@ -854,7 +809,7 @@ def get_input_types(op_node): input_node_kind = in_ty.kind() if input_node_kind == 'TensorType': if in_ty.scalarType() is None: - input_list_types.append('float') + input_list_types.append(None) else: input_list_types.append(in_ty.scalarType().lower()) elif input_node_kind == 'ListType': @@ -874,7 +829,7 @@ def get_input_types(op_node): return input_list_types -def get_constant(node): +def _get_constant(node): """ Retrieve a constant associated with this prim::Constant node """ attribute_names = node.attributeNames() num_attributes = len(attribute_names) @@ -903,15 +858,15 @@ def get_constant(node): return None -def get_operator_nodes(nodes): +def _get_operator_nodes(nodes): """ Returns torch IR nodes that need conversion to Relay """ ops = {} # Traverse nodes and add to graph for node in nodes: if node.outputsSize() > 1: - node_name = "_".join(get_output_names(node)) + node_name = "_".join(_get_output_names(node)) else: - node_name = get_output_name(node) + node_name = _get_output_name(node) if node.kind() != "prim::GetAttr": ops[node_name] = node @@ -930,6 +885,46 @@ def parse_inputs(graph_inputs, input_shapes): return input_vars +def get_use_chains(root_node, terminate=lambda _: False): + """ + Track a chain of users of this node forward, returning a list of chains + See get_attr_chains below for its usage + """ + def concat_lists(lists): + return itertools.chain.from_iterable(lists) + + def inner(current, accum): + users = [] + for output in current.outputs(): + users += [use.user for use in output.uses()] + + if not users or terminate(users): + return [accum] + + return concat_lists([inner(nxt, accum + [nxt]) for nxt in users]) + + return inner(root_node, [root_node]) + + +def get_attr_chains(root_getattr_node): + """ Returns chains of attribute access starting from root_getattr_node + + For example, given attribute "block", as in "self.block" when "self" points + to the top level torch.nn.Module, it returns lists of attribute "chains", + e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params'] + + These sets of attributes form full attribute accessors. For example, + "self.block.1", "self.block.2" will return the second and third submodule, + and "self.block.0._packed_params" will return the parameters of the first + submodule. + """ + def terminate(users): + next_attrs = [user for user in users if user.kind() == "prim::GetAttr"] + return len(next_attrs) == 0 + + return get_use_chains(root_getattr_node, terminate) + + def parse_params(graph, state_dict): """ Return Relay vars and TVM NDArrays for input parameters @@ -941,19 +936,19 @@ def parse_params(graph, state_dict): seen = set() for node in getattr_nodes: - if get_output_name(node) in seen: + if _get_output_name(node) in seen: continue for getattrs in get_attr_chains(node): - seen.update(map(get_output_name, getattrs)) + seen.update(map(_get_output_name, getattrs)) - full_attr = get_full_attr_name(getattrs) - full_attr_node_name = get_output_name(getattrs[-1]) + full_attr = _getattr_full_name(getattrs) + full_attr_node_name = _get_output_name(getattrs[-1]) if full_attr in state_dict: torch_tensor = state_dict[full_attr] - tensor, var = get_tensor_and_var(torch_tensor, - full_attr_node_name) + tensor, var = _get_tensor_and_var(torch_tensor, + full_attr_node_name) param_tensors[full_attr_node_name] = tensor params[full_attr_node_name] = var @@ -964,12 +959,12 @@ def parse_operators(operators, outputs, output_index_map, ret_name): """ Convert each Torch IR operators to Relay equivalent """ for node_name, op_node in operators.items(): operator = op_node.kind() - inputs = get_op_inputs(op_node, outputs, output_index_map) + inputs = _get_op_inputs(op_node, outputs, output_index_map) if operator == "prim::Constant": output_index_map[node_name] = len(outputs) - outputs.append(get_constant(op_node)) - elif operator == 'prim::ListConstruct' and is_int_seq(inputs): + outputs.append(_get_constant(op_node)) + elif operator == 'prim::ListConstruct' and _is_int_seq(inputs): output_index_map[node_name] = len(outputs) outputs.append(_expr.var(node_name, shape=inputs)) elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']: @@ -977,22 +972,28 @@ def parse_operators(operators, outputs, output_index_map, ret_name): outputs.append(inputs) elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: assert len(inputs) == 1 - unpacked_names = get_output_names(op_node) - update_outputs_from_pairs(zip(unpacked_names, inputs[0]), - outputs, output_index_map) + unpacked_names = _get_output_names(op_node) + _update_outputs_from_pairs(zip(unpacked_names, inputs[0]), + outputs, output_index_map) else: output_index_map[node_name] = len(outputs) relay_op = _convert_map[operator] - outputs.append(relay_op(inputs, get_input_types(op_node))) + outputs.append(relay_op(inputs, _get_input_types(op_node))) return outputs[output_index_map[ret_name]] +def get_all_op_names(graph): + """ Return all operator names in the input graph """ + nodes = list(graph.nodes()) + return set(node.kind() for node in nodes) + + def get_graph_input_names(script_module): """ Use this function to set the keys for input_shapes""" # It seems variable names could change the first time a copy is made # Use the copy of the graph here to prevent troubles later - ir_inputs = get_input_names(script_module.graph.copy()) + ir_inputs = _get_input_names(script_module.graph.copy()) return ir_inputs[1:] # remove self at the 0th arg @@ -1019,9 +1020,9 @@ def from_pytorch(script_module, input_shapes): Dict of converted parameters stored in tvm.runtime.ndarray format """ graph = script_module.graph.copy() - run_jit_passes(graph) + _run_jit_passes(graph) op_names = get_all_op_names(graph) - report_missing_conversion(op_names) + _report_missing_conversion(op_names) params = script_module.state_dict() input_vars = parse_inputs(graph.inputs(), input_shapes) @@ -1030,9 +1031,9 @@ def from_pytorch(script_module, input_shapes): input_vars.update(param_vars) outputs = list(input_vars.values()) output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) - ret_name = get_input_names(graph.return_node())[0] + ret_name = _get_input_names(graph.return_node())[0] - body = parse_operators(get_operator_nodes(graph.nodes()), outputs, + body = parse_operators(_get_operator_nodes(graph.nodes()), outputs, output_index_map, ret_name) func = tvm.relay.Function(_analysis.free_vars(body), body) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}