From 827141b3082d629c9f7893edc13fbcb1c916625c Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Tue, 23 Jan 2024 12:31:31 -0500 Subject: [PATCH] feat: multi-plugins Signed-off-by: Henry Schreiner --- docs/dev-guide.rst | 2 +- src/validate_pyproject/api.py | 4 +- src/validate_pyproject/cli.py | 3 +- src/validate_pyproject/plugins/__init__.py | 44 +++++++++++++------- src/validate_pyproject/pre_compile/cli.py | 3 +- tests/test_api.py | 4 +- tests/test_cli.py | 2 +- tests/test_plugins.py | 47 ++++++++++++++++++---- 8 files changed, 76 insertions(+), 33 deletions(-) diff --git a/docs/dev-guide.rst b/docs/dev-guide.rst index cf69db5..16a487c 100644 --- a/docs/dev-guide.rst +++ b/docs/dev-guide.rst @@ -77,7 +77,7 @@ specify which ``tool`` subtable it would be checking: available_plugins = [ - *plugins.list_from_entry_points(), + *plugins.list_plugins_from_entry_points(), plugins.PluginWrapper("your-tool", your_plugin), ] validator = api.Validator(available_plugins) diff --git a/src/validate_pyproject/api.py b/src/validate_pyproject/api.py index 003bf78..ee3c051 100644 --- a/src/validate_pyproject/api.py +++ b/src/validate_pyproject/api.py @@ -202,9 +202,9 @@ def __init__( self._extra_validations = tuple(extra_validations) if plugins is ALL_PLUGINS: - from .plugins import list_from_entry_points + from .plugins import list_plugins_from_entry_points - plugins = list_from_entry_points() + plugins = list_plugins_from_entry_points() self._plugins = (*plugins, *extra_plugins) diff --git a/src/validate_pyproject/cli.py b/src/validate_pyproject/cli.py index e5f0ad5..7c4f780 100644 --- a/src/validate_pyproject/cli.py +++ b/src/validate_pyproject/cli.py @@ -30,8 +30,7 @@ from . import _tomllib as tomllib from .api import Validator from .errors import ValidationError -from .plugins import PluginWrapper -from .plugins import list_from_entry_points as list_plugins_from_entry_points +from .plugins import PluginWrapper, list_plugins_from_entry_points from .remote import RemotePlugin, load_store _logger = logging.getLogger(__package__) diff --git a/src/validate_pyproject/plugins/__init__.py b/src/validate_pyproject/plugins/__init__.py index 813ac2a..36eb5d9 100644 --- a/src/validate_pyproject/plugins/__init__.py +++ b/src/validate_pyproject/plugins/__init__.py @@ -29,8 +29,6 @@ else: Protocol = object -ENTRYPOINT_GROUP = "validate_pyproject.tool_schema" - class PluginProtocol(Protocol): @property @@ -56,7 +54,7 @@ def fragment(self) -> str: class PluginWrapper: def __init__(self, tool: str, load_fn: "Plugin"): - self._tool = tool + self._tool, _, self._fragment = tool.partition("#") self._load_fn = load_fn @property @@ -73,7 +71,7 @@ def schema(self) -> "Schema": @property def fragment(self) -> str: - return "" + return self._fragment @property def help_text(self) -> str: @@ -90,12 +88,13 @@ def __repr__(self) -> str: _: PluginProtocol = typing.cast(PluginWrapper, None) -def iterate_entry_points(group: str = ENTRYPOINT_GROUP) -> Iterable[EntryPoint]: +def iterate_entry_points(group: str) -> Iterable[EntryPoint]: """Produces a generator yielding an EntryPoint object for each plugin registered via ``setuptools`` `entry point`_ mechanism. - This method can be used in conjunction with :obj:`load_from_entry_point` to filter - the plugins before actually loading them. + This method can be used in conjunction with :obj:`load_from_entry_point` to + filter the plugins before actually loading them. The entry points are not + deduplicated, but they are sorted. """ entries = entry_points() if hasattr(entries, "select"): # pragma: no cover @@ -110,8 +109,7 @@ def iterate_entry_points(group: str = ENTRYPOINT_GROUP) -> Iterable[EntryPoint]: # TODO: Once Python 3.10 becomes the oldest version supported, this fallback and # conditional statement can be removed. entries_ = (plugin for plugin in entries.get(group, [])) - deduplicated = {e.name: e for e in sorted(entries_, key=lambda e: e.name)} - return list(deduplicated.values()) + return sorted(entries_, key=lambda e: e.name) def load_from_entry_point(entry_point: EntryPoint) -> PluginWrapper: @@ -123,23 +121,39 @@ def load_from_entry_point(entry_point: EntryPoint) -> PluginWrapper: raise ErrorLoadingPlugin(entry_point=entry_point) from ex -def list_from_entry_points( - group: str = ENTRYPOINT_GROUP, +def load_multi_entry_point(entry_point: EntryPoint) -> List[PluginWrapper]: + """Carefully load the plugin, raising a meaningful message in case of errors""" + try: + dict_plugins = entry_point.load() + return [PluginWrapper(k, v) for k, v in dict_plugins().items()] + except Exception as ex: + raise ErrorLoadingPlugin(entry_point=entry_point) from ex + + +def list_plugins_from_entry_points( filtering: Callable[[EntryPoint], bool] = lambda _: True, ) -> List[PluginWrapper]: """Produces a list of plugin objects for each plugin registered via ``setuptools`` `entry point`_ mechanism. Args: - group: name of the setuptools' entry point group where plugins is being - registered filtering: function returning a boolean deciding if the entry point should be loaded and included (or not) in the final list. A ``True`` return means the plugin should be included. """ - return [ - load_from_entry_point(e) for e in iterate_entry_points(group) if filtering(e) + eps = [ + load_from_entry_point(e) + for e in iterate_entry_points("validate_pyproject.tool_schema") + if filtering(e) + ] + eps += [ + ep + for e in iterate_entry_points("validate_pyproject.multi_schema") + for ep in load_multi_entry_point(e) + if filtering(e) ] + dedup = {e.tool: e for e in sorted(eps, key=lambda e: e.tool)} + return list(dedup.values()) class ErrorLoadingPlugin(RuntimeError): diff --git a/src/validate_pyproject/pre_compile/cli.py b/src/validate_pyproject/pre_compile/cli.py index 6694b7b..83ed6cc 100644 --- a/src/validate_pyproject/pre_compile/cli.py +++ b/src/validate_pyproject/pre_compile/cli.py @@ -10,8 +10,7 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Sequence from .. import cli -from ..plugins import PluginWrapper -from ..plugins import list_from_entry_points as list_plugins_from_entry_points +from ..plugins import PluginWrapper, list_plugins_from_entry_points from ..remote import RemotePlugin, load_store from . import pre_compile diff --git a/tests/test_api.py b/tests/test_api.py index 48e7ce2..9c5f03c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -32,7 +32,7 @@ def test_load_plugin(): class TestRegistry: def test_with_plugins(self): - plg = plugins.list_from_entry_points() + plg = plugins.list_plugins_from_entry_points() registry = api.SchemaRegistry(plg) main_schema = registry[registry.main] project = main_schema["properties"]["project"] @@ -112,7 +112,7 @@ def test_invalid(self): # --- def plugin(self, tool): - plg = plugins.list_from_entry_points(filtering=lambda e: e.name == tool) + plg = plugins.list_plugins_from_entry_points(filtering=lambda e: e.name == tool) return plg[0] TOOLS = ("distutils", "setuptools") diff --git a/tests/test_cli.py b/tests/test_cli.py index 7fd30f3..fb2c244 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -36,7 +36,7 @@ def test_custom_plugins(self, capsys): def parse_args(args): - plg = plugins.list_from_entry_points() + plg = plugins.list_plugins_from_entry_points() return cli.parse_args(args, plg) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index a5e8d01..f24fab3 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -3,11 +3,13 @@ # The original PyScaffold license can be found in 'NOTICE.txt' import sys +from types import ModuleType +from typing import Any, List import pytest from validate_pyproject import plugins -from validate_pyproject.plugins import ENTRYPOINT_GROUP, ErrorLoadingPlugin +from validate_pyproject.plugins import ErrorLoadingPlugin EXISTING = ( "setuptools", @@ -17,16 +19,16 @@ if sys.version_info[:2] >= (3, 8): # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` - from importlib.metadata import EntryPoint # pragma: no cover + from importlib import metadata # pragma: no cover else: - from importlib_metadata import EntryPoint # pragma: no cover + import importlib_metadata as metadata # pragma: no cover def test_load_from_entry_point__error(): # This module does not exist, so Python will have some trouble loading it - # EntryPoint(name, value, group) + # metadata.EntryPoint(name, value, group) entry = "mypkg.SOOOOO___fake___:activate" - fake = EntryPoint("fake", entry, ENTRYPOINT_GROUP) + fake = metadata.EntryPoint("fake", entry, "validate_pyproject.tool_schema") with pytest.raises(ErrorLoadingPlugin): plugins.load_from_entry_point(fake) @@ -36,7 +38,7 @@ def is_entry_point(ep): def test_iterate_entry_points(): - plugin_iter = plugins.iterate_entry_points() + plugin_iter = plugins.iterate_entry_points("validate_pyproject.tool_schema") assert hasattr(plugin_iter, "__iter__") pluging_list = list(plugin_iter) assert all(is_entry_point(e) for e in pluging_list) @@ -47,14 +49,14 @@ def test_iterate_entry_points(): def test_list_from_entry_points(): # Should return a list with all the plugins registered in the entrypoints - pluging_list = plugins.list_from_entry_points() + pluging_list = plugins.list_plugins_from_entry_points() orig_len = len(pluging_list) plugin_names = " ".join(e.tool for e in pluging_list) for example in EXISTING: assert example in plugin_names # a filtering function can be passed to avoid loading plugins that are not needed - pluging_list = plugins.list_from_entry_points( + pluging_list = plugins.list_plugins_from_entry_points( filtering=lambda e: e.name != "setuptools" ) plugin_names = " ".join(e.tool for e in pluging_list) @@ -76,3 +78,32 @@ def _fn2(_): pw = plugins.PluginWrapper("name", _fn2) assert pw.help_text == "Help for `name`" + + +def loader(name: str) -> Any: + return {"example": "thing"} + + +def dynamic_ep(): + return {"some#fragment": loader} + + +class Select(list): + def select(self, group: str) -> List[str]: + return list(self) if group == "validate_pyproject.multi_schema" else [] + + +def test_process_checks(monkeypatch: pytest.MonkeyPatch) -> None: + ep = metadata.EntryPoint( + name="_", + group="validate_pyproject.multi_schema", + value="test_module:dynamic_ep", + ) + sys.modules["test_module"] = ModuleType("test_module") + sys.modules["test_module"].dynamic_ep = dynamic_ep # type: ignore[attr-defined] + sys.modules["test_module"].loader = loader # type: ignore[attr-defined] + monkeypatch.setattr(plugins, "entry_points", lambda: Select([ep])) + eps = plugins.list_plugins_from_entry_points() + (ep,) = eps + assert ep.tool == "some" + assert ep.fragment == "fragment"