Skip to content

Commit

Permalink
[PT][WC] Add WC tests for functional models (openvinotoolkit#2446)
Browse files Browse the repository at this point in the history
### Changes

Fixed bug with `channel_idx` for Torch Conv metatypes
Fixed bug in `get_module_by_name()` if base model has weights as a
`torch.nn.Parameter`

### Reason for changes

<!--- Why should the change be applied -->

### Related tickets

124822

### Tests

Added `test_compress_weights_functional_model`
  • Loading branch information
l-bat authored Feb 5, 2024
1 parent 1391667 commit 2872f27
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
12 changes: 10 additions & 2 deletions nncf/quantization/algorithms/weight_compression/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@

def split_weight_name(weight_name: str) -> Tuple[str, str]:
index = weight_name.rfind(".")
if index == -1:
return str(), weight_name
module_name = weight_name[:index]
weight_attr_name = weight_name[index + 1 :]
return module_name, weight_attr_name


def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Module:
if not module_name:
return model
curr_module = model
for name in module_name.split("."):
for child_name, child_module in curr_module.named_children():
Expand Down Expand Up @@ -161,8 +165,12 @@ def get_channel_agnostic_reduction_axes(
elif weight_port_id == 2:
reduction_axes = [max(0, ndims - 2)]
elif node_with_weight.metatype in PTWeightCompressionAlgoBackend.CONVOLUTION_METATYPES:
layer_attributes = node_with_weight.layer_attributes
channel_idx = layer_attributes.get_target_dim_for_compression()
channel_idx = (
1
if node_with_weight.metatype
in [om.PTConvTranspose1dMetatype, om.PTConvTranspose2dMetatype, om.PTConvTranspose3dMetatype]
else 0
)
reduction_axes = [i for i in range(ndims) if i != channel_idx]
return tuple(reduction_axes)

Expand Down
41 changes: 41 additions & 0 deletions tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from nncf import SensitivityMetric
from nncf.quantization import compress_weights
from nncf.torch import wrap_model
from nncf.torch.quantization.layers import WeightsDecompressor

DATA_BASED_SENSITIVITY_METRICS = (
SensitivityMetric.HESSIAN_INPUT_ACTIVATION,
Expand Down Expand Up @@ -52,6 +53,32 @@ def forward(self, input_ids):
return res


class NestedMatMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(torch.ones(size=(300, 300), dtype=torch.float32))

def forward(self, input):
return input @ self.w


class FunctionalModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv_w = torch.nn.Parameter(torch.ones(size=(5, 3, 3, 3), dtype=torch.float32))
self.matmul_w = torch.nn.Parameter(torch.ones(size=(1, 3, 300, 300), dtype=torch.float32))
self.conv_tr_w = torch.nn.Parameter(torch.rand(size=(5, 4, 3, 3)))
self.nested_matmul = NestedMatMul()

def forward(self, input_):
x = input_.to(torch.float32)
x = x @ self.matmul_w
x = self.nested_matmul(x)
x = F.conv2d(x, self.conv_w)
x = F.conv_transpose2d(x, self.conv_tr_w)
return x


class ConvolutionModel(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -98,6 +125,20 @@ def test_compress_weights():
assert n_compressed_weights == n_target_modules


def test_compress_weights_functional_model():
model = FunctionalModel()

input_ids = torch.randint(0, 10, [1, 3, 300, 300])
wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True)
compressed_model = compress_weights(wrapped_model)

n_compressed_weights = 0
for layer in compressed_model.nncf.external_op.values():
if isinstance(layer, WeightsDecompressor):
n_compressed_weights += 1
assert n_compressed_weights == 4


def test_compress_weights_conv():
model = ConvolutionModel()

Expand Down

0 comments on commit 2872f27

Please sign in to comment.