Skip to content

Commit

Permalink
[Torch] Do not change state of the compression parameter (#2670)
Browse files Browse the repository at this point in the history
### Changes

`_compression_lr_multiplier` attribute introduced in #2531 is removed
from the `CompressionParameter`

### Reason for changes

`_compression_lr_multiplier` makes the `CompressionParameter` a stateful
parameter which for some reason does not work properly in
distributed/dataparallel mode


### Tests

torch_nightly/213/ - finished successfully
  • Loading branch information
daniil-lyakhov authored May 8, 2024
1 parent d507585 commit 359447b
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 7 deletions.
5 changes: 0 additions & 5 deletions nncf/torch/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,7 @@ def __init__(self, data: torch.Tensor = None, requires_grad: bool = True, compre
"""
super().__init__()

self._compression_lr_multiplier = compression_lr_multiplier
if compression_lr_multiplier is not None and self.dtype.is_floating_point:
self.requires_grad = True
self.register_hook(lambda grad: compression_lr_multiplier * grad)
self.requires_grad = requires_grad

@property
def compression_lr_multiplier(self):
return self._compression_lr_multiplier
3 changes: 2 additions & 1 deletion nncf/torch/sparsity/rb/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multipli
requires_grad=not self.frozen,
compression_lr_multiplier=compression_lr_multiplier,
)
self._compression_lr_multiplier = compression_lr_multiplier
self.binary_mask = binary_mask(self._mask)
self.register_buffer("uniform", torch.zeros(weight_shape))
self.mask_calculation_hook = MaskCalculationHook(self)
Expand All @@ -61,7 +62,7 @@ def get_config(self) -> Dict[str, Any]:
return {
self.WEIGHTS_SHAPE_KEY: list(self.mask.shape),
self.FROZEN_KEY: self.frozen,
self.COMPRESSION_LR_MULTIPLIER_KEY: self.mask.compression_lr_multiplier,
self.COMPRESSION_LR_MULTIPLIER_KEY: self._compression_lr_multiplier,
self.EPS_KEY: self.eps,
}

Expand Down
2 changes: 1 addition & 1 deletion tests/torch/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def test_rb_sparsity_mask_serialization():

assert list(recovered_mask.mask.shape) == ref_weights_shape
assert recovered_mask.frozen == ref_frozen
assert recovered_mask.mask.compression_lr_multiplier == ref_compression_lr_multiplier
assert recovered_mask._compression_lr_multiplier == ref_compression_lr_multiplier
assert recovered_mask.eps == ref_eps

assert torch.all(mask.mask == recovered_mask.mask)
Expand Down

0 comments on commit 359447b

Please sign in to comment.