-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add grpc compression * minor fix * add uniform quantization * add unittest * update unittest * minor fix
- Loading branch information
1 parent
5ab598f
commit 5e3abad
Showing
10 changed files
with
319 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Message compression for efficient communication | ||
|
||
We provide plugins of message compression for efficient communication. | ||
|
||
## Lossless compression based on gRPC | ||
When running with distributed mode of FederatedScope, the shared messages can be compressed using the compression module provided by gRPC (More details can be found [here](https://chromium.googlesource.com/external/github.com/grpc/grpc/+/HEAD/examples/python/compression/)). | ||
|
||
Users can turn on the message compression by adding the following configuration: | ||
```yaml | ||
distribute: | ||
grpc_compression: 'deflate' # or 'gzip' | ||
``` | ||
The compression of training ConvNet-2 on FEMNIST is shown as below: | ||
| | NoCompression | Deflate | Gzip | | ||
| :---: | :---: | :---: | :---: | | ||
| Communication bytes per round (in gRPC channel) | 4.021MB | 1.888MB | 1.890MB | | ||
## Model quantization | ||
We provide a symmetric uniform quantization to transform the model parameters (32-bit float) to 8/16-bit int (note that it might bring model performance drop). | ||
To apply model quantization, users need to add the following configurations: | ||
```yaml | ||
quantization: | ||
method: 'uniform' | ||
nbits: 16 # or 8 | ||
``` | ||
We conduct experiments based on the scripts provided in `federatedscope/cv/baseline/fedavg_convnet2_on_femnist.yaml` and report the results as: | ||
|
||
| | 32-bit float (vanilla) | 16-bit int | 8-bit int | | ||
| :---: | :---: | :---: | :---: | | ||
| Shared model size (in memory) | 25.20MB | 12.61MB | 6.31MB | | ||
| Model performance (acc) | 0.7856 | 0.7854 | 0.6807 | | ||
|
||
More fancy compression techniques are coming soon! We greatly appreciate contribution to FederatedScope! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from federatedscope.core.compression.utils import \ | ||
symmetric_uniform_quantization, symmetric_uniform_dequantization | ||
|
||
__all__ = [ | ||
'symmetric_uniform_quantization', 'symmetric_uniform_dequantization' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import torch | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
def _symmetric_uniform_quantization(x, nbits, stochastic=False): | ||
assert (torch.isnan(x).sum() == 0) | ||
assert (torch.isinf(x).sum() == 0) | ||
|
||
c = torch.max(torch.abs(x)) | ||
s = c / (2**(nbits - 1) - 1) | ||
if s == 0: | ||
return x, s | ||
c_minus = c * -1.0 | ||
|
||
# qx = torch.where(x.ge(c), c, x) | ||
# qx = torch.where(qx.le(c_minus), c_minus, qx) | ||
# qx.div_(s) | ||
qx = x / s | ||
|
||
if stochastic: | ||
noise = qx.new(qx.shape).uniform_(-0.5, 0.5) | ||
qx.add_(noise) | ||
|
||
qx.clamp_(-(2**(nbits - 1) - 1), (2**(nbits - 1) - 1)).round_() | ||
return qx, s | ||
|
||
|
||
def symmetric_uniform_quantization(state_dict, nbits=8): | ||
""" | ||
Perform symmetric uniform quantization to weight in conv & fc layers | ||
Args: | ||
state_dict: dict of model parameter (torch_model.state_dict) | ||
nbits: the bit of values after quantized, chosen from [8, 16] | ||
Returns: | ||
The quantized model parameters | ||
""" | ||
if nbits == 8: | ||
quant_data_type = torch.int8 | ||
elif nbits == 16: | ||
quant_data_type = torch.int16 | ||
else: | ||
logger.info(f'The provided value of nbits ({nbits}) is invalid, and we' | ||
f' change it to 8') | ||
nbits = 8 | ||
quant_data_type = torch.int8 | ||
|
||
quant_state_dict = dict() | ||
for key, value in state_dict.items(): | ||
if ('fc' in key or 'conv' in key) and 'weight' == key.split('.')[-1]: | ||
q_weight, w_s = _symmetric_uniform_quantization(value, nbits=nbits) | ||
quant_state_dict[key.replace( | ||
'weight', 'weight_quant')] = q_weight.type(quant_data_type) | ||
quant_state_dict[key.replace('weight', 'weight_scale')] = w_s | ||
else: | ||
quant_state_dict[key] = value | ||
|
||
return quant_state_dict | ||
|
||
|
||
def symmetric_uniform_dequantization(state_dict): | ||
""" | ||
Perform symmetric uniform dequantization | ||
Args: | ||
state_dict: dict of model parameter (torch_model.state_dict) | ||
Returns: | ||
The model parameters after dequantization | ||
""" | ||
dequantizated_state_dict = dict() | ||
for key, value in state_dict.items(): | ||
if 'weight_quant' in key: | ||
alpha = state_dict[key.replace('weight_quant', 'weight_scale')] | ||
dequantizated_state_dict[key.replace('weight_quant', | ||
'weight')] = value * alpha | ||
elif 'weight_scale' in key: | ||
pass | ||
else: | ||
dequantizated_state_dict[key] = value | ||
|
||
return dequantizated_state_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import logging | ||
|
||
from federatedscope.core.configs.config import CN | ||
from federatedscope.register import register_config | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def extend_compression_cfg(cfg): | ||
# ---------------------------------------------------------------------- # | ||
# Compression (for communication efficiency) related options | ||
# ---------------------------------------------------------------------- # | ||
cfg.quantization = CN() | ||
|
||
# Params | ||
cfg.quantization.method = 'none' # ['none', 'uniform'] | ||
cfg.quantization.nbits = 8 # [8,16] | ||
|
||
# --------------- register corresponding check function ---------- | ||
cfg.register_cfg_check_fun(assert_compression_cfg) | ||
|
||
|
||
def assert_compression_cfg(cfg): | ||
|
||
if cfg.quantization.method.lower() not in ['none', 'uniform']: | ||
logger.warning( | ||
f'Quantization method is expected to be one of ["none",' | ||
f'"uniform"], but got "{cfg.quantization.method}". So we ' | ||
f'change it to "none"') | ||
|
||
if cfg.quantization.method.lower( | ||
) != 'none' and cfg.quantization.nbits not in [8, 16]: | ||
raise ValueError(f'The value of cfg.quantization.nbits is invalid, ' | ||
f'which is expected to be one on [8, 16] but got ' | ||
f'{cfg.quantization.nbits}.') | ||
|
||
|
||
register_config("compression", extend_compression_cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.