Skip to content

Commit

Permalink
Selective build/testing in CI (fairinternal/xformers#1273)
Browse files Browse the repository at this point in the history
__original_commit__ = fairinternal/xformers@c0afbd4
  • Loading branch information
danthe3rd authored and xFormers Bot committed Dec 20, 2024
1 parent 4b035ad commit 9a59df2
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/selective_ci/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
GitPython
159 changes: 159 additions & 0 deletions .github/selective_ci/selective_ci.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import fnmatch
import os
from dataclasses import dataclass, field
from pathlib import Path

import git


@dataclass
class ComponentInfo:
"""
A component is deemed to have changed if any of its
files or dependencies have changed.
If it has not changed, its files will be removed.
"""

name: str
# These files will be deleted if the component is not enabled
files: list[str]
dependencies: list[str]
disable_set_env: dict[str, str] = field(default_factory=dict)


COMMON_PATTERNS = [
# All components will be tested if something in there changes
"setup.py",
]

COMPONENTS = [
ComponentInfo(
name="attention",
files=[
"tests/test_mem_eff_attention.py",
"tests/test_find_sparse_locations*.py",
"tests/test_block_sparse_mem_eff_attention*.py",
"tests/test_sparse_mem_eff_attention*.py",
"tests/test_attention_patterns.py",
"tests/test_rope_padded.py",
"tests/test_tree_attention*.py",
],
dependencies=[
"xformers/ops/fmha/*",
"third_party/cutlass",
"third_party/flash-attention",
"third_party/composable_kernel_tiled",
"xformers/csrc/attention/*",
],
disable_set_env={
"XFORMERS_DISABLE_FLASH_ATTN": "1",
},
),
ComponentInfo(
name="sp24",
files=[
"tests/test_sparsity24.py",
"xformers/csrc/sparse24/*",
],
dependencies=[
"xformers/ops/sp24.py",
],
),
ComponentInfo(
name="sequence_parallel_fused",
files=[
"tests/test_seqpar.py",
"tests/test_sequence_parallel_fused_ops.py",
"tests/test_tiled_matmul.py",
"xformers/csrc/sequence_parallel_fused/*",
],
dependencies=[
"tests/multiprocessing_utils.py",
"xformers/ops/sequence_parallel_fused_ops.py",
"xformers/ops/ipc.py",
],
),
ComponentInfo(
name="swiglu",
files=[
"tests/test_swiglu.py",
"xformers/csrc/swiglu/*",
],
dependencies=[
"xformers/ops/swiglu_op.py",
],
),
]

repo_root_path = Path(__file__).parent.parent.parent.resolve().absolute()
repo = git.Repo(repo_root_path)


def list_files_in_commit(commit: git.Commit):
file_list = []
stack = [commit.tree]
while len(stack) > 0:
tree = stack.pop()
# enumerate blobs (files) at this level
for b in tree.blobs:
file_list.append(str(Path(b.path).absolute().relative_to(repo_root_path)))
for subtree in tree.trees:
stack.append(subtree)
# you can return dir_list if you want directories too
return file_list


def check_patterns_are_valid(patterns):
found_patterns = set()
for f in all_files:
for pattern in patterns:
if fnmatch.fnmatch(f, pattern):
found_patterns.add(pattern)
for pattern in patterns:
if pattern not in found_patterns:
assert False, f"Pattern does not match any file: `{pattern}`"


all_files = list_files_in_commit(repo.head.commit) + [sm.path for sm in repo.submodules]
all_modified_files = set()
for item in repo.head.commit.diff(repo.rev_parse("origin/main")):
if item.a_path is not None:
all_modified_files.add(item.a_path)
if item.b_path is not None:
all_modified_files.add(item.b_path)

check_patterns_are_valid(COMMON_PATTERNS)
for component in COMPONENTS:
# Sanity check that all files exist
check_patterns_are_valid(component.files + component.dependencies)

# Check if module is updated
skip_module = True
for pattern in COMMON_PATTERNS + component.files + component.dependencies:
for f in all_modified_files:
if fnmatch.fnmatch(f, pattern):
skip_module = False
break
print(component.name, "SKIP" if skip_module else "TEST")
if not skip_module:
continue

# Delete component files
for f in all_files:
for pattern in component.files:
if fnmatch.fnmatch(f, pattern):
if Path(f).exists():
Path(f).unlink()

# Set env variable
for env_k, env_v in component.disable_set_env.items():
if "GITHUB_ENV" not in os.environ:
print(f"{env_k}={env_v}")
continue
with open(os.environ["GITHUB_ENV"], "a") as fd:
fd.write(f"{env_k}={env_v}\n")
6 changes: 6 additions & 0 deletions .github/workflows/gpu_test_gh.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ jobs:
with:
submodules: recursive
path: "."
fetch-depth: 0 # We need commits history as well
- run: nvidia-smi
- name: Install micromamba
run: |
Expand All @@ -66,6 +67,11 @@ jobs:
echo "micromamba activate env" >> ~/.profile
echo "==== .profile ====="
cat ~/.profile
- name: Selective build/tests
if: github.event_name == 'pull_request'
run: |
pip install -r .github/selective_ci/requirements.txt
python .github/selective_ci/selective_ci.py
- name: Setup test requirements
run: |
which python
Expand Down

0 comments on commit 9a59df2

Please sign in to comment.