diff --git a/django-stubs/db/backends/postgresql/base.pyi b/django-stubs/db/backends/postgresql/base.pyi index eb27220421..a59100e003 100644 --- a/django-stubs/db/backends/postgresql/base.pyi +++ b/django-stubs/db/backends/postgresql/base.pyi @@ -3,6 +3,7 @@ from typing import Any, Dict, Tuple, Type from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper +from django.db.backends.utils import _ExecuteQuery from .client import DatabaseClient from .creation import DatabaseCreation @@ -37,5 +38,5 @@ class DatabaseWrapper(BaseDatabaseWrapper): def pg_version(self) -> int: ... class CursorDebugWrapper(BaseCursorDebugWrapper): - def copy_expert(self, sql: str, file: IOBase, *args: Any): ... + def copy_expert(self, sql: _ExecuteQuery, file: IOBase, *args: Any): ... def copy_to(self, file: IOBase, table: str, *args: Any, **kwargs: Any): ... diff --git a/django-stubs/db/backends/utils.pyi b/django-stubs/db/backends/utils.pyi index 02bf51d54b..e8a63addff 100644 --- a/django-stubs/db/backends/utils.pyi +++ b/django-stubs/db/backends/utils.pyi @@ -3,7 +3,21 @@ import sys import types from contextlib import contextmanager from decimal import Decimal -from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union, overload +from typing import ( + Any, + Dict, + Generator, + Iterator, + List, + Mapping, + Optional, + Protocol, + Sequence, + Tuple, + Type, + Union, + overload, +) from uuid import UUID if sys.version_info < (3, 8): @@ -13,6 +27,14 @@ else: logger: Any +# Protocol matching psycopg2.sql.Composable, to avoid depending psycopg2 +class _Composable(Protocol): + def as_string(self, context: Any) -> str: ... + def __add__(self, other: _Composable) -> _Composable: ... + def __mul__(self, n: int) -> _Composable: ... + +_ExecuteQuery = str | _Composable + # Python types that can be adapted to SQL. _SQLType = Union[ None, bool, int, float, Decimal, str, bytes, datetime.date, datetime.datetime, UUID, Tuple[Any, ...], List[Any] @@ -36,8 +58,8 @@ class CursorWrapper: def callproc( self, procname: str, params: Optional[Sequence[Any]] = ..., kparams: Optional[Dict[str, int]] = ... ) -> Any: ... - def execute(self, sql: str, params: _ExecuteParameters = ...) -> Any: ... - def executemany(self, sql: str, param_list: Sequence[_ExecuteParameters]) -> Any: ... + def execute(self, sql: _ExecuteQuery, params: _ExecuteParameters = ...) -> Any: ... + def executemany(self, sql: _ExecuteQuery, param_list: Sequence[_ExecuteParameters]) -> Any: ... class CursorDebugWrapper(CursorWrapper): cursor: Any diff --git a/tests/typecheck/db/test_connection.yml b/tests/typecheck/db/test_connection.yml index 6b3f88bf16..4b2fef0ddd 100644 --- a/tests/typecheck/db/test_connection.yml +++ b/tests/typecheck/db/test_connection.yml @@ -4,6 +4,13 @@ with connection.cursor() as cursor: reveal_type(cursor) # N: Revealed type is "django.db.backends.utils.CursorWrapper" cursor.execute("SELECT %s", [123]) +- case: raw_connection_psycopg2_composable + main: | + from django.db import connection + from psycopg2.sql import SQL, Identifier + with connection.cursor() as cursor: + reveal_type(cursor) # N: Revealed type is "django.db.backends.utils.CursorWrapper" + cursor.execute(SQL("INSERT INTO {} VALUES (%s)").format(Identifier("my_table")), [123]) - case: raw_connections main: | from django.db import connections