diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index bd69df3c26a..ccb9f01a432 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -473,14 +473,21 @@ def compress_weights( if mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]: raise AttributeError("Torch backend does not support NF4 and E2M1 modes for weight compression.") - if True in [awq, scale_estimation, gptq, lora_correction]: + options = { + "sensitivity_metric": sensitivity_metric, + "awq": awq, + "scale_estimation": scale_estimation, + "gptq": gptq, + "lora_correction": lora_correction, + } + unsupported_options = [name for name, value in options.items() if value is not None] + if unsupported_options: raise AttributeError( - "Torch backend does not support 'awq', 'scale_estimation', 'gptq' and 'lora_correction' options. " - "Set them to None." + f"Torch backend does not support {', '.join(unsupported_options)} option(s). Set them to None." ) - if backup_mode is not None: - raise AttributeError("Torch backend does not support backup_mode option.") + if ratio is not None and ratio != 1: + raise AttributeError("Torch backend does not support ratio != 1.") if is_wrapped_model(model): if not model.nncf.trace_parameters: @@ -506,14 +513,22 @@ def compress_weights( if mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]: raise AttributeError("Torch backend does not support NF4 and E2M1 modes for weight compression.") - if backup_mode is not None: - raise AttributeError("TorchFX backend does not support backup_mode option.") - - if any((awq, scale_estimation, gptq, lora_correction)): + options = { + "sensitivity_metric": sensitivity_metric, + "awq": awq, + "scale_estimation": scale_estimation, + "gptq": gptq, + "lora_correction": lora_correction, + } + unsupported_options = [name for name, value in options.items() if value is not None] + if unsupported_options: raise AttributeError( - "TorchFX backend does not support 'awq', 'scale_estimation', 'gptq'," - "and 'lora_correction' options. Set them to None." + f"TorchFX backend does not support {', '.join(unsupported_options)} option(s). Set them to None." ) + + if ratio is not None and ratio != 1: + raise AttributeError("TorchFX backend does not support ratio != 1.") + if dataset: raise AttributeError( "TorchFX only supports data-free weights compression," "Set the 'dataset' option to None" diff --git a/tests/post_training/experimental/sparsify_activations/pipelines.py b/tests/post_training/experimental/sparsify_activations/pipelines.py index ef2ecbd1847..a6705b6e8e9 100644 --- a/tests/post_training/experimental/sparsify_activations/pipelines.py +++ b/tests/post_training/experimental/sparsify_activations/pipelines.py @@ -21,17 +21,17 @@ import torch.utils import torch.utils.data import torchvision -from datasets import load_dataset from optimum.exporters.openvino.convert import export_from_model from optimum.intel.openvino import OVModelForCausalLM from transformers import AutoModelForCausalLM import nncf +from datasets import load_dataset from nncf.experimental.torch.sparsify_activations import sparsify_activations from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgoBackend from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend -from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor -from nncf.torch.quantization.layers import SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor from tests.post_training.pipelines.base import LIMIT_LENGTH_OF_STATUS from tests.post_training.pipelines.base import PT_BACKENDS from tests.post_training.pipelines.base import BackendType @@ -267,7 +267,7 @@ def save_compressed_model(self): if self.backend == BackendType.CUDA_TORCH: self.model_hf.float() for module in self.model_hf.nncf.modules(): - if isinstance(module, (AsymmetricWeightsDecompressor, SymmetricWeightsDecompressor)): + if isinstance(module, (INT8AsymmetricWeightsDecompressor, INT8SymmetricWeightsDecompressor)): module.result_dtype = torch.float32 export_from_model( self.model_hf, self.output_model_dir, stateful=False, compression_option="fp32", device="cuda" diff --git a/tests/torch/fx/test_compress_weights.py b/tests/torch/fx/test_compress_weights.py index 20793e31493..700c3e3f7f8 100644 --- a/tests/torch/fx/test_compress_weights.py +++ b/tests/torch/fx/test_compress_weights.py @@ -23,6 +23,8 @@ from nncf.quantization import compress_weights from nncf.torch.dynamic_graph.patch_pytorch import disable_patching from tests.torch.ptq.test_weights_compression import ALL_SENSITIVITY_METRICS +from tests.torch.ptq.test_weights_compression import INT4_MODES +from tests.torch.ptq.test_weights_compression import INT8_MODES from tests.torch.ptq.test_weights_compression import SUPPORTED_MODES from tests.torch.ptq.test_weights_compression import UNSUPPORTED_MODES from tests.torch.ptq.test_weights_compression import ConvolutionModel @@ -76,10 +78,13 @@ def _capture_model(model, inputs): @pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_weights(mode): - model = ShortTransformer(5, 10) - input_ids = torch.randint(0, 10, (5,)) + model = ShortTransformer(8, 16) + input_ids = torch.randint(0, 10, (8,)) exported_model = _capture_model(model, input_ids) - compressed_model = compress_weights(exported_model, mode=mode) + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(exported_model, mode=mode, **kwargs) dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 n_compressed_weights = 0 n_target_modules = 0 @@ -91,7 +96,7 @@ def test_compress_weights(mode): assert n_target_modules == n_compressed_weights -@pytest.mark.parametrize("mode", SUPPORTED_MODES) +@pytest.mark.parametrize("mode", INT8_MODES) def test_compress_weights_graph_edge(mode): model = ShortTransformer(5, 10) input_ids = torch.randint(0, 10, (5,)) @@ -108,10 +113,13 @@ def test_compress_weights_graph_edge(mode): @pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_weights_shared_weights(mocker, mode): with disable_patching(): - model = ShortTransformer(5, 10, share_weights=True) - input_ids = torch.randint(0, 10, (5,)) + model = ShortTransformer(8, 16, share_weights=True) + input_ids = torch.randint(0, 10, (8,)) exported_model = _capture_model(model, input_ids) - compressed_model = compress_weights(exported_model, mode=mode) + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(exported_model, mode=mode, **kwargs) dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 n_compressed_weights = 0 n_target_modules = 0 @@ -141,11 +149,14 @@ def test_compress_weights_shared_weights(mocker, mode): @pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compressed_model_inference(mode): torch.manual_seed(42) - model = ShortTransformer(5, 10, share_weights=True) - input_ids = torch.randint(0, 10, (5,)) + model = ShortTransformer(8, 16, share_weights=True) + input_ids = torch.randint(0, 10, (8,)) exported_model = _capture_model(model, input_ids) exported_model_output = exported_model(input_ids) - compressed_model = compress_weights(exported_model, mode=mode) + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(exported_model, mode=mode, **kwargs) compressed_model_outputs = compressed_model(input_ids) assert ( @@ -160,7 +171,7 @@ def test_compress_weights_model_size_conv(mode): dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 model = ConvolutionModel() - input_ids = torch.randint(0, 10, [1, 3, 300, 300]) + input_ids = torch.randint(0, 10, [1, 3, 256, 256]) exported_model = _capture_model(model, input_ids) model_size = get_model_size(exported_model) compressed_model = compress_weights(exported_model, mode=mode) @@ -181,9 +192,11 @@ def test_compress_weights_model_size_conv(mode): @pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_weights_functional_model(mode): model = FunctionalModel() - decompressor_type = "symmetric" if mode == CompressWeightsMode.INT8_SYM else "asymmetric" + decompressor_type = ( + "symmetric" if mode in (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT4_SYM) else "asymmetric" + ) - input_ids = torch.randint(0, 10, [1, 3, 300, 300]) + input_ids = torch.randint(0, 10, [1, 3, 256, 256]) exported_model = _capture_model(model, input_ids) compressed_model = compress_weights(exported_model, mode=mode) @@ -195,7 +208,7 @@ def test_compress_weights_functional_model(mode): assert n_compressed_weights == 4 -@pytest.mark.parametrize("mode", SUPPORTED_MODES) +@pytest.mark.parametrize("mode", INT8_MODES) @pytest.mark.parametrize( "params", ( @@ -222,6 +235,27 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): compress_weights(exported_model, mode=mode, **params) +@pytest.mark.parametrize("mode", INT4_MODES) +@pytest.mark.parametrize( + "params", + ( + {"ratio": 0.5}, + *({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS), + {"gptq": True}, + {"awq": True}, + {"scale_estimation": True}, + {"lora_correction": True}, + {"dataset": Dataset([1])}, + ), +) +def test_raise_error_with_unsupported_params_for_int8(mode, params): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + exported_model = _capture_model(dummy_torch_model, dummy_input) + with pytest.raises(AttributeError): + compress_weights(exported_model, mode=mode, **params) + + @pytest.mark.parametrize("mode", UNSUPPORTED_MODES) def test_raise_error_with_not_int8(mode): dummy_torch_model = EmptyModel() @@ -251,7 +285,7 @@ def test_model_devices_and_precisions(use_cuda, dtype): model = MatMulModel().to(device) if dtype == torch.float16: model.half() - dummy_input = torch.rand((1, 300), dtype=dtype, device=device) + dummy_input = torch.rand((1, 256), dtype=dtype, device=device) exported_model = _capture_model(model, dummy_input) compressed_model = compress_weights(exported_model) result = compressed_model(dummy_input) diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index dee60e92e5f..bede8793150 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -18,8 +18,10 @@ from nncf import SensitivityMetric from nncf.quantization import compress_weights from nncf.torch import wrap_model -from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor -from nncf.torch.quantization.layers import SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor DATA_BASED_SENSITIVITY_METRICS = ( SensitivityMetric.HESSIAN_INPUT_ACTIVATION, @@ -30,12 +32,10 @@ ALL_SENSITIVITY_METRICS = DATA_BASED_SENSITIVITY_METRICS + (SensitivityMetric.WEIGHT_QUANTIZATION_ERROR,) -SUPPORTED_MODES = (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM) -UNSUPPORTED_MODES = ( - CompressWeightsMode.INT4_SYM, - CompressWeightsMode.INT4_ASYM, - CompressWeightsMode.NF4, -) +INT8_MODES = (CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM) +INT4_MODES = (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM) +SUPPORTED_MODES = INT8_MODES + INT4_MODES +UNSUPPORTED_MODES = (CompressWeightsMode.NF4, CompressWeightsMode.E2M1) class ShortTransformer(torch.nn.Module): @@ -58,7 +58,7 @@ def forward(self, input_ids): class MatMulModel(torch.nn.Module): def __init__(self): super().__init__() - self.w = torch.nn.Parameter(torch.ones(size=(300, 300), dtype=torch.float32)) + self.w = torch.nn.Parameter(torch.ones(size=(256, 256), dtype=torch.float32)) def forward(self, input): return input @ self.w @@ -68,7 +68,7 @@ class FunctionalModel(torch.nn.Module): def __init__(self): super().__init__() self.conv_w = torch.nn.Parameter(torch.ones(size=(5, 3, 3, 3), dtype=torch.float32)) - self.matmul_w = torch.nn.Parameter(torch.ones(size=(1, 3, 300, 300), dtype=torch.float32)) + self.matmul_w = torch.nn.Parameter(torch.ones(size=(1, 3, 256, 256), dtype=torch.float32)) self.conv_tr_w = torch.nn.Parameter(torch.rand(size=(5, 4, 3, 3))) self.nested_matmul = MatMulModel() @@ -108,14 +108,18 @@ def forward(self, input_): return x -@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +@pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_weights(mode): - model = ShortTransformer(5, 10) + model = ShortTransformer(8, 16) dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 - input_ids = torch.randint(0, 10, (5,)) + input_ids = torch.randint(0, 10, (8,)) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model, mode=mode) + + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(wrapped_model, mode=mode, **kwargs) n_compressed_weights = 0 n_target_modules = 0 @@ -129,14 +133,19 @@ def test_compress_weights(mode): assert n_compressed_weights == n_target_modules -@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +@pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_weights_functional_model(mode): model = FunctionalModel() - decompressor_type = ( - SymmetricWeightsDecompressor if mode == CompressWeightsMode.INT8_SYM else AsymmetricWeightsDecompressor - ) + decompressor_map = { + CompressWeightsMode.INT8_SYM: (INT8SymmetricWeightsDecompressor,), + CompressWeightsMode.INT8_ASYM: (INT8AsymmetricWeightsDecompressor,), + CompressWeightsMode.INT4_SYM: (INT4SymmetricWeightsDecompressor, INT8AsymmetricWeightsDecompressor), + CompressWeightsMode.INT4_ASYM: (INT4AsymmetricWeightsDecompressor, INT8AsymmetricWeightsDecompressor), + } - input_ids = torch.randint(0, 10, [1, 3, 300, 300]) + decompressor_type = decompressor_map[mode] + + input_ids = torch.randint(0, 10, [1, 3, 256, 256]) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) compressed_model = compress_weights(wrapped_model, mode=mode) @@ -166,14 +175,18 @@ def test_compress_weights_conv(): assert n_compressed_weights == n_target_modules -@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +@pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_shared_weights(mocker, mode): - model = ShortTransformer(5, 10, share_weights=True) + model = ShortTransformer(8, 16, share_weights=True) dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 - input_ids = torch.randint(0, 10, (5,)) + input_ids = torch.randint(0, 10, (8,)) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model, mode=mode) + + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(wrapped_model, mode=mode, **kwargs) n_compressed_weights = 0 n_target_modules = 0 @@ -202,7 +215,7 @@ def forward(self, input): return input -@pytest.mark.parametrize("mode", SUPPORTED_MODES) +@pytest.mark.parametrize("mode", INT8_MODES) @pytest.mark.parametrize( "params", ( @@ -228,6 +241,26 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): compress_weights(wrapped_model, mode=mode, **params) +@pytest.mark.parametrize("mode", INT4_MODES) +@pytest.mark.parametrize( + "params", + ( + {"ratio": 0.5}, + *({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS), + {"gptq": True}, + {"awq": True}, + {"scale_estimation": True}, + {"lora_correction": True}, + ), +) +def test_raise_error_with_unsupported_params_for_int4(mode, params): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True) + with pytest.raises(AttributeError): + compress_weights(wrapped_model, mode=mode, **params) + + @pytest.mark.parametrize("mode", UNSUPPORTED_MODES) def test_raise_error_with_not_int8(mode): dummy_torch_model = EmptyModel() @@ -269,7 +302,7 @@ def test_model_devices_and_precisions(use_cuda, dtype): if dtype == torch.float16: model.half() - dummy_input = torch.rand((1, 300), dtype=dtype, device=device) + dummy_input = torch.rand((1, 256), dtype=dtype, device=device) wrapped_model = wrap_model(model, example_input=dummy_input, trace_parameters=True) compressed_model = compress_weights(wrapped_model) result = compressed_model(dummy_input)