Skip to content

L1/L2 regularization of network weights with NNX #4160

Answered by cgarciae
pushkar5586 asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @pushkar5586, sorry for the delay. To apply global regularization you could use nnx.state to extract the Params and then follow recipe from #1654. Here is the basic example on the landing page with L2 regularization:

from flax import nnx
import optax
import 


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@hyunwoooh5
Comment options

@cgarciae
Comment options

Answer selected by pushkar5586
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants