Skip to content

Commit

Permalink
update quantization filters to handle tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Dec 12, 2024
1 parent 6cc4c8c commit dca4fd8
Show file tree
Hide file tree
Showing 8 changed files with 417 additions and 290 deletions.
4 changes: 4 additions & 0 deletions examples/advanced/llm_hf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ We can use the following command to run the federated training with direct tenso
```
python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_tensor --job_dir ${PWD}/workspace/jobs/hf_sft_tensor --train_mode SFT --message_mode tensor
```
Similarly, quantization can be applied to tensor communication as well.
```
python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_tensor_fp4 --job_dir ${PWD}/workspace/jobs/hf_sft_tensor_fp4 --train_mode SFT --message_mode tensor --quantize_mode float4
```

## Federated Training with Multiple Clients
With the above example, we can easily extend the federated training to multiple clients. We can use the following command to run the federated training with multiple clients:
Expand Down
8 changes: 4 additions & 4 deletions examples/advanced/llm_hf/sft_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.pt.tensor_params_converter import PTReceiveParamsConverter, PTSendParamsConverter
from nvflare.app_opt.quantization.numpy_dequantizor import NumpyModelDequantizor
from nvflare.app_opt.quantization.numpy_quantizor import NumpyModelQuantizor
from nvflare.app_opt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.quantization.quantizor import ModelQuantizor
from nvflare.job_config.script_runner import BaseScriptRunner


Expand Down Expand Up @@ -68,8 +68,8 @@ def main():

if args.quantize_mode:
# If using quantization, add quantize filters.
quantizor = NumpyModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = NumpyModelDequantizor(source_data_type="float32")
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = ModelDequantizor()
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)

Expand Down
3 changes: 3 additions & 0 deletions nvflare/app_opt/quantization/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.

DATA_TYPE = [
"FLOAT64",
"FLOAT32",
"FLOAT16",
"BFLOAT16",
]

QUANTIZATION_TYPE = [
Expand Down
198 changes: 198 additions & 0 deletions nvflare/app_opt/quantization/dequantizor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Union

import numpy as np
import torch
from bitsandbytes.functional import QuantState, dequantize_4bit, dequantize_blockwise

from nvflare.apis.dxo import DXO, DataKind, MetaKey
from nvflare.apis.dxo_filter import DXOFilter
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_opt.quantization.constant import QUANTIZATION_TYPE


class ModelDequantizor(DXOFilter):
def __init__(self):
"""Filter to dequantize Shareable object to recover from quantization
Args:
None
"""

# support weight and weight_diff data kinds
data_kinds = [DataKind.WEIGHTS, DataKind.WEIGHT_DIFF]
super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds)
self.logger.info("Using model dequantizator.")

def dequantization(
self, params: dict, quant_state: dict, quantization_type: str, source_datatype: dict, fl_ctx: FLContext
):
n_params = len(params.keys())
self.log_info(fl_ctx, f"Running dequantization on {n_params} variables")
n_bytes_before = 0
n_bytes_after = 0
n_bytes_meta = 0
n_quant_params = 0
for i, param_name in enumerate(params.keys()):
source_data_type = source_datatype[param_name]

# get the bits information
source_date_bits = int(re.findall(r"\d+", source_data_type)[0])
quantization_bits = int(re.findall(r"\d+", quantization_type)[0])

# only dequantize if the quantization type is lower than the source data type
if quantization_bits >= source_date_bits:
self.log_info(
fl_ctx,
f"Skipping dequantization for {param_name}, quantization bit {quantization_type} >= source data bit {source_data_type}",
)
continue
else:
values = params[param_name]
n_bytes_before += values.nbytes
for item in quant_state[param_name].values():
if isinstance(item, np.ndarray) or isinstance(item, torch.Tensor):
n_bytes_meta += item.nbytes

if isinstance(values, np.ndarray):
# if numpy, convert to torch
source_data_format = "numpy"
elif isinstance(values, torch.Tensor):
source_data_format = "torch"
else:
raise ValueError(f"Invalid source data type: {type(values)}, valid: numpy or torch")

n_quant_params += 1
if quantization_type == "float16":
# direct convert back to higher precision
if source_data_format == "numpy":
if source_data_type == "float32":
values = values.astype(np.float32)
elif source_data_type == "float64":
values = values.astype(np.float64)
elif source_data_format == "torch":
if source_data_type == "float32":
values = values.float()
elif source_data_type == "float64":
values = values.double()
params[param_name] = values
elif quantization_type in ["blockwise8", "float4", "normfloat4"]:
# use bitsandbytes to dequantize the values
# extract quantization state
if quantization_type == "blockwise8":
if source_data_format == "numpy":
# first convert numpy array to tensor if numpy
quantized = torch.as_tensor(values)
absmax = torch.as_tensor(quant_state[param_name]["absmax"])
code = torch.as_tensor(quant_state[param_name]["code"])
# de-quanitze
dequantized = dequantize_blockwise(quantized, absmax=absmax, code=code)
# assign back
if source_data_format == "numpy":
params[param_name] = dequantized.numpy()
elif source_data_format == "torch":
params[param_name] = dequantized
else:
if source_data_format == "numpy":
# first convert numpy array to tensor, need to use GPU
quantized = torch.as_tensor(values).cuda()
# create QuantState object
quantize_state = QuantState(
quant_type=quant_state[param_name]["quant_type"],
absmax=torch.as_tensor(quant_state[param_name]["absmax"]).cuda(),
blocksize=quant_state[param_name]["blocksize"],
code=torch.as_tensor(quant_state[param_name]["quant_map"]).cuda(),
dtype=getattr(torch, quant_state[param_name]["dtype"]),
shape=torch.Size(quant_state[param_name]["shape"]),
)
elif source_data_format == "torch":
quantized = values.cuda()
quantize_state = QuantState(
quant_type=quant_state[param_name]["quant_type"],
absmax=quant_state[param_name]["absmax"].cuda(),
blocksize=quant_state[param_name]["blocksize"],
code=quant_state[param_name]["quant_map"].cuda(),
dtype=getattr(torch, quant_state[param_name]["dtype"]),
shape=torch.Size(quant_state[param_name]["shape"]),
)
# de-quanitze
if quantization_type == "float4":
dequantized = dequantize_4bit(quantized, quantize_state, quant_type="fp4")
else:
dequantized = dequantize_4bit(quantized, quantize_state, quant_type="nf4")
# assign back
if source_data_format == "numpy":
params[param_name] = dequantized.cpu().numpy()
elif source_data_format == "torch":
params[param_name] = dequantized.cpu()
# convert back to original data type
if source_data_type == "float32":
params[param_name] = params[param_name].float()
elif source_data_type == "float64":
params[param_name] = params[param_name].double()
elif source_data_type == "float16":
params[param_name] = params[param_name].half()
elif source_data_type == "bfloat16":
params[param_name] = params[param_name].bfloat16()

n_bytes_after += params[param_name].nbytes

self.log_info(
fl_ctx,
f"Dequantized {n_quant_params}/{n_params} params."
f" Before dequantization: {n_bytes_before / (1024 ** 2):.2f} MB with meta: {n_bytes_meta / (1024 ** 2):.2f} MB."
f" After dequantization: {n_bytes_after / (1024 ** 2):.2f} MB.",
)
return params

def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]:
"""Filter process apply to the Shareable object.
Args:
dxo: data to be processed
shareable: that the dxo belongs to
fl_ctx: FLContext
Returns: DXO object with dequantized weights
"""

self.log_info(fl_ctx, "Running dequantization...")

# check config
quantization_type = dxo.get_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, default=None)
if quantization_type.upper() not in QUANTIZATION_TYPE:
raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}")

dequantized_params = self.dequantization(
params=dxo.data,
quant_state=dxo.meta["quant_state"],
quantization_type=quantization_type,
source_datatype=dxo.meta["source_datatype"],
fl_ctx=fl_ctx,
)
# Compose new DXO with dequantized data
dxo.data = dequantized_params
dxo.remove_meta_props(MetaKey.PROCESSED_ALGORITHM)
dxo.remove_meta_props("quant_state")
dxo.remove_meta_props("source_datatype")
dxo.update_shareable(shareable)
self.log_info(fl_ctx, "Dequantized back")

return dxo
136 changes: 0 additions & 136 deletions nvflare/app_opt/quantization/numpy_dequantizor.py

This file was deleted.

Loading

0 comments on commit dca4fd8

Please sign in to comment.