From d43890f93c69efcc0a132bf47182e58e60e50f93 Mon Sep 17 00:00:00 2001 From: Nikita Malinin Date: Fri, 29 Sep 2023 18:15:20 +0200 Subject: [PATCH] Initial commit --- .../common/tensor_statistics/collectors.py | 5 +- nncf/openvino/graph/model_transformer.py | 76 +++--- nncf/openvino/graph/node_utils.py | 12 + .../graph/transformations/commands.py | 8 +- nncf/openvino/statistics/collectors.py | 2 +- .../algorithms/bias_correction/algorithm.py | 223 +++++++++--------- .../algorithms/bias_correction/backend.py | 25 +- .../bias_correction/onnx_backend.py | 4 - .../bias_correction/openvino_backend.py | 25 +- tests/openvino/native/test_bias_correction.py | 21 +- .../test_templates/test_bias_correction.py | 2 +- 11 files changed, 188 insertions(+), 215 deletions(-) diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index 3655fffe5d6..ff26c5e71de 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -185,12 +185,13 @@ class TensorCollector: a dict could be collected by `get_statistics` call. """ - def __init__(self, statistic_container: Optional[TensorStatistic] = None) -> None: + def __init__(self, statistic_container: Optional[TensorStatistic] = None, skip_empty_stats: Optional[bool] = True) -> None: self._reducers: Set[TensorReducerBase] = set() self._aggregators: Dict[Tuple[int, int], TensorAggregatorBase] = {} self._stat_container_kwargs_map: Dict[str, Tuple[int, int]] = {} self._stat_container = statistic_container self._enabled = True + self._skip_empty_stats = skip_empty_stats @property def num_samples(self) -> Optional[int]: @@ -279,7 +280,7 @@ def register_inputs(self, inputs: Dict[int, List[NNCFTensor]]) -> None: for reducer in self._reducers: reducer_hash = hash(reducer) input_ = inputs[reducer_hash] - if any(tensor.is_empty() for tensor in input_): + if any(tensor.is_empty() for tensor in input_) and self._skip_empty_stats: continue reduced_inputs[reducer_hash] = reducer(input_) diff --git a/nncf/openvino/graph/model_transformer.py b/nncf/openvino/graph/model_transformer.py index 19e43f4b131..929dc87fd5a 100644 --- a/nncf/openvino/graph/model_transformer.py +++ b/nncf/openvino/graph/model_transformer.py @@ -22,6 +22,7 @@ from nncf.common.graph.model_transformer import TModel from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.openvino.graph.node_utils import get_parameter_node_name from nncf.openvino.graph.node_utils import get_result_node_name from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand @@ -68,25 +69,6 @@ def _get_name_to_node_mapping(model: ov.Model) -> Dict[str, ov.Node]: """ return {op.get_friendly_name(): op for op in model.get_ops()} - @staticmethod - def _get_activation_node_names(model: ov.Model) -> List[str]: - """ - Returns list of the activation node names. - - :param model: Model to get list. - :return: List with the activation names. - """ - activation_nodes = set() - nodes_queue = deque(model.get_parameters()) - while nodes_queue: - node = nodes_queue.popleft() - if node.name in activation_nodes: - continue - activation_nodes.add(node.name) - for node_output in node.outputs(): - nodes_queue.extend([i.get_node() for i in node_output.get_target_inputs()]) - return list(activation_nodes) - @staticmethod def _update_tensor_name(tensors: List[DescriptorTensor], name: str) -> None: """ @@ -389,38 +371,36 @@ def _apply_model_extraction_transformation( """ transformation = transformations[-1] name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model) - activation_node_names = OVModelTransformer._get_activation_node_names(model) params, results = [], [] - for input_name in transformation.inputs: - input_node = name_to_node_mapping[input_name] - if input_name in [tensor.node.get_friendly_name() for tensor in model.inputs]: - params.append(input_node) - continue - for input_port in input_node.inputs(): - if input_port.get_source_output().get_node().name not in activation_node_names: - continue - input_node_output = input_port.get_source_output() - parameter_name = f"Parameter_{input_name}" - new_param = opset.parameter( - shape=input_node_output.partial_shape, - dtype=input_node_output.get_element_type(), - name=parameter_name, - ) - input_port.replace_source_output(new_param.output(0)) - new_param_tensors = [o.get_tensor() for o in new_param.outputs()] - OVModelTransformer._update_tensor_name(new_param_tensors, parameter_name) - params.append(new_param) - for output_name in transformation.outputs: + for node_id in transformation.inputs: + node_name, port_id = node_id + node = name_to_node_mapping[node_name] + + node_input_port = node.input(port_id) + source_output_port = node_input_port.get_source_output() + + parameter_name = get_parameter_node_name(node_name, port_id) + new_param = opset.parameter( + shape=source_output_port.partial_shape, + dtype=source_output_port.get_element_type(), + name=parameter_name, + ) + + node_input_port.replace_source_output(new_param.output(0)) + + new_param_tensors = [o.get_tensor() for o in new_param.outputs()] + OVModelTransformer._update_tensor_name(new_param_tensors, parameter_name) + params.append(new_param) + + for output_id in transformation.outputs: + output_name, port_id = output_id output_node = name_to_node_mapping[output_name] - for node_out in output_node.outputs(): - result_name = get_result_node_name(output_name, 0) - new_result = opset.result(node_out, name=result_name) - OVModelTransformer._update_tensor_name([new_result.get_output_tensor(0)], result_name) - results.append(new_result) - - if not results: - results = model.get_results() + node_out = output_node.output(port_id) + result_name = get_result_node_name(output_name, port_id) + new_result = opset.result(node_out, name=result_name) + OVModelTransformer._update_tensor_name([new_result.get_output_tensor(port_id)], result_name) + results.append(new_result) return ov.Model(results, params) diff --git a/nncf/openvino/graph/node_utils.py b/nncf/openvino/graph/node_utils.py index 9c9d41137cf..650488126a6 100644 --- a/nncf/openvino/graph/node_utils.py +++ b/nncf/openvino/graph/node_utils.py @@ -148,6 +148,18 @@ def get_result_node_name(output_name: str, port_id: int) -> str: return f"Result_{output_name}.{port_id}" +def get_parameter_node_name(parameter_name: str, port_id: int) -> str: + """ + Returns name of Parameter based on node name and its port. + + :param parameter_name: Node name. + :param port_id: Node port. + :return: Name of result. + """ + + return f"Parameter_{parameter_name}.{port_id}" + + def get_ov_model_reduce_node_name(output_name: str, reduce_node_name: str, port_id: int) -> str: """ Returns name of reduce node based on output name, node type and port id. diff --git a/nncf/openvino/graph/transformations/commands.py b/nncf/openvino/graph/transformations/commands.py index 491515aa0f5..ced400920a0 100644 --- a/nncf/openvino/graph/transformations/commands.py +++ b/nncf/openvino/graph/transformations/commands.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Tuple import numpy as np import openvino.runtime as ov @@ -134,10 +134,10 @@ class OVModelExtractionCommand(Command): Extracts sub-graph based on the sub-model input and output names. """ - def __init__(self, inputs: List[str], outputs: List[str]): + def __init__(self, inputs: List[Tuple[str, int]], outputs: List[Tuple[str, int]]): """ - :param inputs: List of the input names that denote the sub-graph beginning. - :param outputs: List of the output names that denote the sub-graph ending. + :param inputs: List of the input ids that denote the sub-graph beginning. + :param outputs: List of the output ids that denote the sub-graph ending. """ super().__init__(TransformationType.EXTRACT) self.inputs = inputs diff --git a/nncf/openvino/statistics/collectors.py b/nncf/openvino/statistics/collectors.py index 61ef776fab7..2a79f364db2 100644 --- a/nncf/openvino/statistics/collectors.py +++ b/nncf/openvino/statistics/collectors.py @@ -269,7 +269,7 @@ def get_raw_stat_collector(num_samples, inplace=False): reducer = OVNoopReducer() aggregator = NoopAggregator(num_samples) - collector = TensorCollector(OVRawTensorStatistic) + collector = TensorCollector(OVRawTensorStatistic, skip_empty_stats=False) collector.register_statistic_branch(OVRawTensorStatistic.VALUES_STATS, reducer, aggregator) return collector diff --git a/nncf/quantization/algorithms/bias_correction/algorithm.py b/nncf/quantization/algorithms/bias_correction/algorithm.py index c3e1e10a6f6..113f87a2e81 100644 --- a/nncf/quantization/algorithms/bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/bias_correction/algorithm.py @@ -33,6 +33,7 @@ from nncf.common.utils.backend import get_backend from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.bias_correction.backend import ALGO_BACKENDS +from nncf.quantization.passes import filter_constant_nodes TModel = TypeVar("TModel") @@ -202,7 +203,7 @@ def apply( # Also, we need to remove unnecessary statistics that we don't need anymore, # to reduce memory usage during the algorithm's pipeline. - self._remove_unnecessary_stats(position, subgraphs_data) + # self._remove_unnecessary_stats(position, subgraphs_data) return main_model_transformer.transform(main_transformations_layout) @@ -216,69 +217,73 @@ def _get_subgraph_data_for_node(self, node: NNCFNode, nncf_graph: NNCFGraph) -> :param nncf_graph: NNCFGraph instance for graph analysis. :return: A dict with the list of the nodes for the subgraph input and statistics collection. """ - statistic_nodes, subgraph_input_nodes, subgraph_output_nodes, subgraph_output_ids = [], [], [], [] + statistic_nodes = [] + subgraph_output_ids, subgraph_input_ids = [], [] - def fill_statistic_nodes(node): + def fill_subgraph_output_ids(edge): + from_node, to_node = edge.from_node, edge.to_node # A small hack to speed up graph traversal. - if node in statistic_nodes or node in visited_nodes: + if to_node in visited_nodes: return - visited_nodes.append(node) + visited_nodes.append(to_node) # If we found a node with bias, we have to collect it as a statistic node, - # and its input for _collected_stat_inputs_map, + # and ID of it's precessor (node_name, input_port_id) into _collected_statistics_ids, # which will be used during the collection of statistics for the next node. - if self._backend_entity.is_node_with_bias(node, nncf_graph) and self._backend_entity.is_quantized_weights( - node, nncf_graph - ): - statistic_nodes.append(node) - activation_node, output_port_id = self._get_activation_node_and_port(node, nncf_graph) - subgraph_output_nodes.append(activation_node) + if self._backend_entity.is_node_with_bias( + to_node, nncf_graph + ) and self._backend_entity.is_quantized_weights(to_node, nncf_graph): + activation_id = (from_node.node_name, edge.output_port_id) + node_id = (to_node.node_name, edge.input_port_id) + subgraph_output_ids.append(activation_id) + + self._collected_stat_inputs_map[node_id] = activation_id - output_id = (activation_node.node_name, output_port_id) - subgraph_output_ids.append(output_id) - self._collected_stat_inputs_map[node.node_name] = output_id + statistic_nodes.append(to_node) return - for next_node in nncf_graph.get_next_nodes(node): - fill_statistic_nodes(next_node) + for output_edge in nncf_graph.get_output_edges(to_node): + fill_subgraph_output_ids(output_edge) - def fill_subgraph_input_nodes(node): - # A small hack to speed up graph traversal. - if node in subgraph_input_nodes or node in visited_nodes: - return - visited_nodes.append(node) + def fill_subgraph_input_ids(edge): + from_node, to_node = edge.from_node, edge.to_node + node_id = (to_node.node_name, edge.input_port_id) # Since we need to find the inputs for the subgraph, # we can take only those layers for which we have already collected statistics. - if node.node_name in self._collected_stat_inputs_map and node not in statistic_nodes: - subgraph_input_nodes.append(node) + if node_id in self._collected_stat_inputs_map: + activation_id = self._collected_stat_inputs_map[node_id] + if activation_id not in subgraph_output_ids: + subgraph_input_ids.append(node_id) + return + + if from_node in visited_nodes: return + visited_nodes.append(from_node) - for previous_node in nncf_graph.get_previous_nodes(node): - fill_subgraph_input_nodes(previous_node) + for input_edge in nncf_graph.get_input_edges(from_node): + fill_subgraph_input_ids(input_edge) # First, we need to find out the nodes with bias that follow by main node. # To collect statistics for next nodes. visited_nodes = [] - for next_node in nncf_graph.get_next_nodes(node): - fill_statistic_nodes(next_node) + for output_edge in nncf_graph.get_output_edges(node): + fill_subgraph_output_ids(output_edge) # We then need to find nodes for which statistics have already been collected, # to use them as inputs for the subgraph. - statistic_nodes = statistic_nodes if statistic_nodes else nncf_graph.get_next_nodes(node) + statistic_nodes = statistic_nodes if statistic_nodes else [node] visited_nodes = [] for stat_node in statistic_nodes: - fill_subgraph_input_nodes(stat_node) + for input_edge in nncf_graph.get_input_edges(stat_node): + fill_subgraph_input_ids(input_edge) + + if not subgraph_output_ids: + for edge in nncf_graph.get_output_edges(node): + subgraph_output_ids.append((edge.to_node.node_name, OUTPUT_PORT_OF_NODE)) - # In case the outputs were not found during the collection of statistics nodes, - # we use the latter as the outputs of the subgraph. - subgraph_output_nodes = subgraph_output_nodes if subgraph_output_nodes else statistic_nodes - subgraph_output_names = [ - n.node_name for n in subgraph_output_nodes if NNCFGraphNodeType.OUTPUT_NODE not in n.node_name - ] subgraph_data = { - "subgraph_input_names": set(n.node_name for n in subgraph_input_nodes), - "subgraph_output_names": set(subgraph_output_names), + "subgraph_input_ids": set(subgraph_input_ids), "subgraph_output_ids": set(subgraph_output_ids), } @@ -295,7 +300,7 @@ def _prepare_subgraph(self, node: NNCFNode, model: TModel, nncf_graph: NNCFGraph :return: Backend-specific subgraph extracted from the model. """ extracted_model = self.extract_model( - model, subgraph_data["subgraph_input_names"], subgraph_data["subgraph_output_names"] + model, subgraph_data["subgraph_input_ids"], subgraph_data["subgraph_output_ids"] ) transformation_layout = TransformationLayout() @@ -324,20 +329,22 @@ def _create_feed_dicts( statistics_size = self.subset_size statistics_per_input = {} - for input_node_name in subgraph_data["subgraph_input_names"]: - input_tensor_name = self._backend_entity.get_input_name(model, input_node_name) - activation_name, port_id = self._collected_stat_inputs_map[input_node_name] - input_fp = self._get_fp_inputs(statistic_points, node_name=activation_name, port_id=port_id) + for input_id in subgraph_data["subgraph_input_ids"]: + input_name, input_port_id = input_id + input_tensor_name = self._backend_entity.get_input_name(model, input_name, input_port_id) + activation_name, output_port_id = self._collected_stat_inputs_map[input_id] + input_fp = self._get_fp_inputs(statistic_points, node_name=activation_name, port_id=output_port_id) statistics_per_input[input_tensor_name] = input_fp statistics_size = min(statistics_size, len(input_fp)) for stat_id in range(statistics_size): feed_dict = {} - for input_node_name in subgraph_data["subgraph_input_names"]: - input_tensor_name = self._backend_entity.get_input_name(model, input_node_name) + for input_id in subgraph_data["subgraph_input_ids"]: + input_name, input_port_id = input_id + input_tensor_name = self._backend_entity.get_input_name(model, input_name, input_port_id) # Since we do not use as inputs the layers from which the statistics are gathered, # but those that follow them, we need to take this into account when creating feed dicts. - activation_name, port_id = self._collected_stat_inputs_map[input_node_name] + activation_name, _ = self._collected_stat_inputs_map[input_id] feed_dict[input_tensor_name] = statistics_per_input[input_tensor_name][stat_id] feed_dicts.append(feed_dict) return feed_dicts @@ -422,11 +429,11 @@ def _remove_unnecessary_stats(self, position: int, subgraphs_data: Dict[str, Dic needed_stats_list = [] for i in range(position + 1, len(subgraphs_data)): input_names = subgraphs_data[i]["subgraph_input_names"] - needed_stats_list.extend([self._collected_stat_inputs_map[name][0] for name in input_names]) + needed_stats_list.extend([self._collected_statistics_ids[name][0] for name in input_names]) node_inputs_name = subgraphs_data[position]["subgraph_input_names"] for node_input_name in node_inputs_name: - activation_name, port_id = self._collected_stat_inputs_map[node_input_name] + activation_name, port_id = self._collected_statistics_ids[node_input_name] input_id = (activation_name, port_id) if activation_name not in needed_stats_list and input_id in self._fp_inputs: nncf_logger.debug(f"Dropped {activation_name} output statistics.") @@ -460,7 +467,10 @@ def input_filter_func(point): for tensor_collector in statistic_points.get_algo_statistics_for_node( node_name, input_filter_func, self._algorithm_key ): - input_fp.extend(tensor_collector.get_statistics().values) + statistics_value = tensor_collector.get_statistics().values + # if statistics_value is None: + # raise RuntimeError(f"Statistics were not collected for the node {node_name}") + input_fp.extend(statistics_value) self._fp_inputs[input_id] = input_fp return self._fp_inputs[input_id] @@ -488,18 +498,18 @@ def output_filter_func(point): def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: self._set_backend_entity(model) - model_copy = self._backend_entity.remove_fq_from_inputs(copy_model(model), graph) - graph_copy = NNCFGraphFactory.create(model_copy) - model_copy = self._backend_entity.insert_null_biases(model_copy, graph_copy) - nncf_graph = NNCFGraphFactory.create(model_copy) + inference_graph = NNCFGraphFactory.create(model) + inference_graph = filter_constant_nodes(inference_graph) statistic_container = StatisticPointsContainer() nodes_with_bias = [ - node for node in nncf_graph.topological_sort() if self._backend_entity.is_node_with_bias(node, nncf_graph) + node for node in graph.topological_sort() if self._backend_entity.is_node_with_bias(node, graph) ] - model_inputs = nncf_graph.get_input_nodes() + model_inputs = graph.get_input_nodes() + biased_after_input_nodes = self._get_biased_after_nodes(graph, model_inputs) # Collection of statistics after layers where biases will be corrected. + # These floating-point statistics is needed for bias shift calculation. for node in nodes_with_bias: node_name = node.node_name channel_axis = node.metatype.output_channel_axis @@ -517,69 +527,66 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin ) ) - # We must collect the nodes with biases following the model inputs. - biased_after_input_nodes = self._get_biased_after_nodes(nncf_graph, model_inputs, model_copy) - + # We must collect the nodes with biases following the model inputs using greedy approach. for biased_after_input_node in biased_after_input_nodes: # We need to collect activation input to register it for the biased layer as the layer with statistics. - activation_node, output_port_id = self._get_activation_node_and_port(biased_after_input_node, nncf_graph) - activation_node_name = activation_node.node_name + for edge in inference_graph.get_input_edges(biased_after_input_node): + activation_name = edge.from_node.node_name + output_port_id = edge.output_port_id + activation_id = (activation_name, output_port_id) + node_id = (edge.to_node.node_name, edge.input_port_id) - self._collected_stat_inputs_map[biased_after_input_node.node_name] = (activation_node_name, output_port_id) - statistic_point = self._backend_entity.target_point( - TargetType.POST_LAYER_OPERATION, activation_node_name, port_id=output_port_id - ) - stat_collector = self._backend_entity.raw_statistic_collector( - num_samples=self.subset_size, inplace=self.inplace_statistics - ) - statistic_container.add_statistic_point( - StatisticPoint( - target_point=statistic_point, tensor_collector=stat_collector, algorithm=self._algorithm_key + self._collected_stat_inputs_map[node_id] = activation_id + + if activation_name in statistic_container: + continue + + statistic_point = self._backend_entity.target_point( + TargetType.POST_LAYER_OPERATION, activation_name, port_id=output_port_id + ) + stat_collector = self._backend_entity.raw_statistic_collector( + num_samples=self.subset_size, inplace=self.inplace_statistics + ) + statistic_container.add_statistic_point( + StatisticPoint( + target_point=statistic_point, tensor_collector=stat_collector, algorithm=self._algorithm_key + ) ) - ) - # Then we need also to collect model input statistics to prevent cases when nodes with bias have no input data. + # Then we need also to collect model input statistics to prevent cases when subgraph inputs have no input data. for input_node in model_inputs: - # We assume that input node has only one output port - input_name = input_node.node_name - if input_name in statistic_container: - continue - for next_layer in nncf_graph.get_next_nodes(input_node): - self._collected_stat_inputs_map[next_layer.node_name] = (input_node.node_name, OUTPUT_PORT_OF_NODE) - statistic_point = self._backend_entity.target_point( - TargetType.POST_LAYER_OPERATION, input_node.node_name, port_id=OUTPUT_PORT_OF_NODE - ) - stat_collector = self._backend_entity.raw_statistic_collector( - num_samples=self.subset_size, inplace=self.inplace_statistics - ) - statistic_container.add_statistic_point( - StatisticPoint( - target_point=statistic_point, tensor_collector=stat_collector, algorithm=self._algorithm_key - ) - ) + # We assume that input node has only one output port - 0. + for edge in inference_graph.get_output_edges(input_node): + activation_name = edge.from_node.node_name + output_port_id = edge.output_port_id + activation_id = (activation_name, output_port_id) + node_id = (edge.to_node.node_name, edge.input_port_id) - return statistic_container + self._collected_stat_inputs_map[node_id] = activation_id - def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[NNCFNode, int]: - """ - This method returns the activation layer and corresponding port id for the node. + if activation_name in statistic_container: + continue - :param node: NNCFGraph node for which the activation is sought. - :param nncf_graph: NNCFGraph instance with the node. - :return: Tuple with the activation node and port id. - """ - activation_port = self._backend_entity.get_activation_port_id(node, nncf_graph) - activation_node = nncf_graph.get_input_edges(node)[activation_port].from_node - port_id = nncf_graph.get_edge(activation_node, node).output_port_id - return activation_node, port_id + statistic_point = self._backend_entity.target_point( + TargetType.POST_LAYER_OPERATION, activation_name, port_id=output_port_id + ) + stat_collector = self._backend_entity.raw_statistic_collector( + num_samples=self.subset_size, inplace=self.inplace_statistics + ) + statistic_container.add_statistic_point( + StatisticPoint( + target_point=statistic_point, tensor_collector=stat_collector, algorithm=self._algorithm_key + ) + ) + + return statistic_container - def _get_biased_after_nodes(self, nncf_graph: NNCFGraph, nodes: List[NNCFNode], model: TModel) -> List[NNCFNode]: + def _get_biased_after_nodes(self, nncf_graph: NNCFGraph, nodes: List[NNCFNode]) -> List[NNCFNode]: """ This method finds and returns nodes with the bias in the model that follows after the input nodes. :param nncf_graph: NNCFGraph instance. :param nodes: List of the model inputs as NNCFNodes. - :param model: TModel instance. :return: List of the nodes with bias. """ @@ -617,17 +624,19 @@ def traverse_to_biased(node, condition_container): return list(biased_nodes - dependant_nodes) - def extract_model(self, model: TModel, input_node_names: List[str], output_node_names: List[str]) -> TModel: + def extract_model( + self, model: TModel, input_node_ids: List[Tuple[str, int]], output_node_ids: List[Tuple[str, int]] + ) -> TModel: """ Returns the backend-specific model that bounded by the specified input & output layers. :param model: Backend-specific model. - :param input_node_names: List with the input node names. - :param output_node_names: List with the output node names. + :param input_node_ids: List with the input node ids. + :param output_node_ids: List with the output node ids. :return: Extracted backend-specific model. """ transformation_layout = TransformationLayout() model_transformer = ModelTransformerFactory.create(model) - model_extraction_command = self._backend_entity.model_extraction_command(input_node_names, output_node_names) + model_extraction_command = self._backend_entity.model_extraction_command(input_node_ids, output_node_ids) transformation_layout.register(model_extraction_command) return model_transformer.transform(transformation_layout) diff --git a/nncf/quantization/algorithms/bias_correction/backend.py b/nncf/quantization/algorithms/bias_correction/backend.py index a85f2fffb0a..05795b2fa96 100644 --- a/nncf/quantization/algorithms/bias_correction/backend.py +++ b/nncf/quantization/algorithms/bias_correction/backend.py @@ -11,7 +11,7 @@ from abc import ABC from abc import abstractmethod -from typing import List, Optional, TypeVar +from typing import List, Optional, Tuple, TypeVar import numpy as np @@ -64,12 +64,14 @@ def create_bias_correction_command(node: NNCFNode, bias_value: np.ndarray) -> Tr @staticmethod @abstractmethod - def model_extraction_command(inputs: List[str], outputs: List[str]) -> TransformationCommand: + def model_extraction_command( + inputs: List[Tuple[str, int]], outputs: List[Tuple[str, int]] + ) -> TransformationCommand: """ - Returns backend-specific command to extract sub-model based on input & output names. + Returns backend-specific command to extract sub-model based on input & output ids. - :param inputs: List of the input names for sub-model beginning. - :param outputs: List of the output names for sub-model end. + :param inputs: List of the input ids for sub-model beginning. + :param outputs: List of the output ids for sub-model end. :return: Backend-specific TransformationCommand for the model extraction. """ @@ -125,19 +127,6 @@ def process_model_output(raw_data: OutputType, output_name: str) -> NNCFTensor: :return: Processed output as NNCFTensor. """ - @staticmethod - @abstractmethod - def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: - """ - Returns input port id corresponding to activation input edge for - the node. - Supports only nodes that could have bias value. - - :param node: Node of NNCFGraph with bias value. - :param nncf_graph: NNCFGraph instance with the node. - :return: boolean port id. - """ - @staticmethod @abstractmethod def get_bias_value(node: NNCFNode, model: TModel, nncf_graph: NNCFGraph) -> np.ndarray: diff --git a/nncf/quantization/algorithms/bias_correction/onnx_backend.py b/nncf/quantization/algorithms/bias_correction/onnx_backend.py index 0e9ad720a10..d0bbb4f51b8 100644 --- a/nncf/quantization/algorithms/bias_correction/onnx_backend.py +++ b/nncf/quantization/algorithms/bias_correction/onnx_backend.py @@ -92,10 +92,6 @@ def raw_statistic_collector(inplace: bool, num_samples: int = None) -> ONNXMeanS def process_model_output(raw_data: Dict, output_name: str) -> ONNXNNCFTensor: return ONNXNNCFTensor(raw_data[output_name]) - @staticmethod - def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[int, int]: - return 0 - @staticmethod def get_bias_value(node: NNCFNode, model: onnx.ModelProto, nncf_graph: NNCFGraph) -> np.ndarray: return get_bias_value(node, model) diff --git a/nncf/quantization/algorithms/bias_correction/openvino_backend.py b/nncf/quantization/algorithms/bias_correction/openvino_backend.py index c57f9df6a20..632163fdc2b 100644 --- a/nncf/quantization/algorithms/bias_correction/openvino_backend.py +++ b/nncf/quantization/algorithms/bias_correction/openvino_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import numpy as np import openvino.runtime as ov @@ -56,7 +56,9 @@ def create_bias_correction_command( return OVCommandCreator.create_command_to_update_bias(node, bias_value, nncf_graph) @staticmethod - def model_extraction_command(inputs: List[str], outputs: List[str]) -> OVModelExtractionCommand: + def model_extraction_command( + inputs: List[Tuple[str, int]], outputs: List[Tuple[str, int]] + ) -> OVModelExtractionCommand: return OVModelExtractionCommand(inputs, outputs) @staticmethod @@ -80,21 +82,12 @@ def raw_statistic_collector(inplace: bool, num_samples: int = None) -> TensorCol def process_model_output(raw_data: Dict, output_name: str) -> OVNNCFTensor: return OVNNCFTensor(raw_data[output_name]) - @staticmethod - def get_activation_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: - constant_ports = node.layer_attributes.get_const_port_ids() - activation_ports = [ - e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in constant_ports - ] - assert len(activation_ports) == 1 - return activation_ports[0] - @staticmethod def get_bias_value(node: NNCFNode, model: ov.Model, nncf_graph: NNCFGraph) -> np.ndarray: return get_bias_value(node, nncf_graph, model) @staticmethod - def get_input_name(model: ov.Model, node_name: str) -> str: + def get_input_name(model: ov.Model, node_name: str, port_id: int) -> str: ops_dict = {op.get_friendly_name(): op for op in model.get_ops()} model_input_names = [] @@ -103,10 +96,10 @@ def get_input_name(model: ov.Model, node_name: str) -> str: if node_name in model_input_names: return node_name - for input_port in ops_dict[node_name].inputs(): - input_node = input_port.get_source_output().get_node() - if input_node.get_type_name() == "Parameter": - return input_port.get_tensor().get_any_name() + input_port = ops_dict[node_name].input(port_id) + input_node = input_port.get_source_output().get_node() + if input_node.get_type_name() == "Parameter": + return input_port.get_tensor().get_any_name() raise RuntimeError(f"Input layer not found for {node_name}") @staticmethod diff --git a/tests/openvino/native/test_bias_correction.py b/tests/openvino/native/test_bias_correction.py index 6b42f11d4e2..e1f4589d6bb 100644 --- a/tests/openvino/native/test_bias_correction.py +++ b/tests/openvino/native/test_bias_correction.py @@ -91,8 +91,7 @@ def check_bias(model: ov.Model, ref_biases: Dict) -> None: { "collected_inputs": {"/conv_1/Conv/WithoutBiases": ("input.1", 0)}, "subgraph_data": { - "subgraph_input_names": {"/conv_1/Conv/WithoutBiases"}, - "subgraph_output_names": {"/maxpool_1/MaxPool", "/Split"}, + "subgraph_input_ids": {"/conv_1/Conv/WithoutBiases"}, "subgraph_output_ids": {("/Split", 0), ("/maxpool_1/MaxPool", 0), ("/Split", 1)}, }, }, @@ -107,8 +106,7 @@ def check_bias(model: ov.Model, ref_biases: Dict) -> None: "/conv_6/Conv/WithoutBiases": ("/Split", 1), }, "subgraph_data": { - "subgraph_input_names": {"/conv_2/Conv/WithoutBiases"}, - "subgraph_output_names": {"/Relu_1"}, + "subgraph_input_ids": {"/conv_2/Conv/WithoutBiases"}, "subgraph_output_ids": {("/Relu_1", 0)}, }, }, @@ -124,8 +122,7 @@ def check_bias(model: ov.Model, ref_biases: Dict) -> None: "/conv_6/Conv/WithoutBiases": ("/Split", 1), }, "subgraph_data": { - "subgraph_input_names": {"/conv_1/Conv/WithoutBiases", "/conv_3/Conv/WithoutBiases"}, - "subgraph_output_names": {"/Split"}, + "subgraph_input_ids": {"/conv_1/Conv/WithoutBiases", "/conv_3/Conv/WithoutBiases"}, "subgraph_output_ids": {("/Split", 0), ("/Split", 1)}, }, }, @@ -138,8 +135,7 @@ def check_bias(model: ov.Model, ref_biases: Dict) -> None: "/conv_6/Conv/WithoutBiases": ("/Split", 1), }, "subgraph_data": { - "subgraph_input_names": {"/conv_4/Conv/WithoutBiases"}, - "subgraph_output_names": {"/Relu_2"}, + "subgraph_input_ids": {"/conv_4/Conv/WithoutBiases"}, "subgraph_output_ids": {("/Relu_2", 0)}, }, }, @@ -152,8 +148,7 @@ def check_bias(model: ov.Model, ref_biases: Dict) -> None: "/conv_6/Conv/WithoutBiases": ("/Split", 1), }, "subgraph_data": { - "subgraph_input_names": {"/conv_5/Conv/WithoutBiases", "/conv_6/Conv/WithoutBiases"}, - "subgraph_output_names": {"/Add_3", "/Concat"}, + "subgraph_input_ids": {"/conv_5/Conv/WithoutBiases", "/conv_6/Conv/WithoutBiases"}, "subgraph_output_ids": {("/Add_3", 0), ("/Concat", 0)}, }, }, @@ -167,12 +162,11 @@ def check_bias(model: ov.Model, ref_biases: Dict) -> None: "/conv_10/Conv/WithoutBiases": ("/Concat", 0), }, "subgraph_data": { - "subgraph_input_names": { + "subgraph_input_ids": { "/conv_8/Conv/WithoutBiases", "/conv_9/Conv/WithoutBiases", "/conv_10/Conv/WithoutBiases", }, - "subgraph_output_names": {"/Concat_1"}, "subgraph_output_ids": {("/Concat_1", 0)}, }, }, @@ -184,8 +178,7 @@ def check_bias(model: ov.Model, ref_biases: Dict) -> None: "/MatMul": ("/Reshape", 0), }, "subgraph_data": { - "subgraph_input_names": {"/MatMul"}, - "subgraph_output_names": {"/Reshape_1", "/Add_4"}, + "subgraph_input_ids": {"/MatMul"}, "subgraph_output_ids": {("/Reshape_1", 0), ("/Add_4", 0)}, }, }, diff --git a/tests/post_training/test_templates/test_bias_correction.py b/tests/post_training/test_templates/test_bias_correction.py index 68c72301707..bced2affc78 100644 --- a/tests/post_training/test_templates/test_bias_correction.py +++ b/tests/post_training/test_templates/test_bias_correction.py @@ -174,7 +174,7 @@ def test__get_subgraph_data_for_node(self, quantized_test_model, layer_name, ref bc_algo._set_backend_entity(quantized_test_model) node = nncf_graph.get_node_by_name(layer_name) - bc_algo._collected_stat_inputs_map.update(ref_data["collected_inputs"]) + bc_algo._collected_statistics_ids.update(ref_data["collected_inputs"]) subgraph_data = bc_algo._get_subgraph_data_for_node(node, nncf_graph) ref_subgraph_data = ref_data["subgraph_data"]