From 048da3393c72fe308facc7ee7cdf5c96eca3a506 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Tue, 26 Nov 2024 21:10:57 -0500 Subject: [PATCH 01/16] Added support for bfloat16 tensor using JIT --- nvflare/app_opt/pt/decomposers.py | 51 +++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/nvflare/app_opt/pt/decomposers.py b/nvflare/app_opt/pt/decomposers.py index f009a6b1c2..f50ea47fb4 100644 --- a/nvflare/app_opt/pt/decomposers.py +++ b/nvflare/app_opt/pt/decomposers.py @@ -22,18 +22,63 @@ from nvflare.fuel.utils.fobs.datum import DatumManager +class SerializationModule(torch.nn.Module): + def __init__(self, tensor): + super().__init__() + self.register_buffer("saved_tensor", tensor) + + class TensorDecomposer(fobs.Decomposer): def supported_type(self): return torch.Tensor def decompose(self, target: torch.Tensor, manager: DatumManager = None) -> Any: + if target.dtype == torch.bfloat16: + return self._jit_serialize(target) + else: + return self._numpy_serialize(target) + + def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: + + if isinstance(data, dict): + if data["dtype"] == "torch.bfloat16": + return self._jit_deserialize(data) + else: + buf = data["buffer"] + else: + buf = data + + return self._numpy_deserialize(buf) + + @staticmethod + def _numpy_serialize(tensor: torch.Tensor) -> dict: stream = BytesIO() # torch.save uses Pickle so converting Tensor to ndarray first - array = target.detach().cpu().numpy() + array = tensor.detach().cpu().numpy() np.save(stream, array, allow_pickle=False) - return stream.getvalue() + return { + "buffer": stream.getvalue(), + "dtype": str(tensor.dtype), + } - def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: + @staticmethod + def _numpy_deserialize(data: Any) -> torch.Tensor: stream = BytesIO(data) array = np.load(stream, allow_pickle=False) return torch.from_numpy(array) + + @staticmethod + def _jit_serialize(tensor: torch.Tensor) -> dict: + module = SerializationModule(tensor) + stream = BytesIO() + torch.jit.save(torch.jit.script(module), stream) + return { + "buffer": stream.getvalue(), + "dtype": str(tensor.dtype), + } + + @staticmethod + def _jit_deserialize(data: Any) -> torch.Tensor: + stream = BytesIO(data["buffer"]) + loaded_module = torch.jit.load(stream) + return loaded_module.saved_tensor From 0dcefd3f34e1c67996066517755e94a6ec3451dd Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Mon, 2 Dec 2024 16:56:03 -0500 Subject: [PATCH 02/16] directly send tensor via jit serialization --- examples/advanced/llm_hf/sft_job.py | 12 +++- .../advanced/llm_hf/src/hf_sft_peft_fl.py | 2 +- .../advanced/llm_hf/src/params_converter.py | 64 +++++++++++++++++++ .../in_process_client_api_executor.py | 33 ++++++++-- nvflare/app_opt/pt/decomposers.py | 24 +++++-- nvflare/app_opt/pt/params_converter.py | 2 + nvflare/job_config/script_runner.py | 9 ++- 7 files changed, 133 insertions(+), 13 deletions(-) create mode 100644 examples/advanced/llm_hf/src/params_converter.py diff --git a/examples/advanced/llm_hf/sft_job.py b/examples/advanced/llm_hf/sft_job.py index 3b5221baec..ab328b7297 100644 --- a/examples/advanced/llm_hf/sft_job.py +++ b/examples/advanced/llm_hf/sft_job.py @@ -21,8 +21,9 @@ from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor from nvflare.app_opt.quantization.numpy_dequantizor import NumpyModelDequantizor from nvflare.app_opt.quantization.numpy_quantizor import NumpyModelQuantizor -from nvflare.job_config.script_runner import ScriptRunner +from nvflare.job_config.script_runner import BaseScriptRunner +from src.params_converter import PTSendParamsConverter, PTReceiveParamsConverter def main(): args = define_parser() @@ -85,11 +86,16 @@ def main(): for i in range(num_clients): client_id = client_ids[i] site_name = f"site-{client_id}" - data_path_train = os.path.join(args.data_path, client_id, "training.jsonl") + data_path_train = os.path.join(args.data_path, client_id, "validation.jsonl") data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl") - runner = ScriptRunner( + # Add params converters and send to client + job.to(PTSendParamsConverter(), site_name, id="pt_send") + job.to(PTReceiveParamsConverter(), site_name, id="pt_receive") + runner = BaseScriptRunner( script=train_script, script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}", + from_nvflare_converter_id="pt_receive", + to_nvflare_converter_id="pt_send", ) job.to(runner, site_name, tasks=["train"]) if args.quantize_mode: diff --git a/examples/advanced/llm_hf/src/hf_sft_peft_fl.py b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py index 1411fabd95..d0114ffec7 100755 --- a/examples/advanced/llm_hf/src/hf_sft_peft_fl.py +++ b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py @@ -233,7 +233,7 @@ def evaluate(input_weights, mode): out_param["model." + key] = out_param.pop(key).cpu() # cast out_param to float32 preparing for communication - out_param = {k: v.to(torch.float32) for k, v in out_param.items()} + #out_param = {k: v.to(torch.float32) for k, v in out_param.items()} # construct trained FL model output_model = flare.FLModel( diff --git a/examples/advanced/llm_hf/src/params_converter.py b/examples/advanced/llm_hf/src/params_converter.py new file mode 100644 index 0000000000..2b8cdc9c3f --- /dev/null +++ b/examples/advanced/llm_hf/src/params_converter.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import numpy as np +import torch + +from nvflare.app_common.abstract.params_converter import ParamsConverter + + +class PTReceiveParamsConverter(ParamsConverter): + def convert(self, params: Dict, fl_ctx) -> Dict: + tensor_shapes = fl_ctx.get_prop("tensor_shapes") + exclude_vars = fl_ctx.get_prop("exclude_vars") + + return_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + return_params[k] = v + else: + if tensor_shapes: + if k in tensor_shapes: + return_params[k] = torch.as_tensor(np.reshape(v, tensor_shapes[k])) + else: + return_params[k] = torch.as_tensor(v) + else: + return_params[k] = torch.as_tensor(v) + + if exclude_vars: + for k, v in exclude_vars.items(): + return_params[k] = v + + return return_params + + +class PTSendParamsConverter(ParamsConverter): + def convert(self, params: Dict, fl_ctx) -> Dict: + return_tensors = {} + exclude_vars = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + return_tensors[k] = v.cpu() + else: + exclude_vars[k] = v + + if exclude_vars: + fl_ctx.set_prop("exclude_vars", exclude_vars) + self.logger.warning( + f"{len(exclude_vars)} vars excluded as they were non-tensor type: " f"{list(exclude_vars.keys())}" + ) + + return return_tensors diff --git a/nvflare/app_common/executors/in_process_client_api_executor.py b/nvflare/app_common/executors/in_process_client_api_executor.py index e04ff5673e..84c18ec72a 100644 --- a/nvflare/app_common/executors/in_process_client_api_executor.py +++ b/nvflare/app_common/executors/in_process_client_api_executor.py @@ -54,8 +54,8 @@ def __init__( log_pull_interval: Optional[float] = None, params_exchange_format: str = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, - from_nvflare_converter_id: Optional[str] = None, - to_nvflare_converter_id: Optional[str] = None, + from_nvflare_converter_id: str = None, + to_nvflare_converter_id: str = None, train_with_evaluation: bool = True, train_task_name: str = AppConstants.TASK_TRAIN, evaluate_task_name: str = AppConstants.TASK_VALIDATION, @@ -138,6 +138,14 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort shareable.set_header(FLMetaKey.JOB_ID, fl_ctx.get_job_id()) shareable.set_header(FLMetaKey.SITE_NAME, fl_ctx.get_identity_name()) + + # print the from and to nvflare converter + print("----------------------------------------------------------------") + print(f"from_nvflare_converter: {self._from_nvflare_converter}") + print(f"to_nvflare_converter: {self._to_nvflare_converter}") + print("----------------------------------------------------------------") + + if self._from_nvflare_converter is not None: shareable = self._from_nvflare_converter.process(task_name, shareable, fl_ctx) @@ -200,12 +208,29 @@ def send_data_to_peer(self, shareable, fl_ctx: FLContext): def _init_converter(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() - from_nvflare_converter: ParamsConverter = engine.get_component(self._from_nvflare_converter_id) + + print("********************************") + print(self._from_nvflare_converter_id) + print(self._params_exchange_format) + print("********************************") + + from_nvflare_converter: ParamsConverter = engine.get_component(self._from_nvflare_converter_id[0]) + + print("********************************") + print(engine.get_component(self._from_nvflare_converter_id[0])) + print(from_nvflare_converter) + print("********************************") + if from_nvflare_converter is not None: check_object_type(self._from_nvflare_converter_id, from_nvflare_converter, ParamsConverter) self._from_nvflare_converter = from_nvflare_converter - to_nvflare_converter: ParamsConverter = engine.get_component(self._to_nvflare_converter_id) + + print("********************************") + print(self._from_nvflare_converter) + print("********************************") + + to_nvflare_converter: ParamsConverter = engine.get_component(self._to_nvflare_converter_id[0]) if to_nvflare_converter is not None: check_object_type(self._to_nvflare_converter_id, to_nvflare_converter, ParamsConverter) self._to_nvflare_converter = to_nvflare_converter diff --git a/nvflare/app_opt/pt/decomposers.py b/nvflare/app_opt/pt/decomposers.py index f009a6b1c2..12549a232d 100644 --- a/nvflare/app_opt/pt/decomposers.py +++ b/nvflare/app_opt/pt/decomposers.py @@ -21,6 +21,14 @@ from nvflare.fuel.utils import fobs from nvflare.fuel.utils.fobs.datum import DatumManager +# Create a module +class TensorModule(torch.nn.Module): + def __init__(self, tensor): + super().__init__() + self.tensor = tensor + + def forward(self): + return self.tensor class TensorDecomposer(fobs.Decomposer): def supported_type(self): @@ -28,12 +36,20 @@ def supported_type(self): def decompose(self, target: torch.Tensor, manager: DatumManager = None) -> Any: stream = BytesIO() + + scripted_module = torch.jit.script(TensorModule(target)) + torch.jit.save(scripted_module, stream) + stream.seek(0) + # torch.save uses Pickle so converting Tensor to ndarray first - array = target.detach().cpu().numpy() - np.save(stream, array, allow_pickle=False) + #array = target.detach().cpu().numpy() + #np.save(stream, array, allow_pickle=False) return stream.getvalue() def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: + stream = BytesIO(data) - array = np.load(stream, allow_pickle=False) - return torch.from_numpy(array) + #array = np.load(stream, allow_pickle=False) + #return torch.from_numpy(array) + loaded_module = torch.jit.load(stream) + return loaded_module() diff --git a/nvflare/app_opt/pt/params_converter.py b/nvflare/app_opt/pt/params_converter.py index 503da8e284..c64508616c 100644 --- a/nvflare/app_opt/pt/params_converter.py +++ b/nvflare/app_opt/pt/params_converter.py @@ -48,6 +48,8 @@ def convert(self, params: Dict, fl_ctx) -> Dict: exclude_vars = {} for k, v in params.items(): if isinstance(v, torch.Tensor): + # print the data type of the tensor + print(v.dtype) return_tensors[k] = v.cpu().numpy() tensor_shapes[k] = v.shape else: diff --git a/nvflare/job_config/script_runner.py b/nvflare/job_config/script_runner.py index 5fa79e3fb0..0bcb6d8a05 100644 --- a/nvflare/job_config/script_runner.py +++ b/nvflare/job_config/script_runner.py @@ -45,6 +45,8 @@ def __init__( framework: FrameworkType = FrameworkType.PYTORCH, params_transfer_type: str = TransferType.FULL, executor: Union[ClientAPILauncherExecutor, InProcessClientAPIExecutor, None] = None, + from_nvflare_converter_id: str = None, + to_nvflare_converter_id: str = None, task_pipe: Optional[Pipe] = None, launcher: Optional[Launcher] = None, metric_relay: Optional[MetricRelay] = None, @@ -96,7 +98,8 @@ def __init__( self._launch_external_process = launch_external_process self._framework = framework self._params_transfer_type = params_transfer_type - + self._from_nvflare_converter_id = from_nvflare_converter_id, + self._to_nvflare_converter_id = to_nvflare_converter_id, self._params_exchange_format = None if self._framework == FrameworkType.PYTORCH: @@ -186,6 +189,8 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): launcher_id=launcher_id, params_exchange_format=self._params_exchange_format, params_transfer_type=self._params_transfer_type, + from_nvflare_converter_id=self._from_nvflare_converter_id, + to_nvflare_converter_id=self._to_nvflare_converter_id, heartbeat_timeout=0, ) ) @@ -231,6 +236,8 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): task_script_args=self._script_args, params_exchange_format=self._params_exchange_format, params_transfer_type=self._params_transfer_type, + from_nvflare_converter_id=self._from_nvflare_converter_id, + to_nvflare_converter_id=self._to_nvflare_converter_id, ) ) job.add_executor(executor, tasks=tasks, ctx=ctx) From a2e6849df2866c24a47d6ff345c1d0a6ab9998d4 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 3 Dec 2024 10:15:24 -0500 Subject: [PATCH 03/16] polish sft_job --- examples/advanced/llm_hf/sft_job.py | 36 +++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/examples/advanced/llm_hf/sft_job.py b/examples/advanced/llm_hf/sft_job.py index ab328b7297..6e21272305 100644 --- a/examples/advanced/llm_hf/sft_job.py +++ b/examples/advanced/llm_hf/sft_job.py @@ -47,6 +47,7 @@ def main(): job_dir = args.job_dir model_name_or_path = args.model_name_or_path train_mode = args.train_mode + message_mode = args.message_mode # Create the FedJob if train_mode.lower() == "sft": @@ -88,16 +89,27 @@ def main(): site_name = f"site-{client_id}" data_path_train = os.path.join(args.data_path, client_id, "validation.jsonl") data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl") - # Add params converters and send to client - job.to(PTSendParamsConverter(), site_name, id="pt_send") - job.to(PTReceiveParamsConverter(), site_name, id="pt_receive") - runner = BaseScriptRunner( - script=train_script, - script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}", - from_nvflare_converter_id="pt_receive", - to_nvflare_converter_id="pt_send", - ) + + if message_mode == "tensor": + # Add params converters and send to client + job.to(PTSendParamsConverter(), site_name, id="pt_send") + job.to(PTReceiveParamsConverter(), site_name, id="pt_receive") + runner = BaseScriptRunner( + script=train_script, + script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}", + from_nvflare_converter_id="pt_receive", + to_nvflare_converter_id="pt_send", + ) + job.to(runner, site_name, tasks=["train"]) + elif message_mode == "numpy": + runner = BaseScriptRunner( + script=train_script, + script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}", + ) + else: + raise ValueError(f"Invalid message_mode: {message_mode}, only numpy and tensor are supported.") job.to(runner, site_name, tasks=["train"]) + if args.quantize_mode: job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT) job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA) @@ -163,6 +175,12 @@ def define_parser(): default=None, help="quantization mode, float16 or blockwise8, default to None (no quantization)", ) + parser.add_argument( + "--message_mode", + type=str, + default="numpy", + help="message mode, numpy or tensor, default to numpy", + ) parser.add_argument( "--threads", type=int, From 247e7c71e94b987a2e66ce57291b7cc4d50c2869 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 3 Dec 2024 10:15:49 -0500 Subject: [PATCH 04/16] polish sft_job --- examples/advanced/llm_hf/sft_job.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/advanced/llm_hf/sft_job.py b/examples/advanced/llm_hf/sft_job.py index 6e21272305..b1536f568f 100644 --- a/examples/advanced/llm_hf/sft_job.py +++ b/examples/advanced/llm_hf/sft_job.py @@ -87,7 +87,7 @@ def main(): for i in range(num_clients): client_id = client_ids[i] site_name = f"site-{client_id}" - data_path_train = os.path.join(args.data_path, client_id, "validation.jsonl") + data_path_train = os.path.join(args.data_path, client_id, "training.jsonl") data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl") if message_mode == "tensor": @@ -109,7 +109,7 @@ def main(): else: raise ValueError(f"Invalid message_mode: {message_mode}, only numpy and tensor are supported.") job.to(runner, site_name, tasks=["train"]) - + if args.quantize_mode: job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT) job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA) From 6acec33a3ece24d38db93f053e5b4df7d34511fd Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 3 Dec 2024 10:22:18 -0500 Subject: [PATCH 05/16] polish local training script --- examples/advanced/llm_hf/sft_job.py | 4 ++-- examples/advanced/llm_hf/src/hf_sft_peft_fl.py | 12 ++++++++++-- .../in_process_client_api_executor.py | 18 ------------------ 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/examples/advanced/llm_hf/sft_job.py b/examples/advanced/llm_hf/sft_job.py index b1536f568f..2a280e8bd2 100644 --- a/examples/advanced/llm_hf/sft_job.py +++ b/examples/advanced/llm_hf/sft_job.py @@ -96,7 +96,7 @@ def main(): job.to(PTReceiveParamsConverter(), site_name, id="pt_receive") runner = BaseScriptRunner( script=train_script, - script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}", + script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --message_mode {message_mode} --clean_up {clean_up}", from_nvflare_converter_id="pt_receive", to_nvflare_converter_id="pt_send", ) @@ -104,7 +104,7 @@ def main(): elif message_mode == "numpy": runner = BaseScriptRunner( script=train_script, - script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}", + script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --message_mode {message_mode} --clean_up {clean_up}", ) else: raise ValueError(f"Invalid message_mode: {message_mode}, only numpy and tensor are supported.") diff --git a/examples/advanced/llm_hf/src/hf_sft_peft_fl.py b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py index d0114ffec7..78ecac8950 100755 --- a/examples/advanced/llm_hf/src/hf_sft_peft_fl.py +++ b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py @@ -69,6 +69,12 @@ def main(): default="SFT", help="training mode, SFT or PEFT, default to SFT", ) + parser.add_argument( + "--message_mode", + type=str, + default="numpy", + help="message mode, numpy or tensor, default to numpy", + ) parser.add_argument("--local_epoch", type=int, default=1) parser.add_argument("--clean_up", type=int, default=0) args = parser.parse_args() @@ -232,8 +238,10 @@ def evaluate(input_weights, mode): for key in list(out_param.keys()): out_param["model." + key] = out_param.pop(key).cpu() - # cast out_param to float32 preparing for communication - #out_param = {k: v.to(torch.float32) for k, v in out_param.items()} + if args.message_mode.lower() == "numpy": + # cast out_param to float32 preparing for communication with numpy + # otherwise do nothing + out_param = {k: v.to(torch.float32) for k, v in out_param.items()} # construct trained FL model output_model = flare.FLModel( diff --git a/nvflare/app_common/executors/in_process_client_api_executor.py b/nvflare/app_common/executors/in_process_client_api_executor.py index 84c18ec72a..99a29df2a0 100644 --- a/nvflare/app_common/executors/in_process_client_api_executor.py +++ b/nvflare/app_common/executors/in_process_client_api_executor.py @@ -208,28 +208,10 @@ def send_data_to_peer(self, shareable, fl_ctx: FLContext): def _init_converter(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() - - print("********************************") - print(self._from_nvflare_converter_id) - print(self._params_exchange_format) - print("********************************") - from_nvflare_converter: ParamsConverter = engine.get_component(self._from_nvflare_converter_id[0]) - - print("********************************") - print(engine.get_component(self._from_nvflare_converter_id[0])) - print(from_nvflare_converter) - print("********************************") - if from_nvflare_converter is not None: check_object_type(self._from_nvflare_converter_id, from_nvflare_converter, ParamsConverter) self._from_nvflare_converter = from_nvflare_converter - - - print("********************************") - print(self._from_nvflare_converter) - print("********************************") - to_nvflare_converter: ParamsConverter = engine.get_component(self._to_nvflare_converter_id[0]) if to_nvflare_converter is not None: check_object_type(self._to_nvflare_converter_id, to_nvflare_converter, ParamsConverter) From 15240d818844f77a7a71ee9a9d511a2c7963b028 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 3 Dec 2024 10:30:10 -0500 Subject: [PATCH 06/16] polish tensor params converter --- examples/advanced/llm_hf/sft_job.py | 3 +-- .../executors/in_process_client_api_executor.py | 8 -------- nvflare/app_opt/pt/params_converter.py | 2 -- .../app_opt/pt/tensor_params_converter.py | 0 4 files changed, 1 insertion(+), 12 deletions(-) rename examples/advanced/llm_hf/src/params_converter.py => nvflare/app_opt/pt/tensor_params_converter.py (100%) diff --git a/examples/advanced/llm_hf/sft_job.py b/examples/advanced/llm_hf/sft_job.py index 2a280e8bd2..257022fc4e 100644 --- a/examples/advanced/llm_hf/sft_job.py +++ b/examples/advanced/llm_hf/sft_job.py @@ -22,8 +22,7 @@ from nvflare.app_opt.quantization.numpy_dequantizor import NumpyModelDequantizor from nvflare.app_opt.quantization.numpy_quantizor import NumpyModelQuantizor from nvflare.job_config.script_runner import BaseScriptRunner - -from src.params_converter import PTSendParamsConverter, PTReceiveParamsConverter +from nvflare.app_opt.pt.tensor_params_converter import PTSendParamsConverter, PTReceiveParamsConverter def main(): args = define_parser() diff --git a/nvflare/app_common/executors/in_process_client_api_executor.py b/nvflare/app_common/executors/in_process_client_api_executor.py index 99a29df2a0..98445c66e2 100644 --- a/nvflare/app_common/executors/in_process_client_api_executor.py +++ b/nvflare/app_common/executors/in_process_client_api_executor.py @@ -138,14 +138,6 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort shareable.set_header(FLMetaKey.JOB_ID, fl_ctx.get_job_id()) shareable.set_header(FLMetaKey.SITE_NAME, fl_ctx.get_identity_name()) - - # print the from and to nvflare converter - print("----------------------------------------------------------------") - print(f"from_nvflare_converter: {self._from_nvflare_converter}") - print(f"to_nvflare_converter: {self._to_nvflare_converter}") - print("----------------------------------------------------------------") - - if self._from_nvflare_converter is not None: shareable = self._from_nvflare_converter.process(task_name, shareable, fl_ctx) diff --git a/nvflare/app_opt/pt/params_converter.py b/nvflare/app_opt/pt/params_converter.py index c64508616c..503da8e284 100644 --- a/nvflare/app_opt/pt/params_converter.py +++ b/nvflare/app_opt/pt/params_converter.py @@ -48,8 +48,6 @@ def convert(self, params: Dict, fl_ctx) -> Dict: exclude_vars = {} for k, v in params.items(): if isinstance(v, torch.Tensor): - # print the data type of the tensor - print(v.dtype) return_tensors[k] = v.cpu().numpy() tensor_shapes[k] = v.shape else: diff --git a/examples/advanced/llm_hf/src/params_converter.py b/nvflare/app_opt/pt/tensor_params_converter.py similarity index 100% rename from examples/advanced/llm_hf/src/params_converter.py rename to nvflare/app_opt/pt/tensor_params_converter.py From 531304b66624afa480b0206ab23594386e1fec5f Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 3 Dec 2024 11:00:47 -0500 Subject: [PATCH 07/16] polish decomposer --- nvflare/app_opt/pt/decomposers.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/nvflare/app_opt/pt/decomposers.py b/nvflare/app_opt/pt/decomposers.py index 12549a232d..2d8b27b555 100644 --- a/nvflare/app_opt/pt/decomposers.py +++ b/nvflare/app_opt/pt/decomposers.py @@ -30,26 +30,39 @@ def __init__(self, tensor): def forward(self): return self.tensor -class TensorDecomposer(fobs.Decomposer): +class TensorJitDecomposer(fobs.Decomposer): def supported_type(self): return torch.Tensor def decompose(self, target: torch.Tensor, manager: DatumManager = None) -> Any: stream = BytesIO() + # Use JIT serialization to avoid Pickle scripted_module = torch.jit.script(TensorModule(target)) torch.jit.save(scripted_module, stream) stream.seek(0) - # torch.save uses Pickle so converting Tensor to ndarray first - #array = target.detach().cpu().numpy() - #np.save(stream, array, allow_pickle=False) return stream.getvalue() def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: - stream = BytesIO(data) - #array = np.load(stream, allow_pickle=False) - #return torch.from_numpy(array) loaded_module = torch.jit.load(stream) return loaded_module() + + +class TensorNumpyDecomposer(fobs.Decomposer): + def supported_type(self): + return torch.Tensor + + def decompose(self, target: torch.Tensor, manager: DatumManager = None) -> Any: + stream = BytesIO() + + # torch.save uses Pickle so converting Tensor to ndarray first + array = target.detach().cpu().numpy() + np.save(stream, array, allow_pickle=False) + return stream.getvalue() + + def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: + stream = BytesIO(data) + array = np.load(stream, allow_pickle=False) + return torch.from_numpy(array) \ No newline at end of file From 059c64fcc36f3cce2b01bacbbe54fdb6739483bf Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 3 Dec 2024 12:31:49 -0500 Subject: [PATCH 08/16] format correction --- examples/advanced/llm_hf/sft_job.py | 4 ++-- nvflare/app_opt/pt/decomposers.py | 6 ++++-- nvflare/job_config/script_runner.py | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/advanced/llm_hf/sft_job.py b/examples/advanced/llm_hf/sft_job.py index 257022fc4e..6420861935 100644 --- a/examples/advanced/llm_hf/sft_job.py +++ b/examples/advanced/llm_hf/sft_job.py @@ -19,10 +19,11 @@ from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector from nvflare.app_common.workflows.fedavg import FedAvg from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor +from nvflare.app_opt.pt.tensor_params_converter import PTReceiveParamsConverter, PTSendParamsConverter from nvflare.app_opt.quantization.numpy_dequantizor import NumpyModelDequantizor from nvflare.app_opt.quantization.numpy_quantizor import NumpyModelQuantizor from nvflare.job_config.script_runner import BaseScriptRunner -from nvflare.app_opt.pt.tensor_params_converter import PTSendParamsConverter, PTReceiveParamsConverter + def main(): args = define_parser() @@ -99,7 +100,6 @@ def main(): from_nvflare_converter_id="pt_receive", to_nvflare_converter_id="pt_send", ) - job.to(runner, site_name, tasks=["train"]) elif message_mode == "numpy": runner = BaseScriptRunner( script=train_script, diff --git a/nvflare/app_opt/pt/decomposers.py b/nvflare/app_opt/pt/decomposers.py index 2d8b27b555..725df2b544 100644 --- a/nvflare/app_opt/pt/decomposers.py +++ b/nvflare/app_opt/pt/decomposers.py @@ -21,6 +21,7 @@ from nvflare.fuel.utils import fobs from nvflare.fuel.utils.fobs.datum import DatumManager + # Create a module class TensorModule(torch.nn.Module): def __init__(self, tensor): @@ -30,6 +31,7 @@ def __init__(self, tensor): def forward(self): return self.tensor + class TensorJitDecomposer(fobs.Decomposer): def supported_type(self): return torch.Tensor @@ -50,7 +52,7 @@ def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: return loaded_module() -class TensorNumpyDecomposer(fobs.Decomposer): +class TensorDecomposer(fobs.Decomposer): def supported_type(self): return torch.Tensor @@ -65,4 +67,4 @@ def decompose(self, target: torch.Tensor, manager: DatumManager = None) -> Any: def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: stream = BytesIO(data) array = np.load(stream, allow_pickle=False) - return torch.from_numpy(array) \ No newline at end of file + return torch.from_numpy(array) diff --git a/nvflare/job_config/script_runner.py b/nvflare/job_config/script_runner.py index 0bcb6d8a05..cb418408d9 100644 --- a/nvflare/job_config/script_runner.py +++ b/nvflare/job_config/script_runner.py @@ -98,8 +98,8 @@ def __init__( self._launch_external_process = launch_external_process self._framework = framework self._params_transfer_type = params_transfer_type - self._from_nvflare_converter_id = from_nvflare_converter_id, - self._to_nvflare_converter_id = to_nvflare_converter_id, + self._from_nvflare_converter_id = (from_nvflare_converter_id,) + self._to_nvflare_converter_id = (to_nvflare_converter_id,) self._params_exchange_format = None if self._framework == FrameworkType.PYTORCH: From f694a143e69b4f3bfbd6f7a99c089641fe4081cf Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 3 Dec 2024 12:34:39 -0500 Subject: [PATCH 09/16] header update --- examples/advanced/llm_hf/src/hf_sft_peft_fl.py | 2 +- nvflare/app_opt/pt/decomposers.py | 2 +- nvflare/app_opt/pt/tensor_params_converter.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/advanced/llm_hf/src/hf_sft_peft_fl.py b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py index 78ecac8950..49113190ff 100755 --- a/examples/advanced/llm_hf/src/hf_sft_peft_fl.py +++ b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nvflare/app_opt/pt/decomposers.py b/nvflare/app_opt/pt/decomposers.py index 725df2b544..1f87df543b 100644 --- a/nvflare/app_opt/pt/decomposers.py +++ b/nvflare/app_opt/pt/decomposers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nvflare/app_opt/pt/tensor_params_converter.py b/nvflare/app_opt/pt/tensor_params_converter.py index 2b8cdc9c3f..1e8275f58c 100644 --- a/nvflare/app_opt/pt/tensor_params_converter.py +++ b/nvflare/app_opt/pt/tensor_params_converter.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 5c567c74afb48358b8dfead4f35497d5ef5c1dfe Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 3 Dec 2024 14:21:25 -0500 Subject: [PATCH 10/16] update decomposer --- nvflare/app_opt/pt/decomposers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nvflare/app_opt/pt/decomposers.py b/nvflare/app_opt/pt/decomposers.py index 0696f3cd0c..a7f071f33d 100644 --- a/nvflare/app_opt/pt/decomposers.py +++ b/nvflare/app_opt/pt/decomposers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ def __init__(self, tensor): super().__init__() self.register_buffer("saved_tensor", tensor) + class TensorDecomposer(fobs.Decomposer): def supported_type(self): return torch.Tensor @@ -38,7 +39,6 @@ def decompose(self, target: torch.Tensor, manager: DatumManager = None) -> Any: return self._numpy_serialize(target) def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: - if isinstance(data, dict): if data["dtype"] == "torch.bfloat16": return self._jit_deserialize(data) @@ -52,8 +52,7 @@ def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: @staticmethod def _numpy_serialize(tensor: torch.Tensor) -> dict: stream = BytesIO() - - # torch.save uses Pickle so converting Tensor to ndarray first + # supported ScalarType, use numpy to avoid Pickle array = tensor.detach().cpu().numpy() np.save(stream, array, allow_pickle=False) return { @@ -69,8 +68,9 @@ def _numpy_deserialize(data: Any) -> torch.Tensor: @staticmethod def _jit_serialize(tensor: torch.Tensor) -> dict: - module = SerializationModule(tensor) stream = BytesIO() + # unsupported ScalarType by numpy, use torch.jit to avoid Pickle + module = SerializationModule(tensor) torch.jit.save(torch.jit.script(module), stream) return { "buffer": stream.getvalue(), From 3cf01c480803a0ff834a1f544dbe63bc43240cc3 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 3 Dec 2024 17:15:16 -0500 Subject: [PATCH 11/16] end to end tensor communication passed --- examples/advanced/llm_hf/README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/advanced/llm_hf/README.md b/examples/advanced/llm_hf/README.md index 6cbe1ae994..831719eff9 100644 --- a/examples/advanced/llm_hf/README.md +++ b/examples/advanced/llm_hf/README.md @@ -99,7 +99,7 @@ Similar patterns can be observed from the PEFT curves, purple for centralized re ![peft](./figs/fl_peft.png) ## Model Quantization for Communication -In the above example, we used float32 for communication. To reduce the message size, we can use model precision conversion and quantization +In the above example, we used numpy in float32 for communication. To reduce the message size, we can use model precision conversion and quantization from float32 to 16-bit, 8-bit, and 4-bit for communication. Quantization is enabled by NVFlare's [filter mechanism](https://nvflare.readthedocs.io/en/main/programming_guide/filters.html). We can use the following command to run the federated training with model quantization. 16-bit is a direct precision conversion, while 8-bit, 4-bit quantization is performed by [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes/tree/main). Note that 4-bit quantizations (`fp4` or `nf4`) need device support. @@ -125,6 +125,13 @@ For message reduce, from float32 to 16-/8-/4-bit, the message size (in MB) of Ll Note that quantization will generate additional meta data, which can be significant for 4-bit cases. +## Model Communication with Tensor +In addition, since the model is trained with bf16, instead of first converting to numpy in float32, we can directly communicate with tensor in bf16 to avoid the message size inflation due to the conversion. +We can use the following command to run the federated training with direct tensor communication. +``` +python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_tensor --job_dir ${PWD}/workspace/jobs/hf_sft_tensor --train_mode SFT --message_mode tensor +``` + ## Federated Training with Multiple Clients With the above example, we can easily extend the federated training to multiple clients. We can use the following command to run the federated training with multiple clients: ``` From dca4fd812e240713c1494fab27d60de6e67acef2 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Thu, 12 Dec 2024 18:33:09 -0500 Subject: [PATCH 12/16] update quantization filters to handle tensor --- examples/advanced/llm_hf/README.md | 4 + examples/advanced/llm_hf/sft_job.py | 8 +- nvflare/app_opt/quantization/constant.py | 3 + nvflare/app_opt/quantization/dequantizor.py | 198 +++++++++++++++++ .../app_opt/quantization/numpy_dequantizor.py | 136 ------------ .../app_opt/quantization/numpy_quantizor.py | 148 ------------- nvflare/app_opt/quantization/quantizor.py | 206 ++++++++++++++++++ .../app_opt/quantization/quantization_test.py | 4 +- 8 files changed, 417 insertions(+), 290 deletions(-) create mode 100644 nvflare/app_opt/quantization/dequantizor.py delete mode 100644 nvflare/app_opt/quantization/numpy_dequantizor.py delete mode 100644 nvflare/app_opt/quantization/numpy_quantizor.py create mode 100644 nvflare/app_opt/quantization/quantizor.py diff --git a/examples/advanced/llm_hf/README.md b/examples/advanced/llm_hf/README.md index 831719eff9..173536d61f 100644 --- a/examples/advanced/llm_hf/README.md +++ b/examples/advanced/llm_hf/README.md @@ -131,6 +131,10 @@ We can use the following command to run the federated training with direct tenso ``` python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_tensor --job_dir ${PWD}/workspace/jobs/hf_sft_tensor --train_mode SFT --message_mode tensor ``` +Similarly, quantization can be applied to tensor communication as well. +``` +python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_tensor_fp4 --job_dir ${PWD}/workspace/jobs/hf_sft_tensor_fp4 --train_mode SFT --message_mode tensor --quantize_mode float4 +``` ## Federated Training with Multiple Clients With the above example, we can easily extend the federated training to multiple clients. We can use the following command to run the federated training with multiple clients: diff --git a/examples/advanced/llm_hf/sft_job.py b/examples/advanced/llm_hf/sft_job.py index 6420861935..a1810ab3ee 100644 --- a/examples/advanced/llm_hf/sft_job.py +++ b/examples/advanced/llm_hf/sft_job.py @@ -20,8 +20,8 @@ from nvflare.app_common.workflows.fedavg import FedAvg from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor from nvflare.app_opt.pt.tensor_params_converter import PTReceiveParamsConverter, PTSendParamsConverter -from nvflare.app_opt.quantization.numpy_dequantizor import NumpyModelDequantizor -from nvflare.app_opt.quantization.numpy_quantizor import NumpyModelQuantizor +from nvflare.app_opt.quantization.dequantizor import ModelDequantizor +from nvflare.app_opt.quantization.quantizor import ModelQuantizor from nvflare.job_config.script_runner import BaseScriptRunner @@ -68,8 +68,8 @@ def main(): if args.quantize_mode: # If using quantization, add quantize filters. - quantizor = NumpyModelQuantizor(quantization_type=args.quantize_mode) - dequantizor = NumpyModelDequantizor(source_data_type="float32") + quantizor = ModelQuantizor(quantization_type=args.quantize_mode) + dequantizor = ModelDequantizor() job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA) job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT) diff --git a/nvflare/app_opt/quantization/constant.py b/nvflare/app_opt/quantization/constant.py index 06c422c655..e1e48ea779 100644 --- a/nvflare/app_opt/quantization/constant.py +++ b/nvflare/app_opt/quantization/constant.py @@ -13,7 +13,10 @@ # limitations under the License. DATA_TYPE = [ + "FLOAT64", "FLOAT32", + "FLOAT16", + "BFLOAT16", ] QUANTIZATION_TYPE = [ diff --git a/nvflare/app_opt/quantization/dequantizor.py b/nvflare/app_opt/quantization/dequantizor.py new file mode 100644 index 0000000000..1ac6265c6e --- /dev/null +++ b/nvflare/app_opt/quantization/dequantizor.py @@ -0,0 +1,198 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Union + +import numpy as np +import torch +from bitsandbytes.functional import QuantState, dequantize_4bit, dequantize_blockwise + +from nvflare.apis.dxo import DXO, DataKind, MetaKey +from nvflare.apis.dxo_filter import DXOFilter +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.app_opt.quantization.constant import QUANTIZATION_TYPE + + +class ModelDequantizor(DXOFilter): + def __init__(self): + """Filter to dequantize Shareable object to recover from quantization + + Args: + None + + """ + + # support weight and weight_diff data kinds + data_kinds = [DataKind.WEIGHTS, DataKind.WEIGHT_DIFF] + super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds) + self.logger.info("Using model dequantizator.") + + def dequantization( + self, params: dict, quant_state: dict, quantization_type: str, source_datatype: dict, fl_ctx: FLContext + ): + n_params = len(params.keys()) + self.log_info(fl_ctx, f"Running dequantization on {n_params} variables") + n_bytes_before = 0 + n_bytes_after = 0 + n_bytes_meta = 0 + n_quant_params = 0 + for i, param_name in enumerate(params.keys()): + source_data_type = source_datatype[param_name] + + # get the bits information + source_date_bits = int(re.findall(r"\d+", source_data_type)[0]) + quantization_bits = int(re.findall(r"\d+", quantization_type)[0]) + + # only dequantize if the quantization type is lower than the source data type + if quantization_bits >= source_date_bits: + self.log_info( + fl_ctx, + f"Skipping dequantization for {param_name}, quantization bit {quantization_type} >= source data bit {source_data_type}", + ) + continue + else: + values = params[param_name] + n_bytes_before += values.nbytes + for item in quant_state[param_name].values(): + if isinstance(item, np.ndarray) or isinstance(item, torch.Tensor): + n_bytes_meta += item.nbytes + + if isinstance(values, np.ndarray): + # if numpy, convert to torch + source_data_format = "numpy" + elif isinstance(values, torch.Tensor): + source_data_format = "torch" + else: + raise ValueError(f"Invalid source data type: {type(values)}, valid: numpy or torch") + + n_quant_params += 1 + if quantization_type == "float16": + # direct convert back to higher precision + if source_data_format == "numpy": + if source_data_type == "float32": + values = values.astype(np.float32) + elif source_data_type == "float64": + values = values.astype(np.float64) + elif source_data_format == "torch": + if source_data_type == "float32": + values = values.float() + elif source_data_type == "float64": + values = values.double() + params[param_name] = values + elif quantization_type in ["blockwise8", "float4", "normfloat4"]: + # use bitsandbytes to dequantize the values + # extract quantization state + if quantization_type == "blockwise8": + if source_data_format == "numpy": + # first convert numpy array to tensor if numpy + quantized = torch.as_tensor(values) + absmax = torch.as_tensor(quant_state[param_name]["absmax"]) + code = torch.as_tensor(quant_state[param_name]["code"]) + # de-quanitze + dequantized = dequantize_blockwise(quantized, absmax=absmax, code=code) + # assign back + if source_data_format == "numpy": + params[param_name] = dequantized.numpy() + elif source_data_format == "torch": + params[param_name] = dequantized + else: + if source_data_format == "numpy": + # first convert numpy array to tensor, need to use GPU + quantized = torch.as_tensor(values).cuda() + # create QuantState object + quantize_state = QuantState( + quant_type=quant_state[param_name]["quant_type"], + absmax=torch.as_tensor(quant_state[param_name]["absmax"]).cuda(), + blocksize=quant_state[param_name]["blocksize"], + code=torch.as_tensor(quant_state[param_name]["quant_map"]).cuda(), + dtype=getattr(torch, quant_state[param_name]["dtype"]), + shape=torch.Size(quant_state[param_name]["shape"]), + ) + elif source_data_format == "torch": + quantized = values.cuda() + quantize_state = QuantState( + quant_type=quant_state[param_name]["quant_type"], + absmax=quant_state[param_name]["absmax"].cuda(), + blocksize=quant_state[param_name]["blocksize"], + code=quant_state[param_name]["quant_map"].cuda(), + dtype=getattr(torch, quant_state[param_name]["dtype"]), + shape=torch.Size(quant_state[param_name]["shape"]), + ) + # de-quanitze + if quantization_type == "float4": + dequantized = dequantize_4bit(quantized, quantize_state, quant_type="fp4") + else: + dequantized = dequantize_4bit(quantized, quantize_state, quant_type="nf4") + # assign back + if source_data_format == "numpy": + params[param_name] = dequantized.cpu().numpy() + elif source_data_format == "torch": + params[param_name] = dequantized.cpu() + # convert back to original data type + if source_data_type == "float32": + params[param_name] = params[param_name].float() + elif source_data_type == "float64": + params[param_name] = params[param_name].double() + elif source_data_type == "float16": + params[param_name] = params[param_name].half() + elif source_data_type == "bfloat16": + params[param_name] = params[param_name].bfloat16() + + n_bytes_after += params[param_name].nbytes + + self.log_info( + fl_ctx, + f"Dequantized {n_quant_params}/{n_params} params." + f" Before dequantization: {n_bytes_before / (1024 ** 2):.2f} MB with meta: {n_bytes_meta / (1024 ** 2):.2f} MB." + f" After dequantization: {n_bytes_after / (1024 ** 2):.2f} MB.", + ) + return params + + def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]: + """Filter process apply to the Shareable object. + + Args: + dxo: data to be processed + shareable: that the dxo belongs to + fl_ctx: FLContext + + Returns: DXO object with dequantized weights + + """ + + self.log_info(fl_ctx, "Running dequantization...") + + # check config + quantization_type = dxo.get_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, default=None) + if quantization_type.upper() not in QUANTIZATION_TYPE: + raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}") + + dequantized_params = self.dequantization( + params=dxo.data, + quant_state=dxo.meta["quant_state"], + quantization_type=quantization_type, + source_datatype=dxo.meta["source_datatype"], + fl_ctx=fl_ctx, + ) + # Compose new DXO with dequantized data + dxo.data = dequantized_params + dxo.remove_meta_props(MetaKey.PROCESSED_ALGORITHM) + dxo.remove_meta_props("quant_state") + dxo.remove_meta_props("source_datatype") + dxo.update_shareable(shareable) + self.log_info(fl_ctx, "Dequantized back") + + return dxo diff --git a/nvflare/app_opt/quantization/numpy_dequantizor.py b/nvflare/app_opt/quantization/numpy_dequantizor.py deleted file mode 100644 index 0409ac45c2..0000000000 --- a/nvflare/app_opt/quantization/numpy_dequantizor.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union - -import numpy as np -import torch -from bitsandbytes.functional import QuantState, dequantize_4bit, dequantize_blockwise - -from nvflare.apis.dxo import DXO, DataKind, MetaKey -from nvflare.apis.dxo_filter import DXOFilter -from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable -from nvflare.app_opt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE - - -class NumpyModelDequantizor(DXOFilter): - def __init__(self, source_data_type="float32"): - """Filter to dequantize Shareable object to recover from quantization - - Args: - source_data_type: original data type of the model - - """ - - # support weight and weight_diff data kinds - data_kinds = [DataKind.WEIGHTS, DataKind.WEIGHT_DIFF] - super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds) - - # assign data type and check if it is valid - self.logger.info("Using model dequantizator.") - if source_data_type.upper() not in DATA_TYPE: - raise ValueError(f"Invalid source data type: {source_data_type}, valid: {DATA_TYPE}") - else: - self.source_data_type = source_data_type - - def dequantization(self, params: dict, quant_state: dict, quant_type: str, fl_ctx: FLContext): - n_params = len(params.keys()) - self.log_info(fl_ctx, f"Running dequantization on {n_params} variables") - n_bytes_before = 0 - n_bytes_after = 0 - n_bytes_meta = 0 - n_quant_params = 0 - for i, param_name in enumerate(params.keys()): - if self.source_data_type == "float32": - values = params[param_name] - n_bytes_before += values.nbytes - for item in quant_state[param_name].values(): - if isinstance(item, np.ndarray): - n_bytes_meta += item.nbytes - if self.source_data_type != quant_type: - # if the source data type is not the same as the quantization type, convert it - n_quant_params += 1 - if quant_type == "float16": - # direct convert - values = values.astype(np.float32) - params[param_name] = values - elif quant_type in ["blockwise8", "float4", "normfloat4"]: - # use bitsandbytes to dequantize the values - # extract quantization state - if quant_type == "blockwise8": - quantized = torch.as_tensor(values) - absmax = torch.as_tensor(quant_state[param_name]["absmax"]) - code = torch.as_tensor(quant_state[param_name]["code"]) - # de-quanitze - dequantized = dequantize_blockwise(quantized, absmax=absmax, code=code) - params[param_name] = dequantized.numpy() - else: - # first convert numpy array to tensor, need to use GPU - quantized = torch.as_tensor(values).cuda() - # create QuantState object - quantize_state = QuantState( - quant_type=quant_state[param_name]["quant_type"], - absmax=torch.as_tensor(quant_state[param_name]["absmax"]).cuda(), - blocksize=quant_state[param_name]["blocksize"], - code=torch.as_tensor(quant_state[param_name]["quant_map"]).cuda(), - dtype=getattr(torch, quant_state[param_name]["dtype"]), - shape=torch.Size(quant_state[param_name]["shape"]), - ) - # de-quanitze - if quant_type == "float4": - dequantized = dequantize_4bit(quantized, quantize_state, quant_type="fp4") - else: - dequantized = dequantize_4bit(quantized, quantize_state, quant_type="nf4") - params[param_name] = dequantized.cpu().numpy() - n_bytes_after += params[param_name].nbytes - - self.log_info( - fl_ctx, - f"Dequantized {n_quant_params}/{n_params} params." - f" Before dequantization: {n_bytes_before / (1024 ** 2):.2f} MB with meta: {n_bytes_meta / (1024 ** 2):.2f} MB." - f" After dequantization: {n_bytes_after / (1024 ** 2):.2f} MB.", - ) - return params - - def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]: - """Filter process apply to the Shareable object. - - Args: - dxo: data to be processed - shareable: that the dxo belongs to - fl_ctx: FLContext - - Returns: DXO object with dequantized weights - - """ - - self.log_info(fl_ctx, "Running dequantization...") - - # check config - quantization_type = dxo.get_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, default=None) - if quantization_type.upper() not in QUANTIZATION_TYPE: - raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}") - - dequantized_params = self.dequantization( - params=dxo.data, quant_state=dxo.meta["quant_state"], quant_type=quantization_type, fl_ctx=fl_ctx - ) - # Compose new DXO with dequantized data - dxo.data = dequantized_params - dxo.remove_meta_props(MetaKey.PROCESSED_ALGORITHM) - dxo.remove_meta_props("quant_state") - dxo.update_shareable(shareable) - self.log_info(fl_ctx, f"Dequantized back to {self.source_data_type}") - - return dxo diff --git a/nvflare/app_opt/quantization/numpy_quantizor.py b/nvflare/app_opt/quantization/numpy_quantizor.py deleted file mode 100644 index 13b0b44379..0000000000 --- a/nvflare/app_opt/quantization/numpy_quantizor.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union - -import numpy as np -import torch -from bitsandbytes.functional import quantize_4bit, quantize_blockwise - -from nvflare.apis.dxo import DXO, DataKind, MetaKey -from nvflare.apis.dxo_filter import DXOFilter -from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable -from nvflare.app_opt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE - - -class NumpyModelQuantizor(DXOFilter): - def __init__( - self, - quantization_type="float16", - ): - """Filter to quantize Shareable object to reduce communication burden. - - Args: - quantization_type: method used for quantization - - """ - - # support weight and weight_diff data kinds - data_kinds = [DataKind.WEIGHTS, DataKind.WEIGHT_DIFF] - super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds) - - # assign quantization type and check if it is valid - self.logger.info("Using model quantizator.") - if quantization_type.upper() not in QUANTIZATION_TYPE: - raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}") - else: - self.quantization_type = quantization_type - - # quantization constants - self.FP16_MIN = np.finfo(np.float16).min - self.FP16_MAX = np.finfo(np.float16).max - - def quantization(self, params: dict, fl_ctx: FLContext): - n_params = len(params.keys()) - self.log_info(fl_ctx, f"Running quantization on {n_params} variables") - n_bytes_before = 0 - n_bytes_after = 0 - n_bytes_meta = 0 - n_quant_params = 0 - quant_state = {} - for i, param_name in enumerate(params.keys()): - values = params[param_name] - quant_state[param_name] = {} - # check the data type of the values and if it is valid - source_data_type = values.dtype.name - if source_data_type.upper() not in DATA_TYPE: - raise ValueError(f"Invalid source data type: {source_data_type}, valid: {DATA_TYPE}") - # add the number of bytes of the values - n_bytes_before += values.nbytes - if source_data_type != self.quantization_type: - # if the source data type is not the same as the quantization type, convert it - n_quant_params += 1 - if source_data_type == "float32": - if self.quantization_type == "float16": - # first clamp the values to the range of float16 - values = np.clip(values, self.FP16_MIN, self.FP16_MAX) - # then convert to float16 - values = values.astype(np.float16) - params[param_name] = values - elif self.quantization_type in ["blockwise8", "float4", "normfloat4"]: - # use bitsandbytes to quantize the values - # input is a tensor, output is a tuple of (quantized tensor, quantized_state) - if self.quantization_type == "blockwise8": - # first convert numpy array to tensor - values_tensor = torch.as_tensor(values) - # then quantize the tensor - quantized, quantized_state = quantize_blockwise(values_tensor) - # add the quantization state - quant_state[param_name]["absmax"] = quantized_state.absmax.numpy() - n_bytes_meta += quant_state[param_name]["absmax"].nbytes - quant_state[param_name]["code"] = quantized_state.code.numpy() - n_bytes_meta += quant_state[param_name]["code"].nbytes - # add values - values = quantized.numpy() - else: - # first convert numpy array to tensor, need to use GPU - values_tensor = torch.as_tensor(values).cuda() - # then quantize the tensor - if self.quantization_type == "float4": - quantized, quantized_state = quantize_4bit(values_tensor, quant_type="fp4") - else: - quantized, quantized_state = quantize_4bit(values_tensor, quant_type="nf4") - # add the quantization state - quantized_state = quantized_state.as_dict() - for state_name, state in quantized_state.items(): - # if the state is a tensor, convert it to numpy array - if isinstance(state, torch.Tensor): - quant_state[param_name][state_name] = state.cpu().numpy() - n_bytes_meta += state.nbytes - else: - quant_state[param_name][state_name] = state - # add values - values = quantized.cpu().numpy() - params[param_name] = values - n_bytes_after += params[param_name].nbytes - - self.log_info( - fl_ctx, - f"Quantized {n_quant_params}/{n_params} params." - f" Before quantization: {n_bytes_before / (1024 ** 2):.2f} MB." - f" After quantization: {n_bytes_after / (1024 ** 2):.2f} MB with meta: {n_bytes_meta / (1024 ** 2):.2f} MB.", - ) - return params, quant_state - - def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]: - """Filter process apply to the Shareable object. - - Args: - dxo: data to be processed - shareable: that the dxo belongs to - fl_ctx: FLContext - - Returns: DXO object with quantized weights - - """ - - self.log_info(fl_ctx, "Running quantization...") - quantized_params, quant_state = self.quantization(params=dxo.data, fl_ctx=fl_ctx) - # Compose new DXO with quantized data - # Add quant_state to the new DXO meta - new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta) - new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type) - new_dxo.set_meta_prop(key="quant_state", value=quant_state) - self.log_info(fl_ctx, f"Quantized to {self.quantization_type}") - - return new_dxo diff --git a/nvflare/app_opt/quantization/quantizor.py b/nvflare/app_opt/quantization/quantizor.py new file mode 100644 index 0000000000..ca427e8e03 --- /dev/null +++ b/nvflare/app_opt/quantization/quantizor.py @@ -0,0 +1,206 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Union + +import numpy as np +import torch +from bitsandbytes.functional import quantize_4bit, quantize_blockwise + +from nvflare.apis.dxo import DXO, DataKind, MetaKey +from nvflare.apis.dxo_filter import DXOFilter +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.app_opt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE + + +class ModelQuantizor(DXOFilter): + def __init__( + self, + quantization_type="float16", + ): + """Filter to quantize Shareable object to reduce communication burden. + + Args: + quantization_type: method used for quantization + + """ + + # support weight and weight_diff data kinds + data_kinds = [DataKind.WEIGHTS, DataKind.WEIGHT_DIFF] + super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds) + + # assign quantization type and check if it is valid + self.logger.info("Using model quantizator.") + if quantization_type.upper() not in QUANTIZATION_TYPE: + raise ValueError(f"Invalid quantization type: {quantization_type}, valid: {QUANTIZATION_TYPE}") + else: + self.quantization_type = quantization_type + + # quantization constants + self.NP_FP16_MIN = np.finfo(np.float16).min + self.NP_FP16_MAX = np.finfo(np.float16).max + self.TS_FP16_MIN = torch.finfo(torch.float16).min + self.TS_FP16_MAX = torch.finfo(torch.float16).max + + def quantization(self, params: dict, fl_ctx: FLContext): + n_params = len(params.keys()) + self.log_info(fl_ctx, f"Running quantization on {n_params} variables") + n_bytes_before = 0 + n_bytes_after = 0 + n_bytes_meta = 0 + n_quant_params = 0 + quant_state = {} + source_datatype = {} + for i, param_name in enumerate(params.keys()): + values = params[param_name] + quant_state[param_name] = {} + + # check the data type, numpy or torch + # otherwise error + if isinstance(values, np.ndarray): + # if numpy, convert to torch + source_data_format = "numpy" + elif isinstance(values, torch.Tensor): + source_data_format = "torch" + else: + raise ValueError(f"Invalid source data type: {type(values)}, valid: numpy or torch") + + # get the data type of the values + if source_data_format == "numpy": + source_data_type = values.dtype.name + elif source_data_format == "torch": + source_data_type = str(values.dtype).split(".")[1] + source_datatype[param_name] = source_data_type + + # check if the data type is valid + if source_data_type.upper() not in DATA_TYPE: + raise ValueError(f"Invalid source data type: {source_data_type}, valid: {DATA_TYPE}") + + # get the bits information + source_data_bits = int(re.findall(r"\d+", source_data_type)[0]) + quantization_bits = int(re.findall(r"\d+", self.quantization_type)[0]) + + # add the number of bytes of the values + n_bytes_before += values.nbytes + # only quantize if the quantization type is lower than the source data type + if quantization_bits >= source_data_bits: + self.log_info( + fl_ctx, + f"Skipping quantization for {param_name}, quantization bit {self.quantization_type} >= source data bit {source_data_type}", + ) + continue + else: + n_quant_params += 1 + if self.quantization_type == "float16": + if source_data_format == "numpy": + # first clamp the values to the range of float16 + values = np.clip(values, self.NP_FP16_MIN, self.NP_FP16_MAX) + # then convert to float16 + values = values.astype(np.float16) + elif source_data_format == "torch": + # first clamp the values to the range of float16 + values = torch.clamp(values, self.TS_FP16_MIN, self.TS_FP16_MAX) + # then convert to float16 + values = values.to(torch.float16) + params[param_name] = values + elif self.quantization_type in ["blockwise8", "float4", "normfloat4"]: + # use bitsandbytes to quantize the values + # input is a tensor, output is a tuple of (quantized tensor, quantized_state) + if self.quantization_type == "blockwise8": + if source_data_format == "numpy": + # if numpy, first convert numpy array to tensor + values_tensor = torch.as_tensor(values) + elif source_data_format == "torch": + values_tensor = values + + # then quantize the tensor + quantized, quantized_state = quantize_blockwise(values_tensor) + # add the quantization state and values, keep source data format + if source_data_format == "numpy": + quant_state[param_name]["absmax"] = quantized_state.absmax.numpy() + quant_state[param_name]["code"] = quantized_state.code.numpy() + values = quantized.numpy() + elif source_data_format == "torch": + quant_state[param_name]["absmax"] = quantized_state.absmax + quant_state[param_name]["code"] = quantized_state.code + values = quantized + n_bytes_meta += quant_state[param_name]["absmax"].nbytes + n_bytes_meta += quant_state[param_name]["code"].nbytes + else: + if source_data_format == "numpy": + # if numpy, first convert numpy array to tensor, need to use GPU + values_tensor = torch.as_tensor(values).cuda() + elif source_data_format == "torch": + # if torch, directly use the tensor, need to use GPU + values_tensor = values.cuda() + # then quantize the tensor + if self.quantization_type == "float4": + quantized, quantized_state = quantize_4bit(values_tensor, quant_type="fp4") + else: + quantized, quantized_state = quantize_4bit(values_tensor, quant_type="nf4") + # add the quantization state and values, keep source data format + quantized_state = quantized_state.as_dict() + + for state_name, state in quantized_state.items(): + if isinstance(state, torch.Tensor): + if source_data_format == "numpy": + # if the state is a tensor, convert it to numpy array + quant_state[param_name][state_name] = state.cpu().numpy() + elif source_data_format == "torch": + # if the state is a tensor, keep it as tensor + quant_state[param_name][state_name] = state.cpu() + n_bytes_meta += state.nbytes + else: + quant_state[param_name][state_name] = state + # add values + if source_data_format == "numpy": + values = quantized.cpu().numpy() + elif source_data_format == "torch": + values = quantized.cpu() + params[param_name] = values + n_bytes_after += params[param_name].nbytes + + self.log_info( + fl_ctx, + f"Quantized {n_quant_params}/{n_params} params." + f" Before quantization: {n_bytes_before / (1024 ** 2):.2f} MB." + f" After quantization: {n_bytes_after / (1024 ** 2):.2f} MB with meta: {n_bytes_meta / (1024 ** 2):.2f} MB.", + ) + return params, quant_state, source_datatype + + def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]: + """Filter process apply to the Shareable object. + + Args: + dxo: data to be processed + shareable: that the dxo belongs to + fl_ctx: FLContext + + Returns: DXO object with quantized weights + + """ + + self.log_info(fl_ctx, "Running quantization...") + quantized_params, quant_state, source_datatype = self.quantization(params=dxo.data, fl_ctx=fl_ctx) + # Compose new DXO with quantized data + # Add quant_state to the new DXO meta + new_dxo = DXO(data_kind=dxo.data_kind, data=quantized_params, meta=dxo.meta) + new_dxo.set_meta_prop(key=MetaKey.PROCESSED_ALGORITHM, value=self.quantization_type) + new_dxo.set_meta_prop(key="quant_state", value=quant_state) + new_dxo.set_meta_prop(key="source_datatype", value=source_datatype) + self.log_info(fl_ctx, f"Quantized to {self.quantization_type}") + + return new_dxo diff --git a/tests/unit_test/app_opt/quantization/quantization_test.py b/tests/unit_test/app_opt/quantization/quantization_test.py index 9fdc00d54a..3d3b41bddd 100644 --- a/tests/unit_test/app_opt/quantization/quantization_test.py +++ b/tests/unit_test/app_opt/quantization/quantization_test.py @@ -17,8 +17,8 @@ from nvflare.apis.dxo import DXO, DataKind from nvflare.apis.fl_context import FLContext -from nvflare.app_opt.quantization.numpy_dequantizor import NumpyModelDequantizor -from nvflare.app_opt.quantization.numpy_quantizor import NumpyModelQuantizor +from nvflare.app_opt.quantization.dequantizor import NumpyModelDequantizor +from nvflare.app_opt.quantization.quantizor import NumpyModelQuantizor TEST_CASES = [ ( From 395f2deafc3cc7dd81dfb848e19d81578408def5 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Fri, 13 Dec 2024 10:14:37 -0500 Subject: [PATCH 13/16] bug fixes and unittest updates --- nvflare/app_opt/quantization/dequantizor.py | 59 +++++++++---------- .../app_opt/quantization/quantization_test.py | 24 ++++++-- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/nvflare/app_opt/quantization/dequantizor.py b/nvflare/app_opt/quantization/dequantizor.py index 1ac6265c6e..abd1c1c277 100644 --- a/nvflare/app_opt/quantization/dequantizor.py +++ b/nvflare/app_opt/quantization/dequantizor.py @@ -80,17 +80,7 @@ def dequantization( n_quant_params += 1 if quantization_type == "float16": - # direct convert back to higher precision - if source_data_format == "numpy": - if source_data_type == "float32": - values = values.astype(np.float32) - elif source_data_type == "float64": - values = values.astype(np.float64) - elif source_data_format == "torch": - if source_data_type == "float32": - values = values.float() - elif source_data_type == "float64": - values = values.double() + # direct assign and convert back to higher precision params[param_name] = values elif quantization_type in ["blockwise8", "float4", "normfloat4"]: # use bitsandbytes to dequantize the values @@ -101,13 +91,12 @@ def dequantization( quantized = torch.as_tensor(values) absmax = torch.as_tensor(quant_state[param_name]["absmax"]) code = torch.as_tensor(quant_state[param_name]["code"]) + elif source_data_format == "torch": + quantized = values + absmax = quant_state[param_name]["absmax"] + code = quant_state[param_name]["code"] # de-quanitze dequantized = dequantize_blockwise(quantized, absmax=absmax, code=code) - # assign back - if source_data_format == "numpy": - params[param_name] = dequantized.numpy() - elif source_data_format == "torch": - params[param_name] = dequantized else: if source_data_format == "numpy": # first convert numpy array to tensor, need to use GPU @@ -136,20 +125,30 @@ def dequantization( dequantized = dequantize_4bit(quantized, quantize_state, quant_type="fp4") else: dequantized = dequantize_4bit(quantized, quantize_state, quant_type="nf4") - # assign back - if source_data_format == "numpy": - params[param_name] = dequantized.cpu().numpy() - elif source_data_format == "torch": - params[param_name] = dequantized.cpu() - # convert back to original data type - if source_data_type == "float32": - params[param_name] = params[param_name].float() - elif source_data_type == "float64": - params[param_name] = params[param_name].double() - elif source_data_type == "float16": - params[param_name] = params[param_name].half() - elif source_data_type == "bfloat16": - params[param_name] = params[param_name].bfloat16() + if source_data_format == "numpy": + params[param_name] = dequantized.cpu().numpy() + elif source_data_format == "torch": + params[param_name] = dequantized.cpu() + + # assign back + if source_data_format == "numpy": + # convert back to original data type + if source_data_type == "float32": + params[param_name] = params[param_name].astype(np.float32) + elif source_data_type == "float64": + params[param_name] = params[param_name].astype(np.float64) + elif source_data_type == "float16": + params[param_name] = params[param_name].astype(np.float16) + elif source_data_format == "torch": + # convert back to original data type + if source_data_type == "float32": + params[param_name] = params[param_name].float() + elif source_data_type == "float64": + params[param_name] = params[param_name].double() + elif source_data_type == "float16": + params[param_name] = params[param_name].half() + elif source_data_type == "bfloat16": + params[param_name] = params[param_name].bfloat16() n_bytes_after += params[param_name].nbytes diff --git a/tests/unit_test/app_opt/quantization/quantization_test.py b/tests/unit_test/app_opt/quantization/quantization_test.py index 3d3b41bddd..d2943a5e9a 100644 --- a/tests/unit_test/app_opt/quantization/quantization_test.py +++ b/tests/unit_test/app_opt/quantization/quantization_test.py @@ -14,11 +14,12 @@ import numpy as np import pytest +import torch from nvflare.apis.dxo import DXO, DataKind from nvflare.apis.fl_context import FLContext -from nvflare.app_opt.quantization.dequantizor import NumpyModelDequantizor -from nvflare.app_opt.quantization.quantizor import NumpyModelQuantizor +from nvflare.app_opt.quantization.dequantizor import ModelDequantizor +from nvflare.app_opt.quantization.quantizor import ModelQuantizor TEST_CASES = [ ( @@ -31,6 +32,16 @@ "blockwise8", {"a": np.array([0.99062496, 2.003125, 3.015625, 4.0], dtype="float32")}, ), + ( + {"a": torch.tensor([1.0, 2.0, 3.0, 4000.0], dtype=torch.bfloat16)}, + "float16", + {"a": torch.tensor([1.0, 2.0, 3.0, 4000.0], dtype=torch.bfloat16)}, + ), + ( + {"a": torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.bfloat16)}, + "float4", + {"a": torch.tensor([1.0, 2.0, 2.6719, 4.0], dtype=torch.bfloat16)}, + ), ] @@ -42,12 +53,15 @@ def test_quantization(self, input_data, quantization_type, expected_data): data=input_data, ) fl_ctx = FLContext() - f_quant = NumpyModelQuantizor(quantization_type=quantization_type) + f_quant = ModelQuantizor(quantization_type=quantization_type) quant_dxo = f_quant.process_dxo(dxo, dxo.to_shareable(), fl_ctx) - f_dequant = NumpyModelDequantizor(source_data_type="float32") + f_dequant = ModelDequantizor() dequant_dxo = f_dequant.process_dxo(quant_dxo, dxo.to_shareable(), fl_ctx) dequant_data = dequant_dxo.data for key in dequant_data.keys(): dequant_array = dequant_data[key] expected_array = expected_data[key] - assert np.allclose(dequant_array, expected_array) + if isinstance(dequant_array, torch.Tensor): + assert torch.allclose(dequant_array, expected_array) + else: + assert np.allclose(dequant_array, expected_array) From e59210fafaeb63698b9a354dab5472135125c9c8 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Fri, 13 Dec 2024 10:32:07 -0500 Subject: [PATCH 14/16] unit test cannot run on gpu, update case --- .../unit_test/app_opt/quantization/quantization_test.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/unit_test/app_opt/quantization/quantization_test.py b/tests/unit_test/app_opt/quantization/quantization_test.py index d2943a5e9a..cd273480a9 100644 --- a/tests/unit_test/app_opt/quantization/quantization_test.py +++ b/tests/unit_test/app_opt/quantization/quantization_test.py @@ -38,9 +38,9 @@ {"a": torch.tensor([1.0, 2.0, 3.0, 4000.0], dtype=torch.bfloat16)}, ), ( - {"a": torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.bfloat16)}, - "float4", - {"a": torch.tensor([1.0, 2.0, 2.6719, 4.0], dtype=torch.bfloat16)}, + {"a": torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)}, + "blockwise8", + {"a": torch.tensor([0.99062496, 2.003125, 3.015625, 4.0], dtype=torch.float32)}, ), ] @@ -61,6 +61,9 @@ def test_quantization(self, input_data, quantization_type, expected_data): for key in dequant_data.keys(): dequant_array = dequant_data[key] expected_array = expected_data[key] + # print the values + print(f"dequant_array: {dequant_array}") + print(f"expected_array: {expected_array}") if isinstance(dequant_array, torch.Tensor): assert torch.allclose(dequant_array, expected_array) else: From 54ea4fb96ef8d3dca2fc5a14453f693b5364e882 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Fri, 13 Dec 2024 15:30:05 -0500 Subject: [PATCH 15/16] bug fixes and polish --- examples/advanced/llm_hf/sft_job.py | 9 +++++---- .../executors/in_process_client_api_executor.py | 8 ++++---- nvflare/app_opt/{ => pt}/quantization/__init__.py | 0 nvflare/app_opt/{ => pt}/quantization/constant.py | 0 nvflare/app_opt/{ => pt}/quantization/dequantizor.py | 2 +- nvflare/app_opt/{ => pt}/quantization/quantizor.py | 2 +- nvflare/app_opt/pt/tensor_params_converter.py | 1 + nvflare/job_config/script_runner.py | 8 ++++---- .../unit_test/app_opt/quantization/quantization_test.py | 4 ++-- 9 files changed, 18 insertions(+), 16 deletions(-) rename nvflare/app_opt/{ => pt}/quantization/__init__.py (100%) rename nvflare/app_opt/{ => pt}/quantization/constant.py (100%) rename nvflare/app_opt/{ => pt}/quantization/dequantizor.py (99%) rename nvflare/app_opt/{ => pt}/quantization/quantizor.py (99%) diff --git a/examples/advanced/llm_hf/sft_job.py b/examples/advanced/llm_hf/sft_job.py index a1810ab3ee..c8c6ac6bbf 100644 --- a/examples/advanced/llm_hf/sft_job.py +++ b/examples/advanced/llm_hf/sft_job.py @@ -20,8 +20,8 @@ from nvflare.app_common.workflows.fedavg import FedAvg from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor from nvflare.app_opt.pt.tensor_params_converter import PTReceiveParamsConverter, PTSendParamsConverter -from nvflare.app_opt.quantization.dequantizor import ModelDequantizor -from nvflare.app_opt.quantization.quantizor import ModelQuantizor +from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor +from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor from nvflare.job_config.script_runner import BaseScriptRunner @@ -90,20 +90,21 @@ def main(): data_path_train = os.path.join(args.data_path, client_id, "training.jsonl") data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl") + script_args = f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --message_mode {message_mode} --clean_up {clean_up}" if message_mode == "tensor": # Add params converters and send to client job.to(PTSendParamsConverter(), site_name, id="pt_send") job.to(PTReceiveParamsConverter(), site_name, id="pt_receive") runner = BaseScriptRunner( script=train_script, - script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --message_mode {message_mode} --clean_up {clean_up}", + script_args=script_args, from_nvflare_converter_id="pt_receive", to_nvflare_converter_id="pt_send", ) elif message_mode == "numpy": runner = BaseScriptRunner( script=train_script, - script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --message_mode {message_mode} --clean_up {clean_up}", + script_args=script_args, ) else: raise ValueError(f"Invalid message_mode: {message_mode}, only numpy and tensor are supported.") diff --git a/nvflare/app_common/executors/in_process_client_api_executor.py b/nvflare/app_common/executors/in_process_client_api_executor.py index 98445c66e2..c89233904b 100644 --- a/nvflare/app_common/executors/in_process_client_api_executor.py +++ b/nvflare/app_common/executors/in_process_client_api_executor.py @@ -54,8 +54,8 @@ def __init__( log_pull_interval: Optional[float] = None, params_exchange_format: str = ExchangeFormat.NUMPY, params_transfer_type: TransferType = TransferType.FULL, - from_nvflare_converter_id: str = None, - to_nvflare_converter_id: str = None, + from_nvflare_converter_id: Optional[str] = None, + to_nvflare_converter_id: Optional[str] = None, train_with_evaluation: bool = True, train_task_name: str = AppConstants.TASK_TRAIN, evaluate_task_name: str = AppConstants.TASK_VALIDATION, @@ -200,11 +200,11 @@ def send_data_to_peer(self, shareable, fl_ctx: FLContext): def _init_converter(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() - from_nvflare_converter: ParamsConverter = engine.get_component(self._from_nvflare_converter_id[0]) + from_nvflare_converter: ParamsConverter = engine.get_component(self._from_nvflare_converter_id) if from_nvflare_converter is not None: check_object_type(self._from_nvflare_converter_id, from_nvflare_converter, ParamsConverter) self._from_nvflare_converter = from_nvflare_converter - to_nvflare_converter: ParamsConverter = engine.get_component(self._to_nvflare_converter_id[0]) + to_nvflare_converter: ParamsConverter = engine.get_component(self._to_nvflare_converter_id) if to_nvflare_converter is not None: check_object_type(self._to_nvflare_converter_id, to_nvflare_converter, ParamsConverter) self._to_nvflare_converter = to_nvflare_converter diff --git a/nvflare/app_opt/quantization/__init__.py b/nvflare/app_opt/pt/quantization/__init__.py similarity index 100% rename from nvflare/app_opt/quantization/__init__.py rename to nvflare/app_opt/pt/quantization/__init__.py diff --git a/nvflare/app_opt/quantization/constant.py b/nvflare/app_opt/pt/quantization/constant.py similarity index 100% rename from nvflare/app_opt/quantization/constant.py rename to nvflare/app_opt/pt/quantization/constant.py diff --git a/nvflare/app_opt/quantization/dequantizor.py b/nvflare/app_opt/pt/quantization/dequantizor.py similarity index 99% rename from nvflare/app_opt/quantization/dequantizor.py rename to nvflare/app_opt/pt/quantization/dequantizor.py index abd1c1c277..d19a9584ec 100644 --- a/nvflare/app_opt/quantization/dequantizor.py +++ b/nvflare/app_opt/pt/quantization/dequantizor.py @@ -23,7 +23,7 @@ from nvflare.apis.dxo_filter import DXOFilter from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable -from nvflare.app_opt.quantization.constant import QUANTIZATION_TYPE +from nvflare.app_opt.pt.quantization.constant import QUANTIZATION_TYPE class ModelDequantizor(DXOFilter): diff --git a/nvflare/app_opt/quantization/quantizor.py b/nvflare/app_opt/pt/quantization/quantizor.py similarity index 99% rename from nvflare/app_opt/quantization/quantizor.py rename to nvflare/app_opt/pt/quantization/quantizor.py index ca427e8e03..083ce8bde5 100644 --- a/nvflare/app_opt/quantization/quantizor.py +++ b/nvflare/app_opt/pt/quantization/quantizor.py @@ -23,7 +23,7 @@ from nvflare.apis.dxo_filter import DXOFilter from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable -from nvflare.app_opt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE +from nvflare.app_opt.pt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE class ModelQuantizor(DXOFilter): diff --git a/nvflare/app_opt/pt/tensor_params_converter.py b/nvflare/app_opt/pt/tensor_params_converter.py index 1e8275f58c..6c87cdf9d3 100644 --- a/nvflare/app_opt/pt/tensor_params_converter.py +++ b/nvflare/app_opt/pt/tensor_params_converter.py @@ -30,6 +30,7 @@ def convert(self, params: Dict, fl_ctx) -> Dict: if isinstance(v, torch.Tensor): return_params[k] = v else: + # "PT receive, so potentially also need to handle numpy to tensor" if tensor_shapes: if k in tensor_shapes: return_params[k] = torch.as_tensor(np.reshape(v, tensor_shapes[k])) diff --git a/nvflare/job_config/script_runner.py b/nvflare/job_config/script_runner.py index cb418408d9..68fbd8a2a9 100644 --- a/nvflare/job_config/script_runner.py +++ b/nvflare/job_config/script_runner.py @@ -45,8 +45,8 @@ def __init__( framework: FrameworkType = FrameworkType.PYTORCH, params_transfer_type: str = TransferType.FULL, executor: Union[ClientAPILauncherExecutor, InProcessClientAPIExecutor, None] = None, - from_nvflare_converter_id: str = None, - to_nvflare_converter_id: str = None, + from_nvflare_converter_id: Optional[str] = None, + to_nvflare_converter_id: Optional[str] = None, task_pipe: Optional[Pipe] = None, launcher: Optional[Launcher] = None, metric_relay: Optional[MetricRelay] = None, @@ -98,8 +98,8 @@ def __init__( self._launch_external_process = launch_external_process self._framework = framework self._params_transfer_type = params_transfer_type - self._from_nvflare_converter_id = (from_nvflare_converter_id,) - self._to_nvflare_converter_id = (to_nvflare_converter_id,) + self._from_nvflare_converter_id = from_nvflare_converter_id + self._to_nvflare_converter_id = to_nvflare_converter_id self._params_exchange_format = None if self._framework == FrameworkType.PYTORCH: diff --git a/tests/unit_test/app_opt/quantization/quantization_test.py b/tests/unit_test/app_opt/quantization/quantization_test.py index cd273480a9..b8b2a6fb35 100644 --- a/tests/unit_test/app_opt/quantization/quantization_test.py +++ b/tests/unit_test/app_opt/quantization/quantization_test.py @@ -18,8 +18,8 @@ from nvflare.apis.dxo import DXO, DataKind from nvflare.apis.fl_context import FLContext -from nvflare.app_opt.quantization.dequantizor import ModelDequantizor -from nvflare.app_opt.quantization.quantizor import ModelQuantizor +from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor +from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor TEST_CASES = [ ( From 837689416882423d343bd91cb7484bd105501022 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Fri, 13 Dec 2024 15:40:15 -0500 Subject: [PATCH 16/16] format update --- examples/advanced/llm_hf/sft_job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced/llm_hf/sft_job.py b/examples/advanced/llm_hf/sft_job.py index c8c6ac6bbf..f14a0ce4cb 100644 --- a/examples/advanced/llm_hf/sft_job.py +++ b/examples/advanced/llm_hf/sft_job.py @@ -19,9 +19,9 @@ from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector from nvflare.app_common.workflows.fedavg import FedAvg from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor -from nvflare.app_opt.pt.tensor_params_converter import PTReceiveParamsConverter, PTSendParamsConverter from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor +from nvflare.app_opt.pt.tensor_params_converter import PTReceiveParamsConverter, PTSendParamsConverter from nvflare.job_config.script_runner import BaseScriptRunner