diff --git a/docs/index.rst b/docs/index.rst index e631a35..b7acc26 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,57 +1,141 @@ quicken -=================================== +======= -Quicken is a Python package that enables CLI-based tools to start more quickly. +Quicken is a library that helps Python applications start more quickly, with a focus on doing the "right thing" for +wrapping CLI-based tools. -When added to a command-line tool, Quicken starts a server transparently the first time the tool is invoked. -After the server is running, it is responsible for executing commands. This is fast because imports happen once at -server start. +When a quicken-wrapped command is executed the first time, an application server will be started. If the server +is already up then the command will be executed in a ``fork``ed child, which avoids the overhead of loading +libraries. -The library is transparent for users. Every time a command is run all context is sent to the server, including: +Quicken only speeds up applications on Unix platforms, but falls back to executing commands directly +on non-Unix platforms. -* arguments +The library tries to be transparent for applications. Every time a command is run all context is sent from the +client to the server, including: + +* command-line arguments * current working directory * environment * umask * file descriptors for stdin/stdout/stderr -Usage: - -Assume your application is ``my_app`` and your original CLI entrypoint is ``my_app.cli.cli``. Create a file ``my_app/cli_wrapper.py``, with contents: - -.. code-block:: python - - from quicken import cli_factory +``quicken.script`` +================== +For command-line tool authors that want to speed up applications, simply add quicken as a +dependency, then use ``quicken.script`` in your ``console_scripts`` (or equivalent). - @cli_factory('my_app') - def main(): - # Import your existing command-line entrypoint. - # This is the expensive operation that only happens once. - from .cli import cli - # Return it. - return cli +If your existing entry point is ``my_app.cli:main``, then you would use ``quicken.script:my_app.cli._.main``. -Adapt ``setup.py``: +For example, if using setuptools (``setup.py``): .. code-block:: python setup( ... entry_points={ - 'console_scripts': ['my-command=my_app.cli_wrapper:cli'] + 'console_scripts': [ + 'my-command=my_app.cli:main', + # With quicken + 'my-commandc=quicken.script:my_app.cli._.main', + ], }, ... ) -If you have ``my_app/__main__.py``, it should look like: +If using poetry + +.. code-block:: toml + + [tools.poetry.scripts] + poetry = 'poetry:console.run' + # With quicken + poetryc = 'quicken.script:poetry._.console.run' + +If using flit + +.. code-block:: toml + + [tools.flit.scripts] + flit = "flit:main" + # With quicken + flitc = "quicken.script:flit._.main" + + +``quicken`` command +=================== + +The ``quicken`` command can be used with basic scripts and command-line tools that do not use quicken built-in. + +Given a script ``script.py``, like .. code-block:: python - from .cli_wrapper import main + import click + import requests + + ... + + @click.command() + def main(): + """My script.""" + ... + + if __name__ == '__main__': + main() + + +running ``quicken -f script.py arg1 arg2`` once will start the application server then run ``main()``. Running the command +again like ``quicken -f script.py arg2 arg3`` will run the command on the server, and should be faster + +If ``script.py`` is changed then the server will be restarted the next time the command is run. + + +Differences and restrictions +============================ + +The library tries to be transparent for applications, but it cannot be exactly the same. Specifically here +is the behavior you can expect: + +* ``sys.argv`` is set to the list of arguments of the client +* ``sys.stdin``, ``sys.stdout``, and ``sys.stderr`` are sent from the client to the command process, any console loggers + are re-initialized with the new streams +* ``os.environ`` is copied from the client to the command process +* we change directory to the cwd of the client process + +The setup above is guaranteed to be done before the registered script function +is invoked. + +In addition: + +* signals sent to the client that can be forwarded are sent to the command process +* for SIGTSTP (C-z at a terminal) and SIGTTIN, we send SIGSTOP to the command process and then stop the client process +* for SIGKILL, which cannot be intercepted, the server recognizes when the connection to the client is broken and will + kill the command process soon after +* when the command runner exits, the client exits with the same exit code +* if a client and the server differ in group id or supplementary group ids then a new + server process is started before the command is run + +While the registered script function is executed in a command process, the initial import of +the module is done by the first client that is executed. For that reason, there are several things that +should be avoided outside of the registered script function: + +1. starting threads - because we fork to create the command runner process, it may cause undesirable effects if + threads are created +2. read configuration based on environment, arguments, or current working directory - if done when a module is imported + then it will capture the values of the client that started the server +3. set signal handlers - this will only be setting signal handlers for the first client starting + the server and these are overridden to forward signals to the command runner process +4. start sub-processes +5. reading plugin information (from e.g. ``pkg_resources``) - this will only be at server start time, + and not when the command is actually run + +Currently the following is unsupported at any point: +* ``atexit`` handlers - they will not be run at the end of the handler process +* ``setuid`` or ``setgid`` are currently unsupported - main() .. toctree:: :maxdepth: 2 diff --git a/docs/requirements.txt b/docs/requirements.txt index ef30af4..b98f9e1 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -8,6 +8,7 @@ certifi==2019.3.9 ch==0.1.2 chardet==3.0.4 coverage==4.5.3 +demandimport==0.3.4 docutils==0.14 fasteners==0.14.1 idna==2.8 @@ -23,7 +24,7 @@ packaging==19.0 pathtools==0.1.2 pid==2.2.3 pluggy==0.9.0 -psutil==5.6.1 +psutil==5.6.2 py==1.8.0 py-cpuinfo==5.0.0 pydevd==1.6.0 diff --git a/examples/slow_start/app.py b/examples/slow_start/app.py index 7de136a..ab75163 100644 --- a/examples/slow_start/app.py +++ b/examples/slow_start/app.py @@ -14,3 +14,7 @@ def cli(): print('cli()') print('cli2()') logger.info('cli()') + + +if __name__ == '__main__': + cli() diff --git a/poetry.lock b/poetry.lock index 9769e02..4f1bce4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -232,12 +232,12 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" version = "0.9.0" [[package]] -category = "main" +category = "dev" description = "Cross-platform lib for process and system monitoring in Python." name = "psutil" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "5.6.1" +version = "5.6.2" [[package]] category = "dev" @@ -500,7 +500,7 @@ python-versions = "*" version = "1.11.1" [metadata] -content-hash = "8292428c85fa528b53b6c125419ee7ddcbb0ff87e89796b8f9c5e5ea9d373fe8" +content-hash = "772a6ef0a6b4cdb74f7f8213ed4b51821ff463d43dac824c74ae2876c2fa5f59" python-versions = "^3.7" [metadata.hashes] @@ -530,7 +530,7 @@ packaging = ["0c98a5d0be38ed775798ece1b9727178c4469d9c3b4ada66e8e6b7849f8732af", pathtools = ["7c35c5421a39bb82e58018febd90e3b6e5db34c5443aaaf742b3f33d4655f1c0"] pid = ["077da788630394adce075c88f4a087bcdb27d98cab67eb9046ebcfeedfc1194d"] pluggy = ["19ecf9ce9db2fce065a7a0586e07cfb4ac8614fe96edf628a264b1c70116cf8f", "84d306a647cc805219916e62aab89caa97a33a1dd8c342e87a37f91073cd4746"] -psutil = ["23e9cd90db94fbced5151eaaf9033ae9667c033dffe9e709da761c20138d25b6", "27858d688a58cbfdd4434e1c40f6c79eb5014b709e725c180488ccdf2f721729", "354601a1d1a1322ae5920ba397c58d06c29728a15113598d1a8158647aaa5385", "9c3a768486194b4592c7ae9374faa55b37b9877fd9746fb4028cb0ac38fd4c60", "c1fd45931889dc1812ba61a517630d126f6185f688eac1693171c6524901b7de", "d463a142298112426ebd57351b45c39adb41341b91f033aa903fa4c6f76abecc", "e1494d20ffe7891d07d8cb9a8b306c1a38d48b13575265d090fc08910c56d474", "ec4b4b638b84d42fc48139f9352f6c6587ee1018d55253542ee28db7480cc653", "fa0a570e0a30b9dd618bffbece590ae15726b47f9f1eaf7518dfb35f4d7dcd21"] +psutil = ["206eb909aa8878101d0eca07f4b31889c748f34ed6820a12eb3168c7aa17478e", "649f7ffc02114dced8fbd08afcd021af75f5f5b2311bc0e69e53e8f100fe296f", "6ebf2b9c996bb8c7198b385bade468ac8068ad8b78c54a58ff288cd5f61992c7", "753c5988edc07da00dafd6d3d279d41f98c62cd4d3a548c4d05741a023b0c2e7", "76fb0956d6d50e68e3f22e7cc983acf4e243dc0fcc32fd693d398cb21c928802", "828e1c3ca6756c54ac00f1427fdac8b12e21b8a068c3bb9b631a1734cada25ed", "a4c62319ec6bf2b3570487dd72d471307ae5495ce3802c1be81b8a22e438b4bc", "acba1df9da3983ec3c9c963adaaf530fcb4be0cd400a8294f1ecc2db56499ddd", "ef342cb7d9b60e6100364f50c57fa3a77d02ff8665d5b956746ac01901247ac4"] py = ["64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa", "dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53"] py-cpuinfo = ["2cf6426f776625b21d1db8397d3297ef7acfa59018f02a8779123f3190f18500"] pydevd = ["2308d15cc887b12917b137789ee7a2b3a9cef02f8636eea02ad952821a4a39ce", "23af20bdbc677dac38bfe92e6f8845063a2ed8d98989ac0820e2e9a84b837079", "43b69eb4f567403d91d8c57493b5b97525cd30da3ab00e6d0a2cc7ccd5312eb9", "5a0aa83181655e3acbd82f46b43a38815e84df0421b0230127e583d0731c0109", "5f6fdf411c2194c8873f032e30ce191b372feb94589b193cce5bcd5c667a05da", "840244b0d79d0cb7beb6782882a795c6728a34db2d7f7848d98d197a1b13c77c", "8726342890bd2b36cc471f8c8aabf694ffcdc066cb66528837d0394661620784", "cebd35cf33f765fdb1f70f8fa705a317d3258761633f5ee051b8a1c0b354d47a", "d0dbe33663fdfe186a01790a5310d1d9420c53721ab4ccd468e4236a51888fe8", "e87374d9146b662554ddb775d417aba9d2f2da3f2ff0a3a0fec7e0b1867dafc2", "feab08a1df822c653891408ad20de1762ce56cbd2b47cf866c13eb929d494336"] diff --git a/pyproject.toml b/pyproject.toml index 9ec1077..f1527ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,6 @@ repository = "https://github.com/chrahunt/quicken" [tool.poetry.dependencies] python = "^3.7" pid = "^2.2" -psutil = "^5.4" tblib = "^1.3" fasteners = "^0.14.1" @@ -32,6 +31,11 @@ sphinxcontrib-trio = "^1.0" ch = "^0.1.2" pytest-benchmark = "^3.2" +[tool.poetry.scripts] +quicken = "quicken._cli:main" +quicken-c = "quicken.script:quicken._cli._.main" +quicken-ctl = "quicken.ctl_script:quicken._cli._.main" + [build-system] requires = ["poetry>=0.12"] build-backend = "poetry.masonry.api" diff --git a/quicken/__init__.py b/quicken/__init__.py index 2213888..e294182 100644 --- a/quicken/__init__.py +++ b/quicken/__init__.py @@ -1,8 +1,6 @@ -from ._decorator import quicken - +""" +""" +# Top-level package should be free from any unnecessary imports, since they are +# unconditional. __version__ = '0.1.0' - - -class QuickenError(Exception): - pass diff --git a/quicken/__main__.py b/quicken/__main__.py new file mode 100644 index 0000000..8298230 --- /dev/null +++ b/quicken/__main__.py @@ -0,0 +1,5 @@ +from ._cli import main + + +if __name__ == '__main__': + main() diff --git a/quicken/_cli.py b/quicken/_cli.py index 6364aef..d16a319 100644 --- a/quicken/_cli.py +++ b/quicken/_cli.py @@ -1,297 +1,374 @@ -"""CLI wrapper interface for starting/using server process. +"""CLI for running Python code in an application server. + +Argument types: +[command, '--', *args] +['-f', file, '--', *args] +['-h'] +['-c', code, '--', *args] +['-m', module, '--', *args] +['-s', script, '--', *args] + +The general issue this solves is: given a single piece of executable code, split +it into two parts: + +1. The part useful to preload that can be executed without issues +2. The part that should not be executed unless explicitly requested + +MainProvider: Splitting the code and executing the code matched to part 1 +Main: part 2 + +For the invoked code we try to align with the equivalent Python behavior, where +it makes sense: + +1. For files: + 1. Python sets __file__ to the path as provided to the Python interpreter, + but since initial import occurs in a directory that may be different than + the runtime execution directory, we normalize __file__ to be an absolute + path - any usages of __file__ within methods would then be correct. + 2. Python sets sys.argv[0] to the path as provided to the Python interpreter, + this behavior is OK since if relative it will be relative to cwd which will + be set by the time the if __name__ == '__main__' block is called. +1. For modules: + 1. Python sets __file__ to the absolute path of the file, we should do the same. + 2. Python sets sys.argv[0] to the absolute path of the file, we should do the same. """ -import json -import logging -import multiprocessing +from __future__ import annotations + +from ._timings import report +report('start cli load') + +import argparse +import ast +import importlib.util import os +import stat +import sys + +try: + # Faster to import than hashlib if _sha512 is present. See e.g. python/cpython#12742 + from _sha512 import sha512 as _sha512 +except ImportError: + from hashlib import sha512 as _sha512 -from contextlib import contextmanager from functools import partial -from pathlib import Path -from typing import Callable, Optional - -from fasteners import InterProcessLock - -from . import QuickenError -from ._client import Client -from ._constants import socket_name, server_state_name -from ._protocol import ProcessState, Request, RequestTypes -from ._signal import blocked_signals, forwarded_signals, SignalProxy -from ._typing import JSONType -from ._xdg import cache_dir, chdir, RuntimeDir - - -logger = logging.getLogger(__name__) - - -MainFunction = Callable[[], Optional[int]] -MainProvider = Callable[[], MainFunction] - - -def check_res_ids(): - ruid, euid, suid = os.getresuid() - if not ruid == euid == suid: - raise QuickenError( - f'real uid ({ruid}), effective uid ({euid}), and saved uid ({suid})' - ' must be the same' - ) - - rgid, egid, sgid = os.getresgid() - if not rgid == egid == sgid: - raise QuickenError( - f'real gid ({rgid}), effective gid ({egid}), and saved gid ({sgid})' - ' must be the same' - ) - - -def need_server_reload(manager, reload_server, user_data): - server_state = manager.server_state - gid = os.getgid() - if gid != server_state['gid']: - logger.info('Reloading server due to gid change') - return True - - # XXX: Will not have the intended effect on macOS, see os.getgroups() for - # details. - groups = os.getgroups() - if set(groups) != set(server_state['groups']): - logger.info('Reloading server due to changed groups') - return True - - if reload_server: - previous_user_data = manager.user_data - if reload_server(previous_user_data, user_data): - logger.info('Reload requested by callback, stopping server.') - return True - - # TODO: Restart based on library version. - return False - - -def _server_runner_wrapper( - name: str, - main_provider: MainProvider, - # /, - *, - runtime_dir_path: Optional[str] = None, - log_file: Optional[str] = None, - server_idle_timeout: Optional[float] = None, - reload_server: Callable[[JSONType, JSONType], bool] = None, - user_data: JSONType = None, -) -> Optional[int]: - """Run operation in server identified by name, starting it if required. - """ - check_res_ids() - try: - json.dumps(user_data) - except TypeError as e: - raise QuickenError('user_data must be serializable') from e +from .lib._typing import MYPY_CHECK_RUNNING +from .lib._xdg import RuntimeDir - if log_file is None: - log_file = cache_dir(f'quicken-{name}') / 'server.log' - log_file = Path(log_file).absolute() +if MYPY_CHECK_RUNNING: + from typing import List - main_provider = partial(with_reset_authkey, main_provider) - runtime_dir = RuntimeDir(f'quicken-{name}', runtime_dir_path) +report('end cli load dependencies') - manager = CliServerManager( - main_provider, runtime_dir, log_file, server_idle_timeout, user_data + +def run(name, metadata, callback): + report('start quicken load') + from .lib import quicken + report('end quicken load') + + def reload(old_data, new_data): + return old_data != new_data + + log_file = os.environ.get('QUICKEN_LOG') + + decorator = quicken( + name, reload_server=reload, log_file=log_file, user_data=metadata ) - with manager.lock: - need_start = False - try: - client = manager.connect() - except ConnectionFailed as e: - logger.info('Failed to connect to server due to %s.', e) - need_start = True - else: - if need_server_reload(manager, reload_server, user_data): - manager.stop_server() - need_start = True + return decorator(callback)() - if need_start: - logger.info('Starting server') - manager.start_server() - client = manager.connect() - proxy = SignalProxy() - # We must block signals before requesting remote process start otherwise - # a user signal to the client may race with our ability to propagate it. - with blocked_signals(forwarded_signals): - state = ProcessState.for_current_process() - logger.debug('Requesting process start') - req = Request(RequestTypes.run_process, state) - response = client.send(req) - pid = response.contents - logger.debug('Process running with pid: %d', pid) - proxy.set_target(pid) +def is_main(node): + """Whether a node represents: + if __name__ == '__main__': + if '__main__' == __name__: + """ + if not isinstance(node, ast.If): + return False - logger.debug('Waiting for process to finish') - response = client.send(Request(RequestTypes.wait_process_done, None)) - client.close() - return response.contents + test = node.test + if not isinstance(test, ast.Compare): + return False -def reset_authkey(): - multiprocessing.current_process().authkey = os.urandom(32) + if len(test.ops) != 1 or not isinstance(test.ops[0], ast.Eq): + return False + if len(test.comparators) != 1: + return False -def with_reset_authkey(main_provider): - """Ensure that user code is not executed without an authkey set. - """ - main = main_provider() + left = test.left + right = test.comparators[0] - def inner(): - reset_authkey() - return main() + if isinstance(left, ast.Name): + name = left + elif isinstance(right, ast.Name): + name = right + else: + return False - return inner + if isinstance(left, ast.Str): + main = left + elif isinstance(right, ast.Str): + main = right + else: + return False + if name.id != '__name__': + return False -class ConnectionFailed(Exception): - pass + if not isinstance(name.ctx, ast.Load): + return False + if main.s != '__main__': + return False -class CliServerManager: - """Responsible for starting (if applicable) and connecting to the server. + return True - Race conditions are prevented by acquiring an exclusive lock on - {runtime_dir}/admin during connection and start. + +# XXX: May be nicer to use a Loader implemented for our purpose. +def parse_file(path: str): """ - def __init__( - self, factory_fn, runtime_dir: RuntimeDir, log_file, - server_idle_timeout, user_data): - """ - Args: - factory_fn: function that provides the server request handler - runtime_dir: runtime dir used for locks/socket - log_file: server log file - server_idle_timeout: idle timeout communicated to server if the - process of connecting results in server start - user_data: added to server state - """ - self._factory = factory_fn - self._runtime_dir = runtime_dir - self._log_file = log_file - self._idle_timeout = server_idle_timeout - self._user_data = user_data - self._lock = InterProcessLock('admin') - - def connect(self) -> Client: - """Attempt to connect to the server. - - Returns: - Client connected to the server - - Raises: - ConnectionFailed on connection failure (server not up or accepting) - """ - assert self._lock.acquired, 'connect must be called under lock.' - - with chdir(self._runtime_dir): - try: - return Client(socket_name) - except FileNotFoundError as e: - raise ConnectionFailed('File not found') from e - except ConnectionRefusedError as e: - raise ConnectionFailed('Connection refused') from e + Given a path, parse it into a + Parse a file into prelude and main sections. + + We assume that the "prelude" is anything before the first "if __name__ == '__main__'". + + Returns annotated code objects as expected. + """ + path = os.path.abspath(path) + with open(path, 'rb') as f: + text = f.read() + + root = ast.parse(text, filename=path) + for i, node in enumerate(root.body): + if is_main(node): + break + else: + raise RuntimeError('Must have if __name__ == "__main__":') + + prelude = ast.copy_location( + ast.Module(root.body[:i]), root + ) + main = ast.copy_location( + ast.Module(root.body[i:]), root + ) + prelude_code = compile(prelude, filename=path, dont_inherit=True, mode="exec") + main_code = compile(main, filename=path, dont_inherit=True, mode="exec") + # Shared context. + context = { + '__name__': '__main__', + '__file__': path, + } + prelude_func = partial(exec, prelude_code, context) + main_func = partial(exec, main_code, context) + return prelude_func, main_func + + +class PathHandler: + def __init__(self, path, args): + report('start handle_path()') + self._path_arg = path + + path = os.path.abspath(path) + self._path = path + + self._args = args + + real_path = os.path.realpath(path) + digest = _sha512(path.encode('utf-8')).hexdigest() + self._name = f'quicken.file.{digest}' + + stat_result = os.stat(real_path) + + self._metadata = { + 'path': path, + 'real_path': real_path, + 'ctime': stat_result[stat.ST_CTIME], + 'mtime': stat_result[stat.ST_MTIME], + } @property - def server_state(self): - with chdir(self._runtime_dir): - text = Path(server_state_name).read_text(encoding='utf-8') - return json.loads(text) + def argv(self): + return [self._path_arg, *self._args] @property - def user_data(self): - """Returns user data for the current server. - """ - return self.server_state['user_data'] - - def start_server(self): - """Start server as background process. - - This function only returns in the parent, not the background process. - - By the time this function returns it is safe to call connect(). - """ - assert self._lock.acquired, 'start_server must be called under lock.' - # XXX: Should have logging around this, for timing. - main = self._factory() - # Lazy import so we only take the time to import if we have to start - # the server. - # XXX: Should have logging around this, for timing. - from ._server import run - - with chdir(self._runtime_dir): - try: - os.unlink(socket_name) - except FileNotFoundError: - pass - - run( - main, - log_file=self._log_file, - runtime_dir=self._runtime_dir, - server_idle_timeout=self._idle_timeout, - user_data=self._user_data, - ) - - def stop_server(self): - assert self._lock.acquired, 'stop_server must be called under lock.' - from psutil import NoSuchProcess, Process - server_state = self.server_state - pid = server_state['pid'] - create_time = server_state['create_time'] + def name(self): + return self._name + + @property + def metadata(self): + return self._metadata + + def main(self): + report('start file processing') + prelude_code, main_code = parse_file(self._path) + # Execute everything before if __name__ == '__main__': + prelude_code() + report('end file processing') + # Pass main back to be executed by the server. + return main_code + + +# Adapted from https://github.com/python/cpython/blob/e42b705188271da108de42b55d9344642170aa2b/Lib/runpy.py#L101 +# with changes: +# * we do not actually want to retrieve the module code yet +def _get_module_details(mod_name, error=ImportError): + if mod_name.startswith("."): + raise error("Relative module names not supported") + pkg_name, _, _ = mod_name.rpartition(".") + if pkg_name: + # Try importing the parent to avoid catching initialization errors try: - process = Process(pid=pid) - except NoSuchProcess: - logger.debug( - f'Daemon reload requested but process with pid {pid}' - ' does not exist.') - return - - if process.create_time() != create_time: - logger.debug( - 'Daemon reload requested but start time does not match' - ' expected (probably new process re-using pid), skipping.') - return + __import__(pkg_name) + except ImportError as e: + # If the parent or higher ancestor package is missing, let the + # error be raised by find_spec() below and then be caught. But do + # not allow other errors to be caught. + if e.name is None or (e.name != pkg_name and + not pkg_name.startswith(e.name + ".")): + raise + # Warn if the module has already been imported under its normal name + existing = sys.modules.get(mod_name) + if existing is not None and not hasattr(existing, "__path__"): + from warnings import warn + msg = "{mod_name!r} found in sys.modules after import of " \ + "package {pkg_name!r}, but prior to execution of " \ + "{mod_name!r}; this may result in unpredictable " \ + "behaviour".format(mod_name=mod_name, pkg_name=pkg_name) + warn(RuntimeWarning(msg)) + try: + spec = importlib.util.find_spec(mod_name) + except (ImportError, AttributeError, TypeError, ValueError) as ex: + # This hack fixes an impedance mismatch between pkgutil and + # importlib, where the latter raises other errors for cases where + # pkgutil previously raised ImportError + msg = "Error while finding module specification for {!r} ({}: {})" + raise error(msg.format(mod_name, type(ex).__name__, ex)) from ex + if spec is None: + raise error("No module named %s" % mod_name) + if spec.submodule_search_locations is not None: + if mod_name == "__main__" or mod_name.endswith(".__main__"): + raise error("Cannot use package as __main__ module") try: - # We don't want to leave it to the server to remove the socket since - # we do not wait for it. - with chdir(self._runtime_dir): - os.unlink(socket_name) - except FileNotFoundError: - # No problem, if the file was removed at some point it doesn't - # impact us. - pass - # This will cause the server to stop accepting clients and start - # shutting down. It will wait for any still-running processes before - # stopping completely, but it does not consume any other resources that - # we are concerned with. - process.terminate() + pkg_main_name = mod_name + ".__main__" + return _get_module_details(pkg_main_name, error) + except error as e: + if mod_name not in sys.modules: + raise # No module loaded; being a package is irrelevant + raise error(("%s; %r is a package and cannot " + + "be directly executed") %(e, mod_name)) + loader = spec.loader + if loader is None: + raise error("%r is a namespace package and cannot be executed" + % mod_name) + return mod_name, spec + + +class ModuleHandler: + """ + For modules it can go several ways: + 1. top-level module which does have if __name__ == "__main__" (pytest) + 2. __main__ module which does have if __name__ == "__main__" (pip) + 3. __main__ module which does not have if __name__ == "__main__" (flit) + 4. __main__ module which does have if __name__ == "__main__" but does imports + underneath it (poetry) + + As a result, and since they are pretty small usually, we can be more flexible + with parsing - + 1. extract all top-level import statements or import statements under if __name__ == "__main__" + into their own unit + 2. execute the import unit at server start + 3. execute the rest + this will only cause problems if some tool has order-dependent imports underneath e.g. a + platform check and we can trace a warning if that is the case. + """ + def __init__(self, module_name, args): + report('start ModuleHandler') + self._module_name, self._spec = _get_module_details(module_name) - @property - @contextmanager - def lock(self): - with chdir(self._runtime_dir): - self._lock.acquire() - - # We initialize the Client and Listener classes without an authkey - # parameter since there's no way to pre-share the secret securely - # between processes not part of the same process tree. However, the - # internal Client/Listener used as part of - # multiprocessing.resource_sharer DOES initialize its own Client and - # Listener with multiprocessing.current_process().authkey. We must have - # some value so we use this dummy value. - multiprocessing.current_process().authkey = b'0' * 32 + def main(self): + loader = self._spec.loader try: - yield - finally: - self._lock.release() + code = loader.get_code(self._module_name) + except ImportError as e: + raise ImportError(format(e)) from e + if code is None: + raise ImportError("No code object available for %s" % self._module_name) + + +def parse_args(args=None): + #args = preprocess_args(args) + parser = argparse.ArgumentParser( + description=''' + Invoke Python commands in an application server. + ''' + ) + parser.add_argument( + '--ctl', + choices=['status', 'stop'], + help='server control' + ) + selector = parser.add_mutually_exclusive_group(required=True) + # We have to have an option name otherwise the first value in `args` might + # be taken as the script path. + selector.add_argument('-f', help='path to script') + # Deferred. + #selector.add_argument('-m', help='module name') + + parser.add_argument('args', nargs='*') + + parsed = parser.parse_args(args) + return parser, parsed + + +def main(): + report('start main()') + parser, args = parse_args() + if args.f: + path = args.f + if not os.path.exists(path): + parser.error(f'{path} does not exist') + handler = PathHandler(path, args.args) + #elif args.m: + # # We do not have a good strategy for avoiding import of the parent module + # # so for now just reject. + # if '.' in args.m: + # parser.error('Sub-modules are not supported') + # handler = ModuleHandler(args.m, args.args) + else: + parser.print_usage(sys.stderr) + sys.exit(1) + + sys.argv = handler.argv + + if args.ctl: + from .lib._lib import CliServerManager, ConnectionFailed + + # TODO: De-duplicate runtime dir name construction. + runtime_dir = RuntimeDir(f'quicken-{handler.name}') + + manager = CliServerManager(runtime_dir) + + with manager.lock: + try: + client = manager.connect() + except ConnectionFailed: + print('Server down') + sys.exit(0) + else: + client.close() + + if args.ctl == 'status': + print(manager.server_state) + elif args.ctl == 'stop': + manager.stop_server() + else: + sys.stderr.write('Unknown action') + sys.exit(1) + else: + sys.exit(run(handler.name, handler.metadata, handler.main)) diff --git a/quicken/_scripts.py b/quicken/_scripts.py new file mode 100644 index 0000000..1d25d82 --- /dev/null +++ b/quicken/_scripts.py @@ -0,0 +1,56 @@ +"""Helpers for "console_scripts"/"script" interceptors. +""" +def parse_script_spec(parts): + """ + Returns: + (module_parts, function_parts) + """ + try: + i = parts.index('_') + except ValueError: + return parts, [] + else: + return parts[:i], parts[i+1:] + + +def get_nested_attr(o, parts): + for name in parts: + o = getattr(o, name) + return o + + +def get_attribute_accumulator(callback, context=None): + """Who knows what someone may put in their entry point spec. + + We try to take the most flexible approach here and accept as much as + possible. + + Args: + callback: called when the accumulator is called with the gathered + names as the first argument. + context: names that should have explicit returned values. + """ + # Use variable in closure to reduce chance of conflicting name. + parts = [] + + class Accumulator: + def __getattribute__(self, name): + if name == '__call__': + return object.__getattribute__(self, name) + + if context: + try: + return context[name] + except KeyError: + pass + + parts.append(name) + return self + + def __call__(self, *args, **kwargs): + nonlocal parts + current_parts = parts + parts = [] + return callback(current_parts, *args, **kwargs) + + return Accumulator() diff --git a/quicken/_timings.py b/quicken/_timings.py new file mode 100644 index 0000000..95dc998 --- /dev/null +++ b/quicken/_timings.py @@ -0,0 +1,11 @@ +import os +import time + + +def report(name): + """Used for reporting important timings. + + PYTHONHUNTER='kind="call", function="report", module="quicken._timings"' + """ + if os.environ.get('QUICKEN_TRACE_TIMINGS'): + print(f'{time.perf_counter()}: {name}') diff --git a/quicken/_typing.py b/quicken/_typing.py deleted file mode 100644 index 68970db..0000000 --- a/quicken/_typing.py +++ /dev/null @@ -1,5 +0,0 @@ -from typing import Callable, List, Mapping, Union - - -NoneFunction = Callable[[], None] -JSONType = Union[str, int, float, bool, None, Mapping[str, 'JSONType'], List['JSONType']] diff --git a/quicken/ctl_script.py b/quicken/ctl_script.py new file mode 100644 index 0000000..2dd66bf --- /dev/null +++ b/quicken/ctl_script.py @@ -0,0 +1,61 @@ +"""Entrypoint wrapper that represents a control script. + +Valid commands are: + +* stop +* status +""" +from .lib._lib import CliServerManager, ConnectionFailed +from .lib._xdg import RuntimeDir +from ._scripts import ( + get_attribute_accumulator, parse_script_spec +) + +import sys + + +__all__ = [] + + +def parse_args(): + if len(sys.argv) == 1: + sys.stderr.write( + f'Usage: {sys.argv[0]} stop|status' + ) + sys.exit(1) + + return sys.argv[1] + + +def callback(parts): + module_parts, function_parts = parse_script_spec(parts) + module_name = '.'.join(module_parts) + function_name = '.'.join(function_parts) + # TODO: De-duplicate key creation. + name = f'quicken.entrypoint.{module_name}.{function_name}' + + # TODO: De-duplicate runtime dir name construction. + runtime_dir = RuntimeDir(f'quicken-{name}') + + manager = CliServerManager(runtime_dir) + + action = parse_args() + + with manager.lock: + try: + client = manager.connect() + except ConnectionFailed: + print('Server down') + sys.exit(0) + else: + client.close() + + if action == 'status': + print(manager.server_state) + elif action == 'stop': + manager.stop_server() + else: + sys.stderr.write('Unknown action') + sys.exit(1) + +sys.modules[__name__] = get_attribute_accumulator(callback) diff --git a/quicken/lib/__init__.py b/quicken/lib/__init__.py new file mode 100644 index 0000000..5fa2407 --- /dev/null +++ b/quicken/lib/__init__.py @@ -0,0 +1,5 @@ +from ._decorator import quicken + + +class QuickenError(Exception): + pass diff --git a/quicken/_asyncio.py b/quicken/lib/_asyncio.py similarity index 90% rename from quicken/_asyncio.py rename to quicken/lib/_asyncio.py index 7cf561e..f822244 100644 --- a/quicken/_asyncio.py +++ b/quicken/lib/_asyncio.py @@ -1,6 +1,12 @@ """Asyncio utility classes. """ -import asyncio +from __future__ import annotations + +from ._typing import MYPY_CHECK_RUNNING + + +if MYPY_CHECK_RUNNING: + import asyncio class DeadlineTimer: diff --git a/quicken/_client.py b/quicken/lib/_client.py similarity index 57% rename from quicken/_client.py rename to quicken/lib/_client.py index a8bba3a..782d52b 100644 --- a/quicken/_client.py +++ b/quicken/lib/_client.py @@ -1,6 +1,10 @@ -import multiprocessing.connection +from __future__ import annotations -from ._protocol import Request, Response +from ._imports import multiprocessing_connection +from ._typing import MYPY_CHECK_RUNNING + +if MYPY_CHECK_RUNNING: + from ._protocol import Request, Response class Client: @@ -8,7 +12,7 @@ class Client: multiprocessing.connection.Client. """ def __init__(self, *args, **kwargs): - self._client = multiprocessing.connection.Client(*args, **kwargs) + self._client = multiprocessing_connection.Client(*args, **kwargs) def send(self, request: Request) -> Response: self._client.send(request) diff --git a/quicken/_constants.py b/quicken/lib/_constants.py similarity index 68% rename from quicken/_constants.py rename to quicken/lib/_constants.py index 3ba3407..6d53d22 100644 --- a/quicken/_constants.py +++ b/quicken/lib/_constants.py @@ -1,2 +1,3 @@ -socket_name = 'socket' server_state_name = 'state.json' +socket_name = 'socket' +stop_socket_name = 'stop' diff --git a/quicken/_decorator.py b/quicken/lib/_decorator.py similarity index 87% rename from quicken/_decorator.py rename to quicken/lib/_decorator.py index 8d28464..c0686d7 100644 --- a/quicken/_decorator.py +++ b/quicken/lib/_decorator.py @@ -1,15 +1,18 @@ """User-facing decorator. """ +from __future__ import annotations + import sys from functools import wraps -from typing import Callable, Optional -from ._typing import JSONType +from ._typing import MYPY_CHECK_RUNNING +from .._timings import report +if MYPY_CHECK_RUNNING: + from typing import Callable, Optional -MainFunction = Callable[[], Optional[int]] -MainProvider = Callable[[], MainFunction] + from ._types import JSONType, MainFunction, MainProvider def quicken( @@ -65,8 +68,10 @@ def wrapper() -> Optional[int]: return main_provider()() # Lazy import to avoid overhead. - from ._cli import _server_runner_wrapper - return _server_runner_wrapper( + report('load quicken library') + from ._lib import server_runner_wrapper + report('end load quicken library') + return server_runner_wrapper( name, main_provider, runtime_dir_path=runtime_dir_path, diff --git a/quicken/lib/_import.py b/quicken/lib/_import.py new file mode 100644 index 0000000..0501d07 --- /dev/null +++ b/quicken/lib/_import.py @@ -0,0 +1,42 @@ +"""Some packages are a little overzealous with package-level imports. + +We don't need all functionality they offer and can patch them out to get speed +ups. +""" +import sys + +from contextlib import contextmanager +from types import ModuleType + + +@contextmanager +def patch_modules(modules=None, packages=None): + """Within the scope, patch the provided modules so dummy values are imported + instead. + """ + if modules is None: + modules = [] + if packages is None: + packages = [] + + current_modules = set(sys.modules.keys()) + for name in modules: + # XXX: May want to enable for unit tests. + #assert name not in current_modules + if name not in current_modules: + sys.modules[name] = ModuleType(name) + + for name in packages: + if name not in current_modules: + package = ModuleType(name) + # Required if importing module from within package. + package.__path__ = None + sys.modules[name] = package + + try: + yield + finally: + new_modules = set(sys.modules.keys()) - current_modules + # Prevent possibly half-imported modules from impacting other users. + for name in new_modules: + sys.modules.pop(name) diff --git a/quicken/lib/_imports.py b/quicken/lib/_imports.py new file mode 100644 index 0000000..024faa0 --- /dev/null +++ b/quicken/lib/_imports.py @@ -0,0 +1,50 @@ +"""Patched imports, for improving startup speed. + +Where we identify that a dependency has imported some heavy module but doesn't +use it, we can provide that dependency here but with any imports stubbed out. +""" +from __future__ import annotations + +import sys + +from ._import import patch_modules +from ._typing import MYPY_CHECK_RUNNING + +if MYPY_CHECK_RUNNING: + import asyncio + import multiprocessing.connection as multiprocessing_connection + + from typing import Type + + from fasteners import InterProcessLock + + +class Modules: + def __init__(self): + self.__name__ = __name__ + self.__file__ = __file__ + + @property + def asyncio(self) -> asyncio: + # Saves up to 5ms, and we don't use tls. + with patch_modules(modules=['ssl']): + import asyncio + return asyncio + + @property + def InterProcessLock(self) -> Type[InterProcessLock]: + # Saves 2ms since we don't use the decorators. + # We should probably just write our own at this point. + with patch_modules(modules=['six']): + from fasteners import InterProcessLock + return InterProcessLock + + @property + def multiprocessing_connection(self) -> Type[multiprocessing_connection]: + # Saves 2ms since we don't use randomly-created sockets (tempfile, shutil) + with patch_modules(modules=['tempfile']): + import multiprocessing.connection + return multiprocessing.connection + + +sys.modules[__name__] = Modules() diff --git a/quicken/lib/_lib.py b/quicken/lib/_lib.py new file mode 100644 index 0000000..e441909 --- /dev/null +++ b/quicken/lib/_lib.py @@ -0,0 +1,301 @@ +"""CLI wrapper interface for starting/using server process. +""" +from __future__ import annotations + +import json +import logging +import multiprocessing +import os +import socket + +from contextlib import contextmanager +from functools import partial + +from . import QuickenError +from ._client import Client +from ._constants import socket_name, server_state_name, stop_socket_name +from ._imports import InterProcessLock +from ._multiprocessing import set_fd_sharing_base_path_fd +from ._protocol import ProcessState, Request, RequestTypes +from ._signal import blocked_signals, forwarded_signals, SignalProxy +from ._typing import MYPY_CHECK_RUNNING +from ._xdg import cache_dir, chdir, RuntimeDir +from .._timings import report + +if MYPY_CHECK_RUNNING: + from typing import Callable, Optional + + from ._types import JSONType, MainProvider + + +logger = logging.getLogger(__name__) + + + +def check_res_ids(): + ruid, euid, suid = os.getresuid() + if not ruid == euid == suid: + raise QuickenError( + f'real uid ({ruid}), effective uid ({euid}), and saved uid ({suid})' + ' must be the same' + ) + + rgid, egid, sgid = os.getresgid() + if not rgid == egid == sgid: + raise QuickenError( + f'real gid ({rgid}), effective gid ({egid}), and saved gid ({sgid})' + ' must be the same' + ) + + +def need_server_reload(manager, reload_server, user_data): + server_state = manager.server_state + gid = os.getgid() + if gid != server_state['gid']: + logger.info('Reloading server due to gid change') + return True + + # XXX: Will not have the intended effect on macOS, see os.getgroups() for + # details. + groups = os.getgroups() + if set(groups) != set(server_state['groups']): + logger.info('Reloading server due to changed groups') + return True + + if reload_server: + previous_user_data = manager.user_data + if reload_server(previous_user_data, user_data): + logger.info('Reload requested by callback, stopping server.') + return True + + # TODO: Restart based on library version. + return False + + +def server_runner_wrapper( + name: str, + main_provider: MainProvider, + # /, + *, + runtime_dir_path: Optional[str] = None, + log_file: Optional[str] = None, + server_idle_timeout: Optional[float] = None, + reload_server: Callable[[JSONType, JSONType], bool] = None, + user_data: JSONType = None, +) -> Optional[int]: + """Run operation in server identified by name, starting it if required. + """ + + check_res_ids() + + try: + json.dumps(user_data) + except TypeError as e: + raise QuickenError('user_data must be serializable') from e + + if log_file is None: + log_file = os.path.join(cache_dir(f'quicken-{name}'), 'server.log') + log_file = os.path.abspath(log_file) + + main_provider = partial(with_reset_authkey, main_provider) + + runtime_dir = RuntimeDir(f'quicken-{name}', runtime_dir_path) + + set_fd_sharing_base_path_fd(runtime_dir.fileno()) + + manager = CliServerManager(runtime_dir) + + report('connecting to server') + with manager.lock: + need_start = False + try: + client = manager.connect() + except ConnectionFailed as e: + logger.info('Failed to connect to server due to %s.', e) + need_start = True + else: + if need_server_reload(manager, reload_server, user_data): + manager.stop_server() + need_start = True + + if need_start: + logger.info('Starting server') + # XXX: Should have logging around this, for timing. + main = main_provider() + manager.start_server(main, log_file, server_idle_timeout, user_data) + client = manager.connect() + + report('connected to server') + proxy = SignalProxy() + # We must block signals before requesting remote process start otherwise + # a user signal to the client may race with our ability to propagate it. + with blocked_signals(forwarded_signals): + state = ProcessState.for_current_process() + report('requesting start') + logger.debug('Requesting process start') + req = Request(RequestTypes.run_process, state) + response = client.send(req) + pid = response.contents + report('process started') + logger.debug('Process running with pid: %d', pid) + proxy.set_target(pid) + + logger.debug('Waiting for process to finish') + response = client.send(Request(RequestTypes.wait_process_done, None)) + client.close() + report('client finished') + return response.contents + + +def reset_authkey(): + multiprocessing.current_process().authkey = os.urandom(32) + + +def with_reset_authkey(main_provider): + """Ensure that user code is not executed without an authkey set. + """ + main = main_provider() + + def inner(): + reset_authkey() + return main() + + return inner + + +class ConnectionFailed(Exception): + pass + + +class CliServerManager: + """Responsible for starting (if applicable) and connecting to the server. + + Race conditions are prevented by acquiring an exclusive lock on + {runtime_dir}/admin during connection and start. + """ + def __init__(self, runtime_dir: RuntimeDir): + """ + Args: + runtime_dir: runtime dir used for locks/socket + """ + self._runtime_dir = runtime_dir + self._lock = InterProcessLock('admin') + + def connect(self) -> Client: + """Attempt to connect to the server. + + Returns: + Client connected to the server + + Raises: + ConnectionFailed on connection failure (server not up or accepting) + """ + assert self._lock.acquired, 'connect must be called under lock.' + + with chdir(self._runtime_dir): + try: + return Client(socket_name) + except FileNotFoundError as e: + raise ConnectionFailed('File not found') from e + except ConnectionRefusedError as e: + raise ConnectionFailed('Connection refused') from e + + @property + def server_state(self): + with chdir(self._runtime_dir): + with open(server_state_name, encoding='utf-8') as f: + text = f.read() + return json.loads(text) + + @property + def user_data(self): + """Returns user data for the current server. + """ + return self.server_state['user_data'] + + def start_server(self, main, log_file, server_idle_timeout, user_data): + """Start server as background process. + + This function only returns in the parent, not the background process. + + By the time this function returns it is safe to call connect(). + + Args: + main: function that provides the server request handler + log_file: server log file + server_idle_timeout: idle timeout communicated to server if the + process of connecting results in server start + user_data: added to server state + """ + assert self._lock.acquired, 'start_server must be called under lock.' + # Lazy import so we only take the time to import if we have to start + # the server. + # XXX: Should have logging around this, for timing. + from ._server import run + + with chdir(self._runtime_dir): + try: + os.unlink(socket_name) + except FileNotFoundError: + pass + + run( + main, + log_file=log_file, + runtime_dir=self._runtime_dir, + server_idle_timeout=server_idle_timeout, + user_data=user_data, + ) + + def stop_server(self): + assert self._lock.acquired, 'stop_server must be called under lock.' + + # We don't want to leave it to the server to remove the sockets since + # we do not wait for it to shut down before starting a new one. + try: + with chdir(self._runtime_dir): + os.unlink(socket_name) + except FileNotFoundError: + # If any file was removed it doesn't impact us. + pass + + try: + # This will cause the server to stop accepting clients and start + # shutting down. It will wait for any still-running processes before + # stopping completely, but it does not consume any other resources + # that we are concerned with. + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: + sock.connect(stop_socket_name) + except ConnectionRefusedError: + # Best effort. + pass + except FileNotFoundError: + # No problem, server not up. + pass + + try: + with chdir(self._runtime_dir): + os.unlink(stop_socket_name) + except FileNotFoundError: + # If any file was removed it doesn't impact us. + pass + + @property + @contextmanager + def lock(self): + with chdir(self._runtime_dir): + self._lock.acquire() + + # We initialize the Client and Listener classes without an authkey + # parameter since there's no way to pre-share the secret securely + # between processes not part of the same process tree. However, the + # internal Client/Listener used as part of + # multiprocessing.resource_sharer DOES initialize its own Client and + # Listener with multiprocessing.current_process().authkey. We must have + # some value so we use this dummy value. + multiprocessing.current_process().authkey = b'0' * 32 + + try: + yield + finally: + self._lock.release() diff --git a/quicken/_logging.py b/quicken/lib/_logging.py similarity index 100% rename from quicken/_logging.py rename to quicken/lib/_logging.py diff --git a/quicken/_multiprocessing.py b/quicken/lib/_multiprocessing.py similarity index 66% rename from quicken/_multiprocessing.py rename to quicken/lib/_multiprocessing.py index 8a14e85..e0724f3 100644 --- a/quicken/_multiprocessing.py +++ b/quicken/lib/_multiprocessing.py @@ -1,18 +1,20 @@ -import multiprocessing +from __future__ import annotations + import os +import multiprocessing import sys +import threading from io import TextIOWrapper -from multiprocessing.connection import wait -from multiprocessing.reduction import register -from typing import TextIO +from multiprocessing.reduction import recv_handle, register, send_handle +from ._imports import multiprocessing_connection +from ._signal import blocked_signals, signal_range +from ._typing import MYPY_CHECK_RUNNING +from ._xdg import chdir -try: - from multiprocessing.reduction import DupFd -except ImportError: - # Can happen on Windows - DupFd = None +if MYPY_CHECK_RUNNING: + from typing import TextIO def run_in_process( @@ -93,7 +95,7 @@ def detach(result=None): p = ctx.Process(target=launcher, name=name) p.start() - ready = wait([p.sentinel, parent_pipe], timeout=timeout) + ready = multiprocessing_connection.wait([p.sentinel, parent_pipe], timeout=timeout) # Timeout if not ready: @@ -131,18 +133,64 @@ def detach(result=None): return result +_fd_sharing_base_path_fd = None + + +def set_fd_sharing_base_path_fd(fd: int): + global _fd_sharing_base_path_fd + _fd_sharing_base_path_fd = fd + + def reduce_textio(obj: TextIO): + """Simpler version of multiprocessing.resource_sharer._ResourceSharer + that: + + * doesn't require authkey (but does require base path to be set which should + be in a secure folder + * doesn't require random unix socket (so when we import multiprocessing.connection + we can stub out tempfile, saving 5ms). + """ # Picklable object that contains a callback id to be used by the # receiving process. if obj.readable() == obj.writable(): raise ValueError( 'TextIO object must be either readable or writable, but not both.') - df = DupFd(obj.fileno()) - return rebuild_textio, (df, obj.readable(), obj.writable()) + fd = obj.fileno() + name = f'{os.getpid()}-{fd}' - -def rebuild_textio(df: DupFd, readable: bool, _writable: bool) -> TextIO: - fd = df.detach() + with chdir(_fd_sharing_base_path_fd): + # In case a client crashed and we're re-using the pid. + try: + os.unlink(name) + except FileNotFoundError: + pass + path = os.path.abspath(name) + listener = multiprocessing_connection.Listener(address=name) + + def target(): + conn = listener.accept() + with chdir(_fd_sharing_base_path_fd): + listener.close() + pid = conn.recv() + send_handle(conn, fd, pid) + conn.close() + + t = threading.Thread(target=target) + t.daemon = True + with blocked_signals(signal_range): + # Inherits blocked signals. + t.start() + + return rebuild_textio, (path, obj.readable(), obj.writable()) + + +def rebuild_textio(path: str, readable: bool, _writable: bool) -> TextIO: + # UNIX socket path name is limited to 108 characters, so cd to the directory + # and refer to it as a relative path. + with chdir(os.path.dirname(path)): + with multiprocessing_connection.Client(os.path.basename(path)) as c: + c.send(os.getpid()) + fd = recv_handle(c) flags = 'r' if readable else 'w' return open(fd, flags) diff --git a/quicken/_multiprocessing_asyncio.py b/quicken/lib/_multiprocessing_asyncio.py similarity index 98% rename from quicken/_multiprocessing_asyncio.py rename to quicken/lib/_multiprocessing_asyncio.py index bb3dd37..58d021b 100644 --- a/quicken/_multiprocessing_asyncio.py +++ b/quicken/lib/_multiprocessing_asyncio.py @@ -2,7 +2,6 @@ """ from __future__ import annotations -import asyncio import logging import multiprocessing import os @@ -10,7 +9,12 @@ import socket from contextlib import contextmanager -from typing import Any + +from ._imports import asyncio +from ._typing import MYPY_CHECK_RUNNING + +if MYPY_CHECK_RUNNING: + from typing import Any logger = logging.getLogger(__name__) diff --git a/quicken/_protocol.py b/quicken/lib/_protocol.py similarity index 70% rename from quicken/_protocol.py rename to quicken/lib/_protocol.py index 86a04c0..9918071 100644 --- a/quicken/_protocol.py +++ b/quicken/lib/_protocol.py @@ -6,50 +6,61 @@ import os import sys -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, List - # Registers TextIOWrapper handler. from . import _multiprocessing +from ._typing import MYPY_CHECK_RUNNING + +if MYPY_CHECK_RUNNING: + from typing import Any, Dict, List class RequestTypes: - get_server_state = 'get_server_state' run_process = 'run_process' wait_process_done = 'wait_process_done' -@dataclass class Request: - name: str - contents: Any + def __init__(self, name: str, contents: Any): + self.name = name + self.contents = contents -@dataclass class Response: - contents: Any + def __init__(self, contents: Any): + self.contents = contents -@dataclass class StdStreams: - stdin: io.TextIOWrapper - stdout: io.TextIOWrapper - stderr: io.TextIOWrapper + def __init__( + self, + stdin: io.TextIOWrapper, + stdout: io.TextIOWrapper, + stderr: io.TextIOWrapper, + ): + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr -@dataclass class ProcessState: - std_streams: StdStreams - cwd: Path - umask: int - environment: Dict[str, str] - argv: List[str] + def __init__( + self, + std_streams: StdStreams, + cwd: str, + umask: int, + environment: Dict[str, str], + argv: List[str], + ): + self.std_streams = std_streams + self.cwd = cwd + self.umask = umask + self.environment = environment + self.argv = argv @staticmethod def for_current_process() -> ProcessState: streams = StdStreams(sys.stdin, sys.stdout, sys.stderr) - cwd = Path.cwd() + cwd = os.getcwd() # Only way to get umask is to set umask. umask = os.umask(0o077) os.umask(umask) @@ -89,10 +100,3 @@ def apply_to_current_process(state: ProcessState): os.umask(state.umask) os.environ = copy.deepcopy(state.environment) sys.argv = list(state.argv) - - -@dataclass -class ServerState: - start_time: float - pid: int - context: Dict[str, str] diff --git a/quicken/_server.py b/quicken/lib/_server.py similarity index 88% rename from quicken/_server.py rename to quicken/lib/_server.py index 2a03d75..026664f 100644 --- a/quicken/_server.py +++ b/quicken/lib/_server.py @@ -11,7 +11,8 @@ To allow use of any callable in the server we override the forkserver implementation and do not """ -import asyncio +from __future__ import annotations + import contextvars import functools import json @@ -20,20 +21,18 @@ import multiprocessing import os import signal +import socket import sys import time import traceback from abc import ABC, abstractmethod from contextlib import ExitStack -from pathlib import Path -from typing import Any, Dict, Optional - -import psutil -from . import __version__ +from .. import __version__ from ._asyncio import DeadlineTimer -from ._constants import socket_name, server_state_name +from ._constants import socket_name, stop_socket_name, server_state_name +from ._imports import asyncio from ._logging import ContextLogger, NullContextFilter, UTCFormatter from ._multiprocessing import run_in_process from ._multiprocessing_asyncio import ( @@ -43,11 +42,16 @@ ConnectionClose, ListenerStopped ) -from ._typing import NoneFunction -from ._protocol import ProcessState, Request, RequestTypes, Response, ServerState +from ._typing import MYPY_CHECK_RUNNING +from ._protocol import ProcessState, Request, RequestTypes, Response from ._signal import settable_signals from ._xdg import RuntimeDir +if MYPY_CHECK_RUNNING: + from typing import Any, Dict, Optional + + from ._types import NoneFunction + logger = ContextLogger(logging.getLogger(__name__), prefix='server_') @@ -56,7 +60,7 @@ def run( socket_handler, runtime_dir: RuntimeDir, server_idle_timeout: Optional[float], - log_file: Optional[Path], + log_file: Optional[str], user_data: Optional[Any], ): """Start the server in the background. @@ -130,7 +134,7 @@ def _run_server( server_idle_timeout: Optional[float], log_file, user_data, - done + done, ) -> None: """Server that provides sockets to `callback`. @@ -149,7 +153,6 @@ def _run_server( logger.debug('_run_server()') loop = asyncio.new_event_loop() - loop.set_debug(True) def print_exception(_loop, context): exc = context['exception'] @@ -157,6 +160,7 @@ def print_exception(_loop, context): traceback.format_exception(type(exc), exc, exc.__traceback__)) logger.error( 'Error in event loop: %s\n%s', context['message'], formatted_exc) + loop.set_exception_handler(print_exception) handler = ProcessConnectionHandler(callback, {}, loop=loop) @@ -172,7 +176,13 @@ def finish_loop(): # socket_name is relative and we must already have cwd set to the # runtime_dir. server = Server( - socket_name, handler, finish_loop, server_idle_timeout, loop=loop) + socket_name, + stop_socket_name, + handler, + finish_loop, + server_idle_timeout, + loop=loop + ) def handle_sigterm(): logger.debug('Received SIGTERM') @@ -187,20 +197,20 @@ def handle_sigterm(): # For server state info. pid = os.getpid() - process = psutil.Process(pid) server_state = { - 'create_time': process.create_time(), + 'create_time': time.time(), 'version': __version__, 'pid': pid, 'user_data': user_data, 'groups': os.getgroups(), 'gid': os.getgid(), } - Path(server_state_name).write_text( - json.dumps(server_state), encoding='utf-8') + + with open(server_state_name, 'w', encoding='utf-8') as f: + json.dump(server_state, f) logger.debug('Starting server') - loop.create_task(server.serve()) + server.serve() loop.run_forever() logger.debug('Server finished.') @@ -270,12 +280,8 @@ async def handle_request(): nonlocal process, process_task logger.debug('Waiting for request') request = await queue.get() - if request.name == RequestTypes.get_server_state: - state = ServerState(self._start_time, self._pid, self._context) - logger.debug('Sending server state') - await connection.send(Response(state)) - elif request.name == RequestTypes.run_process: + if request.name == RequestTypes.run_process: assert process is None, \ 'Process must not have been started' process_state = request.contents @@ -383,21 +389,30 @@ class Server: Not thread-safe. """ def __init__( - self, socket_path, handler: ConnectionHandler, - on_shutdown: NoneFunction, idle_timeout: Optional[int] = None, - shutdown_ctx=None, loop=None): + self, + socket_path, + stop_socket_path, + handler: ConnectionHandler, + on_shutdown: NoneFunction, + idle_timeout: Optional[int] = None, + shutdown_ctx=None, + loop=None + ): """ Args: - socket_path: + socket_path: path to listen for client connections + stop_socket_path: path to a socket which, when a client connects, + will cause the server to shut down handler: Handler for received connections - idle_timeout: + idle_timeout: timeout (in seconds) after which server will shut itself down + without any work shutdown_ctx: Context manager to be entered prior to server shutdown. - loop: + loop """ if not loop: loop = asyncio.get_event_loop() self._loop = loop - + self._stop_socket_path = stop_socket_path self._listener = AsyncListener(socket_path, loop=self._loop) self._handler = handler self._idle_timeout = idle_timeout @@ -408,19 +423,9 @@ def __init__( self._shutdown_accept_cv = asyncio.Condition(loop=self._loop) self._on_shutdown = on_shutdown - async def serve(self): - while True: - try: - connection = await self._listener.accept() - except ListenerStopped: - if not self._shutting_down: - logger.error('Listener has stopped') - else: - async with self._shutdown_accept_cv: - self._shutdown_accept_cv.notify() - return - - self._handle_connection(connection) + def serve(self): + self._loop.create_task(self._serve_stop()) + self._loop.create_task(self._serve_clients()) async def stop(self): """Gracefully stop server, processing all pending connections. @@ -450,6 +455,33 @@ async def stop(self): # Finish everything off. self._on_shutdown() + async def _serve_stop(self): + sock = self._stop_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.bind(self._stop_socket_path) + sock.listen(1) + sock.setblocking(False) + self._loop.add_reader(sock.fileno(), self._on_stop_connect) + + def _on_stop_connect(self): + self._loop.remove_reader(self._stop_sock.fileno()) + _sock, _address = self._stop_sock.accept() + # XXX: Should we close the socket? + self._loop.create_task(self.stop()) + + async def _serve_clients(self): + while True: + try: + connection = await self._listener.accept() + except ListenerStopped: + if not self._shutting_down: + logger.error('Listener has stopped') + else: + async with self._shutdown_accept_cv: + self._shutdown_accept_cv.notify() + return + + self._handle_connection(connection) + def _handle_connection(self, connection: AsyncConnectionAdapter): self._idle_handle_connect() @@ -492,8 +524,9 @@ def _idle_handle_close(self): self._set_idle_timer() -def _configure_logging(logfile: Path, loglevel: str) -> None: - logfile.parent.mkdir(parents=True, exist_ok=True) +def _configure_logging(logfile: str, loglevel: str) -> None: + parent = os.path.dirname(logfile) + os.makedirs(parent, mode=0o700, exist_ok=True) # TODO: Make fully configurable. logging.config.dictConfig({ 'version': 1, diff --git a/quicken/_signal.py b/quicken/lib/_signal.py similarity index 95% rename from quicken/_signal.py rename to quicken/lib/_signal.py index eb6689d..4b38952 100644 --- a/quicken/_signal.py +++ b/quicken/lib/_signal.py @@ -1,5 +1,8 @@ """Signal helpers. """ +from __future__ import annotations + + import errno import logging import os @@ -7,7 +10,11 @@ import sys from contextlib import contextmanager -from typing import Set + +from ._typing import MYPY_CHECK_RUNNING + +if MYPY_CHECK_RUNNING: + from typing import Set logger = logging.getLogger(__name__) @@ -32,7 +39,6 @@ def _settable_signal(sig) -> bool: signal_range = set(range(1, signal.NSIG)) -# XXX: Can be signal.valid_signals() in 3.8+ settable_signals = set(filter(_settable_signal, signal_range)) if not sys.platform.startswith('win'): diff --git a/quicken/lib/_types.py b/quicken/lib/_types.py new file mode 100644 index 0000000..3947312 --- /dev/null +++ b/quicken/lib/_types.py @@ -0,0 +1,10 @@ +from ._typing import MYPY_CHECK_RUNNING + + +if MYPY_CHECK_RUNNING: + from typing import Callable, List, Mapping, Optional, Union + + NoneFunction = Callable[[], None] + JSONType = Union[str, int, float, bool, None, Mapping[str, 'JSONType'], List['JSONType']] + MainFunction = Callable[[], Optional[int]] + MainProvider = Callable[[], MainFunction] diff --git a/quicken/lib/_typing.py b/quicken/lib/_typing.py new file mode 100644 index 0000000..f790c13 --- /dev/null +++ b/quicken/lib/_typing.py @@ -0,0 +1 @@ +MYPY_CHECK_RUNNING = False diff --git a/quicken/_xdg.py b/quicken/lib/_xdg.py similarity index 62% rename from quicken/_xdg.py rename to quicken/lib/_xdg.py index 53eb8f7..cb2fb4b 100644 --- a/quicken/_xdg.py +++ b/quicken/lib/_xdg.py @@ -1,10 +1,14 @@ -from contextlib import contextmanager, ExitStack -from functools import wraps +from __future__ import annotations + import os -from pathlib import Path, PosixPath import stat -import threading -from typing import Any, ContextManager, Union + +from contextlib import contextmanager + +from ._typing import MYPY_CHECK_RUNNING + +if MYPY_CHECK_RUNNING: + from typing import ContextManager @contextmanager @@ -25,50 +29,6 @@ def chdir(fd) -> ContextManager: os.close(cwd) -@contextmanager -def lock_guard(l: Union[threading.Lock, threading.RLock]): - l.acquire() - try: - yield - finally: - l.release() - - -class BoundPath(PosixPath): - _lock = threading.RLock() - - def __init__(self, *_, dir_fd: int): - self._dir_fd = dir_fd - super().__init__() - - def __getattribute__(self, name: str) -> Any: - """Intercept and execute all functions in the context of the - directory. - """ - attr = super().__getattribute__(name) - if callable(attr): - @wraps(attr) - def wrapper(*args, **kwargs): - with ExitStack() as stack: - stack.enter_context(lock_guard(self._lock)) - try: - stack.enter_context(chdir(self._dir_fd)) - except AttributeError: - # Avoids issues during Path construction, before - # __init__ is called. - pass - return attr(*args, **kwargs) - return wrapper - return attr - - @property - def dir(self): - return self._dir_fd - - def pass_to(self, callback): - return callback(self) - - class RuntimeDir: """Helper class to create/manage the application runtime directory. @@ -108,7 +68,7 @@ def __init__(self, base_name: str = None, dir_path=None): try: self._fd = os.open(dir_path, os.O_RDONLY) except FileNotFoundError: - Path(dir_path).mkdir(mode=0o700) + os.mkdir(dir_path, mode=0o700) self._fd = os.open(dir_path, os.O_RDONLY) # Test after open to avoid toctou, also since we do not trust the mode # passed to mkdir. @@ -128,15 +88,6 @@ def __init__(self, base_name: str = None, dir_path=None): def fileno(self) -> int: return self._fd - def path(self, *args) -> BoundPath: - """Execute action in directory so relative paths are resolved inside the - directory without specific operations needing to support `dir_fd`. - """ - result = BoundPath(*args, dir_fd=self._fd) - if result.is_absolute(): - raise ValueError('Provided argument must not be absolute') - return result - def __str__(self): return self._path @@ -155,6 +106,6 @@ def runtime_dir(base_name): def cache_dir(base_name): try: - return Path(os.environ['XDG_CACHE_HOME']) / base_name + return os.path.join(os.environ['XDG_CACHE_HOME'], base_name) except KeyError: - return Path(os.environ['HOME']) / '.cache' / base_name + return os.path.join(os.environ['HOME'], '.cache', base_name) diff --git a/quicken/script.py b/quicken/script.py new file mode 100644 index 0000000..5535443 --- /dev/null +++ b/quicken/script.py @@ -0,0 +1,33 @@ +"""Entrypoint wrapper that starts a quicken server around the provided +command. +""" +import importlib +import os +import sys + +from ._scripts import ( + get_attribute_accumulator, get_nested_attr, parse_script_spec +) + + +__all__ = [] + + +def callback(parts): + from .lib import quicken + module_parts, function_parts = parse_script_spec(parts) + module_name = '.'.join(module_parts) + function_name = '.'.join(function_parts) + name = f'quicken.entrypoint.{module_name}.{function_name}' + + log_file = os.environ.get('QUICKEN_LOG') + + @quicken(name, log_file=log_file) + def main(): + module = importlib.import_module(module_name) + return get_nested_attr(module, function_parts) + + return main() + + +sys.modules[__name__] = get_attribute_accumulator(callback) diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cli_wrapper/__init__.py b/tests/cli_wrapper/__init__.py index 6bf03a3..65ad960 100644 --- a/tests/cli_wrapper/__init__.py +++ b/tests/cli_wrapper/__init__.py @@ -35,7 +35,7 @@ def inner(): """ from pathlib import Path -from quicken import quicken +from quicken.lib import quicken from ..utils.pytest import current_test_name from ..utils import preserved_signals diff --git a/tests/cli_wrapper/test_cli_wrapper.py b/tests/cli_wrapper/test_cli_wrapper.py index facf68f..e9b8b66 100644 --- a/tests/cli_wrapper/test_cli_wrapper.py +++ b/tests/cli_wrapper/test_cli_wrapper.py @@ -15,17 +15,19 @@ import psutil import pytest -from quicken import __version__, QuickenError -from quicken._constants import server_state_name, socket_name -from quicken._signal import forwarded_signals -from quicken._xdg import RuntimeDir +from quicken import __version__ +from quicken.lib import QuickenError +from quicken.lib._constants import server_state_name, socket_name +from quicken.lib._signal import forwarded_signals +from quicken.lib._xdg import RuntimeDir from . import cli_factory from ..utils import ( argv, captured_std_streams, env, isolated_filesystem, umask) +from ..utils.path import get_bound_path from ..utils.process import contained_children from ..utils.pytest import current_test_name, non_windows -from ..watch import wait_for_create +from ..utils.watch import wait_for_create pytestmark = non_windows @@ -336,7 +338,7 @@ def wait_for(predicate): p = Process(target=client) p.start() assert wait_for_create( - runtime_dir.path(runner_pid_file.name), timeout=2), \ + get_bound_path(runtime_dir, runner_pid_file.name), timeout=2), \ f'{runner_pid_file} must have been created' runner_pid = int(runner_pid_file.read_text(encoding='utf-8')) @@ -407,7 +409,7 @@ def client(): logger.debug('Waiting for pid file') assert wait_for_create( - runtime_dir.path(runner_pid_file.name), timeout=2), \ + get_bound_path(runtime_dir, runner_pid_file.name), timeout=2), \ f'{runner_pid_file} must have been created' runner_pid = int(runner_pid_file.read_text(encoding='utf-8')) logger.debug('Runner started with pid: %d', runner_pid) @@ -702,7 +704,7 @@ def inner(): worker_pid = str(p.pid) runtime_dir = RuntimeDir(dir_path=path) process = psutil.Process(pid=p.pid) - state_file = runtime_dir.path(server_state_name) + state_file = get_bound_path(runtime_dir, server_state_name) state_file.write_text(json.dumps({ 'create_time': process.create_time(), 'pid': p.pid, diff --git a/tests/conftest.py b/tests/conftest.py index 325cb46..9481541 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,13 @@ import logging import logging.config import os +import subprocess import sys import threading +import venv from pathlib import Path +from typing import List import pytest @@ -23,21 +26,30 @@ def get_process_stack(pid): else: raise - +from .utils import isolated_filesystem from .utils.pytest import current_test_name from .utils.process import active_children, disable_child_tracking, kill_children -from quicken._logging import UTCFormatter +from quicken.lib._logging import UTCFormatter log_file_format = 'logs/{test_case}.log' -pytest_plugins = "tests.timeout", "tests.strace" +pytest_plugins = "tests.plugins.timeout", "tests.plugins.strace" + + +def get_log_file(test_name): + return Path(log_file_format.format(test_case=test_name)).absolute() + + +@pytest.fixture +def log_file_path(): + return get_log_file(current_test_name()) def pytest_runtest_setup(item): - path = Path(log_file_format.format(test_case=item.name)).absolute() + path = get_log_file(item.name) path.parent.mkdir(parents=True, exist_ok=True) class TestNameAdderFilter(logging.Filter): @@ -125,3 +137,15 @@ def pytest_timeout_timeout(item, report): report.longrepr = report.longrepr + '\nsubprocess stacks:\n' + '\n'.join(stacks) kill_children() + + +@pytest.fixture(scope='module') +def virtualenv(): + def run_python(cmd: List[str], *args, **kwargs): + interpreter = Path(path) / 'bin' / 'python' + cmd.insert(0, str(interpreter)) + return subprocess.run(cmd, *args, **kwargs) + + with isolated_filesystem() as path: + venv.create(path, symlinks=True, with_pip=True) + yield run_python diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strace/__init__.py b/tests/plugins/strace/__init__.py similarity index 87% rename from tests/strace/__init__.py rename to tests/plugins/strace/__init__.py index 67008da..2f07793 100644 --- a/tests/strace/__init__.py +++ b/tests/plugins/strace/__init__.py @@ -9,7 +9,7 @@ import pytest -from ..utils.pytest import current_test_name +from ...utils.pytest import current_test_name logger = logging.getLogger(__name__) @@ -48,12 +48,12 @@ def pytest_runtest_call(item): yield return - path = Path(__file__).parent.parent.parent / 'logs' / f'strace.{current_test_name()}.log' + path = Path(__file__).parent / '..' / '..' / '..' / 'logs' / f'strace.{current_test_name()}.log' p = subprocess.Popen( ['strace', '-yttfo', str(path), '-s', '512', '-p', str(os.getpid())], stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL + stderr=subprocess.DEVNULL, ) if not busy_wait(is_traced, timeout=5): logger.warning('Could not attach strace') diff --git a/tests/timeout/__init__.py b/tests/plugins/timeout/__init__.py similarity index 100% rename from tests/timeout/__init__.py rename to tests/plugins/timeout/__init__.py diff --git a/tests/timeout/newhooks.py b/tests/plugins/timeout/newhooks.py similarity index 100% rename from tests/timeout/newhooks.py rename to tests/plugins/timeout/newhooks.py diff --git a/tests/test__multiprocessing.py b/tests/test__multiprocessing.py index c9265a4..d26a955 100644 --- a/tests/test__multiprocessing.py +++ b/tests/test__multiprocessing.py @@ -6,7 +6,7 @@ import pytest -from quicken._multiprocessing import run_in_process +from quicken.lib._multiprocessing import run_in_process from .utils import isolated_filesystem from .utils.process import contained_children diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index e4c0a12..917bf28 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -28,7 +28,7 @@ def cli(): import os import sys -from quicken import quicken +from quicken.lib import quicken def bypass(): @@ -74,7 +74,7 @@ def target(): def test_quicken_import_time(benchmark): def target(): - return run_code('from quicken import quicken') + return run_code('from quicken.lib import quicken') result = benchmark(target) assert result == 0, 'Process must have exited cleanly' @@ -82,7 +82,7 @@ def target(): def test_quicken_cli_import_time(benchmark): def target(): - return run_code('import quicken._cli') + return run_code('import quicken.lib._lib') result = benchmark(target) assert result == 0, 'Process must have exited cleanly' @@ -101,7 +101,7 @@ def test_quicken_server_import_time(benchmark): # We import _server lazily, this shows us the portion of startup that goes # towards that. def target(): - return run_code('import quicken._server; import quicken._cli') + return run_code('import quicken.lib._server; import quicken.lib._lib') result = benchmark(target) assert result == 0, 'Process must have exited cleanly' diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..06ecd0f --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,299 @@ +import copy +import os +import subprocess +import sys + +from contextlib import contextmanager +from pathlib import Path +from textwrap import dedent + +import pytest + +from quicken._cli import parse_args, parse_file + +from .utils import captured_std_streams, chdir, env, isolated_filesystem +from .utils.pytest import non_windows + +import logging; logger = logging.getLogger(__name__) + +pytestmark = non_windows + + +@contextmanager +def sys_path(path): + current_sys_path = sys.path + sys.path = sys.path.copy() + sys.path.append(path) + try: + yield + finally: + sys.path = current_sys_path + + +@contextmanager +def kept(o, attr): + current_attr = copy.copy(getattr(o, attr)) + try: + yield + except: + setattr(o, attr, current_attr) + + +def test_args_ctl_passthru(): + _, args = parse_args(['-f', './script.py', '--', '--ctl']) + assert args.f == './script.py' + assert args.args == ['--ctl'] + + +#def test_args_module_passthru(): +# _, args = parse_args(['-m', 'pytest', '--', '-s', '-ra']) +# assert args.m == 'pytest' +# assert args.args == ['-s', '-ra'] + + +def test_file_args_passthru(): + _, args = parse_args(['-f', 'foo', '--', '-m', 'hello']) + assert args.f == 'foo' + assert args.args == ['-m', 'hello'] + + +def test_file_evaluation(): + # Given a package hello with + # + # hello/ + # __init__.py + # foo.py + # + # # hello/__init__.py + # foo = 1 + # + # # script.py + # from hello import foo + # import hello.foo + # + # if __name__ == '__main__': + # print(foo) + # + # should print 1 + with isolated_filesystem() as path: + Path('hello').mkdir() + Path('hello/__init__.py').write_text('foo = 1') + Path('hello/foo.py').write_text('') + + Path('script.py').write_text(dedent(''' + from hello import foo + import hello.foo + + if __name__ == '__main__': + print(foo) + ''')) + + with sys_path(str(path)): + with kept(sys, 'modules'): + prelude, main = parse_file(str(path / 'script.py')) + + prelude() + + with captured_std_streams() as (stdin, stdout, stderr): + main() + + output = stdout.read() + assert output == '1\n' + + +def test_file_backtrace_line_numbering(): + # Given a file `script.py`: + # + # import os + # raise RuntimeError('example') + # + # if __name__ == '__main__': + # raise RuntimeError('example2') + # + # When executed, the backtrace should have RuntimeError('example') coming + # from the appropriate location. + with isolated_filesystem(): + Path('script.py').write_text(dedent('''\ + import os + raise RuntimeError('example') + + if __name__ == '__main__': + raise RuntimeError('example2') + ''')) + + prelude, main = parse_file('script.py') + + with pytest.raises(RuntimeError) as e: + prelude() + + assert 'example' in str(e) + entry = e.traceback[1] + assert str(entry.path) == str(Path('script.py').absolute()) + # the pytest lineno is one less than actual + assert entry.lineno + 1 == 2 + + +def test_file_main_backtrace_line_numbering(): + # Given a file `script.py`: + # + # import os + # + # if __name__ == '__main__': + # os.getpid + # raise RuntimeError('example') + # + # When executed, the backtrace should have RuntimeError('example') coming + # from the appropriate location. + with isolated_filesystem(): + Path('script.py').write_text(dedent('''\ + import os + + if __name__ == '__main__': + os.getpid + raise RuntimeError('example') + ''')) + + prelude, main = parse_file('script.py') + + prelude() + + with pytest.raises(RuntimeError) as e: + main() + + entry = e.traceback[1] + assert str(entry.path) == str(Path('script.py').absolute()) + assert entry.lineno + 1 == 5 + + +def test_file_path_set(): + # Given a file `script.py` + # Executed like python script.py + # The Python interpreter sets __file__ to the value passed as the first + # argument. + # The problem with that is we execute main from a cwd that is different than + # the initial execution for the module. + # So we normalize the __file__ attribute to always be the full, resolved path + # to the file. + with isolated_filesystem() as path: + Path('script.py').write_text(dedent(''' + print(__file__) + + if __name__ == '__main__': + print(__file__) + ''')) + + prelude, main = parse_file('script.py') + + with captured_std_streams() as (stdin, stdout, stderr): + prelude() + + assert stdout.read().strip() == str(path / 'script.py') + + with captured_std_streams() as (stdin, stdout, stderr): + main() + + assert stdout.read().strip() == str(path / 'script.py') + + +def test_file_path_set_symlink(): + # Given a file `script.py` + # __file__ should be the full, resolved path to the file. + ... + + +def test_file_path_symlink_modified(): + ... + + +# TODO: Make console scripts nicer to access for tests. +def run_cli(*args, **kwargs): + runner_path = Path('runner.py') + runner_path.write_text(dedent(''' + from quicken._cli import main + + main() + ''')) + + cmd = [ + sys.executable, + str(runner_path.absolute()), + ] + + if args: + cmd.extend(args[0]) + + return subprocess.run(cmd, *args[1:], **kwargs) + + +def test_file_argv_set(log_file_path): + # Given a file `script.py` + # sys.argv should start with `script.py` and be followed by any + # other arguments + with isolated_filesystem() as path: + Path('script.py').write_text(dedent(''' + import sys + + if __name__ == '__main__': + print(sys.argv[0]) + print(sys.argv[1]) + ''')) + + args = ['hello'] + with env(QUICKEN_LOG=str(log_file_path)): + result = run_cli( + [ + '-f', + str(Path('script.py')), + *args + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + assert result.stdout.decode('utf-8') == f'script.py\n{args[0]}\n' + + +def test_file_server_name_uses_absolute_path(log_file_path): + # Given a file a/script.py + # And a server started from a/script.py + # When in a + # And quicken -f script.py is executed + # Then the server should be used for the call + with isolated_filesystem() as path: + script_a = Path('a/script.py') + script_a.parent.mkdir(parents=True) + script_a.write_text(dedent(''' + import os + + if __name__ == '__main__': + print(os.getpid()) + print(os.getppid()) + ''')) + + with env(QUICKEN_LOG=str(log_file_path)): + result = run_cli( + [ + '-f', + str(script_a), + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + current_pid = str(os.getpid()) + runner_pid_1, parent_pid_1 = result.stdout.decode('utf-8').split() + assert runner_pid_1 != current_pid + assert parent_pid_1 != current_pid + + with chdir('a'): + result = run_cli( + [ + '-f', + script_a.name, + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + runner_pid_2, parent_pid_2 = result.stdout.decode('utf-8').split() + assert runner_pid_2 != current_pid + assert parent_pid_2 != current_pid + assert runner_pid_1 != runner_pid_2 + assert parent_pid_1 == parent_pid_2 diff --git a/tests/test_logging.py b/tests/test_logging.py index 433edc2..8273f94 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -1,7 +1,7 @@ import io import logging -from quicken._protocol import ProcessState +from quicken.lib._protocol import ProcessState from .utils.pytest import non_windows diff --git a/tests/test_multiprocessing_asyncio.py b/tests/test_multiprocessing_asyncio.py index a5049a5..71e3cc1 100644 --- a/tests/test_multiprocessing_asyncio.py +++ b/tests/test_multiprocessing_asyncio.py @@ -4,7 +4,7 @@ import pytest -from quicken._multiprocessing_asyncio import AsyncProcess +from quicken.lib._multiprocessing_asyncio import AsyncProcess from .utils import isolated_filesystem from .utils.pytest import non_windows diff --git a/tests/test_scripts.py b/tests/test_scripts.py new file mode 100644 index 0000000..18dbce9 --- /dev/null +++ b/tests/test_scripts.py @@ -0,0 +1,19 @@ +"""Test script helpers. +""" +from quicken._scripts import get_attribute_accumulator + + +def test_attribute_accumulator(): + result = None + + def check(this_result): + nonlocal result + result = this_result + + get_attribute_accumulator(check).foo.bar.baz() + + assert result == ['foo', 'bar', 'baz'] + + get_attribute_accumulator(check).__init__.hello._.world() + + assert result == ['__init__', 'hello', '_', 'world'] diff --git a/tests/test_xdg.py b/tests/test_xdg.py index 79652c3..b333fc7 100644 --- a/tests/test_xdg.py +++ b/tests/test_xdg.py @@ -3,11 +3,12 @@ from pathlib import Path -from quicken._xdg import RuntimeDir +from quicken.lib._xdg import RuntimeDir import pytest from .utils import env +from .utils.path import get_bound_path from .utils.pytest import non_windows @@ -77,7 +78,7 @@ def test_runtime_dir_succeeds_creating_a_file(): sample_text = 'hello' with tempfile.TemporaryDirectory() as p: runtime_dir = RuntimeDir(dir_path=p) - file = runtime_dir.path('example') + file = get_bound_path(runtime_dir, 'example') file.write_text(sample_text, encoding='utf-8') text = (Path(p) / 'example').read_text(encoding='utf-8') assert sample_text == text @@ -92,7 +93,7 @@ def test_runtime_dir_path_fails_when_directory_unlinked_and_recreated(): sample_text = 'hello' with tempfile.TemporaryDirectory() as p: runtime_dir = RuntimeDir(dir_path=p) - file = runtime_dir.path('example') + file = get_bound_path(runtime_dir, 'example') Path(p).mkdir() diff --git a/tests/tests/test_utils.py b/tests/tests/test_utils.py new file mode 100644 index 0000000..27a5857 --- /dev/null +++ b/tests/tests/test_utils.py @@ -0,0 +1,9 @@ +from ..utils import captured_std_streams + + +def test_captured_std_streams(): + text = 'hello world' + with captured_std_streams() as (stdin, stdout, stderr): + print(text) + + assert stdout.read() == f'{text}\n' diff --git a/tests/test_watch.py b/tests/tests/test_watch.py similarity index 81% rename from tests/test_watch.py rename to tests/tests/test_watch.py index 1cee5c8..7b2856f 100644 --- a/tests/test_watch.py +++ b/tests/tests/test_watch.py @@ -3,11 +3,12 @@ import threading import time -from quicken._xdg import RuntimeDir +from quicken.lib._xdg import RuntimeDir -from .utils import isolated_filesystem -from .utils.pytest import non_windows -from .watch import wait_for_create, wait_for_delete +from ..utils import isolated_filesystem +from ..utils.pytest import non_windows +from ..utils.watch import wait_for_create, wait_for_delete +from ..utils.path import get_bound_path pytestmark = non_windows @@ -16,7 +17,7 @@ def test_wait_for_create_notices_existing_file(): with isolated_filesystem() as p: runtime_dir = RuntimeDir(dir_path=p) - file = runtime_dir.path('example.txt') + file = get_bound_path(runtime_dir, 'example.txt') file.write_text('hello', encoding='utf-8') assert wait_for_create(file, timeout=0.01) @@ -24,7 +25,7 @@ def test_wait_for_create_notices_existing_file(): def test_wait_for_create_fails_missing_file(): with isolated_filesystem() as p: runtime_dir = RuntimeDir(dir_path=p) - file = runtime_dir.path('example.txt') + file = get_bound_path(runtime_dir, 'example.txt') assert not wait_for_create(file, timeout=0.01) @@ -33,7 +34,7 @@ def test_watch_for_create_notices_file_fast(): # To rule out dependence on being in the cwd. os.chdir('/') runtime_dir = RuntimeDir(dir_path=p) - file = runtime_dir.path('example.txt') + file = get_bound_path(runtime_dir, 'example.txt') writer_timestamp: datetime = None def create_file(): @@ -53,14 +54,14 @@ def create_file(): def test_wait_for_delete_notices_missing_file(): with isolated_filesystem() as p: runtime_dir = RuntimeDir(dir_path=p) - file = runtime_dir.path('example.txt') + file = get_bound_path(runtime_dir, 'example.txt') assert wait_for_delete(file, timeout=0.01) def test_wait_for_delete_fails_existing_file(): with isolated_filesystem() as p: runtime_dir = RuntimeDir(dir_path=p) - file = runtime_dir.path('example.txt') + file = get_bound_path(runtime_dir, 'example.txt') file.write_text('hello', encoding='utf-8') assert not wait_for_delete(file, timeout=0.1) @@ -70,7 +71,7 @@ def test_watch_for_delete_notices_file_fast(): # To rule out dependence on being in the cwd. os.chdir('/') runtime_dir = RuntimeDir(dir_path=p) - file = runtime_dir.path('example.txt') + file = get_bound_path(runtime_dir, 'example.txt') file.write_text('hello', encoding='utf-8') writer_timestamp: datetime = None diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 1371eb4..f349bb0 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -7,16 +7,16 @@ from contextlib import contextmanager from pathlib import Path -from typing import ContextManager, List +from typing import ContextManager, List, TextIO, Tuple, Union -from quicken._signal import settable_signals +from quicken.lib._signal import settable_signals logger = logging.getLogger(__name__) @contextmanager -def chdir(path: Path) -> ContextManager: +def chdir(path: Union[Path, str]) -> ContextManager: current_path = Path.cwd() try: os.chdir(str(path)) @@ -76,7 +76,7 @@ def argv(args: List[str]) -> ContextManager: @contextmanager -def umask(umask: int) -> ContextManager: +def umask(umask: int) -> ContextManager[None]: """Set umask within the context. """ umask = os.umask(umask) @@ -87,7 +87,7 @@ def umask(umask: int) -> ContextManager: @contextmanager -def preserved_signals() -> ContextManager: +def preserved_signals() -> ContextManager[None]: handlers = [(s, signal.getsignal(s)) for s in settable_signals] try: yield @@ -105,7 +105,7 @@ def preserved_signals() -> ContextManager: @contextmanager -def captured_std_streams() -> ContextManager: +def captured_std_streams() -> ContextManager[Tuple[TextIO, TextIO, TextIO]]: """Capture standard streams and provide an interface for interacting with them. @@ -125,6 +125,8 @@ def captured_std_streams() -> ContextManager: try: yield os.fdopen(stdin_w, 'w'), os.fdopen(stdout_r), os.fdopen(stderr_r) finally: + sys.stdout.flush() + sys.stderr.flush() os.close(stdin_r) os.close(stdout_w) os.close(stderr_w) diff --git a/tests/utils/path.py b/tests/utils/path.py new file mode 100644 index 0000000..da13cfb --- /dev/null +++ b/tests/utils/path.py @@ -0,0 +1,62 @@ +import os +import threading + +from contextlib import contextmanager, ExitStack +from pathlib import PosixPath +from functools import wraps +from typing import Any, ContextManager, Union + +from quicken.lib._xdg import chdir + + +@contextmanager +def lock_guard(l: Union[threading.Lock, threading.RLock]): + l.acquire() + try: + yield + finally: + l.release() + +class BoundPath(PosixPath): + _lock = threading.RLock() + + def __init__(self, *_, dir_fd: int): + self._dir_fd = dir_fd + super().__init__() + + def __getattribute__(self, name: str) -> Any: + """Intercept and execute all functions in the context of the + directory. + """ + attr = super().__getattribute__(name) + if callable(attr): + @wraps(attr) + def wrapper(*args, **kwargs): + with ExitStack() as stack: + stack.enter_context(lock_guard(self._lock)) + try: + stack.enter_context(chdir(self._dir_fd)) + except AttributeError: + # Avoids issues during Path construction, before + # __init__ is called. + pass + return attr(*args, **kwargs) + return wrapper + return attr + + @property + def dir(self): + return self._dir_fd + + def pass_to(self, callback): + return callback(self) + + +def get_bound_path(context, *args) -> BoundPath: + """Execute action in directory so relative paths are resolved inside the + directory without specific operations needing to support `dir_fd`. + """ + result = BoundPath(*args, dir_fd=context.fileno()) + if result.is_absolute(): + raise ValueError('Provided argument must not be absolute') + return result diff --git a/tests/watch.py b/tests/utils/watch.py similarity index 98% rename from tests/watch.py rename to tests/utils/watch.py index 8ee29dd..504ca41 100644 --- a/tests/watch.py +++ b/tests/utils/watch.py @@ -6,7 +6,7 @@ from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler -from quicken._xdg import BoundPath, chdir +from .path import BoundPath, chdir logger = logging.getLogger(__name__)