diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index 16b42c1b2..bdb91557e 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -86,15 +86,15 @@ def __init__(self, django_settings_module: str) -> None: self.settings = settings @cached_property - def model_modules(self) -> Dict[str, Set[Type[Model]]]: + def model_modules(self) -> Dict[str, Dict[str, Type[Model]]]: """All modules that contain Django models.""" - modules: Dict[str, Set[Type[Model]]] = defaultdict(set) + modules: Dict[str, Dict[str, Type[Model]]] = defaultdict(dict) for concrete_model_cls in self.apps_registry.get_models(): - modules[concrete_model_cls.__module__].add(concrete_model_cls) + modules[concrete_model_cls.__module__][concrete_model_cls.__name__] = concrete_model_cls # collect abstract=True models for model_cls in concrete_model_cls.mro()[1:]: if issubclass(model_cls, Model) and hasattr(model_cls, "_meta") and model_cls._meta.abstract: - modules[model_cls.__module__].add(model_cls) + modules[model_cls.__module__][model_cls.__name__] = model_cls return modules def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]: @@ -109,10 +109,7 @@ def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]: fullname = fullname.replace("__", ".") module, _, model_cls_name = fullname.rpartition(".") - for model_cls in self.model_modules.get(module, set()): - if model_cls.__name__ == model_cls_name: - return model_cls - return None + return self.model_modules.get(module, {}).get(model_cls_name) def get_model_fields(self, model_cls: Type[Model]) -> Iterator["Field[Any, Any]"]: for field in model_cls._meta.get_fields(): diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index f97aeab17..454a6cb13 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -127,7 +127,7 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]: return [] deps = set() - for model_class in defined_model_classes: + for model_class in defined_model_classes.values(): for field in itertools.chain( # forward relations self.django_context.get_model_related_fields(model_class),