Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
Add accelerator device arguments and cuda test (#82)
Browse files Browse the repository at this point in the history
* Add accelerator device arguments and cuda test

* Bump scvi req to 1.1.0rc2
  • Loading branch information
martinkim0 authored Jan 31, 2024
1 parent 3b47c90 commit 42612ff
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
64 changes: 64 additions & 0 deletions .github/workflows/test_linux_cuda.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: Test (Linux, CUDA)

on:
pull_request:
branches: [main]
types: [labeled, synchronize, opened]
schedule:
- cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
# if PR has label "cuda tests" or "all tests" or if scheduled or manually triggered
if: >-
(
contains(github.event.pull_request.labels.*.name, 'cuda tests') ||
contains(github.event.pull_request.labels.*.name, 'all tests') ||
contains(github.event_name, 'schedule') ||
contains(github.event_name, 'workflow_dispatch')
)
runs-on: [self-hosted, Linux, X64, CUDA]
defaults:
run:
shell: bash -e {0} # -e to fail on error

strategy:
fail-fast: false
matrix:
python: ["3.11"]
cuda: ["11"]

container:
image: scverse/scvi-tools:py${{ matrix.python }}-cu${{ matrix.cuda }}-base
options: --user root --gpus all

name: Integration (CUDA)

env:
OS: ${{ matrix.os }}
PYTHON: ${{ matrix.python }}

steps:
- uses: actions/checkout@v4

- name: Install dependencies
run: |
pip install ".[dev,test]"
- name: Test
env:
MPLBACKEND: agg
PLATFORM: ${{ matrix.os }}
DISPLAY: :42
run: |
coverage run -m pytest -v --color=yes
- name: Report coverage
run: |
coverage report
- name: Upload coverage
uses: codecov/codecov-action@v3
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ urls.Documentation = "https://scvi-v2.readthedocs.io/"
urls.Source = "https://github.com/YosefLab/scvi-v2"
urls.Home-page = "https://github.com/YosefLab/scvi-v2"
dependencies = [
"scvi-tools>=1.0.0",
"scvi-tools>=1.1.0rc2",
"seaborn>=0.12.1",
"statsmodels>=0.13.0",
]
Expand Down
4 changes: 4 additions & 0 deletions src/scvi_v2/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def setup_anndata(
def train(
self,
max_epochs: int | None = None,
accelerator: str | None = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
validation_size: float | None = None,
batch_size: int = 128,
Expand All @@ -210,6 +212,8 @@ def train(
):
train_kwargs = dict(
max_epochs=max_epochs,
accelerator=accelerator,
devices=devices,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
Expand Down

0 comments on commit 42612ff

Please sign in to comment.