From 395f2deafc3cc7dd81dfb848e19d81578408def5 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Fri, 13 Dec 2024 10:14:37 -0500 Subject: [PATCH] bug fixes and unittest updates --- nvflare/app_opt/quantization/dequantizor.py | 59 +++++++++---------- .../app_opt/quantization/quantization_test.py | 24 ++++++-- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/nvflare/app_opt/quantization/dequantizor.py b/nvflare/app_opt/quantization/dequantizor.py index 1ac6265c6e..abd1c1c277 100644 --- a/nvflare/app_opt/quantization/dequantizor.py +++ b/nvflare/app_opt/quantization/dequantizor.py @@ -80,17 +80,7 @@ def dequantization( n_quant_params += 1 if quantization_type == "float16": - # direct convert back to higher precision - if source_data_format == "numpy": - if source_data_type == "float32": - values = values.astype(np.float32) - elif source_data_type == "float64": - values = values.astype(np.float64) - elif source_data_format == "torch": - if source_data_type == "float32": - values = values.float() - elif source_data_type == "float64": - values = values.double() + # direct assign and convert back to higher precision params[param_name] = values elif quantization_type in ["blockwise8", "float4", "normfloat4"]: # use bitsandbytes to dequantize the values @@ -101,13 +91,12 @@ def dequantization( quantized = torch.as_tensor(values) absmax = torch.as_tensor(quant_state[param_name]["absmax"]) code = torch.as_tensor(quant_state[param_name]["code"]) + elif source_data_format == "torch": + quantized = values + absmax = quant_state[param_name]["absmax"] + code = quant_state[param_name]["code"] # de-quanitze dequantized = dequantize_blockwise(quantized, absmax=absmax, code=code) - # assign back - if source_data_format == "numpy": - params[param_name] = dequantized.numpy() - elif source_data_format == "torch": - params[param_name] = dequantized else: if source_data_format == "numpy": # first convert numpy array to tensor, need to use GPU @@ -136,20 +125,30 @@ def dequantization( dequantized = dequantize_4bit(quantized, quantize_state, quant_type="fp4") else: dequantized = dequantize_4bit(quantized, quantize_state, quant_type="nf4") - # assign back - if source_data_format == "numpy": - params[param_name] = dequantized.cpu().numpy() - elif source_data_format == "torch": - params[param_name] = dequantized.cpu() - # convert back to original data type - if source_data_type == "float32": - params[param_name] = params[param_name].float() - elif source_data_type == "float64": - params[param_name] = params[param_name].double() - elif source_data_type == "float16": - params[param_name] = params[param_name].half() - elif source_data_type == "bfloat16": - params[param_name] = params[param_name].bfloat16() + if source_data_format == "numpy": + params[param_name] = dequantized.cpu().numpy() + elif source_data_format == "torch": + params[param_name] = dequantized.cpu() + + # assign back + if source_data_format == "numpy": + # convert back to original data type + if source_data_type == "float32": + params[param_name] = params[param_name].astype(np.float32) + elif source_data_type == "float64": + params[param_name] = params[param_name].astype(np.float64) + elif source_data_type == "float16": + params[param_name] = params[param_name].astype(np.float16) + elif source_data_format == "torch": + # convert back to original data type + if source_data_type == "float32": + params[param_name] = params[param_name].float() + elif source_data_type == "float64": + params[param_name] = params[param_name].double() + elif source_data_type == "float16": + params[param_name] = params[param_name].half() + elif source_data_type == "bfloat16": + params[param_name] = params[param_name].bfloat16() n_bytes_after += params[param_name].nbytes diff --git a/tests/unit_test/app_opt/quantization/quantization_test.py b/tests/unit_test/app_opt/quantization/quantization_test.py index 3d3b41bddd..d2943a5e9a 100644 --- a/tests/unit_test/app_opt/quantization/quantization_test.py +++ b/tests/unit_test/app_opt/quantization/quantization_test.py @@ -14,11 +14,12 @@ import numpy as np import pytest +import torch from nvflare.apis.dxo import DXO, DataKind from nvflare.apis.fl_context import FLContext -from nvflare.app_opt.quantization.dequantizor import NumpyModelDequantizor -from nvflare.app_opt.quantization.quantizor import NumpyModelQuantizor +from nvflare.app_opt.quantization.dequantizor import ModelDequantizor +from nvflare.app_opt.quantization.quantizor import ModelQuantizor TEST_CASES = [ ( @@ -31,6 +32,16 @@ "blockwise8", {"a": np.array([0.99062496, 2.003125, 3.015625, 4.0], dtype="float32")}, ), + ( + {"a": torch.tensor([1.0, 2.0, 3.0, 4000.0], dtype=torch.bfloat16)}, + "float16", + {"a": torch.tensor([1.0, 2.0, 3.0, 4000.0], dtype=torch.bfloat16)}, + ), + ( + {"a": torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.bfloat16)}, + "float4", + {"a": torch.tensor([1.0, 2.0, 2.6719, 4.0], dtype=torch.bfloat16)}, + ), ] @@ -42,12 +53,15 @@ def test_quantization(self, input_data, quantization_type, expected_data): data=input_data, ) fl_ctx = FLContext() - f_quant = NumpyModelQuantizor(quantization_type=quantization_type) + f_quant = ModelQuantizor(quantization_type=quantization_type) quant_dxo = f_quant.process_dxo(dxo, dxo.to_shareable(), fl_ctx) - f_dequant = NumpyModelDequantizor(source_data_type="float32") + f_dequant = ModelDequantizor() dequant_dxo = f_dequant.process_dxo(quant_dxo, dxo.to_shareable(), fl_ctx) dequant_data = dequant_dxo.data for key in dequant_data.keys(): dequant_array = dequant_data[key] expected_array = expected_data[key] - assert np.allclose(dequant_array, expected_array) + if isinstance(dequant_array, torch.Tensor): + assert torch.allclose(dequant_array, expected_array) + else: + assert np.allclose(dequant_array, expected_array)