Skip to content

Commit

Permalink
test/mypy: types in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rwe committed Jan 10, 2022
1 parent 0595a93 commit 5047ae3
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 75 deletions.
130 changes: 95 additions & 35 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@
import textwrap
import subprocess
import traceback
from types import TracebackType
from typing import (
Any,
Callable,
Generator,
Optional,
Sequence,
Tuple,
Type,
Union,
TYPE_CHECKING,
)
from gitrevise.odb import Repository
from gitrevise.utils import sh_path
from contextlib import contextmanager
Expand All @@ -17,8 +29,15 @@
import dummy_editor


if TYPE_CHECKING:
from _typeshed import StrPath


@pytest.fixture(autouse=True)
def hermetic_seal(tmp_path_factory, monkeypatch):
def hermetic_seal(
tmp_path_factory: pytest.TempPathFactory,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Lock down user git configuration
home = tmp_path_factory.mktemp("home")
xdg_config_home = home / ".config"
Expand Down Expand Up @@ -59,23 +78,27 @@ def hermetic_seal(tmp_path_factory, monkeypatch):


@pytest.fixture
def repo(hermetic_seal):
def repo(hermetic_seal: None) -> Generator[Repository, None, None]:
with Repository() as repo:
yield repo


@pytest.fixture
def short_tmpdir():
def short_tmpdir() -> Generator[py.path.local, None, None]:
with tempfile.TemporaryDirectory() as tdir:
yield py.path.local(tdir)


@contextmanager
def in_parallel(func, *args, **kwargs):
def in_parallel(
func: Callable[..., Any],
*args: Any,
**kwargs: Any,
) -> "Generator[None, None, None]":
class HelperThread(Thread):
exception = None
exception: Optional[Exception] = None

def run(self):
def run(self) -> None:
try:
func(*args, **kwargs)
except Exception as exc:
Expand All @@ -93,7 +116,7 @@ def run(self):
raise thread.exception


def bash(command):
def bash(command: str) -> None:
# Use a custom environment for bash commands so commits with those commands
# have unique names and emails.
env = dict(
Expand All @@ -106,7 +129,7 @@ def bash(command):
subprocess.run([sh_path(), "-ec", textwrap.dedent(command)], check=True, env=env)


def changeline(path, lineno, newline):
def changeline(path: "StrPath", lineno: int, newline: bytes) -> None:
with open(path, "rb") as f:
lines = f.readlines()
lines[lineno] = newline
Expand All @@ -115,15 +138,23 @@ def changeline(path, lineno, newline):


# Run the main entry point for git-revise in a subprocess.
def main(args, **kwargs):
kwargs.setdefault("check", True)
def main(
args: Sequence[str],
cwd: Optional["StrPath"] = None,
input: Optional[bytes] = None,
check: bool = True,
) -> "subprocess.CompletedProcess[bytes]":
cmd = [sys.executable, "-m", "gitrevise", *args]
print("Running", cmd, kwargs)
return subprocess.run(cmd, **kwargs)
print("Running", cmd, dict(cwd=cwd, input=input, check=check))
return subprocess.run(cmd, cwd=cwd, input=input, check=check)


@contextmanager
def editor_main(args, **kwargs):
def editor_main(
args: Sequence[str],
cwd: Optional["StrPath"] = None,
input: Optional[bytes] = None,
) -> "Generator[Editor, None, None]":
with pytest.MonkeyPatch().context() as m, Editor() as ed:
editor_cmd = " ".join(
shlex.quote(p)
Expand All @@ -135,11 +166,12 @@ def editor_main(args, **kwargs):
)
m.setenv("GIT_EDITOR", editor_cmd)

def main_wrapper():
def main_wrapper() -> Optional["subprocess.CompletedProcess[bytes]"]:
try:
return main(args, **kwargs)
return main(args, cwd=cwd, input=input)
except Exception as e:
ed.exception = e
return None
finally:
if not ed.exception:
ed.exception = Exception(
Expand All @@ -152,14 +184,23 @@ def main_wrapper():


class EditorFile(BaseHTTPRequestHandler):
def __init__(self, *args, **kwargs):
indata: Optional[bytes]
outdata: Optional[bytes]
server: "Editor"

def __init__(
self,
request: bytes,
client_address: Tuple[str, int],
server: "Editor",
) -> None:
self.response_ready = Event()
self.indata = None
self.outdata = None
self.exception = None
super().__init__(*args, **kwargs)
super().__init__(request=request, client_address=client_address, server=server)

def do_POST(self):
def do_POST(self) -> None:
length = int(self.headers.get("content-length"))
self.indata = self.rfile.read(length)
self.outdata = b""
Expand All @@ -174,41 +215,49 @@ def do_POST(self):
finally:
self.server.current = None

def send_editor_reply(self, status, data):
def send_editor_reply(self, status: int, data: bytes) -> None:
assert not self.response_ready.is_set(), "already replied?"
self.send_response(status)
self.send_header("content-length", len(data))
self.send_header("content-length", str(len(data)))
self.end_headers()
self.wfile.write(data)
self.response_ready.set()

# Ensure the handle thread has shut down
self.server.handle_thread.join()
self.server.handle_thread = None
if self.server.handle_thread is not None:
self.server.handle_thread.join()
self.server.handle_thread = None
assert self.server.current is None

def startswith(self, text):
def startswith(self, text: bytes) -> bool:
assert self.indata is not None
return self.indata.startswith(text)

def startswith_dedent(self, text):
def startswith_dedent(self, text: str) -> bool:
return self.startswith(textwrap.dedent(text).encode())

def equals(self, text):
def equals(self, text: bytes) -> bool:
return self.indata == text

def equals_dedent(self, text):
def equals_dedent(self, text: str) -> bool:
return self.equals(textwrap.dedent(text).encode())

def replace_dedent(self, text):
def replace_dedent(self, text: Union[str, bytes]) -> None:
if isinstance(text, str):
text = textwrap.dedent(text).encode()
self.outdata = text

def __enter__(self):
def __enter__(self) -> "EditorFile":
return self

def __exit__(self, etype, evalue, tb):
def __exit__(
self,
etype: Optional[Type[BaseException]],
evalue: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
if etype is None:
assert self.outdata
self.send_editor_reply(200, self.outdata)
else:
exc = "".join(traceback.format_exception(etype, evalue, tb)).encode()
Expand All @@ -217,12 +266,18 @@ def __exit__(self, etype, evalue, tb):
except:
pass

def __repr__(self):
def __repr__(self) -> str:
return f"<EditorFile {self.indata!r}>"


class Editor(HTTPServer):
def __init__(self):
request_ready: Event
handle_thread: Optional[Thread]
current: Optional[EditorFile]
exception: Optional[Exception]
timeout: int

def __init__(self) -> None:
# Bind to a randomly-allocated free port.
super().__init__(("127.0.0.1", 0), EditorFile)
self.request_ready = Event()
Expand All @@ -231,7 +286,7 @@ def __init__(self):
self.exception = None
self.timeout = 10

def next_file(self):
def next_file(self) -> EditorFile:
assert self.handle_thread is None
assert self.current is None

Expand All @@ -249,13 +304,18 @@ def next_file(self):
assert self.current
return self.current

def is_idle(self):
def is_idle(self) -> bool:
return self.handle_thread is None and self.current is None

def __enter__(self):
def __enter__(self) -> "Editor":
return self

def __exit__(self, etype, value, tb):
def __exit__(
self,
etype: Optional[Type[BaseException]],
value: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
try:
# Only assert if we're not already raising an exception.
if etype is None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from conftest import *


def test_cut(repo):
def test_cut(repo: Repository) -> None:
bash(
"""
echo "Hello, World" >> file1
Expand Down Expand Up @@ -46,7 +46,7 @@ def test_cut(repo):
assert new_uu == prev_uu


def test_cut_root(repo):
def test_cut_root(repo: Repository) -> None:
bash(
"""
echo "Hello, World" >> file1
Expand Down
Loading

0 comments on commit 5047ae3

Please sign in to comment.