Skip to content

Commit

Permalink
Update stubgen to output async def and coroutine decorators (#5845)
Browse files Browse the repository at this point in the history
Fixes #5844
  • Loading branch information
bryanforbes authored and ilevkivskyi committed Nov 26, 2018
1 parent dec6004 commit 1a9e280
Show file tree
Hide file tree
Showing 2 changed files with 353 additions and 10 deletions.
48 changes: 39 additions & 9 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,15 +450,15 @@ def visit_func_def(self, o: FuncDef) -> None:
return
if self.is_recorded_name(o.name()):
return
if not self._indent and self._state not in (EMPTY, FUNC):
if not self._indent and self._state not in (EMPTY, FUNC) and not o.is_awaitable_coroutine:
self.add('\n')
if not self.is_top_level():
self_inits = find_self_initializers(o)
for init, value in self_inits:
init_code = self.get_init(init, value)
if init_code:
self.add(init_code)
self.add("%sdef %s(" % (self._indent, o.name()))
self.add("%s%sdef %s(" % (self._indent, 'async ' if o.is_coroutine else '', o.name()))
self.record_name(o.name())
args = [] # type: List[str]
for i, arg_ in enumerate(o.arguments):
Expand Down Expand Up @@ -513,13 +513,36 @@ def visit_decorator(self, o: Decorator) -> None:
if self.is_private_name(o.func.name()):
return
for decorator in o.decorators:
if isinstance(decorator, NameExpr) and decorator.name in ('property',
'staticmethod',
'classmethod'):
self.add('%s@%s\n' % (self._indent, decorator.name))
elif (isinstance(decorator, MemberExpr) and decorator.name == 'setter' and
isinstance(decorator.expr, NameExpr)):
self.add('%s@%s.setter\n' % (self._indent, decorator.expr.name))
if isinstance(decorator, NameExpr):
if decorator.name in ('property',
'staticmethod',
'classmethod'):
self.add('%s@%s\n' % (self._indent, decorator.name))
elif self.import_tracker.module_for.get(decorator.name) in ('asyncio',
'asyncio.coroutines',
'types'):
self.add_coroutine_decorator(o.func, decorator.name, decorator.name)
elif isinstance(decorator, MemberExpr):
if decorator.name == 'setter' and isinstance(decorator.expr, NameExpr):
self.add('%s@%s.setter\n' % (self._indent, decorator.expr.name))
elif decorator.name == 'coroutine':
if (isinstance(decorator.expr, MemberExpr) and
decorator.expr.name == 'coroutines' and
isinstance(decorator.expr.expr, NameExpr) and
(decorator.expr.expr.name == 'asyncio' or
self.import_tracker.reverse_alias.get(decorator.expr.expr.name) ==
'asyncio')):
self.add_coroutine_decorator(o.func,
'%s.coroutines.coroutine' %
(decorator.expr.expr.name,),
decorator.expr.expr.name)
elif (isinstance(decorator.expr, NameExpr) and
(decorator.expr.name in ('asyncio', 'types') or
self.import_tracker.reverse_alias.get(decorator.expr.name) in
('asyncio', 'asyncio.coroutines', 'types'))):
self.add_coroutine_decorator(o.func,
decorator.expr.name + '.coroutine',
decorator.expr.name)
super().visit_decorator(o)

def visit_class_def(self, o: ClassDef) -> None:
Expand Down Expand Up @@ -750,6 +773,13 @@ def add_import_line(self, line: str) -> None:
if line not in self._import_lines:
self._import_lines.append(line)

def add_coroutine_decorator(self, func: FuncDef, name: str, require_name: str) -> None:
func.is_awaitable_coroutine = True
if not self._indent and self._state not in (EMPTY, FUNC):
self.add('\n')
self.add('%s@%s\n' % (self._indent, name))
self.import_tracker.require_name(require_name)

def output(self) -> str:
"""Return the text for the stub."""
imports = ''
Expand Down
Loading

0 comments on commit 1a9e280

Please sign in to comment.