Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Mar 25, 2024
1 parent 1c06382 commit 40290e0
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 120 deletions.
33 changes: 23 additions & 10 deletions tests/torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,20 +566,33 @@ def add_ref(
"""
Adds references hooks.
"""
op_address = self._convert_to_op_address(target_type, target_node_name, input_port_id)
op_address = self._convert_to_op_address(
target_type, target_node_name, input_port_id, self._target_model.nncf.replace_modules
)
self._ref_hooks[target_type].update({op_address: ref_hooks})

def _convert_to_op_address(self, target_type: TargetType, target_node_name: str, input_port_id: int) -> Any:
def _convert_to_op_address(
self, target_type: TargetType, target_node_name: str, input_port_id: int, replace_modules: bool
) -> Any:
address_map = self._target_model.nncf.get_node_to_op_address_mapping()
address = address_map[target_node_name]
if target_type == TargetType.OPERATOR_PRE_HOOK:
address = PreHookId(address, input_port_id)
elif target_type in [
TargetType.OPERATION_WITH_WEIGHTS,
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
]:
address = getattr(self._target_model, self._nncf_module_attr_name)
if replace_modules:
if target_type == TargetType.OPERATOR_PRE_HOOK:
address = PreHookId(address, input_port_id)
elif target_type in [
TargetType.OPERATION_WITH_WEIGHTS,
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
]:
address = getattr(self._target_model, self._nncf_module_attr_name)
else:
if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
address = PreHookId(address, input_port_id)
elif target_type in [
TargetType.PRE_LAYER_OPERATION,
TargetType.POST_LAYER_OPERATION,
]:
address = getattr(self._target_model, self._nncf_module_attr_name)
return address

def check_with_reference(self):
Expand Down
152 changes: 42 additions & 110 deletions tests/torch/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from tests.torch.ptq.test_ptq_params import ToNNCFNetworkInterface

IDENTITY_NODE_NAME = "PTIdentityConvModel/__add___0"
CONV_NODE_NAME = "PTIdentityConvModel/NNCFConv2d[conv]/conv2d_0"
CONV_NODE_NAME = "PTIdentityConvModel/Conv2d[conv]/conv2d_0"
INPUT_SHAPE = [1, 3, 3, 3]


Expand All @@ -55,9 +55,6 @@ def get_nncf_network(self):
return get_nncf_network(self, INPUT_SHAPE)


MinMaxTestParameters = TemplateTestStatisticsAggregator.MinMaxTestParameters


class TestStatisticsAggregator(TemplateTestStatisticsAggregator):
@staticmethod
def get_min_max_algo_backend_cls() -> Type[PTMinMaxAlgoBackend]:
Expand Down Expand Up @@ -97,7 +94,7 @@ def get_target_point(target_type: TargetType):
port_id = 0
if target_type == TargetType.OPERATION_WITH_WEIGHTS:
target_node_name = CONV_NODE_NAME
port_id = None
port_id = 1
return PTMinMaxAlgoBackend.target_point(target_type, target_node_name, port_id)

def get_target_point_cls(self):
Expand Down Expand Up @@ -134,169 +131,98 @@ def test_same_collectors_different_attrs_dont_merge(self, statistics_type, test_
pass

@pytest.mark.parametrize(
"test_parameters",
"target_point",
(
MinMaxTestParameters(
RangeEstimatorParametersSet.MINMAX,
TargetType.OPERATOR_PRE_HOOK,
QuantizationMode.SYMMETRIC,
False,
256,
-256,
),
MinMaxTestParameters(
RangeEstimatorParametersSet.MINMAX,
TargetType.OPERATION_WITH_WEIGHTS,
QuantizationMode.SYMMETRIC,
False,
256,
-256,
),
MinMaxTestParameters(
RangeEstimatorParametersSet.MINMAX,
TargetType.OPERATOR_POST_HOOK,
QuantizationMode.SYMMETRIC,
False,
256,
-256,
),
PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, IDENTITY_NODE_NAME, input_port_id=0),
PTTargetPoint(TargetType.OPERATION_WITH_WEIGHTS, CONV_NODE_NAME, input_port_id=1),
PTTargetPoint(TargetType.OPERATOR_POST_HOOK, IDENTITY_NODE_NAME, input_port_id=None),
),
)
def test_successive_statistics_aggregation(
self,
test_parameters: MinMaxTestParameters,
target_point: PTTargetPoint,
dataset_samples,
inplace_statistics,
is_backend_support_custom_estimators,
mocker,
):
is_stat_in_shape_of_scale = True
model = self.get_backend_model(dataset_samples)
quantizer_config = QuantizerConfig(
mode=test_parameters.quantization_mode, per_channel=test_parameters.per_channel
)

is_standard_estimator = test_parameters.range_estimator_params in [
RangeEstimatorParametersSet.MINMAX,
RangeEstimatorParametersSet.MEAN_MINMAX,
]
if not is_standard_estimator and not is_backend_support_custom_estimators:
pytest.skip("Custom estimators are not supported for this backend yet")
quantizer_config = QuantizerConfig(mode=QuantizationMode.SYMMETRIC, per_channel=False)

### Register operations before statistic collection
def fn(x):
return x * 2

target_point = self.get_target_point(test_parameters.target_type)
target_point = self.get_target_point(target_point.target_type)
model = self.__add_fn_to_model(model, target_point, fn)

### Check hook inserted correctly
self.__check_successive_hooks(test_parameters, model, target_point, fn)
self.__check_successive_hooks(model, target_point, fn)

### Register and collect statistics after inserted operations
statistic_points = self.__get_statistic_points(
test_parameters, model, quantizer_config, dataset_samples, inplace_statistics, mocker
target_point, model, quantizer_config, dataset_samples, inplace_statistics, mocker
)
tensor_collector = self.__collect_statistics_get_collector(statistic_points, model, dataset_samples)
### Check values are changed because of the inserted operation
self.__check_collector(
test_parameters,
tensor_collector,
is_stat_in_shape_of_scale,
)
self.__check_collector(target_point, tensor_collector, is_stat_in_shape_of_scale, -256, 256)

### Check the inserted operation is inside the model
self.__check_successive_hooks(test_parameters, model, target_point, fn)
self.__check_successive_hooks(model, target_point, fn)

@pytest.mark.parametrize(
"test_parameters, nested_target_node_name",
"target_point, nested_target_node_name",
(
(
MinMaxTestParameters(
RangeEstimatorParametersSet.MINMAX,
TargetType.OPERATOR_PRE_HOOK,
QuantizationMode.SYMMETRIC,
False,
512,
-512,
),
PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, IDENTITY_NODE_NAME, input_port_id=0),
"PTIdentityConvModel/fn_0",
),
(
MinMaxTestParameters(
RangeEstimatorParametersSet.MINMAX,
TargetType.OPERATION_WITH_WEIGHTS,
QuantizationMode.SYMMETRIC,
False,
512,
-512,
),
"PTIdentityConvModel/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/fn_0",
PTTargetPoint(TargetType.OPERATION_WITH_WEIGHTS, CONV_NODE_NAME, input_port_id=1),
"PTIdentityConvModel/Conv2d[conv]/fn_0",
),
(
MinMaxTestParameters(
RangeEstimatorParametersSet.MINMAX,
TargetType.OPERATOR_POST_HOOK,
QuantizationMode.SYMMETRIC,
False,
512,
-512,
),
PTTargetPoint(TargetType.OPERATOR_POST_HOOK, IDENTITY_NODE_NAME, input_port_id=None),
"PTIdentityConvModel/fn_0",
),
),
)
@pytest.mark.parametrize("nested_target_type", [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATOR_POST_HOOK])
def test_nested_statistics_aggregation(
self,
test_parameters: MinMaxTestParameters,
target_point: PTTargetPoint,
nested_target_type: TargetType,
nested_target_node_name,
dataset_samples,
inplace_statistics,
is_backend_support_custom_estimators,
mocker,
):
is_stat_in_shape_of_scale = True
model = self.get_backend_model(dataset_samples)
quantizer_config = QuantizerConfig(
mode=test_parameters.quantization_mode, per_channel=test_parameters.per_channel
)

is_standard_estimator = test_parameters.range_estimator_params in [
RangeEstimatorParametersSet.MINMAX,
RangeEstimatorParametersSet.MEAN_MINMAX,
]
if not is_standard_estimator and not is_backend_support_custom_estimators:
pytest.skip("Custom estimators are not supported for this backend yet")
quantizer_config = QuantizerConfig(mode=QuantizationMode.SYMMETRIC, per_channel=False)

### Register operations before statistic collection
@register_operator()
def fn(x):
return x * 2

target_point = self.get_target_point(test_parameters.target_type)
model = self.__add_fn_to_model(model, target_point, fn)
nested_target_point = PTMinMaxAlgoBackend.target_point(nested_target_type, nested_target_node_name, 0)
model = self.__add_fn_to_model(model, nested_target_point, fn)

### Check hook inserted correctly
self.__check_nested_hooks(test_parameters, model, target_point, nested_target_type, nested_target_node_name, fn)
self.__check_nested_hooks(model, target_point, nested_target_type, nested_target_node_name, fn)

### Register and collect statistics after inserted operations
statistic_points = self.__get_statistic_points(
test_parameters, model, quantizer_config, dataset_samples, inplace_statistics, mocker
target_point, model, quantizer_config, dataset_samples, inplace_statistics, mocker
)
tensor_collector = self.__collect_statistics_get_collector(statistic_points, model, dataset_samples)
### Check values are changed because of the inserted operation
self.__check_collector(
test_parameters,
tensor_collector,
is_stat_in_shape_of_scale,
)
self.__check_collector(target_point, tensor_collector, is_stat_in_shape_of_scale, -512, 512)

### Check the inserted operation is inside the model
self.__check_nested_hooks(test_parameters, model, target_point, nested_target_type, nested_target_node_name, fn)
self.__check_nested_hooks(model, target_point, nested_target_type, nested_target_node_name, fn)

@staticmethod
def __add_fn_to_model(model, target_point, fn):
Expand All @@ -310,10 +236,10 @@ def __add_fn_to_model(model, target_point, fn):

@classmethod
def __get_statistic_points(
cls, test_parameters: MinMaxTestParameters, model, quantizer_config, dataset_samples, inplace_statistics, mocker
cls, target_point: PTTargetPoint, model, quantizer_config, dataset_samples, inplace_statistics, mocker
) -> StatisticPointsContainer:
statistics_points = StatisticPointsContainer()
for target_type in [test_parameters.target_type]:
for target_type in [target_point.target_type]:
target_point = cls.get_target_point(target_type)
statistic_point = cls.create_statistics_point(
model,
Expand All @@ -322,7 +248,7 @@ def __get_statistic_points(
len(dataset_samples),
"TEST_ALGO",
inplace_statistics,
test_parameters.range_estimator_params,
RangeEstimatorParametersSet.MINMAX,
mocker,
)
statistics_points.add_statistic_point(statistic_point)
Expand All @@ -345,14 +271,16 @@ def __collect_statistics_get_collector(
return tensor_collectors[0][2]

@staticmethod
def __check_collector(test_parameters, tensor_collector, stat_in_shape_of_scale):
def __check_collector(
target_point: PTTargetPoint, tensor_collector, stat_in_shape_of_scale, ref_min_val, ref_max_val
):
stat = tensor_collector.get_statistics()
# Torch and Openvino backends tensor collectors return values in shape of scale
# in comparison to ONNX backends.
ref_min_val, ref_max_val = test_parameters.ref_min_val, test_parameters.ref_max_val

if isinstance(ref_min_val, np.ndarray) and stat_in_shape_of_scale:
shape = (1, 3, 1, 1)
if test_parameters.target_type == TargetType.OPERATION_WITH_WEIGHTS:
if target_point.target_type == TargetType.OPERATION_WITH_WEIGHTS:
shape = (3, 1, 1, 1)
ref_min_val, ref_max_val = map(lambda x: np.reshape(x, shape), (ref_min_val, ref_max_val))

Expand All @@ -367,26 +295,30 @@ def __check_collector(test_parameters, tensor_collector, stat_in_shape_of_scale)
assert stat.max_values.shape == ref_shape

@staticmethod
def __check_successive_hooks(test_parameters, model, target_point, fn):
def __check_successive_hooks(model, target_point: PTTargetPoint, fn):
checker = HookChecker(model, "conv")
checker.add_ref(
ref_hooks=[fn],
target_type=test_parameters.target_type,
target_type=TargetType.OPERATOR_PRE_HOOK
if target_point.target_type == TargetType.OPERATION_WITH_WEIGHTS
else target_point.target_type,
target_node_name=target_point.target_node_name,
input_port_id=0,
input_port_id=target_point.input_port_id,
)
checker.check_with_reference()

@staticmethod
def __check_nested_hooks(
test_parameters, model, target_point, nested_target_type: TargetType, nested_target_node_name: str, fn
model, target_point: PTTargetPoint, nested_target_type: TargetType, nested_target_node_name: str, fn
):
checker = HookChecker(model, "conv")
checker.add_ref(
ref_hooks=[fn],
target_type=test_parameters.target_type,
target_type=TargetType.OPERATOR_PRE_HOOK
if target_point.target_type == TargetType.OPERATION_WITH_WEIGHTS
else target_point.target_type,
target_node_name=target_point.target_node_name,
input_port_id=0,
input_port_id=target_point.input_port_id,
)
checker.add_ref(
ref_hooks=[fn],
Expand Down

0 comments on commit 40290e0

Please sign in to comment.