diff --git a/changelogs/fragments/113-subprocess_util_loopback.yaml b/changelogs/fragments/113-subprocess_util_loopback.yaml new file mode 100644 index 0000000..8c86e29 --- /dev/null +++ b/changelogs/fragments/113-subprocess_util_loopback.yaml @@ -0,0 +1,6 @@ +--- +minor_changes: + - "``subprocess_util.async_log_run()``, ``subprocess_util.log_run()``, and + the corresponding functions in ``venv`` now support passing generic + callback functions for ``stdout_loglevel`` and ``stderr_loglevel`` + (https://github.com/ansible-community/antsibull-core/pull/113)." diff --git a/pyproject.toml b/pyproject.toml index 9d99a56..7874359 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ typing = [ "pyre-check >= 0.9.17", "types-aiofiles", "types-PyYAML", + "typing-extensions", ] dev = [ # Used by nox sessions diff --git a/src/antsibull_core/subprocess_util.py b/src/antsibull_core/subprocess_util.py index b2acbd9..8a11274 100644 --- a/src/antsibull_core/subprocess_util.py +++ b/src/antsibull_core/subprocess_util.py @@ -11,9 +11,12 @@ import asyncio import subprocess +import sys from asyncio.exceptions import IncompleteReadError, LimitOverrunError -from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, cast +from collections.abc import Awaitable, Callable, Sequence +from functools import partial +from inspect import isawaitable +from typing import TYPE_CHECKING, Any, TypeVar, cast from antsibull_core.logging import log @@ -22,15 +25,36 @@ from _typeshed import StrOrBytesPath from twiggy.logger import Logger as TwiggyLogger # type: ignore[import] + from typing_extensions import ParamSpec, TypeAlias + + _T = TypeVar("_T") + _P = ParamSpec("_P") mlog = log.fields(mod=__name__) CalledProcessError = subprocess.CalledProcessError +OutputCallbackType: TypeAlias = "Callable[[str], Any] | Callable[[str], Awaitable[Any]]" + +stdout_callback = print +stderr_callback = partial(print, file=sys.stderr) + + +async def _sync_or_async( + func: Callable[_P, Awaitable[_T]] | Callable[_P, _T], + /, + *args: _P.args, + **kwargs: _P.kwargs, +) -> _T: + out = func(*args, **kwargs) + if isawaitable(out): + return await out + return cast("_T", out) + async def _stream_log( name: str, - callback: Callable[[str], Any] | None, + callback: OutputCallbackType | None, stream: asyncio.StreamReader, errors: str, ) -> str: @@ -58,7 +82,7 @@ async def _stream_log( break text = line.decode("utf-8", errors=errors) if callback: - callback(f"{name}: {text.strip()}") + await _sync_or_async(callback, f"{name}{text.strip()}") lines.append(text) return "".join(lines) @@ -66,8 +90,8 @@ async def _stream_log( async def async_log_run( args: Sequence[StrOrBytesPath], logger: TwiggyLogger | StdLogger | None = None, - stdout_loglevel: str | None = None, - stderr_loglevel: str | None = "debug", + stdout_loglevel: str | OutputCallbackType | None = None, + stderr_loglevel: str | OutputCallbackType | None = "debug", check: bool = True, *, errors: str = "strict", @@ -83,8 +107,12 @@ async def async_log_run( :param logger: Logger in which to log the command. Can be a `twiggy.logger.Logger` or a stdlib `logger.Logger`. - :param stdout_loglevel: Which level to use to log stdout. `None` disables logging. - :param stderr_loglevel: Which level to use to log stderr. `None` disables logging. + :param stdout_loglevel: + Which level to use to log stdout or a generic callback function. + `None` disables logging. + :param stderr_loglevel: + Which level to use to log stdout or a generic callback function. + `None` disables logging. :param check: Whether to raise a `subprocess.CalledProcessError` when the command returns a non-zero exit code @@ -93,12 +121,21 @@ async def async_log_run( """ logger = logger or mlog stdout_logfunc: Callable[[str], Any] | None = None + stdout_log_prefix = "stdout: " if stdout_loglevel: - stdout_logfunc = getattr(logger, stdout_loglevel) + if callable(stdout_loglevel): + stdout_logfunc = stdout_loglevel + stdout_log_prefix = "" + else: + stdout_logfunc = getattr(logger, stdout_loglevel) stderr_logfunc: Callable[[str], Any] | None = None + stderr_log_prefix = "stderr: " if stderr_loglevel: - stderr_logfunc = getattr(logger, stderr_loglevel) - + if callable(stderr_loglevel): + stderr_logfunc = stderr_loglevel + stderr_log_prefix = "" + else: + stderr_logfunc = getattr(logger, stderr_loglevel) logger.debug(f"Running subprocess: {args!r}") kwargs["stdout"] = asyncio.subprocess.PIPE kwargs["stderr"] = asyncio.subprocess.PIPE @@ -108,7 +145,7 @@ async def async_log_run( # proc.stdout and proc.stderr won't be None with PIPE, hence the cast() asyncio.create_task( _stream_log( - "stdout", + stdout_log_prefix, stdout_logfunc, cast(asyncio.StreamReader, proc.stdout), errors, @@ -116,7 +153,7 @@ async def async_log_run( ), asyncio.create_task( _stream_log( - "stderr", + stderr_log_prefix, stderr_logfunc, cast(asyncio.StreamReader, proc.stderr), errors, @@ -136,8 +173,8 @@ async def async_log_run( def log_run( args: Sequence[StrOrBytesPath], logger: TwiggyLogger | StdLogger | None = None, - stdout_loglevel: str | None = None, - stderr_loglevel: str | None = "debug", + stdout_loglevel: str | OutputCallbackType | None = None, + stderr_loglevel: str | OutputCallbackType | None = "debug", check: bool = True, **kwargs, ) -> subprocess.CompletedProcess[str]: @@ -151,4 +188,11 @@ def log_run( ) -__all__ = ("async_log_run", "log_run", "CalledProcessError") +__all__ = ( + "async_log_run", + "log_run", + "CalledProcessError", + "stdout_callback", + "stderr_callback", + "OutputCallbackType", +) diff --git a/src/antsibull_core/venv.py b/src/antsibull_core/venv.py index d55564e..cccf5ce 100644 --- a/src/antsibull_core/venv.py +++ b/src/antsibull_core/venv.py @@ -98,8 +98,8 @@ async def async_log_run( self, args: Sequence[StrPath], logger: TwiggyLogger | StdLogger | None = None, - stdout_loglevel: str | None = None, - stderr_loglevel: str | None = "debug", + stdout_loglevel: str | subprocess_util.OutputCallbackType | None = None, + stderr_loglevel: str | subprocess_util.OutputCallbackType | None = "debug", check: bool = True, *, errors: str = "strict", @@ -133,8 +133,8 @@ def log_run( self, args: Sequence[StrPath], logger: TwiggyLogger | StdLogger | None = None, - stdout_loglevel: str | None = None, - stderr_loglevel: str | None = "debug", + stdout_loglevel: str | subprocess_util.OutputCallbackType | None = None, + stderr_loglevel: str | subprocess_util.OutputCallbackType | None = "debug", check: bool = True, *, errors: str = "strict", diff --git a/tests/units/test_subprocess_util.py b/tests/units/test_subprocess_util.py index b7ffa5e..22890f4 100644 --- a/tests/units/test_subprocess_util.py +++ b/tests/units/test_subprocess_util.py @@ -72,3 +72,20 @@ def test_log_run_long_line(count: int) -> None: assert proc.args == args assert proc.returncode == 0 assert proc.stdout == ("\u0000" * count) + "\nfoo\n" + + +def test_log_run_callback() -> None: + stdout_lines: list[str] = [] + stderr_lines: list[str] = [] + + async def add_to_stderr(string: str, /) -> None: + stderr_lines.append(string) + + antsibull_core.subprocess_util.log_run( + ["sh", "-c", "echo Never; echo gonna >&2; echo give"], + None, + stdout_lines.append, + add_to_stderr, + ) + assert stdout_lines == ["Never", "give"] + assert stderr_lines == ["gonna"]