diff --git a/src/nominatim_db/db/connection.py b/src/nominatim_db/db/connection.py index 97d81fc47..e8f9b6a0c 100644 --- a/src/nominatim_db/db/connection.py +++ b/src/nominatim_db/db/connection.py @@ -188,3 +188,11 @@ def get_pg_env(dsn: str, LOG.error("Unknown connection parameter '%s' ignored.", param) return env + + +async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None: + """ Open a connection to the database and run a single query + asynchronously. + """ + async with await psycopg.AsyncConnection.connect(dsn) as aconn: + await aconn.execute(query) diff --git a/src/nominatim_db/indexer/indexer.py b/src/nominatim_db/indexer/indexer.py index e5c70b713..15d355144 100644 --- a/src/nominatim_db/indexer/indexer.py +++ b/src/nominatim_db/indexer/indexer.py @@ -16,7 +16,7 @@ from psycopg_pool import AsyncConnectionPool from ..db.connection import connect, execute_scalar -from ..utils.async_task_pool import AsyncTaskPool +from ..db.query_pool import QueryPool from ..tokenizer.base import AbstractTokenizer from .progress import ProgressLogger from . import runners @@ -141,39 +141,36 @@ async def _index(self, runner: runners.Runner, batch: int = 1) -> int: total_tuples = self._prepare_indexing(runner) - async with await psycopg.AsyncConnection.connect( - self.dsn, row_factory=psycopg.rows.dict_row) as aconn: - progress = ProgressLogger(runner.name(), total_tuples) + progress = ProgressLogger(runner.name(), total_tuples) - if total_tuples > 0: + if total_tuples > 0: + async with await psycopg.AsyncConnection.connect( + self.dsn, row_factory=psycopg.rows.dict_row) as aconn: fetcher_time = 0.0 async with aconn.cursor(name='places') as cur: await cur.execute(runner.sql_get_objects()) - async with AsyncConnectionPool( - self.dsn, min_size = self.num_threads) as conn_pool: - async with AsyncTaskPool(2 * self.num_threads) as task_pool: + async with QueryPool(self.dsn, self.num_threads, autocommit=True) as pool: + places_task = asyncio.create_task(_fetch_next_batch(cur, runner)) + places = await places_task + while places is not None: + # asynchronously query the next batch places_task = asyncio.create_task(_fetch_next_batch(cur, runner)) - places = await places_task - while places is not None: - # asynchronously query the next batch - places_task = asyncio.create_task(_fetch_next_batch(cur, runner)) - # And insert the current batch - for idx in range(0, len(places), batch): - part = places[idx:idx + batch] - LOG.debug("Processing places: %s", str(part)) - await task_pool.add_coroutine(_index_places(conn_pool, - runner, part)) - progress.add(len(part)) + # And insert the current batch + for idx in range(0, len(places), batch): + part = places[idx:idx + batch] + LOG.debug("Processing places: %s", str(part)) + await pool.put_query(*runner.index_places(part)) + progress.add(len(part)) - # get the results for the next batch - tstart = time.time() - places = await places_task - fetcher_time += time.time() - tstart + # get the results for the next batch + tstart = time.time() + places = await places_task + fetcher_time += time.time() - tstart - LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs", - fetcher_time, task_pool.wait_time) + LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs", + fetcher_time, pool.wait_time) @@ -207,17 +204,3 @@ async def _fetch_next_batch(cursor: psycopg.AsyncCursor[psycopg.rows.DictRow], return await cur.fetchall() return ids - - -async def _index_places(pool: AsyncConnectionPool, runner: runners.Runner, - places: List[psycopg.rows.DictRow]) -> None: - sql, params = runner.index_places(places) - - async with pool.connection() as conn: - while True: - try: - await conn.execute(sql, params) - return - except psycopg.errors.DeadlockDetected: - LOG.info("Deadlock detected (params = %s), retry.", str(params)) - await conn.rollback() diff --git a/src/nominatim_db/tools/database_import.py b/src/nominatim_db/tools/database_import.py index 9b931994a..e3b962558 100644 --- a/src/nominatim_db/tools/database_import.py +++ b/src/nominatim_db/tools/database_import.py @@ -225,7 +225,7 @@ async def load_data(dsn: str, threads: int) -> None: except Exception as ex: for task in pool: task.cancel() - progress.cacnel() + progress.cancel() raise ex progress.cancel()