Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plum.overload #93

Merged
merged 20 commits into from
Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade --no-cache-dir -e '.[dev]'
- name: Test linter assertions
run: |
python check_linter_assertions.py tests/typechecked
- name: Test
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
9 changes: 5 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
.PHONY: docmake docopen docinit docremove docupdate install test clean
.PHONY: install test

PACKAGE := plum

install:
pip install -e '.[dev]'

test:
pre-commit run --all-files && sleep 0.2 && \
PRAGMA_VERSION=`python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))"` \
pytest tests -v --cov=$(PACKAGE) --cov-report html:cover --cov-report term-missing
python check_linter_assertions.py tests/typechecked
pre-commit run --all-files
PRAGMA_VERSION=`python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))"` \
pytest tests -v --cov=$(PACKAGE) --cov-report html:cover --cov-report term-missing
256 changes: 256 additions & 0 deletions check_linter_assertions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import re
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Tuple

FileLineInfo = Dict[Path, Dict[int, List[str]]]
"""type: Type of a nested dictionary that gives for a collection of files line-wise
information, where the information is of the form `list[str]`."""


def next_noncomment_line(index: int, lines: List[str], path: Path) -> int:
"""Starting at `index`, find the next line with code.

Args:
index (int): Index to start at.
lines (list[str]): Source code lines.
path (:class:`pathlib.Path`): Path where the source code is from.

Returns:
int: Index of the next line with code.
"""
i = index + 1 # Start at the next line.
while i < len(lines):
line_content = lines[i].strip()
if line_content and not line_content.startswith("#"):
return i
i += 1
raise RuntimeError(f"{path}:{index}: Cannot match error assertion to code line.")


def parse_assertions(source_dir: Path, linter: str) -> FileLineInfo:
"""Parse all assertions in all Python files in `source_dir` for linter `linter`.

Args:
source_dir (:class:`pathlib.Path`): Source directory.
linter (str): Linter.

Returns:
:obj:`FileLineInfo`: Assertions.
"""
asserted_errors: FileLineInfo = defaultdict(lambda: defaultdict(list))

for path in source_dir.resolve().rglob("*.py"): # Important to `resolve` here!
with open(path, "r") as f:
lines = f.read().splitlines()
for i, line in enumerate(lines):
# Check if the line has an error assertion.
try:
code, comment = line.rsplit("# E:", 1)
except ValueError:
continue

# Find error assertions.
assertions = re.findall(linter + r"\(([^\)]*)\)", comment)

# There is nothing to do if there are no assertions.
if not assertions:
continue

# Find line number of the code that the assertions pertains to. If there is
# no code on the line, find the next non-comment line.
if not code.strip():
i = next_noncomment_line(i, lines, path)

line_number = i + 1 # Line numbers start at one.
asserted_errors[path][line_number].extend(assertions)

return asserted_errors


def parse_mypy_line(line: str) -> Tuple[Path, int, str, str]:
"""Parse a line of the output of `mypy`.

Args:
line (str): Line.

Raises:
ValueError: If the line cannot be parsed.

Returns:
:class:`pathlib.Path`: Path of file.
int: Line number.
str: Kind of message.
str: Message.
"""
path, line_number, status, message = line.split(":", 3)
# Path must be `resolve`d!
return Path(path).resolve(), int(line_number), status, message


def parse_pyright_line(line: str) -> Tuple[Path, int, str, str]:
"""Parse a line of the output of `pyright`.

Args:
line (str): Line.

Raises:
ValueError: If the line cannot be parsed.

Returns:
:class:`pathlib.Path`: Path of file.
int: Line number.
str: Kind of message.
str: Message.
"""
specification, status_message = line.split(" - ", 1)
path, line_number, _ = specification.split(":", 2)
status, message = status_message.split(":", 1)
# Path must be `resolve`d!
return Path(path.strip()).resolve(), int(line_number), status, message


parse_line: Dict[str, Callable[[str], Tuple[Path, int, str, str]]] = {
"mypy": parse_mypy_line,
"pyright": parse_pyright_line,
}
"""dict[str, Callable[[str], tuple[:class:`pathlib.Path`, int, str, str]]]: Map a
linter to a function that parses a line of the output of the linter."""


def parse_output(stdout: str, linter: str) -> FileLineInfo:
"""Parse the whole output of a linter.

Args:
stdout (str): `stdout` of the linter.
linter (str): Name of the linter.

Returns:
:obj:`FileLineInfo`: Linter errors.
"""
errors: FileLineInfo = defaultdict(lambda: defaultdict(list))

for line in stdout.splitlines():
# Parse line in the output of `mypy`. If it cannot be parsed, just skip it.
try:
path, line_number, status, message = parse_line[linter](line)
except ValueError:
continue

# We only need to validate errors.
if status.lower().strip() != "error":
continue

errors[Path(path)][line_number].append(message)

return errors


def run_linter(linter: str) -> str:
"""Run a linter and get the `stdout`.

Args:
linter (str): Name of the linter.

Returns:
str: `stdout`.
"""
p = subprocess.Popen(
[linter, source_dir],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = p.communicate()
assert stderr == b"", "`stderr` must be empty."
return stdout.decode()


def get_missed(
errors: FileLineInfo,
assertions: FileLineInfo,
match: Callable[[str, str], bool],
) -> FileLineInfo:
"""Find unasserted errors.

Args:
error (:obj:`FileLineInfo`): Errors.
assertions (:obj:`FileLineInfo`): Assertions.
match (Callable[[str, str], bool]): Function that takes in an error and an
assertion and checks whether the assertion asserts the error.

Returns:
:obj:`FileLineInfo`: Unasserted errors.
"""
missed_errors: FileLineInfo = defaultdict(lambda: defaultdict(list))
for path, path_errors in errors.items():
# If there are no assertions for `path`, report all errors as missing.
if path not in assertions:
for line_number in errors[path]:
missed_errors[path][line_number].extend(errors[path][line_number])
continue
for line_number, line_errors in path_errors.items():
# If there are no assertions for `line_number`, report all errors as
# missing.
if line_number not in assertions[path]:
missed_errors[path][line_number].extend(errors[path][line_number])
continue
# Check every error for the line.
for e in line_errors:
if not any(match(e, a) for a in assertions[path][line_number]):
missed_errors[path][line_number].append(e)
return missed_errors


def check_linter(source_dir: Path, linter: str) -> bool:
"""Run a linter and check if all errors were asserted and all assertions yielded
errors. If not, print an overview of what was missed.

Args:
source_dir (:class:`pathlib.Path`): Source directory.
linter (str): Name of the linter.

Returns:
bool: `True` if nothing was missed, else `False`.
"""
stdout = run_linter(linter)

errors = parse_output(stdout, linter)
assertions = parse_assertions(source_dir, linter)

missed_errors = get_missed(
errors,
assertions,
lambda e, a: a.strip().lower() in e.strip().lower(),
)
missed_assertions = get_missed(
assertions,
errors,
lambda a, e: a.strip().lower() in e.strip().lower(),
)

for path, path_errors in missed_errors.items():
for line_number, line_errors in path_errors.items():
for error in line_errors:
print(f"{linter}:{path}:{line_number}: Error: {error.strip()}")

for path, path_assertions in missed_assertions.items():
for line_number, line_assertions in path_assertions.items():
for assertion in line_assertions:
print(
f"{linter}:{path}:{line_number}: "
f"Missed assertion: {assertion.strip()}"
)

return not (missed_errors or missed_assertions)


if __name__ == "__main__":
source_dir = Path(sys.argv[1]) # Files that must be validated
status = True
status |= check_linter(source_dir, "mypy")
status |= check_linter(source_dir, "pyright")
if status:
print("All OK!")
exit(0 if status else 1)
1 change: 1 addition & 0 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ chapters:
- file: classes
- file: keyword_arguments
- file: comparison
- file: integration
- file: advanced_usage
sections:
- file: conversion_promotion
Expand Down
34 changes: 34 additions & 0 deletions docs/integration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Integration with Linters and `mypy`

Plum's integration with linters and `mypy` is unfortunately limited.
Properly supporting multiple dispatch in these tools is challenging for a [variety of reasons](https://github.com/python/mypy/issues/11727).
In this section, we collect various patterns in which Plum plays nicely with type checking.

## Overload Support

At the moment, the only know pattern in which Plum produces `mypy`-compliant code uses `typing.overload`.

An example is as follows:

```python
from plum import dispatch, overload


@overload
def f(x: int) -> int:
return x


@overload
def f(x: str) -> str:
return x


@dispatch
def f(x):
pass
```

In the above, for Python versions prior to 3.11, `plum.overload` is `typing_extensions.overload`.
For this pattern to work in Python versions prior to 3.11, you must use `typing_extensions.overload`, not `typing.overload`.
By importing `overload` from `plum`, you will always use the correct `overload`.
1 change: 1 addition & 0 deletions plum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .autoreload import * # noqa: F401, F403
from .dispatcher import * # noqa: F401, F403
from .function import * # noqa: F401, F403
from .overload import overload # noqa: F401
from .parametric import * # noqa: F401, F403
from .promotion import * # noqa: F401, F403
from .resolver import * # noqa: F401, F403
Expand Down
23 changes: 16 additions & 7 deletions plum/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union

from .function import Function
from .overload import get_overloads
from .signature import Signature
from .util import TypeHint, get_class, is_in_class
from .util import Callable, TypeHint, get_class, is_in_class

__all__ = ["Dispatcher", "dispatch", "clear_all_cache"]

T = TypeVar("T", bound=Callable[..., Any])


class Dispatcher:
"""A namespace for functions.
Expand All @@ -20,11 +23,7 @@ def __init__(self):
self.functions: Dict[str, Function] = {}
self.classes: Dict[str, Dict[str, Function]] = {}

def __call__(
self,
method: Optional[Callable] = None,
precedence: int = 0,
) -> Callable:
def __call__(self, method: Optional[T] = None, precedence: int = 0) -> T:
"""Decorator to register for a particular signature.

Args:
Expand All @@ -36,6 +35,16 @@ def __call__(
if method is None:
return lambda m: self(m, precedence=precedence)

# If `method` has overloads, assume that those overloads need to be registered
# and that `method` is not an implementation.
overloads = get_overloads(method)
if overloads:
for overload_method in overloads:
# All `f` returned by `self._add_method` are the same.
f = self._add_method(overload_method, None, precedence=precedence)
# We do not need to register `method`, because it is not an implementation.
return f

# The signature will be automatically derived from `method`, so we can safely
# set the signature argument to `None`.
return self._add_method(method, None, precedence=precedence)
Expand Down
2 changes: 1 addition & 1 deletion plum/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __doc__(self) -> Optional[str]:
"""
try:
self._resolve_pending_registrations()
except NameError:
except NameError: # pragma: specific no cover 3.7 3.8 3.9
# When `staticmethod` is combined with
# `from __future__ import annotations`, in Python 3.10 and higher
# `staticmethod` will attempt to inherit `__doc__` (see
Expand Down
Loading
Loading