Skip to content

Commit

Permalink
Expand ignore_unused_kwargs for pybind11 compiled functions.
Browse files Browse the repository at this point in the history
Add support functions to extract partial function signatures from
pybind11 functions via ast-based reparse of pybind11 generated
docstring. (Workaround for pybind/pybind11#945)

Add initial test coverage with basic c++-level overload, default args
and template-based type signatures.
  • Loading branch information
asford committed Dec 25, 2018
1 parent cbcfc33 commit fe3d39f
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 0 deletions.
100 changes: 100 additions & 0 deletions tmol/tests/utility/test_args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest

from tmol.utility.args import _signature
from tmol.utility.args import ignore_unused_kwargs
from tmol.utility.cpp_extension import load_inline

import inspect

import numba
import numpy
Expand Down Expand Up @@ -55,3 +59,99 @@ def vector_foo(a):

with pytest.raises(TypeError):
vector_foo(v, 2)


def test_ignore_unused_kwargs_pybind11():
test_source = """
#include <deque>
void template_param(std::deque<int> x) {
return;
}
at::Tensor tensor_param(at::Tensor x) {
return x;
}
template<typename Real>
Real overloaded(Real x) {
return x;
}
template<typename Real>
Real defaults(Real x, Real y = 1) {
return x;
}
int add_args(int x) {
return x;
}
int add_args(int x, int y) {
return x;
}
int invalid_overload(int x, int y_int) {
return x;
}
int invalid_overload(int x, float y_float) {
return x;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
using namespace pybind11::literals;
m.def("tensor_param", &tensor_param,
"A tensor parameter, includes '::' namespace separator.", "x"_a);
m.def("template_param", &template_param,
"A template parameter, includes '<', '>'.", "x"_a);
m.def("overloaded", &overloaded<float>,
"Multiple valid overloads.", "x"_a);
m.def("overloaded", &overloaded<double>,
"Multiple valid overloads.", "x"_a);
m.def("defaults", &defaults<float>,
"Has default values", "x"_a, "y"_a=1);
m.def("add_args", static_cast<int(*)(int)>(add_args),
"Additional arguments.", "x"_a);
m.def("add_args", static_cast<int(*)(int, int)>(add_args),
"Additional arguments.", "x"_a, "y"_a);
m.def("invalid_overload", static_cast<int(*)(int, float)>(invalid_overload),
"Overload with varying arg names.", "x"_a, "y_int"_a);
m.def("invalid_overload", static_cast<int(*)(int, float)>(invalid_overload),
"Overload with varying arg names.", "x"_a, "y_float"_a);
}
"""

c = load_inline("test_ignore_unused_kwargs_pybind11", test_source, extra_cflags=())

# Test signature extraction for various combinations of overloads, default
# values, and type signatures.
assert _signature(c.template_param) == inspect.signature(lambda x: None)
assert _signature(c.tensor_param) == inspect.signature(lambda x: None)
assert _signature(c.overloaded) == inspect.signature(lambda x: None)
assert _signature(c.defaults) == inspect.signature(lambda x, y=True: None)
assert _signature(c.add_args) == inspect.signature(lambda x, y=True: None)

with pytest.raises(ValueError):
_signature(c.invalid_overload)

# Test ignore_unused_kwargs for a pybind11-defined function
overloaded = ignore_unused_kwargs(c.overloaded)

assert overloaded(1) == 1
assert overloaded(x=1) == 1

assert overloaded(x=1, b=2) == 1
assert overloaded(1, b=2) == 1

with pytest.raises(TypeError):
overloaded(1, 2)

with pytest.raises(TypeError):
overloaded("expects int or float")
116 changes: 116 additions & 0 deletions tmol/utility/args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import functools
import inspect
import toolz
import copy

import ast
import types
from itertools import zip_longest
from inspect import Signature, Parameter

import re


@functools.singledispatch
Expand Down Expand Up @@ -56,3 +64,111 @@ def _dufunc_wraps(f):

except ImportError:
pass


@_signature.register(types.BuiltinFunctionType)
def _builtin_signature(f):
"""Resolve inspect.Signature for builtin types."""

if getattr(f, "__text_signature__", None):
return inspect.signature(f)
else:
return _pybind11_signature(f)


def _pybind11_doc_signatures(f):
name = f.__name__

overloaded = (
re.search(r"^Overloaded function.", f.__doc__, re.MULTILINE) is not None
)

if overloaded:
doc_signatures = [
m.group(1)
for m in re.finditer(fr"^\d+\. ({name}.*)$", f.__doc__, re.MULTILINE)
]
else:
doc_signatures = [
m.group(1) for m in re.finditer(fr"^({name}.*)$", f.__doc__, re.MULTILINE)
]

signatures = [_parse_doc_signature(ds) for ds in doc_signatures]

return signatures


def _parse_doc_signature(sig):
# Hack attempt to sanitize c++ type signatures into syntactically valid
# python type annotations. Split off the return value annotation then
# coerce c++ namespace and template markers into compatible syntax.
psig = sig.split("->")[0]
for c, p in (("::", "."), ("<", "["), (">", "]")):
psig = psig.replace(c, p)

fdef, = ast.parse(f"def {psig}: pass").body

# Just count the number of arguments w/ default values, no attempt to parse
# the values.
args_no_default = len(fdef.args.args) - len(fdef.args.defaults)
kwargs_no_default = len(fdef.args.kwonlyargs) - len(fdef.args.kw_defaults)

return Signature(
[
Parameter(
a.arg,
Parameter.POSITIONAL_OR_KEYWORD,
default=Parameter.empty if i < args_no_default else True,
)
for i, a in enumerate(fdef.args.args)
]
+ (
[Parameter(fdef.args.kwarg, Parameter.VAR_KEYWORD)]
if fdef.args.kwarg
else []
)
+ (
[Parameter(fdef.args.vararg, Parameter.VAR_POSITIONAL)]
if fdef.args.vararg
else []
)
+ [
Parameter(
a.arg,
Parameter.KEYWORD_ONLY,
default=Parameter.empty if i < kwargs_no_default else True,
)
for i, a in enumerate(fdef.args.kwonlyargs)
]
)


def _aligned_signature(sigs):
combined_params = []
param_sets = zip_longest(*(s.parameters.values() for s in sigs))

for i, ps in enumerate(param_sets):
p = set(filter(None, ps))

if len(p) != 1:
raise ValueError(
f"Incompatible params: {ps} index: {i} in signatures:\n{sigs}"
)

param = p.pop()

combined_params.append(
Parameter(
param.name,
param.kind,
# If the parameter is not present in all signatures mark as
# having a default value, otherwise use existing default.
default=param.default if None not in ps else True,
)
)

return Signature(combined_params)


def _pybind11_signature(pybind11_f):
return _aligned_signature(_pybind11_doc_signatures(pybind11_f))

0 comments on commit fe3d39f

Please sign in to comment.