diff --git a/stdlib/sqlite3/dbapi2.pyi b/stdlib/sqlite3/dbapi2.pyi index 05e5a176d8ff..e1eaf99624e0 100644 --- a/stdlib/sqlite3/dbapi2.pyi +++ b/stdlib/sqlite3/dbapi2.pyi @@ -1,11 +1,12 @@ import sys -from _typeshed import Self, StrOrBytesPath +from _typeshed import ReadableBuffer, Self, StrOrBytesPath from datetime import date, datetime, time from types import TracebackType -from typing import Any, Callable, Generator, Iterable, Iterator, Protocol, TypeVar +from typing import Any, Callable, Generator, Iterable, Iterator, Protocol, TypeVar, overload from typing_extensions import Literal, final _T = TypeVar("_T") +_SqliteData = str | bytes | int | float | None paramstyle: str threadsafety: int @@ -128,6 +129,18 @@ class _AggregateProtocol(Protocol): def step(self, value: int) -> None: ... def finalize(self) -> int: ... +class _SingleParamWindowAggregateClass(Protocol): + def step(self, __param: Any) -> None: ... + def inverse(self, __param: Any) -> None: ... + def value(self) -> _SqliteData: ... + def finalize(self) -> _SqliteData: ... + +class _WindowAggretateClass(Protocol): + step: Callable[..., None] + inverse: Callable[..., None] + def value(self) -> _SqliteData: ... + def finalize(self) -> _SqliteData: ... + class Connection: DataError: Any DatabaseError: Any @@ -146,8 +159,23 @@ class Connection: total_changes: Any def __init__(self, *args: Any, **kwargs: Any) -> None: ... def close(self) -> None: ... + if sys.version_info >= (3, 11): + def blobopen(self, __table: str, __column: str, __row: int, *, readonly: bool = ..., name: str = ...) -> Blob: ... + def commit(self) -> None: ... def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ... + if sys.version_info >= (3, 11): + # num_params determines how many params will be passed to the aggregate class. We provide an overload + # for the case where num_params = 1, which is expected to be the common case. + @overload + def create_window_function( + self, __name: str, __num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass] + ) -> None: ... + @overload + def create_window_function( + self, __name: str, __num_params: int, aggregate_class: Callable[[], _WindowAggretateClass] + ) -> None: ... + def create_collation(self, __name: str, __callback: Any) -> None: ... if sys.version_info >= (3, 8): def create_function(self, name: str, narg: int, func: Any, *, deterministic: bool = ...) -> None: ... @@ -181,6 +209,11 @@ class Connection: name: str = ..., sleep: float = ..., ) -> None: ... + if sys.version_info >= (3, 11): + def setlimit(self, __category: int, __limit: int) -> int: ... + def getlimit(self, __category: int) -> int: ... + def serialize(self, *, name: str = ...) -> bytes: ... + def deserialize(self, __data: ReadableBuffer, *, name: str = ...) -> None: ... def __call__(self, *args: Any, **kwargs: Any) -> Any: ... def __enter__(self: Self) -> Self: ... @@ -253,3 +286,13 @@ if sys.version_info < (3, 8): def __init__(self, *args, **kwargs): ... class Warning(Exception): ... + +if sys.version_info >= (3, 11): + class Blob: + def close(self) -> None: ... + def read(self, __length: int = ...) -> bytes: ... + def write(self, __data: bytes) -> None: ... + def tell(self) -> int: ... + # whence must be one of os.SEEK_SET, os.SEEK_CUR, os.SEEK_END + def seek(self, __offset: int, __whence: int = ...) -> None: ... + def __len__(self) -> int: ...