diff --git a/AUTHORS.rst b/AUTHORS.rst index 1b9d66ad..8fe811d6 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -22,3 +22,4 @@ Contributors * piglei - https://github.com/piglei * Anton Gruebel - https://github.com/gruebel * Peter Byfield - https://github.com/Peter554 +* Fabian Binz - https://github.com/fbinz \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d347093f..587aeb66 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,12 @@ Changelog ========= +latest +------ + +* Add support for wildcards in forbidden and independence contracts. + + 2.0 (2024-1-9) -------------- diff --git a/docs/contract_types.rst b/docs/contract_types.rst index 075384a1..c452f904 100644 --- a/docs/contract_types.rst +++ b/docs/contract_types.rst @@ -2,7 +2,6 @@ Contract types ============== - .. _forbidden modules: Forbidden modules @@ -59,8 +58,6 @@ External packages may also be forbidden. **Configuration options** -Configuration options: - - ``source_modules``: A list of modules that should not import the forbidden modules. - ``forbidden_modules``: A list of modules that should not be imported by the source modules. These may include root level external packages (i.e. ``django``, but not ``django.db.models``). If external packages are included, @@ -335,14 +332,26 @@ Options used by multiple contracts - ``ignore_imports``: Optional list of imports, each in the form ``mypackage.foo.importer -> mypackage.bar.imported``. These imports will be ignored: if the import would cause a contract to be broken, adding it to the list will cause the - contract be kept instead. + contract be kept instead. Supports :ref:`wildcards`. + +- ``unmatched_ignore_imports_alerting``: The alerting level for handling expressions supplied in ``ignore_imports`` + that do not match any imports in the graph. Choices are: - Wildcards are supported. ``*`` stands in for a module name, without including subpackages. ``**`` includes - subpackages too. + - ``error``: Error if there are any unmatched expressions (default). + - ``warn``: Print a warning for each unmatched expression. + - ``none``: Do not alert. - Note that this wildcard format is only supported for the ``ignore_imports`` fields. It can't currently be used for - other fields, such as in the ``source_modules`` field of a :ref:`forbidden modules` contract. +.. _wildcards: +Wildcards +--------- + + Wildcards are supported in most places where a module name is required to express a set of modules. + ``*`` stands in for a module name, without including subpackages. ``**`` includes subpackages too. + + Note that at the moment, layer contracts only support wildcards in `illegal_imports`. + If you have a use case for this, please file an issue. + Examples: - ``mypackage.*``: matches ``mypackage.foo`` but not ``mypackage.foo.bar``. @@ -352,9 +361,3 @@ Options used by multiple contracts - ``mypackage.**.qux``: matches ``mypackage.foo.bar.qux`` and ``mypackage.foo.bar.baz.qux``. - ``mypackage.foo*``: not a valid expression. (The wildcard must replace a whole module name.) -- ``unmatched_ignore_imports_alerting``: The alerting level for handling expressions supplied in ``ignore_imports`` - that do not match any imports in the graph. Choices are: - - - ``error``: Error if there are any unmatched expressions (default). - - ``warn``: Print a warning for each unmatched expression. - - ``none``: Do not alert. diff --git a/src/importlinter/api.py b/src/importlinter/api.py index c766eecb..f49ce1dd 100644 --- a/src/importlinter/api.py +++ b/src/importlinter/api.py @@ -1,6 +1,7 @@ """ Module for public-facing Python functions. """ + from __future__ import annotations from importlinter.application import use_cases diff --git a/src/importlinter/application/contract_utils.py b/src/importlinter/application/contract_utils.py index 3fb6481a..b3e6e70d 100644 --- a/src/importlinter/application/contract_utils.py +++ b/src/importlinter/application/contract_utils.py @@ -1,6 +1,7 @@ import enum from typing import List, Optional, Sequence, Set + from importlinter.domain import helpers from importlinter.domain.helpers import MissingImport from importlinter.domain.imports import ImportExpression diff --git a/src/importlinter/contracts/_common.py b/src/importlinter/contracts/_common.py index 7453eb19..9da36474 100644 --- a/src/importlinter/contracts/_common.py +++ b/src/importlinter/contracts/_common.py @@ -5,6 +5,7 @@ relying on it for a custom contract type, be aware things may change without warning. """ + from __future__ import annotations import itertools diff --git a/src/importlinter/contracts/forbidden.py b/src/importlinter/contracts/forbidden.py index 6f68a2ef..f10eff52 100644 --- a/src/importlinter/contracts/forbidden.py +++ b/src/importlinter/contracts/forbidden.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, cast +from typing import Iterable, List, cast from grimp import ImportGraph @@ -9,6 +9,7 @@ from importlinter.configuration import settings from importlinter.domain import fields from importlinter.domain.contract import Contract, ContractCheck +from importlinter.domain.helpers import module_expressions_to_modules from importlinter.domain.imports import Module from ._common import format_line_numbers @@ -33,8 +34,8 @@ class ForbiddenContract(Contract): type_name = "forbidden" - source_modules = fields.ListField(subfield=fields.ModuleField()) - forbidden_modules = fields.ListField(subfield=fields.ModuleField()) + source_modules = fields.ListField(subfield=fields.ModuleExpressionField()) + forbidden_modules = fields.ListField(subfield=fields.ModuleExpressionField()) ignore_imports = fields.SetField(subfield=fields.ImportExpressionField(), required=False) allow_indirect_imports = fields.BooleanField(required=False, default=False) unmatched_ignore_imports_alerting = fields.EnumField(AlertLevel, default=AlertLevel.ERROR) @@ -49,16 +50,30 @@ def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck: unmatched_alerting=self.unmatched_ignore_imports_alerting, # type: ignore ) - self._check_all_modules_exist_in_graph(graph) - self._check_external_forbidden_modules() + source_modules = list( + module_expressions_to_modules( + graph, + self.source_modules, # type: ignore + ) + ) + forbidden_modules = list( + module_expressions_to_modules( + graph, + self.forbidden_modules, # type: ignore + ) + ) + + self._check_all_modules_exist_in_graph(source_modules, graph) + self._check_external_forbidden_modules(forbidden_modules) # We only need to check for illegal imports for forbidden modules that are in the graph. - forbidden_modules_in_graph = [ - m for m in self.forbidden_modules if m.name in graph.modules # type: ignore - ] + forbidden_modules_in_graph = [m for m in forbidden_modules if m.name in graph.modules] + + def sort_key(module): + return module.name - for source_module in self.source_modules: # type: ignore - for forbidden_module in forbidden_modules_in_graph: + for source_module in sorted(source_modules, key=sort_key): + for forbidden_module in sorted(forbidden_modules_in_graph, key=sort_key): output.verbose_print( verbose, "Searching for import chains from " @@ -95,7 +110,7 @@ def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck: "line_numbers": line_numbers, } ) - subpackage_chain_data["chains"].append(chain_data) + subpackage_chain_data["chains"].append(chain_data) # type: ignore if subpackage_chain_data["chains"]: invalid_chains.append(subpackage_chain_data) if verbose: @@ -106,8 +121,15 @@ def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck: f"in {timer.duration_in_s}s.", ) + # Sorting by upstream and downstream module ensures that the output is deterministic + # and that the same upstream and downstream modules are always adjacent in the output. + def chain_sort_key(chain_data): + return (chain_data["upstream_module"], chain_data["downstream_module"]) + return ContractCheck( - kept=is_kept, warnings=warnings, metadata={"invalid_chains": invalid_chains} + kept=is_kept, + warnings=warnings, + metadata={"invalid_chains": sorted(invalid_chains, key=chain_sort_key)}, ) def render_broken_contract(self, check: "ContractCheck") -> None: @@ -133,13 +155,15 @@ def render_broken_contract(self, check: "ContractCheck") -> None: output.new_line() - def _check_all_modules_exist_in_graph(self, graph: ImportGraph) -> None: - for module in self.source_modules: # type: ignore + def _check_all_modules_exist_in_graph( + self, modules: Iterable[Module], graph: ImportGraph + ) -> None: + for module in modules: if module.name not in graph.modules: raise ValueError(f"Module '{module.name}' does not exist.") - def _check_external_forbidden_modules(self) -> None: - external_forbidden_modules = self._get_external_forbidden_modules() + def _check_external_forbidden_modules(self, forbidden_modules) -> None: + external_forbidden_modules = self._get_external_forbidden_modules(forbidden_modules) if external_forbidden_modules: if self._graph_was_built_with_externals(): for module in external_forbidden_modules: @@ -154,11 +178,11 @@ def _check_external_forbidden_modules(self) -> None: "when there are external forbidden modules." ) - def _get_external_forbidden_modules(self) -> set[Module]: + def _get_external_forbidden_modules(self, forbidden_modules) -> set[Module]: root_packages = [Module(name) for name in self.session_options["root_packages"]] return { forbidden_module - for forbidden_module in cast(List[Module], self.forbidden_modules) + for forbidden_module in cast(List[Module], forbidden_modules) if not any( forbidden_module.is_in_package(root_package) for root_package in root_packages ) diff --git a/src/importlinter/contracts/independence.py b/src/importlinter/contracts/independence.py index 62d0500d..9b60f72a 100644 --- a/src/importlinter/contracts/independence.py +++ b/src/importlinter/contracts/independence.py @@ -10,9 +10,15 @@ from importlinter.application.contract_utils import AlertLevel from importlinter.domain import fields from importlinter.domain.contract import Contract, ContractCheck +from importlinter.domain.helpers import module_expressions_to_modules from importlinter.domain.imports import Module -from ._common import DetailedChain, Link, build_detailed_chain_from_route, render_chain_data +from ._common import ( + DetailedChain, + Link, + build_detailed_chain_from_route, + render_chain_data, +) class _SubpackageChainData(TypedDict): @@ -41,7 +47,7 @@ class IndependenceContract(Contract): type_name = "independence" - modules = fields.ListField(subfield=fields.ModuleField()) + modules = fields.ListField(subfield=fields.ModuleExpressionField()) ignore_imports = fields.SetField(subfield=fields.ImportExpressionField(), required=False) unmatched_ignore_imports_alerting = fields.EnumField(AlertLevel, default=AlertLevel.ERROR) @@ -52,11 +58,12 @@ def check(self, graph: ImportGraph, verbose: bool) -> ContractCheck: unmatched_alerting=self.unmatched_ignore_imports_alerting, # type: ignore ) - self._check_all_modules_exist_in_graph(graph) + modules = list(module_expressions_to_modules(graph, self.modules)) # type: ignore + self._check_all_modules_exist_in_graph(graph, modules) dependencies = graph.find_illegal_dependencies_for_layers( # A single layer consisting of siblings. - layers=({module.name for module in self.modules},), # type: ignore + layers=({module.name for module in modules},), ) invalid_chains = self._build_invalid_chains(dependencies, graph) @@ -81,8 +88,8 @@ def render_broken_contract(self, check: "ContractCheck") -> None: output.new_line() - def _check_all_modules_exist_in_graph(self, graph: ImportGraph) -> None: - for module in self.modules: # type: ignore + def _check_all_modules_exist_in_graph(self, graph: ImportGraph, modules) -> None: + for module in modules: if module.name not in graph.modules: raise ValueError(f"Module '{module.name}' does not exist.") diff --git a/src/importlinter/domain/fields.py b/src/importlinter/domain/fields.py index 1e91aa11..41d33c2e 100644 --- a/src/importlinter/domain/fields.py +++ b/src/importlinter/domain/fields.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Generic, Iterable, List, Set, Type, TypeVar, Union, cast -from importlinter.domain.imports import ImportExpression, Module +from importlinter.domain.imports import ImportExpression, Module, ModuleExpression FieldValue = TypeVar("FieldValue") @@ -34,7 +34,6 @@ def __init__( required: Union[bool, Type[NotSupplied]] = NotSupplied, default: Union[FieldValue, Type[NotSupplied]] = NotSupplied, ) -> None: - if default is NotSupplied: if required is NotSupplied: self.required = True @@ -159,6 +158,37 @@ def parse(self, raw_data: Union[str, List]) -> Module: return Module(StringField().parse(raw_data)) +class ModuleExpressionField(Field): + """ + A field for ModuleExpressions. + + Accepts strings in the form: + "mypackage.foo.importer" + "mypackage.foo.*" + "mypackage.*.importer" + "mypackage.**" + """ + + def parse(self, expression: Union[str, List[str]]) -> ModuleExpression: + if isinstance(expression, list): + raise ValidationError("Expected a single value, got multiple values.") + + last_wildcard = None + for part in expression.split("."): + if "**" == last_wildcard and ("*" == part or "**" == part): + raise ValidationError("A recursive wildcard cannot be followed by a wildcard.") + if "*" == last_wildcard and "**" == part: + raise ValidationError("A wildcard cannot be followed by a recursive wildcard.") + if "*" == part or "**" == part: + last_wildcard = part + continue + if "*" in part: + raise ValidationError("A wildcard can only replace a whole module.") + last_wildcard = None + + return ModuleExpression(expression) + + class ImportExpressionField(Field): """ A field for ImportExpressions. @@ -181,24 +211,10 @@ def parse(self, raw_data: Union[str, List]) -> ImportExpression: if not (importer and imported): raise ValidationError('Must be in the form "package.importer -> package.imported".') - self._validate_wildcard(importer) - self._validate_wildcard(imported) - - return ImportExpression(importer=importer, imported=imported) - - def _validate_wildcard(self, expression: str) -> None: - last_wildcard = None - for part in expression.split("."): - if "**" == last_wildcard and ("*" == part or "**" == part): - raise ValidationError("A recursive wildcard cannot be followed by a wildcard.") - if "*" == last_wildcard and "**" == part: - raise ValidationError("A wildcard cannot be followed by a recursive wildcard.") - if "*" == part or "**" == part: - last_wildcard = part - continue - if "*" in part: - raise ValidationError("A wildcard can only replace a whole module.") - last_wildcard = None + return ImportExpression( + importer=ModuleExpressionField().parse(importer), + imported=ModuleExpressionField().parse(imported), + ) class EnumField(Field): diff --git a/src/importlinter/domain/helpers.py b/src/importlinter/domain/helpers.py index 3f53b0a4..c253c64c 100644 --- a/src/importlinter/domain/helpers.py +++ b/src/importlinter/domain/helpers.py @@ -2,10 +2,14 @@ import re from typing import Iterable, List, Pattern, Set, Tuple -from grimp import DetailedImport +from grimp import DetailedImport, ImportGraph -from importlinter.domain.imports import DirectImport, ImportExpression, Module -from grimp import ImportGraph +from importlinter.domain.imports import ( + DirectImport, + ImportExpression, + Module, + ModuleExpression, +) class MissingImport(Exception): @@ -28,13 +32,15 @@ def pop_imports(graph: ImportGraph, imports: Iterable[DirectImport]) -> List[Det for import_to_remove in imports_to_remove: import_details = graph.get_import_details( - importer=import_to_remove.importer.name, imported=import_to_remove.imported.name + importer=import_to_remove.importer.name, + imported=import_to_remove.imported.name, ) if not import_details: raise MissingImport(f"Ignored import {import_to_remove} not present in the graph.") graph.remove_import( - importer=import_to_remove.importer.name, imported=import_to_remove.imported.name + importer=import_to_remove.importer.name, + imported=import_to_remove.imported.name, ) removed_imports.extend(import_details) @@ -55,7 +61,9 @@ def import_expression_to_imports( imports: Set[DirectImport] = set() matched = False - for (importer, imported) in _expression_to_modules(expression, graph): + importers = module_expression_to_modules(graph, expression.importer) + importeds = module_expression_to_modules(graph, expression.imported) + for importer, imported in itertools.product(importers, importeds): import_details = graph.get_import_details(importer=importer.name, imported=imported.name) if import_details: @@ -78,6 +86,29 @@ def import_expression_to_imports( return list(imports) +def module_expressions_to_modules( + graph: ImportGraph, expressions: Iterable[ModuleExpression] +) -> Set[Module]: + modules = set() + for expression in expressions: + modules |= module_expression_to_modules(graph, expression) + return modules + + +def module_expression_to_modules(graph: ImportGraph, expression: ModuleExpression) -> Set[Module]: + if not expression.has_wildcard_expression(): + return {Module(expression.expression)} + + pattern = _to_pattern(expression.expression) + + modules = set() + for module in graph.modules: + if pattern.match(module): + modules.add(Module(module)) + + return modules + + def import_expressions_to_imports( graph: ImportGraph, expressions: Iterable[ImportExpression] ) -> List[DirectImport]: @@ -213,25 +244,3 @@ def _to_pattern(expression: str) -> Pattern: else: pattern_parts.append(part) return re.compile(r"^" + r"\.".join(pattern_parts) + r"$") - - -def _expression_to_modules( - expression: ImportExpression, graph: ImportGraph -) -> Iterable[Tuple[Module, Module]]: - if not expression.has_wildcard_expression(): - return [(Module(expression.importer), Module(expression.imported))] - - importer = [] - imported = [] - - importer_pattern = _to_pattern(expression.importer) - imported_expression = _to_pattern(expression.imported) - - for module in graph.modules: - - if importer_pattern.match(module): - importer.append(Module(module)) - if imported_expression.match(module): - imported.append(Module(module)) - - return itertools.product(set(importer), set(imported)) diff --git a/src/importlinter/domain/imports.py b/src/importlinter/domain/imports.py index 325a33aa..7c4277ba 100644 --- a/src/importlinter/domain/imports.py +++ b/src/importlinter/domain/imports.py @@ -83,24 +83,41 @@ def __hash__(self) -> int: return hash((str(self), self.line_contents)) +class ModuleExpression(ValueObject): + """ + A user-submitted expression describing a module or a set of modules. + + Sets of modules are notated using * or ** wildcards. + Examples: + "mypackage.foo.bar": a single module + "mypackage.foo.*": all direct submodules in the foo subpackage + "mypackage.*.bar": all bar-submodules of any mypackage submodule + "mypackage.**": all modules in the mypackage package + + Note that * and ** cannot be mixed in the same expression. + """ + + def __init__(self, expression: str) -> None: + self.expression = expression + + def has_wildcard_expression(self) -> bool: + return "*" in self.expression + + class ImportExpression(ValueObject): """ A user-submitted expression describing an import or set of imports. - Sets of imports are notated using * wildcards. - These wildcards can stand in for a module name or part of a name, but they do - not extend to subpackages. - - For example, "mypackage.*" refers to every child subpackage of mypackage. - It does not, however, include more distant descendants such as mypackage.foo.bar. + The importer and imported expressions are both ModuleExpressions + (see ModuleExpression for details). """ - def __init__(self, importer: str, imported: str) -> None: + def __init__(self, importer: ModuleExpression, imported: ModuleExpression) -> None: self.importer = importer self.imported = imported def has_wildcard_expression(self) -> bool: - return "*" in self.imported or "*" in self.importer + return self.imported.has_wildcard_expression() or self.importer.has_wildcard_expression() def __str__(self) -> str: - return "{} -> {}".format(self.importer, self.imported) + return "{} -> {}".format(self.importer.expression, self.imported.expression) diff --git a/tests/unit/application/test_contract_utils.py b/tests/unit/application/test_contract_utils.py index fa34dd4c..8a716d2f 100644 --- a/tests/unit/application/test_contract_utils.py +++ b/tests/unit/application/test_contract_utils.py @@ -3,7 +3,7 @@ from importlinter.application.contract_utils import AlertLevel, remove_ignored_imports from importlinter.domain.helpers import MissingImport -from importlinter.domain.imports import DirectImport, ImportExpression, Module +from importlinter.domain.imports import DirectImport, ImportExpression, Module, ModuleExpression class TestRemoveIgnoredImports: @@ -41,8 +41,14 @@ def test_no_unresolved_import_expressions(self, alert_level): warnings = remove_ignored_imports( graph=graph, ignore_imports=[ - ImportExpression(importer="mypackage.green", imported="mypackage.blue"), - ImportExpression(importer="mypackage.green", imported="mypackage.purple"), + ImportExpression( + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.blue"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.purple"), + ), ], unmatched_alerting=alert_level, ) @@ -72,10 +78,22 @@ def test_no_unresolved_import_expressions(self, alert_level): def test_unresolved_import_expressions(self, alert_level, expected_result): graph = self._build_graph(self.DIRECT_IMPORTS) ignore_imports = [ - ImportExpression(importer="mypackage.green", imported="mypackage.blue"), - ImportExpression(importer="mypackage.*", imported="mypackage.nonexistent"), - ImportExpression(importer="mypackage.green", imported="mypackage.purple"), - ImportExpression(importer="mypackage.nonexistent", imported="mypackage.blue"), + ImportExpression( + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.blue"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.*"), + imported=ModuleExpression("mypackage.nonexistent"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.purple"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.nonexistent"), + imported=ModuleExpression("mypackage.blue"), + ), ] if isinstance(expected_result, Exception): diff --git a/tests/unit/contracts/test_forbidden.py b/tests/unit/contracts/test_forbidden.py index 34e7f349..17ed0ab2 100644 --- a/tests/unit/contracts/test_forbidden.py +++ b/tests/unit/contracts/test_forbidden.py @@ -59,6 +59,19 @@ def test_is_broken_when_forbidden_modules_imported(self): ], ], }, + { + "upstream_module": "mypackage.green", + "downstream_module": "mypackage.three", + "chains": [ + [ + { + "importer": "mypackage.three", + "imported": "mypackage.green", + "line_numbers": (4,), + } + ] + ], + }, { "upstream_module": "mypackage.purple", "downstream_module": "mypackage.two", @@ -77,19 +90,6 @@ def test_is_broken_when_forbidden_modules_imported(self): ] ], }, - { - "upstream_module": "mypackage.green", - "downstream_module": "mypackage.three", - "chains": [ - [ - { - "importer": "mypackage.three", - "imported": "mypackage.green", - "line_numbers": (4,), - } - ] - ], - }, ] } @@ -153,6 +153,168 @@ def test_ignore_imports_tolerates_duplicates(self): check = contract.check(graph=graph, verbose=False) assert check.kept + def test_wildcards_in_source_modules_are_resolved(self): + graph = self._build_graph() + contract = self._build_contract( + forbidden_modules=("mypackage.green"), + source_modules=("mypackage.one.*",), + include_external_packages=False, + ) + + check = contract.check(graph=graph, verbose=False) + assert check.metadata == { + "invalid_chains": [ + { + "upstream_module": "mypackage.green", + "downstream_module": "mypackage.one.alpha", + "chains": [ + [ + { + "importer": "mypackage.one.alpha", + "imported": "mypackage.green.beta", + "line_numbers": (3,), + }, + ], + [ + { + "importer": "mypackage.one.alpha.circle", + "imported": "mypackage.green.beta.sphere", + "line_numbers": (8,), + }, + ], + ], + }, + ], + } + + def test_recursive_wildcards_in_source_modules_are_resolved(self): + graph = self._build_graph() + contract = self._build_contract( + forbidden_modules=("mypackage.green"), + source_modules=("mypackage.one.**",), + include_external_packages=False, + ) + + check = contract.check(graph=graph, verbose=False) + assert check.metadata == { + "invalid_chains": [ + { + "upstream_module": "mypackage.green", + "downstream_module": "mypackage.one.alpha", + "chains": [ + [ + { + "importer": "mypackage.one.alpha", + "imported": "mypackage.green.beta", + "line_numbers": (3,), + }, + ], + [ + { + "importer": "mypackage.one.alpha.circle", + "imported": "mypackage.green.beta.sphere", + "line_numbers": (8,), + }, + ], + ], + }, + { + "upstream_module": "mypackage.green", + "downstream_module": "mypackage.one.alpha.circle", + "chains": [ + [ + { + "importer": "mypackage.one.alpha.circle", + "imported": "mypackage.green.beta.sphere", + "line_numbers": (8,), + }, + ], + ], + }, + ], + } + + def test_wildcards_in_forbidden_modules_are_resolved(self): + graph = self._build_graph() + contract = self._build_contract( + forbidden_modules=("mypackage.green.*"), + source_modules=("mypackage.one",), + include_external_packages=False, + ) + + check = contract.check(graph=graph, verbose=False) + assert check.metadata == { + "invalid_chains": [ + { + "upstream_module": "mypackage.green.beta", + "downstream_module": "mypackage.one", + "chains": [ + [ + { + "importer": "mypackage.one.alpha", + "imported": "mypackage.green.beta", + "line_numbers": (3,), + }, + ], + [ + { + "importer": "mypackage.one.alpha.circle", + "imported": "mypackage.green.beta.sphere", + "line_numbers": (8,), + }, + ], + ], + }, + ], + } + + def test_recursive_wildcards_in_forbidden_modules_are_resolved(self): + graph = self._build_graph() + contract = self._build_contract( + forbidden_modules=("mypackage.green.**"), + source_modules=("mypackage.one",), + include_external_packages=False, + ) + + check = contract.check(graph=graph, verbose=False) + assert check.metadata == { + "invalid_chains": [ + { + "upstream_module": "mypackage.green.beta", + "downstream_module": "mypackage.one", + "chains": [ + [ + { + "importer": "mypackage.one.alpha", + "imported": "mypackage.green.beta", + "line_numbers": (3,), + }, + ], + [ + { + "importer": "mypackage.one.alpha.circle", + "imported": "mypackage.green.beta.sphere", + "line_numbers": (8,), + }, + ], + ], + }, + { + "upstream_module": "mypackage.green.beta.sphere", + "downstream_module": "mypackage.one", + "chains": [ + [ + { + "importer": "mypackage.one.alpha.circle", + "imported": "mypackage.green.beta.sphere", + "line_numbers": (8,), + }, + ], + ], + }, + ], + } + def test_ignore_imports_with_wildcards(self): graph = self._build_graph() contract = self._build_contract( @@ -415,13 +577,15 @@ def _build_contract( ignore_imports=None, include_external_packages=False, allow_indirect_imports=None, + source_modules=None, ): session_options = {"root_packages": ["mypackage"]} if include_external_packages: session_options["include_external_packages"] = "True" contract_options = { - "source_modules": ("mypackage.one", "mypackage.two", "mypackage.three"), + "source_modules": source_modules + or ("mypackage.one", "mypackage.two", "mypackage.three"), "forbidden_modules": forbidden_modules, "ignore_imports": ignore_imports or [], } @@ -639,25 +803,25 @@ def test_verbose(self): Found 0 illegal chains in 10s. Searching for import chains from mypackage.one to mypackage.green... Found 1 illegal chain in 10s. - Searching for import chains from mypackage.one to mypackage.yellow... - Found 0 illegal chains in 10s. Searching for import chains from mypackage.one to mypackage.purple... Found 0 illegal chains in 10s. - Searching for import chains from mypackage.two to mypackage.blue... - Found 0 illegal chains in 10s. - Searching for import chains from mypackage.two to mypackage.green... - Found 0 illegal chains in 10s. - Searching for import chains from mypackage.two to mypackage.yellow... + Searching for import chains from mypackage.one to mypackage.yellow... Found 0 illegal chains in 10s. - Searching for import chains from mypackage.two to mypackage.purple... - Found 1 illegal chain in 10s. Searching for import chains from mypackage.three to mypackage.blue... Found 0 illegal chains in 10s. Searching for import chains from mypackage.three to mypackage.green... Found 1 illegal chain in 10s. + Searching for import chains from mypackage.three to mypackage.purple... + Found 0 illegal chains in 10s. Searching for import chains from mypackage.three to mypackage.yellow... Found 0 illegal chains in 10s. - Searching for import chains from mypackage.three to mypackage.purple... + Searching for import chains from mypackage.two to mypackage.blue... + Found 0 illegal chains in 10s. + Searching for import chains from mypackage.two to mypackage.green... + Found 0 illegal chains in 10s. + Searching for import chains from mypackage.two to mypackage.purple... + Found 1 illegal chain in 10s. + Searching for import chains from mypackage.two to mypackage.yellow... Found 0 illegal chains in 10s. """ ) diff --git a/tests/unit/contracts/test_independence.py b/tests/unit/contracts/test_independence.py index 2fc27aa4..61c9253a 100644 --- a/tests/unit/contracts/test_independence.py +++ b/tests/unit/contracts/test_independence.py @@ -43,6 +43,14 @@ def _check_default_contract(self, graph): ) return contract.check(graph=graph, verbose=False) + def _check_wildcard_contract(self, graph): + contract = IndependenceContract( + name="Independence contract", + session_options={"root_packages": ["mypackage"]}, + contract_options={"modules": ("mypackage.*",)}, + ) + return contract.check(graph=graph, verbose=False) + def test_when_modules_are_independent(self): graph = self._build_default_graph() graph.add_import( @@ -59,9 +67,26 @@ def test_when_modules_are_independent(self): ) contract_check = self._check_default_contract(graph) + assert contract_check.kept + def test_when_wildcard_modules_are_independent(self): + graph = self._build_default_graph() + + contract_check = self._check_wildcard_contract(graph) assert contract_check.kept + def test_when_wildcard_modules_are_not_independent(self): + graph = self._build_default_graph() + graph.add_import( + importer="mypackage.blue", + imported="mypackage.green", + line_number=10, + line_contents="-", + ) + + contract_check = self._check_wildcard_contract(graph) + assert not contract_check.kept + def test_when_root_imports_root_directly(self): graph = self._build_default_graph() graph.add_import( diff --git a/tests/unit/domain/test_fields.py b/tests/unit/domain/test_fields.py index 142c967b..0af3279b 100644 --- a/tests/unit/domain/test_fields.py +++ b/tests/unit/domain/test_fields.py @@ -15,7 +15,7 @@ StringField, ValidationError, ) -from importlinter.domain.imports import ImportExpression, Module +from importlinter.domain.imports import ImportExpression, Module, ModuleExpression def test_field_cannot_be_instantiated_with_default_and_required(): @@ -91,49 +91,80 @@ class TestModuleField(BaseFieldTest): ( ( "mypackage.foo -> mypackage.bar", - ImportExpression(importer="mypackage.foo", imported="mypackage.bar"), + ImportExpression( + importer=ModuleExpression("mypackage.foo"), + imported=ModuleExpression("mypackage.bar"), + ), ), ( "my_package.foo -> my_package.bar", # Extra whitespaces are supported. - ImportExpression(importer="my_package.foo", imported="my_package.bar"), + ImportExpression( + importer=ModuleExpression("my_package.foo"), + imported=ModuleExpression("my_package.bar"), + ), ), ( "my_package.foo -> my_package.foo_bar", # Underscores are supported. - ImportExpression(importer="my_package.foo", imported="my_package.foo_bar"), + ImportExpression( + importer=ModuleExpression("my_package.foo"), + imported=ModuleExpression("my_package.foo_bar"), + ), ), # Wildcards # --------- ( "mypackage.foo.* -> mypackage.bar", - ImportExpression(importer="mypackage.foo.*", imported="mypackage.bar"), + ImportExpression( + importer=ModuleExpression("mypackage.foo.*"), + imported=ModuleExpression("mypackage.bar"), + ), ), ( "mypackage.foo.*.baz -> mypackage.bar", - ImportExpression(importer="mypackage.foo.*.baz", imported="mypackage.bar"), + ImportExpression( + importer=ModuleExpression("mypackage.foo.*.baz"), + imported=ModuleExpression("mypackage.bar"), + ), ), ( "mypackage.foo -> mypackage.bar.*", - ImportExpression(importer="mypackage.foo", imported="mypackage.bar.*"), + ImportExpression( + importer=ModuleExpression("mypackage.foo"), + imported=ModuleExpression("mypackage.bar.*"), + ), ), ( "*.*.* -> mypackage.*.foo.*", - ImportExpression(importer="*.*.*", imported="mypackage.*.foo.*"), + ImportExpression( + importer=ModuleExpression("*.*.*"), imported=ModuleExpression("mypackage.*.foo.*") + ), ), ( "mypackage.foo.** -> mypackage.bar", - ImportExpression(importer="mypackage.foo.**", imported="mypackage.bar"), + ImportExpression( + importer=ModuleExpression("mypackage.foo.**"), + imported=ModuleExpression("mypackage.bar"), + ), ), ( "mypackage.foo.**.baz -> mypackage.bar", - ImportExpression(importer="mypackage.foo.**.baz", imported="mypackage.bar"), + ImportExpression( + importer=ModuleExpression("mypackage.foo.**.baz"), + imported=ModuleExpression("mypackage.bar"), + ), ), ( "mypackage.foo -> mypackage.bar.**", - ImportExpression(importer="mypackage.foo", imported="mypackage.bar.**"), + ImportExpression( + importer=ModuleExpression("mypackage.foo"), + imported=ModuleExpression("mypackage.bar.**"), + ), ), ( "** -> mypackage.**.foo.*", - ImportExpression(importer="**", imported="mypackage.**.foo.*"), + ImportExpression( + importer=ModuleExpression("**"), imported=ModuleExpression("mypackage.**.foo.*") + ), ), # Invalid expressions # ------------------- diff --git a/tests/unit/domain/test_helpers.py b/tests/unit/domain/test_helpers.py index 6501ae02..7832279e 100644 --- a/tests/unit/domain/test_helpers.py +++ b/tests/unit/domain/test_helpers.py @@ -13,7 +13,7 @@ pop_imports, resolve_import_expressions, ) -from importlinter.domain.imports import DirectImport, ImportExpression, Module +from importlinter.domain.imports import DirectImport, ImportExpression, Module, ModuleExpression class TestPopImports: @@ -181,8 +181,8 @@ class TestImportExpressionsToImports: "No wildcards", [ ImportExpression( - importer=DIRECT_IMPORTS[0].importer.name, - imported=DIRECT_IMPORTS[0].imported.name, + importer=ModuleExpression(DIRECT_IMPORTS[0].importer.name), + imported=ModuleExpression(DIRECT_IMPORTS[0].imported.name), ), ], [DIRECT_IMPORTS[0]], @@ -190,37 +190,53 @@ class TestImportExpressionsToImports: ( "Importer wildcard", [ - ImportExpression(importer="mypackage.*", imported="mypackage.blue"), + ImportExpression( + importer=ModuleExpression("mypackage.*"), + imported=ModuleExpression("mypackage.blue"), + ), ], [DIRECT_IMPORTS[1]], ), ( "Imported wildcard", [ - ImportExpression(importer="mypackage.green", imported="mypackage.*"), + ImportExpression( + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.*"), + ), ], DIRECT_IMPORTS[0:2], ), ( "Importer and imported wildcards", [ - ImportExpression(importer="mypackage.*", imported="mypackage.*"), + ImportExpression( + importer=ModuleExpression("mypackage.*"), + imported=ModuleExpression("mypackage.*"), + ), ], DIRECT_IMPORTS[0:3], ), ( "Inner wildcard", [ - ImportExpression(importer="mypackage.*.cats", imported="mypackage.*.dogs"), + ImportExpression( + importer=ModuleExpression("mypackage.*.cats"), + imported=ModuleExpression("mypackage.*.dogs"), + ), ], DIRECT_IMPORTS[3:5], ), ( "Multiple expressions, non-overlapping", [ - ImportExpression(importer="mypackage.green", imported="mypackage.*"), ImportExpression( - importer="mypackage.green.cats", imported="mypackage.orange.*" + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.*"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.green.cats"), + imported=ModuleExpression("mypackage.orange.*"), ), ], DIRECT_IMPORTS[0:2] + DIRECT_IMPORTS[4:6], @@ -228,15 +244,24 @@ class TestImportExpressionsToImports: ( "Multiple expressions, overlapping", [ - ImportExpression(importer="mypackage.*", imported="mypackage.blue"), - ImportExpression(importer="mypackage.green", imported="mypackage.blue"), + ImportExpression( + importer=ModuleExpression("mypackage.*"), + imported=ModuleExpression("mypackage.blue"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.blue"), + ), ], [DIRECT_IMPORTS[1]], ), ( "Multiple imports of external package with same importer", [ - ImportExpression(importer="mypackage.brown", imported="someotherpackage"), + ImportExpression( + importer=ModuleExpression("mypackage.brown"), + imported=ModuleExpression("someotherpackage"), + ), ], DIRECT_IMPORTS[6:8], ), @@ -257,7 +282,9 @@ def test_raises_missing_import(self): importer="mypackage.b", imported="other.foo", line_number=1, line_contents="-" ) - expression = ImportExpression(importer="mypackage.a.*", imported="other.foo") + expression = ImportExpression( + importer=ModuleExpression("mypackage.a.*"), imported=ModuleExpression("other.foo") + ) with pytest.raises(MissingImport): import_expressions_to_imports(graph, [expression]) @@ -340,8 +367,8 @@ class TestResolveImportExpressions: "No wildcards", [ ImportExpression( - importer=DIRECT_IMPORTS[0].importer.name, - imported=DIRECT_IMPORTS[0].imported.name, + importer=ModuleExpression(DIRECT_IMPORTS[0].importer.name), + imported=ModuleExpression(DIRECT_IMPORTS[0].imported.name), ), ], {DIRECT_IMPORTS[0]}, @@ -349,65 +376,93 @@ class TestResolveImportExpressions: ( "Importer wildcard", [ - ImportExpression(importer="mypackage.*", imported="mypackage.blue"), + ImportExpression( + importer=ModuleExpression("mypackage.*"), + imported=ModuleExpression("mypackage.blue"), + ), ], {DIRECT_IMPORTS[1]}, ), ( "Imported wildcard", [ - ImportExpression(importer="mypackage.green", imported="mypackage.*"), + ImportExpression( + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.*"), + ), ], set(DIRECT_IMPORTS[0:2]), ), ( "Importer and imported wildcards", [ - ImportExpression(importer="mypackage.*", imported="mypackage.*"), + ImportExpression( + importer=ModuleExpression("mypackage.*"), + imported=ModuleExpression("mypackage.*"), + ), ], set(DIRECT_IMPORTS[0:3]), ), ( "Inner wildcard", [ - ImportExpression(importer="mypackage.*.cats", imported="mypackage.*.dogs"), + ImportExpression( + importer=ModuleExpression("mypackage.*.cats"), + imported=ModuleExpression("mypackage.*.dogs"), + ), ], set(DIRECT_IMPORTS[3:5]), ), ( "Importer recursive wildcard", [ - ImportExpression(importer="mypackage.**", imported="mypackage.blue"), + ImportExpression( + importer=ModuleExpression("mypackage.**"), + imported=ModuleExpression("mypackage.blue"), + ), ], {DIRECT_IMPORTS[1]}, ), ( "Imported recursive wildcard", [ - ImportExpression(importer="mypackage.green", imported="mypackage.**"), + ImportExpression( + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.**"), + ), ], set(DIRECT_IMPORTS[0:2]), ), ( "Importer and imported recursive wildcards", [ - ImportExpression(importer="mypackage.**", imported="mypackage.**"), + ImportExpression( + importer=ModuleExpression("mypackage.**"), + imported=ModuleExpression("mypackage.**"), + ), ], set(DIRECT_IMPORTS[0:6]) | {DIRECT_IMPORTS[8]}, ), ( "Inner recursive wildcard", [ - ImportExpression(importer="mypackage.**.cats", imported="mypackage.**.dogs"), + ImportExpression( + importer=ModuleExpression("mypackage.**.cats"), + imported=ModuleExpression("mypackage.**.dogs"), + ), ], set(DIRECT_IMPORTS[3:5]) | {DIRECT_IMPORTS[8]}, ), ( "Multiple expressions, non-overlapping", [ - ImportExpression(importer="mypackage.green", imported="mypackage.*"), ImportExpression( - importer="mypackage.green.cats", imported="mypackage.orange.*" + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.*"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.green.cats"), + imported=ModuleExpression("mypackage.orange.*"), ), ], set(DIRECT_IMPORTS[0:2] + DIRECT_IMPORTS[4:6]), @@ -415,15 +470,24 @@ class TestResolveImportExpressions: ( "Multiple expressions, overlapping", [ - ImportExpression(importer="mypackage.*", imported="mypackage.blue"), - ImportExpression(importer="mypackage.green", imported="mypackage.blue"), + ImportExpression( + importer=ModuleExpression("mypackage.*"), + imported=ModuleExpression("mypackage.blue"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.blue"), + ), ], {DIRECT_IMPORTS[1]}, ), ( "Multiple imports of external package with same importer", [ - ImportExpression(importer="mypackage.brown", imported="someotherpackage"), + ImportExpression( + importer=ModuleExpression("mypackage.brown"), + imported=ModuleExpression("someotherpackage"), + ), ], set(DIRECT_IMPORTS[6:8]), ), @@ -449,13 +513,20 @@ def test_detects_unresolved_expression(self): graph.add_import( importer="mypackage.b", imported="other.foo", line_number=1, line_contents="-" ) - expression = ImportExpression(importer="mypackage.a.*", imported="other.foo") + expression = ImportExpression( + importer=ModuleExpression("mypackage.a.*"), imported=ModuleExpression("other.foo") + ) imports, unresolved_expressions = resolve_import_expressions(graph, [expression]) assert (imports, unresolved_expressions) == ( set(), - {ImportExpression(imported="other.foo", importer="mypackage.a.*")}, + { + ImportExpression( + imported=ModuleExpression("other.foo"), + importer=ModuleExpression("mypackage.a.*"), + ) + }, ) def _build_graph(self, direct_imports): @@ -507,10 +578,19 @@ class TestPopImportExpressions: def test_succeeds(self) -> None: graph = self._build_graph(self.DIRECT_IMPORTS) expressions = [ - ImportExpression(importer="mypackage.green", imported="mypackage.*"), + ImportExpression( + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.*"), + ), # Expressions can overlap. - ImportExpression(importer="mypackage.green", imported="mypackage.blue"), - ImportExpression(importer="mypackage.blue.cats", imported="mypackage.purple.dogs"), + ImportExpression( + importer=ModuleExpression("mypackage.green"), + imported=ModuleExpression("mypackage.blue"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.blue.cats"), + imported=ModuleExpression("mypackage.purple.dogs"), + ), ] popped_imports: List[DetailedImport] = pop_import_expressions(graph, expressions) diff --git a/tests/unit/domain/test_imports.py b/tests/unit/domain/test_imports.py index cc718f03..50838fa4 100644 --- a/tests/unit/domain/test_imports.py +++ b/tests/unit/domain/test_imports.py @@ -2,7 +2,7 @@ import pytest -from importlinter.domain.imports import DirectImport, ImportExpression, Module +from importlinter.domain.imports import DirectImport, ImportExpression, Module, ModuleExpression @contextmanager @@ -98,34 +98,62 @@ def test_string_object_representation(self, test_object, expected_string): class TestImportExpression: def test_object_representation(self): - test_object = ImportExpression(importer="mypackage.foo", imported="mypackage.bar") + test_object = ImportExpression( + importer=ModuleExpression("mypackage.foo"), imported=ModuleExpression("mypackage.bar") + ) assert repr(test_object) == " mypackage.bar>" def test_string_object_representation(self): - expression = ImportExpression(importer="mypackage.foo", imported="mypackage.bar") + expression = ImportExpression( + importer=ModuleExpression("mypackage.foo"), imported=ModuleExpression("mypackage.bar") + ) assert str(expression) == "mypackage.foo -> mypackage.bar" @pytest.mark.parametrize( "first, second, expected", [ ( - ImportExpression(importer="mypackage.foo", imported="mypackage.bar"), - ImportExpression(importer="mypackage.foo", imported="mypackage.bar"), + ImportExpression( + importer=ModuleExpression("mypackage.foo"), + imported=ModuleExpression("mypackage.bar"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.foo"), + imported=ModuleExpression("mypackage.bar"), + ), True, ), ( - ImportExpression(importer="mypackage.foo", imported="mypackage.bar"), - ImportExpression(importer="mypackage.bar", imported="mypackage.foo"), + ImportExpression( + importer=ModuleExpression("mypackage.foo"), + imported=ModuleExpression("mypackage.bar"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.bar"), + imported=ModuleExpression("mypackage.foo"), + ), False, ), ( - ImportExpression(importer="mypackage.foo", imported="mypackage.bar"), - ImportExpression(importer="mypackage.foo", imported="mypackage.foobar"), + ImportExpression( + importer=ModuleExpression("mypackage.foo"), + imported=ModuleExpression("mypackage.bar"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.foo"), + imported=ModuleExpression("mypackage.foobar"), + ), False, ), ( - ImportExpression(importer="mypackage.foo", imported="mypackage.bar"), - ImportExpression(importer="mypackage.foobar", imported="mypackage.bar"), + ImportExpression( + importer=ModuleExpression("mypackage.foo"), + imported=ModuleExpression("mypackage.bar"), + ), + ImportExpression( + importer=ModuleExpression("mypackage.foobar"), + imported=ModuleExpression("mypackage.bar"), + ), False, ), ], @@ -148,5 +176,7 @@ def test_equality(self, first, second, expected): ], ) def test_has_wildcard_expression(self, importer, imported, has_wildcard_expression): - expression = ImportExpression(importer=importer, imported=imported) + expression = ImportExpression( + importer=ModuleExpression(importer), imported=ModuleExpression(imported) + ) assert expression.has_wildcard_expression() == has_wildcard_expression