Skip to content

Commit

Permalink
allow forcing triton (fairinternal/xformers#1137)
Browse files Browse the repository at this point in the history
* Allow forcing triton

XFORMERS_ENABLE_TRITON prevents checks which might initialize
cuda unwantedly. Also remove checks for old triton versions.

__original_commit__ = fairinternal/xformers@9b0a743
  • Loading branch information
bottler authored and xFormers Bot committed Jun 20, 2024
1 parent 133d7f1 commit be13e22
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 26 deletions.
2 changes: 2 additions & 0 deletions xformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def func_wrapper():

@compute_once
def _is_triton_available():
if os.environ.get("XFORMERS_ENABLE_TRITON", "0") == "1":
return True
if not torch.cuda.is_available():
return False
if os.environ.get("XFORMERS_FORCE_DISABLE_TRITON", "0") == "1":
Expand Down
21 changes: 0 additions & 21 deletions xformers/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
from typing import Any, Callable, Dict, List, Type, TypeVar, Union

import torch
from torch.torch_version import TorchVersion
from typing_extensions import Annotated, get_args, get_origin

from .. import _is_triton_available


def get_operator(library: str, name: str):
def no_such_operator(*args, **kwargs):
Expand Down Expand Up @@ -166,21 +163,3 @@ def caller(*args, **kwargs):
return dispatcher_impl(*ba.args, **ba.kwargs)

return caller # type: ignore


def _has_triton2():
if not _is_triton_available():
return False
import triton

tv = TorchVersion(triton.__version__)
return tv >= (2, 1) or tv == (2, 0)


def _has_triton21():
if not _is_triton_available():
return False
import triton

tv = TorchVersion(triton.__version__)
return tv >= (2, 1)
5 changes: 3 additions & 2 deletions xformers/ops/fmha/triton_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

import torch

from ..common import _has_triton21, register_operator
from ... import _is_triton_available
from ..common import register_operator
from .attn_bias import (
AttentionBias,
BlockDiagonalCausalWithOffsetGappyKeysMask,
Expand Down Expand Up @@ -84,7 +85,7 @@ def _is_supported_paged_bias(attn_bias: Any) -> bool:
"BLOCK_N_PER_SPLIT",
]

if TYPE_CHECKING or _has_triton21():
if TYPE_CHECKING or _is_triton_available():
import triton
import triton.language as tl

Expand Down
5 changes: 2 additions & 3 deletions xformers/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
# LICENSE file in the root directory of this source tree.


import torch
from .. import _is_triton_available

_triton_available = torch.cuda.is_available()
if _triton_available:
if _is_triton_available():
try:
from .dropout import FusedDropoutBias, dropout # noqa
from .fused_linear_layer import FusedLinear # noqa
Expand Down

0 comments on commit be13e22

Please sign in to comment.