Skip to content

Commit

Permalink
bug fixes and unittest updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Dec 13, 2024
1 parent dca4fd8 commit 395f2de
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 35 deletions.
59 changes: 29 additions & 30 deletions nvflare/app_opt/quantization/dequantizor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,7 @@ def dequantization(

n_quant_params += 1
if quantization_type == "float16":
# direct convert back to higher precision
if source_data_format == "numpy":
if source_data_type == "float32":
values = values.astype(np.float32)
elif source_data_type == "float64":
values = values.astype(np.float64)
elif source_data_format == "torch":
if source_data_type == "float32":
values = values.float()
elif source_data_type == "float64":
values = values.double()
# direct assign and convert back to higher precision
params[param_name] = values
elif quantization_type in ["blockwise8", "float4", "normfloat4"]:
# use bitsandbytes to dequantize the values
Expand All @@ -101,13 +91,12 @@ def dequantization(
quantized = torch.as_tensor(values)
absmax = torch.as_tensor(quant_state[param_name]["absmax"])
code = torch.as_tensor(quant_state[param_name]["code"])
elif source_data_format == "torch":
quantized = values
absmax = quant_state[param_name]["absmax"]
code = quant_state[param_name]["code"]
# de-quanitze
dequantized = dequantize_blockwise(quantized, absmax=absmax, code=code)
# assign back
if source_data_format == "numpy":
params[param_name] = dequantized.numpy()
elif source_data_format == "torch":
params[param_name] = dequantized
else:
if source_data_format == "numpy":
# first convert numpy array to tensor, need to use GPU
Expand Down Expand Up @@ -136,20 +125,30 @@ def dequantization(
dequantized = dequantize_4bit(quantized, quantize_state, quant_type="fp4")
else:
dequantized = dequantize_4bit(quantized, quantize_state, quant_type="nf4")
# assign back
if source_data_format == "numpy":
params[param_name] = dequantized.cpu().numpy()
elif source_data_format == "torch":
params[param_name] = dequantized.cpu()
# convert back to original data type
if source_data_type == "float32":
params[param_name] = params[param_name].float()
elif source_data_type == "float64":
params[param_name] = params[param_name].double()
elif source_data_type == "float16":
params[param_name] = params[param_name].half()
elif source_data_type == "bfloat16":
params[param_name] = params[param_name].bfloat16()
if source_data_format == "numpy":
params[param_name] = dequantized.cpu().numpy()
elif source_data_format == "torch":
params[param_name] = dequantized.cpu()

# assign back
if source_data_format == "numpy":
# convert back to original data type
if source_data_type == "float32":
params[param_name] = params[param_name].astype(np.float32)
elif source_data_type == "float64":
params[param_name] = params[param_name].astype(np.float64)
elif source_data_type == "float16":
params[param_name] = params[param_name].astype(np.float16)
elif source_data_format == "torch":
# convert back to original data type
if source_data_type == "float32":
params[param_name] = params[param_name].float()
elif source_data_type == "float64":
params[param_name] = params[param_name].double()
elif source_data_type == "float16":
params[param_name] = params[param_name].half()
elif source_data_type == "bfloat16":
params[param_name] = params[param_name].bfloat16()

n_bytes_after += params[param_name].nbytes

Expand Down
24 changes: 19 additions & 5 deletions tests/unit_test/app_opt/quantization/quantization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@

import numpy as np
import pytest
import torch

from nvflare.apis.dxo import DXO, DataKind
from nvflare.apis.fl_context import FLContext
from nvflare.app_opt.quantization.dequantizor import NumpyModelDequantizor
from nvflare.app_opt.quantization.quantizor import NumpyModelQuantizor
from nvflare.app_opt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.quantization.quantizor import ModelQuantizor

TEST_CASES = [
(
Expand All @@ -31,6 +32,16 @@
"blockwise8",
{"a": np.array([0.99062496, 2.003125, 3.015625, 4.0], dtype="float32")},
),
(
{"a": torch.tensor([1.0, 2.0, 3.0, 4000.0], dtype=torch.bfloat16)},
"float16",
{"a": torch.tensor([1.0, 2.0, 3.0, 4000.0], dtype=torch.bfloat16)},
),
(
{"a": torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.bfloat16)},
"float4",
{"a": torch.tensor([1.0, 2.0, 2.6719, 4.0], dtype=torch.bfloat16)},
),
]


Expand All @@ -42,12 +53,15 @@ def test_quantization(self, input_data, quantization_type, expected_data):
data=input_data,
)
fl_ctx = FLContext()
f_quant = NumpyModelQuantizor(quantization_type=quantization_type)
f_quant = ModelQuantizor(quantization_type=quantization_type)
quant_dxo = f_quant.process_dxo(dxo, dxo.to_shareable(), fl_ctx)
f_dequant = NumpyModelDequantizor(source_data_type="float32")
f_dequant = ModelDequantizor()
dequant_dxo = f_dequant.process_dxo(quant_dxo, dxo.to_shareable(), fl_ctx)
dequant_data = dequant_dxo.data
for key in dequant_data.keys():
dequant_array = dequant_data[key]
expected_array = expected_data[key]
assert np.allclose(dequant_array, expected_array)
if isinstance(dequant_array, torch.Tensor):
assert torch.allclose(dequant_array, expected_array)
else:
assert np.allclose(dequant_array, expected_array)

0 comments on commit 395f2de

Please sign in to comment.