diff --git a/craft_application/application.py b/craft_application/application.py index e098f68a..bfb43310 100644 --- a/craft_application/application.py +++ b/craft_application/application.py @@ -407,7 +407,7 @@ def project(self) -> models.Project: def is_managed(self) -> bool: """Shortcut to tell whether we're running in managed mode.""" - return self.services.ProviderClass.is_managed() + return self.services.get_class("provider").is_managed() def run_managed(self, platform: str | None, build_for: str | None) -> None: """Run the application in a managed instance.""" diff --git a/craft_application/services/fetch.py b/craft_application/services/fetch.py index f170576b..161f5fc7 100644 --- a/craft_application/services/fetch.py +++ b/craft_application/services/fetch.py @@ -83,7 +83,7 @@ def setup(self) -> None: """Start the fetch-service process with proper arguments.""" super().setup() - if not self._services.ProviderClass.is_managed(): + if not self._services.get_class("provider").is_managed(): # Early fail if the fetch-service is not installed. fetch.verify_installed() diff --git a/craft_application/services/service_factory.py b/craft_application/services/service_factory.py index 51416e52..715d6ac4 100644 --- a/craft_application/services/service_factory.py +++ b/craft_application/services/service_factory.py @@ -1,4 +1,4 @@ -# Copyright 2023 Canonical Ltd. +# Copyright 2023-2024 Canonical Ltd. # # This program is free software: you can redistribute it and/or modify it # under the terms of the GNU Lesser General Public License version 3, as @@ -15,14 +15,41 @@ from __future__ import annotations import dataclasses +import importlib +import re import warnings -from typing import TYPE_CHECKING, Any +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + ClassVar, + Literal, + TypeVar, + cast, + overload, +) + +import annotated_types from craft_application import models, services if TYPE_CHECKING: from craft_application.application import AppMetadata +_DEFAULT_SERVICES = { + "config": "ConfigService", + "fetch": "FetchService", + "init": "InitService", + "lifecycle": "LifecycleService", + "provider": "ProviderService", + "remote_build": "RemoteBuildService", + "request": "RequestService", +} +_CAMEL_TO_PYTHON_CASE_REGEX = re.compile(r"(? None: self._service_kwargs: dict[str, dict[str, Any]] = {} + self._services: dict[str, services.AppService] = {} + + factory_dict = dataclasses.asdict(self) + for cls_name, value in factory_dict.items(): + if cls_name.endswith("Class"): + if value is not None: + identifier = _CAMEL_TO_PYTHON_CASE_REGEX.sub( + "_", cls_name[:-5] + ).lower() + warnings.warn( + f'Registering services on service factory instantiation is deprecated. Use ServiceFactory.register("{identifier}", {value.__name__}) instead.', + category=DeprecationWarning, + stacklevel=3, + ) + self.register(identifier, value) + setattr(self, cls_name, self.get_class(cls_name)) + + if "package" not in self._service_classes: + raise TypeError( + "A PackageService must be registered before creating the ServiceFactory." + ) + + @classmethod + def register( + cls, + name: str, + service_class: type[services.AppService] | str, + *, + module: str | None = None, + ) -> None: + """Register a service class with a given name. + + :param name: the name to call the service class. + :param service_class: either a service class or a string that names the service + class. + :param module: If service_class is a string, the module from which to import + the service class. + """ + if isinstance(service_class, str): + if module is None: + raise KeyError("Must set module if service_class is set by name.") + cls._service_classes[name] = (module, service_class) + else: + if module is not None: + raise KeyError( + "Must not set module if service_class is passed by value." + ) + cls._service_classes[name] = service_class + + # For backwards compatibility with class attribute service types. + service_cls_name = "".join(word.title() for word in name.split("_")) + "Class" + setattr(cls, service_cls_name, cls.get_class(name)) + + @classmethod + def reset(cls) -> None: + """Reset the registered services.""" + cls._service_classes.clear() + for name, class_name in _DEFAULT_SERVICES.items(): + module_name = name.replace("_", "") + cls.register( + name, class_name, module=f"craft_application.services.{module_name}" + ) def set_kwargs( self, @@ -94,25 +185,101 @@ def update_kwargs( """ self._service_kwargs.setdefault(service, {}).update(kwargs) - def __getattr__(self, service: str) -> services.AppService: - """Instantiate a service class. + @overload + @classmethod + def get_class( + cls, name: Literal["config", "ConfigService", "ConfigClass"] + ) -> type[services.ConfigService]: ... + @overload + @classmethod + def get_class( + cls, name: Literal["fetch", "FetchService", "FetchClass"] + ) -> type[services.FetchService]: ... + @overload + @classmethod + def get_class( + cls, name: Literal["init", "InitService", "InitClass"] + ) -> type[services.InitService]: ... + @overload + @classmethod + def get_class( + cls, name: Literal["lifecycle", "LifecycleService", "LifecycleClass"] + ) -> type[services.LifecycleService]: ... + @overload + @classmethod + def get_class( + cls, name: Literal["package", "PackageService", "PackageClass"] + ) -> type[services.PackageService]: ... + @overload + @classmethod + def get_class( + cls, name: Literal["provider", "ProviderService", "ProviderClass"] + ) -> type[services.ProviderService]: ... + @overload + @classmethod + def get_class( + cls, name: Literal["remote_build", "RemoteBuildService", "RemoteBuildClass"] + ) -> type[services.RemoteBuildService]: ... + @overload + @classmethod + def get_class( + cls, name: Literal["request", "RequestService", "RequestClass"] + ) -> type[services.RequestService]: ... + @overload + @classmethod + def get_class(cls, name: str) -> type[services.AppService]: ... + @classmethod + def get_class(cls, name: str) -> type[services.AppService]: + """Get the class for a service by its name.""" + if name.endswith("Class"): + service_cls_name = name + service = _CAMEL_TO_PYTHON_CASE_REGEX.sub("_", name[:-5]).lower() + elif name.endswith("Service"): + service = _CAMEL_TO_PYTHON_CASE_REGEX.sub("_", name[:-7]).lower() + service_cls_name = name[:-7] + "Class" + else: + service_cls_name = "".join(word.title() for word in name.split("_")) + service_cls_name += "Class" + service = name + if service not in cls._service_classes: + raise AttributeError(f"Not a registered service: {service}") + service_info = cls._service_classes[service] + if isinstance(service_info, tuple): + module_name, class_name = service_info + module = importlib.import_module(module_name) + return cast(type[services.AppService], getattr(module, class_name)) + return service_info - This allows us to lazy-load only the necessary services whilst still - treating them as attributes of our factory in a dynamic manner. - For a service (e.g. ``package``, the PackageService instance) that has not - been instantiated, this method finds the corresponding class, instantiates - it with defaults and any values set using ``set_kwargs``, and stores the - instantiated service as an instance attribute, allowing the same service - instance to be reused for the entire run of the application. + @overload + def get(self, service: Literal["config"]) -> services.ConfigService: ... + @overload + def get(self, service: Literal["fetch"]) -> services.FetchService: ... + @overload + def get(self, service: Literal["init"]) -> services.InitService: ... + @overload + def get(self, service: Literal["package"]) -> services.PackageService: ... + @overload + def get(self, service: Literal["lifecycle"]) -> services.LifecycleService: ... + @overload + def get(self, service: Literal["provider"]) -> services.ProviderService: ... + @overload + def get(self, service: Literal["remote_build"]) -> services.RemoteBuildService: ... + @overload + def get(self, service: Literal["request"]) -> services.RequestService: ... + @overload + def get(self, service: str) -> services.AppService: ... + def get(self, service: str) -> services.AppService: + """Get a service by name. + + :param service: the name of the service (e.g. "config") + :returns: An instantiated and set up service class. + + Also caches the service so as to provide a single service instance per + ServiceFactory. """ - service_cls_name = "".join(word.title() for word in service.split("_")) - service_cls_name += "Class" - classes = dataclasses.asdict(self) - if service_cls_name not in classes: - raise AttributeError(service) - cls = getattr(self, service_cls_name) - if not issubclass(cls, services.AppService): - raise TypeError(f"{cls.__name__} is not a service class") + if service in self._services: + return self._services[service] + cls = self.get_class(service) kwargs = self._service_kwargs.get(service, {}) if issubclass(cls, services.ProjectService): if not self.project: @@ -121,7 +288,25 @@ def __getattr__(self, service: str) -> services.AppService: ) kwargs.setdefault("project", self.project) - instance: services.AppService = cls(app=self.app, services=self, **kwargs) + instance = cls(app=self.app, services=self, **kwargs) instance.setup() - setattr(self, service, instance) + self._services[service] = instance return instance + + def __getattr__(self, name: str) -> services.AppService | type[services.AppService]: + """Instantiate a service class. + + This allows us to lazy-load only the necessary services whilst still + treating them as attributes of our factory in a dynamic manner. + For a service (e.g. ``package``, the PackageService instance) that has not + been instantiated, this method finds the corresponding class, instantiates + it with defaults and any values set using ``set_kwargs``, and stores the + instantiated service as an instance attribute, allowing the same service + instance to be reused for the entire run of the application. + """ + result = self.get_class(name) if name.endswith("Class") else self.get(name) + setattr(self, name, result) + return result + + +ServiceFactory.reset() # Set up default services on import. diff --git a/pyproject.toml b/pyproject.toml index fe9adf83..9fb79003 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ name = "craft-application" description = "A framework for *craft applications." dynamic = ["version", "readme"] dependencies = [ + "annotated-types>=0.6.0", "craft-archives>=2.0.0", "craft-cli>=2.10.1", "craft-grammar>=2.0.0", diff --git a/tests/conftest.py b/tests/conftest.py index 383a2adb..bed92b32 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,7 @@ import pydantic import pytest from craft_application import application, git, launchpad, models, services, util +from craft_application.services import service_factory from craft_cli import EmitterMode, emit from craft_providers import bases from jinja2 import FileSystemLoader @@ -46,6 +47,12 @@ def _create_fake_build_plan(num_infos: int = 1) -> list[models.BuildInfo]: return [models.BuildInfo("foo", arch, arch, base)] * num_infos +@pytest.fixture(autouse=True) +def reset_services(): + yield + service_factory.ServiceFactory.reset() + + @pytest.fixture def features(request) -> dict[str, bool]: """Fixture that controls the enabled features. @@ -310,19 +317,21 @@ def _get_loader(self, template_dir: pathlib.Path) -> jinja2.BaseLoader: @pytest.fixture def fake_services( + tmp_path, app_metadata, fake_project, fake_lifecycle_service_class, fake_package_service_class, fake_init_service_class, ): - return services.ServiceFactory( - app_metadata, - project=fake_project, - PackageClass=fake_package_service_class, - LifecycleClass=fake_lifecycle_service_class, - InitClass=fake_init_service_class, + services.ServiceFactory.register("package", fake_package_service_class) + services.ServiceFactory.register("lifecycle", fake_lifecycle_service_class) + services.ServiceFactory.register("init", fake_init_service_class) + factory = services.ServiceFactory(app_metadata, project=fake_project) + factory.update_kwargs( + "lifecycle", work_dir=tmp_path, cache_dir=tmp_path / "cache", build_plan=[] ) + return factory class FakeApplication(application.Application): diff --git a/tests/integration/services/test_service_factory.py b/tests/integration/services/test_service_factory.py index 5bcab419..4b50d071 100644 --- a/tests/integration/services/test_service_factory.py +++ b/tests/integration/services/test_service_factory.py @@ -18,7 +18,7 @@ from craft_application import services -def test_gets_real_services( +def test_gets_dataclass_services( check, app_metadata, fake_project, @@ -39,6 +39,27 @@ def test_gets_real_services( check.is_instance(factory.provider, services.ProviderService) +def test_gets_registered_services( + check, + app_metadata, + fake_project, + fake_package_service_class, + fake_lifecycle_service_class, + fake_provider_service_class, +): + services.ServiceFactory.register("package", fake_package_service_class) + services.ServiceFactory.register("lifecycle", fake_lifecycle_service_class) + services.ServiceFactory.register("provider", fake_provider_service_class) + factory = services.ServiceFactory( + app_metadata, + project=fake_project, + ) + + check.is_instance(factory.package, services.PackageService) + check.is_instance(factory.lifecycle, services.LifecycleService) + check.is_instance(factory.provider, services.ProviderService) + + def test_real_service_error(app_metadata, fake_project): factory = services.ServiceFactory( app_metadata, diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 91bf03bc..0a07a23e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -22,6 +22,7 @@ import pytest import pytest_mock from craft_application import git, services, util +from craft_application.services import service_factory @pytest.fixture(params=["amd64", "arm64", "riscv64"]) @@ -44,17 +45,32 @@ def provider_service( @pytest.fixture -def mock_services(app_metadata, fake_project, fake_package_service_class): - factory = services.ServiceFactory( - app_metadata, project=fake_project, PackageClass=fake_package_service_class +def mock_services(monkeypatch, app_metadata, fake_project): + services.ServiceFactory.register("config", mock.Mock(spec=services.ConfigService)) + services.ServiceFactory.register("fetch", mock.Mock(spec=services.FetchService)) + services.ServiceFactory.register("init", mock.MagicMock(spec=services.InitService)) + services.ServiceFactory.register( + "lifecycle", mock.Mock(spec=services.LifecycleService) ) - factory.lifecycle = mock.Mock(spec=services.LifecycleService) - factory.package = mock.Mock(spec=services.PackageService) - factory.provider = mock.Mock(spec=services.ProviderService) - factory.remote_build = mock.Mock(spec_set=services.RemoteBuildService) - factory.fetch = mock.Mock(spec=services.FetchService) - factory.init = mock.Mock(spec=services.InitService) - return factory + services.ServiceFactory.register("package", mock.Mock(spec=services.PackageService)) + services.ServiceFactory.register( + "provider", mock.Mock(spec=services.ProviderService) + ) + services.ServiceFactory.register( + "remote_build", mock.Mock(spec=services.RemoteBuildService) + ) + + def forgiving_is_subclass(child, parent): + if not isinstance(child, type): + return False + return issubclass(child, parent) + + # Mock out issubclass on the service factory since we're registering mock objects + # rather than actual classes. + monkeypatch.setattr( + service_factory, "issubclass", forgiving_is_subclass, raising=False + ) + return services.ServiceFactory(app_metadata, project=fake_project) @pytest.fixture diff --git a/tests/unit/services/test_service_factory.py b/tests/unit/services/test_service_factory.py index 223296e8..7dcf4f0a 100644 --- a/tests/unit/services/test_service_factory.py +++ b/tests/unit/services/test_service_factory.py @@ -21,20 +21,80 @@ from craft_application import AppMetadata, services from craft_cli import emit +pytestmark = [ + pytest.mark.filterwarnings("ignore:Registering services on service factory") +] + + +class FakeService(services.AppService): + """A fake service for testing.""" + @pytest.fixture def factory( - app_metadata, fake_project, fake_package_service_class, fake_lifecycle_service_class + tmp_path, + app_metadata, + fake_project, + fake_package_service_class, + fake_lifecycle_service_class, ): - return services.ServiceFactory( + services.ServiceFactory.register("package", fake_package_service_class) + services.ServiceFactory.register("lifecycle", fake_lifecycle_service_class) + + factory = services.ServiceFactory( app_metadata, project=fake_project, - PackageClass=fake_package_service_class, - LifecycleClass=fake_lifecycle_service_class, ) + factory.update_kwargs( + "lifecycle", + work_dir=tmp_path, + cache_dir=tmp_path / "cache", + build_plan=[], + ) + return factory + + +@pytest.mark.parametrize( + ("service_class", "module"), + [ + ("ConfigService", "craft_application.services.config"), + ("InitService", "craft_application.services.init"), + ], +) +def test_register_service_by_path(service_class, module): + services.ServiceFactory.register("testy", service_class, module=module) + + service = services.ServiceFactory.get_class("testy") + pytest_check.equal(service.__module__, module) + pytest_check.equal(service.__name__, service_class) + pytest_check.is_( + service, + services.ServiceFactory.TestyClass, # pyright: ignore[reportAttributeAccessIssue] + ) + + +def test_register_service_by_reference(): + services.ServiceFactory.register("testy", FakeService) + + service = services.ServiceFactory.get_class("testy") + pytest_check.is_(service, FakeService) + pytest_check.is_( + service, + services.ServiceFactory.TestyClass, # pyright: ignore[reportAttributeAccessIssue] + ) + +def test_register_service_by_path_no_module(): + with pytest.raises(KeyError, match="Must set module"): + services.ServiceFactory.register("testy", "FakeService") -def test_correct_init( + +def test_register_service_by_reference_with_module(): + with pytest.raises(KeyError, match="Must not set module"): + services.ServiceFactory.register("testy", FakeService, module="__main__") + + +def test_register_services_in_init( app_metadata, fake_project, fake_package_service_class, @@ -49,9 +109,9 @@ def test_correct_init( ProviderClass=fake_provider_service_class, ) - pytest_check.is_instance(factory.package, services.PackageService) - pytest_check.is_instance(factory.lifecycle, services.LifecycleService) - pytest_check.is_instance(factory.provider, services.ProviderService) + pytest_check.is_instance(factory.package, fake_package_service_class) + pytest_check.is_instance(factory.lifecycle, fake_lifecycle_service_class) + pytest_check.is_instance(factory.provider, fake_provider_service_class) @pytest.mark.parametrize( @@ -128,16 +188,79 @@ def __new__(cls, *args, **kwargs): ) -def test_getattr_cached_service(monkeypatch, check, factory): +def test_get_class(): + mock_service = mock.Mock(spec=services.AppService) + services.ServiceFactory.register("test_service", mock_service) + + pytest_check.is_(services.ServiceFactory.get_class("test_service"), mock_service) + pytest_check.is_( + services.ServiceFactory.get_class("TestServiceClass"), mock_service + ) + pytest_check.is_( + services.ServiceFactory.get_class("TestServiceService"), mock_service + ) + + +def test_get_class_not_registered(): + with pytest.raises( + AttributeError, match="Not a registered service: not_registered" + ): + services.ServiceFactory.get_class("not_registered") + with pytest.raises( + AttributeError, match="Not a registered service: not_registered" + ): + services.ServiceFactory.get_class("NotRegisteredService") + with pytest.raises( + AttributeError, match="Not a registered service: not_registered" + ): + services.ServiceFactory.get_class("NotRegisteredClass") + + +def test_get_default_services( + factory, fake_package_service_class, fake_lifecycle_service_class +): + pytest_check.is_instance(factory.get("package"), fake_package_service_class) + pytest_check.is_instance(factory.get("lifecycle"), fake_lifecycle_service_class) + pytest_check.is_instance(factory.get("config"), services.ConfigService) + pytest_check.is_instance(factory.get("init"), services.InitService) + + +def test_get_registered_service(factory): + factory.register("testy", FakeService) + + first_result = factory.get("testy") + pytest_check.is_instance(first_result, FakeService) + pytest_check.is_(first_result, factory.get("testy")) + + +def test_get_unregistered_service(factory): + with pytest.raises( + AttributeError, match="Not a registered service: not_registered" + ): + factory.get("not_registered") + with pytest.raises( + AttributeError, match="Not a registered service: not_registered" + ): + factory.get("NotRegisteredService") + with pytest.raises( + AttributeError, match="Not a registered service: not_registered" + ): + factory.get("NotRegisteredClass") + + +def test_get_project_service_error(factory): + factory.project = None + with pytest.raises(ValueError, match="LifecycleService requires a project"): + factory.get("lifecycle") + + +def test_getattr_cached_service(monkeypatch, factory): mock_getattr = mock.Mock(wraps=factory.__getattr__) monkeypatch.setattr(services.ServiceFactory, "__getattr__", mock_getattr) first = factory.package second = factory.package - check.is_(first, second) - # Only gets called once because the second time `package` is an instance attribute. - with check: - mock_getattr.assert_called_once_with("package") + assert first is second def test_getattr_not_a_class(factory): @@ -206,3 +329,14 @@ def test_mandatory_adoptable_field( ) _ = factory.lifecycle + + +@pytest.mark.parametrize( + ("name", "cls"), + [ + ("PackageClass", services.PackageService), + ], +) +def test_services_on_instantiation_deprecated(app_metadata, name, cls): + with pytest.warns(DeprecationWarning, match="Use ServiceFactory.register"): + services.ServiceFactory(**{"app": app_metadata, name: cls}) diff --git a/tests/unit/test_application.py b/tests/unit/test_application.py index 75ad9271..02dfae78 100644 --- a/tests/unit/test_application.py +++ b/tests/unit/test_application.py @@ -515,7 +515,7 @@ def test_merge_default_commands_only(app): ) def test_log_path(monkeypatch, app, provider_managed, expected): monkeypatch.setattr( - app.services.ProviderClass, "is_managed", lambda: provider_managed + app.services.get_class("provider"), "is_managed", lambda: provider_managed ) actual = app.log_path @@ -702,7 +702,9 @@ def test_get_arg_or_config(monkeypatch, app, parsed_args, environ, item, expecte def test_get_dispatcher_error( monkeypatch, check, capsys, app, mock_dispatcher, managed, error, exit_code, message ): - monkeypatch.setattr(app.services.ProviderClass, "is_managed", lambda: managed) + monkeypatch.setattr( + app.services.get_class("provider"), "is_managed", lambda: managed + ) mock_dispatcher.pre_parse_args.side_effect = error with pytest.raises(SystemExit) as exc_info: diff --git a/tests/unit/test_application_fetch.py b/tests/unit/test_application_fetch.py index 718bba3b..3487d0ec 100644 --- a/tests/unit/test_application_fetch.py +++ b/tests/unit/test_application_fetch.py @@ -99,8 +99,8 @@ def test_run_managed_fetch_service( app._build_plan = fake_build_plan fetch_calls: list[str] = [] - app.services.FetchClass = FakeFetchService - app.services.set_kwargs("fetch", fetch_calls=fetch_calls) + app.services.register("fetch", FakeFetchService) + app.services.update_kwargs("fetch", fetch_calls=fetch_calls) monkeypatch.setattr("sys.argv", ["testcraft", "pack", *pack_args]) app.run()