From 232b7f2eedb3780b8f126c8669448e9fd8796167 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Thu, 15 Jun 2023 18:32:57 +0200 Subject: [PATCH] Feat (nn): cache modules that require subtensor slicing --- src/brevitas/nn/mixin/parameter.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/brevitas/nn/mixin/parameter.py b/src/brevitas/nn/mixin/parameter.py index 898006444..5ece48325 100644 --- a/src/brevitas/nn/mixin/parameter.py +++ b/src/brevitas/nn/mixin/parameter.py @@ -34,6 +34,7 @@ def __init__(self, weight_quant: Optional[WeightQuantType], **kwargs): kwargs_prefix='weight_', proxy_prefix='weight_', **kwargs) + self._cached_sub_tensor_slice_list_modules = None @property @abstractmethod @@ -64,9 +65,17 @@ def quant_weight( # prepare the quantizer for a subtensor input, if any modifications are required # we set a list of tuples rather than a list of slices so that it's jit friendly # slices generation is handled by each module internally - for m in self.weight_quant.modules(): - if hasattr(m, 'subtensor_slice_list'): - m.setattr('subtensor_slice_list', subtensor_slice_list) + + # we cache which modules require the attribute + if self._cached_sub_tensor_slice_list_modules is not None: + for m in self._cached_sub_tensor_slice_list_modules: + m.subtensor_slice_list = subtensor_slice_list + else: + self._cached_sub_tensor_slice_list_modules = [] + for m in self.weight_quant.modules(): + if hasattr(m, 'subtensor_slice_list'): + self._cached_sub_tensor_slice_list_modules.append(m) + m.subtensor_slice_list = subtensor_slice_list # generate slices for the weight tensor based on the list passed in weight_slice_tuple = tuple( slice(*s) if s is not None else slice(s) for s in subtensor_slice_list) @@ -91,9 +100,10 @@ def quant_weight( out = self.weight_quant(self.weight[weight_slice_tuple]) if subtensor_slice_list is not None: # Restore the quantizer behaviour to full tensor quantization - for m in self.weight_quant.modules(): - if hasattr(m, 'subtensor_slice_list'): - m.setattr('subtensor_slice_list', None) + # The modules to slice should have been cached already at this point + assert self._cached_sub_tensor_slice_list_modules is not None, "Missing cache of modules to slice." + for m in self._cached_sub_tensor_slice_list_modules: + m.subtensor_slice_list = None return out def int_weight(self, float_datatype=False):