Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support module wildcards everywhere #235

Merged
merged 17 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/importlinter/application/contract_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import enum
from typing import List, Optional, Sequence, Set
from typing import List, Optional, Sequence, Set, Iterable

from importlinter.domain import helpers
from importlinter.domain.helpers import MissingImport
from importlinter.domain.imports import ImportExpression
from importlinter.domain.imports import ImportExpression, ModuleExpression, Module
from grimp import ImportGraph


Expand Down Expand Up @@ -46,6 +46,13 @@ def remove_ignored_imports(
return warnings


def resolve_module_expressions(
graph: ImportGraph, expressions: Iterable[ModuleExpression]
) -> Iterable[Module]:
fbinz marked this conversation as resolved.
Show resolved Hide resolved
for expression in expressions:
yield from _resolve_module_expression(graph, expression)


# Private functions
# -----------------

Expand Down Expand Up @@ -75,3 +82,9 @@ def _handle_unresolved_import_expressions(

def _build_missing_import_message(expression: ImportExpression) -> str:
return f"No matches for ignored import {expression}."


def _resolve_module_expression(
graph: ImportGraph, expression: ModuleExpression
) -> Iterable[Module]:
pass
52 changes: 33 additions & 19 deletions src/importlinter/domain/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import Generic, Iterable, List, Set, Type, TypeVar, Union, cast

from importlinter.domain.imports import ImportExpression, Module
from importlinter.domain.imports import ImportExpression, Module, ModuleExpression

FieldValue = TypeVar("FieldValue")

Expand Down Expand Up @@ -159,6 +159,34 @@ def parse(self, raw_data: Union[str, List]) -> Module:
return Module(StringField().parse(raw_data))


class ModuleExpressionField(Field):
"""
A field for ModuleExpressions.

Accepts strings in the form:
"mypackage.foo.importer"
"mypackage.foo.*"
"mypackage.*.importer"
"mypackage.**"
"""

def parse(self, expression: str) -> ModuleExpression:
last_wildcard = None
for part in expression.split("."):
if "**" == last_wildcard and ("*" == part or "**" == part):
raise ValidationError("A recursive wildcard cannot be followed by a wildcard.")
if "*" == last_wildcard and "**" == part:
raise ValidationError("A wildcard cannot be followed by a recursive wildcard.")
if "*" == part or "**" == part:
last_wildcard = part
continue
if "*" in part:
raise ValidationError("A wildcard can only replace a whole module.")
last_wildcard = None

return ModuleExpression(expression)


class ImportExpressionField(Field):
"""
A field for ImportExpressions.
Expand All @@ -181,24 +209,10 @@ def parse(self, raw_data: Union[str, List]) -> ImportExpression:
if not (importer and imported):
raise ValidationError('Must be in the form "package.importer -> package.imported".')

self._validate_wildcard(importer)
self._validate_wildcard(imported)

return ImportExpression(importer=importer, imported=imported)

def _validate_wildcard(self, expression: str) -> None:
last_wildcard = None
for part in expression.split("."):
if "**" == last_wildcard and ("*" == part or "**" == part):
raise ValidationError("A recursive wildcard cannot be followed by a wildcard.")
if "*" == last_wildcard and "**" == part:
raise ValidationError("A wildcard cannot be followed by a recursive wildcard.")
if "*" == part or "**" == part:
last_wildcard = part
continue
if "*" in part:
raise ValidationError("A wildcard can only replace a whole module.")
last_wildcard = None
return ImportExpression(
importer=ModuleExpressionField().parse(importer),
imported=ModuleExpressionField().parse(imported),
)


class EnumField(Field):
Expand Down
28 changes: 11 additions & 17 deletions src/importlinter/domain/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from grimp import DetailedImport

from importlinter.domain.imports import DirectImport, ImportExpression, Module
from importlinter.domain.imports import DirectImport, ImportExpression, Module, ModuleExpression
from grimp import ImportGraph


Expand Down Expand Up @@ -55,7 +55,9 @@ def import_expression_to_imports(
imports: Set[DirectImport] = set()
matched = False

for (importer, imported) in _expression_to_modules(expression, graph):
importers = _expression_to_modules(expression.importer, graph)
importeds = _expression_to_modules(expression.imported, graph)
for (importer, imported) in itertools.product(importers, importeds):
import_details = graph.get_import_details(importer=importer.name, imported=imported.name)

if import_details:
Expand Down Expand Up @@ -215,23 +217,15 @@ def _to_pattern(expression: str) -> Pattern:
return re.compile(r"^" + r"\.".join(pattern_parts) + r"$")


def _expression_to_modules(
expression: ImportExpression, graph: ImportGraph
) -> Iterable[Tuple[Module, Module]]:
def _expression_to_modules(expression: ModuleExpression, graph: ImportGraph) -> Iterable[Module]:
if not expression.has_wildcard_expression():
return [(Module(expression.importer), Module(expression.imported))]
return [Module(expression.expression)]

importer = []
imported = []

importer_pattern = _to_pattern(expression.importer)
imported_expression = _to_pattern(expression.imported)
pattern = _to_pattern(expression.expression)

modules = set()
for module in graph.modules:
if pattern.match(module):
modules.add(Module(module))

if importer_pattern.match(module):
importer.append(Module(module))
if imported_expression.match(module):
imported.append(Module(module))

return itertools.product(set(importer), set(imported))
return modules
14 changes: 11 additions & 3 deletions src/importlinter/domain/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def __hash__(self) -> int:
return hash((str(self), self.line_contents))


class ModuleExpression(ValueObject):
fbinz marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, expression: str) -> None:
self.expression = expression

def has_wildcard_expression(self) -> bool:
return "*" in self.expression


class ImportExpression(ValueObject):
"""
A user-submitted expression describing an import or set of imports.
Expand All @@ -95,12 +103,12 @@ class ImportExpression(ValueObject):
It does not, however, include more distant descendants such as mypackage.foo.bar.
"""

def __init__(self, importer: str, imported: str) -> None:
def __init__(self, importer: ModuleExpression, imported: ModuleExpression) -> None:
self.importer = importer
self.imported = imported

def has_wildcard_expression(self) -> bool:
return "*" in self.imported or "*" in self.importer
return self.imported.has_wildcard_expression() or self.importer.has_wildcard_expression()

def __str__(self) -> str:
return "{} -> {}".format(self.importer, self.imported)
return "{} -> {}".format(self.importer.expression, self.imported.expression)
32 changes: 25 additions & 7 deletions tests/unit/application/test_contract_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from importlinter.application.contract_utils import AlertLevel, remove_ignored_imports
from importlinter.domain.helpers import MissingImport
from importlinter.domain.imports import DirectImport, ImportExpression, Module
from importlinter.domain.imports import DirectImport, ImportExpression, Module, ModuleExpression


class TestRemoveIgnoredImports:
Expand Down Expand Up @@ -41,8 +41,14 @@ def test_no_unresolved_import_expressions(self, alert_level):
warnings = remove_ignored_imports(
graph=graph,
ignore_imports=[
ImportExpression(importer="mypackage.green", imported="mypackage.blue"),
ImportExpression(importer="mypackage.green", imported="mypackage.purple"),
ImportExpression(
importer=ModuleExpression("mypackage.green"),
imported=ModuleExpression("mypackage.blue"),
),
ImportExpression(
importer=ModuleExpression("mypackage.green"),
imported=ModuleExpression("mypackage.purple"),
),
],
unmatched_alerting=alert_level,
)
Expand Down Expand Up @@ -72,10 +78,22 @@ def test_no_unresolved_import_expressions(self, alert_level):
def test_unresolved_import_expressions(self, alert_level, expected_result):
graph = self._build_graph(self.DIRECT_IMPORTS)
ignore_imports = [
ImportExpression(importer="mypackage.green", imported="mypackage.blue"),
ImportExpression(importer="mypackage.*", imported="mypackage.nonexistent"),
ImportExpression(importer="mypackage.green", imported="mypackage.purple"),
ImportExpression(importer="mypackage.nonexistent", imported="mypackage.blue"),
ImportExpression(
importer=ModuleExpression("mypackage.green"),
imported=ModuleExpression("mypackage.blue"),
),
ImportExpression(
importer=ModuleExpression("mypackage.*"),
imported=ModuleExpression("mypackage.nonexistent"),
),
ImportExpression(
importer=ModuleExpression("mypackage.green"),
imported=ModuleExpression("mypackage.purple"),
),
ImportExpression(
importer=ModuleExpression("mypackage.nonexistent"),
imported=ModuleExpression("mypackage.blue"),
),
]

if isinstance(expected_result, Exception):
Expand Down
55 changes: 43 additions & 12 deletions tests/unit/domain/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
StringField,
ValidationError,
)
from importlinter.domain.imports import ImportExpression, Module
from importlinter.domain.imports import ImportExpression, Module, ModuleExpression


def test_field_cannot_be_instantiated_with_default_and_required():
Expand Down Expand Up @@ -91,49 +91,80 @@ class TestModuleField(BaseFieldTest):
(
(
"mypackage.foo -> mypackage.bar",
ImportExpression(importer="mypackage.foo", imported="mypackage.bar"),
ImportExpression(
importer=ModuleExpression("mypackage.foo"),
imported=ModuleExpression("mypackage.bar"),
),
),
(
"my_package.foo -> my_package.bar", # Extra whitespaces are supported.
ImportExpression(importer="my_package.foo", imported="my_package.bar"),
ImportExpression(
importer=ModuleExpression("my_package.foo"),
imported=ModuleExpression("my_package.bar"),
),
),
(
"my_package.foo -> my_package.foo_bar", # Underscores are supported.
ImportExpression(importer="my_package.foo", imported="my_package.foo_bar"),
ImportExpression(
importer=ModuleExpression("my_package.foo"),
imported=ModuleExpression("my_package.foo_bar"),
),
),
# Wildcards
# ---------
(
"mypackage.foo.* -> mypackage.bar",
ImportExpression(importer="mypackage.foo.*", imported="mypackage.bar"),
ImportExpression(
importer=ModuleExpression("mypackage.foo.*"),
imported=ModuleExpression("mypackage.bar"),
),
),
(
"mypackage.foo.*.baz -> mypackage.bar",
ImportExpression(importer="mypackage.foo.*.baz", imported="mypackage.bar"),
ImportExpression(
importer=ModuleExpression("mypackage.foo.*.baz"),
imported=ModuleExpression("mypackage.bar"),
),
),
(
"mypackage.foo -> mypackage.bar.*",
ImportExpression(importer="mypackage.foo", imported="mypackage.bar.*"),
ImportExpression(
importer=ModuleExpression("mypackage.foo"),
imported=ModuleExpression("mypackage.bar.*"),
),
),
(
"*.*.* -> mypackage.*.foo.*",
ImportExpression(importer="*.*.*", imported="mypackage.*.foo.*"),
ImportExpression(
importer=ModuleExpression("*.*.*"), imported=ModuleExpression("mypackage.*.foo.*")
),
),
(
"mypackage.foo.** -> mypackage.bar",
ImportExpression(importer="mypackage.foo.**", imported="mypackage.bar"),
ImportExpression(
importer=ModuleExpression("mypackage.foo.**"),
imported=ModuleExpression("mypackage.bar"),
),
),
(
"mypackage.foo.**.baz -> mypackage.bar",
ImportExpression(importer="mypackage.foo.**.baz", imported="mypackage.bar"),
ImportExpression(
importer=ModuleExpression("mypackage.foo.**.baz"),
imported=ModuleExpression("mypackage.bar"),
),
),
(
"mypackage.foo -> mypackage.bar.**",
ImportExpression(importer="mypackage.foo", imported="mypackage.bar.**"),
ImportExpression(
importer=ModuleExpression("mypackage.foo"),
imported=ModuleExpression("mypackage.bar.**"),
),
),
(
"** -> mypackage.**.foo.*",
ImportExpression(importer="**", imported="mypackage.**.foo.*"),
ImportExpression(
importer=ModuleExpression("**"), imported=ModuleExpression("mypackage.**.foo.*")
),
),
# Invalid expressions
# -------------------
Expand Down
Loading