Skip to content

Commit

Permalink
Improve decorator to turn function into PyTorch operator
Browse files Browse the repository at this point in the history
ghstack-source-id: 97ad2a9bfd84fb0a0a50b270e65c05a09b4ea5a7
Pull Request resolved: fairinternal/xformers#857

__original_commit__ = fairinternal/xformers@c530810
  • Loading branch information
lw authored and xFormers Bot committed Nov 3, 2023
1 parent 6b10d18 commit bc5ff01
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 26 deletions.
5 changes: 3 additions & 2 deletions .circleci/continue_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ run_doc_build: &run_doc_build
when: always
command: |
source $BASH_ENV
source activate /home/circleci/venv
cd docs
python3 -m ensurepip
python3 -m pip install -r requirements.txt
# Don't install PyTorch as we already pulled it from conda, and we'd risk having two conflicting versions
$CONDA_PYTHON -m pip install $(grep -ivE "^#|^torch" requirements.txt)
make help
make singlehtml | tee make.out
! tail make.out | grep -q warning
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ sphinx==5.0.0
git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
torch>=1.6.0
numpy>=1.19.5
pyre-extensions == 0.0.29
pyre-extensions==0.0.29
jinja2==3.0.3
einops
13 changes: 0 additions & 13 deletions xformers/csrc/attention/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,6 @@
*/
#include <torch/types.h>

// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension.
// For PyMODINIT_FUNC to work, we need to include Python.h
// https://github.com/pytorch/vision/blob/main/torchvision/csrc/vision.cpp#L17
// Fixes error LNK2001: unresolved external symbol PyInit__C
#if defined(_WIN32)
#include <Python.h>
PyMODINIT_FUNC PyInit__C(void) {
// No need to do anything.
return NULL;
}
#endif // defined(_WIN32)

TORCH_LIBRARY_FRAGMENT(xformers, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_forward_small_k(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)"));
Expand Down
41 changes: 41 additions & 0 deletions xformers/csrc/boxing_unboxing.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include <pybind11/pybind11.h>

// Must come first to load TORCH_VERSION_FOO.
#include <torch/torch.h>

#if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#else
#include <c10d/ProcessGroup.hpp>
#endif

#include <torch/csrc/jit/python/pybind_utils.h>

namespace py = pybind11;

namespace {

// Starting with PyTorch 2.2, we will be able to do boxing/unboxing in Python.
// See https://github.com/pytorch/pytorch/pull/111997.
// In the meantime, we had to implement the conversions in C++ ourselves.

py::object box_process_group(
c10::intrusive_ptr<c10d::ProcessGroup> process_group) {
return torch::jit::toPyObject(c10::IValue(process_group));
}

c10::intrusive_ptr<c10d::ProcessGroup> unbox_process_group(
const py::object& obj) {
return torch::jit::toIValue(
obj,
c10::getCustomClassType<c10::intrusive_ptr<c10d::ProcessGroup>>())
.toCustomClass<c10d::ProcessGroup>();
}

} // namespace

PYBIND11_MODULE(TORCH_EXTENSION_NAME, module) {
module.def("box_process_group", &box_process_group);
module.def("unbox_process_group", &unbox_process_group);
}
58 changes: 48 additions & 10 deletions xformers/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

import inspect
import os
from typing import Any, Dict, List, Type, TypeVar
from functools import wraps
from typing import Any, Dict, List, Type, TypeVar, Union, get_args, get_origin

import torch
from torch.torch_version import TorchVersion

from .._C import box_process_group, unbox_process_group


def get_operator(library: str, name: str):
def no_such_operator(*args, **kwargs):
Expand Down Expand Up @@ -72,35 +75,70 @@ def make_pytorch_cuda_operator(fn: ClsT) -> ClsT:
from .. import get_python_lib

def render_arg_type(annotation) -> str:
# Optional[T] is an alias for Union[T, None]
if get_origin(annotation) is Union:
inner_types = [
t for t in get_args(annotation) if t is not type(None) # noqa: E721
]
if len(inner_types) == 1:
return f"{render_arg_type(inner_types[0])}?"
if get_origin(annotation) is list:
(inner_type,) = get_args(annotation)
return f"{render_arg_type(inner_type)}[]"
if get_origin(annotation) is tuple:
return (
"("
+ ", ".join([render_arg_type(t) for t in get_args(annotation)])
+ ")"
)
if annotation is torch.Tensor:
return "Tensor"
if annotation is bool:
return "bool"
if annotation is int:
return "int"
if annotation is List[int]:
return "int[]"
if annotation is List[torch.Tensor]:
return "Tensor[]"
if annotation is float:
return "float"
if annotation is torch.dtype:
return "ScalarType"
if annotation is torch.distributed.ProcessGroup:
return "__torch__.torch.classes.c10d.ProcessGroup"
assert False, f"Unable to parse annotation: `{annotation}`"

def render_default_value(default):
if default is inspect.Parameter.empty:
return ""
return f" = {default!r}"

sign = inspect.signature(fn) # type: ignore
arguments = [
f"{render_arg_type(arg.annotation)} {arg.name}"
f"{render_arg_type(arg.annotation)} {arg.name}{render_default_value(arg.default)}"
for arg in sign.parameters.values()
]
op_name = fn.__name__ # type: ignore
definition = f"{op_name}({', '.join(arguments)}) -> {render_arg_type(sign.return_annotation)}"

def callee(*args, **kwargs):
ba = sign.bind(*args, **kwargs)
for name, value in ba.arguments.items():
if sign.parameters[name].annotation is torch.distributed.ProcessGroup:
ba.arguments[name] = unbox_process_group(value)
return fn(*ba.args, **ba.kwargs)

xformers_lib = get_python_lib()
xformers_lib.define(definition)
xformers_lib.impl(op_name, fn, "CUDA")
xformers_lib.impl(op_name, callee, "CUDA")
dispatcher_impl = getattr(getattr(torch.ops, xformers_lib.ns), op_name)

def wrapper(*args, **kwargs):
return dispatcher_impl(*args, **kwargs)
@wraps(fn) # type: ignore[arg-type]
def caller(*args, **kwargs):
ba = sign.bind(*args, **kwargs)
for name, value in ba.arguments.items():
if sign.parameters[name].annotation is torch.distributed.ProcessGroup:
ba.arguments[name] = box_process_group(value)
return dispatcher_impl(*ba.args, **ba.kwargs)

return wrapper # type: ignore
return caller # type: ignore


def _has_a_version_of_triton():
Expand Down

0 comments on commit bc5ff01

Please sign in to comment.