optax.multi_transform
+ nnx.State
/nnx.Optimizer
troubles
#3955
Labels
Priority: P1 - soon
Response within 5 business days. Resolution within 30 days. (Assignee required)
Discussed in #3954
Originally posted by yklcs June 1, 2024
optax.multi_transform
defines multiple transforms with aMapping[Hashable, GradientTransformation]
and uses a PyTree or function to map parameters to the key.Using
optax.multi_transform
withnnx.Optimizer
means said mapping of typennx.State
is needed.nnx.State
is typed to useStateLeaf
which means we can't use string or integer keys.While ignoring typing does work, it feels brittle and might end up broken later.
Is there any other solution for this problem?
The text was updated successfully, but these errors were encountered: