Skip to content

Commit

Permalink
retrocompatibility fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 7, 2024
1 parent 28bd14f commit 8fa0025
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 38 deletions.
1 change: 0 additions & 1 deletion src/brevitas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,3 @@ def env_to_bool(name, default):
_FULL_STATE_DICT = False
_IS_INSIDE_QUANT_LAYER = None
_ONGOING_EXPORT = None
_RETROCOMPATIBLE_SCALING = False
16 changes: 8 additions & 8 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return x
def combine_stats_threshold(self, x, threshold):
return x / threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> Tensor:
Expand All @@ -116,8 +116,8 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def retrocompatibility_op(self, x):
return self.power_of_two(x)
def combine_stats_threshold(self, x, threshold):
return x - threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
Expand All @@ -143,8 +143,8 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def retrocompatibility_op(self, x):
return x
def combine_stats_threshold(self, x, threshold):
return x / threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
Expand All @@ -171,8 +171,8 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def retrocompatibility_op(self, x):
return self.power_of_two(x)
def combine_stats_threshold(self, x, threshold):
return x - threshold

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
Expand Down
5 changes: 4 additions & 1 deletion src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,16 @@ def __init__(
self.affine_rescaling = Identity()
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
self.restrict_scaling_impl = restrict_scaling_impl

@brevitas.jit.script_method
def forward(
self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats)
stats = self.restrict_scaling_pre(stats / threshold)
threshold = self.restrict_scaling_pre(threshold)
stats = self.restrict_scaling_pre(stats)
stats = self.restrict_scaling_impl.combine_stats_threshold(stats, threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
return stats
Expand Down
60 changes: 33 additions & 27 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,27 @@ def __init__(
scaling_init = scaling_init.to(device=device, dtype=dtype)
if restrict_scaling_impl is not None:
scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
else:
self.restrict_init_module = Identity()
self.value = StatelessBuffer(scaling_init.detach())
else:
if restrict_scaling_impl is not None:
scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init)
self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
else:
self.restrict_init_module = Identity()
self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(placeholder)
value = self.value() / threshold
restricted_value = self.restrict_clamp_scaling(value)
# We first apply any restriction to scaling
# For IntQuant, this is no-op, retrocompatible.
threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold))
restricted_value = self.restrict_clamp_scaling(self.value())
restricted_value = restricted_value / threshold
return restricted_value


Expand Down Expand Up @@ -145,6 +154,9 @@ def __init__(
scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device)
if restrict_scaling_impl is not None:
scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init)
self.restrict_init_module = restrict_scaling_impl.restrict_init_module()
else:
self.restrict_init_module = Identity()
if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None:
scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device)
self.value = Parameter(scaling_init)
Expand All @@ -154,8 +166,11 @@ def __init__(
def forward(self, placeholder: Tensor, threshold: Optional[torch.Tensor] = None) -> Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(placeholder)
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value) / threshold)
return value
# We first apply any restriction to scaling
# For IntQuant, this is no-op, retrocompatible.
threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold))
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))
return value / threshold

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
Expand Down Expand Up @@ -217,18 +232,21 @@ def forward(
# This is because we don't want to store a parameter dependant on a runtime value (threshold)
# And because restrict needs to happen after we divide by threshold
if self.init_done:
value = self.restrict_preprocess(self.value / threshold)
value = self.restrict_scaling_impl.combine_stats_threshold(
self.value, self.restrict_inplace_preprocess(threshold))
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value))
return value
else:
stats = self.parameter_list_stats()
# workaround to avoid find_ununsed_parameter=True in DDP
stats = stats + 0. * self.value
if self.local_loss_mode:
return self.stats_scaling_impl(stats, threshold)
return self.stats_scaling_impl(stats)
stats = self.restrict_inplace_preprocess(stats)
threshold = self.restrict_inplace_preprocess(threshold)
inplace_tensor_mul(self.value.detach(), stats)
value = self.restrict_preprocess(self.value / threshold)
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value))
value = self.restrict_scaling_impl.combine_stats_threshold(stats, threshold)
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value))
self.init_done = True
return value

Expand All @@ -245,14 +263,6 @@ def _load_from_state_dict(
error_msgs):
value_key = prefix + 'value'

# Before, the parameter would be stored after restrict_preprocess (e.g., Log2)
# When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2)
# Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2)
if config._RETROCOMPATIBLE_SCALING:
if not isinstance(self.restrict_scaling_impl, Identity):
state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op(
state_dict[value_key])

super(ParameterFromStatsFromParameterScaling, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
# disable stats collection when a pretrained value is loaded
Expand Down Expand Up @@ -365,12 +375,15 @@ def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tens
self.counter = new_counter
return abs_binary_sign_grad(clamped_stats / threshold)
elif self.counter == self.collect_stats_steps:
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
value = self.restrict_preprocess(self.value / threshold)
threshold = self.restrict_preprocess(threshold)
value = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold)
self.counter = self.counter + 1
return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))
else:
value = self.restrict_preprocess(self.value / threshold)
threshold = self.restrict_preprocess(threshold)
value = self.restrict_scaling_impl.combine_stats_threshold(value, threshold)
return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value)))

@brevitas.jit.script_method
Expand All @@ -385,7 +398,8 @@ def forward(self, stats_input: Tensor, threshold: Optional[torch.Tensor] = None)
out = self.buffer / threshold
out = self.restrict_preprocess(out)
else:
out = self.restrict_preprocess(self.value / threshold)
threshold = self.restrict_preprocess(threshold)
out = self.restrict_scaling_impl.combine_stats_threshold(self.value, threshold)
out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out)))
return out

Expand All @@ -411,14 +425,6 @@ def _load_from_state_dict(
if retrocomp_value_key in state_dict:
state_dict[value_key] = state_dict.pop(retrocomp_value_key)

# Before, the parameter would be stored after restrict_preprocess (e.g., Log2)
# When we load, if retrocompatibility is enabled, we perform the opposite operation (e.g., Po2)
# Knowing that during the forward pass we will re-apply restrict_preprocess (e.g., again Log2)
if config._RETROCOMPATIBLE_SCALING:
if not isinstance(self.restrict_scaling_impl, Identity):
state_dict[value_key] = self.restrict_scaling_impl.retrocompatibility_op(
state_dict[value_key])

super(ParameterFromRuntimeStatsScaling, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
# Buffer is supposed to be always missing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from brevitas.export import export_qonnx
from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_4b

config._RETROCOMPATIBLE_SCALING = True
QUARTZNET_POSTPROCESSED_INPUT_SIZE = (1, 64, 256) # B, features, sequence
MIN_INP_VAL = 0.0
MAX_INP_VAL = 200.0
Expand Down

0 comments on commit 8fa0025

Please sign in to comment.