Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Use custom grad accumulation for FP8 params #3623

Merged

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Jan 16, 2024

This pull request introduces a custom data type rule for the FP8 parameters to implement custom gradient accumulation. Specifically, when reusing the FP8 parameters, the autograd will accumulate their gradients. In this case, we aim for the accumulation to be a maximum operation instead of the default addition operation.

@kaixih
Copy link
Contributor Author

kaixih commented Jan 16, 2024

cc. @nluehr @reedwm @zhangqiaorjc

@kaixih
Copy link
Contributor Author

kaixih commented Jan 16, 2024

Also, cc. @mingxu1067

@kaixih
Copy link
Contributor Author

kaixih commented Jan 16, 2024

Note, the failed tests on ValueError: Cannot convert_element_type to dtype=fp8_meta32 is probably because we might have utilized some feature only available from the jax nightly (0.4.24.devXXXX). Currently, Jax is on its 0.4.23. Maybe @mattjj knows better.

@cgarciae
Copy link
Collaborator

@kaixih it looks good but we will have to wait for JAX to push a new release to pypi for tests to pass (according to your comment).

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 6, 2024

@kaixih thanks! Seems that pytest is still broken.

@zhangqiaorjc
Copy link
Member

@kaixih do you mind fixing the CI errors?

@kaixih
Copy link
Contributor Author

kaixih commented Feb 20, 2024

It seems the CI is still on jax 0.4.23 (failed test)

Requirement already satisfied: jax>=0.4.19 in ./venv/lib/python3.10/site-packages (from flax==0.8.1) (0.4.23)
Requirement already satisfied: jaxlib in ./venv/lib/python3.10/site-packages (from flax==0.8.1) (0.4.23)

I re-tested on my machine but using 0.4.24 and then the tests can pass.

pip install --no-deps jax==0.4.24
pip install --no-deps jaxlib==0.4.24

@zhangqiaorjc @cgarciae Can you help check if the jax has already been updated to 0.4.24 or later?

@cgarciae
Copy link
Collaborator

Based on the output of the test run it seems not:

====== test config =======
PYTEST_OPTS: --cov=flax --cov-report=xml --cov-report=term --cov-config=pyproject.toml
RUN_DOCTEST: false
RUN_PYTEST: true
RUN_MYPY: false
RUN_PYTYPE: false
GH_VENV: true
WHICH PYTHON: /home/runner/work/flax/flax/venv/bin/python
jax: 0.4.23
flax: 0.8.1
==========================

@cgarciae
Copy link
Collaborator

As a quick fix maybe add

venv/bin/python3 -m pip install -U jax jaxlib

after this line:

venv/bin/python3 -m pip install -e .[all,testing]

@kaixih
Copy link
Contributor Author

kaixih commented Feb 21, 2024

@cgarciae Do you mean I should add this line in this PR?

@cgarciae
Copy link
Collaborator

Created a PR so you can rebase when merged.

@kaixih kaixih force-pushed the fp8_meta_custom_grad_accumulation branch from 3e31661 to dd004c2 Compare February 26, 2024 19:00
@kaixih
Copy link
Contributor Author

kaixih commented Feb 26, 2024

@zhangqiaorjc It seems all tests pass now. Can you take another look or reassign? Thx.

@zhangqiaorjc
Copy link
Member

i'll let @cgarciae merge this and review the follow up in praxis

@zhangqiaorjc
Copy link
Member

@cgarciae is this PR merge blocked on any internal error?

@copybara-service copybara-service bot merged commit e4282ee into google:main Mar 7, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants