Skip to content

Commit

Permalink
Add out= to tiled matmul, and turn it into a PyTorch op
Browse files Browse the repository at this point in the history
And use it in seqpar to avoid wave quantization

ghstack-source-id: 742a20129fd538f3504eb469ca13853e56bf6e04
Pull Request resolved: fairinternal/xformers#876

__original_commit__ = fairinternal/xformers@0abb9cd
  • Loading branch information
lw authored and xFormers Bot committed Nov 9, 2023
1 parent 4846ebc commit 35fcc2d
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions xformers/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

import inspect
import os
from dataclasses import dataclass
from functools import wraps
from typing import Any, Dict, List, Type, TypeVar, Union, get_args, get_origin
from typing import Any, Callable, Dict, List, Type, TypeVar, Union

import torch
from torch.torch_version import TorchVersion
from typing_extensions import Annotated, get_args, get_origin


def get_operator(library: str, name: str):
Expand Down Expand Up @@ -69,7 +71,24 @@ def _get_storage_base(x: torch.Tensor) -> int:
return _GET_TENSOR_STORAGE(x).data_ptr() # type: ignore


@dataclass(frozen=True)
class Alias:
name: str
write: bool


def make_pytorch_cuda_operator(fn: ClsT) -> ClsT:
return turn_into_pytorch_op(fn, "CUDA")


def make_pytorch_operator_for_dispatch_key(dispatch_key: str) -> Callable[[ClsT], ClsT]:
def decorator(fn: ClsT) -> ClsT:
return turn_into_pytorch_op(fn, dispatch_key)

return decorator


def turn_into_pytorch_op(fn: ClsT, dispatch_key: str) -> ClsT:
from .. import get_python_lib

def render_arg_type(annotation) -> str:
Expand All @@ -89,6 +108,11 @@ def render_arg_type(annotation) -> str:
+ ", ".join([render_arg_type(t) for t in get_args(annotation)])
+ ")"
)
if get_origin(annotation) is Annotated:
inner_type, annotation = get_args(annotation)
if isinstance(annotation, Alias):
alias = annotation.name + ("!" if annotation.write else "")
return f"{render_arg_type(inner_type)}({alias})"
if annotation is torch.Tensor:
return "Tensor"
if annotation is bool:
Expand Down Expand Up @@ -127,7 +151,7 @@ def callee(*args, **kwargs):

xformers_lib = get_python_lib()
xformers_lib.define(definition)
xformers_lib.impl(op_name, callee, "CUDA")
xformers_lib.impl(op_name, callee, dispatch_key)
dispatcher_impl = getattr(getattr(torch.ops, xformers_lib.ns), op_name)

@wraps(fn) # type: ignore[arg-type]
Expand Down

0 comments on commit 35fcc2d

Please sign in to comment.