From bc5ff010ef1e32beb3d1205b513ccc1618d606ab Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Fri, 3 Nov 2023 18:06:27 +0000 Subject: [PATCH] Improve decorator to turn function into PyTorch operator ghstack-source-id: 97ad2a9bfd84fb0a0a50b270e65c05a09b4ea5a7 Pull Request resolved: https://github.com/fairinternal/xformers/pull/857 __original_commit__ = fairinternal/xformers@c5308102454bae98cd139e5a83b6014ea13e713a --- .circleci/continue_config.yml | 5 ++- docs/requirements.txt | 2 +- xformers/csrc/attention/attention.cpp | 13 ------ xformers/csrc/boxing_unboxing.cpp | 41 +++++++++++++++++++ xformers/ops/common.py | 58 ++++++++++++++++++++++----- 5 files changed, 93 insertions(+), 26 deletions(-) create mode 100644 xformers/csrc/boxing_unboxing.cpp diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index 10b42ccdbf..556c0fa32a 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -150,9 +150,10 @@ run_doc_build: &run_doc_build when: always command: | source $BASH_ENV + source activate /home/circleci/venv cd docs - python3 -m ensurepip - python3 -m pip install -r requirements.txt + # Don't install PyTorch as we already pulled it from conda, and we'd risk having two conflicting versions + $CONDA_PYTHON -m pip install $(grep -ivE "^#|^torch" requirements.txt) make help make singlehtml | tee make.out ! tail make.out | grep -q warning diff --git a/docs/requirements.txt b/docs/requirements.txt index 2fff160cc0..7c651146f5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,6 +4,6 @@ sphinx==5.0.0 git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme torch>=1.6.0 numpy>=1.19.5 -pyre-extensions == 0.0.29 +pyre-extensions==0.0.29 jinja2==3.0.3 einops diff --git a/xformers/csrc/attention/attention.cpp b/xformers/csrc/attention/attention.cpp index b09a7cfad4..6b025d6763 100644 --- a/xformers/csrc/attention/attention.cpp +++ b/xformers/csrc/attention/attention.cpp @@ -7,19 +7,6 @@ */ #include -// If we are in a Windows environment, we need to define -// initialization functions for the _custom_ops extension. -// For PyMODINIT_FUNC to work, we need to include Python.h -// https://github.com/pytorch/vision/blob/main/torchvision/csrc/vision.cpp#L17 -// Fixes error LNK2001: unresolved external symbol PyInit__C -#if defined(_WIN32) -#include -PyMODINIT_FUNC PyInit__C(void) { - // No need to do anything. - return NULL; -} -#endif // defined(_WIN32) - TORCH_LIBRARY_FRAGMENT(xformers, m) { m.def(TORCH_SELECTIVE_SCHEMA( "xformers::efficient_attention_forward_small_k(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)")); diff --git a/xformers/csrc/boxing_unboxing.cpp b/xformers/csrc/boxing_unboxing.cpp new file mode 100644 index 0000000000..2348fdcc10 --- /dev/null +++ b/xformers/csrc/boxing_unboxing.cpp @@ -0,0 +1,41 @@ +#include + +// Must come first to load TORCH_VERSION_FOO. +#include + +#if TORCH_VERSION_MAJOR > 1 || \ + (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) +#include +#else +#include +#endif + +#include + +namespace py = pybind11; + +namespace { + +// Starting with PyTorch 2.2, we will be able to do boxing/unboxing in Python. +// See https://github.com/pytorch/pytorch/pull/111997. +// In the meantime, we had to implement the conversions in C++ ourselves. + +py::object box_process_group( + c10::intrusive_ptr process_group) { + return torch::jit::toPyObject(c10::IValue(process_group)); +} + +c10::intrusive_ptr unbox_process_group( + const py::object& obj) { + return torch::jit::toIValue( + obj, + c10::getCustomClassType>()) + .toCustomClass(); +} + +} // namespace + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, module) { + module.def("box_process_group", &box_process_group); + module.def("unbox_process_group", &unbox_process_group); +} diff --git a/xformers/ops/common.py b/xformers/ops/common.py index a32ec3f556..bd893c7034 100644 --- a/xformers/ops/common.py +++ b/xformers/ops/common.py @@ -5,11 +5,14 @@ import inspect import os -from typing import Any, Dict, List, Type, TypeVar +from functools import wraps +from typing import Any, Dict, List, Type, TypeVar, Union, get_args, get_origin import torch from torch.torch_version import TorchVersion +from .._C import box_process_group, unbox_process_group + def get_operator(library: str, name: str): def no_such_operator(*args, **kwargs): @@ -72,35 +75,70 @@ def make_pytorch_cuda_operator(fn: ClsT) -> ClsT: from .. import get_python_lib def render_arg_type(annotation) -> str: + # Optional[T] is an alias for Union[T, None] + if get_origin(annotation) is Union: + inner_types = [ + t for t in get_args(annotation) if t is not type(None) # noqa: E721 + ] + if len(inner_types) == 1: + return f"{render_arg_type(inner_types[0])}?" + if get_origin(annotation) is list: + (inner_type,) = get_args(annotation) + return f"{render_arg_type(inner_type)}[]" + if get_origin(annotation) is tuple: + return ( + "(" + + ", ".join([render_arg_type(t) for t in get_args(annotation)]) + + ")" + ) if annotation is torch.Tensor: return "Tensor" if annotation is bool: return "bool" if annotation is int: return "int" - if annotation is List[int]: - return "int[]" - if annotation is List[torch.Tensor]: - return "Tensor[]" + if annotation is float: + return "float" + if annotation is torch.dtype: + return "ScalarType" + if annotation is torch.distributed.ProcessGroup: + return "__torch__.torch.classes.c10d.ProcessGroup" assert False, f"Unable to parse annotation: `{annotation}`" + def render_default_value(default): + if default is inspect.Parameter.empty: + return "" + return f" = {default!r}" + sign = inspect.signature(fn) # type: ignore arguments = [ - f"{render_arg_type(arg.annotation)} {arg.name}" + f"{render_arg_type(arg.annotation)} {arg.name}{render_default_value(arg.default)}" for arg in sign.parameters.values() ] op_name = fn.__name__ # type: ignore definition = f"{op_name}({', '.join(arguments)}) -> {render_arg_type(sign.return_annotation)}" + def callee(*args, **kwargs): + ba = sign.bind(*args, **kwargs) + for name, value in ba.arguments.items(): + if sign.parameters[name].annotation is torch.distributed.ProcessGroup: + ba.arguments[name] = unbox_process_group(value) + return fn(*ba.args, **ba.kwargs) + xformers_lib = get_python_lib() xformers_lib.define(definition) - xformers_lib.impl(op_name, fn, "CUDA") + xformers_lib.impl(op_name, callee, "CUDA") dispatcher_impl = getattr(getattr(torch.ops, xformers_lib.ns), op_name) - def wrapper(*args, **kwargs): - return dispatcher_impl(*args, **kwargs) + @wraps(fn) # type: ignore[arg-type] + def caller(*args, **kwargs): + ba = sign.bind(*args, **kwargs) + for name, value in ba.arguments.items(): + if sign.parameters[name].annotation is torch.distributed.ProcessGroup: + ba.arguments[name] = box_process_group(value) + return dispatcher_impl(*ba.args, **ba.kwargs) - return wrapper # type: ignore + return caller # type: ignore def _has_a_version_of_triton():