Skip to content

Commit

Permalink
Add type annotations, enforce strict typing (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
amyreese authored Apr 22, 2024
1 parent 163aedc commit 50d1815
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 156 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,6 @@ skip_covered = true

[tool.mypy]
python_version = "3.8"
# strict = true
strict = true
ignore_missing_imports = true
disallow_untyped_calls = false
16 changes: 8 additions & 8 deletions sphinx_mdinclude/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from mistune import BlockParser, InlineParser
from mistune.core import BlockState, InlineState
from mistune.inline_parser import HTML_ATTRIBUTES, HTML_TAGNAME
from mistune.helpers import HTML_ATTRIBUTES, HTML_TAGNAME


State = Dict[str, Any]
Expand All @@ -26,17 +26,17 @@ class RestBlockParser(BlockParser):
"rest_code_block",
)

def parse_directive(self, m: Match, state: BlockState) -> int:
def parse_directive(self, m: Match[str], state: BlockState) -> int:
state.append_token({"type": "directive", "raw": m.group("directive_1")})
return m.end()

def parse_oneline_directive(self, m: Match, state: BlockState) -> int:
def parse_oneline_directive(self, m: Match[str], state: BlockState) -> int:
# reuse directive output
state.append_token({"type": "directive", "raw": m.group("directive_2")})
# $ does not count '\n'
return m.end() + 1

def parse_rest_code_block(self, m: Match, state: BlockState) -> int:
def parse_rest_code_block(self, m: Match[str], state: BlockState) -> int:
state.append_token({"type": "rest_code_block", "text": ""})
# $ does not count '\n'
return m.end() + 1
Expand Down Expand Up @@ -74,22 +74,22 @@ class RestInlineParser(InlineParser):
"eol_literal_marker",
) + InlineParser.DEFAULT_RULES # type: ignore[has-type]

def parse_rest_role(self, m: Match, state: InlineState) -> int:
def parse_rest_role(self, m: Match[str], state: InlineState) -> int:
"""Pass through rest role."""
state.append_token({"type": "rest_role", "raw": m.group(0)})
return m.end()

def parse_rest_link(self, m: Match, state: InlineState) -> int:
def parse_rest_link(self, m: Match[str], state: InlineState) -> int:
"""Pass through rest link."""
state.append_token({"type": "rest_link", "raw": m.group(0)})
return m.end()

def parse_inline_math(self, m: Match, state: InlineState) -> int:
def parse_inline_math(self, m: Match[str], state: InlineState) -> int:
"""Pass through inline math."""
state.append_token({"type": "inline_math", "raw": m.group("math_1")})
return m.end()

def parse_eol_literal_marker(self, m: Match, state: InlineState) -> int:
def parse_eol_literal_marker(self, m: Match[str], state: InlineState) -> int:
"""Pass through rest link."""
marker = ":" if m.group("eol_space") is None else ""
state.append_token({"type": "eol_literal_marker", "raw": marker})
Expand Down
109 changes: 60 additions & 49 deletions sphinx_mdinclude/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import textwrap
from functools import partial
from importlib import import_module
from typing import Any, Dict
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

from docutils.utils import column_width
from mistune import Markdown
from mistune.core import BaseRenderer
from mistune.core import BaseRenderer, BlockState
from mistune.plugins import _plugins

from .parse import RestBlockParser, RestInlineParser
Expand Down Expand Up @@ -34,13 +34,13 @@ class RestRenderer(BaseRenderer):
6: "#",
}

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self._indent_block = partial(textwrap.indent, prefix=self.indent)
super().__init__(*args, **kwargs)

def render_token(self, token, state):
def render_token(self, token: Dict[str, Any], state: BlockState) -> str:
# based on mistune 3.0.2, mistune/renderers/html.py
func = self._get_method(token["type"])
func: Callable[..., str] = self._get_method(token["type"])
attrs = token.get("attrs")
style = token.get("style")

Expand Down Expand Up @@ -68,14 +68,14 @@ def render_token(self, token, state):
else:
return func(text)

def finalize(self, data):
def finalize(self, data: Iterable[str]) -> str:
return "".join(data)

def _raw_html(self, html):
def _raw_html(self, html: str) -> str:
self._include_raw_html = True
return r":raw-html-md:`{}`".format(html)

def block_code(self, code, style, info=None):
def block_code(self, code: str, style: str, info: Optional[str] = None) -> str:
if info == "math":
first_line = "\n.. math::\n\n"
elif info:
Expand All @@ -86,21 +86,21 @@ def block_code(self, code, style, info=None):
newline = "\n" if style == "indent" else ""
return first_line + self._indent_block(code + newline)

def block_quote(self, text):
def block_quote(self, text: str) -> str:
# text includes some empty line
return "\n..\n\n{}\n\n".format(self._indent_block(text.strip("\n")))

def block_text(self, text):
def block_text(self, text: str) -> str:
return text

def block_html(self, html):
def block_html(self, html: str) -> str:
"""Rendering block level pure html content.
:param html: text content of the html snippet.
"""
return "\n\n.. raw:: html\n\n" + self._indent_block(html) + "\n"

def heading(self, text, level, **attrs):
def heading(self, text: str, level: int, **attrs: Any) -> str:
"""Rendering header/heading tags like ``<h1>`` ``<h2>``.
:param text: rendered text content for the header.
Expand All @@ -109,11 +109,11 @@ def heading(self, text, level, **attrs):
"""
return "\n{0}\n{1}\n".format(text, self.hmarks[level] * column_width(text))

def thematic_break(self):
def thematic_break(self) -> str:
"""Rendering method for ``<hr>`` tag."""
return "\n----\n"

def list(self, text, ordered, **attrs):
def list(self, text: str, ordered: bool, **attrs: Any) -> str:
"""Rendering list tags like ``<ul>`` and ``<ol>``.
:param text: body contents of the list.
Expand All @@ -128,15 +128,15 @@ def list(self, text, ordered, **attrs):
result = "\n{}\n".format("\n".join(lines)).replace(self.list_marker, mark)
return result

def list_item(self, text):
def list_item(self, text: str) -> str:
"""Rendering list item snippet. Like ``<li>``."""
return "\n" + self.list_marker + text

def paragraph(self, text):
def paragraph(self, text: str) -> str:
"""Rendering paragraph tags. Like ``<p>``."""
return "\n" + text + "\n"

def table(self, body):
def table(self, body: str) -> str:
"""Rendering table element. Wrap header and body in it.
:param header: header part of the table.
Expand All @@ -146,13 +146,13 @@ def table(self, body):
table = table + self._indent_block(body) + "\n"
return table

def table_head(self, text):
def table_head(self, text: str) -> str:
return ":header-rows: 1\n\n" + self.table_row(text)

def table_body(self, text):
def table_body(self, text: str) -> str:
return text

def table_row(self, content):
def table_row(self, content: str) -> str:
"""Rendering a table row. Like ``<tr>``.
:param content: content of current table row.
Expand All @@ -166,7 +166,7 @@ def table_row(self, content):
clist.append(" " + c)
return "\n".join(clist) + "\n"

def table_cell(self, content, align=None, head=False):
def table_cell(self, content: str, align: None = None, head: bool = False) -> str:
"""Rendering a table cell. Like ``<th>`` ``<td>``.
:param content: content of current table cell.
Expand All @@ -175,24 +175,24 @@ def table_cell(self, content, align=None, head=False):
"""
return "- " + content + "\n"

def double_emphasis(self, text):
def double_emphasis(self, text: str) -> str:
"""Rendering **strong** text.
:param text: text content for emphasis.
"""
return r"**{}**".format(text)

def emphasis(self, text):
def emphasis(self, text: str) -> str:
"""Rendering *emphasis* text.
:param text: text content for emphasis.
"""
return r"*{}*".format(text)

def strong(self, text):
def strong(self, text: str) -> str:
return r"**{}**".format(text)

def codespan(self, text):
def codespan(self, text: str) -> str:
"""Rendering inline `code` text.
:param text: text content for inline code.
Expand All @@ -208,29 +208,29 @@ def codespan(self, text):
else:
return r"``{}``".format(text)

def linebreak(self):
def linebreak(self) -> str:
"""Rendering line break like ``<br>``."""
return " " + self._raw_html("<br />") + "\n"

def softbreak(self):
def softbreak(self) -> str:
"""Rendering soft line break."""
return "\n"

def strikethrough(self, text):
def strikethrough(self, text: str) -> str:
"""Rendering ~~strikethrough~~ text.
:param text: text content for strikethrough.
"""
return self._raw_html("<del>{}</del>".format(text))

def text(self, text):
def text(self, text: str) -> str:
"""Rendering unformatted text.
:param text: text content.
"""
return text

def link(self, text, url, title=None):
def link(self, text: str, url: str, title: Optional[str] = None) -> str:
"""Rendering a given link with content and title.
:param text: text content for description.
Expand All @@ -256,7 +256,7 @@ def link(self, text, url, title=None):
target=url, text=text, underscore=underscore
)

def image(self, text, url, title=None):
def image(self, text: str, url: str, title: Optional[str] = None) -> str:
"""Rendering a image with title and text.
:param text: alt text of the image.
Expand All @@ -275,7 +275,7 @@ def image(self, text, url, title=None):
]
)

def image_link(self, url, target, alt):
def image_link(self, url: str, target: str, alt: str) -> str:
return "\n".join(
[
"",
Expand All @@ -286,34 +286,34 @@ def image_link(self, url, target, alt):
]
)

def inline_html(self, html):
def inline_html(self, html: str) -> str:
"""Rendering span level pure html content.
:param html: text content of the html snippet.
"""
return self._raw_html(html)

def newline(self):
def newline(self) -> str:
"""Rendering newline element."""
return ""

def footnote_ref(self, key, index):
def footnote_ref(self, key: str, index: int) -> str:
"""Rendering the ref anchor of a footnote.
:param key: identity key for the footnote.
:param index: the index count of current footnote.
"""
return r"[#fn-{}]_".format(key)

def footnote_item(self, text, key, index):
def footnote_item(self, text: str, key: str, index: int) -> str:
"""Rendering a footnote item.
:param key: identity key for the footnote.
:param text: text content of the footnote.
"""
return ".. [#fn-{0}] {1}\n".format(key, text.strip())

def footnotes(self, text):
def footnotes(self, text: str) -> str:
"""Wrapper for all footnotes.
:param text: contents of all footnotes.
Expand All @@ -325,32 +325,39 @@ def footnotes(self, text):

"""Below outputs are for rst."""

def rest_role(self, raw):
def rest_role(self, raw: str) -> str:
return raw

def rest_link(self, raw):
def rest_link(self, raw: str) -> str:
return raw

def inline_math(self, raw):
def inline_math(self, raw: str) -> str:
"""Extension of recommonmark."""
return r":math:`{}`".format(raw)

def eol_literal_marker(self, raw):
def eol_literal_marker(self, raw: str) -> str:
"""Extension of recommonmark."""
return raw

def directive(self, text):
def directive(self, text: str) -> str:
return "\n" + text

def rest_code_block(self, text):
def rest_code_block(self, text: str) -> str:
return "\n\n"

def blank_line(self):
def blank_line(self) -> str:
return ""


class RestMarkdown(Markdown):
def __init__(self, renderer=None, block=None, inline=None, plugins=None, **kwargs):
def __init__(
self,
renderer: Optional[BaseRenderer] = None,
block: Optional[RestBlockParser] = None,
inline: Optional[RestInlineParser] = None,
plugins: Optional[List[Any]] = None,
**kwargs: Any,
) -> None:
renderer = renderer or RestRenderer()
block = block or RestBlockParser()
inline = inline or RestInlineParser()
Expand All @@ -372,18 +379,22 @@ def __init__(self, renderer=None, block=None, inline=None, plugins=None, **kwarg

super().__init__(renderer, block=block, inline=inline, plugins=plugins)

def parse(self, text):
def parse(
self,
text: str,
state: Optional[BlockState] = None,
) -> Tuple[str, Optional[BlockState]]:
output, state = super().parse(text)
output = self.post_process(output)

return output, state

def post_process(self, text):
def post_process(self, text: str) -> str:
if self.renderer._include_raw_html:
return PROLOG + text
else:
return text


def convert(text, **kwargs):
return RestMarkdown(**kwargs)(text)
def convert(text: str, **kwargs: Any) -> str:
return str(RestMarkdown(**kwargs)(text))
Loading

0 comments on commit 50d1815

Please sign in to comment.