Skip to content

Commit

Permalink
Update Flash-Attention to v2.3.2
Browse files Browse the repository at this point in the history
ghstack-source-id: 73d59258856da7688e58f11c12ba26c904207ce9
Pull Request resolved: https://github.com/fairinternal/xformers/pull/839

__original_commit__ = fairinternal/xformers@1e667b7d28ff062a52c828807491039e6e53c3d7
  • Loading branch information
danthe3rd authored and xFormers Bot committed Oct 13, 2023
1 parent 8d50e38 commit 59ec470
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/win-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:

env:
FORCE_CUDA: 1
MAX_JOBS: 4
MAX_JOBS: 2
DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
XFORMERS_BUILD_TYPE: "Release"

Expand Down Expand Up @@ -43,12 +43,13 @@ jobs:
- name: Setup Runner
uses: ./.github/actions/setup-windows-runner
with:
cuda: "11.6.2"
# (FAv2 requires cuda 12+)
cuda: "12.1.0"
python: "3.8"

- name: Install build dependencies
run: |
$PY -m pip install wheel setuptools ninja torch==2.1.0 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118
$PY -m pip install wheel setuptools ninja torch==2.1.0 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121
git config --global --add safe.directory "*"
$PY -c "import torch; print('torch', torch.__version__)"
$PY -c "import torch; print('torch.cuda', torch.version.cuda)"
Expand All @@ -60,4 +61,6 @@ jobs:
run: $PY -m pip install -v dist/*

- name: Info
run: $PY -m xformers.info
run: |
cd ../../ # So we don't have a folder named `xformers`
XFORMERS_MORE_DETAILS=1 $PY -m xformers.info
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def get_cuda_version(cuda_dir) -> int:


def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
# XXX: Not supported on windows yet
# XXX: Not supported on windows for cuda<12
# https://github.com/Dao-AILab/flash-attention/issues/345
if platform.system() != "Linux":
if platform.system() != "Linux" and cuda_version < 1200:
return []
# Figure out default archs to target
DEFAULT_ARCHS_LIST = ""
Expand Down

0 comments on commit 59ec470

Please sign in to comment.