Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve the constant subgraph #2902

Merged
merged 5 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions nncf/quantization/algorithms/accuracy_control/ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from nncf.quantization.algorithms.accuracy_control.evaluator import Evaluator
from nncf.quantization.algorithms.accuracy_control.rank_functions import create_normalized_mse_func
from nncf.quantization.algorithms.accuracy_control.subset_selection import select_subset
from nncf.quantization.passes import remove_shapeof_subgraphs
from nncf.quantization.passes import find_shapeof_subgraphs

TModel = TypeVar("TModel")
TPModel = TypeVar("TPModel")
Expand Down Expand Up @@ -109,11 +109,13 @@ def find_groups_of_quantizers_to_rank(self, quantized_model_graph: NNCFGraph) ->
*self._algo_backend.get_start_nodes_for_activation_path_tracing(quantized_model_graph),
]

quantized_model_graph_without_shapeof = remove_shapeof_subgraphs(
deepcopy(quantized_model_graph),
shapeof_subgraphs = find_shapeof_subgraphs(
quantized_model_graph,
self._algo_backend.get_shapeof_metatypes(),
input_nodes,
)
quantized_model_graph_without_shapeof = deepcopy(quantized_model_graph)
quantized_model_graph_without_shapeof.remove_nodes_from(shapeof_subgraphs)

for quantizer_node in reversed(quantizers):
if processed.get(quantizer_node.node_name, False):
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/layerwise/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def schedule(
"""
# Initialize input nodes and create a copy of the graph for inference
input_nodes = graph.get_input_nodes()
inference_graph = transform_to_inference_graph(deepcopy(graph), input_nodes, [], [])
inference_graph = transform_to_inference_graph(deepcopy(graph), input_nodes, [], [], [])

steps = []
visited_map = {node: False for node in inference_graph.get_all_nodes()}
Expand Down
1 change: 1 addition & 0 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ def _find_quantization_target_points(
self._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
self._backend_entity.shapeof_metatypes,
self._backend_entity.dropout_metatypes,
self._backend_entity.preserved_metatypes,
)

quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns)
Expand Down
8 changes: 8 additions & 0 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@


class MinMaxAlgoBackend(ABC):
@property
@abstractmethod
def preserved_metatypes(self) -> List[OperatorMetatype]:
"""
Property for backend-specific metatypes that require preserving float subgraphs
when removing the ShapeOf subgraph.
"""

@property
@abstractmethod
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@


class ONNXMinMaxAlgoBackend(MinMaxAlgoBackend):
@property
def preserved_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
return MATMUL_METATYPES
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@


class OVMinMaxAlgoBackend(MinMaxAlgoBackend):
@property
def preserved_metatypes(self) -> List[OperatorMetatype]:
return [om.OVConvolutionMetatype, om.OVLSTMSequenceMetatype]

@property
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
return [om.OVMatMulMetatype]
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@


class PTMinMaxAlgoBackend(MinMaxAlgoBackend):
@property
def preserved_metatypes(self) -> List[OperatorMetatype]:
return []

TARGET_TYPE_TO_PT_INS_TYPE_MAP = {
TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK,
TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK,
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class FXMinMaxAlgoBackend(MinMaxAlgoBackend):
TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK,
}

@property
def preserved_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def mat_mul_metatypes(self) -> List[OperatorMetatype]:
return [om.PTLinearMetatype, om.PTMatMulMetatype]
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def apply(
"""
matches = []

inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [])
inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], [])
nx_graph = inference_nncf_graph.get_nx_graph_copy()
for _, pattern_graph in self._patterns.items():
matches.extend(find_subgraphs_matching_pattern(nx_graph, pattern_graph(), strict=False))
Expand Down
94 changes: 67 additions & 27 deletions nncf/quantization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.graph.operator_metatypes import OperatorMetatype

TModel = TypeVar("TModel")
Expand All @@ -24,6 +25,7 @@ def transform_to_inference_graph(
input_nodes: List[NNCFNode],
shapeof_metatypes: List[OperatorMetatype],
dropout_metatypes: List[OperatorMetatype],
preserved_metatypes: List[OperatorMetatype],
) -> NNCFGraph:
"""
This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows.
Expand All @@ -34,27 +36,33 @@ def transform_to_inference_graph(
:param dropout_metatypes: List of backend-specific Dropout metatypes.
:return: NNCFGraph in the inference style.
"""
remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, input_nodes)
filter_constant_nodes(nncf_graph, input_nodes)
shapeof_subgraphs = find_shapeof_subgraphs(nncf_graph, shapeof_metatypes, input_nodes)
preserved_nodes = find_preserved_nodes(nncf_graph, shapeof_subgraphs, preserved_metatypes)
constant_subgraphs = find_constant_subgraphs(nncf_graph, input_nodes)

nodes_to_drop = set([*shapeof_subgraphs, *constant_subgraphs]).difference(preserved_nodes)
nncf_graph.remove_nodes_from(nodes_to_drop)

remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes)
return nncf_graph


def remove_shapeof_subgraphs(
def find_shapeof_subgraphs(
andrey-churkin marked this conversation as resolved.
Show resolved Hide resolved
nncf_graph: NNCFGraph,
shapeof_metatypes: List[OperatorMetatype],
input_nodes: List[NNCFNode],
) -> NNCFGraph:
) -> List[NNCFNode]:
"""
Removes the ShapeOf subgraphs from the provided NNCFGraph instance inplace.
Constant subgraph should be already removed from the given NNCFGraph.

:param nncf_graph: NNCFGraph instance for the transformation.
:param shapeof_metatypes: List of backend-specific ShapeOf metatypes.
:param input_nodes: List of input nodes for the given NNCFGraph.
:return: NNCFGraph without ShapeOf subgraphs.
Returns a list of nodes belonging to ShapeOf subgraphs.

:param nncf_graph: The input graph to be analyzed.
:param shapeof_metatypes: A list of metatypes representing backend-specific
ShapeOf operations.
:param input_nodes: A list of nodes designated as graph inputs. These nodes are
used to identify which nodes depend on input data.
:return: A list of nodes belonging to ShapeOf subgraphs.
"""
nodes_to_drop = set()
shapeof_subgraphs = set()
shape_of_nodes = []
infer_nodes = []

Expand All @@ -70,21 +78,53 @@ def remove_shapeof_subgraphs(
nodes_queue.extend(nncf_graph.get_next_nodes(node))

for shape_of_node in shape_of_nodes:
nodes_to_drop.add(shape_of_node.node_name)
shapeof_subgraphs.add(shape_of_node)

shape_of_queue = collections.deque()
shape_of_queue.extend(nncf_graph.get_next_nodes(shape_of_node))
while shape_of_queue:
node = shape_of_queue.pop()
if node.node_name in nodes_to_drop or node.node_name in infer_nodes:
if node in shapeof_subgraphs or node.node_name in infer_nodes:
continue
nodes_to_drop.add(node.node_name)
shapeof_subgraphs.add(node)
# traverse forward and backward to exclude full shape of subgraph
# recursion excluded due to infer_nodes list around subgraph shape
shape_of_queue.extend(nncf_graph.get_next_nodes(node) + nncf_graph.get_previous_nodes(node))

nncf_graph.remove_nodes_from([nncf_graph.get_node_by_name(name) for name in nodes_to_drop])
return nncf_graph
return list(shapeof_subgraphs)


def find_preserved_nodes(
graph: NNCFGraph,
shapeof_subgraphs: List[NNCFNode],
preserved_metatypes: List[OperatorMetatype],
) -> List[NNCFNode]:
"""
:param graph: The input graph to be analyzed.
:param shapeof_subgraphs: A list of nodes belonging to ShapeOf subgraphs.
:param preserved_metatypes: Backend-specific metatypes that require preserving
float subgraphs when removing the ShapeOf subgraph.
:return: A list of nodes in float subgraphs of ShapeOf subgraphs.
"""
preserved_nodes = set()
for node in graph.get_nodes_by_metatypes(preserved_metatypes):
for e in graph.get_input_edges(node):
if e.from_node in shapeof_subgraphs and e.dtype == Dtype.FLOAT:
preserved_nodes.add(e.from_node)

queue = collections.deque(preserved_nodes)
while queue:
node = queue.pop()

for e in graph.get_input_edges(node):
if e.from_node in preserved_nodes:
continue

if e.dtype == Dtype.FLOAT and e.from_node in shapeof_subgraphs:
queue.append(e.from_node)
preserved_nodes.add(e.from_node)

return list(preserved_nodes)


def remove_nodes_and_reconnect_graph(
Expand Down Expand Up @@ -137,20 +177,20 @@ def remove_nodes_and_reconnect_graph(
return nncf_graph


def filter_constant_nodes(
def find_constant_subgraphs(
nncf_graph: NNCFGraph,
input_nodes: List[NNCFNode],
) -> NNCFGraph:
) -> List[NNCFNode]:
"""
Removes all Constant nodes from NNCFGraph inplace, making it inference graph.
The traversing starts from the input nodes and nodes with weights.
Returns a list of nodes belonging to constant subgraphs.

:param nncf_graph: NNCFGraph instance for the transformation.
:param input_nodes: List of input nodes for the given NNCFGraph.
:return: NNCFGraph without Constant nodes.
:param nncf_graph: The input graph to be analyzed.
:param input_nodes: A list of nodes designated as graph inputs. These nodes are
used to identify which nodes depend on input data.
:return: A list of nodes belonging to constant subgraphs.
"""
if not input_nodes:
return nncf_graph
return []

visited_nodes = set()
nodes_queue = collections.deque(input_nodes)
Expand All @@ -161,5 +201,5 @@ def filter_constant_nodes(
visited_nodes.add(node)
nodes_queue.extend(nncf_graph.get_next_nodes(node))
constant_nodes = [node for node in nncf_graph.get_all_nodes() if node not in visited_nodes]
nncf_graph.remove_nodes_from(constant_nodes)
return nncf_graph

return constant_nodes
7 changes: 4 additions & 3 deletions tests/common/quantization/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.quantization.passes import filter_constant_nodes
from nncf.quantization.passes import find_constant_subgraphs
from nncf.quantization.passes import remove_nodes_and_reconnect_graph
from tests.cross_fw.test_templates.models import NNCFGraphDropoutRemovingCase
from tests.cross_fw.test_templates.models import NNCFGraphToTestConstantFiltering
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_remove_nodes_and_reconnect_graph(mode: ParameterTestModes):


@pytest.mark.parametrize("node_between_const_and_op", [False, True])
def test_filter_constant_nodes(node_between_const_and_op):
def test_find_constant_subgraphs(node_between_const_and_op):
dot_reference_path_before = (
Path("passes") / f"test_constant_filtering_model_before{int(node_between_const_and_op)}.dot"
)
Expand All @@ -88,5 +88,6 @@ class NodeWithWeightMetatype(OperatorMetatype):
additional_input_names = ["/Conv2_0", "/Concat_with_missed_input_0"]
input_nodes = nncf_graph.get_input_nodes() + [nncf_graph.get_node_by_name(name) for name in additional_input_names]
_check_graphs(dot_reference_path_before, nncf_graph)
filter_constant_nodes(nncf_graph, input_nodes)
constant_subgraphs = find_constant_subgraphs(nncf_graph, input_nodes)
nncf_graph.remove_nodes_from(constant_subgraphs)
_check_graphs(dot_reference_path_after, nncf_graph)
8 changes: 6 additions & 2 deletions tests/common/quantization/test_quantizer_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.quantization.quantizer_removal import find_quantizer_nodes_to_cut
from nncf.quantization.passes import remove_shapeof_subgraphs
from nncf.quantization.passes import find_shapeof_subgraphs
from tests.common.quantization.metatypes import CONSTANT_METATYPES
from tests.common.quantization.metatypes import METATYPES_FOR_TEST
from tests.common.quantization.metatypes import QUANTIZABLE_METATYPES
Expand Down Expand Up @@ -304,7 +304,11 @@ def test_find_quantizer_nodes_to_cut(nncf_graph: NNCFGraph, test_case: Parameter
# As test graphs are fully connected and does not have readvariable metatype,
# this should work
input_nodes = nncf_graph.get_input_nodes()
nncf_graph_without_shapeof = remove_shapeof_subgraphs(deepcopy(nncf_graph), SHAPEOF_METATYPES, input_nodes)

shapeof_subgraphs = find_shapeof_subgraphs(nncf_graph, SHAPEOF_METATYPES, input_nodes)
nncf_graph_without_shapeof = deepcopy(nncf_graph)
nncf_graph_without_shapeof.remove_nodes_from(shapeof_subgraphs)

nodes, ops = find_quantizer_nodes_to_cut(
nncf_graph_without_shapeof,
quantizer_node,
Expand Down
4 changes: 4 additions & 0 deletions tests/cross_fw/test_templates/test_ptq_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def test_quantize_outputs(self, test_params, quantize_outputs):
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
min_max_algo._backend_entity.preserved_metatypes,
)
q_setup = min_max_algo._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns)
act_num_q, weight_num_q = 0, 0
Expand All @@ -261,6 +262,7 @@ def test_ignored_scopes(self, test_params, ignored_scopes_data):
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
min_max_algo._backend_entity.preserved_metatypes,
)
q_setup = min_max_algo._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns)
act_num_q, weight_num_q = 0, 0
Expand All @@ -286,6 +288,7 @@ def test_model_type_pass(self, test_params, model_type):
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
min_max_algo._backend_entity.preserved_metatypes,
)
q_setup = min_max_algo._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns)
for quantization_point in q_setup.quantization_points.values():
Expand Down Expand Up @@ -384,6 +387,7 @@ def test_validate_scope(self, test_params, validate_scopes):
self.get_algo_backend().get_start_nodes_for_activation_path_tracing(nncf_graph),
[],
[],
[],
)
ignored_patterns = test_params["test_model_type_pass"]["ignored_patterns"]
algo = MinMaxQuantization(
Expand Down
3 changes: 3 additions & 0 deletions tests/cross_fw/test_templates/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def test_default_quantizer_config(self, single_conv_nncf_graph):
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
min_max_algo._backend_entity.preserved_metatypes,
)
q_setup = min_max_algo._get_quantizer_setup(
nncf_graph, inference_nncf_graph, hw_patterns=GraphPattern(), ignored_patterns=GraphPattern()
Expand Down Expand Up @@ -188,6 +189,7 @@ def test_quantizer_config_from_ptq_params_for_CPU(
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
min_max_algo._backend_entity.preserved_metatypes,
)
if signed_weights is False or signed_activations in [True, False]: # Incompatible with HW CPU config
with pytest.raises(
Expand Down Expand Up @@ -230,6 +232,7 @@ def test_depthwise_conv_default_quantizer_config(self, depthwise_conv_nncf_graph
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
min_max_algo._backend_entity.preserved_metatypes,
)
q_setup = min_max_algo._get_quantizer_setup(
nncf_graph, inference_nncf_graph, hw_patterns=GraphPattern(), ignored_patterns=GraphPattern()
Expand Down
Loading