From 2cdb36fb2c4f657a5339d36dd5797bd891e7111d Mon Sep 17 00:00:00 2001 From: David Seddon Date: Sun, 24 Sep 2023 17:13:33 +0100 Subject: [PATCH] Fix bug with allow_indirect_imports (#197) Prior to this commit, forbidden contracts with allow_indirect_imports only checked imports between the source/forbidden modules specified, not the descendants of those modules. --- CHANGELOG.rst | 8 ++- src/importlinter/contracts/forbidden.py | 43 ++++++++---- tests/unit/contracts/test_forbidden.py | 89 +++++++++++++++++++++---- 3 files changed, 114 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d76e6fc4..e0407d02 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,12 +2,16 @@ Changelog ========= Latest -------------------- +------ * Officially support Python 3.12. -* Fix: Error when using `click` version 6.0 and 7.0 (#191). +* Fix error when using `click` version 6.0 and 7.0 (#191). * Allow extra whitespace around the module names in import expressions. * Ignore blank lines in multiple value fields. +* Fix bug with allow_indirect_imports in forbidden contracts. + Prior to this fix, forbidden contracts with allow_indirect_imports + only checked imports between the source/forbidden modules specified, + not the descendants of those modules. 1.11.1 (2023-08-21) ------------------- diff --git a/src/importlinter/contracts/forbidden.py b/src/importlinter/contracts/forbidden.py index 4c298614..6f68a2ef 100644 --- a/src/importlinter/contracts/forbidden.py +++ b/src/importlinter/contracts/forbidden.py @@ -1,5 +1,8 @@ from __future__ import annotations -from typing import List, Tuple, cast + +from typing import List, cast + +from grimp import ImportGraph from importlinter.application import contract_utils, output from importlinter.application.contract_utils import AlertLevel @@ -7,7 +10,6 @@ from importlinter.domain import fields from importlinter.domain.contract import Contract, ContractCheck from importlinter.domain.imports import Module -from grimp import ImportGraph from ._common import format_line_numbers @@ -34,7 +36,7 @@ class ForbiddenContract(Contract): source_modules = fields.ListField(subfield=fields.ModuleField()) forbidden_modules = fields.ListField(subfield=fields.ModuleField()) ignore_imports = fields.SetField(subfield=fields.ImportExpressionField(), required=False) - allow_indirect_imports = fields.StringField(required=False) + allow_indirect_imports = fields.BooleanField(required=False, default=False) unmatched_ignore_imports_alerting = fields.EnumField(AlertLevel, default=AlertLevel.ERROR) def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck: @@ -70,15 +72,7 @@ def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck: } if str(self.allow_indirect_imports).lower() == "true": - chains = { - cast( - Tuple[str, ...], - (str(import_det["importer"]), str(import_det["imported"])), - ) - for import_det in graph.get_import_details( - importer=source_module.name, imported=forbidden_module.name - ) - } + chains = self._get_direct_chains(source_module, forbidden_module, graph) else: chains = graph.find_shortest_chains( importer=source_module.name, imported=forbidden_module.name @@ -172,3 +166,28 @@ def _get_external_forbidden_modules(self) -> set[Module]: def _graph_was_built_with_externals(self) -> bool: return str(self.session_options.get("include_external_packages")).lower() == "true" + + def _get_direct_chains( + self, source_package: Module, forbidden_package: Module, graph: ImportGraph + ) -> set[tuple[str, ...]]: + chains: set[tuple[str, ...]] = set() + source_modules = self._get_all_modules_in_package(source_package, graph) + forbidden_modules = self._get_all_modules_in_package(forbidden_package, graph) + for source_module in source_modules: + imported_module_names = graph.find_modules_directly_imported_by(source_module.name) + for imported_module_name in imported_module_names: + imported_module = Module(imported_module_name) + if imported_module in forbidden_modules: + chains.add((source_module.name, imported_module.name)) + return chains + + def _get_all_modules_in_package(self, module: Module, graph: ImportGraph) -> set[Module]: + """ + Return all the modules in the supplied module, including itself. + + If the module is squashed, it will be treated as a single module. + """ + importer_modules = {module} + if not graph.is_module_squashed(module.name): + importer_modules |= {Module(m) for m in graph.find_descendants(module.name)} + return importer_modules diff --git a/tests/unit/contracts/test_forbidden.py b/tests/unit/contracts/test_forbidden.py index 58efb419..34e7f349 100644 --- a/tests/unit/contracts/test_forbidden.py +++ b/tests/unit/contracts/test_forbidden.py @@ -222,19 +222,79 @@ def test_ignore_imports_with_recursive_wildcards(self): } @pytest.mark.parametrize( - "allow_indirect_imports, contract_is_kept", - ((None, False), ("false", False), ("True", True), ("true", True), ("anything", False)), + "importer", + ("mypackage.one", "mypackage.one.alpha"), ) - def test_allow_indirect_imports(self, allow_indirect_imports, contract_is_kept): + @pytest.mark.parametrize( + "imported", + ("mypackage.mauve", "mypackage.mauve.beta"), + ) + @pytest.mark.parametrize( + "allow_indirect_imports", + (False, True), + ) + def test_allow_indirect_imports(self, importer, imported, allow_indirect_imports): graph = self._build_graph() contract = self._build_contract( - forbidden_modules=("mypackage.purple"), + forbidden_modules=("mypackage.mauve",), allow_indirect_imports=allow_indirect_imports, ) + graph.add_module("mypackage.mauve") + # Add a direct import. + graph.add_import( + importer=importer, + imported=imported, + line_number=10, + line_contents="-", + ) + # Add an indirect import. + graph.add_import( + importer="mypackage.one.delta", + imported="mypackage.something", + line_number=20, + line_contents="-", + ) + graph.add_import( + importer="mypackage.something", + imported="mypackage.mauve.gamma", + line_number=30, + line_contents="-", + ) contract_check = contract.check(graph=graph, verbose=False) - assert contract_check.kept == contract_is_kept + direct_chain = [ + { + "importer": importer, + "imported": imported, + "line_numbers": (10,), + }, + ] + indirect_chain = [ + { + "importer": "mypackage.one.delta", + "imported": "mypackage.something", + "line_numbers": (20,), + }, + { + "importer": "mypackage.something", + "imported": "mypackage.mauve.gamma", + "line_numbers": (30,), + }, + ] + if allow_indirect_imports: + expected_chains = [direct_chain] + else: + expected_chains = [direct_chain, indirect_chain] + assert contract_check.metadata == { + "invalid_chains": [ + { + "upstream_module": "mypackage.mauve", + "downstream_module": "mypackage.one", + "chains": expected_chains, + }, + ], + } def test_ignore_imports_adds_warnings(self): graph = self._build_graph() @@ -360,15 +420,20 @@ def _build_contract( if include_external_packages: session_options["include_external_packages"] = "True" + contract_options = { + "source_modules": ("mypackage.one", "mypackage.two", "mypackage.three"), + "forbidden_modules": forbidden_modules, + "ignore_imports": ignore_imports or [], + } + if allow_indirect_imports is not None: + contract_options["allow_indirect_imports"] = ( + "true" if allow_indirect_imports else "false" + ) + return ForbiddenContract( name="Forbid contract", session_options=session_options, - contract_options={ - "source_modules": ("mypackage.one", "mypackage.two", "mypackage.three"), - "forbidden_modules": forbidden_modules, - "ignore_imports": ignore_imports or [], - "allow_indirect_imports": allow_indirect_imports, - }, + contract_options=contract_options, ) @@ -562,7 +627,7 @@ def test_verbose(self): "mypackage.purple", ), "ignore_imports": [], - "allow_indirect_imports": False, + "allow_indirect_imports": "false", }, )