Skip to content

Commit

Permalink
Feat (core): support for stochastic round (#689)
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius authored Jul 28, 2023
1 parent 0dc9268 commit 09dce0e
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/brevitas/core/function_wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .clamp import ClampMin
from .clamp import ScalarClamp
from .clamp import TensorClamp
from .learned_round import LearnedRoundSte
from .misc import Identity
from .misc import InplaceLogTwo
from .misc import LogTwo
Expand All @@ -21,3 +22,4 @@
from .shape import OverOutputChannelView
from .shape import OverTensorView
from .shape import StatsInputViewShapeImpl
from .stochastic_round import StochasticRoundSte
46 changes: 46 additions & 0 deletions src/brevitas/core/function_wrapper/stochastic_round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import torch

from brevitas.function.ops_ste import round_ste


def stochastic_round_ste_fn(generator):

class StochasticRoundSteFn(torch.autograd.Function):

@staticmethod
def forward(ctx, x):
floor_x = torch.floor(x)
x_diff = torch.abs(x - floor_x)
prob = torch.bernoulli(x_diff, generator=generator)
out = torch.where(prob.to(torch.bool), floor_x + 1., floor_x)
return out

@staticmethod
def backward(ctx, x_grad):
return x_grad

return StochasticRoundSteFn.apply


class StochasticRoundSte(torch.nn.Module):

def __init__(self, deterministic_inference=True, seed=None, device=None) -> None:
super().__init__()
self.generator = torch.Generator(device=device)
if seed is not None:
self.generator.manual_seed(seed)
self.round_fn = stochastic_round_ste_fn(self.generator)
self.deterministic_inference = deterministic_inference
if deterministic_inference:
self.inference_fn = round_ste
else:
self.inference_fn = None

@torch.jit.ignore
def forward(self, x):
if self.deterministic_inference and not self.training:
return self.inference_fn(x)
return self.round_fn(x)
1 change: 1 addition & 0 deletions src/brevitas/inject/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class FloatToIntImplType(AutoName):
ROUND_TO_ZERO = auto()
DPU = auto()
LEARNED_ROUND = auto()
STOCHASTIC_ROUND = auto()


class LearnedRoundImplType(AutoName):
Expand Down
10 changes: 5 additions & 5 deletions src/brevitas/quant/experimental/float_base.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.core.function_wrapper import RoundSte
from brevitas.core.quant.float import FloatQuant
from brevitas.core.scaling.float_scaling import FloatScaling
from brevitas.inject import ExtendedInjector
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.quant.solver import ActQuantSolver
from brevitas.quant.solver import WeightQuantSolver
from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum


class FloatWeightBase(ExtendedInjector):
class FloatWeightBase(SolveTensorQuantFloatToIntImplFromEnum):
proxy_class = WeightQuantProxyFromInjector
tensor_quant = FloatQuant
signed = True
float_to_int_impl = RoundSte
float_to_int_impl_type = 'round'


class FloatActBase(ExtendedInjector):
class FloatActBase(SolveTensorQuantFloatToIntImplFromEnum):
proxy_class = ActQuantProxyFromInjector
tensor_quant = FloatQuant
signed = True
float_to_int_impl = RoundSte
float_to_int_impl_type = 'round'


class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver):
Expand Down
3 changes: 3 additions & 0 deletions src/brevitas/quant/solver/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from brevitas.core.function_wrapper.learned_round import LearnedRoundHardSigmoid
from brevitas.core.function_wrapper.learned_round import LearnedRoundSigmoid
from brevitas.core.function_wrapper.learned_round import LearnedRoundSte
from brevitas.core.function_wrapper.stochastic_round import StochasticRoundSte
from brevitas.core.quant import *
from brevitas.core.quant import QuantType
from brevitas.core.restrict_val import *
Expand Down Expand Up @@ -45,6 +46,8 @@ def solve_float_to_int_impl_from_enum(impl_type):
return DPURoundSte
elif impl_type == FloatToIntImplType.LEARNED_ROUND:
return LearnedRoundSte
elif impl_type == FloatToIntImplType.STOCHASTIC_ROUND:
return StochasticRoundSte
else:
raise Exception(f"{impl_type} not recognized.")

Expand Down
9 changes: 8 additions & 1 deletion src/brevitas/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def float_to_int_impl_to_enum(module):
elif isinstance(module, CeilSte):
return FloatToIntImplType.CEIL
elif isinstance(module, DPURoundSte):
return DPURoundSte
return FloatToIntImplType.DPU
elif isinstance(module, LearnedRoundSte):
return FloatToIntImplType.LEARNED_ROUND
elif isinstance(module, StochasticRoundSte):
if module.deterministic_inference:
return FloatToIntImplType.ROUND
else:
return FloatToIntImplType.STOCHASTIC_ROUND
else:
return None

0 comments on commit 09dce0e

Please sign in to comment.