-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace jax.lax.select with jnp.where #25
Conversation
Hi @nlsfnr , thank you for the PR. I think the issue in your code snippet is that we expect the loss scale to not be an integer, in our example we suggest using the half dtype for the loss scale: loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15)) I think Perhaps one way to have the best of both worlds would be to add a |
Hi @tomhennigan, I accidentally closed the PR because I undid the only commit in it. The new commit enforces the dtype of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks good! We'll pull it on our internal copy later this week and a robot account will take care of marking this PR as merged shortly after.
Hi @tomhennigan, I noticed some shortcomings in my latest commits. First, they might break downstream code. An example are the unittests that I had to change, i.e. Second, and probably more importantly, the change seems to make
I believe this is because jax will call the LMK what you think of this. If you find the two try-blocks in the |
Thanks for the thoughtful comments, I agree that the post_init function is a bit busy now. I think a useful helper function would be the following, it will act as the identity for tracers: def as_jax_array(x) -> jax.Array:
if not isinstance(s.loss_scale, jax.Array):
x = jnp.asarray(x)
return x We can then extract out a helper for the check_floating test: def check_floating(name, x):
if not jnp.issubdtype(x.dtype, jnp.floating):
warnings.warn(f'Expected floating type for {name}, got {x.dtype}') Finally, we can re-assign all fields and apply the check in post_init: def __post_init__(self):
object.__setattr__(self, 'loss_scale', as_jax_array(self.loss_scale)
object.__setattr__(self, 'min_loss_scale', as_jax_array(self.min_loss_scale)
check_floating('loss_scale', self.loss_scale)
check_floating('min_loss_scale', self.min_loss_scale) |
Hi @tomhennigan , I've had a deeper look at the issues behind the mysterious The TLDR is that Jax needs to determine the shape of custom PyTrees during compilation. To do this, it passes raw To avoid this, they recommend checking if A notable alternative to all of this would be to convert |
Thank you for digging into this, I did not realise Looks like the unit tests pass, so we'll pull this into our internal repo and run tests to check internal usages pass. Should be merged soon 😄 |
-- 21900d0 by Nicolas Forstner <[email protected]>: enforce floating dtype for loss_scale and min_loss_scale in DynamicLossScale -- 4d9ce3d by Nicolas Forstner <[email protected]>: added unittests FUTURE_COPYBARA_INTEGRATE_REVIEW=#25 from nlsfnr:main 4d9ce3d PiperOrigin-RevId: 502849580
-- 21900d0 by Nicolas Forstner <[email protected]>: enforce floating dtype for loss_scale and min_loss_scale in DynamicLossScale -- 4d9ce3d by Nicolas Forstner <[email protected]>: added unittests FUTURE_COPYBARA_INTEGRATE_REVIEW=#25 from nlsfnr:main 4d9ce3d PiperOrigin-RevId: 502849580
Thanks for the awesome work!
This PR fixes an issue where
jax.lax.select
complains about dtypes not being equal when adjusting theDynamicLossScale
.The exception's stack trace ends with:
My code looks similar to