Skip to content

Commit

Permalink
Add snip_momentum structured pruning which can support higher sparse …
Browse files Browse the repository at this point in the history
…ratio with minor accuracy loss (microsoft#3300)

Signed-off-by: Tian, Feng <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
ftian1 and tjruwase authored May 10, 2023
1 parent b31b46c commit 6938c44
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 12 deletions.
25 changes: 25 additions & 0 deletions deepspeed/compression/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
import os
import json

try:
import neural_compressor as nc
except ImportError as e:
nc = None


def check_deepspeed_config(config):
if isinstance(config, dict):
Expand Down Expand Up @@ -117,6 +122,26 @@ def init_compression(model, deepspeed_config, teacher_model=None, mpu=None):
layer_added_compress_methods = get_compress_methods(c_model, compress_methods, mpu=mpu)
compression_preparation(c_model, layer_added_compress_methods, mpu)

# For sparse pruning snip_momentum method
shared_parameters = compress_methods[SPARSE_PRUNING][SHARED_PARAMETERS]
if shared_parameters[SPARSE_PRUNING_ENABLED] and \
shared_parameters[SPARSE_PRUNING_METHOD] == SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:

assert nc is not None, "please ensure the neural_compressor python package is installed by pip or conda if user wants to use snip_momentum sparse pruning"

from .helper import generate_pruners, register_on_step_begin
from nc import WeightPruningConfig

config = WeightPruningConfig(target_sparsity=1 - shared_parameters[SPARSE_PRUNING_DENSE_RATIO],
pattern=shared_parameters[SPARSE_PRUNING_BLOCK_PATTERN],
pruning_frequency=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE],
start_step=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET],
end_step=shared_parameters[SPARSE_PRUNING_SCHEDULE_OFFSET_END],
excluded_op_names=shared_parameters[SPARSE_PRUNING_EXCLUDED_MODULES])
pruners = generate_pruners(config, c_model)
c_model.pruners = pruners
register_on_step_begin(c_model)

return model


Expand Down
28 changes: 23 additions & 5 deletions deepspeed/compression/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .constants import *
import copy
from ..runtime.config_utils import get_scalar_param
from ..runtime.config_utils import get_scalar_param, get_list_param


def get_compression_config(param_dict):
Expand Down Expand Up @@ -221,26 +221,44 @@ def get_sparse_pruning(param_dict):
# shared parameters
output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]:
if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED] and output[SHARED_PARAMETERS][
SPARSE_PRUNING_METHOD] != SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
), f"Sparse Pruning is enabled and not snip_momentum method, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict)
return output


def get_sparse_pruning_shared_parameters(param_dict):
output = {}

if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[SPARSE_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_ENABLED,
SPARSE_PRUNING_ENABLED_DEFAULT)
output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_METHOD,
SPARSE_PRUNING_METHOD_DEFAULT)
assert output[SPARSE_PRUNING_METHOD] in [
SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK
], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]"
SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK, SPARSE_PRUNING_METHOD_SNIP_MOMENTUM
], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}, {SPARSE_PRUNING_METHOD_SNIP_MOMENTUM}]"
output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET,
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
if output[SPARSE_PRUNING_METHOD] == SPARSE_PRUNING_METHOD_SNIP_MOMENTUM:
output[SPARSE_PRUNING_BLOCK_PATTERN] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_BLOCK_PATTERN,
SPARSE_PRUNING_BLOCK_PATTERN_DEFAULT)
output[SPARSE_PRUNING_DENSE_RATIO] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_DENSE_RATIO,
SPARSE_PRUNING_DENSE_RATIO_DEFAULT)
assert output[SPARSE_PRUNING_DENSE_RATIO] > 0 and output[
SPARSE_PRUNING_DENSE_RATIO] < 1, f"Invalid dense_ratio value. Must be less than 1"
output[SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE] = get_scalar_param(
sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE, SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE_DEFAULT)
output[SPARSE_PRUNING_EXCLUDED_MODULES] = get_list_param(sub_param_dict, SPARSE_PRUNING_EXCLUDED_MODULES,
SPARSE_PRUNING_EXCLUDED_MODULES_DEFAULT)
output[SPARSE_PRUNING_SCHEDULE_OFFSET_END] = get_scalar_param(sub_param_dict,
SPARSE_PRUNING_SCHEDULE_OFFSET_END,
output[SPARSE_PRUNING_SCHEDULE_OFFSET])
assert output[SPARSE_PRUNING_SCHEDULE_OFFSET] <= output[
SPARSE_PRUNING_SCHEDULE_OFFSET_END], f"Invalid schedule_offset and schedule_offset_end values"
else:
output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT
output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT
Expand Down
15 changes: 15 additions & 0 deletions deepspeed/compression/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DIFFERENT_GROUPS = "different_groups"
TECHNIQUE_ENABLED = "enabled"
TECHNIQUE_SCHEDULE_OFFSET = "schedule_offset"
TECHNIQUE_SCHEDULE_OFFSET_END = "schedule_offset_end"
DIFFERENT_GROUPS_PARAMETERS = "params"
DIFFERENT_GROUPS_MODULE_SCOPE = "modules"
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT = "*"
Expand Down Expand Up @@ -111,11 +112,25 @@
SPARSE_PRUNING_METHOD_DEFAULT = "l1"
SPARSE_PRUNING_METHOD_L1 = "l1"
SPARSE_PRUNING_METHOD_TOPK = "topk"
SPARSE_PRUNING_METHOD_SNIP_MOMENTUM = "snip_momentum"

SPARSE_PRUNING_BLOCK_PATTERN = "block_pattern"
SPARSE_PRUNING_BLOCK_PATTERN_DEFAULT = "4x1"

SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE = "schedule_offset_stride"
SPARSE_PRUNING_SCHEDULE_OFFSET_STRIDE_DEFAULT = 1

SPARSE_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000

SPARSE_PRUNING_SCHEDULE_OFFSET_END = TECHNIQUE_SCHEDULE_OFFSET_END
SPARSE_PRUNING_SCHEDULE_OFFSET_END_DEFAULT = SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT

SPARSE_PRUNING_DENSE_RATIO = "dense_ratio"
SPARSE_PRUNING_DENSE_RATIO_DEFAULT = 0.1

SPARSE_PRUNING_EXCLUDED_MODULES = "excluded_modules"
SPARSE_PRUNING_EXCLUDED_MODULES_DEFAULT = []
###
# Row Pruning
###
Expand Down
74 changes: 74 additions & 0 deletions deepspeed/compression/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
import torch
from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
from .constants import *
from deepspeed.utils import logger

try:
from neural_compressor.compression import pruner as nc_pruner
except ImportError as e:
nc_pruner = None


def recursive_getattr(model, module_name):
Expand Down Expand Up @@ -246,3 +252,71 @@ def convert_conv1d_to_linear(model, convert_type):
recursive_setattr(c_model, name, new_module)

return model


def generate_pruners(config, model):
"""Generate pruners.
Args:
config (`neural_compressor.WeightPruningConfig`)
The object to the class WeightPruningConfig.
model (`torch.nn.module`)
The torch module object to be pruned.
"""
assert nc_pruner is not None, "please ensure the neural_compressor python package is installed by pip or conda if user wants to use snip_momentum sparse pruning"
from nc_pruner.utils import process_config, parse_to_prune
from nc_pruner.pruners import get_pruner
assert isinstance(model, torch.nn.Module)
pruners_info = process_config(config)
pruners = []
for info in pruners_info:
modules = parse_to_prune(info, model)
if modules == {}:
logger.warning("one pruner hooks no layers, please have a check")

pruners.append(get_pruner(info, modules))
info['modules'] = [key for key in modules.keys()]
info['len_of_modules'] = len(info['modules'])
logger.info(info)
return pruners


def register_on_step_begin(model):
"""Mount on_step_begin to the model.
Args:
model (`torch.nn.module`)
The torch module object to be pruned.
"""

def hook(module, input):
for pruner in module.pruners:
pruner.on_step_begin(0)

hook_handle = model.register_forward_pre_hook(hook)
return hook_handle


def rewrite_optimizer_step(opt: torch.optim.Optimizer):
"""Mount on_before/after_optimizer_step to the optimizer.
Args:
model (`torch.opt.Optimizer`)
The torch optimizer object to be hooked.
"""

def new_step(self, closure=None):
if hasattr(self, "pruners"):
for pruner in self.pruners:
pruner.on_before_optimizer_step()

if closure is not None:
res = self.orig_step(closure)
else:
res = self.orig_step()
if hasattr(self, "pruners"):
for pruner in self.pruners:
pruner.on_after_optimizer_step()
return res

opt.orig_step = opt.step
import types
opt.step = types.MethodType(new_step, opt)
return opt
4 changes: 3 additions & 1 deletion deepspeed/compression/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def check_sparse_pruning(self):
return
else:
shared_parameters = sp[SHARED_PARAMETERS]
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
if self.training_steps >= shared_parameters[
TECHNIQUE_SCHEDULE_OFFSET] and self.training_steps <= shared_parameters[
TECHNIQUE_SCHEDULE_OFFSET_END]:
for group_name, module_name_list, method_parameters in sp[DIFFERENT_GROUPS]:
for module_name in module_name_list:
module = recursive_getattr(self.model, module_name)
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ def __init__(
elif self.bfloat16_enabled():
self.optimizer = self._configure_bf16_optimizer(optimizer=None)

# Hook optimizer for snip_momentum pruning
if hasattr(model, 'pruners'):
from ..compression.helper import rewrite_optimizer_step
self.optimizer.pruners = model.pruners
rewrite_optimizer_step(self.optimizer)

# Bookkeeping for sparse support
self.sparse_tensor_module_names = set()
# if self.sparse_gradients_enabled():
Expand Down
27 changes: 26 additions & 1 deletion docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,25 @@ Different quantization sets, this is used for different quantization parameters.
}
```

```json
"compression_training": {
"sparse_pruning":{
"shared_parameters":{
"enabled": true,
"schedule_offset": 30,
"schedule_offset_end": 90,
"schedule_offset_stride": 15,
"method": "snip_momentum",
"block_pattern": "4x1",
"dense_ratio": 0.4,
"excluded_modules": ['classifier', 'pooler']
},
"different_groups":{
}
}
}
```

<i>**shared_parameters**</i>: [dictionary]

Shared parameters for all sparse pruning groups.
Expand All @@ -1443,11 +1462,17 @@ Shared parameters for all sparse pruning groups.
| ----- | ----- | ----- |
| <i>**enabled**</i>: [boolean] | Enable sparse pruning or not. | `false` |
| <i>**schedule_offset**</i>: [integer] | Enable sparse pruning after scheduled steps (can be treated as warmup steps). | `0` |
| <i>**method**</i>: [string] | Choose different pruning methods, l1 (static, magnitude based) or topk (dynamic, learnable). | `"l1"` |
| <i>**schedule_offset_end**</i>: [integer] | Disable sparse pruning after scheduled steps, mandotory for `snip_momentum`. | `0` |
| <i>**schedule_offset_stride**</i>: [integer] | The stride of pruning on training steps, mandotory for `snip_momentum`. | `"1"` |
| <i>**method**</i>: [string] | Choose different pruning methods, l1 (static, magnitude based), topk (dynamic, learnable) or snip_momentum (structured pruning). | `"l1"` |
| <i>**block_pattern**</i>: [string] | Choose different structured pruning block patterns, NxM or N:M (N and M are integers). For instance, "4x1" or "2:4" are common block patterns, mandotory for `snip_momentum`. | `"4x1"` |
| <i>**dense_ratio**</i>: [float] | Used to get the targeted global sparsity ratio, mandotory for `snip_momentum`. | `"0.1"` |
| <i>**excluded_modules**</i>: [list] | Excluded pruning scope on some special modules like output layer. | `[]` |

<i>**different_groups**</i>: [dictionary]

Different pruning sets, this is used for different pruning parameters. In this example, we give one set. In practice, you can choose the number of sets based on your requirements.
Note for `snip_momentum` method, you can leave it as empty.

| Fields | Value | Default |
| ----- | ----- | ----- |
Expand Down
12 changes: 7 additions & 5 deletions docs/_tutorials/model-compression.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,15 @@ Pruning aims to reduce the number of parameters and operations involved in gener

| **Method** | **Type** |
| --------------------- | ------------ |
| [Sparse pruning](#141-sparse-pruning) | Unstructured |
| [Sparse pruning](#141-sparse-pruning) | Unstructured and Structured |
| [Row pruning](#142-row-pruning) | Structured |
| [Head pruning](#143-head-pruning) | Structured |
| [Channel pruning](#144-channel-pruning) | Structured |

#### 1.4.1 Sparse Pruning
**What is sparse pruning**

Sparse pruning means we set some of the elements in each weight matrix with zero values. There is no structure pattern in the zero values. One way to perform pruning is based on the absolute value of the weight parameters, see for instance [this paper](https://arxiv.org/abs/1506.02626).
Sparse pruning means we set some of the elements in each weight matrix with zero values. Relying on the pruning method user chosen, the zero values may have structured pattern or unstructured pattern. One way to perform pruning is based on the absolute value of the weight parameters, see for instance [this paper](https://arxiv.org/abs/1506.02626). Another way to perform pruning is based on the weights' effect to the loss function when they are masked, see for instance [this paper](https://arxiv.org/abs/1810.02340).

**When to use sparse pruning**

Expand All @@ -178,11 +178,13 @@ Sparse pruning can be enabled and configured using the DeepSpeed config JSON fil

(1)`schedule_offset`, we empirically find that when using `method: topk`, it’s better to set the `schedule_offset` to a large value such as 10% of the total training steps.

(2)`method`, we support L1 norm and topk methods. Users are welcome to contribute more methods.
(2)`method`, we support L1 norm, topk and snip_momentum methods. Users are welcome to contribute more methods.

(3)`sp1`, users can expand more groups such as `sp2`, `sp3`, etc.
(3)`sp1`, users can expand more groups such as `sp2`, `sp3`, etc. Note this is not needed for snip_momentum method.

(4)`dense_ratio`, for unstructured sparse pruning, the dense ratio could be less than 0.1 for BRET-base model while still yielding a good accuracy. For ResNet-50, the dense ratio could be as low as 0.3 while still having good accuracy on ImageNet.
(4)`dense_ratio`, for unstructured sparse pruning, the dense ratio could be less than 0.1 for BRET-base model while still yielding a good accuracy. For ResNet-50, the dense ratio could be as low as 0.3 while still having good accuracy on ImageNet. for structured sparse pruning like snip_momentum, the dense ratio should be specified in shared_parameters and is used to calculate the global sparsity ratio.

(5)`frequency`, `block_pattern` and `schedule_offset_end`, they are used to specify the pruning frequency on steps, the block-wise pruning pattern (NxM and N in M), and the end steps for pruning. For snip_momentum method, these configurations are mandotory.

The client code change is the same as [weight quantization](#12-weight-quantization).

Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-sparse_pruning.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
neural-compressor==2.1.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def fetch_requirements(path):
'autotuning': fetch_requirements('requirements/requirements-autotuning.txt'),
'autotuning_ml': fetch_requirements('requirements/requirements-autotuning-ml.txt'),
'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'),
'sparse': fetch_requirements('requirements/requirements-sparse_pruning.txt'),
'inf': fetch_requirements('requirements/requirements-inf.txt'),
'sd': fetch_requirements('requirements/requirements-sd.txt')
}
Expand Down

0 comments on commit 6938c44

Please sign in to comment.