From 6596db2174a5c15d85ae096859567aa92a72cdfa Mon Sep 17 00:00:00 2001 From: Petter Friberg Date: Mon, 11 Sep 2023 21:10:36 +0200 Subject: [PATCH] Use `parse_bool` implementation from mypy --- mypy_django_plugin/lib/helpers.py | 9 --------- mypy_django_plugin/transformers/fields.py | 5 +++-- mypy_django_plugin/transformers/models.py | 8 +------- mypy_django_plugin/transformers/querysets.py | 11 ++++++----- 4 files changed, 10 insertions(+), 23 deletions(-) diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 886e63ada..9f5289970 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -171,15 +171,6 @@ def make_optional(typ: MypyType) -> MypyType: return UnionType.make_union([typ, NoneTyp()]) -def parse_bool(expr: Expression) -> Optional[bool]: - if isinstance(expr, NameExpr): - if expr.fullname == "builtins.True": - return True - if expr.fullname == "builtins.False": - return False - return None - - def has_any_of_bases(info: TypeInfo, bases: Iterable[str]) -> bool: for base_fullname in bases: if info.has_base(base_fullname): diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index 7569975d6..6a68207c3 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -6,6 +6,7 @@ from django.db.models.fields.reverse_related import ForeignObjectRel from mypy.nodes import AssignmentStmt, NameExpr, TypeInfo from mypy.plugin import FunctionContext +from mypy.semanal_shared import parse_bool from mypy.types import AnyType, Instance, TypeOfAny, UnionType from mypy.types import Type as MypyType @@ -134,12 +135,12 @@ def set_descriptor_types_for_field( is_nullable = False null_expr = helpers.get_call_argument_by_name(ctx, "null") if null_expr is not None: - is_nullable = helpers.parse_bool(null_expr) or False + is_nullable = parse_bool(null_expr) or False # Allow setting field value to `None` when a field is primary key and has a default that can produce a value default_expr = helpers.get_call_argument_by_name(ctx, "default") primary_key_expr = helpers.get_call_argument_by_name(ctx, "primary_key") if default_expr is not None and primary_key_expr is not None: - is_set_nullable = helpers.parse_bool(primary_key_expr) or False + is_set_nullable = parse_bool(primary_key_expr) or False set_type, get_type = get_field_descriptor_types( default_return_type.type, diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index e2f05b4e4..85c2e7271 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -661,13 +661,7 @@ def is_model_abstract(self) -> bool: and stmt.lvalues[0].name == "abstract" ): # abstract = True (builtins.bool) - rhs_is_true = ( - isinstance(stmt.rvalue, NameExpr) - and stmt.rvalue.name == "True" - and isinstance(stmt.rvalue.node, Var) - and isinstance(stmt.rvalue.node.type, Instance) - and stmt.rvalue.node.type.type.fullname == "builtins.bool" - ) + rhs_is_true = self.api.parse_bool(stmt.rvalue) is True # abstract: Literal[True] is_literal_true = isinstance(stmt.type, LiteralType) and stmt.type.value is True return rhs_is_true or is_literal_true diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index bca760e90..671e4fea4 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -5,8 +5,9 @@ from django.db.models.base import Model from django.db.models.fields.related import RelatedField from django.db.models.fields.reverse_related import ForeignObjectRel -from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, Expression, NameExpr +from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, Expression from mypy.plugin import FunctionContext, MethodContext +from mypy.semanal_shared import parse_bool from mypy.types import AnyType, Instance, TupleType, TypedDictType, TypeOfAny, get_proper_type from mypy.types import Type as MypyType @@ -159,14 +160,14 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: return default_return_type flat_expr = helpers.get_call_argument_by_name(ctx, "flat") - if flat_expr is not None and isinstance(flat_expr, NameExpr): - flat = helpers.parse_bool(flat_expr) + if flat_expr is not None: + flat = parse_bool(flat_expr) else: flat = False named_expr = helpers.get_call_argument_by_name(ctx, "named") - if named_expr is not None and isinstance(named_expr, NameExpr): - named = helpers.parse_bool(named_expr) + if named_expr is not None: + named = parse_bool(named_expr) else: named = False