-
Notifications
You must be signed in to change notification settings - Fork 519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ci: install GPU JAX in GPU CI #4293
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 WalkthroughWalkthroughThe pull request introduces modifications to the GitHub Actions workflow for testing CUDA. It adds a Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant GitHub Actions
participant CUDA Environment
User->>GitHub Actions: Trigger merge_group event
GitHub Actions->>CUDA Environment: Run test_cuda job
CUDA Environment->>CUDA Environment: Install jax[cuda12]
CUDA Environment->>CUDA Environment: Set XLA_PYTHON_CLIENT_PREALLOCATE=false
CUDA Environment->>GitHub Actions: Execute tests
GitHub Actions->>User: Return test results
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
.github/workflows/test_cuda.yml (1)
50-50
: Consider pinning the JAX version for better reproducibility.While tensorflow and torch versions are pinned, the JAX version is not. This could lead to unexpected behavior if JAX releases a breaking change.
Consider updating to:
- - run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0" "jax[cuda12]" + - run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0" "jax[cuda12]~=0.4.23"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
.github/workflows/test_cuda.yml
(2 hunks)
🔇 Additional comments (2)
.github/workflows/test_cuda.yml (2)
64-65
: LGTM! Well-documented memory management configuration.
Setting XLA_PYTHON_CLIENT_PREALLOCATE=false
is a good practice for CI environments to prevent JAX from pre-allocating all GPU memory. The documentation reference is helpful for future maintenance.
Line range hint 19-20
: Verify CUDA version compatibility across dependencies.
There's a minor version mismatch in CUDA versions:
- Container uses CUDA 12.3.1
- PyTorch is downloaded with CUDA 12.4
- JAX is installed with CUDA 12 support
While minor version differences might work, it's better to align all versions for consistency.
Also applies to: 50-50, 71-72
🧰 Tools
🪛 actionlint
51-51: shellcheck reported issue in this script: SC2155:warning:1:8: Declare and assign separately to avoid masking return values
(shellcheck)
51-51: shellcheck reported issue in this script: SC2155:warning:2:8: Declare and assign separately to avoid masking return values
(shellcheck)
51-51: shellcheck reported issue in this script: SC2102:info:3:61: Ranges can only match single chars (mentioned due to duplicates)
(shellcheck)
Summary by CodeRabbit