Skip to content

Commit

Permalink
Fix bug with allow_indirect_imports (#197)
Browse files Browse the repository at this point in the history
Prior to this commit, forbidden contracts
with allow_indirect_imports only checked
imports between the source/forbidden
modules specified, not the descendants of
those modules.
  • Loading branch information
seddonym authored Sep 24, 2023
1 parent 6383ff7 commit 2cdb36f
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 26 deletions.
8 changes: 6 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ Changelog
=========

Latest
-------------------
------

* Officially support Python 3.12.
* Fix: Error when using `click` version 6.0 and 7.0 (#191).
* Fix error when using `click` version 6.0 and 7.0 (#191).
* Allow extra whitespace around the module names in import expressions.
* Ignore blank lines in multiple value fields.
* Fix bug with allow_indirect_imports in forbidden contracts.
Prior to this fix, forbidden contracts with allow_indirect_imports
only checked imports between the source/forbidden modules specified,
not the descendants of those modules.

1.11.1 (2023-08-21)
-------------------
Expand Down
43 changes: 31 additions & 12 deletions src/importlinter/contracts/forbidden.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations
from typing import List, Tuple, cast

from typing import List, cast

from grimp import ImportGraph

from importlinter.application import contract_utils, output
from importlinter.application.contract_utils import AlertLevel
from importlinter.configuration import settings
from importlinter.domain import fields
from importlinter.domain.contract import Contract, ContractCheck
from importlinter.domain.imports import Module
from grimp import ImportGraph

from ._common import format_line_numbers

Expand All @@ -34,7 +36,7 @@ class ForbiddenContract(Contract):
source_modules = fields.ListField(subfield=fields.ModuleField())
forbidden_modules = fields.ListField(subfield=fields.ModuleField())
ignore_imports = fields.SetField(subfield=fields.ImportExpressionField(), required=False)
allow_indirect_imports = fields.StringField(required=False)
allow_indirect_imports = fields.BooleanField(required=False, default=False)
unmatched_ignore_imports_alerting = fields.EnumField(AlertLevel, default=AlertLevel.ERROR)

def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck:
Expand Down Expand Up @@ -70,15 +72,7 @@ def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck:
}

if str(self.allow_indirect_imports).lower() == "true":
chains = {
cast(
Tuple[str, ...],
(str(import_det["importer"]), str(import_det["imported"])),
)
for import_det in graph.get_import_details(
importer=source_module.name, imported=forbidden_module.name
)
}
chains = self._get_direct_chains(source_module, forbidden_module, graph)
else:
chains = graph.find_shortest_chains(
importer=source_module.name, imported=forbidden_module.name
Expand Down Expand Up @@ -172,3 +166,28 @@ def _get_external_forbidden_modules(self) -> set[Module]:

def _graph_was_built_with_externals(self) -> bool:
return str(self.session_options.get("include_external_packages")).lower() == "true"

def _get_direct_chains(
self, source_package: Module, forbidden_package: Module, graph: ImportGraph
) -> set[tuple[str, ...]]:
chains: set[tuple[str, ...]] = set()
source_modules = self._get_all_modules_in_package(source_package, graph)
forbidden_modules = self._get_all_modules_in_package(forbidden_package, graph)
for source_module in source_modules:
imported_module_names = graph.find_modules_directly_imported_by(source_module.name)
for imported_module_name in imported_module_names:
imported_module = Module(imported_module_name)
if imported_module in forbidden_modules:
chains.add((source_module.name, imported_module.name))
return chains

def _get_all_modules_in_package(self, module: Module, graph: ImportGraph) -> set[Module]:
"""
Return all the modules in the supplied module, including itself.
If the module is squashed, it will be treated as a single module.
"""
importer_modules = {module}
if not graph.is_module_squashed(module.name):
importer_modules |= {Module(m) for m in graph.find_descendants(module.name)}
return importer_modules
89 changes: 77 additions & 12 deletions tests/unit/contracts/test_forbidden.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,19 +222,79 @@ def test_ignore_imports_with_recursive_wildcards(self):
}

@pytest.mark.parametrize(
"allow_indirect_imports, contract_is_kept",
((None, False), ("false", False), ("True", True), ("true", True), ("anything", False)),
"importer",
("mypackage.one", "mypackage.one.alpha"),
)
def test_allow_indirect_imports(self, allow_indirect_imports, contract_is_kept):
@pytest.mark.parametrize(
"imported",
("mypackage.mauve", "mypackage.mauve.beta"),
)
@pytest.mark.parametrize(
"allow_indirect_imports",
(False, True),
)
def test_allow_indirect_imports(self, importer, imported, allow_indirect_imports):
graph = self._build_graph()
contract = self._build_contract(
forbidden_modules=("mypackage.purple"),
forbidden_modules=("mypackage.mauve",),
allow_indirect_imports=allow_indirect_imports,
)
graph.add_module("mypackage.mauve")
# Add a direct import.
graph.add_import(
importer=importer,
imported=imported,
line_number=10,
line_contents="-",
)
# Add an indirect import.
graph.add_import(
importer="mypackage.one.delta",
imported="mypackage.something",
line_number=20,
line_contents="-",
)
graph.add_import(
importer="mypackage.something",
imported="mypackage.mauve.gamma",
line_number=30,
line_contents="-",
)

contract_check = contract.check(graph=graph, verbose=False)

assert contract_check.kept == contract_is_kept
direct_chain = [
{
"importer": importer,
"imported": imported,
"line_numbers": (10,),
},
]
indirect_chain = [
{
"importer": "mypackage.one.delta",
"imported": "mypackage.something",
"line_numbers": (20,),
},
{
"importer": "mypackage.something",
"imported": "mypackage.mauve.gamma",
"line_numbers": (30,),
},
]
if allow_indirect_imports:
expected_chains = [direct_chain]
else:
expected_chains = [direct_chain, indirect_chain]
assert contract_check.metadata == {
"invalid_chains": [
{
"upstream_module": "mypackage.mauve",
"downstream_module": "mypackage.one",
"chains": expected_chains,
},
],
}

def test_ignore_imports_adds_warnings(self):
graph = self._build_graph()
Expand Down Expand Up @@ -360,15 +420,20 @@ def _build_contract(
if include_external_packages:
session_options["include_external_packages"] = "True"

contract_options = {
"source_modules": ("mypackage.one", "mypackage.two", "mypackage.three"),
"forbidden_modules": forbidden_modules,
"ignore_imports": ignore_imports or [],
}
if allow_indirect_imports is not None:
contract_options["allow_indirect_imports"] = (
"true" if allow_indirect_imports else "false"
)

return ForbiddenContract(
name="Forbid contract",
session_options=session_options,
contract_options={
"source_modules": ("mypackage.one", "mypackage.two", "mypackage.three"),
"forbidden_modules": forbidden_modules,
"ignore_imports": ignore_imports or [],
"allow_indirect_imports": allow_indirect_imports,
},
contract_options=contract_options,
)


Expand Down Expand Up @@ -562,7 +627,7 @@ def test_verbose(self):
"mypackage.purple",
),
"ignore_imports": [],
"allow_indirect_imports": False,
"allow_indirect_imports": "false",
},
)

Expand Down

0 comments on commit 2cdb36f

Please sign in to comment.