diff --git a/src/brevitas/core/function_wrapper/__init__.py b/src/brevitas/core/function_wrapper/__init__.py index bf7c84e60..9d5d929d6 100644 --- a/src/brevitas/core/function_wrapper/__init__.py +++ b/src/brevitas/core/function_wrapper/__init__.py @@ -4,6 +4,7 @@ from .clamp import ClampMin from .clamp import ScalarClamp from .clamp import TensorClamp +from .learned_round import LearnedRoundSte from .misc import Identity from .misc import InplaceLogTwo from .misc import LogTwo @@ -21,3 +22,4 @@ from .shape import OverOutputChannelView from .shape import OverTensorView from .shape import StatsInputViewShapeImpl +from .stochastic_round import StochasticRoundSte diff --git a/src/brevitas/core/function_wrapper/stochastic_round.py b/src/brevitas/core/function_wrapper/stochastic_round.py new file mode 100644 index 000000000..529038d7b --- /dev/null +++ b/src/brevitas/core/function_wrapper/stochastic_round.py @@ -0,0 +1,46 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import torch + +from brevitas.function.ops_ste import round_ste + + +def stochastic_round_ste_fn(generator): + + class StochasticRoundSteFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + floor_x = torch.floor(x) + x_diff = torch.abs(x - floor_x) + prob = torch.bernoulli(x_diff, generator=generator) + out = torch.where(prob.to(torch.bool), floor_x + 1., floor_x) + return out + + @staticmethod + def backward(ctx, x_grad): + return x_grad + + return StochasticRoundSteFn.apply + + +class StochasticRoundSte(torch.nn.Module): + + def __init__(self, deterministic_inference=True, seed=None, device=None) -> None: + super().__init__() + self.generator = torch.Generator(device=device) + if seed is not None: + self.generator.manual_seed(seed) + self.round_fn = stochastic_round_ste_fn(self.generator) + self.deterministic_inference = deterministic_inference + if deterministic_inference: + self.inference_fn = round_ste + else: + self.inference_fn = None + + @torch.jit.ignore + def forward(self, x): + if self.deterministic_inference is not None and not self.training: + return self.inference_fn(x) + return self.round_fn(x) diff --git a/src/brevitas/inject/enum.py b/src/brevitas/inject/enum.py index 4b62d71b8..95ea05355 100644 --- a/src/brevitas/inject/enum.py +++ b/src/brevitas/inject/enum.py @@ -45,6 +45,7 @@ class FloatToIntImplType(AutoName): ROUND_TO_ZERO = auto() DPU = auto() LEARNED_ROUND = auto() + STOCHASTIC_ROUND = auto() class LearnedRoundImplType(AutoName): diff --git a/src/brevitas/quant/experimental/float_base.py b/src/brevitas/quant/experimental/float_base.py index 7c3738226..f1d191057 100644 --- a/src/brevitas/quant/experimental/float_base.py +++ b/src/brevitas/quant/experimental/float_base.py @@ -1,7 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from brevitas.core.function_wrapper import RoundSte from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling.float_scaling import FloatScaling from brevitas.inject import ExtendedInjector @@ -9,20 +8,21 @@ from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector from brevitas.quant.solver import ActQuantSolver from brevitas.quant.solver import WeightQuantSolver +from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum -class FloatWeightBase(ExtendedInjector): +class FloatWeightBase(SolveTensorQuantFloatToIntImplFromEnum): proxy_class = WeightQuantProxyFromInjector tensor_quant = FloatQuant signed = True - float_to_int_impl = RoundSte + float_to_int_impl_type = 'round' -class FloatActBase(ExtendedInjector): +class FloatActBase(SolveTensorQuantFloatToIntImplFromEnum): proxy_class = ActQuantProxyFromInjector tensor_quant = FloatQuant signed = True - float_to_int_impl = RoundSte + float_to_int_impl_type = 'round' class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver): diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 55b295253..2e5e1e982 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -6,6 +6,7 @@ from brevitas.core.function_wrapper.learned_round import LearnedRoundHardSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSigmoid from brevitas.core.function_wrapper.learned_round import LearnedRoundSte +from brevitas.core.function_wrapper.stochastic_round import StochasticRoundSte from brevitas.core.quant import * from brevitas.core.quant import QuantType from brevitas.core.restrict_val import * @@ -45,6 +46,8 @@ def solve_float_to_int_impl_from_enum(impl_type): return DPURoundSte elif impl_type == FloatToIntImplType.LEARNED_ROUND: return LearnedRoundSte + elif impl_type == FloatToIntImplType.STOCHASTIC_ROUND: + return StochasticRoundSte else: raise Exception(f"{impl_type} not recognized.") diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 5d3f2fcd2..5afed36c0 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -43,6 +43,13 @@ def float_to_int_impl_to_enum(module): elif isinstance(module, CeilSte): return FloatToIntImplType.CEIL elif isinstance(module, DPURoundSte): - return DPURoundSte + return FloatToIntImplType.DPU + elif isinstance(module, LearnedRoundSte): + return FloatToIntImplType.LEARNED_ROUND + elif isinstance(module, StochasticRoundSte): + if module.deterministic_inference: + return FloatToIntImplType.ROUND + else: + return FloatToIntImplType.STOCHASTIC_ROUND else: return None