Skip to content

Commit

Permalink
Improve type hint for backends
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng committed Jul 31, 2024
1 parent 2c634bf commit e353b97
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 50 deletions.
2 changes: 1 addition & 1 deletion tortoise/backends/asyncpg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(self, connection: AsyncpgDBClient) -> None:
def _in_transaction(self) -> "TransactionContext":
return NestedTransactionPooledContext(self)

def acquire_connection(self) -> "ConnectionWrapper":
def acquire_connection(self) -> ConnectionWrapper[asyncpg.Connection]:
return ConnectionWrapper(self._lock, self)

@translate_exceptions
Expand Down
41 changes: 28 additions & 13 deletions tortoise/backends/base/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
from __future__ import annotations

import asyncio
from typing import Any, List, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Generic,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)

from pypika import Query

Expand All @@ -9,6 +22,8 @@
from tortoise.exceptions import TransactionManagementError
from tortoise.log import db_client_logger

T_conn = TypeVar("T_conn") # Instance of client connection, such as: asyncpg.Connection()


class Capabilities:
"""
Expand Down Expand Up @@ -202,21 +217,21 @@ async def execute_query_dict(self, query: str, values: Optional[list] = None) ->
raise NotImplementedError() # pragma: nocoverage


class ConnectionWrapper:
class ConnectionWrapper(Generic[T_conn]):
__slots__ = ("connection", "lock", "client")

def __init__(self, lock: asyncio.Lock, client: Any) -> None:
"""Wraps the connections with a lock to facilitate safe concurrent access."""
self.lock: asyncio.Lock = lock
self.client = client
self.connection: Any = client._connection
self.connection: T_conn = client._connection

async def ensure_connection(self) -> None:
if not self.connection:
await self.client.create_connection(with_db=True)
self.connection = self.client._connection

async def __aenter__(self):
async def __aenter__(self) -> T_conn:
await self.lock.acquire()
await self.ensure_connection()
return self.connection
Expand All @@ -225,7 +240,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.lock.release()


class TransactionContext:
class TransactionContext(Generic[T_conn]):
__slots__ = ("connection", "connection_name", "token", "lock")

def __init__(self, connection: Any) -> None:
Expand All @@ -238,7 +253,7 @@ async def ensure_connection(self) -> None:
await self.connection._parent.create_connection(with_db=True)
self.connection._connection = self.connection._parent._connection

async def __aenter__(self):
async def __aenter__(self) -> T_conn:
await self.ensure_connection()
await self.lock.acquire() # type:ignore
self.token = connections.set(self.connection_name, self.connection)
Expand All @@ -264,7 +279,7 @@ async def ensure_connection(self) -> None:
if not self.connection._parent._pool:
await self.connection._parent.create_connection(with_db=True)

async def __aenter__(self):
async def __aenter__(self) -> T_conn:
await self.ensure_connection()
self.token = connections.set(self.connection_name, self.connection)
self.connection._connection = await self.connection._parent._pool.acquire()
Expand All @@ -285,7 +300,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:


class NestedTransactionContext(TransactionContext):
async def __aenter__(self):
async def __aenter__(self) -> T_conn:
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
Expand All @@ -297,7 +312,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:


class NestedTransactionPooledContext(TransactionContext):
async def __aenter__(self):
async def __aenter__(self) -> T_conn:
await self.lock.acquire() # type:ignore
return self.connection

Expand All @@ -310,23 +325,23 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.connection.rollback()


class PoolConnectionWrapper:
class PoolConnectionWrapper(Generic[T_conn]):
def __init__(self, client: Any) -> None:
"""Class to manage acquiring from and releasing connections to a pool."""
self.pool = client._pool
self.client = client
self.connection = None
self.connection: Optional[T_conn] = None

async def ensure_connection(self) -> None:
if not self.pool:
await self.client.create_connection(with_db=True)
self.pool = self.client._pool

async def __aenter__(self):
async def __aenter__(self) -> T_conn:
await self.ensure_connection()
# get first available connection
self.connection = await self.pool.acquire()
return self.connection
return cast(T_conn, self.connection)

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
# release the connection back to the pool
Expand Down
11 changes: 6 additions & 5 deletions tortoise/backends/base_postgres/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import (
Any,
Callable,
Coroutine,
List,
Optional,
SupportsInt,
Expand All @@ -24,16 +25,16 @@
from tortoise.backends.base_postgres.executor import BasePostgresExecutor
from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator

FuncType = Callable[..., Any]
F = TypeVar("F", bound=FuncType)
T = TypeVar("T")
FuncType = Callable[..., Coroutine[None, None, T]]


def translate_exceptions(func: F) -> F:
def translate_exceptions(func: FuncType) -> FuncType:
@wraps(func)
async def _translate_exceptions(self, *args, **kwargs):
async def _translate_exceptions(self, *args, **kwargs) -> T:
return await self._translate_exceptions(func, *args, **kwargs)

return _translate_exceptions # type: ignore
return _translate_exceptions


class BasePostgresPool:
Expand Down
24 changes: 17 additions & 7 deletions tortoise/backends/mysql/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import asyncio
from functools import wraps
from typing import Any, Callable, List, Optional, SupportsInt, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Coroutine,
List,
Optional,
SupportsInt,
Tuple,
TypeVar,
Union,
)

try:
import asyncmy as mysql
Expand Down Expand Up @@ -33,13 +43,13 @@
TransactionManagementError,
)

FuncType = Callable[..., Any]
F = TypeVar("F", bound=FuncType)
T = TypeVar("T")
FuncType = Callable[..., Coroutine[None, None, T]]


def translate_exceptions(func: F) -> F:
def translate_exceptions(func: FuncType) -> FuncType:
@wraps(func)
async def translate_exceptions_(self, *args):
async def translate_exceptions_(self, *args) -> T:
try:
return await func(self, *args)
except (
Expand All @@ -53,7 +63,7 @@ async def translate_exceptions_(self, *args):
except errors.IntegrityError as exc:
raise IntegrityError(exc)

return translate_exceptions_ # type: ignore
return translate_exceptions_


class MySQLClient(BaseDBAsyncClient):
Expand Down Expand Up @@ -228,7 +238,7 @@ def __init__(self, connection: MySQLClient) -> None:
def _in_transaction(self) -> "TransactionContext":
return NestedTransactionPooledContext(self)

def acquire_connection(self) -> ConnectionWrapper:
def acquire_connection(self) -> ConnectionWrapper[mysql.Connection]:
return ConnectionWrapper(self._lock, self)

@translate_exceptions
Expand Down
19 changes: 11 additions & 8 deletions tortoise/backends/odbc/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from abc import ABC
from functools import wraps
from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Coroutine, List, Optional, Tuple, TypeVar, Union

import asyncodbc
import pyodbc
Expand All @@ -22,13 +22,16 @@
TransactionManagementError,
)

FuncType = Callable[..., Any]
F = TypeVar("F", bound=FuncType)
T = TypeVar("T")
FuncType = Callable[..., Coroutine[None, None, T]]
ConnWrapperType = Union[
ConnectionWrapper[asyncodbc.Connection], PoolConnectionWrapper[asyncodbc.Connection]
]


def translate_exceptions(func: F) -> F:
def translate_exceptions(func: FuncType) -> FuncType:
@wraps(func)
async def translate_exceptions_(self, *args):
async def translate_exceptions_(self, *args) -> T:
try:
return await func(self, *args)
except (
Expand All @@ -43,7 +46,7 @@ async def translate_exceptions_(self, *args):
except (pyodbc.IntegrityError, pyodbc.Error) as exc:
raise IntegrityError(exc)

return translate_exceptions_ # type: ignore
return translate_exceptions_


class ODBCClient(BaseDBAsyncClient, ABC):
Expand Down Expand Up @@ -110,7 +113,7 @@ async def close(self) -> None:
self.log.debug("Closed connection %s with params: %s", self._connection, self._template)
self._pool = None

def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]:
def acquire_connection(self) -> ConnWrapperType:
return PoolConnectionWrapper(self)

@translate_exceptions
Expand Down Expand Up @@ -174,7 +177,7 @@ def __init__(self, connection: ODBCClient) -> None:
def _in_transaction(self) -> "TransactionContext":
return NestedTransactionPooledContext(self)

def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]:
def acquire_connection(self) -> ConnWrapperType:
return ConnectionWrapper(self._lock, self)

@translate_exceptions
Expand Down
9 changes: 6 additions & 3 deletions tortoise/backends/oracle/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import functools
from typing import Any, SupportsInt, Union
from typing import TYPE_CHECKING, Any, SupportsInt, Union

import pyodbc
import pytz
Expand Down Expand Up @@ -28,6 +28,9 @@
from tortoise.backends.oracle.executor import OracleExecutor
from tortoise.backends.oracle.schema_generator import OracleSchemaGenerator

if TYPE_CHECKING:
import asyncodbc


class OracleClient(ODBCClient):
query_class = OracleQuery
Expand Down Expand Up @@ -99,8 +102,8 @@ def _timestamp_convert(self, value: bytes) -> datetime.date:
except ValueError:
return parse_datetime(value.decode()[:-32]).astimezone(tz=pytz.utc)

async def __aenter__(self):
connection = await super(OraclePoolConnectionWrapper, self).__aenter__() # type: ignore
async def __aenter__(self) -> "asyncodbc.Connection":
connection = await super().__aenter__()
if getattr(self.client, "database", False) and not hasattr(connection, "current_schema"):
await connection.execute(f'ALTER SESSION SET CURRENT_SCHEMA = "{self.client.user}"')
await connection.execute("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD'")
Expand Down
14 changes: 9 additions & 5 deletions tortoise/backends/oracle/executor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from typing import TYPE_CHECKING, cast

from tortoise import Model
from tortoise.backends.odbc.executor import ODBCExecutor

if TYPE_CHECKING:
from .client import OracleClient # pylint: disable=W0611


class OracleExecutor(ODBCExecutor):
async def _process_insert_result(self, instance: Model, results: int) -> None:
sql = "SELECT SEQUENCE_NAME FROM ALL_TAB_IDENTITY_COLS where TABLE_NAME = ? and OWNER = ?"
ret = await self.db.execute_query_dict(
sql, values=[instance._meta.db_table, self.db.database] # type: ignore
)
db = cast("OracleClient", self.db)
ret = await db.execute_query_dict(sql, values=[instance._meta.db_table, db.database])
try:
seq = ret[0]["SEQUENCE_NAME"]
except IndexError:
return
sql = f"SELECT {seq}.CURRVAL FROM DUAL" # nosec:B608
ret = await self.db.execute_query_dict(sql)
await super(OracleExecutor, self)._process_insert_result(instance, ret[0]["CURRVAL"])
ret = await db.execute_query_dict(sql)
await super()._process_insert_result(instance, ret[0]["CURRVAL"])
2 changes: 1 addition & 1 deletion tortoise/backends/psycopg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(self, connection: PsycopgClient) -> None:
def _in_transaction(self) -> base_client.TransactionContext:
return base_client.NestedTransactionPooledContext(self)

def acquire_connection(self) -> base_client.ConnectionWrapper:
def acquire_connection(self) -> base_client.ConnectionWrapper[psycopg.AsyncConnection]:
return base_client.ConnectionWrapper(self._lock, self)

@postgres_client.translate_exceptions
Expand Down
24 changes: 17 additions & 7 deletions tortoise/backends/sqlite/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import os
import sqlite3
from functools import wraps
from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar
from typing import (
Any,
Callable,
Coroutine,
List,
Optional,
Sequence,
Tuple,
TypeVar,
cast,
)

import aiosqlite
from pypika import SQLLiteQuery
Expand All @@ -23,21 +33,21 @@
TransactionManagementError,
)

FuncType = Callable[..., Any]
F = TypeVar("F", bound=FuncType)
T = TypeVar("T")
FuncType = Callable[..., Coroutine[None, None, T]]


def translate_exceptions(func: F) -> F:
def translate_exceptions(func: FuncType) -> FuncType:
@wraps(func)
async def translate_exceptions_(self, query, *args):
async def translate_exceptions_(self, query, *args) -> T:
try:
return await func(self, query, *args)
except sqlite3.OperationalError as exc:
raise OperationalError(exc)
except sqlite3.IntegrityError as exc:
raise IntegrityError(exc)

return translate_exceptions_ # type: ignore
return translate_exceptions_


class SqliteClient(BaseDBAsyncClient):
Expand Down Expand Up @@ -158,7 +168,7 @@ async def execute_script(self, query: str) -> None:
class TransactionWrapper(SqliteClient, BaseTransactionWrapper):
def __init__(self, connection: SqliteClient) -> None:
self.connection_name = connection.connection_name
self._connection: aiosqlite.Connection = connection._connection # type: ignore
self._connection: aiosqlite.Connection = cast(aiosqlite.Connection, connection._connection)
self._lock = asyncio.Lock()
self._trxlock = connection._lock
self.log = connection.log
Expand Down

0 comments on commit e353b97

Please sign in to comment.