diff --git a/src/stubgen.py b/src/stubgen.py index c4ff32ae..b29f0709 100755 --- a/src/stubgen.py +++ b/src/stubgen.py @@ -376,6 +376,20 @@ def put_nb_func(self, fn: NbFunction, name: Optional[str] = None) -> None: self.write_ln(f"@{overload}") self.put_nb_overload(fn, s, name) + def put_nb_method(self, fn: NbFunction, name: Optional[str], parent: Optional[object]) -> None: + fn_qualname = getattr(fn, "__qualname__", None) + if name and fn_qualname: + fn_class, _, fn_name = fn_qualname.rpartition('.') + # Check if this function is an alias + if name != fn_name: + if fn_class == getattr(parent, "__qualname__", None): + real_name = fn_name + else: + real_name = fn_qualname + self.write_ln(f"{name} = {real_name}\n") + return + self.put_nb_func(fn, name) + def put_function(self, fn: Callable[..., Any], name: Optional[str] = None, parent: Optional[object] = None): """Append a function of an arbitrary type to the stub""" # Don't generate a constructor for nanobind classes that aren't constructible @@ -848,7 +862,7 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object elif tp_mod == "nanobind": if tp_name == "nb_method": value = cast(NbFunction, value) - self.put_nb_func(value, name) + self.put_nb_method(value, name, parent) elif tp_name == "nb_static_property": value = cast(NbStaticProperty, value) self.put_nb_static_property(name, value) diff --git a/tests/test_typing.cpp b/tests/test_typing.cpp index 3426a8b5..45131d4a 100644 --- a/tests/test_typing.cpp +++ b/tests/test_typing.cpp @@ -57,9 +57,10 @@ NB_MODULE(test_typing_ext, m) { m.def("makeNestedClass", [] { return NestedClass(); }); - // Aliases to local functoins and types + // Aliases to functions and types m.attr("FooAlias") = m.attr("Foo"); m.attr("f_alias") = m.attr("f"); + nb::type().attr("lt_alias") = nb::type().attr("__lt__"); // Custom signature generation for classes and methods struct CustomSignature { int value; }; diff --git a/tests/test_typing_ext.pyi.ref b/tests/test_typing_ext.pyi.ref index 5156122b..c1488ad8 100644 --- a/tests/test_typing_ext.pyi.ref +++ b/tests/test_typing_ext.pyi.ref @@ -32,6 +32,8 @@ class Foo: def __ge__(self, arg: Foo, /) -> bool: ... + lt_alias = __lt__ + FooAlias: TypeAlias = Foo T = TypeVar("T", contravariant=True)