Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Mar 19, 2024
1 parent 9f47752 commit c3cbc0d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 54 deletions.
62 changes: 33 additions & 29 deletions nncf/torch/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ def extract_conv(
model: NNCFNetwork,
) -> ExtractedFunc:
"""
Extracts a convolutional layer from an NNCF graph and constructs an ExtractedConv module.
Extracts a convolutional layer from an NNCF graph and constructs an ExtractedFunc module.
:param node: The NNCF node representing the convolutional layer to extract.
:param input_nodes: The name of input node.
:param output_nodes: The name of output node.
:param model: The NNCF network containing the layer.
:return: The extracted convolutional layer as an ExtractedFunc module.
"""
Expand All @@ -91,7 +92,6 @@ def extract_conv(
"dilation": input_node.layer_attributes.dilations,
"groups": input_node.layer_attributes.groups,
}
extracted_module = ExtractedFunc(input_node.node_type, kwargs)
elif input_node.metatype in CONV_TRANSPOSE_METATYPES:
kwargs = {
"weight": e_weight.clone(),
Expand All @@ -101,7 +101,8 @@ def extract_conv(
"output_padding": input_node.layer_attributes.output_padding_values,
"dilation": input_node.layer_attributes.dilations,
}
extracted_module = ExtractedFunc(input_node.node_type, kwargs)

extracted_module = ExtractedFunc(input_node.node_type, kwargs)

if input_node != output_node:
extracted_module = try_to_fuse_conv(input_node, output_node, model, extracted_module)
Expand Down Expand Up @@ -155,39 +156,45 @@ def extract_bn(node: NNCFNode, model: NNCFNetwork) -> Optional[Union[nn.BatchNor
return extracted_bn


def try_to_fuse_conv(input_node: NNCFNode, output_node: NNCFNode, model: NNCFNetwork, extracted_module: nn.Module):
def try_to_fuse_conv(
input_node: NNCFNode, output_node: NNCFNode, model: NNCFNetwork, extracted_module: nn.Module
) -> nn.Module:
"""
Fused convolution operation with next batch if possible,
Fused convolution operation with the next batch norm node if possible,
:param input_node: Input subgraph node.
:param output_node: Output subgraph node (fused with input node).
:param model: Source model.
:param extracted_module: Extracted module.
"""
next_nodes = model.nncf.get_graph().get_next_nodes(input_node)
if len(next_nodes) == 1:
if output_node != next_nodes[0]:
raise nncf.InternalError(f"Output node {output_node} not found after {input_node}")
extracted_bn = extract_bn(next_nodes[0], model)
if next_nodes[0].metatype == om.PTBatchNormMetatype:
extracted_bn = extract_bn(next_nodes[0], model)
if extracted_bn is None:
nncf_logger.debug(
f"Can`t extract fused batchnorm module for {input_node.node_name},"
" module that contain batchnorm operator should be inhered from one of {BATCH_NORM_CLASSES}."
)
return None
return nn.Sequential(extracted_module, extracted_bn)
return extracted_module

if len(next_nodes) != 1:
return extracted_module

if output_node != next_nodes[0]:
raise nncf.InternalError(f"Output node {output_node} not found after {input_node}")

if next_nodes[0].metatype != om.PTBatchNormMetatype:
raise nncf.InternalError("Supported only BatchNorm layers")

extracted_bn = extract_bn(next_nodes[0], model)
if extracted_bn is None:
nncf_logger.debug(
f"Can`t extract fused batchnorm module for {input_node.node_name},"
" module that contain batchnorm operator should be inhered from one of {BATCH_NORM_CLASSES}."
)
return None
return nn.Sequential(extracted_module, extracted_bn)


def extract_model(model: NNCFNetwork, input_nodes: List[str], output_nodes: List[str]) -> Optional[nn.Module]:
"""
Extracts a submodule from a given NNCF network containing only the nodes from the input to the output node.
:param model: The NNCF network to extract the submodule from.
:param input_nodes: List containing the name of input nodes for the submodule.
:param output_nodes: List containing the name of output nodes for the submodule.
:param input_nodes: List containing names of the input nodes for the submodule.
:param output_nodes: List containing names of the output nodes for the submodule.
:return: An nn.Module containing the extracted submodel, or None if extraction is not supported.
"""

Expand All @@ -198,11 +205,8 @@ def extract_model(model: NNCFNetwork, input_nodes: List[str], output_nodes: List
input_node = graph.get_node_by_name(input_nodes[0])
output_node = graph.get_node_by_name(output_nodes[0])

extracted_module: Optional[nn.Module] = None

if input_node.metatype in CONV_METATYPES + CONV_TRANSPOSE_METATYPES:
extracted_module = extract_conv(input_node, output_node, model)
else:
nncf_logger.debug(f"Can`t extract module for {input_node.node_name}")
return None
return extracted_module
return extract_conv(input_node, output_node, model)

nncf_logger.debug(f"Can`t extract module for {input_node.node_name}")
return None
56 changes: 31 additions & 25 deletions nncf/torch/model_graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def get_const_node(node: NNCFNode, port_id: int, graph: NNCFGraph) -> Optional[N
:param node: The NNCF node for which to find the constant input node.
:param port_id: The ID of the input port to consider.
:param graph: The NNCF graph containing the nodes.
:return: The NNCF node providing the constant input to the specified port, or None if no such node is found.
"""
for prev_node in graph.get_previous_nodes(node):
Expand All @@ -79,9 +78,9 @@ def get_const_node(node: NNCFNode, port_id: int, graph: NNCFGraph) -> Optional[N

def split_const_name(const_name: str) -> Tuple[str, str]:
"""
Splits a constant name into module and attribute names.
Splits the constant name into module and attribute names.
:param weight_name: The full name of the constant, including module and attribute.
:param const_name: The full name of the constant, including module and attribute.
:return:
- module_name: The name of the module containing the constant.
- weight_attr_name: The name of the constant attribute within the module.
Expand Down Expand Up @@ -115,7 +114,7 @@ def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Mod
return curr_module


def get_const_data(const_node: NNCFNode, model: NNCFNetwork):
def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor:
"""
Retrieves a constant tensor associated with a given node.
Expand All @@ -134,7 +133,7 @@ def get_const_data(const_node: NNCFNode, model: NNCFNetwork):

def get_const_data_on_port(node: NNCFNode, port_id: int, model: NNCFNetwork) -> torch.Tensor:
"""
Retrieves a constant tensor associated with a given node and port in an NNCF graph.
Retrieves a constant tensor associated with a given node and input port in an NNCF graph.
:param node: The node to retrieve the constant from.
:param port_id: The port id within the node that holds the constant.
Expand All @@ -150,11 +149,11 @@ def get_const_data_on_port(node: NNCFNode, port_id: int, model: NNCFNetwork) ->

def get_potential_fused_node(node_name: str, nncf_graph: NNCFGraph) -> Optional[NNCFNode]:
"""
Get next node that can contain fused bias in runtime.
Retrieves the next node in the NNCF graph that could be fused with the provided node during runtime optimization.
:param node_name: The node name.
:param nncf_graph: The NNCF graph.
:return: The node that can be fused or None.
:return: The node that can be fused or None if no suitable node is found.
"""
target_node = nncf_graph.get_node_by_name(node_name)

Expand Down Expand Up @@ -188,7 +187,7 @@ def is_node_with_fused_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:

def get_fused_bias_value(node: NNCFNode, model: NNCFNetwork) -> Optional[torch.Tensor]:
"""
Returns the bias tensor for the node or potential fused node.
Returns the bias tensor for the node or for potential fused node.
:param node: The node that corresponds to the operation with bias.
:param model: The model that contains this operation.
Expand Down Expand Up @@ -219,6 +218,13 @@ def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[int]:


def set_const_data(data: torch.Tensor, const_node: NNCFNode, model: NNCFNetwork) -> None:
"""
Sets the constant data associated with a specific constant node in an NNCF network model.
:param data: The constant data tensor to be set.
:param const_node: The NNCF node representing the constant data.
:param model: The NNCF network model.
"""
const_name = const_node.layer_attributes.name
module_name, const_attr_name = split_const_name(const_name)
module = get_module_by_name(module_name, model)
Expand Down Expand Up @@ -252,27 +258,27 @@ def set_const_data_to_port_id(data: torch.Tensor, node: NNCFNode, port_id: int,
setattr(module, const_attr_name, data)


def _find_fq_node_in_constant_subgraph(node: NNCFNode, graph: NNCFGraph) -> Optional[NNCFNode]:
def _find_fq_node_in_constant_subgraph(start_node: NNCFNode, graph: NNCFGraph) -> Optional[NNCFNode]:
"""
Finds a fake quantize node within a constant subgraph.
:param node: The starting node within the subgraph.
:param start_node: The starting node within the subgraph.
:param graph: The NNCFGraph containing the subgraph
:return: The found fake quantize node, or None if not found.
:return: The founded fake quantize node, or None if not found.
"""
if node.metatype == om.PTNoopMetatype:
prev_nodes = graph.get_previous_nodes(node)
if start_node.metatype == om.PTNoopMetatype:
prev_nodes = graph.get_previous_nodes(start_node)
if len(prev_nodes) != 1:
return None
return find_const_node_in_constant_subgraph(prev_nodes[0], graph)
if node.node_type in om.QUANTIZE_NODE_TYPES:
return node
return _find_fq_node_in_constant_subgraph(prev_nodes[0], graph)
if start_node.node_type in om.QUANTIZE_NODE_TYPES:
return start_node
return None


def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
"""
Check that module have fake_quantizer for weight.
Check that module have fake_quantizer for its weights.
:param node: The target node.
:param nncf_graph: The NNCF graph.
Expand Down Expand Up @@ -300,15 +306,15 @@ def get_fake_quantizer(

address_map = model.nncf.get_node_to_op_address_mapping()
op_addr = address_map[node.node_name]

if port_id is not None:
id = PreHookId(op_address=op_addr, input_port_id=port_id)
for call_hook in model.nncf._compressed_context._pre_hooks.get(id, {}).values():
if isinstance(call_hook, ExternalQuantizerCallHook):
storage = getattr(model.nncf, call_hook._storage_name)
return storage[call_hook._storage_key]
hook_container = model.nncf._compressed_context._pre_hooks.get(id, {})
else:
for call_hook in model.nncf._compressed_context._post_hooks.get(op_addr, {}).values():
if isinstance(call_hook, ExternalQuantizerCallHook):
storage = getattr(model.nncf, call_hook._storage_name)
return storage[call_hook._storage_key]
hook_container = model.nncf._compressed_context._post_hooks.get(op_addr, {})

for call_hook in hook_container.values():
if isinstance(call_hook, ExternalQuantizerCallHook):
storage = getattr(model.nncf, call_hook._storage_name)
return storage[call_hook._storage_key]
return None

0 comments on commit c3cbc0d

Please sign in to comment.