Skip to content

Commit

Permalink
Deprecate xformers.components (fairinternal/xformers#1210)
Browse files Browse the repository at this point in the history
__original_commit__ = fairinternal/xformers@01c08fb
  • Loading branch information
fmassa authored and xFormers Bot committed Aug 27, 2024
1 parent feaaa1f commit 2bc3175
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 0 deletions.
9 changes: 9 additions & 0 deletions xformers/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import warnings
from dataclasses import fields
from pathlib import Path
from typing import Any, Dict, Union
Expand All @@ -24,6 +25,14 @@
from .residual import Residual # noqa
from .residual import ResidualNormStyle # noqa

warnings.warn(
"xformers.components is deprecated and is not maintained anymore. "
"It might be removed in a future version of xFormers ",
FutureWarning,
stacklevel=2,
)


# automatically import any Python files in the directory
import_all_modules(str(Path(__file__).parent), "xformers.components")

Expand Down
5 changes: 5 additions & 0 deletions xformers/components/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
from torch import nn

from xformers._deprecation_warning import deprecated_function


class Activation(str, Enum):
SquaredReLU = "squared_relu"
Expand All @@ -24,6 +26,7 @@ class Activation(str, Enum):
class SquaredReLU(nn.Module):
def __init__(self) -> None:
super().__init__()
deprecated_function(self)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_ = torch.nn.functional.relu(x)
Expand All @@ -33,6 +36,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class StarReLU(nn.Module):
def __init__(self) -> None:
super().__init__()
deprecated_function(self)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_ = torch.nn.functional.relu(x)
Expand All @@ -43,6 +47,7 @@ class SmeLU(nn.Module):
def __init__(self, beta: float = 2.0) -> None:
super().__init__()
self.beta = beta
deprecated_function(self)

def forward(self, x: torch.Tensor) -> torch.Tensor:
relu = torch.where(
Expand Down
2 changes: 2 additions & 0 deletions xformers/components/attention/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn as nn

from xformers._deprecation_warning import deprecated_function
from xformers.components.attention import AttentionMask


Expand All @@ -36,6 +37,7 @@ class Attention(nn.Module, metaclass=ABCMeta):
@abstractmethod
def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
super().__init__()
deprecated_function(self)

# Requires the inputs to be projected
self.requires_input_projection = True
Expand Down
2 changes: 2 additions & 0 deletions xformers/components/feedforward/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch.nn as nn

from xformers._deprecation_warning import deprecated_function
from xformers.components import Activation

Self = TypeVar("Self", bound="Feedforward")
Expand All @@ -35,6 +36,7 @@ def __init__(
**kwargs,
):
super().__init__()
deprecated_function(self)

# This feedforward requires a CUDA accelerator
self.requires_cuda = False
Expand Down
3 changes: 3 additions & 0 deletions xformers/components/input_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import torch
from torch import nn

from xformers._deprecation_warning import deprecated_function

logger = logging.getLogger("xformers")


Expand All @@ -38,6 +40,7 @@ def __init__(
):

super().__init__()
deprecated_function(self)

self.out_features = query_proj_params.out_features

Expand Down
2 changes: 2 additions & 0 deletions xformers/components/multi_head_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn as nn
from torch.nn.init import constant_

from xformers._deprecation_warning import deprecated_function
from xformers.components.attention import Attention
from xformers.components.input_projection import InputProjection, InputProjectionConfig
from xformers.components.positional_embedding import RotaryEmbedding
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
**kwargs,
):
super().__init__()
deprecated_function(self)

if isinstance(bias, bool):
logger.warning(
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/patch_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import torch

from xformers._deprecation_warning import deprecated_function


class PoolType(str, Enum):
Conv2D = "CONV_2D"
Expand Down Expand Up @@ -39,6 +41,7 @@ class PatchEmbeddingConfig:
class ConditionalReshape(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
deprecated_function(self)

def forward(self, x):
if x.ndim == 3:
Expand All @@ -54,6 +57,7 @@ def forward(self, x):
class PatchToSequence(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
deprecated_function(self)

def forward(self, x):
return x.flatten(2, 3).transpose(1, 2).contiguous() # B HW C
Expand Down
3 changes: 3 additions & 0 deletions xformers/components/positional_embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import torch.nn as nn

from xformers._deprecation_warning import deprecated_function

Self = TypeVar("Self", bound="PositionEmbedding")


Expand All @@ -24,6 +26,7 @@ class PositionEmbedding(nn.Module, metaclass=ABCMeta):
@abstractmethod
def __init__(self, *args, **kwargs) -> None:
super().__init__()
deprecated_function(self)

@classmethod
def from_config(cls: Type[Self], config: PositionEmbeddingConfig) -> Self:
Expand Down
6 changes: 6 additions & 0 deletions xformers/components/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
import torch.nn as nn

from xformers._deprecation_warning import deprecated_function


class ResidualNormStyle(str, Enum):
"""Support different residual path and norm styles.
Expand All @@ -34,6 +36,7 @@ def get_normalization_layer(normalization_type: NormalizationType):
class Skip(nn.Module):
def __init__(self, *_, **__) -> None:
super().__init__()
deprecated_function(self)

def forward(self, x: torch.Tensor, **_):
return x
Expand Down Expand Up @@ -64,6 +67,7 @@ class Residual(nn.Module, RequiresWrappedInputs):

def __init__(self, layer: nn.Module, scale: Optional[float] = None):
super().__init__()
deprecated_function(self)
self.layer = layer
self.scale = scale

Expand Down Expand Up @@ -97,6 +101,7 @@ def __init__(
):

super().__init__()
deprecated_function(self)
self.norm = get_normalization_layer(normalization)(d_norm)

self.sublayer = sublayer
Expand Down Expand Up @@ -132,6 +137,7 @@ def __init__(
use_triton: bool = True,
):
super().__init__()
deprecated_function(self)
self.norm = get_normalization_layer(normalization)(d_norm)

self.sublayer = sublayer
Expand Down
3 changes: 3 additions & 0 deletions xformers/components/reversible.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states

from xformers._deprecation_warning import deprecated_function
from xformers.components import RequiresWrappedInputs

# CREDITS: Code adapted from
Expand All @@ -23,6 +24,7 @@
class Deterministic(nn.Module):
def __init__(self, net: nn.Module):
super().__init__()
deprecated_function(self)
self.net = net
self.cpu_state: torch.Tensor = torch.get_rng_state()
self.cuda_in_fwd: bool = False
Expand Down Expand Up @@ -146,6 +148,7 @@ def backward(
class ReversibleSequence(nn.Module):
def __init__(self, blocks: nn.ModuleList):
super().__init__()
deprecated_function(self)

# pyre-fixme[23]: Unable to unpack `torch.nn.Module` into 2 values.
self.blocks = nn.ModuleList([ReversibleBlock(f, g) for f, g in blocks])
Expand Down
3 changes: 3 additions & 0 deletions xformers/components/simplicial_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import torch

from xformers._deprecation_warning import deprecated_function

Self = TypeVar("Self", bound="SimplicialEmbedding")


Expand All @@ -32,6 +34,7 @@ class SimplicialEmbedding(torch.nn.Module):

def __init__(self, L: int, temperature: Optional[float] = None) -> None:
super().__init__()
deprecated_function(self)
self.L = L
self.temperature = temperature

Expand Down

0 comments on commit 2bc3175

Please sign in to comment.