Skip to content

Commit

Permalink
Compression methods (#555)
Browse files Browse the repository at this point in the history
* add grpc compression

* minor fix

* add uniform quantization

* add unittest

* update unittest

* minor fix
  • Loading branch information
xieyxclack authored Mar 29, 2023
1 parent 5ab598f commit 5e3abad
Show file tree
Hide file tree
Showing 10 changed files with 319 additions and 16 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/test_distribute.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ jobs:
- name: Test Distributed (LR on toy with a unified files)
run: |
python scripts/distributed_scripts/gen_data.py
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_server_no_data.yaml &
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_server_no_data.yaml distribute.grpc_compression gzip &
sleep 2
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml &
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml distribute.grpc_compression gzip &
sleep 2
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml &
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml distribute.grpc_compression gzip &
sleep 2
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml distribute.grpc_compression gzip
[ $? -eq 1 ] && exit 1 || echo "Passed"
- name: Test Distributed (LR on toy with multiple files)
run: |
Expand Down
21 changes: 14 additions & 7 deletions federatedscope/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from collections import deque

from federatedscope.core.configs.config import global_cfg
from federatedscope.core.proto import gRPC_comm_manager_pb2, \
gRPC_comm_manager_pb2_grpc
from federatedscope.core.gRPC_server import gRPCComServeFunc
Expand Down Expand Up @@ -106,17 +105,23 @@ class gRPCCommManager(object):
The implementation of gRPCCommManager is referred to the tutorial on
https://grpc.io/docs/languages/python/
"""
def __init__(self, host='0.0.0.0', port='50050', client_num=2):
def __init__(self, host='0.0.0.0', port='50050', client_num=2, cfg=None):
self.host = host
self.port = port
options = [
("grpc.max_send_message_length",
global_cfg.distribute.grpc_max_send_message_length),
("grpc.max_send_message_length", cfg.grpc_max_send_message_length),
("grpc.max_receive_message_length",
global_cfg.distribute.grpc_max_receive_message_length),
("grpc.enable_http_proxy",
global_cfg.distribute.grpc_enable_http_proxy),
cfg.grpc_max_receive_message_length),
("grpc.enable_http_proxy", cfg.grpc_enable_http_proxy),
]

if cfg.grpc_compression.lower() == 'deflate':
self.comp_method = grpc.Compression.Deflate
elif cfg.grpc_compression.lower() == 'gzip':
self.comp_method = grpc.Compression.Gzip
else:
self.comp_method = grpc.Compression.NoCompression

self.server_funcs = gRPCComServeFunc()
self.grpc_server = self.serve(max_workers=client_num,
host=host,
Expand All @@ -132,6 +137,7 @@ def serve(self, max_workers, host, port, options):
"""
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=max_workers),
compression=self.comp_method,
options=options)
gRPC_comm_manager_pb2_grpc.add_gRPCComServeFuncServicer_to_server(
self.server_funcs, server)
Expand Down Expand Up @@ -170,6 +176,7 @@ def _create_stub(receiver_address):
https://grpc.io/docs/languages/python/basics/#creating-a-stub
"""
channel = grpc.insecure_channel(receiver_address,
compression=self.comp_method,
options=(('grpc.enable_http_proxy',
0), ))
stub = gRPC_comm_manager_pb2_grpc.gRPCComServeFuncStub(channel)
Expand Down
38 changes: 38 additions & 0 deletions federatedscope/core/compression/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Message compression for efficient communication

We provide plugins of message compression for efficient communication.

## Lossless compression based on gRPC
When running with distributed mode of FederatedScope, the shared messages can be compressed using the compression module provided by gRPC (More details can be found [here](https://chromium.googlesource.com/external/github.com/grpc/grpc/+/HEAD/examples/python/compression/)).

Users can turn on the message compression by adding the following configuration:
```yaml
distribute:
grpc_compression: 'deflate' # or 'gzip'
```
The compression of training ConvNet-2 on FEMNIST is shown as below:
| | NoCompression | Deflate | Gzip |
| :---: | :---: | :---: | :---: |
| Communication bytes per round (in gRPC channel) | 4.021MB | 1.888MB | 1.890MB |
## Model quantization
We provide a symmetric uniform quantization to transform the model parameters (32-bit float) to 8/16-bit int (note that it might bring model performance drop).
To apply model quantization, users need to add the following configurations:
```yaml
quantization:
method: 'uniform'
nbits: 16 # or 8
```
We conduct experiments based on the scripts provided in `federatedscope/cv/baseline/fedavg_convnet2_on_femnist.yaml` and report the results as:

| | 32-bit float (vanilla) | 16-bit int | 8-bit int |
| :---: | :---: | :---: | :---: |
| Shared model size (in memory) | 25.20MB | 12.61MB | 6.31MB |
| Model performance (acc) | 0.7856 | 0.7854 | 0.6807 |

More fancy compression techniques are coming soon! We greatly appreciate contribution to FederatedScope!
6 changes: 6 additions & 0 deletions federatedscope/core/compression/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from federatedscope.core.compression.utils import \
symmetric_uniform_quantization, symmetric_uniform_dequantization

__all__ = [
'symmetric_uniform_quantization', 'symmetric_uniform_dequantization'
]
84 changes: 84 additions & 0 deletions federatedscope/core/compression/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def _symmetric_uniform_quantization(x, nbits, stochastic=False):
assert (torch.isnan(x).sum() == 0)
assert (torch.isinf(x).sum() == 0)

c = torch.max(torch.abs(x))
s = c / (2**(nbits - 1) - 1)
if s == 0:
return x, s
c_minus = c * -1.0

# qx = torch.where(x.ge(c), c, x)
# qx = torch.where(qx.le(c_minus), c_minus, qx)
# qx.div_(s)
qx = x / s

if stochastic:
noise = qx.new(qx.shape).uniform_(-0.5, 0.5)
qx.add_(noise)

qx.clamp_(-(2**(nbits - 1) - 1), (2**(nbits - 1) - 1)).round_()
return qx, s


def symmetric_uniform_quantization(state_dict, nbits=8):
"""
Perform symmetric uniform quantization to weight in conv & fc layers
Args:
state_dict: dict of model parameter (torch_model.state_dict)
nbits: the bit of values after quantized, chosen from [8, 16]
Returns:
The quantized model parameters
"""
if nbits == 8:
quant_data_type = torch.int8
elif nbits == 16:
quant_data_type = torch.int16
else:
logger.info(f'The provided value of nbits ({nbits}) is invalid, and we'
f' change it to 8')
nbits = 8
quant_data_type = torch.int8

quant_state_dict = dict()
for key, value in state_dict.items():
if ('fc' in key or 'conv' in key) and 'weight' == key.split('.')[-1]:
q_weight, w_s = _symmetric_uniform_quantization(value, nbits=nbits)
quant_state_dict[key.replace(
'weight', 'weight_quant')] = q_weight.type(quant_data_type)
quant_state_dict[key.replace('weight', 'weight_scale')] = w_s
else:
quant_state_dict[key] = value

return quant_state_dict


def symmetric_uniform_dequantization(state_dict):
"""
Perform symmetric uniform dequantization
Args:
state_dict: dict of model parameter (torch_model.state_dict)
Returns:
The model parameters after dequantization
"""
dequantizated_state_dict = dict()
for key, value in state_dict.items():
if 'weight_quant' in key:
alpha = state_dict[key.replace('weight_quant', 'weight_scale')]
dequantizated_state_dict[key.replace('weight_quant',
'weight')] = value * alpha
elif 'weight_scale' in key:
pass
else:
dequantizated_state_dict[key] = value

return dequantizated_state_dict
38 changes: 38 additions & 0 deletions federatedscope/core/configs/cfg_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

from federatedscope.core.configs.config import CN
from federatedscope.register import register_config

logger = logging.getLogger(__name__)


def extend_compression_cfg(cfg):
# ---------------------------------------------------------------------- #
# Compression (for communication efficiency) related options
# ---------------------------------------------------------------------- #
cfg.quantization = CN()

# Params
cfg.quantization.method = 'none' # ['none', 'uniform']
cfg.quantization.nbits = 8 # [8,16]

# --------------- register corresponding check function ----------
cfg.register_cfg_check_fun(assert_compression_cfg)


def assert_compression_cfg(cfg):

if cfg.quantization.method.lower() not in ['none', 'uniform']:
logger.warning(
f'Quantization method is expected to be one of ["none",'
f'"uniform"], but got "{cfg.quantization.method}". So we '
f'change it to "none"')

if cfg.quantization.method.lower(
) != 'none' and cfg.quantization.nbits not in [8, 16]:
raise ValueError(f'The value of cfg.quantization.nbits is invalid, '
f'which is expected to be one on [8, 16] but got '
f'{cfg.quantization.nbits}.')


register_config("compression", extend_compression_cfg)
8 changes: 8 additions & 0 deletions federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def extend_fl_setting_cfg(cfg):
cfg.distribute.grpc_max_send_message_length = 100 * 1024 * 1024
cfg.distribute.grpc_max_receive_message_length = 100 * 1024 * 1024
cfg.distribute.grpc_enable_http_proxy = False
cfg.distribute.grpc_compression = 'nocompression' # [deflate, gzip]

# ---------------------------------------------------------------------- #
# Vertical FL related options (for demo)
Expand Down Expand Up @@ -263,5 +264,12 @@ def assert_fl_setting_cfg(cfg):
f'must be in (0, 1.0], but got '
f'{cfg.vertical.feature_subsample_ratio}')

if cfg.distribute.use and cfg.distribute.grpc_compression.lower() not in [
'nocompression', 'deflate', 'gzip'
]:
raise ValueError(f'The type of grpc compression is expected to be one '
f'of ["nocompression", "deflate", "gzip"], but got '
f'{cfg.distribute.grpc_compression}.')


register_config("fl_setting", extend_fl_setting_cfg)
31 changes: 30 additions & 1 deletion federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def __init__(self,
server_host = kwargs['server_host']
server_port = kwargs['server_port']
self.comm_manager = gRPCCommManager(
host=host, port=port, client_num=self._cfg.federate.client_num)
host=host,
port=port,
client_num=self._cfg.federate.client_num,
cfg=self._cfg.distribute)
logger.info('Client: Listen to {}:{}...'.format(host, port))
self.comm_manager.add_neighbors(neighbor_id=server_id,
address={
Expand Down Expand Up @@ -291,6 +294,18 @@ def callback_funcs_for_model_para(self, message: Message):
sender = message.sender
timestamp = message.timestamp
content = message.content

# dequantization
if self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_dequantization
if isinstance(content, list): # multiple model
content = [
symmetric_uniform_dequantization(x) for x in content
]
else:
content = symmetric_uniform_dequantization(content)

# When clients share the local model, we must set strict=True to
# ensure all the model params (which might be updated by other
# clients in the previous local training process) are overwritten
Expand Down Expand Up @@ -394,6 +409,20 @@ def callback_funcs_for_model_para(self, message: Message):
else:
shared_model_para = model_para_all

# quantization
if self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_quantization
nbits = self._cfg.quantization.nbits
if isinstance(shared_model_para, list):
shared_model_para = [
symmetric_uniform_quantization(x, nbits)
for x in shared_model_para
]
else:
shared_model_para = symmetric_uniform_quantization(
shared_model_para, nbits)

self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
Expand Down
34 changes: 30 additions & 4 deletions federatedscope/core/workers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ def __init__(self,
port = kwargs['port']
self.comm_manager = gRPCCommManager(host=host,
port=port,
client_num=client_num)
client_num=client_num,
cfg=self._cfg.distribute)
logger.info('Server: Listen to {}:{}...'.format(host, port))

# inject noise before broadcast
Expand Down Expand Up @@ -666,6 +667,20 @@ def broadcast_model_para(self,
else:
model_para = {} if skip_broadcast else self.models[0].state_dict()

# quantization
if msg_type == 'model_para' and not skip_broadcast and \
self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_quantization
nbits = self._cfg.quantization.nbits
if self.model_num > 1:
model_para = [
symmetric_uniform_quantization(x, nbits)
for x in model_para
]
else:
model_para = symmetric_uniform_quantization(model_para, nbits)

# We define the evaluation happens at the end of an epoch
rnd = self.state - 1 if msg_type == 'evaluate' else self.state

Expand Down Expand Up @@ -815,9 +830,6 @@ def trigger_for_start(self):
logger.info(
'----------- Starting training (Round #{:d}) -------------'.
format(self.state))
print(
time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(time.time())))

def trigger_for_feat_engr(self,
trigger_train_func,
Expand Down Expand Up @@ -920,6 +932,20 @@ def callback_funcs_model_para(self, message: Message):
content = message.content
self.sampler.change_state(sender, 'idle')

# dequantization
if self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_dequantization
if isinstance(content[1], list): # multiple model
sample_size = content[0]
quant_model = [
symmetric_uniform_dequantization(x) for x in content[1]
]
else:
sample_size = content[0]
quant_model = symmetric_uniform_dequantization(content[1])
content = (sample_size, quant_model)

# update the currency timestamp according to the received message
assert timestamp >= self.cur_timestamp # for test
self.cur_timestamp = timestamp
Expand Down
Loading

0 comments on commit 5e3abad

Please sign in to comment.