Skip to content

Commit

Permalink
Release v2.8.1 of NNCF to master
Browse files Browse the repository at this point in the history
  • Loading branch information
KodiaqQ committed Feb 9, 2024
1 parent 72eb39c commit 17b72c8
Show file tree
Hide file tree
Showing 33 changed files with 279 additions and 146 deletions.
11 changes: 11 additions & 0 deletions ReleaseNotes.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
# Release Notes

## New in Release 2.8.1

Post-training Quantization:

- Bugfixes:
- (Common) Fixed issue with `nncf.compress_weights()` to avoid overflows on 32-bit Windows systems.
- (Common) Fixed performance issue with `nncf.compress_weights()` on LLama models.
- (Common) Fixed `nncf.quantize_with_accuracy_control` pipeline with `tune_hyperparams=True` enabled option.
- (OpenVINO) Fixed issue for stateful LLM models and added state restoring after the inference for it.
- (PyTorch) Fixed issue with `nncf.compress_weights()` for LLM models with the executing `is_floating_point` with tracing.

## New in Release 2.8.0

Post-training Quantization:
Expand Down
1 change: 1 addition & 0 deletions docs/Installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ as well as the supported versions of Python:
| NNCF | OpenVINO | PyTorch | ONNX | TensorFlow | Python |
|-----------|------------|----------|----------|------------|--------|
| `develop` | `2023.3.0` | `2.1.2` | `1.13.1` | `2.12.0` | `3.8` |
| `2.8.1` | `2023.3.0` | `2.1.2` | `1.13.1` | `2.12.0` | `3.8` |
| `2.8.0` | `2023.3.0` | `2.1.2` | `1.13.1` | `2.12.0` | `3.8` |
| `2.7.0` | `2023.2.0` | `2.1` | `1.13.1` | `2.12.0` | `3.8` |
| `2.6.0` | `2023.1.0` | `2.0.1` | `1.13.1` | `2.12.0` | `3.8` |
Expand Down
8 changes: 0 additions & 8 deletions nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
from nncf.common.graph.transformations.command_creation import CommandCreator
from nncf.common.tensor_statistics import aggregator
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_available_backends
from nncf.common.utils.backend import get_backend
from nncf.common.utils.backend import is_openvino_compiled_model
from nncf.data.dataset import Dataset

TModel = TypeVar("TModel")
Expand Down Expand Up @@ -86,12 +84,6 @@ def create(model: TModel) -> Engine:
:param model: backend-specific model instance.
:return: backend-specific Engine instance.
"""
available_backends = get_available_backends()
if BackendType.OPENVINO in available_backends and is_openvino_compiled_model(model):
from nncf.openvino.engine import OVCompiledModelEngine

return OVCompiledModelEngine(model)

model_backend = get_backend(model)
if model_backend == BackendType.ONNX:
from nncf.onnx.engine import ONNXEngine
Expand Down
19 changes: 13 additions & 6 deletions nncf/openvino/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import openvino.runtime as ov

from nncf.common.engine import Engine
from nncf.openvino.graph.model_utils import model_has_state
from nncf.parameters import TargetDevice


Expand All @@ -27,11 +28,12 @@ class OVCompiledModelEngine(Engine):
to infer the compiled model.
"""

def __init__(self, model: ov.CompiledModel):
self.compiled_model = model
def __init__(self, compiled_model: ov.CompiledModel, stateful: bool):
self.infer_request = compiled_model.create_infer_request()
self.reset_state = stateful and hasattr(self.infer_request, "reset_state")
self.input_tensor_names = set()
self.number_of_inputs = len(model.inputs)
for model_input in model.inputs:
self.number_of_inputs = len(compiled_model.inputs)
for model_input in compiled_model.inputs:
self.input_tensor_names.update(model_input.get_names())

def _check_input_data_format(
Expand Down Expand Up @@ -63,7 +65,11 @@ def infer(
:return output_data: Model's output.
"""
self._check_input_data_format(input_data)
model_outputs = self.compiled_model(input_data)

if self.reset_state:
self.infer_request.reset_state()

model_outputs = self.infer_request.infer(input_data, share_inputs=True)

output_data = {}
for tensor, value in model_outputs.items():
Expand All @@ -86,8 +92,9 @@ def __init__(self, model: ov.Model, target_device: TargetDevice = TargetDevice.C
target_device = TargetDevice.CPU

ie = ov.Core()
stateful = model_has_state(model)
compiled_model = ie.compile_model(model, target_device.value)
self.engine = OVCompiledModelEngine(compiled_model)
self.engine = OVCompiledModelEngine(compiled_model, stateful)

def infer(
self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray]]
Expand Down
10 changes: 10 additions & 0 deletions nncf/openvino/graph/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,13 @@ def get_start_nodes_for_activation_path_tracing(nncf_graph: NNCFGraph) -> List[N
:return: Target NNCFGraph input nodes.
"""
return nncf_graph.get_input_nodes() + nncf_graph.get_nodes_by_metatypes([OVReadValueMetatype])


def model_has_state(model: ov.Model) -> bool:
"""
Returns True if model has state else False
:param model: OpenVINO model
:return: True if model has state else False
"""
return len(model.get_sinks()) > 0
2 changes: 1 addition & 1 deletion nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def native_quantize_with_accuracy_control_impl(
fast_bias_correction,
model_type,
ignored_scope,
advanced_quantization_parameters,
copied_parameters,
)
tuned_quantized_metric_results = evaluator.collect_metric_results(
tuned_quantized_model, validation_dataset, model_name="tuned"
Expand Down
42 changes: 30 additions & 12 deletions nncf/quantization/algorithms/accuracy_control/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from abc import abstractmethod
from typing import Any, List, Optional, TypeVar

from nncf.common.engine import Engine
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
Expand All @@ -21,6 +22,35 @@
TPModel = TypeVar("TPModel")


class PreparedModel(ABC):
@property
@abstractmethod
def model_for_inference(self) -> TPModel:
"""
Returns prepared model for inference.
:return: Prepared model for inference.
"""

@property
@abstractmethod
def engine(self) -> Engine:
"""
Returns the engine for inference the prepared model.
:return: The engine for inference the prepared model.
"""

def __call__(self, input_data: Any) -> Any:
"""
Runs model on the provided input data and returns the raw model outputs.
:param input_data: inputs for the model
:return: raw model outputs
"""
return self.engine.infer(input_data)


class AccuracyControlAlgoBackend(ABC):
# Metatypes

Expand Down Expand Up @@ -158,15 +188,3 @@ def get_model_size(model: TModel) -> int:
:param model: A model
:return: Model size (in bytes)
"""

# Preparation of model

@staticmethod
@abstractmethod
def prepare_for_inference(model: TModel) -> TPModel:
"""
Prepares model for inference.
:param model: A model that should be prepared.
:return: Prepared model for inference.
"""
58 changes: 26 additions & 32 deletions nncf/quantization/algorithms/accuracy_control/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
from dataclasses import dataclass
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union

from nncf.common.factory import EngineFactory
from nncf.common.logging import nncf_logger
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.common.utils.timer import timer
from nncf.data.dataset import Dataset
from nncf.quantization.algorithms.accuracy_control.backend import PreparedModel

TModel = TypeVar("TModel")
TPModel = TypeVar("TPModel")
TTensor = TypeVar("TTensor")


Expand Down Expand Up @@ -111,7 +110,7 @@ def is_metric_mode(self) -> bool:
"""
return self._metric_mode

def prepare_model_for_inference(self, model: TModel) -> TPModel:
def prepare_model(self, model: TModel) -> PreparedModel:
"""
Prepares model for inference.
Expand All @@ -121,21 +120,19 @@ def prepare_model_for_inference(self, model: TModel) -> TPModel:
backend = get_backend(model)

if backend == BackendType.OPENVINO:
import openvino.runtime as ov
from nncf.quantization.algorithms.accuracy_control.openvino_backend import OVPreparedModel

return ov.compile_model(model)
return OVPreparedModel(model)

raise NotImplementedError(
f"The `prepare_model_for_inference()` method is not implemented for the {backend} backend."
)
raise NotImplementedError(f"The `prepare_model()` method is not implemented for the {backend} backend.")

def validate_model_for_inference(
self, model_for_inference: TPModel, dataset: Dataset, indices: Optional[List[int]] = None
def validate_prepared_model(
self, prepared_model: PreparedModel, dataset: Dataset, indices: Optional[List[int]] = None
):
"""
Validates prepared model for inference.
:param model: Prepared model to validate.
:param prepared_model: Prepared model to validate.
:param dataset: Dataset to validate the model.
:param indices: Zero-based indices of data items that should be selected from
the dataset.
Expand All @@ -147,7 +144,7 @@ def validate_model_for_inference(
item.
"""
if self._metric_mode is None:
self._metric_mode = Evaluator.determine_mode(model_for_inference, dataset, self._validation_fn)
self._metric_mode = Evaluator.determine_mode(prepared_model, dataset, self._validation_fn)

if not self.is_metric_mode() and indices is not None:
raise ValueError("The `indices` parameter can be used only if Evaluator.is_metric_mode() = True")
Expand All @@ -156,7 +153,7 @@ def validate_model_for_inference(
if self._enable_iteration_count:
validation_dataset = IterationCounter(validation_dataset)

metric, values_for_each_item = self._validation_fn(model_for_inference, validation_dataset)
metric, values_for_each_item = self._validation_fn(prepared_model.model_for_inference, validation_dataset)

self._num_passed_iterations = validation_dataset.num_iterations if self._enable_iteration_count else 0

Expand Down Expand Up @@ -189,20 +186,20 @@ def validate(
Otherwise, if the condition is false, it represents list of logits for each
item.
"""
model_for_inference = self.prepare_model_for_inference(model)
return self.validate_model_for_inference(model_for_inference, dataset, indices)
prepared_model = self.prepare_model(model)
return self.validate_prepared_model(prepared_model, dataset, indices)

@staticmethod
def determine_mode(
model_for_inference: TPModel,
prepared_model: PreparedModel,
dataset: Dataset,
validation_fn: Callable[[Any, Iterable[Any]], Tuple[float, Union[None, List[float], List[List[TTensor]]]]],
) -> bool:
"""
Determines mode based on the type of returned value from the
validation function.
:param model_for_inference: Model to validate.
:param prepared_model: Model to validate.
:param dataset: Dataset to validate the model.
:param validation_fn: Validation function to validate model.
:return: A boolean indicator where `True` means that the `Evaluator` collects
Expand All @@ -214,7 +211,7 @@ def determine_mode(
data_item = dataset.get_data([0])

try:
metric_value, values_for_each_item = validation_fn(model_for_inference, data_item)
metric_value, values_for_each_item = validation_fn(prepared_model.model_for_inference, data_item)
except Exception:
metric_mode = False

Expand Down Expand Up @@ -261,15 +258,15 @@ def determine_mode(

return metric_mode

def collect_values_for_each_item_using_model_for_inference(
self, model_for_inference: TPModel, dataset: Dataset, indices: Optional[List[int]] = None
def collect_values_for_each_item_using_prepared_model(
self, prepared_model: PreparedModel, dataset: Dataset, indices: Optional[List[int]] = None
) -> Union[List[float], List[List[TTensor]]]:
"""
Collects value for each item from the dataset using prepared model for inference.
If `is_metric_mode()` returns `True` then i-th value is a metric for i-th data item.
It is an output of the model for i-th data item otherwise.
:param model: Model to infer.
:param prepared_model: Model to infer.
:param dataset: Dataset to collect values.
:param indices: The zero-based indices of data items that should be selected from
the dataset.
Expand All @@ -278,15 +275,14 @@ def collect_values_for_each_item_using_model_for_inference(
if self._metric_mode:
# Collect metrics for each item
values_for_each_item = [
self._validation_fn(model_for_inference, [data_item])[0] for data_item in dataset.get_data(indices)
self._validation_fn(prepared_model.model_for_inference, [data_item])[0]
for data_item in dataset.get_data(indices)
]
else:
# Collect outputs for each item
engine = EngineFactory.create(model_for_inference)

values_for_each_item = []
for data_item in dataset.get_inference_data(indices):
logits = engine.infer(data_item)
logits = prepared_model(data_item)
values_for_each_item.append(list(logits.values()))

self._num_passed_iterations = len(values_for_each_item) if self._enable_iteration_count else 0
Expand All @@ -307,8 +303,8 @@ def collect_values_for_each_item(
the dataset.
:return: Collected values.
"""
model_for_inference = self.prepare_model_for_inference(model)
return self.collect_values_for_each_item_using_model_for_inference(model_for_inference, dataset, indices)
prepared_model = self.prepare_model(model)
return self.collect_values_for_each_item_using_prepared_model(prepared_model, dataset, indices)

def collect_metric_results(self, model: TModel, dataset: Dataset, model_name: str = "") -> MetricResults:
"""
Expand All @@ -322,18 +318,16 @@ def collect_metric_results(self, model: TModel, dataset: Dataset, model_name: st
nncf_logger.info(f"Validation of {model_name} model was started")

with timer() as preparation_time:
model_for_inference = self.prepare_model_for_inference(model)
prepared_model = self.prepare_model(model)

with timer() as validation_time:
metric, values_for_each_item = self.validate_model_for_inference(model_for_inference, dataset)
metric, values_for_each_item = self.validate_prepared_model(prepared_model, dataset)

nncf_logger.info(f"Metric of {model_name} model: {metric}")

if values_for_each_item is None:
nncf_logger.info(f"Collecting values for each data item using the {model_name} model")
with timer():
values_for_each_item = self.collect_values_for_each_item_using_model_for_inference(
model_for_inference, dataset
)
values_for_each_item = self.collect_values_for_each_item_using_prepared_model(prepared_model, dataset)

return MetricResults(metric, values_for_each_item, preparation_time(), validation_time())
Loading

0 comments on commit 17b72c8

Please sign in to comment.