Skip to content

Commit

Permalink
stubgen: Preserve simple defaults in function signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
hamdanal committed Jun 2, 2023
1 parent f8f9453 commit 4692138
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 25 deletions.
33 changes: 28 additions & 5 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,14 +749,15 @@ def visit_func_def(self, o: FuncDef) -> None:
args.append("*")

if arg_.initializer:
default = self.get_str_default_of_node(arg_.initializer)
if not annotation:
typename = self.get_str_type_of_node(arg_.initializer, True, False)
if typename == "":
annotation = "=..."
annotation = f"={default}"
else:
annotation = f": {typename} = ..."
annotation = f": {typename} = {default}"
else:
annotation += " = ..."
annotation += f" = {default}"
arg = name + annotation
elif kind == ARG_STAR:
arg = f"*{name}{annotation}"
Expand Down Expand Up @@ -1401,8 +1402,11 @@ def get_str_type_of_node(
return "bytes"
if isinstance(rvalue, FloatExpr):
return "float"
if isinstance(rvalue, UnaryExpr) and isinstance(rvalue.expr, IntExpr):
return "int"
if isinstance(rvalue, UnaryExpr):
if isinstance(rvalue.expr, IntExpr):
return "int"
if isinstance(rvalue.expr, FloatExpr):
return "float"
if isinstance(rvalue, NameExpr) and rvalue.name in ("True", "False"):
return "bool"
if can_infer_optional and isinstance(rvalue, NameExpr) and rvalue.name == "None":
Expand All @@ -1414,6 +1418,25 @@ def get_str_type_of_node(
else:
return ""

def get_str_default_of_node(self, rvalue: Expression) -> str:
default = "..."
if isinstance(rvalue, NameExpr):
if rvalue.name in ("None", "True", "False"):
default = rvalue.name
elif isinstance(rvalue, (IntExpr, FloatExpr)):
default = f"{rvalue.value}"
elif isinstance(rvalue, UnaryExpr):
if isinstance(rvalue.expr, (IntExpr, FloatExpr)):
default = f"{rvalue.op}{rvalue.expr.value}"
elif isinstance(rvalue, StrExpr):
default = repr(rvalue.value)
elif isinstance(rvalue, BytesExpr):
default = f"b{rvalue.value!r}"

if len(default) > 200: # TODO: what's a good limit?
default = "..." # long literals are not useful in stubs
return default

def print_annotation(self, t: Type) -> str:
printer = AnnotationPrinter(self)
return t.accept(printer)
Expand Down
61 changes: 41 additions & 20 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,35 @@ def g(arg) -> None: ...
def f(a, b=2): ...
def g(b=-1, c=0): ...
[out]
def f(a, b: int = ...) -> None: ...
def g(b: int = ..., c: int = ...) -> None: ...
def f(a, b: int = 2) -> None: ...
def g(b: int = -1, c: int = 0) -> None: ...

[case testDefaultArgNone]
def f(x=None): ...
[out]
from _typeshed import Incomplete

def f(x: Incomplete | None = ...) -> None: ...
def f(x: Incomplete | None = None) -> None: ...

[case testDefaultArgBool]
def f(x=True, y=False): ...
[out]
def f(x: bool = ..., y: bool = ...) -> None: ...
def f(x: bool = True, y: bool = False) -> None: ...

[case testDefaultArgStr]
def f(x='foo'): ...
def f(x='foo',y="how's quotes"): ...
[out]
def f(x: str = ...) -> None: ...
def f(x: str = 'foo', y: str = "how's quotes") -> None: ...

[case testDefaultArgBytes]
def f(x=b'foo'): ...
def f(x=b'foo',y=b"what's up"): ...
[out]
def f(x: bytes = ...) -> None: ...
def f(x: bytes = b'foo', y: bytes = b"what's up") -> None: ...

[case testDefaultArgFloat]
def f(x=1.2): ...
def f(x=1.2,y=1e-6,z=0.0,w=-0.0,v=+1.0): ...
[out]
def f(x: float = ...) -> None: ...
def f(x: float = 1.2, y: float = 1e-06, z: float = 0.0, w: float = -0.0, v: float = +1.0) -> None: ...

[case testDefaultArgOther]
def f(x=ord): ...
Expand Down Expand Up @@ -111,10 +111,10 @@ def i(a, *, b=1): ...
def j(a, *, b=1, **c): ...
[out]
def f(a, *b, **c) -> None: ...
def g(a, *b, c: int = ...) -> None: ...
def h(a, *b, c: int = ..., **d) -> None: ...
def i(a, *, b: int = ...) -> None: ...
def j(a, *, b: int = ..., **c) -> None: ...
def g(a, *b, c: int = 1) -> None: ...
def h(a, *b, c: int = 1, **d) -> None: ...
def i(a, *, b: int = 1) -> None: ...
def j(a, *, b: int = 1, **c) -> None: ...

[case testClass]
class A:
Expand Down Expand Up @@ -298,8 +298,8 @@ y: Incomplete
def f(x, *, y=1): ...
def g(x, *, y=1, z=2): ...
[out]
def f(x, *, y: int = ...) -> None: ...
def g(x, *, y: int = ..., z: int = ...) -> None: ...
def f(x, *, y: int = 1) -> None: ...
def g(x, *, y: int = 1, z: int = 2) -> None: ...

[case testProperty]
class A:
Expand Down Expand Up @@ -983,8 +983,8 @@ from _typeshed import Incomplete

class A:
x: Incomplete
def __init__(self, a: Incomplete | None = ...) -> None: ...
def method(self, a: Incomplete | None = ...) -> None: ...
def __init__(self, a: Incomplete | None = None) -> None: ...
def method(self, a: Incomplete | None = None) -> None: ...

[case testAnnotationImportsFrom]
import foo
Expand Down Expand Up @@ -2142,7 +2142,7 @@ from _typeshed import Incomplete as _Incomplete

Y: _Incomplete

def g(x: _Incomplete | None = ...) -> None: ...
def g(x: _Incomplete | None = None) -> None: ...

x: _Incomplete

Expand Down Expand Up @@ -3052,7 +3052,7 @@ class P(Protocol):
[case testNonDefaultKeywordOnlyArgAfterAsterisk]
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...
[out]
def func(*, non_default_kwarg: bool, default_kwarg: bool = ...): ...
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...

[case testNestedGenerator]
def f1():
Expand Down Expand Up @@ -3317,3 +3317,24 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...

class X(_Incomplete): ...
class Y(_Incomplete): ...

[case testIgnoreLongDefaults]
def f(x='abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'): ...

def g(x=b'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'): ...

def h(x=123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890): ...

[out]
def f(x: str = ...) -> None: ...
def g(x: bytes = ...) -> None: ...
def h(x: int = ...) -> None: ...

0 comments on commit 4692138

Please sign in to comment.