diff --git a/README.md b/README.md
index 9929b92255c..3e2c5c983cb 100644
--- a/README.md
+++ b/README.md
@@ -37,6 +37,7 @@ learning frameworks.
| :------------------------------------------------------------------------------------------------------- | :-------: | :-------: | :-----------: | :-----------: |
| [Post-Training Quantization](./docs/usage/post_training_compression/post_training_quantization/Usage.md) | Supported | Supported | Supported | Supported |
| [Weights Compression](./docs/usage/post_training_compression/weights_compression/Usage.md) | Supported | Supported | Not supported | Not supported |
+| [Activation Sparsity](./nncf/experimental/torch/sparsify_activations/ActivationSparsity.md) | Not supported | Experimental |Not supported| Not supported |
### Training-Time Compression Algorithms
diff --git a/nncf/experimental/torch/sparsify_activations/ActivationSparsity.md b/nncf/experimental/torch/sparsify_activations/ActivationSparsity.md
new file mode 100644
index 00000000000..c0d94ceac05
--- /dev/null
+++ b/nncf/experimental/torch/sparsify_activations/ActivationSparsity.md
@@ -0,0 +1,146 @@
+### Activation Sparsity (experimental feature)
+
+The `sparsify_activations` algorithm is a post-training method designed to introduce sparsity into the activations of a neural network. This process reduces the number of active neurons during inference by masking out neurons based on their magnitude relative to a calibrated static threshold.
+
+The algorithm sparsifies the input of a layer by applying the following function:
+
+$$
+sparsify(X) =
+\begin{cases}
+X & \text{if } |X| > \tau \\
+0 & \text{if } |X| \le \tau
+\end{cases}
+$$
+
+The magnitude threshold $\tau$ that corresponds to a desired level of sparsity is determined by the statistical quantile value of activations collected via an input dataset:
+
+$$
+\tau = Quantile(|X|,\ target\ sparsity)
+$$
+
+`sparsify_activations` automates the process of identifying the pruning thresholds based on user-specified layers, target sparsities and input dataset.
+
+> Note: This feature is **experimental** and intended solely for evaluation of sparsity-task performance. While activation sparsity can improve inference efficiency of decoding phase for Large Language Models (LLMs) ([Liu et al., 2023](https://arxiv.org/abs/2310.17157)), it neccessitates optimized runtime kernels, which are in development.
+
+#### Example Usage
+
+Below is an example of applying `sparsify_activations` algorithm to a torch model. Optionally, you can also call `nncf.compress_weights()` before sparsification to get an optimized model with quantized weights and sparse activations.
+
+```python
+import nncf
+from nncf.experimental.torch.sparsify_activations import sparsify_activations, TargetScope
+
+model = ... # Your model
+dataset = ... # Calibration set
+
+# (Optional) Weight-only quantization
+model = nncf.compress_weights(
+ model=model,
+ mode=nncf.CompressWeightsMode.INT8_ASYM,
+ dataset=dataset,
+)
+
+# Activation sparsification
+model = sparsify_activations(
+ model=model,
+ dataset=dataset,
+ target_sparsity_by_scope={
+ TargetScope(patterns=[".*up_proj.*", ".*gate_proj.*"]): 0.3,
+ TargetScope(patterns=[".*down_proj.*",]): 0.5,
+ }
+)
+```
+
+In this example, we first conduct data-free INT8 asymmetric weight quantization on the model. Then we do activation sparsification, setting the target activation sparsity to 30% for all the layers containing the keywords "up_proj" and "gate_proj", and 50% for layers with "down_proj" keyword.
+
+#### Interface Details
+
+- `model`: The model to be sparsified. Currently only Torch backend is supported.
+- `dataset`: An `nncf.Dataset` instance used to calibrate the pruning thresholds.
+- `target_sparsity_by_scope`: A dictionary that defines the target activation sparsity level for specified layers. For each item, the key is an instance of `TargetScope` class representing the layers to match in the model's NNCF graph; the corresponding value is a float number in the range [0, 1] representing the target sparsity level. `TargetScope` supports absolute and REGEX-based name matching.
+
+ - Example:
+
+ ```python
+ {
+ # Target sparsity is 60% for node "Dummy/Linear[layer]/linear_0" in the model graph
+ TargetScope(names=["Dummy/Linear[layer]/linear_0"]): 0.6,
+ # Target sparsity is 30% for the layers whose name contains "up_proj" or "down_proj".
+ TargetScope(patterns=[".*up_proj.*", ".*down_proj.*"]): 0.3,
+ }
+ ```
+
+- `ignored_scope`: Optional. If specified, it should be an instance of `nncf.IgnoredScope` class that defines the nodes in the model graph to be ignored by this algorithm. Note that unsupported layer types are already filtered out internally, so there is no need to mention them in `ignored_scope`. The algorithm currently only supports Linear layers, as they benefit most from dynamic sparse activations by reducing memory read bandwidth for the large Linear layer weights used in LLMs.
+
+#### Evaluation results
+
+Here is the word perplexity for different language models on a subset of [wikitext dataset](https://arxiv.org/abs/1609.07843), with maximum context length set as 2048. In the table, "int8_asym" means the model weights are asymmetrically quantized to INT8. "up/gate/down" means the up, gate, and down projection layers in the [Gated Linear Units](https://arxiv.org/abs/1612.08083) (GLU) style feed forward networks. "Avg. Activation Sparsity" column shows the average activation sparsity on the evaluation samples. For example, "down50%" means that on average the input activations of all "down" layers have a sparsity of 50%.
+
+
+
+ Model |
+ Mode |
+ Avg. Activation Sparsity |
+ Word Perplexity (↓) |
+
+
+ meta-llama/Llama-2-7b-hf |
+ fp32 |
+ - |
+ 9.242 |
+
+
+ |
+ sparse_activation |
+ up/gate30% + down50% |
+ 9.508 |
+
+
+ |
+ int8_asym + sparse_activation |
+ up/gate30% + down50% |
+ 9.511 |
+
+
+ meta-llama/Meta-Llama-3-8B-Instruct |
+ fp32 |
+ - |
+ 10.802 |
+
+
+ |
+ sparse_activation |
+ up/gate30% + down50% |
+ 11.294 |
+
+
+ |
+ int8_asym + sparse_activation |
+ up/gate30% + down50% |
+ 11.302 |
+
+
+ mistralai/Mixtral-8x7B-Instruct-v0.1 |
+ fp32 |
+ - |
+ 6.224 |
+
+
+ |
+ sparse_activation |
+ up/gate40% + down50% |
+ 6.561 |
+
+
+ |
+ int8_asym + sparse_activation |
+ up/gate40% + down50% |
+ 6.579 |
+
+
+
+#### Known Limitations
+
+1. Currently activation sparsity only supports Torch backend. Consequently, this restricts the available compression modes to 8-bit integer modes when using `nncf.compress_weights()` before activation sparsification. More information on supported modes can be found at [Weights Compression](../../../../docs/usage/post_training_compression/weights_compression/Usage.md#limitations).
+2. Actual activation sparsity during inference is dynamic and per input basis, deviation from the target should be expected. In our local experiments, the statistical mean of actual activation sparsity aligned to the target when thresholds are calibrated on datasets similar to the final task.
+3. Similar to other compression methods, model accuracy and activation sparsity are trade-off at play. For LLMs like [Llama](https://llama.meta.com), it is recommended to start with 30%~50% sparsity for the Linear layers in feed-forward networks.
diff --git a/nncf/experimental/torch/sparsify_activations/__init__.py b/nncf/experimental/torch/sparsify_activations/__init__.py
new file mode 100644
index 00000000000..ecfaa78cc4f
--- /dev/null
+++ b/nncf/experimental/torch/sparsify_activations/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2024 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import sparsify_activations # noqa: F401
+from nncf.experimental.torch.sparsify_activations.target_scope import TargetScope # noqa: F401
diff --git a/nncf/experimental/torch/sparsify_activations/sparsify_activations_impl.py b/nncf/experimental/torch/sparsify_activations/sparsify_activations_impl.py
new file mode 100644
index 00000000000..83a7a418911
--- /dev/null
+++ b/nncf/experimental/torch/sparsify_activations/sparsify_activations_impl.py
@@ -0,0 +1,264 @@
+# Copyright (c) 2024 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC
+from abc import abstractmethod
+from typing import Dict, List, Optional, Type, TypeVar
+
+import nncf
+from nncf.common import factory
+from nncf.common.factory import NNCFGraphFactory
+from nncf.common.graph.graph import NNCFGraph
+from nncf.common.graph.graph import NNCFNode
+from nncf.common.graph.operator_metatypes import OperatorMetatype
+from nncf.common.logging.track_progress import track
+from nncf.common.scopes import should_consider_scope
+from nncf.common.utils.backend import BackendType
+from nncf.common.utils.backend import get_backend
+from nncf.data import Dataset
+from nncf.experimental.torch.sparsify_activations.target_scope import TargetScope
+from nncf.experimental.torch.sparsify_activations.target_scope import get_target_node_names_from_target_scope
+from nncf.scopes import IgnoredScope
+from nncf.scopes import get_ignored_node_names_from_ignored_scope
+from nncf.torch.model_creation import is_wrapped_model
+from nncf.torch.model_creation import wrap_model
+
+TModel = TypeVar("TModel")
+
+
+class SparsifyActivationsAlgoBackend(ABC):
+ """
+ Abstract class for activation sparsification algorithm backend.
+ """
+
+ CALIBRATION_TRACKING_DESC = "Conducting Activations Sparsifier Calibration"
+
+ @staticmethod
+ def do_inference(model: TModel, dataset: Dataset):
+ """
+ Conducts model inference on given dataset to calibrate the activation sparsifiers.
+
+ :param model: The model with activation sparsifiers.
+ :param dataset: The calibration dataset to update the sparsifiers.
+ """
+ engine = factory.EngineFactory.create(model)
+ for input_data in track(
+ dataset.get_inference_data(),
+ total=dataset.get_length(),
+ description=SparsifyActivationsAlgoBackend.CALIBRATION_TRACKING_DESC,
+ ):
+ engine.infer(input_data)
+
+ @property
+ @abstractmethod
+ def supported_metatypes(self) -> List[Type[OperatorMetatype]]:
+ """
+ Property for the backend-specific metatypes for supported layers.
+ """
+
+ @abstractmethod
+ def insert_sparsifiers(
+ self,
+ model: TModel,
+ graph: NNCFGraph,
+ target_sparsity_by_node: Dict[NNCFNode, float],
+ ) -> TModel:
+ """
+ Inserts the activation sparsifiers to the model.
+
+ :param model: The model to conduct activation sparsification.
+ :param graph: The model's NNCF graph.
+ :param target_sparsity_by_node: The target sparsity level for the input activation in each given node layer.
+ :return: The model with inserted activation sparsifiers.
+ """
+
+ @abstractmethod
+ def calibrate_sparsifiers(self, model: TModel, graph: NNCFGraph, dataset: Dataset) -> TModel:
+ """
+ Calibrates the thresholds in the activation sparsifiers.
+
+ :param model: The model with inserted activation sparsifiers.
+ :param graph: The model's NNCF graph.
+ :param dataset: The calibration dataset to update the thresholds in the sparsifiers.
+ :return: The model with calibrated activation sparsifiers.
+ """
+
+
+class SparsifyActivationsAlgorithm:
+ """
+ Implementation of activation sparsification algorithm.
+ """
+
+ def __init__(
+ self,
+ target_sparsity_by_scope: Dict[TargetScope, float],
+ ignored_scope: IgnoredScope,
+ ):
+ """
+ :param target_sparsity_by_scope: A dictionary that defines the target sparsity level for specified layers.
+ :param ignored_scope: An ignored scope that defines the list of model control flow
+ graph nodes to be ignored during activation sparsification.
+ """
+ self._target_sparsity_by_scope = target_sparsity_by_scope
+ self._ignored_scope = ignored_scope
+ self._backend_entity: SparsifyActivationsAlgoBackend = None
+
+ @property
+ def available_backends(self) -> List[BackendType]:
+ """
+ Supported backends for this algorithm.
+ """
+ return [BackendType.TORCH]
+
+ def apply(
+ self,
+ model: TModel,
+ graph: NNCFGraph,
+ dataset: Dataset,
+ ) -> TModel:
+ """
+ Applies the algorithm to the given model.
+
+ :param model: The model to be sparsified.
+ :param graph: The model's NNCF graph.
+ :param dataset: The dataset to calibrate the activation sparsifiers.
+ :return: The sparsified model.
+ """
+ self._set_backend_entity(model)
+ target_sparsity_by_node = self._get_target_sparsity_by_node(graph)
+ sparse_model = self.do_sparsification(model, graph, target_sparsity_by_node, dataset)
+ return sparse_model
+
+ def do_sparsification(
+ self,
+ model: TModel,
+ graph: NNCFGraph,
+ target_sparsity_by_node: Dict[NNCFNode, float],
+ dataset: Dataset,
+ ):
+ """
+ Transforms the model into a sparsified one with node-specific target activation sparsity levels.
+
+ :param model: The model to be sparsified.
+ :param graph: The model's NNCF graph.
+ :param target_sparsity_by_node: A dictionary that defines the target sparsity level
+ for specified node layers.
+ :param dataset: The dataset to calibrate the activation sparsifiers.
+ :return: The sparsified model.
+ """
+ model = self._backend_entity.insert_sparsifiers(model, graph, target_sparsity_by_node)
+ model = self._backend_entity.calibrate_sparsifiers(model, graph, dataset)
+ return model
+
+ def _set_backend_entity(self, model: TModel) -> None:
+ """
+ Creates a helper class with a backend-specific logic of the algorithm.
+
+ :param model: Backend-specific input model.
+ """
+ model_backend = get_backend(model)
+ if model_backend == BackendType.TORCH:
+ from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend
+
+ self._backend_entity = PTSparsifyActivationsAlgoBackend()
+ else:
+ raise nncf.UnsupportedBackendError(
+ f"{model_backend.value} backend is not supported for `sparsify_activations`."
+ )
+
+ def _get_target_sparsity_by_node(self, graph: NNCFGraph) -> Dict[NNCFNode, float]:
+ """
+ Collects nodes in the model's graph corresponding to the layers for sparsification.
+
+ :param graph: NNCFGraph instance.
+ :return: A dictionary with nodes and the corresponding target sparsity level.
+ """
+ supported_metatypes = self._backend_entity.supported_metatypes
+ ignored_names = get_ignored_node_names_from_ignored_scope(
+ self._ignored_scope, graph, strict=self._ignored_scope.validate
+ )
+ target_sparsity_by_node = {}
+ for scope, target_sparsity in self._target_sparsity_by_scope.items():
+ target_names = get_target_node_names_from_target_scope(scope, graph, strict=scope.validate)
+ for node_name in target_names:
+ node = graph.get_node_by_name(node_name)
+ if node.metatype not in supported_metatypes or not should_consider_scope(
+ node.node_name, ignored_scopes=ignored_names
+ ):
+ continue
+ if node in target_sparsity_by_node:
+ raise nncf.ValidationError(
+ f'"{node.node_name}" is matched by multiple items in `target_sparsity_by_scope`.'
+ )
+ target_sparsity_by_node[node] = target_sparsity
+ if not target_sparsity_by_node:
+ raise nncf.ValidationError("No layers to conduct activation sparsification.")
+ return target_sparsity_by_node
+
+
+def sparsify_activations(
+ model: TModel,
+ dataset: Dataset,
+ target_sparsity_by_scope: Dict[TargetScope, float],
+ ignored_scope: Optional[IgnoredScope] = None,
+) -> TModel:
+ """
+ Post-training activation sparsification on the given model.
+
+ This algorithm sparsifies the input activations in supported layers based on a calibration
+ dataset. The goal is to zero out neurons with small activation values around 0, thereby
+ roughly achieving the target sparsity at a statistical level.
+
+ Note that currently only linear layers are supported.
+
+ :param model: The model to be sparsified.
+ :param dataset: The dataset to calibrate the activation sparsifiers.
+ :param target_sparsity_by_scope: Defines the target activation sparsity level
+ for specified layers. For each item, the key is an instance of `TargetScope` class
+ representing the layers to match in the model's NNCF graph; the corresponding value
+ is a float number in the range [0, 1] representing the target sparsity level.
+
+ Example:
+ .. code-block:: python
+ {
+ # Target sparsity is 60% for node "Dummy/Linear[layer]/linear_0" in the model graph
+ TargetScope(names=["Dummy/Linear[layer]/linear_0"]): 0.6,
+ # Target sparsity is 30% for the layers whose name contains "up_proj" or "down_proj".
+ TargetScope(patterns=[".*up_proj.*", ".*down_proj.*"]): 0.3,
+ }
+
+ :param ignored_scope: Optional. It defines the nodes in the model graph that should be
+ ignored during activation sparsification. Note that unsupported layer types are already
+ filtered out internally, so there is no need to mention them in `ignored_scope`.
+ :return: The sparsified model.
+ """
+
+ for scope, target_sparsity in target_sparsity_by_scope.items():
+ if target_sparsity < 0.0 or target_sparsity > 1.0:
+ raise ValueError(f'Target sparsity for scope "{scope}" should be in range [0, 1].')
+
+ if ignored_scope is None:
+ ignored_scope = IgnoredScope()
+
+ backend = get_backend(model)
+ if backend == BackendType.TORCH and not is_wrapped_model(model):
+ example_input = next(iter(dataset.get_inference_data()))
+ model = wrap_model(
+ model,
+ example_input=example_input,
+ trace_parameters=True,
+ )
+
+ algorithm = SparsifyActivationsAlgorithm(target_sparsity_by_scope, ignored_scope)
+
+ graph = NNCFGraphFactory.create(model)
+ sparse_model = algorithm.apply(model, graph, dataset)
+ return sparse_model
diff --git a/nncf/experimental/torch/sparsify_activations/target_scope.py b/nncf/experimental/torch/sparsify_activations/target_scope.py
new file mode 100644
index 00000000000..eb09718224a
--- /dev/null
+++ b/nncf/experimental/torch/sparsify_activations/target_scope.py
@@ -0,0 +1,108 @@
+# Copyright (c) 2024 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Set
+
+import nncf
+from nncf.common.graph.graph import NNCFGraph
+from nncf.scopes import IgnoredScope
+from nncf.scopes import get_difference_ignored_scope
+from nncf.scopes import get_matched_ignored_scope_info
+
+
+@dataclass
+class TargetScope(IgnoredScope):
+ """
+ Specifies the target portions in a model graph.
+
+ Example:
+
+ .. code-block:: python
+ # Specified by node names:
+ node_names = ['node_1', 'node_2', 'node_3']
+ target_scope = TargetScope(names=node_names)
+
+ # Specified using regular expressions:
+ patterns = ['.*node_\\d']
+ target_scope = TargetScope(patterns=patterns)
+
+ # Specified by operation types, e.g.,
+
+ # OpenVINO opset https://docs.openvino.ai/latest/openvino_docs_ops_opset.html
+ operation_types = ['Multiply', 'GroupConvolution', 'Interpolate']
+ target_scope = TargetScope(types=operation_types)
+
+ # ONNX opset https://github.com/onnx/onnx/blob/main/docs/Operators.md
+ operation_types = ['Mul', 'Conv', 'Resize']
+ target_scope = TargetScope(types=operation_types)
+
+ # Specifies by subgraphs:
+ from nncf import Subgraph
+ target_scope = TargetScope(subgraphs=[
+ Subgraph(inputs=["node_1"], outputs=["node_3"])
+ ])
+
+ **Note:** Operation types must be specified according to the model framework.
+
+ :param names: List of target node names.
+ :param patterns: List of regular expressions that define patterns for names of target nodes.
+ :param types: List of target operation types.
+ :param subgraphs: List of target subgraphs.
+ :param validate: If set to True, then a RuntimeError will be raised if any target scope does not match
+ in the model graph.
+ """
+
+ def __hash__(self) -> int:
+ return hash(
+ (
+ frozenset(self.names),
+ frozenset(self.patterns),
+ frozenset(self.types),
+ frozenset((frozenset(subgraph.inputs), frozenset(subgraph.outputs)) for subgraph in self.subgraphs),
+ self.validate,
+ )
+ )
+
+
+def get_target_node_names_from_target_scope(
+ target_scope: TargetScope, nncf_graph: NNCFGraph, strict: bool = True
+) -> Set[str]:
+ """
+ Returns NNCF node names from the graph that are matched by target scope.
+ If strict is True, raises nncf.ValidationError if no rule is matched.
+
+ :param target_scope: Target scope specifying the matching rules.
+ :param nncf_graph: The graph.
+ :param strict: Whether target_scope must match at least one node or not.
+ :return: NNCF node names from the given graph matched by target scope.
+ """
+ matched_target_scope, matches = get_matched_ignored_scope_info(target_scope, [nncf_graph])
+ if strict:
+ _check_target_scope_strictly_matched(target_scope, matched_target_scope)
+ return set().union(*matches.values())
+
+
+def _check_target_scope_strictly_matched(target_scope: TargetScope, matched_target_scope: TargetScope):
+ """
+ Passes when target_scope and matched_target_scope are equal, otherwise raises ValidationError.
+
+ :param target_scope: The given target scope.
+ :param matched_target_scope: The actual target scope matched in a graph.
+ """
+ unmatched_scope = get_difference_ignored_scope(target_scope, matched_target_scope)
+ error_messages = []
+ for match_type in ("names", "types", "patterns", "subgraphs"):
+ unmatched_rules = getattr(unmatched_scope, match_type)
+ if unmatched_rules:
+ error_messages.append(f"The following {match_type} are not found in the graph: {unmatched_rules}.")
+ if error_messages:
+ raise nncf.ValidationError("\n".join(error_messages))
diff --git a/nncf/experimental/torch/sparsify_activations/torch_backend.py b/nncf/experimental/torch/sparsify_activations/torch_backend.py
new file mode 100644
index 00000000000..a10f12c6518
--- /dev/null
+++ b/nncf/experimental/torch/sparsify_activations/torch_backend.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2024 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Type, TypeVar
+
+import torch
+import torch.nn as nn
+
+import nncf
+from nncf.common.graph.graph import NNCFGraph
+from nncf.common.graph.graph import NNCFNode
+from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES
+from nncf.common.graph.operator_metatypes import OperatorMetatype
+from nncf.common.graph.transformations.commands import TargetType
+from nncf.data import Dataset
+from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgoBackend
+from nncf.tensor.functions.torch_numeric import quantile
+from nncf.torch.graph import operator_metatypes as om
+from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
+from nncf.torch.graph.transformations.commands import PTTargetPoint
+from nncf.torch.graph.transformations.layout import PTTransformationLayout
+from nncf.torch.model_transformer import PTModelTransformer
+from nncf.torch.nncf_network import NNCFNetwork
+from nncf.torch.utils import training_mode_switcher
+
+ACTIVATIONS_SPARSIFIER_PREFIX = "activations_sparsifier"
+TModel = TypeVar("TModel")
+
+
+class ActivationsSparsifier(nn.Module):
+ """
+ Sparsifies input activations by masking out values around zero.
+ """
+
+ def __init__(self, target_sparsity: float, alpha: float = 0.2):
+ """
+ :param target_sparsity: The target activation sparsity level.
+ :param alpha: The exponential moving average decay factor in range (0, 1) for calibrating
+ the threshold. A larger alpha will give more weight to the most recent batches.
+ """
+ super().__init__()
+ self.target_sparsity = target_sparsity
+ if alpha <= 0.0 or alpha >= 1.0:
+ raise ValueError("The decay factor `alpha` should be in range (0, 1).")
+ self.alpha = alpha
+ self.register_buffer("running_threshold", torch.tensor(float("-inf")))
+ self.register_buffer("num_batches_tracked", torch.tensor(0))
+ self.running_threshold: torch.Tensor
+ self.num_batches_tracked: torch.Tensor
+ self._freeze = True
+
+ @staticmethod
+ def calculate_threshold(x: torch.Tensor, target_sparsity: float) -> torch.Tensor:
+ """
+ Calculates the threshold to sparsify the input tensor with target sparsity if locations of
+ `x.abs() <= threshold` are zeroed out.
+
+ :param x: The input tensor.
+ :param target_sparsity: The target sparsity level on the input tensor.
+ :return: The threshold value.
+ """
+ return quantile(x.detach().abs().view(-1), q=target_sparsity, axis=0)
+
+ @property
+ def freeze(self):
+ return self._freeze
+
+ @freeze.setter
+ def freeze(self, value: bool):
+ self._freeze = value
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if not self.freeze:
+ threshold = self.calculate_threshold(x, self.target_sparsity)
+ self._update(threshold, dtype=x.dtype)
+ mask = torch.le(x.abs(), self.running_threshold)
+ x = torch.masked_fill(x, mask, 0.0)
+ return x
+
+ def reset_running_stats(self):
+ """
+ Resets the running threshold and the number of tracked batches to the initial stage.
+ """
+ self.running_threshold.fill_(float("-inf"))
+ self.num_batches_tracked.zero_()
+
+ def extra_repr(self) -> str:
+ return f"target_sparsity={self.target_sparsity}"
+
+ def _update(self, threshold: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ """
+ Updates the running threshold by exponential moving average with decaying adjustment.
+ The updating logic is similar to `pandas.DataFrame.ewm(adjust=True)`.
+
+ :param threshold: The threshold value derived from this batch to update the running threshold.
+ :param dtype: Data type of the updated running threshold.
+ :return: The updated running threshold.
+ """
+ if self.num_batches_tracked == 0:
+ running_threshold = threshold
+ else:
+ beta = 1.0 - self.alpha
+ old_running_threshold = self.running_threshold.to(device=threshold.device, dtype=torch.float64)
+ running_threshold = (
+ threshold.to(torch.float64) * self.alpha
+ + old_running_threshold * beta * (1 - beta**self.num_batches_tracked)
+ ) / (1 - beta ** (self.num_batches_tracked + 1))
+ self.running_threshold = running_threshold.type(dtype)
+ self.num_batches_tracked += 1
+ return self.running_threshold
+
+
+class PTSparsifyActivationsAlgoBackend(SparsifyActivationsAlgoBackend):
+ """
+ Torch backend for the activation sparsification algorithm.
+ """
+
+ SUPPORTED_METATYPES = [om.PTLinearMetatype]
+
+ @staticmethod
+ def get_sparsifiers(model: NNCFNetwork) -> List[ActivationsSparsifier]:
+ """
+ Finds all the activation sparsifiers in the model.
+
+ :param model: The model with activation sparsifiers.
+ :return: List of activation sparsifiers.
+ """
+ return [m for m in model.nncf.modules() if isinstance(m, ActivationsSparsifier)]
+
+ @property
+ def supported_metatypes(self) -> List[Type[OperatorMetatype]]:
+ return PTSparsifyActivationsAlgoBackend.SUPPORTED_METATYPES
+
+ def insert_sparsifiers(
+ self,
+ model: NNCFNetwork,
+ graph: NNCFGraph,
+ target_sparsity_by_node: Dict[NNCFNode, float],
+ ) -> NNCFNetwork:
+ transformation_layout = PTTransformationLayout()
+ for node, target_sparsity in target_sparsity_by_node.items():
+ activation_port_id = self._get_activation_port_id(node, graph)
+ sparsifier = ActivationsSparsifier(target_sparsity=target_sparsity)
+ sparsifier_name = f"{ACTIVATIONS_SPARSIFIER_PREFIX}_{node.node_name.replace('.', '_')}"
+ transformation_layout.register(
+ PTSharedFnInsertionCommand(
+ [
+ PTTargetPoint(
+ target_type=TargetType.PRE_LAYER_OPERATION,
+ target_node_name=node.node_name,
+ input_port_id=activation_port_id,
+ )
+ ],
+ sparsifier,
+ sparsifier_name,
+ )
+ )
+
+ transformed_model = PTModelTransformer(model).transform(transformation_layout)
+ return transformed_model
+
+ def calibrate_sparsifiers(self, model: NNCFNetwork, graph: NNCFGraph, dataset: Dataset) -> NNCFNetwork:
+ sparsifiers = self.get_sparsifiers(model)
+ for sparsifier in sparsifiers:
+ sparsifier.reset_running_stats()
+ sparsifier.freeze = False
+ with training_mode_switcher(model, is_training=False):
+ with torch.no_grad():
+ self.do_inference(model, dataset)
+ for sparsifier in sparsifiers:
+ sparsifier.freeze = True
+ return model
+
+ @staticmethod
+ def _get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int:
+ """
+ Finds the input activation port id for the node.
+
+ :param node: The node to find its activation port id.
+ :param graph: The NNCF graph containing the node.
+ :return: The activation port id.
+ """
+ activation_ports = []
+ for prev_node in graph.get_previous_nodes(node):
+ edge = graph.get_edge(prev_node, node)
+ if prev_node.metatype in CONST_NOOP_METATYPES or edge.input_port_id in node.metatype.weight_port_ids:
+ continue
+ activation_ports.append(edge.input_port_id)
+ if len(activation_ports) != 1:
+ raise nncf.InternalError(f'Cannot find activation port for node "{node}".')
+ return activation_ports[0]
diff --git a/tests/post_training/experimental/sparsify_activations/model_scope.py b/tests/post_training/experimental/sparsify_activations/model_scope.py
new file mode 100644
index 00000000000..5a89847829d
--- /dev/null
+++ b/tests/post_training/experimental/sparsify_activations/model_scope.py
@@ -0,0 +1,116 @@
+# Copyright (c) 2024 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from typing import Dict, List
+
+import nncf
+from nncf.experimental.torch.sparsify_activations import TargetScope
+from nncf.parameters import CompressWeightsMode
+from tests.post_training.experimental.sparsify_activations.pipelines import ImageClassificationTimmSparsifyActivations
+from tests.post_training.experimental.sparsify_activations.pipelines import LMSparsifyActivations
+from tests.post_training.pipelines.base import BackendType
+
+SPARSIFY_ACTIVATIONS_MODELS = [
+ {
+ "reported_name": "tinyllama",
+ "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b",
+ "pipeline_cls": LMSparsifyActivations,
+ "compression_params": {},
+ "backends": [BackendType.FP32],
+ },
+ {
+ "reported_name": "tinyllama_ffn_sparse20",
+ "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b",
+ "pipeline_cls": LMSparsifyActivations,
+ "compression_params": {
+ "compress_weights": None,
+ "sparsify_activations": {
+ "target_sparsity_by_scope": {
+ TargetScope(patterns=[".*up_proj.*", ".*gate_proj.*", ".*down_proj.*"]): 0.2,
+ }
+ },
+ },
+ "backends": [BackendType.TORCH, BackendType.CUDA_TORCH],
+ "batch_size": 8,
+ },
+ {
+ "reported_name": "tinyllama_int8_asym_data_free_ffn_sparse20",
+ "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b",
+ "pipeline_cls": LMSparsifyActivations,
+ "compression_params": {
+ "compress_weights": {
+ "mode": CompressWeightsMode.INT8_ASYM,
+ },
+ "sparsify_activations": {
+ "target_sparsity_by_scope": {
+ TargetScope(patterns=[".*up_proj.*", ".*gate_proj.*", ".*down_proj.*"]): 0.2,
+ }
+ },
+ },
+ "backends": [BackendType.TORCH, BackendType.CUDA_TORCH],
+ "batch_size": 8,
+ },
+ {
+ "reported_name": "timm/deit3_small_patch16_224",
+ "model_id": "deit3_small_patch16_224",
+ "pipeline_cls": ImageClassificationTimmSparsifyActivations,
+ "compression_params": {},
+ "backends": [BackendType.FP32],
+ "batch_size": 128,
+ },
+ {
+ "reported_name": "timm/deit3_small_patch16_224_qkv_sparse20_fc1_sparse20_fc2_sparse30",
+ "model_id": "deit3_small_patch16_224",
+ "pipeline_cls": ImageClassificationTimmSparsifyActivations,
+ "compression_params": {
+ "sparsify_activations": {
+ "target_sparsity_by_scope": {
+ TargetScope(patterns=[".*qkv.*", ".*fc1.*"]): 0.2,
+ TargetScope(patterns=[".*fc2.*"]): 0.3,
+ }
+ },
+ },
+ "backends": [BackendType.TORCH, BackendType.CUDA_TORCH],
+ "batch_size": 128,
+ },
+]
+
+
+def generate_tests_scope(models_list: List[Dict]) -> Dict[str, Dict]:
+ """
+ Generate tests by names "{reported_name}_backend_{backend}"
+ """
+ tests_scope = {}
+ fp32_models = set()
+ for test_model_param in models_list:
+ model_id = test_model_param["model_id"]
+ reported_name = test_model_param["reported_name"]
+
+ for backend in test_model_param["backends"]:
+ model_param = copy.deepcopy(test_model_param)
+ if "is_batch_size_supported" not in model_param: # Set default value of is_batch_size_supported.
+ model_param["is_batch_size_supported"] = True
+ test_case_name = f"{reported_name}_backend_{backend.value}"
+ model_param["backend"] = backend
+ model_param.pop("backends")
+ if backend == BackendType.FP32:
+ if model_id in fp32_models:
+ raise nncf.ValidationError(f"Duplicate test case for {model_id} with FP32 backend")
+ fp32_models.add(model_id)
+ if test_case_name in tests_scope:
+ raise nncf.ValidationError(f"{test_case_name} already in tests_scope")
+ tests_scope[test_case_name] = model_param
+
+ return tests_scope
+
+
+SPARSIFY_ACTIVATIONS_TEST_CASES = generate_tests_scope(SPARSIFY_ACTIVATIONS_MODELS)
diff --git a/tests/post_training/experimental/sparsify_activations/pipelines.py b/tests/post_training/experimental/sparsify_activations/pipelines.py
new file mode 100644
index 00000000000..82da57caa86
--- /dev/null
+++ b/tests/post_training/experimental/sparsify_activations/pipelines.py
@@ -0,0 +1,323 @@
+# Copyright (c) 2024 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from dataclasses import field
+from pathlib import Path
+from typing import Dict, List, Optional
+
+import numpy as np
+import openvino as ov
+import torch
+import torch.utils
+import torch.utils.data
+import torchvision
+from datasets import load_dataset
+from optimum.exporters.openvino.convert import export_from_model
+from optimum.intel.openvino import OVModelForCausalLM
+from transformers import AutoModelForCausalLM
+
+import nncf
+from nncf.experimental.torch.sparsify_activations import sparsify_activations
+from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgoBackend
+from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend
+from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor
+from nncf.torch.quantization.layers import SymmetricWeightsDecompressor
+from tests.post_training.pipelines.base import LIMIT_LENGTH_OF_STATUS
+from tests.post_training.pipelines.base import PT_BACKENDS
+from tests.post_training.pipelines.base import BackendType
+from tests.post_training.pipelines.base import NumCompressNodes
+from tests.post_training.pipelines.base import RunInfo
+from tests.post_training.pipelines.image_classification_timm import ImageClassificationTimm
+from tests.post_training.pipelines.lm_weight_compression import LMWeightCompression
+from tests.post_training.pipelines.lm_weight_compression import WCTimeStats
+from tests.torch.experimental.sparsify_activations.helpers import count_sparsifier_patterns_in_ov
+from tests.torch.helpers import set_torch_seed
+
+
+@dataclass
+class SATimeStats(WCTimeStats):
+ """
+ Contains statistics that are parsed from the stdout of Sparsify Activations tests.
+ """
+
+ time_sparsifier_calibration: Optional[str] = None
+ STAT_NAMES = [*WCTimeStats.STAT_NAMES, "Activations Sparsifier calibration time"]
+ VAR_NAMES = [*WCTimeStats.VAR_NAMES, "time_sparsifier_calibration"]
+ REGEX_PREFIX = [*WCTimeStats.REGEX_PREFIX, SparsifyActivationsAlgoBackend.CALIBRATION_TRACKING_DESC]
+
+
+@dataclass
+class SANumCompressNodes(NumCompressNodes):
+ num_sparse_activations: Optional[int] = None
+
+
+@dataclass
+class SARunInfo(RunInfo):
+ num_compress_nodes: SANumCompressNodes = field(default_factory=SANumCompressNodes)
+
+ def get_result_dict(self):
+ return {
+ "Model": self.model,
+ "Backend": self.backend.value if self.backend else None,
+ "Metric name": self.metric_name,
+ "Metric value": self.metric_value,
+ "Metric diff": self.metric_diff,
+ "Num FQ": self.num_compress_nodes.num_fq_nodes,
+ "Num int4": self.num_compress_nodes.num_int4,
+ "Num int8": self.num_compress_nodes.num_int8,
+ "Num sparse activations": self.num_compress_nodes.num_sparse_activations,
+ "RAM MiB": self.format_memory_usage(self.compression_memory_usage),
+ "Compr. time": self.format_time(self.time_compression),
+ **self.stats_from_output.get_stats(),
+ "Total time": self.format_time(self.time_total),
+ "FPS": self.fps,
+ "Status": self.status[:LIMIT_LENGTH_OF_STATUS] if self.status is not None else None,
+ }
+
+
+class SAPipelineMixin:
+ """
+ Common methods in the test pipeline for Sparsify Activations.
+ """
+
+ def __init__(
+ self,
+ reported_name: str,
+ model_id: str,
+ backend: BackendType,
+ compression_params: dict,
+ output_dir: Path,
+ data_dir: Path,
+ reference_data: dict,
+ no_eval: bool,
+ run_benchmark_app: bool,
+ params: dict = None,
+ batch_size: int = 1,
+ ):
+ super().__init__(
+ reported_name=reported_name,
+ model_id=model_id,
+ backend=backend,
+ compression_params=compression_params,
+ output_dir=output_dir,
+ data_dir=data_dir,
+ reference_data=reference_data,
+ no_eval=no_eval,
+ run_benchmark_app=run_benchmark_app,
+ params=params,
+ batch_size=batch_size,
+ )
+ self.run_info = SARunInfo(model=reported_name, backend=backend)
+
+ @staticmethod
+ def count_compressed_nodes_from_ir(model: ov.Model) -> SANumCompressNodes:
+ """
+ Get number of compressed nodes in the compressed IR.
+ """
+ num_fq_nodes = 0
+ num_int8 = 0
+ num_int4 = 0
+ for node in model.get_ops():
+ if node.type_info.name == "FakeQuantize":
+ num_fq_nodes += 1
+ for i in range(node.get_output_size()):
+ if node.get_output_element_type(i).get_type_name() in ["i8", "u8"]:
+ num_int8 += 1
+ if node.get_output_element_type(i).get_type_name() in ["i4", "u4"]:
+ num_int4 += 1
+
+ num_sparse_activations = count_sparsifier_patterns_in_ov(model)
+ return SANumCompressNodes(
+ num_fq_nodes=num_fq_nodes,
+ num_int8=num_int8,
+ num_int4=num_int4,
+ num_sparse_activations=num_sparse_activations,
+ )
+
+ def collect_data_from_stdout(self, stdout: str):
+ stats = SATimeStats()
+ stats.fill(stdout)
+ self.run_info.stats_from_output = stats
+
+ @set_torch_seed(seed=42)
+ @torch.no_grad()
+ def _compress(self):
+ """
+ Actual call of weight compression and/or activation sparsification.
+ """
+ self.compressed_model = self.model
+ if self.compression_params.get("compress_weights", None) is not None:
+ self.compressed_model = nncf.compress_weights(
+ self.compressed_model,
+ dataset=self.calibration_dataset,
+ **self.compression_params["compress_weights"],
+ )
+ if self.compression_params.get("sparsify_activations", None) is not None:
+ self.compressed_model = sparsify_activations(
+ self.compressed_model,
+ dataset=self.calibration_dataset,
+ **self.compression_params["sparsify_activations"],
+ )
+
+ def _validate(self):
+ super()._validate()
+ ref_num_sparse_activations = self.reference_data.get("num_sparse_activations", 0)
+ num_sparse_activations = self.run_info.num_compress_nodes.num_sparse_activations
+ if num_sparse_activations != ref_num_sparse_activations:
+ status_msg = f"Regression: The number of sparse activations is {num_sparse_activations}, \
+ which differs from reference {ref_num_sparse_activations}."
+ raise ValueError(status_msg)
+
+
+class LMSparsifyActivations(SAPipelineMixin, LMWeightCompression):
+ DEFAULT_SUBSET_SIZE = 32
+
+ def prepare_model(self):
+ is_stateful = self.params.get("is_stateful", False)
+
+ if self.backend in PT_BACKENDS:
+ if is_stateful:
+ raise RuntimeError(f"is_stateful={is_stateful} is not supported for PyTorch backend.")
+
+ self.model_hf = AutoModelForCausalLM.from_pretrained(
+ self.model_id,
+ torch_dtype=torch.float32,
+ device_map="cuda" if self.backend == BackendType.CUDA_TORCH else "cpu",
+ attn_implementation="eager",
+ )
+ self.model = self.model_hf
+ elif self.backend in [BackendType.OV, BackendType.FP32]:
+ if is_stateful:
+ self.fp32_model_dir = self.fp32_model_dir.parent / (self.fp32_model_dir.name + "_sf")
+ if not (self.fp32_model_dir / self.OV_MODEL_NAME).exists():
+ # export by model_id
+ self.model_hf = OVModelForCausalLM.from_pretrained(
+ self.model_id,
+ trust_remote_code=True,
+ export=True,
+ load_in_8bit=False,
+ compile=False,
+ stateful=is_stateful,
+ )
+ else:
+ # no export, load from IR. Applicable for sequential run of test cases in local environment.
+ self.model_hf = OVModelForCausalLM.from_pretrained(
+ self.fp32_model_dir, load_in_8bit=False, compile=False, stateful=is_stateful
+ )
+ self.model = self.model_hf.model
+ else:
+ raise RuntimeError(f"backend={self.backend.value} is not supported.")
+
+ if not (self.fp32_model_dir / self.OV_MODEL_NAME).exists():
+ self._dump_model_fp32()
+
+ # Use FP16 for CUDA_TORCH backend as it is more common when running LLM on CUDA.
+ if self.backend == BackendType.CUDA_TORCH:
+ self.model_hf.half()
+
+ def get_transform_calibration_fn(self):
+ process_one = super().get_transform_calibration_fn()
+
+ def transform_fn(chunk: List[Dict]):
+ samples = [process_one(data, max_tokens=128, filter_bad_tokens=False) for data in chunk]
+ inputs = {}
+ for input_name, sample_value in samples[0].items():
+ if isinstance(sample_value, torch.Tensor):
+ inputs[input_name] = torch.cat([sample[input_name] for sample in samples], dim=0)
+ elif isinstance(sample_value, np.ndarray):
+ inputs[input_name] = np.concatenate([sample[input_name] for sample in samples], axis=0)
+ elif isinstance(sample_value, ov.Tensor):
+ shape = sample_value.get_shape()
+ shape[0] = len(samples)
+ inputs[input_name] = ov.Tensor(sample_value.get_element_type(), shape)
+ else:
+ raise RuntimeError(
+ f"Failed to generate calibration set for {input_name} in type {type(sample_value)}"
+ )
+ if self.backend == BackendType.CUDA_TORCH:
+ for input_name in inputs:
+ inputs[input_name] = torch.from_numpy(inputs[input_name]).cuda()
+ return inputs
+
+ return transform_fn
+
+ def prepare_calibration_dataset(self):
+ subset_size = self.compression_params.get("subset_size") or self.DEFAULT_SUBSET_SIZE
+ dataset = (
+ load_dataset("wikitext", "wikitext-2-v1", split="train", revision="b08601e")
+ .filter(lambda example: len(example["text"].split()) > 256)
+ .shuffle(seed=42)
+ .select(range(subset_size))
+ .to_list()
+ )
+ chunks = [dataset[i : i + self.batch_size] for i in range(0, subset_size, self.batch_size)]
+ self.calibration_dataset = nncf.Dataset(chunks, self.get_transform_calibration_fn())
+
+ def save_compressed_model(self):
+ if self.backend == BackendType.CUDA_TORCH:
+ self.model_hf.float()
+ for module in self.model_hf.nncf.modules():
+ if isinstance(module, (AsymmetricWeightsDecompressor, SymmetricWeightsDecompressor)):
+ module.result_dtype = torch.float32
+ export_from_model(
+ self.model_hf, self.output_model_dir, stateful=False, compression_option="fp32", device="cuda"
+ )
+ else:
+ super().save_compressed_model()
+
+ def get_num_compressed(self):
+ """
+ Get number of quantization ops and sparsifier ops in the compressed IR.
+ """
+ if self.backend in PT_BACKENDS:
+ model = ov.Core().read_model(self.output_model_dir / self.OV_MODEL_NAME)
+ else:
+ model = self.model
+ self.run_info.num_compress_nodes = self.count_compressed_nodes_from_ir(model)
+
+ def _dump_model_fp32(self):
+ if self.backend == BackendType.CUDA_TORCH:
+ export_from_model(
+ self.model_hf, self.fp32_model_dir, stateful=False, compression_option="fp32", device="cuda"
+ )
+ else:
+ super()._dump_model_fp32()
+
+ def _compress(self):
+ super()._compress()
+ if self.backend in PT_BACKENDS:
+ # This helps reproducibility but is not needed in actual use.
+ for sparsifier in PTSparsifyActivationsAlgoBackend.get_sparsifiers(self.compressed_model):
+ original_dtype = sparsifier.running_threshold.dtype
+ sparsifier.running_threshold = sparsifier.running_threshold.half().to(original_dtype)
+
+
+class ImageClassificationTimmSparsifyActivations(SAPipelineMixin, ImageClassificationTimm):
+ DEFAULT_SUBSET_SIZE = 256
+
+ def prepare_calibration_dataset(self):
+ subset_size = self.compression_params.get("subset_size") or self.DEFAULT_SUBSET_SIZE
+ val_dataset = torchvision.datasets.ImageFolder(
+ root=self.data_dir / "imagenet" / "val", transform=self.transform
+ )
+ indices = np.random.default_rng(42).choice(len(val_dataset), size=subset_size, replace=False)
+ subset = torch.utils.data.Subset(val_dataset, indices=indices)
+ loader = torch.utils.data.DataLoader(subset, batch_size=self.batch_size, num_workers=2, shuffle=False)
+ self.calibration_dataset = nncf.Dataset(loader, self.get_transform_calibration_fn())
+
+ def get_num_compressed(self):
+ """
+ Get number of quantization ops and sparsifier ops in the compressed IR.
+ """
+ model = ov.Core().read_model(model=self.path_compressed_ir)
+ self.run_info.num_compress_nodes = self.count_compressed_nodes_from_ir(model)
diff --git a/tests/post_training/experimental/sparsify_activations/reference_data.yaml b/tests/post_training/experimental/sparsify_activations/reference_data.yaml
new file mode 100644
index 00000000000..3e368a9c185
--- /dev/null
+++ b/tests/post_training/experimental/sparsify_activations/reference_data.yaml
@@ -0,0 +1,47 @@
+tinyllama_backend_FP32:
+ metric_value: 1.0
+ num_int4: 0
+ num_int8: 0
+ num_sparse_activations: 0
+tinyllama_ffn_sparse20_backend_CUDA_TORCH:
+ metric_value: 0.7818
+ atol: 0.025
+ num_int4: 0
+ num_int8: 0
+ num_sparse_activations: 44
+tinyllama_ffn_sparse20_backend_TORCH:
+ metric_value: 0.7879
+ atol: 0.025
+ num_int4: 0
+ num_int8: 0
+ num_sparse_activations: 44
+tinyllama_int8_asym_data_free_ffn_sparse20_backend_CUDA_TORCH:
+ metric_value: 0.8044
+ atol: 0.025
+ num_int4: 0
+ num_int8: 312
+ num_sparse_activations: 44
+tinyllama_int8_asym_data_free_ffn_sparse20_backend_TORCH:
+ metric_value: 0.7846
+ atol: 0.030
+ num_int4: 0
+ num_int8: 312
+ num_sparse_activations: 44
+timm/deit3_small_patch16_224_backend_FP32:
+ metric_value: 0.8135
+ atol: 0.001
+ num_int4: 0
+ num_int8: 0
+ num_sparse_activations: 0
+timm/deit3_small_patch16_224_qkv_sparse20_fc1_sparse20_fc2_sparse30_backend_CUDA_TORCH:
+ metric_value: 0.8102
+ atol: 0.001
+ num_int4: 0
+ num_int8: 0
+ num_sparse_activations: 36
+timm/deit3_small_patch16_224_qkv_sparse20_fc1_sparse20_fc2_sparse30_backend_TORCH:
+ metric_value: 0.8102
+ atol: 0.001
+ num_int4: 0
+ num_int8: 0
+ num_sparse_activations: 36
\ No newline at end of file
diff --git a/tests/post_training/experimental/sparsify_activations/test_sparsify_activations_conformance.py b/tests/post_training/experimental/sparsify_activations/test_sparsify_activations_conformance.py
new file mode 100644
index 00000000000..ebcb1921981
--- /dev/null
+++ b/tests/post_training/experimental/sparsify_activations/test_sparsify_activations_conformance.py
@@ -0,0 +1,161 @@
+# Copyright (c) 2024 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import time
+import traceback
+from collections import OrderedDict
+from pathlib import Path
+from typing import Dict, Optional
+
+import pandas as pd
+import pytest
+import yaml
+
+from tests.post_training.experimental.sparsify_activations.model_scope import SPARSIFY_ACTIVATIONS_TEST_CASES
+from tests.post_training.experimental.sparsify_activations.pipelines import SARunInfo
+from tests.post_training.pipelines.base import BackendType
+from tests.post_training.pipelines.base import BaseTestPipeline
+from tests.post_training.test_quantize_conformance import create_short_run_info
+from tests.post_training.test_quantize_conformance import fixture_batch_size # noqa: F401
+from tests.post_training.test_quantize_conformance import fixture_data # noqa: F401
+from tests.post_training.test_quantize_conformance import fixture_extra_columns # noqa: F401
+from tests.post_training.test_quantize_conformance import fixture_no_eval # noqa: F401
+from tests.post_training.test_quantize_conformance import fixture_output # noqa: F401
+from tests.post_training.test_quantize_conformance import fixture_run_benchmark_app # noqa: F401
+from tests.post_training.test_quantize_conformance import fixture_run_fp32_backend # noqa: F401
+from tests.post_training.test_quantize_conformance import fixture_run_torch_cuda_backend # noqa: F401
+from tests.post_training.test_quantize_conformance import fixture_subset_size # noqa: F401
+from tests.post_training.test_quantize_conformance import maybe_skip_test_case
+from tests.post_training.test_quantize_conformance import write_logs
+
+
+@pytest.fixture(scope="session", name="sparsify_activations_reference_data")
+def fixture_sparsify_activations_reference_data():
+ path_reference = Path(__file__).parent / "reference_data.yaml"
+ with path_reference.open() as f:
+ data = yaml.safe_load(f)
+ for test_case in data.values():
+ test_case["atol"] = test_case.get("atol", 1e-3)
+ return data
+
+
+@pytest.fixture(scope="session", name="sparsify_activations_result_data")
+def fixture_sparsify_activations_report_data(output_dir):
+ data: Dict[str, SARunInfo] = {}
+ yield data
+ if data:
+ test_results = OrderedDict(sorted(data.items()))
+ df = pd.DataFrame(v.get_result_dict() for v in test_results.values())
+ output_dir.mkdir(parents=True, exist_ok=True)
+ df.to_csv(output_dir / "results.csv", index=False)
+
+
+def create_pipeline_kwargs(
+ test_model_param: Dict,
+ subset_size,
+ test_case_name: str,
+ reference_data: Dict[str, Dict],
+ fp32_model_params: Dict[str, Dict],
+):
+ if subset_size:
+ if "compression_params" not in test_model_param:
+ test_model_param["compression_params"] = {}
+ test_model_param["compression_params"]["subset_size"] = subset_size
+
+ print("\n")
+ print(f"Model: {test_model_param['reported_name']}")
+ print(f"Backend: {test_model_param['backend']}")
+ print(f"Comprssion params: {test_model_param['compression_params']}")
+
+ # Get target fp32 metric value
+ model_id = test_model_param["model_id"]
+ fp32_test_case_name = fp32_model_params[model_id]["reported_name"] + f"_backend_{BackendType.FP32.value}"
+ test_reference = reference_data[test_case_name]
+ test_reference["metric_value_fp32"] = reference_data[fp32_test_case_name]["metric_value"]
+
+ return {
+ "reported_name": test_model_param["reported_name"],
+ "model_id": test_model_param["model_id"],
+ "backend": test_model_param["backend"],
+ "compression_params": test_model_param["compression_params"],
+ "params": test_model_param.get("params"),
+ "reference_data": test_reference,
+ }
+
+
+@pytest.mark.parametrize("test_case_name", SPARSIFY_ACTIVATIONS_TEST_CASES.keys())
+def test_sparsify_activations(
+ sparsify_activations_reference_data: dict,
+ test_case_name: str,
+ data_dir: Path,
+ output_dir: Path,
+ sparsify_activations_result_data: Dict[str, SARunInfo],
+ no_eval: bool,
+ batch_size: int,
+ run_fp32_backend: bool,
+ run_torch_cuda_backend: bool,
+ subset_size: Optional[int],
+ run_benchmark_app: bool,
+ capsys: pytest.CaptureFixture,
+ extra_columns: bool,
+):
+ pipeline = None
+ err_msg = None
+ test_model_param = None
+ start_time = time.perf_counter()
+ try:
+ if test_case_name not in sparsify_activations_reference_data:
+ raise RuntimeError(f"{test_case_name} is not defined in `sparsify_activations_reference_data` fixture")
+ test_model_param = SPARSIFY_ACTIVATIONS_TEST_CASES[test_case_name]
+ maybe_skip_test_case(test_model_param, run_fp32_backend, run_torch_cuda_backend, batch_size)
+ fp32_model_params = {
+ tc["model_id"]: tc for tc in SPARSIFY_ACTIVATIONS_TEST_CASES.values() if tc["backend"] == BackendType.FP32
+ }
+ pipeline_cls = test_model_param["pipeline_cls"]
+ pipeline_kwargs = create_pipeline_kwargs(
+ test_model_param, subset_size, test_case_name, sparsify_activations_reference_data, fp32_model_params
+ )
+ calibration_batch_size = batch_size or test_model_param.get("batch_size", 1)
+ pipeline_kwargs.update(
+ {
+ "output_dir": output_dir,
+ "data_dir": data_dir,
+ "no_eval": no_eval,
+ "run_benchmark_app": run_benchmark_app,
+ "batch_size": calibration_batch_size,
+ }
+ )
+ pipeline: BaseTestPipeline = pipeline_cls(**pipeline_kwargs)
+ pipeline.run()
+ except Exception as e:
+ err_msg = str(e)
+ traceback.print_exc()
+
+ if pipeline is not None:
+ pipeline.cleanup_cache()
+ run_info = pipeline.run_info
+ if err_msg:
+ run_info.status = f"{run_info.status} | {err_msg}" if run_info.status else err_msg
+
+ captured = capsys.readouterr()
+ write_logs(captured, pipeline)
+
+ if extra_columns:
+ pipeline.collect_data_from_stdout(captured.out)
+ else:
+ run_info = create_short_run_info(test_model_param, err_msg, test_case_name)
+
+ run_info.time_total = time.perf_counter() - start_time
+ sparsify_activations_result_data[test_case_name] = run_info
+
+ if err_msg:
+ pytest.fail(err_msg)
diff --git a/tests/post_training/pipelines/lm_weight_compression.py b/tests/post_training/pipelines/lm_weight_compression.py
index 31266d172f9..27479fe6a50 100644
--- a/tests/post_training/pipelines/lm_weight_compression.py
+++ b/tests/post_training/pipelines/lm_weight_compression.py
@@ -110,12 +110,14 @@ def prepare_preprocessor(self) -> None:
self.preprocessor = AutoTokenizer.from_pretrained(self.model_id)
def get_transform_calibration_fn(self):
- def transform_fn(data, max_tokens=128):
+ def transform_fn(data, max_tokens=128, filter_bad_tokens=True):
tokenized_text = self.preprocessor(data["text"], return_tensors="np")
-
- bad_tokens = self.preprocessor("", return_tensors="np")["input_ids"]
raw_tokens = tokenized_text["input_ids"][0, :]
- filtered_tokens = np.array(list(filter(lambda x: x not in bad_tokens, raw_tokens)))
+ if filter_bad_tokens:
+ bad_tokens = self.preprocessor("", return_tensors="np")["input_ids"]
+ filtered_tokens = np.array(list(filter(lambda x: x not in bad_tokens, raw_tokens)))
+ else:
+ filtered_tokens = raw_tokens
tokenized_text["input_ids"] = np.expand_dims(filtered_tokens, 0)
tokenized_text["attention_mask"] = tokenized_text["attention_mask"][:, : filtered_tokens.shape[0]]
diff --git a/tests/torch/data/experimental/sparsify_activations/dummy_llama_int8_sym_weights_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/dummy_llama_int8_sym_weights_sparse_activations.dot
new file mode 100644
index 00000000000..c3e5cf0d0c9
--- /dev/null
+++ b/tests/torch/data/experimental/sparsify_activations/dummy_llama_int8_sym_weights_sparse_activations.dot
@@ -0,0 +1,488 @@
+strict digraph {
+"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
+"1 model.embed_tokens.weight" [id=1, type=nncf_model_const];
+"2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=2, type=decompress_symmetric];
+"3 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/type_0" [id=3, type=type];
+"4 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0" [id=4, type=embedding];
+"5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0" [id=5, type=to];
+"6 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/pow_0" [id=6, type=pow];
+"7 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/mean_0" [id=7, type=mean];
+"8 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__add___0" [id=8, type=__add__];
+"9 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/rsqrt_0" [id=9, type=rsqrt];
+"10 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___0" [id=10, type=__mul__];
+"11 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_1" [id=11, type=to];
+"12 model.layers.0.input_layernorm.weight" [id=12, type=nncf_model_const];
+"13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" [id=13, type=__mul__];
+"14 model.layers.0.self_attn.q_proj.weight" [id=14, type=nncf_model_const];
+"15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=15, type=decompress_symmetric];
+"16 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0" [id=16, type=type];
+"17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" [id=17, type=linear];
+"18 model.layers.0.self_attn.k_proj.weight" [id=18, type=nncf_model_const];
+"19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=19, type=decompress_symmetric];
+"20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0" [id=20, type=type];
+"21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" [id=21, type=linear];
+"22 model.layers.0.self_attn.v_proj.weight" [id=22, type=nncf_model_const];
+"23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=23, type=decompress_symmetric];
+"24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0" [id=24, type=type];
+"25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" [id=25, type=linear];
+"26 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_0" [id=26, type=view];
+"27 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0" [id=27, type=transpose];
+"28 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_1" [id=28, type=view];
+"29 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_1" [id=29, type=transpose];
+"30 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_2" [id=30, type=view];
+"31 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_2" [id=31, type=transpose];
+"32 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" [id=32, type=cat];
+"33 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0" [id=33, type=cos];
+"34 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0" [id=34, type=sin];
+"35 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0" [id=35, type=to];
+"36 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1" [id=36, type=to];
+"37 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_0" [id=37, type=unsqueeze];
+"38 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_1" [id=38, type=unsqueeze];
+"39 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___0" [id=39, type=__mul__];
+"40 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___0" [id=40, type=__getitem__];
+"41 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___1" [id=41, type=__getitem__];
+"42 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___0" [id=42, type=__neg__];
+"43 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_0" [id=43, type=cat];
+"44 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___1" [id=44, type=__mul__];
+"45 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___0" [id=45, type=__add__];
+"46 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___2" [id=46, type=__mul__];
+"47 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___2" [id=47, type=__getitem__];
+"48 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___3" [id=48, type=__getitem__];
+"49 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___1" [id=49, type=__neg__];
+"50 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_1" [id=50, type=cat];
+"51 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___3" [id=51, type=__mul__];
+"52 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___1" [id=52, type=__add__];
+"53 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___4" [id=53, type=__getitem__];
+"54 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_0" [id=54, type=expand];
+"55 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_0" [id=55, type=reshape];
+"56 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___5" [id=56, type=__getitem__];
+"57 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_1" [id=57, type=expand];
+"58 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_1" [id=58, type=reshape];
+"59 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_3" [id=59, type=transpose];
+"60 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_0" [id=60, type=matmul];
+"61 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__truediv___0" [id=61, type=__truediv__];
+"62 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___2" [id=62, type=__add__];
+"63 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/softmax_0" [id=63, type=softmax];
+"64 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/to_0" [id=64, type=to];
+"65 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/dropout_0" [id=65, type=dropout];
+"66 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_1" [id=66, type=matmul];
+"67 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_4" [id=67, type=transpose];
+"68 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/contiguous_0" [id=68, type=contiguous];
+"69 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_2" [id=69, type=reshape];
+"70 model.layers.0.self_attn.o_proj.weight" [id=70, type=nncf_model_const];
+"71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=71, type=decompress_symmetric];
+"72 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0" [id=72, type=type];
+"73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" [id=73, type=linear];
+"74 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___0" [id=74, type=__add__];
+"75 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0" [id=75, type=to];
+"76 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/pow_0" [id=76, type=pow];
+"77 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/mean_0" [id=77, type=mean];
+"78 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__add___0" [id=78, type=__add__];
+"79 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0" [id=79, type=rsqrt];
+"80 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___0" [id=80, type=__mul__];
+"81 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_1" [id=81, type=to];
+"82 model.layers.0.post_attention_layernorm.weight" [id=82, type=nncf_model_const];
+"83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" [id=83, type=__mul__];
+"84 model.layers.0.mlp.gate_proj.weight" [id=84, type=nncf_model_const];
+"85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=85, type=decompress_symmetric];
+"86 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0" [id=86, type=type];
+"87 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" [id=87, type=abs];
+"88 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" [id=88, type=le];
+"89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" [id=89, type=masked_fill];
+"90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" [id=90, type=linear];
+"91 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" [id=91, type=silu];
+"92 model.layers.0.mlp.up_proj.weight" [id=92, type=nncf_model_const];
+"93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=93, type=decompress_symmetric];
+"94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0" [id=94, type=type];
+"95 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" [id=95, type=abs];
+"96 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" [id=96, type=le];
+"97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" [id=97, type=masked_fill];
+"98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" [id=98, type=linear];
+"99 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0" [id=99, type=__mul__];
+"100 model.layers.0.mlp.down_proj.weight" [id=100, type=nncf_model_const];
+"101 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=101, type=decompress_symmetric];
+"102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0" [id=102, type=type];
+"103 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" [id=103, type=abs];
+"104 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" [id=104, type=le];
+"105 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" [id=105, type=masked_fill];
+"106 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/linear_0" [id=106, type=linear];
+"107 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___1" [id=107, type=__add__];
+"108 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_0" [id=108, type=to];
+"109 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/pow_0" [id=109, type=pow];
+"110 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/mean_0" [id=110, type=mean];
+"111 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__add___0" [id=111, type=__add__];
+"112 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/rsqrt_0" [id=112, type=rsqrt];
+"113 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___0" [id=113, type=__mul__];
+"114 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_1" [id=114, type=to];
+"115 model.layers.1.input_layernorm.weight" [id=115, type=nncf_model_const];
+"116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" [id=116, type=__mul__];
+"117 model.layers.1.self_attn.q_proj.weight" [id=117, type=nncf_model_const];
+"118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=118, type=decompress_symmetric];
+"119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0" [id=119, type=type];
+"120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" [id=120, type=linear];
+"121 model.layers.1.self_attn.k_proj.weight" [id=121, type=nncf_model_const];
+"122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=122, type=decompress_symmetric];
+"123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0" [id=123, type=type];
+"124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" [id=124, type=linear];
+"125 model.layers.1.self_attn.v_proj.weight" [id=125, type=nncf_model_const];
+"126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=126, type=decompress_symmetric];
+"127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0" [id=127, type=type];
+"128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" [id=128, type=linear];
+"129 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_0" [id=129, type=view];
+"130 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0" [id=130, type=transpose];
+"131 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_1" [id=131, type=view];
+"132 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_1" [id=132, type=transpose];
+"133 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_2" [id=133, type=view];
+"134 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_2" [id=134, type=transpose];
+"135 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" [id=135, type=cat];
+"136 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0" [id=136, type=cos];
+"137 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0" [id=137, type=sin];
+"138 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0" [id=138, type=to];
+"139 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1" [id=139, type=to];
+"140 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_0" [id=140, type=unsqueeze];
+"141 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_1" [id=141, type=unsqueeze];
+"142 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___0" [id=142, type=__mul__];
+"143 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___0" [id=143, type=__getitem__];
+"144 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___1" [id=144, type=__getitem__];
+"145 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___0" [id=145, type=__neg__];
+"146 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_0" [id=146, type=cat];
+"147 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___1" [id=147, type=__mul__];
+"148 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___0" [id=148, type=__add__];
+"149 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___2" [id=149, type=__mul__];
+"150 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___2" [id=150, type=__getitem__];
+"151 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___3" [id=151, type=__getitem__];
+"152 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___1" [id=152, type=__neg__];
+"153 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_1" [id=153, type=cat];
+"154 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___3" [id=154, type=__mul__];
+"155 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___1" [id=155, type=__add__];
+"156 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___4" [id=156, type=__getitem__];
+"157 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_0" [id=157, type=expand];
+"158 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_0" [id=158, type=reshape];
+"159 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___5" [id=159, type=__getitem__];
+"160 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_1" [id=160, type=expand];
+"161 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_1" [id=161, type=reshape];
+"162 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_3" [id=162, type=transpose];
+"163 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_0" [id=163, type=matmul];
+"164 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__truediv___0" [id=164, type=__truediv__];
+"165 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___2" [id=165, type=__add__];
+"166 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/softmax_0" [id=166, type=softmax];
+"167 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/to_0" [id=167, type=to];
+"168 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/dropout_0" [id=168, type=dropout];
+"169 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_1" [id=169, type=matmul];
+"170 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_4" [id=170, type=transpose];
+"171 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/contiguous_0" [id=171, type=contiguous];
+"172 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_2" [id=172, type=reshape];
+"173 model.layers.1.self_attn.o_proj.weight" [id=173, type=nncf_model_const];
+"174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=174, type=decompress_symmetric];
+"175 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0" [id=175, type=type];
+"176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" [id=176, type=linear];
+"177 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___0" [id=177, type=__add__];
+"178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0" [id=178, type=to];
+"179 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/pow_0" [id=179, type=pow];
+"180 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/mean_0" [id=180, type=mean];
+"181 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__add___0" [id=181, type=__add__];
+"182 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0" [id=182, type=rsqrt];
+"183 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___0" [id=183, type=__mul__];
+"184 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_1" [id=184, type=to];
+"185 model.layers.1.post_attention_layernorm.weight" [id=185, type=nncf_model_const];
+"186 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" [id=186, type=__mul__];
+"187 model.layers.1.mlp.gate_proj.weight" [id=187, type=nncf_model_const];
+"188 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=188, type=decompress_symmetric];
+"189 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0" [id=189, type=type];
+"190 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" [id=190, type=abs];
+"191 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" [id=191, type=le];
+"192 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" [id=192, type=masked_fill];
+"193 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" [id=193, type=linear];
+"194 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" [id=194, type=silu];
+"195 model.layers.1.mlp.up_proj.weight" [id=195, type=nncf_model_const];
+"196 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=196, type=decompress_symmetric];
+"197 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0" [id=197, type=type];
+"198 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" [id=198, type=abs];
+"199 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" [id=199, type=le];
+"200 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" [id=200, type=masked_fill];
+"201 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" [id=201, type=linear];
+"202 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0" [id=202, type=__mul__];
+"203 model.layers.1.mlp.down_proj.weight" [id=203, type=nncf_model_const];
+"204 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=204, type=decompress_symmetric];
+"205 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0" [id=205, type=type];
+"206 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" [id=206, type=abs];
+"207 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" [id=207, type=le];
+"208 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" [id=208, type=masked_fill];
+"209 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/linear_0" [id=209, type=linear];
+"210 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___1" [id=210, type=__add__];
+"211 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_0" [id=211, type=to];
+"212 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/pow_0" [id=212, type=pow];
+"213 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/mean_0" [id=213, type=mean];
+"214 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__add___0" [id=214, type=__add__];
+"215 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/rsqrt_0" [id=215, type=rsqrt];
+"216 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___0" [id=216, type=__mul__];
+"217 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_1" [id=217, type=to];
+"218 model.norm.weight" [id=218, type=nncf_model_const];
+"219 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1" [id=219, type=__mul__];
+"220 lm_head.weight" [id=220, type=nncf_model_const];
+"221 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=221, type=decompress_symmetric];
+"222 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/type_0" [id=222, type=type];
+"223 LlamaForCausalLM/Linear[lm_head]/linear_0" [id=223, type=linear];
+"224 LlamaForCausalLM/float_0" [id=224, type=float];
+"225 /nncf_model_output_0" [id=225, type=nncf_model_output];
+"0 /nncf_model_input_0" -> "4 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0";
+"1 model.embed_tokens.weight" -> "2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "3 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/type_0";
+"3 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/type_0" -> "4 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0";
+"4 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0" -> "5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0";
+"5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0" -> "6 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/pow_0";
+"5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0" -> "10 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___0";
+"5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0" -> "74 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___0";
+"6 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/pow_0" -> "7 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/mean_0";
+"7 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/mean_0" -> "8 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__add___0";
+"8 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__add___0" -> "9 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/rsqrt_0";
+"9 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/rsqrt_0" -> "10 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___0";
+"10 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___0" -> "11 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_1";
+"11 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_1" -> "13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1";
+"12 model.layers.0.input_layernorm.weight" -> "13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1";
+"13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0";
+"13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0";
+"13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0";
+"14 model.layers.0.self_attn.q_proj.weight" -> "15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "16 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0";
+"16 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0" -> "17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0";
+"17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" -> "26 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_0";
+"18 model.layers.0.self_attn.k_proj.weight" -> "19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0";
+"20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0" -> "21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0";
+"21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" -> "28 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_1";
+"22 model.layers.0.self_attn.v_proj.weight" -> "23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0";
+"24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0" -> "25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0";
+"25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" -> "30 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_2";
+"26 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_0" -> "27 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0";
+"27 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0" -> "39 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___0";
+"27 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0" -> "40 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___0";
+"27 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0" -> "41 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___1";
+"28 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_1" -> "29 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_1";
+"29 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_1" -> "46 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___2";
+"29 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_1" -> "47 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___2";
+"29 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_1" -> "48 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___3";
+"30 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_2" -> "31 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_2";
+"31 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_2" -> "56 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___5";
+"32 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" -> "33 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0";
+"32 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" -> "34 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0";
+"33 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0" -> "35 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0";
+"34 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0" -> "36 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1";
+"35 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0" -> "37 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_0";
+"36 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1" -> "38 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_1";
+"37 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_0" -> "39 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___0";
+"37 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_0" -> "46 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___2";
+"38 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_1" -> "44 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___1";
+"38 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_1" -> "51 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___3";
+"39 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___0" -> "45 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___0";
+"40 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___0" -> "43 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_0";
+"41 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___1" -> "42 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___0";
+"42 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___0" -> "43 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_0";
+"43 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_0" -> "44 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___1";
+"44 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___1" -> "45 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___0";
+"45 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___0" -> "60 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_0";
+"46 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___2" -> "52 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___1";
+"47 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___2" -> "50 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_1";
+"48 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___3" -> "49 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___1";
+"49 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___1" -> "50 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_1";
+"50 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_1" -> "51 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___3";
+"51 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___3" -> "52 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___1";
+"52 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___1" -> "53 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___4";
+"53 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___4" -> "54 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_0";
+"54 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_0" -> "55 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_0";
+"55 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_0" -> "59 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_3";
+"56 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___5" -> "57 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_1";
+"57 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_1" -> "58 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_1";
+"58 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_1" -> "66 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_1";
+"59 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_3" -> "60 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_0";
+"60 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_0" -> "61 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__truediv___0";
+"61 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__truediv___0" -> "62 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___2";
+"62 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___2" -> "63 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/softmax_0";
+"63 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/softmax_0" -> "64 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/to_0";
+"64 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/to_0" -> "65 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/dropout_0";
+"65 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/dropout_0" -> "66 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_1";
+"66 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_1" -> "67 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_4";
+"67 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_4" -> "68 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/contiguous_0";
+"68 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/contiguous_0" -> "69 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_2";
+"69 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_2" -> "73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0";
+"70 model.layers.0.self_attn.o_proj.weight" -> "71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "72 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0";
+"72 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0" -> "73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0";
+"73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" -> "74 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___0";
+"74 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___0" -> "75 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0";
+"75 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "76 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/pow_0";
+"75 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "80 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___0";
+"75 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "107 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___1";
+"76 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/pow_0" -> "77 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/mean_0";
+"77 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/mean_0" -> "78 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__add___0";
+"78 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__add___0" -> "79 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0";
+"79 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0" -> "80 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___0";
+"80 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___0" -> "81 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_1";
+"81 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_1" -> "83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1";
+"82 model.layers.0.post_attention_layernorm.weight" -> "83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1";
+"83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "87 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0";
+"83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0";
+"83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "95 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0";
+"83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0";
+"84 model.layers.0.mlp.gate_proj.weight" -> "85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "86 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0";
+"86 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0" -> "90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0";
+"87 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" -> "88 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0";
+"88 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" -> "89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0";
+"89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" -> "90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0";
+"90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" -> "91 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0";
+"91 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" -> "99 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0";
+"92 model.layers.0.mlp.up_proj.weight" -> "93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0";
+"94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0" -> "98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0";
+"95 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" -> "96 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0";
+"96 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" -> "97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0";
+"97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" -> "98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0";
+"98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" -> "99 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0";
+"99 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0" -> "103 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0";
+"99 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0" -> "105 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0";
+"100 model.layers.0.mlp.down_proj.weight" -> "101 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"101 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0";
+"102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0" -> "106 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/linear_0";
+"103 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" -> "104 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0";
+"104 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" -> "105 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0";
+"105 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" -> "106 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/linear_0";
+"106 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/linear_0" -> "107 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___1";
+"107 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___1" -> "108 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_0";
+"108 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_0" -> "109 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/pow_0";
+"108 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_0" -> "113 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___0";
+"108 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_0" -> "177 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___0";
+"109 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/pow_0" -> "110 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/mean_0";
+"110 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/mean_0" -> "111 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__add___0";
+"111 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__add___0" -> "112 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/rsqrt_0";
+"112 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/rsqrt_0" -> "113 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___0";
+"113 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___0" -> "114 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_1";
+"114 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_1" -> "116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1";
+"115 model.layers.1.input_layernorm.weight" -> "116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1";
+"116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0";
+"116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0";
+"116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0";
+"117 model.layers.1.self_attn.q_proj.weight" -> "118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0";
+"119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0" -> "120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0";
+"120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" -> "129 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_0";
+"121 model.layers.1.self_attn.k_proj.weight" -> "122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0";
+"123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0" -> "124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0";
+"124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" -> "131 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_1";
+"125 model.layers.1.self_attn.v_proj.weight" -> "126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0";
+"127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0" -> "128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0";
+"128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" -> "133 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_2";
+"129 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_0" -> "130 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0";
+"130 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0" -> "142 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___0";
+"130 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0" -> "143 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___0";
+"130 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0" -> "144 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___1";
+"131 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_1" -> "132 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_1";
+"132 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_1" -> "149 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___2";
+"132 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_1" -> "150 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___2";
+"132 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_1" -> "151 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___3";
+"133 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_2" -> "134 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_2";
+"134 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_2" -> "159 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___5";
+"135 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" -> "136 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0";
+"135 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" -> "137 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0";
+"136 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0" -> "138 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0";
+"137 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0" -> "139 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1";
+"138 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0" -> "140 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_0";
+"139 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1" -> "141 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_1";
+"140 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_0" -> "142 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___0";
+"140 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_0" -> "149 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___2";
+"141 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_1" -> "147 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___1";
+"141 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_1" -> "154 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___3";
+"142 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___0" -> "148 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___0";
+"143 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___0" -> "146 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_0";
+"144 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___1" -> "145 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___0";
+"145 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___0" -> "146 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_0";
+"146 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_0" -> "147 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___1";
+"147 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___1" -> "148 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___0";
+"148 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___0" -> "163 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_0";
+"149 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___2" -> "155 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___1";
+"150 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___2" -> "153 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_1";
+"151 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___3" -> "152 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___1";
+"152 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___1" -> "153 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_1";
+"153 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_1" -> "154 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___3";
+"154 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___3" -> "155 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___1";
+"155 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___1" -> "156 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___4";
+"156 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___4" -> "157 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_0";
+"157 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_0" -> "158 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_0";
+"158 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_0" -> "162 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_3";
+"159 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___5" -> "160 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_1";
+"160 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_1" -> "161 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_1";
+"161 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_1" -> "169 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_1";
+"162 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_3" -> "163 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_0";
+"163 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_0" -> "164 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__truediv___0";
+"164 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__truediv___0" -> "165 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___2";
+"165 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___2" -> "166 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/softmax_0";
+"166 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/softmax_0" -> "167 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/to_0";
+"167 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/to_0" -> "168 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/dropout_0";
+"168 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/dropout_0" -> "169 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_1";
+"169 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_1" -> "170 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_4";
+"170 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_4" -> "171 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/contiguous_0";
+"171 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/contiguous_0" -> "172 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_2";
+"172 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_2" -> "176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0";
+"173 model.layers.1.self_attn.o_proj.weight" -> "174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "175 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0";
+"175 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0" -> "176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0";
+"176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" -> "177 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___0";
+"177 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___0" -> "178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0";
+"178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "179 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/pow_0";
+"178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "183 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___0";
+"178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "210 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___1";
+"179 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/pow_0" -> "180 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/mean_0";
+"180 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/mean_0" -> "181 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__add___0";
+"181 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__add___0" -> "182 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0";
+"182 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0" -> "183 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___0";
+"183 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___0" -> "184 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_1";
+"184 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_1" -> "186 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1";
+"185 model.layers.1.post_attention_layernorm.weight" -> "186 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1";
+"186 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "190 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0";
+"186 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "192 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0";
+"186 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "198 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0";
+"186 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "200 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0";
+"187 model.layers.1.mlp.gate_proj.weight" -> "188 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"188 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "189 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0";
+"189 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0" -> "193 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0";
+"190 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" -> "191 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0";
+"191 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" -> "192 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0";
+"192 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" -> "193 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0";
+"193 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" -> "194 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0";
+"194 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" -> "202 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0";
+"195 model.layers.1.mlp.up_proj.weight" -> "196 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"196 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "197 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0";
+"197 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0" -> "201 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0";
+"198 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" -> "199 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0";
+"199 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" -> "200 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0";
+"200 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" -> "201 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0";
+"201 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" -> "202 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0";
+"202 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0" -> "206 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0";
+"202 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0" -> "208 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0";
+"203 model.layers.1.mlp.down_proj.weight" -> "204 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"204 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "205 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0";
+"205 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0" -> "209 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/linear_0";
+"206 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" -> "207 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0";
+"207 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" -> "208 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0";
+"208 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" -> "209 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/linear_0";
+"209 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/linear_0" -> "210 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___1";
+"210 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___1" -> "211 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_0";
+"211 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_0" -> "212 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/pow_0";
+"211 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_0" -> "216 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___0";
+"212 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/pow_0" -> "213 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/mean_0";
+"213 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/mean_0" -> "214 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__add___0";
+"214 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__add___0" -> "215 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/rsqrt_0";
+"215 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/rsqrt_0" -> "216 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___0";
+"216 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___0" -> "217 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_1";
+"217 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_1" -> "219 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1";
+"218 model.norm.weight" -> "219 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1";
+"219 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1" -> "223 LlamaForCausalLM/Linear[lm_head]/linear_0";
+"220 lm_head.weight" -> "221 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"221 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "222 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/type_0";
+"222 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/type_0" -> "223 LlamaForCausalLM/Linear[lm_head]/linear_0";
+"223 LlamaForCausalLM/Linear[lm_head]/linear_0" -> "224 LlamaForCausalLM/float_0";
+"224 LlamaForCausalLM/float_0" -> "225 /nncf_model_output_0";
+}
diff --git a/tests/torch/data/experimental/sparsify_activations/dummy_llama_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/dummy_llama_sparse_activations.dot
new file mode 100644
index 00000000000..05ba7d8f87c
--- /dev/null
+++ b/tests/torch/data/experimental/sparsify_activations/dummy_llama_sparse_activations.dot
@@ -0,0 +1,424 @@
+strict digraph {
+"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
+"1 model.embed_tokens.weight" [id=1, type=nncf_model_const];
+"2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0" [id=2, type=embedding];
+"3 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0" [id=3, type=to];
+"4 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/pow_0" [id=4, type=pow];
+"5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/mean_0" [id=5, type=mean];
+"6 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__add___0" [id=6, type=__add__];
+"7 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/rsqrt_0" [id=7, type=rsqrt];
+"8 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___0" [id=8, type=__mul__];
+"9 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_1" [id=9, type=to];
+"10 model.layers.0.input_layernorm.weight" [id=10, type=nncf_model_const];
+"11 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" [id=11, type=__mul__];
+"12 model.layers.0.self_attn.q_proj.weight" [id=12, type=nncf_model_const];
+"13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" [id=13, type=linear];
+"14 model.layers.0.self_attn.k_proj.weight" [id=14, type=nncf_model_const];
+"15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" [id=15, type=linear];
+"16 model.layers.0.self_attn.v_proj.weight" [id=16, type=nncf_model_const];
+"17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" [id=17, type=linear];
+"18 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_0" [id=18, type=view];
+"19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0" [id=19, type=transpose];
+"20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_1" [id=20, type=view];
+"21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_1" [id=21, type=transpose];
+"22 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_2" [id=22, type=view];
+"23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_2" [id=23, type=transpose];
+"24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" [id=24, type=cat];
+"25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0" [id=25, type=cos];
+"26 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0" [id=26, type=sin];
+"27 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0" [id=27, type=to];
+"28 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1" [id=28, type=to];
+"29 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_0" [id=29, type=unsqueeze];
+"30 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_1" [id=30, type=unsqueeze];
+"31 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___0" [id=31, type=__mul__];
+"32 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___0" [id=32, type=__getitem__];
+"33 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___1" [id=33, type=__getitem__];
+"34 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___0" [id=34, type=__neg__];
+"35 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_0" [id=35, type=cat];
+"36 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___1" [id=36, type=__mul__];
+"37 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___0" [id=37, type=__add__];
+"38 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___2" [id=38, type=__mul__];
+"39 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___2" [id=39, type=__getitem__];
+"40 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___3" [id=40, type=__getitem__];
+"41 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___1" [id=41, type=__neg__];
+"42 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_1" [id=42, type=cat];
+"43 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___3" [id=43, type=__mul__];
+"44 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___1" [id=44, type=__add__];
+"45 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___4" [id=45, type=__getitem__];
+"46 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_0" [id=46, type=expand];
+"47 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_0" [id=47, type=reshape];
+"48 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___5" [id=48, type=__getitem__];
+"49 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_1" [id=49, type=expand];
+"50 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_1" [id=50, type=reshape];
+"51 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_3" [id=51, type=transpose];
+"52 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_0" [id=52, type=matmul];
+"53 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__truediv___0" [id=53, type=__truediv__];
+"54 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___2" [id=54, type=__add__];
+"55 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/softmax_0" [id=55, type=softmax];
+"56 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/to_0" [id=56, type=to];
+"57 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/dropout_0" [id=57, type=dropout];
+"58 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_1" [id=58, type=matmul];
+"59 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_4" [id=59, type=transpose];
+"60 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/contiguous_0" [id=60, type=contiguous];
+"61 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_2" [id=61, type=reshape];
+"62 model.layers.0.self_attn.o_proj.weight" [id=62, type=nncf_model_const];
+"63 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" [id=63, type=linear];
+"64 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___0" [id=64, type=__add__];
+"65 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0" [id=65, type=to];
+"66 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/pow_0" [id=66, type=pow];
+"67 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/mean_0" [id=67, type=mean];
+"68 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__add___0" [id=68, type=__add__];
+"69 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0" [id=69, type=rsqrt];
+"70 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___0" [id=70, type=__mul__];
+"71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_1" [id=71, type=to];
+"72 model.layers.0.post_attention_layernorm.weight" [id=72, type=nncf_model_const];
+"73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" [id=73, type=__mul__];
+"74 model.layers.0.mlp.gate_proj.weight" [id=74, type=nncf_model_const];
+"75 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" [id=75, type=abs];
+"76 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" [id=76, type=le];
+"77 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" [id=77, type=masked_fill];
+"78 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" [id=78, type=linear];
+"79 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" [id=79, type=silu];
+"80 model.layers.0.mlp.up_proj.weight" [id=80, type=nncf_model_const];
+"81 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" [id=81, type=abs];
+"82 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" [id=82, type=le];
+"83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" [id=83, type=masked_fill];
+"84 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" [id=84, type=linear];
+"85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0" [id=85, type=__mul__];
+"86 model.layers.0.mlp.down_proj.weight" [id=86, type=nncf_model_const];
+"87 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" [id=87, type=abs];
+"88 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" [id=88, type=le];
+"89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" [id=89, type=masked_fill];
+"90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/linear_0" [id=90, type=linear];
+"91 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___1" [id=91, type=__add__];
+"92 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_0" [id=92, type=to];
+"93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/pow_0" [id=93, type=pow];
+"94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/mean_0" [id=94, type=mean];
+"95 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__add___0" [id=95, type=__add__];
+"96 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/rsqrt_0" [id=96, type=rsqrt];
+"97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___0" [id=97, type=__mul__];
+"98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_1" [id=98, type=to];
+"99 model.layers.1.input_layernorm.weight" [id=99, type=nncf_model_const];
+"100 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" [id=100, type=__mul__];
+"101 model.layers.1.self_attn.q_proj.weight" [id=101, type=nncf_model_const];
+"102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" [id=102, type=linear];
+"103 model.layers.1.self_attn.k_proj.weight" [id=103, type=nncf_model_const];
+"104 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" [id=104, type=linear];
+"105 model.layers.1.self_attn.v_proj.weight" [id=105, type=nncf_model_const];
+"106 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" [id=106, type=linear];
+"107 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_0" [id=107, type=view];
+"108 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0" [id=108, type=transpose];
+"109 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_1" [id=109, type=view];
+"110 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_1" [id=110, type=transpose];
+"111 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_2" [id=111, type=view];
+"112 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_2" [id=112, type=transpose];
+"113 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" [id=113, type=cat];
+"114 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0" [id=114, type=cos];
+"115 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0" [id=115, type=sin];
+"116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0" [id=116, type=to];
+"117 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1" [id=117, type=to];
+"118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_0" [id=118, type=unsqueeze];
+"119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_1" [id=119, type=unsqueeze];
+"120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___0" [id=120, type=__mul__];
+"121 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___0" [id=121, type=__getitem__];
+"122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___1" [id=122, type=__getitem__];
+"123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___0" [id=123, type=__neg__];
+"124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_0" [id=124, type=cat];
+"125 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___1" [id=125, type=__mul__];
+"126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___0" [id=126, type=__add__];
+"127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___2" [id=127, type=__mul__];
+"128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___2" [id=128, type=__getitem__];
+"129 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___3" [id=129, type=__getitem__];
+"130 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___1" [id=130, type=__neg__];
+"131 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_1" [id=131, type=cat];
+"132 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___3" [id=132, type=__mul__];
+"133 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___1" [id=133, type=__add__];
+"134 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___4" [id=134, type=__getitem__];
+"135 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_0" [id=135, type=expand];
+"136 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_0" [id=136, type=reshape];
+"137 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___5" [id=137, type=__getitem__];
+"138 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_1" [id=138, type=expand];
+"139 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_1" [id=139, type=reshape];
+"140 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_3" [id=140, type=transpose];
+"141 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_0" [id=141, type=matmul];
+"142 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__truediv___0" [id=142, type=__truediv__];
+"143 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___2" [id=143, type=__add__];
+"144 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/softmax_0" [id=144, type=softmax];
+"145 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/to_0" [id=145, type=to];
+"146 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/dropout_0" [id=146, type=dropout];
+"147 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_1" [id=147, type=matmul];
+"148 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_4" [id=148, type=transpose];
+"149 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/contiguous_0" [id=149, type=contiguous];
+"150 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_2" [id=150, type=reshape];
+"151 model.layers.1.self_attn.o_proj.weight" [id=151, type=nncf_model_const];
+"152 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" [id=152, type=linear];
+"153 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___0" [id=153, type=__add__];
+"154 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0" [id=154, type=to];
+"155 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/pow_0" [id=155, type=pow];
+"156 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/mean_0" [id=156, type=mean];
+"157 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__add___0" [id=157, type=__add__];
+"158 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0" [id=158, type=rsqrt];
+"159 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___0" [id=159, type=__mul__];
+"160 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_1" [id=160, type=to];
+"161 model.layers.1.post_attention_layernorm.weight" [id=161, type=nncf_model_const];
+"162 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" [id=162, type=__mul__];
+"163 model.layers.1.mlp.gate_proj.weight" [id=163, type=nncf_model_const];
+"164 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" [id=164, type=abs];
+"165 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" [id=165, type=le];
+"166 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" [id=166, type=masked_fill];
+"167 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" [id=167, type=linear];
+"168 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" [id=168, type=silu];
+"169 model.layers.1.mlp.up_proj.weight" [id=169, type=nncf_model_const];
+"170 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" [id=170, type=abs];
+"171 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" [id=171, type=le];
+"172 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" [id=172, type=masked_fill];
+"173 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" [id=173, type=linear];
+"174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0" [id=174, type=__mul__];
+"175 model.layers.1.mlp.down_proj.weight" [id=175, type=nncf_model_const];
+"176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" [id=176, type=abs];
+"177 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" [id=177, type=le];
+"178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" [id=178, type=masked_fill];
+"179 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/linear_0" [id=179, type=linear];
+"180 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___1" [id=180, type=__add__];
+"181 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_0" [id=181, type=to];
+"182 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/pow_0" [id=182, type=pow];
+"183 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/mean_0" [id=183, type=mean];
+"184 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__add___0" [id=184, type=__add__];
+"185 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/rsqrt_0" [id=185, type=rsqrt];
+"186 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___0" [id=186, type=__mul__];
+"187 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_1" [id=187, type=to];
+"188 model.norm.weight" [id=188, type=nncf_model_const];
+"189 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1" [id=189, type=__mul__];
+"190 lm_head.weight" [id=190, type=nncf_model_const];
+"191 LlamaForCausalLM/Linear[lm_head]/linear_0" [id=191, type=linear];
+"192 LlamaForCausalLM/float_0" [id=192, type=float];
+"193 /nncf_model_output_0" [id=193, type=nncf_model_output];
+"0 /nncf_model_input_0" -> "2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0";
+"1 model.embed_tokens.weight" -> "2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0";
+"2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0" -> "3 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0";
+"3 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0" -> "4 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/pow_0";
+"3 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0" -> "8 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___0";
+"3 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0" -> "64 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___0";
+"4 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/pow_0" -> "5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/mean_0";
+"5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/mean_0" -> "6 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__add___0";
+"6 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__add___0" -> "7 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/rsqrt_0";
+"7 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/rsqrt_0" -> "8 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___0";
+"8 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___0" -> "9 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_1";
+"9 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_1" -> "11 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1";
+"10 model.layers.0.input_layernorm.weight" -> "11 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1";
+"11 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0";
+"11 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0";
+"11 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0";
+"12 model.layers.0.self_attn.q_proj.weight" -> "13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0";
+"13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" -> "18 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_0";
+"14 model.layers.0.self_attn.k_proj.weight" -> "15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0";
+"15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" -> "20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_1";
+"16 model.layers.0.self_attn.v_proj.weight" -> "17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0";
+"17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" -> "22 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_2";
+"18 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_0" -> "19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0";
+"19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0" -> "31 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___0";
+"19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0" -> "32 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___0";
+"19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0" -> "33 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___1";
+"20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_1" -> "21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_1";
+"21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_1" -> "38 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___2";
+"21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_1" -> "39 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___2";
+"21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_1" -> "40 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___3";
+"22 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_2" -> "23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_2";
+"23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_2" -> "48 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___5";
+"24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" -> "25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0";
+"24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" -> "26 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0";
+"25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0" -> "27 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0";
+"26 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0" -> "28 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1";
+"27 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0" -> "29 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_0";
+"28 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1" -> "30 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_1";
+"29 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_0" -> "31 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___0";
+"29 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_0" -> "38 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___2";
+"30 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_1" -> "36 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___1";
+"30 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/unsqueeze_1" -> "43 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___3";
+"31 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___0" -> "37 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___0";
+"32 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___0" -> "35 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_0";
+"33 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___1" -> "34 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___0";
+"34 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___0" -> "35 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_0";
+"35 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_0" -> "36 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___1";
+"36 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___1" -> "37 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___0";
+"37 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___0" -> "52 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_0";
+"38 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___2" -> "44 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___1";
+"39 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___2" -> "42 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_1";
+"40 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___3" -> "41 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___1";
+"41 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__neg___1" -> "42 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_1";
+"42 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/cat_1" -> "43 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___3";
+"43 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___3" -> "44 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___1";
+"44 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___1" -> "45 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___4";
+"45 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___4" -> "46 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_0";
+"46 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_0" -> "47 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_0";
+"47 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_0" -> "51 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_3";
+"48 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__getitem___5" -> "49 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_1";
+"49 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/expand_1" -> "50 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_1";
+"50 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_1" -> "58 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_1";
+"51 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_3" -> "52 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_0";
+"52 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_0" -> "53 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__truediv___0";
+"53 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__truediv___0" -> "54 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___2";
+"54 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__add___2" -> "55 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/softmax_0";
+"55 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/softmax_0" -> "56 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/to_0";
+"56 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/to_0" -> "57 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/dropout_0";
+"57 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/dropout_0" -> "58 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_1";
+"58 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/matmul_1" -> "59 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_4";
+"59 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_4" -> "60 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/contiguous_0";
+"60 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/contiguous_0" -> "61 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_2";
+"61 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_2" -> "63 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0";
+"62 model.layers.0.self_attn.o_proj.weight" -> "63 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0";
+"63 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" -> "64 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___0";
+"64 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___0" -> "65 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0";
+"65 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "66 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/pow_0";
+"65 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "70 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___0";
+"65 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "91 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___1";
+"66 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/pow_0" -> "67 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/mean_0";
+"67 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/mean_0" -> "68 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__add___0";
+"68 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__add___0" -> "69 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0";
+"69 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0" -> "70 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___0";
+"70 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___0" -> "71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_1";
+"71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_1" -> "73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1";
+"72 model.layers.0.post_attention_layernorm.weight" -> "73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1";
+"73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "75 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0";
+"73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "77 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0";
+"73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "81 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0";
+"73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0";
+"74 model.layers.0.mlp.gate_proj.weight" -> "78 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0";
+"75 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" -> "76 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0";
+"76 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" -> "77 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0";
+"77 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" -> "78 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0";
+"78 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" -> "79 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0";
+"79 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" -> "85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0";
+"80 model.layers.0.mlp.up_proj.weight" -> "84 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0";
+"81 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" -> "82 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0";
+"82 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" -> "83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0";
+"83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" -> "84 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0";
+"84 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" -> "85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0";
+"85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0" -> "87 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0";
+"85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0" -> "89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0";
+"86 model.layers.0.mlp.down_proj.weight" -> "90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/linear_0";
+"87 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" -> "88 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0";
+"88 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" -> "89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0";
+"89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" -> "90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/linear_0";
+"90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/linear_0" -> "91 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___1";
+"91 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___1" -> "92 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_0";
+"92 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_0" -> "93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/pow_0";
+"92 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_0" -> "97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___0";
+"92 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_0" -> "153 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___0";
+"93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/pow_0" -> "94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/mean_0";
+"94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/mean_0" -> "95 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__add___0";
+"95 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__add___0" -> "96 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/rsqrt_0";
+"96 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/rsqrt_0" -> "97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___0";
+"97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___0" -> "98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_1";
+"98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/to_1" -> "100 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1";
+"99 model.layers.1.input_layernorm.weight" -> "100 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1";
+"100 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0";
+"100 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "104 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0";
+"100 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "106 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0";
+"101 model.layers.1.self_attn.q_proj.weight" -> "102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0";
+"102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" -> "107 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_0";
+"103 model.layers.1.self_attn.k_proj.weight" -> "104 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0";
+"104 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" -> "109 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_1";
+"105 model.layers.1.self_attn.v_proj.weight" -> "106 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0";
+"106 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" -> "111 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_2";
+"107 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_0" -> "108 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0";
+"108 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0" -> "120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___0";
+"108 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0" -> "121 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___0";
+"108 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0" -> "122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___1";
+"109 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_1" -> "110 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_1";
+"110 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_1" -> "127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___2";
+"110 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_1" -> "128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___2";
+"110 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_1" -> "129 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___3";
+"111 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_2" -> "112 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_2";
+"112 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_2" -> "137 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___5";
+"113 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" -> "114 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0";
+"113 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cat_0" -> "115 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0";
+"114 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/cos_0" -> "116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0";
+"115 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/sin_0" -> "117 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1";
+"116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_0" -> "118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_0";
+"117 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/LlamaRotaryEmbedding[rotary_emb]/to_1" -> "119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_1";
+"118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_0" -> "120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___0";
+"118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_0" -> "127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___2";
+"119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_1" -> "125 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___1";
+"119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/unsqueeze_1" -> "132 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___3";
+"120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___0" -> "126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___0";
+"121 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___0" -> "124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_0";
+"122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___1" -> "123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___0";
+"123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___0" -> "124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_0";
+"124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_0" -> "125 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___1";
+"125 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___1" -> "126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___0";
+"126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___0" -> "141 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_0";
+"127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___2" -> "133 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___1";
+"128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___2" -> "131 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_1";
+"129 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___3" -> "130 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___1";
+"130 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__neg___1" -> "131 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_1";
+"131 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/cat_1" -> "132 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___3";
+"132 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___3" -> "133 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___1";
+"133 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___1" -> "134 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___4";
+"134 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___4" -> "135 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_0";
+"135 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_0" -> "136 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_0";
+"136 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_0" -> "140 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_3";
+"137 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__getitem___5" -> "138 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_1";
+"138 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/expand_1" -> "139 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_1";
+"139 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_1" -> "147 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_1";
+"140 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_3" -> "141 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_0";
+"141 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_0" -> "142 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__truediv___0";
+"142 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__truediv___0" -> "143 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___2";
+"143 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__add___2" -> "144 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/softmax_0";
+"144 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/softmax_0" -> "145 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/to_0";
+"145 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/to_0" -> "146 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/dropout_0";
+"146 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/dropout_0" -> "147 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_1";
+"147 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/matmul_1" -> "148 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_4";
+"148 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_4" -> "149 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/contiguous_0";
+"149 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/contiguous_0" -> "150 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_2";
+"150 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_2" -> "152 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0";
+"151 model.layers.1.self_attn.o_proj.weight" -> "152 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0";
+"152 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" -> "153 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___0";
+"153 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___0" -> "154 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0";
+"154 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "155 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/pow_0";
+"154 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "159 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___0";
+"154 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "180 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___1";
+"155 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/pow_0" -> "156 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/mean_0";
+"156 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/mean_0" -> "157 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__add___0";
+"157 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__add___0" -> "158 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0";
+"158 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/rsqrt_0" -> "159 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___0";
+"159 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___0" -> "160 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_1";
+"160 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_1" -> "162 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1";
+"161 model.layers.1.post_attention_layernorm.weight" -> "162 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1";
+"162 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "164 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0";
+"162 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "166 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0";
+"162 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "170 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0";
+"162 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "172 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0";
+"163 model.layers.1.mlp.gate_proj.weight" -> "167 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0";
+"164 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" -> "165 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0";
+"165 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" -> "166 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0";
+"166 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" -> "167 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0";
+"167 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" -> "168 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0";
+"168 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" -> "174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0";
+"169 model.layers.1.mlp.up_proj.weight" -> "173 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0";
+"170 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" -> "171 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0";
+"171 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" -> "172 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0";
+"172 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" -> "173 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0";
+"173 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" -> "174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0";
+"174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0" -> "176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0";
+"174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0" -> "178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0";
+"175 model.layers.1.mlp.down_proj.weight" -> "179 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/linear_0";
+"176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" -> "177 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0";
+"177 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" -> "178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0";
+"178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" -> "179 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/linear_0";
+"179 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/linear_0" -> "180 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___1";
+"180 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___1" -> "181 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_0";
+"181 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_0" -> "182 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/pow_0";
+"181 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_0" -> "186 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___0";
+"182 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/pow_0" -> "183 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/mean_0";
+"183 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/mean_0" -> "184 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__add___0";
+"184 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__add___0" -> "185 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/rsqrt_0";
+"185 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/rsqrt_0" -> "186 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___0";
+"186 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___0" -> "187 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_1";
+"187 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_1" -> "189 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1";
+"188 model.norm.weight" -> "189 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1";
+"189 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1" -> "191 LlamaForCausalLM/Linear[lm_head]/linear_0";
+"190 lm_head.weight" -> "191 LlamaForCausalLM/Linear[lm_head]/linear_0";
+"191 LlamaForCausalLM/Linear[lm_head]/linear_0" -> "192 LlamaForCausalLM/float_0";
+"192 LlamaForCausalLM/float_0" -> "193 /nncf_model_output_0";
+}
diff --git a/tests/torch/data/experimental/sparsify_activations/linear_int8_sym_weights_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/linear_int8_sym_weights_sparse_activations.dot
new file mode 100644
index 00000000000..aa24d54a2e0
--- /dev/null
+++ b/tests/torch/data/experimental/sparsify_activations/linear_int8_sym_weights_sparse_activations.dot
@@ -0,0 +1,22 @@
+strict digraph {
+"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
+"1 weight" [id=1, type=nncf_model_const];
+"2 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/decompress_symmetric_0" [id=2, type=decompress_symmetric];
+"3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/type_0" [id=3, type=type];
+"4 bias" [id=4, type=nncf_model_const];
+"5 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/abs_0" [id=5, type=abs];
+"6 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/le_0" [id=6, type=le];
+"7 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/masked_fill_0" [id=7, type=masked_fill];
+"8 Linear/linear_0" [id=8, type=linear];
+"9 /nncf_model_output_0" [id=9, type=nncf_model_output];
+"0 /nncf_model_input_0" -> "5 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/abs_0";
+"0 /nncf_model_input_0" -> "7 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/masked_fill_0";
+"1 weight" -> "2 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/decompress_symmetric_0";
+"2 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/decompress_symmetric_0" -> "3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/type_0";
+"3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/type_0" -> "8 Linear/linear_0";
+"4 bias" -> "8 Linear/linear_0";
+"5 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/abs_0" -> "6 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/le_0";
+"6 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/le_0" -> "7 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/masked_fill_0";
+"7 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/masked_fill_0" -> "8 Linear/linear_0";
+"8 Linear/linear_0" -> "9 /nncf_model_output_0";
+}
diff --git a/tests/torch/data/experimental/sparsify_activations/linear_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/linear_sparse_activations.dot
new file mode 100644
index 00000000000..a3192dfff20
--- /dev/null
+++ b/tests/torch/data/experimental/sparsify_activations/linear_sparse_activations.dot
@@ -0,0 +1,18 @@
+strict digraph {
+"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
+"1 weight" [id=1, type=nncf_model_const];
+"2 bias" [id=2, type=nncf_model_const];
+"3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/abs_0" [id=3, type=abs];
+"4 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/le_0" [id=4, type=le];
+"5 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/masked_fill_0" [id=5, type=masked_fill];
+"6 Linear/linear_0" [id=6, type=linear];
+"7 /nncf_model_output_0" [id=7, type=nncf_model_output];
+"0 /nncf_model_input_0" -> "3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/abs_0";
+"0 /nncf_model_input_0" -> "5 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/masked_fill_0";
+"1 weight" -> "6 Linear/linear_0";
+"2 bias" -> "6 Linear/linear_0";
+"3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/abs_0" -> "4 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/le_0";
+"4 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/le_0" -> "5 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/masked_fill_0";
+"5 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/masked_fill_0" -> "6 Linear/linear_0";
+"6 Linear/linear_0" -> "7 /nncf_model_output_0";
+}
diff --git a/tests/torch/data/experimental/sparsify_activations/three_linear_ignore1_int8_sym_weights_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/three_linear_ignore1_int8_sym_weights_sparse_activations.dot
new file mode 100644
index 00000000000..ae3f667ff3a
--- /dev/null
+++ b/tests/torch/data/experimental/sparsify_activations/three_linear_ignore1_int8_sym_weights_sparse_activations.dot
@@ -0,0 +1,57 @@
+strict digraph {
+"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
+"1 embedding.weight" [id=1, type=nncf_model_const];
+"2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=2, type=decompress_symmetric];
+"3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0" [id=3, type=type];
+"4 ThreeLinearModel/Embedding[embedding]/embedding_0" [id=4, type=embedding];
+"5 linear1.weight" [id=5, type=nncf_model_const];
+"6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=6, type=decompress_symmetric];
+"7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0" [id=7, type=type];
+"8 linear1.bias" [id=8, type=nncf_model_const];
+"9 ThreeLinearModel/Linear[linear1]/linear_0" [id=9, type=linear];
+"10 linear3.weight" [id=10, type=nncf_model_const];
+"11 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=11, type=decompress_symmetric];
+"12 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0" [id=12, type=type];
+"13 linear3.bias" [id=13, type=nncf_model_const];
+"14 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" [id=14, type=abs];
+"15 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" [id=15, type=le];
+"16 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" [id=16, type=masked_fill];
+"17 ThreeLinearModel/Linear[linear3]/linear_0" [id=17, type=linear];
+"18 linear2.weight" [id=18, type=nncf_model_const];
+"19 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=19, type=decompress_symmetric];
+"20 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0" [id=20, type=type];
+"21 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" [id=21, type=abs];
+"22 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" [id=22, type=le];
+"23 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" [id=23, type=masked_fill];
+"24 ThreeLinearModel/Linear[linear2]/linear_0" [id=24, type=linear];
+"25 /nncf_model_output_0" [id=25, type=nncf_model_output];
+"26 /nncf_model_output_1" [id=26, type=nncf_model_output];
+"0 /nncf_model_input_0" -> "4 ThreeLinearModel/Embedding[embedding]/embedding_0";
+"1 embedding.weight" -> "2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0";
+"3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0" -> "4 ThreeLinearModel/Embedding[embedding]/embedding_0";
+"4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "9 ThreeLinearModel/Linear[linear1]/linear_0";
+"4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "21 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0";
+"4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "23 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0";
+"5 linear1.weight" -> "6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0";
+"7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0" -> "9 ThreeLinearModel/Linear[linear1]/linear_0";
+"8 linear1.bias" -> "9 ThreeLinearModel/Linear[linear1]/linear_0";
+"9 ThreeLinearModel/Linear[linear1]/linear_0" -> "14 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0";
+"9 ThreeLinearModel/Linear[linear1]/linear_0" -> "16 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0";
+"10 linear3.weight" -> "11 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"11 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "12 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0";
+"12 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0" -> "17 ThreeLinearModel/Linear[linear3]/linear_0";
+"13 linear3.bias" -> "17 ThreeLinearModel/Linear[linear3]/linear_0";
+"14 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" -> "15 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0";
+"15 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" -> "16 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0";
+"16 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" -> "17 ThreeLinearModel/Linear[linear3]/linear_0";
+"17 ThreeLinearModel/Linear[linear3]/linear_0" -> "25 /nncf_model_output_0";
+"18 linear2.weight" -> "19 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"19 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "20 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0";
+"20 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0" -> "24 ThreeLinearModel/Linear[linear2]/linear_0";
+"21 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" -> "22 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0";
+"22 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" -> "23 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0";
+"23 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" -> "24 ThreeLinearModel/Linear[linear2]/linear_0";
+"24 ThreeLinearModel/Linear[linear2]/linear_0" -> "26 /nncf_model_output_1";
+}
diff --git a/tests/torch/data/experimental/sparsify_activations/three_linear_ignore1_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/three_linear_ignore1_sparse_activations.dot
new file mode 100644
index 00000000000..19a4b32561e
--- /dev/null
+++ b/tests/torch/data/experimental/sparsify_activations/three_linear_ignore1_sparse_activations.dot
@@ -0,0 +1,41 @@
+strict digraph {
+"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
+"1 embedding.weight" [id=1, type=nncf_model_const];
+"2 ThreeLinearModel/Embedding[embedding]/embedding_0" [id=2, type=embedding];
+"3 linear1.weight" [id=3, type=nncf_model_const];
+"4 linear1.bias" [id=4, type=nncf_model_const];
+"5 ThreeLinearModel/Linear[linear1]/linear_0" [id=5, type=linear];
+"6 linear3.weight" [id=6, type=nncf_model_const];
+"7 linear3.bias" [id=7, type=nncf_model_const];
+"8 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" [id=8, type=abs];
+"9 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" [id=9, type=le];
+"10 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" [id=10, type=masked_fill];
+"11 ThreeLinearModel/Linear[linear3]/linear_0" [id=11, type=linear];
+"12 linear2.weight" [id=12, type=nncf_model_const];
+"13 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" [id=13, type=abs];
+"14 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" [id=14, type=le];
+"15 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" [id=15, type=masked_fill];
+"16 ThreeLinearModel/Linear[linear2]/linear_0" [id=16, type=linear];
+"17 /nncf_model_output_0" [id=17, type=nncf_model_output];
+"18 /nncf_model_output_1" [id=18, type=nncf_model_output];
+"0 /nncf_model_input_0" -> "2 ThreeLinearModel/Embedding[embedding]/embedding_0";
+"1 embedding.weight" -> "2 ThreeLinearModel/Embedding[embedding]/embedding_0";
+"2 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "5 ThreeLinearModel/Linear[linear1]/linear_0";
+"2 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "13 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0";
+"2 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "15 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0";
+"3 linear1.weight" -> "5 ThreeLinearModel/Linear[linear1]/linear_0";
+"4 linear1.bias" -> "5 ThreeLinearModel/Linear[linear1]/linear_0";
+"5 ThreeLinearModel/Linear[linear1]/linear_0" -> "8 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0";
+"5 ThreeLinearModel/Linear[linear1]/linear_0" -> "10 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0";
+"6 linear3.weight" -> "11 ThreeLinearModel/Linear[linear3]/linear_0";
+"7 linear3.bias" -> "11 ThreeLinearModel/Linear[linear3]/linear_0";
+"8 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" -> "9 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0";
+"9 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" -> "10 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0";
+"10 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" -> "11 ThreeLinearModel/Linear[linear3]/linear_0";
+"11 ThreeLinearModel/Linear[linear3]/linear_0" -> "17 /nncf_model_output_0";
+"12 linear2.weight" -> "16 ThreeLinearModel/Linear[linear2]/linear_0";
+"13 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" -> "14 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0";
+"14 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" -> "15 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0";
+"15 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" -> "16 ThreeLinearModel/Linear[linear2]/linear_0";
+"16 ThreeLinearModel/Linear[linear2]/linear_0" -> "18 /nncf_model_output_1";
+}
diff --git a/tests/torch/data/experimental/sparsify_activations/three_linear_int8_sym_weights_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/three_linear_int8_sym_weights_sparse_activations.dot
new file mode 100644
index 00000000000..c6488f1131b
--- /dev/null
+++ b/tests/torch/data/experimental/sparsify_activations/three_linear_int8_sym_weights_sparse_activations.dot
@@ -0,0 +1,64 @@
+strict digraph {
+"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
+"1 embedding.weight" [id=1, type=nncf_model_const];
+"2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=2, type=decompress_symmetric];
+"3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0" [id=3, type=type];
+"4 ThreeLinearModel/Embedding[embedding]/embedding_0" [id=4, type=embedding];
+"5 linear1.weight" [id=5, type=nncf_model_const];
+"6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=6, type=decompress_symmetric];
+"7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0" [id=7, type=type];
+"8 linear1.bias" [id=8, type=nncf_model_const];
+"9 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/abs_0" [id=9, type=abs];
+"10 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/le_0" [id=10, type=le];
+"11 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0" [id=11, type=masked_fill];
+"12 ThreeLinearModel/Linear[linear1]/linear_0" [id=12, type=linear];
+"13 linear3.weight" [id=13, type=nncf_model_const];
+"14 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=14, type=decompress_symmetric];
+"15 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0" [id=15, type=type];
+"16 linear3.bias" [id=16, type=nncf_model_const];
+"17 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" [id=17, type=abs];
+"18 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" [id=18, type=le];
+"19 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" [id=19, type=masked_fill];
+"20 ThreeLinearModel/Linear[linear3]/linear_0" [id=20, type=linear];
+"21 linear2.weight" [id=21, type=nncf_model_const];
+"22 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=22, type=decompress_symmetric];
+"23 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0" [id=23, type=type];
+"24 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" [id=24, type=abs];
+"25 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" [id=25, type=le];
+"26 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" [id=26, type=masked_fill];
+"27 ThreeLinearModel/Linear[linear2]/linear_0" [id=27, type=linear];
+"28 /nncf_model_output_0" [id=28, type=nncf_model_output];
+"29 /nncf_model_output_1" [id=29, type=nncf_model_output];
+"0 /nncf_model_input_0" -> "4 ThreeLinearModel/Embedding[embedding]/embedding_0";
+"1 embedding.weight" -> "2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0";
+"3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0" -> "4 ThreeLinearModel/Embedding[embedding]/embedding_0";
+"4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "9 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/abs_0";
+"4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "11 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0";
+"4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "24 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0";
+"4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "26 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0";
+"5 linear1.weight" -> "6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0";
+"7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0" -> "12 ThreeLinearModel/Linear[linear1]/linear_0";
+"8 linear1.bias" -> "12 ThreeLinearModel/Linear[linear1]/linear_0";
+"9 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/abs_0" -> "10 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/le_0";
+"10 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/le_0" -> "11 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0";
+"11 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0" -> "12 ThreeLinearModel/Linear[linear1]/linear_0";
+"12 ThreeLinearModel/Linear[linear1]/linear_0" -> "17 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0";
+"12 ThreeLinearModel/Linear[linear1]/linear_0" -> "19 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0";
+"13 linear3.weight" -> "14 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"14 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "15 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0";
+"15 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0" -> "20 ThreeLinearModel/Linear[linear3]/linear_0";
+"16 linear3.bias" -> "20 ThreeLinearModel/Linear[linear3]/linear_0";
+"17 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" -> "18 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0";
+"18 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" -> "19 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0";
+"19 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" -> "20 ThreeLinearModel/Linear[linear3]/linear_0";
+"20 ThreeLinearModel/Linear[linear3]/linear_0" -> "28 /nncf_model_output_0";
+"21 linear2.weight" -> "22 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0";
+"22 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "23 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0";
+"23 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0" -> "27 ThreeLinearModel/Linear[linear2]/linear_0";
+"24 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" -> "25 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0";
+"25 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" -> "26 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0";
+"26 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" -> "27 ThreeLinearModel/Linear[linear2]/linear_0";
+"27 ThreeLinearModel/Linear[linear2]/linear_0" -> "29 /nncf_model_output_1";
+}
diff --git a/tests/torch/data/experimental/sparsify_activations/three_linear_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/three_linear_sparse_activations.dot
new file mode 100644
index 00000000000..36779fe7f61
--- /dev/null
+++ b/tests/torch/data/experimental/sparsify_activations/three_linear_sparse_activations.dot
@@ -0,0 +1,48 @@
+strict digraph {
+"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
+"1 embedding.weight" [id=1, type=nncf_model_const];
+"2 ThreeLinearModel/Embedding[embedding]/embedding_0" [id=2, type=embedding];
+"3 linear1.weight" [id=3, type=nncf_model_const];
+"4 linear1.bias" [id=4, type=nncf_model_const];
+"5 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/abs_0" [id=5, type=abs];
+"6 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/le_0" [id=6, type=le];
+"7 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0" [id=7, type=masked_fill];
+"8 ThreeLinearModel/Linear[linear1]/linear_0" [id=8, type=linear];
+"9 linear3.weight" [id=9, type=nncf_model_const];
+"10 linear3.bias" [id=10, type=nncf_model_const];
+"11 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" [id=11, type=abs];
+"12 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" [id=12, type=le];
+"13 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" [id=13, type=masked_fill];
+"14 ThreeLinearModel/Linear[linear3]/linear_0" [id=14, type=linear];
+"15 linear2.weight" [id=15, type=nncf_model_const];
+"16 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" [id=16, type=abs];
+"17 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" [id=17, type=le];
+"18 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" [id=18, type=masked_fill];
+"19 ThreeLinearModel/Linear[linear2]/linear_0" [id=19, type=linear];
+"20 /nncf_model_output_0" [id=20, type=nncf_model_output];
+"21 /nncf_model_output_1" [id=21, type=nncf_model_output];
+"0 /nncf_model_input_0" -> "2 ThreeLinearModel/Embedding[embedding]/embedding_0";
+"1 embedding.weight" -> "2 ThreeLinearModel/Embedding[embedding]/embedding_0";
+"2 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "5 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/abs_0";
+"2 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "7 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0";
+"2 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "16 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0";
+"2 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "18 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0";
+"3 linear1.weight" -> "8 ThreeLinearModel/Linear[linear1]/linear_0";
+"4 linear1.bias" -> "8 ThreeLinearModel/Linear[linear1]/linear_0";
+"5 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/abs_0" -> "6 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/le_0";
+"6 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/le_0" -> "7 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0";
+"7 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0" -> "8 ThreeLinearModel/Linear[linear1]/linear_0";
+"8 ThreeLinearModel/Linear[linear1]/linear_0" -> "11 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0";
+"8 ThreeLinearModel/Linear[linear1]/linear_0" -> "13 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0";
+"9 linear3.weight" -> "14 ThreeLinearModel/Linear[linear3]/linear_0";
+"10 linear3.bias" -> "14 ThreeLinearModel/Linear[linear3]/linear_0";
+"11 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" -> "12 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0";
+"12 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" -> "13 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0";
+"13 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" -> "14 ThreeLinearModel/Linear[linear3]/linear_0";
+"14 ThreeLinearModel/Linear[linear3]/linear_0" -> "20 /nncf_model_output_0";
+"15 linear2.weight" -> "19 ThreeLinearModel/Linear[linear2]/linear_0";
+"16 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" -> "17 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0";
+"17 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" -> "18 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0";
+"18 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" -> "19 ThreeLinearModel/Linear[linear2]/linear_0";
+"19 ThreeLinearModel/Linear[linear2]/linear_0" -> "21 /nncf_model_output_1";
+}
diff --git a/tests/torch/experimental/sparsify_activations/__init__.py b/tests/torch/experimental/sparsify_activations/__init__.py
new file mode 100644
index 00000000000..2e49d63977d
--- /dev/null
+++ b/tests/torch/experimental/sparsify_activations/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) 2024 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/torch/experimental/sparsify_activations/helpers.py b/tests/torch/experimental/sparsify_activations/helpers.py
new file mode 100644
index 00000000000..437103ec166
--- /dev/null
+++ b/tests/torch/experimental/sparsify_activations/helpers.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2024 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import defaultdict
+
+import openvino as ov
+import torch
+import torch.nn as nn
+import transformers.models
+
+from nncf import IgnoredScope
+from nncf.experimental.torch.sparsify_activations import TargetScope
+
+
+class ThreeLinearModel(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.embedding = nn.Embedding(32, 2)
+ self.linear1 = nn.Linear(2, 3)
+ self.linear2 = nn.Linear(2, 4, bias=False)
+ self.linear3 = nn.Linear(3, 5)
+
+ def forward(self, input_ids: torch.Tensor):
+ x = self.embedding(input_ids)
+ y0 = self.linear3(self.linear1(x))
+ y1 = self.linear2(x)
+ return y0, y1
+
+
+def dummy_llama_model():
+ config = transformers.models.llama.configuration_llama.LlamaConfig(
+ vocab_size=32,
+ hidden_size=8,
+ intermediate_size=14,
+ num_attention_heads=2,
+ num_key_value_heads=1,
+ num_hidden_layers=2,
+ use_cache=False,
+ return_dict=False,
+ )
+ model = transformers.AutoModelForCausalLM.from_config(config, attn_implementation="eager")
+ return model
+
+
+def count_sparsifier_patterns_in_ov(model: ov.Model) -> int:
+ """
+ Counts the number of activation sparsification pattern "Abs -> LessEqual -> Select"
+ in the OpenVINO model.
+ """
+ pattern = ("Abs", "LessEqual", "Select")
+ result = 0
+ connections = defaultdict(list)
+ for node in model.get_ops():
+ for output in node.outputs():
+ for input_ in output.get_target_inputs():
+ connections[node].append(input_.get_node())
+
+ def dfs(node, location=0):
+ nonlocal result
+ if location < len(pattern) and node.get_type_name() == pattern[location]:
+ if location == len(pattern) - 1:
+ result += 1
+ else:
+ for next_node in connections[node]:
+ dfs(next_node, location + 1)
+
+ for node in model.get_ops():
+ dfs(node)
+ return result
+
+
+def convert_ignored_scope_to_target_scope(ignored_scope: IgnoredScope) -> TargetScope:
+ return TargetScope(
+ ignored_scope.names,
+ ignored_scope.patterns,
+ ignored_scope.types,
+ ignored_scope.subgraphs,
+ ignored_scope.validate,
+ )
diff --git a/tests/torch/experimental/sparsify_activations/test_algo.py b/tests/torch/experimental/sparsify_activations/test_algo.py
new file mode 100644
index 00000000000..b7214aaa5fa
--- /dev/null
+++ b/tests/torch/experimental/sparsify_activations/test_algo.py
@@ -0,0 +1,302 @@
+# Copyright (c) 2024 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable, Dict, Optional
+
+import openvino as ov
+import pytest
+import torch
+import torch.nn as nn
+
+import nncf
+import nncf.experimental
+import nncf.experimental.torch.sparsify_activations
+from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgorithm
+from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import TargetScope
+from nncf.experimental.torch.sparsify_activations.torch_backend import ACTIVATIONS_SPARSIFIER_PREFIX
+from nncf.experimental.torch.sparsify_activations.torch_backend import ActivationsSparsifier
+from nncf.scopes import IgnoredScope
+from nncf.torch.model_creation import wrap_model
+from nncf.torch.nncf_network import NNCFNetwork
+from tests.shared.nx_graph import compare_nx_graph_with_reference
+from tests.shared.paths import TEST_ROOT
+from tests.torch.experimental.sparsify_activations.helpers import ThreeLinearModel
+from tests.torch.experimental.sparsify_activations.helpers import count_sparsifier_patterns_in_ov
+from tests.torch.experimental.sparsify_activations.helpers import dummy_llama_model
+from tests.torch.helpers import set_torch_seed
+
+
+@dataclass
+class SparsifyActivationsAlgorithmTestDesc:
+ name: str
+ model_getter: Callable[[], nn.Module]
+ dataset_getter: Callable[[torch.device], nncf.Dataset]
+ target_sparsity_by_scope: Dict[TargetScope, float]
+ ignored_scope: Optional[nncf.IgnoredScope]
+ ref_sparsifier_target_sparsity: Dict[str, float]
+ ref_num_batches_tracked: int
+ ref_num_patterns_in_ov: int
+
+
+sparsify_activations_algorithm_test_descs = [
+ SparsifyActivationsAlgorithmTestDesc(
+ name="linear",
+ model_getter=lambda: nn.Linear(4, 2),
+ dataset_getter=lambda device: nncf.Dataset(torch.randn([3, 2, 4]).to(device)),
+ target_sparsity_by_scope={
+ TargetScope(names=["Linear/linear_0"]): 0.3,
+ },
+ ignored_scope=None,
+ ref_sparsifier_target_sparsity={
+ f"{ACTIVATIONS_SPARSIFIER_PREFIX}_Linear/linear_0": 0.3,
+ },
+ ref_num_batches_tracked=3,
+ ref_num_patterns_in_ov=1,
+ ),
+ SparsifyActivationsAlgorithmTestDesc(
+ name="three_linear",
+ model_getter=ThreeLinearModel,
+ dataset_getter=lambda device: nncf.Dataset(torch.randint(0, 30, (3, 2, 8)).to(device)),
+ target_sparsity_by_scope={
+ TargetScope(types=["linear"]): 0.4,
+ },
+ ignored_scope=None,
+ ref_sparsifier_target_sparsity={
+ f"{ACTIVATIONS_SPARSIFIER_PREFIX}_ThreeLinearModel/Linear[linear1]/linear_0": 0.4,
+ f"{ACTIVATIONS_SPARSIFIER_PREFIX}_ThreeLinearModel/Linear[linear2]/linear_0": 0.4,
+ f"{ACTIVATIONS_SPARSIFIER_PREFIX}_ThreeLinearModel/Linear[linear3]/linear_0": 0.4,
+ },
+ ref_num_batches_tracked=3,
+ ref_num_patterns_in_ov=2, # Sparsifiers are combined in linear1 and linear2
+ ),
+ SparsifyActivationsAlgorithmTestDesc(
+ name="three_linear_ignore1",
+ model_getter=ThreeLinearModel,
+ dataset_getter=lambda device: nncf.Dataset(torch.randint(0, 30, (3, 2, 8)).to(device)),
+ target_sparsity_by_scope={
+ TargetScope(names=["ThreeLinearModel/Linear[linear2]/linear_0"]): 0.4,
+ TargetScope(patterns=[".*linear3.*"]): 0.4,
+ },
+ ignored_scope=IgnoredScope(patterns=[".*linear1.*"]),
+ ref_sparsifier_target_sparsity={
+ f"{ACTIVATIONS_SPARSIFIER_PREFIX}_ThreeLinearModel/Linear[linear2]/linear_0": 0.4,
+ f"{ACTIVATIONS_SPARSIFIER_PREFIX}_ThreeLinearModel/Linear[linear3]/linear_0": 0.4,
+ },
+ ref_num_batches_tracked=3,
+ ref_num_patterns_in_ov=2,
+ ),
+ SparsifyActivationsAlgorithmTestDesc(
+ name="dummy_llama",
+ model_getter=dummy_llama_model,
+ dataset_getter=lambda device: nncf.Dataset(torch.randint(0, 30, (3, 2, 8)).to(device)),
+ target_sparsity_by_scope={
+ TargetScope(patterns=[".*gate_proj.*"]): 0.2,
+ TargetScope(patterns=[".*up_proj.*"]): 0.3,
+ TargetScope(patterns=[".*down_proj.*"]): 0.4,
+ },
+ ignored_scope=None,
+ ref_sparsifier_target_sparsity={
+ (
+ f"{ACTIVATIONS_SPARSIFIER_PREFIX}_LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/"
+ f"LlamaDecoderLayer[{layer_id}]/LlamaMLP[mlp]/Linear[{name}]/linear_0"
+ ): sparsity
+ for name, sparsity in [("gate_proj", 0.2), ("up_proj", 0.3), ("down_proj", 0.4)]
+ for layer_id in [0, 1]
+ },
+ ref_num_batches_tracked=3,
+ ref_num_patterns_in_ov=6,
+ ),
+]
+
+
+@pytest.mark.parametrize(
+ "desc",
+ sparsify_activations_algorithm_test_descs,
+ ids=[p.name for p in sparsify_activations_algorithm_test_descs],
+ scope="class",
+)
+@pytest.mark.parametrize("compress_weights", [False, True], scope="class")
+@pytest.mark.parametrize("use_cuda", [False, True], ids=["cpu", "cuda"], scope="class")
+class TestSparsifyActivationsAlgorithm:
+
+ @pytest.fixture(autouse=True, scope="class")
+ def setup(self, request, desc: SparsifyActivationsAlgorithmTestDesc, compress_weights: bool, use_cuda: bool):
+ if use_cuda and not torch.cuda.is_available():
+ pytest.skip("CUDA is not available")
+ request.cls.use_cuda = use_cuda
+ device = torch.device("cuda" if use_cuda else "cpu")
+ request.cls.device = device
+ request.cls.desc = desc
+ request.cls.compress_weights = compress_weights
+ with set_torch_seed():
+ model = desc.model_getter()
+ model = model.to(device).eval()
+ dataset = desc.dataset_getter(device)
+ if compress_weights:
+ model = nncf.compress_weights(
+ model,
+ mode=nncf.CompressWeightsMode.INT8_SYM,
+ dataset=dataset,
+ )
+ model = nncf.experimental.torch.sparsify_activations.sparsify_activations(
+ model=model,
+ dataset=dataset,
+ target_sparsity_by_scope=desc.target_sparsity_by_scope,
+ ignored_scope=desc.ignored_scope,
+ )
+ request.cls.model = model
+ request.cls.dataset = dataset
+
+ def test_inserted_sparsifier(self):
+ desc: SparsifyActivationsAlgorithmTestDesc = self.desc
+ model = self.model
+ assert isinstance(model, NNCFNetwork)
+ num_sparsifiers = 0
+ for name, op in model.nncf.external_op.items():
+ if isinstance(op, ActivationsSparsifier):
+ assert op.target_sparsity == desc.ref_sparsifier_target_sparsity[name]
+ assert op.num_batches_tracked == desc.ref_num_batches_tracked
+ num_sparsifiers += 1
+ assert num_sparsifiers == len(desc.ref_sparsifier_target_sparsity)
+
+ def test_nncf_graph(self):
+ desc: SparsifyActivationsAlgorithmTestDesc = self.desc
+ model: NNCFNetwork = self.model
+ file_name = "_".join(
+ filter(None, [desc.name, "int8_sym_weights" if self.compress_weights else None, "sparse_activations"])
+ )
+ ref_dot_path = Path(TEST_ROOT, "torch", "data", "experimental", "sparsify_activations", f"{file_name}.dot")
+ graph = model.nncf.get_graph().get_graph_for_structure_analysis()
+ compare_nx_graph_with_reference(graph, ref_dot_path)
+
+ def test_export_openvino(self):
+ model: NNCFNetwork = self.model
+ example_input = next(iter(self.dataset.get_inference_data()))
+ with torch.no_grad():
+ torch_outputs = model(example_input)
+ if isinstance(torch_outputs, dict):
+ torch_outputs = tuple(torch_outputs.values())
+ if not isinstance(torch_outputs, tuple):
+ torch_outputs = (torch_outputs,)
+
+ ov_model = ov.convert_model(model, example_input=example_input)
+ assert count_sparsifier_patterns_in_ov(ov_model) == self.desc.ref_num_patterns_in_ov
+
+ compiled_model = ov.compile_model(ov_model, "CPU", config={ov.properties.hint.inference_precision: "f32"})
+ ov_outputs = compiled_model(example_input.cpu()).to_tuple()
+ assert len(torch_outputs) == len(ov_outputs)
+ for torch_output, ov_output in zip(torch_outputs, ov_outputs):
+ torch.testing.assert_close(torch_output.cpu(), torch.from_numpy(ov_output), rtol=1e-3, atol=1e-3)
+
+
+@dataclass
+class TargetSparsityByNodeTestDesc:
+ target_sparsity_by_scope: Dict[TargetScope, float]
+ ignored_scope: IgnoredScope
+ ref_target_sparsity_by_node_name: Optional[Dict[str, float]] = None
+ raised_error_message: Optional[str] = None
+
+
+@pytest.mark.parametrize(
+ "desc",
+ [
+ TargetSparsityByNodeTestDesc(
+ target_sparsity_by_scope={TargetScope(patterns=[".*linear.*"]): 0.3},
+ ignored_scope=IgnoredScope(),
+ ref_target_sparsity_by_node_name={
+ "ThreeLinearModel/Linear[linear1]/linear_0": 0.3,
+ "ThreeLinearModel/Linear[linear2]/linear_0": 0.3,
+ "ThreeLinearModel/Linear[linear3]/linear_0": 0.3,
+ },
+ ),
+ TargetSparsityByNodeTestDesc(
+ target_sparsity_by_scope={TargetScope(patterns=[".*linear[23].*"], types=["linear"]): 0.3},
+ ignored_scope=IgnoredScope(),
+ ref_target_sparsity_by_node_name={
+ "ThreeLinearModel/Linear[linear1]/linear_0": 0.3,
+ "ThreeLinearModel/Linear[linear2]/linear_0": 0.3,
+ "ThreeLinearModel/Linear[linear3]/linear_0": 0.3,
+ },
+ ),
+ TargetSparsityByNodeTestDesc(
+ target_sparsity_by_scope={
+ TargetScope(
+ subgraphs=[nncf.Subgraph(inputs=["/nncf_model_input_0"], outputs=["/nncf_model_output_0"])]
+ ): 0.1,
+ },
+ ignored_scope=IgnoredScope(),
+ ref_target_sparsity_by_node_name={
+ "ThreeLinearModel/Linear[linear1]/linear_0": 0.1,
+ "ThreeLinearModel/Linear[linear3]/linear_0": 0.1,
+ },
+ ),
+ TargetSparsityByNodeTestDesc(
+ target_sparsity_by_scope={
+ TargetScope(names=["ThreeLinearModel/Linear[linear1]/linear_0"]): 0.1,
+ TargetScope(patterns=[".*linear[23].*"]): 0.3,
+ },
+ ignored_scope=IgnoredScope(patterns=[".*linear2.*"]),
+ ref_target_sparsity_by_node_name={
+ "ThreeLinearModel/Linear[linear1]/linear_0": 0.1,
+ "ThreeLinearModel/Linear[linear3]/linear_0": 0.3,
+ },
+ ),
+ TargetSparsityByNodeTestDesc(
+ target_sparsity_by_scope={
+ TargetScope(patterns=[".*nonexist.*"], validate=False): 0.3,
+ TargetScope(names=["ThreeLinearModel/Linear[linear1]/linear_0"]): 0.3,
+ },
+ ignored_scope=IgnoredScope(),
+ ref_target_sparsity_by_node_name={
+ "ThreeLinearModel/Linear[linear1]/linear_0": 0.3,
+ },
+ ),
+ TargetSparsityByNodeTestDesc(
+ target_sparsity_by_scope={TargetScope(patterns=[".*nonexist.*"]): 0.3},
+ ignored_scope=IgnoredScope(),
+ raised_error_message="not found in the graph",
+ ),
+ TargetSparsityByNodeTestDesc(
+ target_sparsity_by_scope={
+ TargetScope(patterns=[".*linear2.*"]): 0.3,
+ TargetScope(types=["embedding"]): 0.3, # Embedding is not supported
+ },
+ ignored_scope=IgnoredScope(patterns=[".*linear2.*"]),
+ raised_error_message="No layers to conduct activation sparsification",
+ ),
+ TargetSparsityByNodeTestDesc(
+ target_sparsity_by_scope={
+ TargetScope(names=["ThreeLinearModel/Linear[linear1]/linear_0"]): 0.3,
+ TargetScope(patterns=[".*linear1.*"]): 0.4,
+ },
+ ignored_scope=IgnoredScope(),
+ raised_error_message="matched by multiple items",
+ ),
+ ],
+)
+def test_get_target_sparsity_by_node(desc: TargetSparsityByNodeTestDesc):
+ model = wrap_model(
+ ThreeLinearModel(),
+ example_input=torch.ones((2, 4)).long(),
+ trace_parameters=True,
+ )
+ graph = model.nncf.get_graph()
+ algo = SparsifyActivationsAlgorithm(desc.target_sparsity_by_scope, desc.ignored_scope)
+ algo._set_backend_entity(model)
+ if desc.raised_error_message is not None:
+ with pytest.raises(nncf.ValidationError, match=desc.raised_error_message):
+ algo._get_target_sparsity_by_node(graph)
+ else:
+ target_sparsity_by_node = algo._get_target_sparsity_by_node(graph)
+ target_sparsity_by_node_name = {node.node_name: sparsity for node, sparsity in target_sparsity_by_node.items()}
+ assert sorted(target_sparsity_by_node_name.items()) == sorted(desc.ref_target_sparsity_by_node_name.items())
diff --git a/tests/torch/experimental/sparsify_activations/test_components.py b/tests/torch/experimental/sparsify_activations/test_components.py
new file mode 100644
index 00000000000..9c5fde1c9e5
--- /dev/null
+++ b/tests/torch/experimental/sparsify_activations/test_components.py
@@ -0,0 +1,295 @@
+# Copyright (c) 2024 Intel Corporation
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List
+
+import pytest
+import torch
+
+import nncf
+import nncf.experimental
+import nncf.experimental.torch.sparsify_activations
+from nncf.experimental.torch.sparsify_activations.target_scope import TargetScope
+from nncf.experimental.torch.sparsify_activations.target_scope import get_target_node_names_from_target_scope
+from nncf.experimental.torch.sparsify_activations.torch_backend import ActivationsSparsifier
+from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend
+from nncf.torch.model_creation import wrap_model
+from nncf.torch.nncf_network import NNCFNetwork
+from tests.common.test_ignored_scope import CONV_TYPE
+from tests.common.test_ignored_scope import IGNORED_SCOPES_TEST_DATA
+from tests.common.test_ignored_scope import LINEAR_TYPE
+from tests.common.test_ignored_scope import WRONG_IGNORED_SCOPES_TEST_DATA
+from tests.common.test_ignored_scope import NNCFGraphToTestIgnoredScope
+from tests.torch.experimental.sparsify_activations.helpers import ThreeLinearModel
+from tests.torch.experimental.sparsify_activations.helpers import convert_ignored_scope_to_target_scope
+
+
+@dataclass
+class SparsifierForwardTestDesc:
+ target_sparsity: float
+ alpha: float
+ input_batches: List[torch.Tensor]
+ ref_running_thresholds: List[torch.Tensor]
+ ref_outputs: List[torch.Tensor]
+
+
+sparsifier_forward_during_calibration_test_descs = {
+ "fp16": SparsifierForwardTestDesc(
+ target_sparsity=0.4,
+ alpha=0.2,
+ input_batches=[
+ torch.tensor([1.0, 3.0, 2.0, 4.0], dtype=torch.float16),
+ torch.tensor([4.0, 5.0, 4.5, -3.0], dtype=torch.float16),
+ ],
+ ref_running_thresholds=[
+ torch.tensor(2.1992, dtype=torch.float16),
+ torch.tensor(3.2559, dtype=torch.float16),
+ ],
+ ref_outputs=[
+ torch.tensor([0.0, 3.0, 0.0, 4.0], dtype=torch.float16),
+ torch.tensor([4.0, 5.0, 4.5, 0.0], dtype=torch.float16),
+ ],
+ ),
+ "fp32": SparsifierForwardTestDesc(
+ target_sparsity=0.8,
+ alpha=0.1,
+ input_batches=[
+ torch.tensor([-1.0, 1.0, 2.5]),
+ torch.tensor([1.0, 2.0, 0.0]),
+ torch.tensor([2.0, 0.0, 3.0]),
+ ],
+ ref_running_thresholds=[
+ torch.tensor(1.9000),
+ torch.tensor(1.7421),
+ torch.tensor(2.0587),
+ ],
+ ref_outputs=[
+ torch.tensor([0.0, 0.0, 2.5]),
+ torch.tensor([0.0, 2.0, 0.0]),
+ torch.tensor([0.0, 0.0, 3.0]),
+ ],
+ ),
+ "varying_shape": SparsifierForwardTestDesc(
+ target_sparsity=0.6,
+ alpha=0.5,
+ input_batches=[
+ torch.tensor([1.0, 2.0, 7.0]),
+ torch.tensor([[1.0, 2.0], [7.0, -3.0]]),
+ torch.tensor([[[1.0], [5.5], [8.5], [-3.0], [2.5]]]),
+ ],
+ ref_running_thresholds=[
+ torch.tensor(3.0000),
+ torch.tensor(2.8667),
+ torch.tensor(3.5143),
+ ],
+ ref_outputs=[
+ torch.tensor([0.0, 0.0, 7.0]),
+ torch.tensor([[0.0, 0.0], [7.0, -3.0]]),
+ torch.tensor([[[0.0], [5.5], [8.5], [0.0], [0.0]]]),
+ ],
+ ),
+}
+
+
+class TestActivationsSparsifier:
+ @pytest.fixture(autouse=True)
+ def setup(self, use_cuda: bool):
+ if use_cuda and not torch.cuda.is_available():
+ pytest.skip("CUDA is not available")
+ self.device = torch.device("cuda" if use_cuda else "cpu")
+
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
+ def test_forward_before_calibration(self, use_cuda: bool, dtype: torch.dtype):
+ device = self.device
+ input_tensor = torch.rand([3, 3], device=device, dtype=dtype)
+ sparsifier = ActivationsSparsifier(target_sparsity=0.9).to(device)
+ assert sparsifier.freeze is True
+ assert not sparsifier.num_batches_tracked.is_nonzero()
+ assert sparsifier.running_threshold.isneginf()
+ output_tensor = sparsifier(input_tensor)
+ # The output tensor is a new tensor
+ assert not output_tensor.is_set_to(input_tensor)
+ # Before calibration, the sparsifier does not change the input
+ torch.testing.assert_close(output_tensor, input_tensor, rtol=1e-4, atol=1e-4)
+
+ @pytest.mark.parametrize(
+ "desc",
+ sparsifier_forward_during_calibration_test_descs.values(),
+ ids=sparsifier_forward_during_calibration_test_descs.keys(),
+ )
+ def test_forward_during_calibration(self, use_cuda: bool, desc: SparsifierForwardTestDesc):
+ device = self.device
+ sparsifier = ActivationsSparsifier(desc.target_sparsity, desc.alpha).to(device)
+ sparsifier.freeze = False
+ running_thresholds = []
+ outputs = []
+ with torch.no_grad():
+ for batch in desc.input_batches:
+ output = sparsifier(batch.to(device))
+ running_thresholds.append(sparsifier.running_threshold)
+ outputs.append(output)
+ assert sparsifier.num_batches_tracked == len(desc.input_batches)
+ assert len(running_thresholds) == len(desc.ref_running_thresholds)
+ for threshold, ref_threshold in zip(running_thresholds, desc.ref_running_thresholds):
+ assert threshold.device.type == device.type
+ torch.testing.assert_close(threshold, ref_threshold, rtol=1e-4, atol=1e-4, check_device=False)
+ assert len(outputs) == len(desc.ref_outputs)
+ for output, ref_output in zip(outputs, desc.ref_outputs):
+ assert output.device.type == device.type
+ torch.testing.assert_close(output, ref_output, rtol=1e-4, atol=1e-4, check_device=False)
+
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
+ def test_forward_after_calibration(self, use_cuda: bool, dtype: torch.dtype):
+ device = self.device
+ sparsifier = ActivationsSparsifier(target_sparsity=0.9).to(device)
+ sparsifier.running_threshold.fill_(0.1)
+ sparsifier.num_batches_tracked.fill_(100)
+
+ for _ in range(2):
+ # The sparsifier does not change in the following forwards
+ input_tensor = torch.rand([2, 10], device=device, dtype=dtype)
+ ref_output = torch.where(input_tensor.abs() <= 0.1, 0.0, input_tensor)
+ output_tensor = sparsifier(ref_output)
+ assert sparsifier.num_batches_tracked == 100
+ torch.testing.assert_close(
+ sparsifier.running_threshold, torch.tensor(0.1, device=device), rtol=1e-4, atol=1e-4
+ )
+ torch.testing.assert_close(output_tensor, ref_output, rtol=1e-4, atol=1e-4)
+
+
+class TestPTSparsifyActivationsAlgoBackend:
+ def test_get_sparsifiers(self):
+ model, dataset = self.create_model_and_dataset()
+ sparse_model = nncf.experimental.torch.sparsify_activations.sparsify_activations(
+ model, dataset, target_sparsity_by_scope={TargetScope(patterns=[".*"]): 0.5}
+ )
+ backend = PTSparsifyActivationsAlgoBackend()
+ sparsifiers = backend.get_sparsifiers(sparse_model)
+ assert len(sparsifiers) == 3
+
+ @pytest.mark.parametrize("compress_weights", [False, True])
+ def test_insert_sparsifiers(self, compress_weights: bool):
+ model, dataset = self.create_model_and_dataset(compress_weights=compress_weights)
+ example_input = next(iter(dataset.get_inference_data()))
+ ref_output = model(example_input)
+
+ graph = model.nncf.get_graph()
+ nodes = graph.get_nodes_by_metatypes(PTSparsifyActivationsAlgoBackend.SUPPORTED_METATYPES)
+ backend = PTSparsifyActivationsAlgoBackend()
+ model_with_sparsifiers = backend.insert_sparsifiers(model, graph, {node: 0.9 for node in nodes})
+ assert len(backend.get_sparsifiers(model_with_sparsifiers)) == len(nodes)
+
+ output = model_with_sparsifiers(example_input)
+ torch.testing.assert_close(
+ output, ref_output, rtol=1e-4, atol=1e-4
+ ) # At this time the sparsifers do not change the output
+
+ def test_calibrate_sparsifiers(self, mocker):
+ model, dataset = self.create_model_and_dataset()
+ graph = model.nncf.get_graph()
+ backend = PTSparsifyActivationsAlgoBackend()
+ mock_sparsifier = ActivationsSparsifier(0.5, 0.1)
+ mock_sparsifier.freeze = True
+ num_model_forward_calls = 0
+
+ def model_forward_pre_hook(model: NNCFNetwork, args):
+ nonlocal num_model_forward_calls
+ num_model_forward_calls += 1
+ assert model.training is False
+
+ model.register_forward_pre_hook(model_forward_pre_hook)
+
+ with mocker.patch.object(backend, "get_sparsifiers", return_value=[mock_sparsifier]):
+ backend.calibrate_sparsifiers(model, graph, dataset)
+ assert mock_sparsifier.freeze is True
+ assert num_model_forward_calls == dataset.get_length()
+
+ def create_model_and_dataset(self, compress_weights: bool = False):
+ model = ThreeLinearModel()
+ dataset = nncf.Dataset(torch.randint(0, 30, (3, 2, 8)))
+ if compress_weights:
+ model = nncf.compress_weights(
+ model,
+ mode=nncf.CompressWeightsMode.INT8_SYM,
+ dataset=dataset,
+ )
+ else:
+ model = wrap_model(
+ model,
+ example_input=next(iter(dataset.get_inference_data())),
+ trace_parameters=True,
+ )
+ return model, dataset
+
+
+class TestTargetScope:
+ SAME_HASH_PAIRS = [
+ (TargetScope(), TargetScope()),
+ (
+ TargetScope(
+ names=["node_1", "node_2"],
+ patterns=["node\\d", "layer\\d"],
+ types=["Conv", "MatMul"],
+ subgraphs=[
+ nncf.Subgraph(inputs=["node_1", "node_2"], outputs=["node_3", "node_4"]),
+ nncf.Subgraph(inputs=["layer_1", "layer_2"], outputs=["layer_3", "layer_4", "layer_5"]),
+ ],
+ ),
+ TargetScope(
+ names=["node_2", "node_1"],
+ patterns=["layer\\d", "node\\d"],
+ types=["MatMul", "Conv"],
+ subgraphs=[
+ nncf.Subgraph(inputs=["layer_2", "layer_1"], outputs=["layer_5", "layer_4", "layer_3"]),
+ nncf.Subgraph(inputs=["node_2", "node_1"], outputs=["node_4", "node_3"]),
+ ],
+ ),
+ ),
+ ]
+
+ DIFFERENT_HASH_PAIRS = [
+ (TargetScope(), TargetScope(types=["Conv"])),
+ (
+ TargetScope(names=["node_1"]),
+ TargetScope(names=["node_1"], patterns=["layer\\d"]),
+ ),
+ (
+ TargetScope(subgraphs=[nncf.Subgraph(inputs=["node_1"], outputs=["node_2"])]),
+ TargetScope(subgraphs=[nncf.Subgraph(inputs=["node_1"], outputs=["node_3"])]),
+ ),
+ ]
+
+ TARGET_SCOPE_MATCH_DATA = [
+ (convert_ignored_scope_to_target_scope(ignored_scope), ref_ignored_names)
+ for ignored_scope, ref_ignored_names in IGNORED_SCOPES_TEST_DATA
+ ]
+ WRONG_TARGET_SCOPE_MATCH_DATA = list(map(convert_ignored_scope_to_target_scope, WRONG_IGNORED_SCOPES_TEST_DATA))
+
+ @pytest.mark.parametrize("target_scope1,target_scope2", SAME_HASH_PAIRS)
+ def test_same_hash(self, target_scope1: TargetScope, target_scope2: TargetScope):
+ assert hash(target_scope1) == hash(target_scope2)
+
+ @pytest.mark.parametrize("target_scope1,target_scope2", DIFFERENT_HASH_PAIRS)
+ def test_different_hash(self, target_scope1: TargetScope, target_scope2: TargetScope):
+ assert hash(target_scope1) != hash(target_scope2)
+
+ @pytest.mark.parametrize("target_scope,ref_target_names", TARGET_SCOPE_MATCH_DATA)
+ def test_get_target_node_names_from_target_scope(self, target_scope: TargetScope, ref_target_names: List[str]):
+ nncf_graph = NNCFGraphToTestIgnoredScope(CONV_TYPE, LINEAR_TYPE).nncf_graph
+ target_names = get_target_node_names_from_target_scope(target_scope, nncf_graph)
+ assert sorted(target_names) == sorted(ref_target_names)
+
+ @pytest.mark.parametrize("target_scope", WRONG_TARGET_SCOPE_MATCH_DATA)
+ def test_wrong_target_scope(self, target_scope: TargetScope):
+ nncf_graph = NNCFGraphToTestIgnoredScope(CONV_TYPE, LINEAR_TYPE).nncf_graph
+ with pytest.raises(nncf.ValidationError):
+ get_target_node_names_from_target_scope(target_scope, nncf_graph)