tensor_split expects tensor_indices_or_sections to be on cpu #94
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions | |
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python | |
name: Python tests | |
on: [push, pull_request] | |
jobs: | |
build: | |
runs-on: ubuntu-latest | |
strategy: | |
fail-fast: false | |
matrix: | |
python-version: ["3.8", "3.10", "3.12"] | |
torch-version: ["1.13.1", "2.5.0"] | |
exclude: | |
- python-version: "3.12" | |
torch-version: "1.13.1" | |
steps: | |
- uses: actions/checkout@v4 | |
- name: Set up Python ${{ matrix.python-version }} | |
uses: actions/setup-python@v5 | |
with: | |
python-version: ${{ matrix.python-version }} | |
- name: Install dependencies | |
run: | | |
python -m pip install --upgrade pip | |
pip install torch==${{ matrix.torch-version }} | |
python -m pip install flake8 black[jupyter] | |
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi | |
- name: numpy downgrade for pytorch 1.x | |
if: startsWith(matrix.torch-version, '1.') | |
run: | | |
pip install "numpy<2" | |
- name: Lint check with flake8 | |
run: | | |
# stop the build if there are Python syntax errors or undefined names | |
flake8 . --count --show-source --statistics | |
- name: Check code formatting with black | |
run: | | |
black --check . | |
- name: Run pytest | |
run: | | |
python -m pytest | |
- name: Test installing package | |
run: | | |
python -m pip install . |