diff --git a/src/importlinter/domain/helpers.py b/src/importlinter/domain/helpers.py index 90a7c157..03b8a9b5 100644 --- a/src/importlinter/domain/helpers.py +++ b/src/importlinter/domain/helpers.py @@ -1,8 +1,8 @@ import itertools -from typing import Dict, Iterable, List, Tuple, Union, Pattern import re +from typing import Dict, Iterable, List, Pattern, Tuple, Union, Set -from importlinter.domain.imports import ImportExpression, Module, DirectImport +from importlinter.domain.imports import DirectImport, ImportExpression, Module from importlinter.domain.ports.graph import ImportGraph @@ -18,7 +18,7 @@ def pop_imports( Returns: The list of import details that were removed, including any additional metadata. Raises: - MissingImport if the import is not present in the graph. + MissingImport if an import is not present in the graph. """ removed_imports: List[Dict[str, Union[str, int]]] = [] for import_to_remove in imports: @@ -37,7 +37,14 @@ def pop_imports( def import_expressions_to_imports( graph: ImportGraph, expressions: Iterable[ImportExpression] ) -> List[DirectImport]: - imports: List[DirectImport] = [] + """ + Returns a list of imports in a graph, given some import expressions. + + Raises: + MissingImport if an import is not present in the graph. For a wildcarded import expression, + this is raised if there is not at least one match. + """ + imports: Set[DirectImport] = set() for expression in expressions: matched = False for (importer, imported) in _expression_to_modules(expression, graph): @@ -45,18 +52,35 @@ def import_expressions_to_imports( importer=importer.name, imported=imported.name ) if import_details: - imports.append(DirectImport(importer=importer, imported=imported)) + for individual_import_details in import_details: + imports.add( + DirectImport( + importer=Module(individual_import_details["importer"]), + imported=Module(individual_import_details["imported"]), + line_number=individual_import_details["line_number"], + line_contents=individual_import_details["line_contents"], + ) + ) matched = True if not matched: raise MissingImport( f"Ignored import expression {expression} didn't match anything in the graph." ) - return imports + return list(imports) def pop_import_expressions( graph: ImportGraph, expressions: Iterable[ImportExpression] ) -> List[Dict[str, Union[str, int]]]: + """ + Removes any imports matching the supplied import expressions from the graph. + + Returns: + The list of imports that were removed, including any additional metadata. + Raises: + MissingImport if an import is not present in the graph. For a wildcarded import expression, + this is raised if there is not at least one match. + """ imports = import_expressions_to_imports(graph, expressions) return pop_imports(graph, imports) @@ -94,7 +118,7 @@ def _to_pattern(expression: str) -> Pattern: pattern_parts.append(part.replace("*", r"[^\.]+")) else: pattern_parts.append(part) - return re.compile(r"\.".join(pattern_parts)) + return re.compile(r"^" + r"\.".join(pattern_parts) + r"$") def _expression_to_modules( @@ -115,5 +139,6 @@ def _expression_to_modules( importer.append(Module(module)) if imported_expression.match(module): imported.append(Module(module)) + imported.append(Module(module)) - return itertools.product(importer, imported) + return itertools.product(set(importer), set(imported)) diff --git a/tests/unit/domain/test_helpers.py b/tests/unit/domain/test_helpers.py index 4de13491..d51d384c 100644 --- a/tests/unit/domain/test_helpers.py +++ b/tests/unit/domain/test_helpers.py @@ -1,7 +1,274 @@ -from grimp.adaptors.graph import ImportGraph # type: ignore -from importlinter.domain.helpers import MissingImport, add_imports, import_expressions_to_imports -from importlinter.domain.imports import DirectImport, ImportExpression +import re +from typing import Dict, Union + import pytest +from grimp.adaptors.graph import ImportGraph # type: ignore + +from importlinter.domain.helpers import ( + MissingImport, + add_imports, + import_expressions_to_imports, + pop_import_expressions, + pop_imports, +) +from importlinter.domain.imports import DirectImport, ImportExpression, Module + + +class TestPopImports: + IMPORTS = [ + dict( + importer="mypackage.green", + imported="mypackage.yellow", + line_number=1, + line_contents="blah", + ), + dict( + importer="mypackage.green", + imported="mypackage.blue", + line_number=2, + line_contents="blahblah", + ), + dict( + importer="mypackage.blue", + imported="mypackage.green", + line_number=10, + line_contents="blahblahblah", + ), + ] + + def test_succeeds(self): + graph = self._build_graph(imports=self.IMPORTS) + imports_to_pop = self.IMPORTS[0:2] + import_to_leave = self.IMPORTS[2] + + result = pop_imports( + graph, + [ + DirectImport(importer=Module(i["importer"]), imported=Module(i["imported"])) + for i in imports_to_pop + ], + ) + + assert result == imports_to_pop + assert graph.direct_import_exists( + importer=import_to_leave["importer"], imported=import_to_leave["imported"] + ) + assert graph.count_imports() == 1 + + def test_raises_missing_import_if_module_not_found(self): + graph = self._build_graph(imports=self.IMPORTS) + non_existent_import = DirectImport( + importer=Module("mypackage.nonexistent"), + imported=Module("mypackage.yellow"), + line_number=1, + line_contents="-", + ) + with pytest.raises( + MissingImport, + match=re.escape(f"Ignored import {non_existent_import} not present in the graph."), + ): + pop_imports(graph, [non_existent_import]) + + def _build_graph(self, imports): + graph = ImportGraph() + for module in ("mypackage", "mypackage.green", "mypackage.blue", "mypackage.yellow"): + graph.add_module(module) + for import_ in imports: + graph.add_import(**import_) + return graph + + +class TestImportExpressionsToImports: + DIRECT_IMPORTS = [ + DirectImport( + importer=Module("mypackage.green"), + imported=Module("mypackage.yellow"), + line_number=1, + line_contents="-", + ), + DirectImport( + importer=Module("mypackage.green"), + imported=Module("mypackage.blue"), + line_number=1, + line_contents="-", + ), + DirectImport( + importer=Module("mypackage.blue"), + imported=Module("mypackage.green"), + line_number=1, + line_contents="-", + ), + DirectImport( + importer=Module("mypackage.blue.cats"), + imported=Module("mypackage.purple.dogs"), + line_number=1, + line_contents="-", + ), + DirectImport( + importer=Module("mypackage.green.cats"), + imported=Module("mypackage.orange.dogs"), + line_number=1, + line_contents="-", + ), + ] + + @pytest.mark.parametrize( + "description, expressions, expected", + [ + ( + "No wildcards", + [ + ImportExpression( + importer=DIRECT_IMPORTS[0].importer.name, + imported=DIRECT_IMPORTS[0].imported.name, + ), + ], + [DIRECT_IMPORTS[0]], + ), + ( + "Importer wildcard", + [ + ImportExpression(importer="mypackage.*", imported="mypackage.blue"), + ], + [DIRECT_IMPORTS[1]], + ), + ( + "Imported wildcard", + [ + ImportExpression(importer="mypackage.green", imported="mypackage.*"), + ], + DIRECT_IMPORTS[0:2], + ), + ( + "Importer and imported wildcards", + [ + ImportExpression(importer="mypackage.*", imported="mypackage.*"), + ], + DIRECT_IMPORTS[0:3], + ), + ( + "Inner wildcard", + [ + ImportExpression(importer="mypackage.*.cats", imported="mypackage.*.dogs"), + ], + DIRECT_IMPORTS[3:5], + ), + ( + "Overlapping expressions", + [ + ImportExpression(importer="mypackage.*", imported="mypackage.blue"), + ImportExpression(importer="mypackage.green", imported="mypackage.blue"), + ], + [DIRECT_IMPORTS[1]], + ), + ], + ) + def test_succeeds(self, description, expressions, expected): + graph = self._build_graph(self.DIRECT_IMPORTS) + + assert sorted( + import_expressions_to_imports(graph, expressions), key=_direct_import_sort_key + ) == sorted(expected, key=_direct_import_sort_key) + + def test_fails(self): + graph = ImportGraph() + graph.add_module("mypackage") + graph.add_module("other") + graph.add_import( + importer="mypackage.b", imported="other.foo", line_number=1, line_contents="-" + ) + + expression = ImportExpression(importer="mypackage.a.*", imported="other.foo") + with pytest.raises(MissingImport): + import_expressions_to_imports(graph, [expression]) + + def _build_graph(self, direct_imports): + graph = ImportGraph() + for direct_import in direct_imports: + graph.add_import( + importer=direct_import.importer.name, + imported=direct_import.imported.name, + line_number=direct_import.line_number, + line_contents=direct_import.line_contents, + ) + return graph + + +class TestPopImportExpressions: + DIRECT_IMPORTS = [ + DirectImport( + importer=Module("mypackage.green"), + imported=Module("mypackage.yellow"), + line_number=1, + line_contents="-", + ), + DirectImport( + importer=Module("mypackage.green"), + imported=Module("mypackage.blue"), + line_number=1, + line_contents="-", + ), + DirectImport( + importer=Module("mypackage.blue"), + imported=Module("mypackage.green"), + line_number=1, + line_contents="-", + ), + DirectImport( + importer=Module("mypackage.blue.cats"), + imported=Module("mypackage.purple.dogs"), + line_number=1, + line_contents="-", + ), + DirectImport( + importer=Module("mypackage.green.cats"), + imported=Module("mypackage.orange.dogs"), + line_number=1, + line_contents="-", + ), + ] + + def test_succeeds(self): + graph = self._build_graph(self.DIRECT_IMPORTS) + expressions = [ + ImportExpression(importer="mypackage.green", imported="mypackage.*"), + # Expressions can overlap. + ImportExpression(importer="mypackage.green", imported="mypackage.blue"), + ImportExpression(importer="mypackage.blue.cats", imported="mypackage.purple.dogs"), + ] + result = pop_import_expressions(graph, expressions) + + x = map(self._dict_to_direct_import, result) + sorted_x = sorted(x, key=_direct_import_sort_key) + sorted_y = sorted( + [ + self.DIRECT_IMPORTS[0], + self.DIRECT_IMPORTS[1], + self.DIRECT_IMPORTS[3], + ], + key=_direct_import_sort_key, + ) + assert sorted_x == sorted_y + assert graph.count_imports() == 2 + + def _build_graph(self, direct_imports): + graph = ImportGraph() + for direct_import in direct_imports: + graph.add_import( + importer=direct_import.importer.name, + imported=direct_import.imported.name, + line_number=direct_import.line_number, + line_contents=direct_import.line_contents, + ) + return graph + + def _dict_to_direct_import(self, import_details: Dict[str, Union[str, int]]) -> DirectImport: + return DirectImport( + importer=Module(import_details["importer"]), + imported=Module(import_details["imported"]), + line_number=import_details["line_number"], + line_contents=import_details["line_contents"], + ) def test_add_imports(): @@ -15,31 +282,10 @@ def test_add_imports(): assert graph.modules == {"a", "b", "c", "d"} -def test_import_expressions_to_imports(): - graph = ImportGraph() - graph.add_module("mypackage") - graph.add_module("other") - graph.add_import( - importer="mypackage.a", imported="other.foo", line_number=1, line_contents="-" - ) - graph.add_import( - importer="mypackage.c", imported="other.baz", line_number=1, line_contents="-" +def _direct_import_sort_key(direct_import: DirectImport): + # Doesn't matter how we sort, just a way of sorting consistently for comparison. + return ( + direct_import.importer.name, + direct_import.imported.name, + direct_import.line_number, ) - - expression = ImportExpression(importer="mypackage.*", imported="other.foo") - assert import_expressions_to_imports(graph, [expression]) == [ - DirectImport(importer="mypackage.a", imported="other.foo") - ] - - -def test_import_expressions_to_imports_fails(): - graph = ImportGraph() - graph.add_module("mypackage") - graph.add_module("other") - graph.add_import( - importer="mypackage.b", imported="other.foo", line_number=1, line_contents="-" - ) - - expression = ImportExpression(importer="mypackage.a.*", imported="other.foo") - with pytest.raises(MissingImport): - import_expressions_to_imports(graph, [expression])