From c4fd3ab97f94f184369f1d9e7189a992ffc8f4f7 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Tue, 5 Dec 2023 11:29:16 +0100 Subject: [PATCH 01/17] hide type differences between Postgres and Sqlite in custom types Also define a custom set of operators in preparation of differences in implementation. --- nominatim/api/core.py | 2 +- nominatim/api/search/db_search_fields.py | 6 +- nominatim/api/search/db_searches.py | 13 +--- nominatim/api/search/icu_tokenizer.py | 3 +- nominatim/db/sqlalchemy_schema.py | 54 +++----------- nominatim/db/sqlalchemy_types/__init__.py | 17 +++++ .../geometry.py} | 0 nominatim/db/sqlalchemy_types/int_array.py | 73 +++++++++++++++++++ nominatim/db/sqlalchemy_types/json.py | 30 ++++++++ nominatim/db/sqlalchemy_types/key_value.py | 47 ++++++++++++ nominatim/typing.py | 1 + 11 files changed, 187 insertions(+), 59 deletions(-) create mode 100644 nominatim/db/sqlalchemy_types/__init__.py rename nominatim/db/{sqlalchemy_types.py => sqlalchemy_types/geometry.py} (100%) create mode 100644 nominatim/db/sqlalchemy_types/int_array.py create mode 100644 nominatim/db/sqlalchemy_types/json.py create mode 100644 nominatim/db/sqlalchemy_types/key_value.py diff --git a/nominatim/api/core.py b/nominatim/api/core.py index 44ac91606f..c8045c2d14 100644 --- a/nominatim/api/core.py +++ b/nominatim/api/core.py @@ -137,7 +137,7 @@ def _on_sqlite_connect(dbapi_con: Any, _: Any) -> None: self._property_cache['DB:server_version'] = server_version - self._tables = SearchTables(sa.MetaData(), engine.name) # pylint: disable=no-member + self._tables = SearchTables(sa.MetaData()) # pylint: disable=no-member self._engine = engine diff --git a/nominatim/api/search/db_search_fields.py b/nominatim/api/search/db_search_fields.py index 59af826086..52693e95fc 100644 --- a/nominatim/api/search/db_search_fields.py +++ b/nominatim/api/search/db_search_fields.py @@ -11,7 +11,6 @@ import dataclasses import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import ARRAY from nominatim.typing import SaFromClause, SaColumn, SaExpression from nominatim.api.search.query import Token @@ -155,10 +154,9 @@ def sql_condition(self, table: SaFromClause) -> SaColumn: if self.lookup_type == 'lookup_all': return col.contains(self.tokens) if self.lookup_type == 'lookup_any': - return cast(SaColumn, col.overlap(self.tokens)) + return cast(SaColumn, col.overlaps(self.tokens)) - return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'), - type_=ARRAY(sa.Integer())).contains(self.tokens) + return sa.func.coalesce(sa.null(), col).contains(self.tokens) # pylint: disable=not-callable class SearchData: diff --git a/nominatim/api/search/db_searches.py b/nominatim/api/search/db_searches.py index 232f816ef8..2b4dfd3c9b 100644 --- a/nominatim/api/search/db_searches.py +++ b/nominatim/api/search/db_searches.py @@ -11,7 +11,7 @@ import abc import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import ARRAY, array_agg +from sqlalchemy.dialects.postgresql import array_agg from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \ SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind @@ -494,10 +494,7 @@ async def lookup_in_country_table(self, conn: SearchConnection, sub = sql.subquery('grid') sql = sa.select(t.c.country_code, - (t.c.name - + sa.func.coalesce(t.c.derived_name, - sa.cast('', type_=conn.t.types.Composite)) - ).label('name'), + t.c.name.merge(t.c.derived_name).label('name'), sub.c.centroid, sub.c.bbox)\ .join(sub, t.c.country_code == sub.c.country_code) @@ -569,10 +566,8 @@ async def lookup(self, conn: SearchConnection, assert self.lookups[0].lookup_type == 'restrict' tsearch = conn.t.search_name sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\ - .where(sa.func.array_cat(tsearch.c.name_vector, - tsearch.c.nameaddress_vector, - type_=ARRAY(sa.Integer)) - .contains(self.lookups[0].tokens)) + .where((tsearch.c.name_vector + tsearch.c.nameaddress_vector) + .contains(self.lookups[0].tokens)) for ranking in self.rankings: penalty += ranking.sql_penalty(conn.t.search_name) diff --git a/nominatim/api/search/icu_tokenizer.py b/nominatim/api/search/icu_tokenizer.py index fceec2df52..eabd329d57 100644 --- a/nominatim/api/search/icu_tokenizer.py +++ b/nominatim/api/search/icu_tokenizer.py @@ -22,6 +22,7 @@ from nominatim.api.logging import log from nominatim.api.search import query as qmod from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer +from nominatim.db.sqlalchemy_types import Json DB_TO_TOKEN_TYPE = { @@ -159,7 +160,7 @@ async def _make_transliterator() -> Any: sa.Column('word_token', sa.Text, nullable=False), sa.Column('type', sa.Text, nullable=False), sa.Column('word', sa.Text), - sa.Column('info', self.conn.t.types.Json)) + sa.Column('info', Json)) async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct: diff --git a/nominatim/db/sqlalchemy_schema.py b/nominatim/db/sqlalchemy_schema.py index 7dd1e0ce0b..0ec22b7e1f 100644 --- a/nominatim/db/sqlalchemy_schema.py +++ b/nominatim/db/sqlalchemy_schema.py @@ -7,37 +7,10 @@ """ SQLAlchemy definitions for all tables used by the frontend. """ -from typing import Any - import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import HSTORE, ARRAY, JSONB, array -from sqlalchemy.dialects.sqlite import JSON as sqlite_json import nominatim.db.sqlalchemy_functions #pylint: disable=unused-import -from nominatim.db.sqlalchemy_types import Geometry - -class PostgresTypes: - """ Type definitions for complex types as used in Postgres variants. - """ - Composite = HSTORE - Json = JSONB - IntArray = ARRAY(sa.Integer()) #pylint: disable=invalid-name - to_array = array - - -class SqliteTypes: - """ Type definitions for complex types as used in Postgres variants. - """ - Composite = sqlite_json - Json = sqlite_json - IntArray = sqlite_json - - @staticmethod - def to_array(arr: Any) -> Any: - """ Sqlite has no special conversion for arrays. - """ - return arr - +from nominatim.db.sqlalchemy_types import Geometry, KeyValueStore, IntArray #pylint: disable=too-many-instance-attributes class SearchTables: @@ -47,14 +20,7 @@ class SearchTables: Any data used for updates only will not be visible. """ - def __init__(self, meta: sa.MetaData, engine_name: str) -> None: - if engine_name == 'postgresql': - self.types: Any = PostgresTypes - elif engine_name == 'sqlite': - self.types = SqliteTypes - else: - raise ValueError("Only 'postgresql' and 'sqlite' engines are supported.") - + def __init__(self, meta: sa.MetaData) -> None: self.meta = meta self.import_status = sa.Table('import_status', meta, @@ -80,9 +46,9 @@ def __init__(self, meta: sa.MetaData, engine_name: str) -> None: sa.Column('class', sa.Text, nullable=False, key='class_'), sa.Column('type', sa.Text, nullable=False), sa.Column('admin_level', sa.SmallInteger), - sa.Column('name', self.types.Composite), - sa.Column('address', self.types.Composite), - sa.Column('extratags', self.types.Composite), + sa.Column('name', KeyValueStore), + sa.Column('address', KeyValueStore), + sa.Column('extratags', KeyValueStore), sa.Column('geometry', Geometry, nullable=False), sa.Column('wikipedia', sa.Text), sa.Column('country_code', sa.String(2)), @@ -118,14 +84,14 @@ def __init__(self, meta: sa.MetaData, engine_name: str) -> None: sa.Column('step', sa.SmallInteger), sa.Column('indexed_status', sa.SmallInteger), sa.Column('linegeo', Geometry), - sa.Column('address', self.types.Composite), + sa.Column('address', KeyValueStore), sa.Column('postcode', sa.Text), sa.Column('country_code', sa.String(2))) self.country_name = sa.Table('country_name', meta, sa.Column('country_code', sa.String(2)), - sa.Column('name', self.types.Composite), - sa.Column('derived_name', self.types.Composite), + sa.Column('name', KeyValueStore), + sa.Column('derived_name', KeyValueStore), sa.Column('partition', sa.Integer)) self.country_grid = sa.Table('country_osm_grid', meta, @@ -139,8 +105,8 @@ def __init__(self, meta: sa.MetaData, engine_name: str) -> None: sa.Column('importance', sa.Float), sa.Column('search_rank', sa.SmallInteger), sa.Column('address_rank', sa.SmallInteger), - sa.Column('name_vector', self.types.IntArray), - sa.Column('nameaddress_vector', self.types.IntArray), + sa.Column('name_vector', IntArray), + sa.Column('nameaddress_vector', IntArray), sa.Column('country_code', sa.String(2)), sa.Column('centroid', Geometry)) diff --git a/nominatim/db/sqlalchemy_types/__init__.py b/nominatim/db/sqlalchemy_types/__init__.py new file mode 100644 index 0000000000..dc417995d2 --- /dev/null +++ b/nominatim/db/sqlalchemy_types/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Module with custom types for SQLAlchemy +""" + +# See also https://github.com/PyCQA/pylint/issues/6006 +# pylint: disable=useless-import-alias + +from .geometry import (Geometry as Geometry) +from .int_array import (IntArray as IntArray) +from .key_value import (KeyValueStore as KeyValueStore) +from .json import (Json as Json) diff --git a/nominatim/db/sqlalchemy_types.py b/nominatim/db/sqlalchemy_types/geometry.py similarity index 100% rename from nominatim/db/sqlalchemy_types.py rename to nominatim/db/sqlalchemy_types/geometry.py diff --git a/nominatim/db/sqlalchemy_types/int_array.py b/nominatim/db/sqlalchemy_types/int_array.py new file mode 100644 index 0000000000..335d554197 --- /dev/null +++ b/nominatim/db/sqlalchemy_types/int_array.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Custom type for an array of integers. +""" +from typing import Any, List, cast, Optional + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import ARRAY + +from nominatim.typing import SaDialect, SaColumn + +# pylint: disable=all + +class IntList(sa.types.TypeDecorator[Any]): + """ A list of integers saved as a text of comma-separated numbers. + """ + impl = sa.types.Unicode + cache_ok = True + + def process_bind_param(self, value: Optional[Any], dialect: 'sa.Dialect') -> Optional[str]: + if value is None: + return None + + assert isinstance(value, list) + return ','.join(map(str, value)) + + def process_result_value(self, value: Optional[Any], + dialect: SaDialect) -> Optional[List[int]]: + return [int(v) for v in value.split(',')] if value is not None else None + + def copy(self, **kw: Any) -> 'IntList': + return IntList(self.impl.length) + + +class IntArray(sa.types.TypeDecorator[Any]): + """ Dialect-independent list of integers. + """ + impl = IntList + cache_ok = True + + def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]: + if dialect.name == 'postgresql': + return ARRAY(sa.Integer()) #pylint: disable=invalid-name + + return IntList() + + + class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg] + + def __add__(self, other: SaColumn) -> 'sa.ColumnOperators': + """ Concate the array with the given array. If one of the + operants is null, the value of the other will be returned. + """ + return sa.func.array_cat(self, other, type_=IntArray) + + + def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators': + """ Return true if the array contains all the value of the argument + array. + """ + return cast('sa.ColumnOperators', self.op('@>', is_comparison=True)(other)) + + + def overlaps(self, other: SaColumn) -> 'sa.Operators': + """ Return true if at least one value of the argument is contained + in the array. + """ + return self.op('&&', is_comparison=True)(other) diff --git a/nominatim/db/sqlalchemy_types/json.py b/nominatim/db/sqlalchemy_types/json.py new file mode 100644 index 0000000000..31635fd518 --- /dev/null +++ b/nominatim/db/sqlalchemy_types/json.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Common json type for different dialects. +""" +from typing import Any + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.sqlite import JSON as sqlite_json + +from nominatim.typing import SaDialect + +# pylint: disable=all + +class Json(sa.types.TypeDecorator[Any]): + """ Dialect-independent type for JSON. + """ + impl = sa.types.JSON + cache_ok = True + + def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]: + if dialect.name == 'postgresql': + return JSONB(none_as_null=True) # type: ignore[no-untyped-call] + + return sqlite_json(none_as_null=True) diff --git a/nominatim/db/sqlalchemy_types/key_value.py b/nominatim/db/sqlalchemy_types/key_value.py new file mode 100644 index 0000000000..4f2d824aff --- /dev/null +++ b/nominatim/db/sqlalchemy_types/key_value.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +A custom type that implements a simple key-value store of strings. +""" +from typing import Any + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import HSTORE +from sqlalchemy.dialects.sqlite import JSON as sqlite_json + +from nominatim.typing import SaDialect, SaColumn + +# pylint: disable=all + +class KeyValueStore(sa.types.TypeDecorator[Any]): + """ Dialect-independent type of a simple key-value store of strings. + """ + impl = HSTORE + cache_ok = True + + def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]: + if dialect.name == 'postgresql': + return HSTORE() # type: ignore[no-untyped-call] + + return sqlite_json(none_as_null=True) + + + class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg] + + def merge(self, other: SaColumn) -> 'sa.Operators': + """ Merge the values from the given KeyValueStore into this + one, overwriting values where necessary. When the argument + is null, nothing happens. + """ + return self.op('||')(sa.func.coalesce(other, + sa.type_coerce('', KeyValueStore))) + + + def has_key(self, key: SaColumn) -> 'sa.Operators': + """ Return true if the key is cotained in the store. + """ + return self.op('?', is_comparison=True)(key) diff --git a/nominatim/typing.py b/nominatim/typing.py index 7274f1d396..62ecd8c3e1 100644 --- a/nominatim/typing.py +++ b/nominatim/typing.py @@ -72,3 +72,4 @@ def __getitem__(self, x: Union[int, str]) -> Any: ... SaFromClause: TypeAlias = 'sa.FromClause' SaSelectable: TypeAlias = 'sa.Selectable' SaBind: TypeAlias = 'sa.BindParameter[Any]' +SaDialect: TypeAlias = 'sa.Dialect' From 1b7c8240baeb383a2481f895fc85a74aaa16e94e Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Tue, 5 Dec 2023 12:22:00 +0100 Subject: [PATCH 02/17] enable connection pools for sqlite Connecting is reasonably expensive because the spatialite extension needs to be loaded. Disable pooling for tests because there is some memory leak when quickly opening and closing QueuePools with sqlite connections. --- nominatim/api/core.py | 42 ++++++++++++++++++++++--------------- test/python/api/conftest.py | 3 ++- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/nominatim/api/core.py b/nominatim/api/core.py index c8045c2d14..b262422758 100644 --- a/nominatim/api/core.py +++ b/nominatim/api/core.py @@ -84,6 +84,14 @@ async def setup_database(self) -> None: extra_args: Dict[str, Any] = {'future': True, 'echo': self.config.get_bool('DEBUG_SQL')} + if self.config.get_int('API_POOL_SIZE') == 0: + extra_args['poolclass'] = sa.pool.NullPool + else: + extra_args['poolclass'] = sa.pool.QueuePool + extra_args['max_overflow'] = 0 + extra_args['pool_size'] = self.config.get_int('API_POOL_SIZE') + + is_sqlite = self.config.DATABASE_DSN.startswith('sqlite:') if is_sqlite: @@ -105,28 +113,12 @@ async def setup_database(self) -> None: host=dsn.get('host'), port=int(dsn['port']) if 'port' in dsn else None, query=query) - extra_args['max_overflow'] = 0 - extra_args['pool_size'] = self.config.get_int('API_POOL_SIZE') engine = sa_asyncio.create_async_engine(dburl, **extra_args) - try: - async with engine.begin() as conn: - result = await conn.scalar(sa.text('SHOW server_version_num')) - server_version = int(result) - except (PGCORE_ERROR, sa.exc.OperationalError): + if is_sqlite: server_version = 0 - if server_version >= 110000 and not is_sqlite: - @sa.event.listens_for(engine.sync_engine, "connect") - def _on_connect(dbapi_con: Any, _: Any) -> None: - cursor = dbapi_con.cursor() - cursor.execute("SET jit_above_cost TO '-1'") - cursor.execute("SET max_parallel_workers_per_gather TO '0'") - # Make sure that all connections get the new settings - await self.close() - - if is_sqlite: @sa.event.listens_for(engine.sync_engine, "connect") def _on_sqlite_connect(dbapi_con: Any, _: Any) -> None: dbapi_con.run_async(lambda conn: conn.enable_load_extension(True)) @@ -134,6 +126,22 @@ def _on_sqlite_connect(dbapi_con: Any, _: Any) -> None: cursor.execute("SELECT load_extension('mod_spatialite')") cursor.execute('SELECT SetDecimalPrecision(7)') dbapi_con.run_async(lambda conn: conn.enable_load_extension(False)) + else: + try: + async with engine.begin() as conn: + result = await conn.scalar(sa.text('SHOW server_version_num')) + server_version = int(result) + except (PGCORE_ERROR, sa.exc.OperationalError): + server_version = 0 + + if server_version >= 110000: + @sa.event.listens_for(engine.sync_engine, "connect") + def _on_connect(dbapi_con: Any, _: Any) -> None: + cursor = dbapi_con.cursor() + cursor.execute("SET jit_above_cost TO '-1'") + cursor.execute("SET max_parallel_workers_per_gather TO '0'") + # Make sure that all connections get the new settings + await engine.dispose() self._property_cache['DB:server_version'] = server_version diff --git a/test/python/api/conftest.py b/test/python/api/conftest.py index cb7f324a39..8f0604d407 100644 --- a/test/python/api/conftest.py +++ b/test/python/api/conftest.py @@ -198,7 +198,8 @@ def mkapi(apiobj, options={'reverse'}): db, options)) return napi.NominatimAPI(Path('/invalid'), {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={db}", - 'NOMINATIM_USE_US_TIGER_DATA': 'yes'}) + 'NOMINATIM_USE_US_TIGER_DATA': 'yes', + 'NOMINATIM_API_POOL_SIZE': '0'}) elif request.param == 'postgres_db': def mkapi(apiobj, options=None): return apiobj.api From 05e47fbb28eb0f3f7803b6bfe194896b6e6c1ed0 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Tue, 5 Dec 2023 15:23:16 +0100 Subject: [PATCH 03/17] fix parameter formatting in sqlite debug output --- nominatim/api/logging.py | 44 +++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/nominatim/api/logging.py b/nominatim/api/logging.py index 37ae7f5f04..e16e0bd2d3 100644 --- a/nominatim/api/logging.py +++ b/nominatim/api/logging.py @@ -90,26 +90,42 @@ def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable', params = dict(compiled.params) if isinstance(extra_params, Mapping): for k, v in extra_params.items(): - params[k] = str(v) + if hasattr(v, 'to_wkt'): + params[k] = v.to_wkt() + elif isinstance(v, (int, float)): + params[k] = v + else: + params[k] = str(v) elif isinstance(extra_params, Sequence) and extra_params: for k in extra_params[0]: params[k] = f':{k}' sqlstr = str(compiled) - if sa.__version__.startswith('1'): - try: - sqlstr = re.sub(r'__\[POSTCOMPILE_[^]]*\]', '%s', sqlstr) - return sqlstr % tuple((repr(params.get(name, None)) - for name in compiled.positiontup)) # type: ignore - except TypeError: - return sqlstr - - # Fixes an odd issue with Python 3.7 where percentages are not - # quoted correctly. - sqlstr = re.sub(r'%(?!\()', '%%', sqlstr) - sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', r'%(\1)s', sqlstr) - return sqlstr % params + if conn.dialect.name == 'postgresql': + if sa.__version__.startswith('1'): + try: + sqlstr = re.sub(r'__\[POSTCOMPILE_[^]]*\]', '%s', sqlstr) + return sqlstr % tuple((repr(params.get(name, None)) + for name in compiled.positiontup)) # type: ignore + except TypeError: + return sqlstr + + # Fixes an odd issue with Python 3.7 where percentages are not + # quoted correctly. + sqlstr = re.sub(r'%(?!\()', '%%', sqlstr) + sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', r'%(\1)s', sqlstr) + return sqlstr % params + + assert conn.dialect.name == 'sqlite' + + # params in positional order + pparams = (repr(params.get(name, None)) for name in compiled.positiontup) # type: ignore + + sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', '?', sqlstr) + sqlstr = re.sub(r"\?", lambda m: next(pparams), sqlstr) + + return sqlstr class HTMLLogger(BaseLogger): """ Logger that formats messages in HTML. From c41f2fed2133668dc3179813261d39d3ff69cbdd Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Tue, 5 Dec 2023 16:07:56 +0100 Subject: [PATCH 04/17] simplify weigh_search() function Use JSON arrays which can have mixed types and therefore have a more logical structure than separate arrays. Avoid JSON dicts because of their verboseness. --- lib-sql/functions/ranking.sql | 14 ++++++-------- nominatim/api/search/db_search_fields.py | 17 ++++++++++++----- nominatim/utils/json_writer.py | 4 ++-- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/lib-sql/functions/ranking.sql b/lib-sql/functions/ranking.sql index 0b18954ced..97a0cde38e 100644 --- a/lib-sql/functions/ranking.sql +++ b/lib-sql/functions/ranking.sql @@ -287,21 +287,19 @@ LANGUAGE plpgsql IMMUTABLE; CREATE OR REPLACE FUNCTION weigh_search(search_vector INT[], - term_vectors TEXT[], - weight_vectors FLOAT[], + rankings TEXT, def_weight FLOAT) RETURNS FLOAT AS $$ DECLARE - pos INT := 1; - terms TEXT; + rank JSON; BEGIN - FOREACH terms IN ARRAY term_vectors + FOR rank IN + SELECT * FROM json_array_elements(rankings::JSON) LOOP - IF search_vector @> terms::INTEGER[] THEN - RETURN weight_vectors[pos]; + IF true = ALL(SELECT x::int = ANY(search_vector) FROM json_array_elements_text(rank->1) as x) THEN + RETURN (rank->>0)::float; END IF; - pos := pos + 1; END LOOP; RETURN def_weight; END; diff --git a/nominatim/api/search/db_search_fields.py b/nominatim/api/search/db_search_fields.py index 52693e95fc..324a7acc2c 100644 --- a/nominatim/api/search/db_search_fields.py +++ b/nominatim/api/search/db_search_fields.py @@ -14,6 +14,7 @@ from nominatim.typing import SaFromClause, SaColumn, SaExpression from nominatim.api.search.query import Token +from nominatim.utils.json_writer import JsonWriter @dataclasses.dataclass class WeightedStrings: @@ -128,11 +129,17 @@ def sql_penalty(self, table: SaFromClause) -> SaColumn: """ assert self.rankings - return sa.func.weigh_search(table.c[self.column], - [f"{{{','.join((str(s) for s in r.tokens))}}}" - for r in self.rankings], - [r.penalty for r in self.rankings], - self.default) + rout = JsonWriter().start_array() + for rank in self.rankings: + rout.start_array().value(rank.penalty).next() + rout.start_array() + for token in rank.tokens: + rout.value(token).next() + rout.end_array() + rout.end_array().next() + rout.end_array() + + return sa.func.weigh_search(table.c[self.column], rout(), self.default) @dataclasses.dataclass diff --git a/nominatim/utils/json_writer.py b/nominatim/utils/json_writer.py index bb642233e7..fcc355d5ee 100644 --- a/nominatim/utils/json_writer.py +++ b/nominatim/utils/json_writer.py @@ -76,8 +76,8 @@ def start_array(self) -> 'JsonWriter': def end_array(self) -> 'JsonWriter': """ Write the closing bracket of a JSON array. """ - assert self.pending in (',', '[', '') - if self.pending == '[': + assert self.pending in (',', '[', ']', ')', '') + if self.pending not in (',', ''): self.data.write(self.pending) self.pending = ']' return self From 615b166c6850c0fa2af2e370147b26b978be3659 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Tue, 5 Dec 2023 18:02:40 +0100 Subject: [PATCH 05/17] clean up ST_DWithin and intersects() functions A non-index version of ST_DWithin is not necessary. ST_Distance can be used for that purpose. Index use for intersects can be covered with a simple parameter. --- nominatim/api/reverse.py | 8 ++--- nominatim/api/search/db_searches.py | 37 +++++++++++------------ nominatim/db/sqlalchemy_types/geometry.py | 34 ++++----------------- 3 files changed, 28 insertions(+), 51 deletions(-) diff --git a/nominatim/api/reverse.py b/nominatim/api/reverse.py index fb4c0b23d0..df5c10f266 100644 --- a/nominatim/api/reverse.py +++ b/nominatim/api/reverse.py @@ -180,7 +180,7 @@ async def _find_closest_street_or_poi(self, distance: float) -> Optional[SaRow]: diststr = sa.text(f"{distance}") sql: SaLambdaSelect = sa.lambda_stmt(lambda: _select_from_placex(t) - .where(t.c.geometry.ST_DWithin(WKT_PARAM, diststr)) + .where(t.c.geometry.within_distance(WKT_PARAM, diststr)) .where(t.c.indexed_status == 0) .where(t.c.linked_place_id == None) .where(sa.or_(sa.not_(t.c.geometry.is_area()), @@ -219,7 +219,7 @@ async def _find_housenumber_for_street(self, parent_place_id: int) -> Optional[S t = self.conn.t.placex sql: SaLambdaSelect = sa.lambda_stmt(lambda: _select_from_placex(t) - .where(t.c.geometry.ST_DWithin(WKT_PARAM, 0.001)) + .where(t.c.geometry.within_distance(WKT_PARAM, 0.001)) .where(t.c.parent_place_id == parent_place_id) .where(sa.func.IsAddressPoint(t)) .where(t.c.indexed_status == 0) @@ -241,7 +241,7 @@ async def _find_interpolation_for_street(self, parent_place_id: Optional[int], sa.select(t, t.c.linegeo.ST_Distance(WKT_PARAM).label('distance'), _locate_interpolation(t)) - .where(t.c.linegeo.ST_DWithin(WKT_PARAM, distance)) + .where(t.c.linegeo.within_distance(WKT_PARAM, distance)) .where(t.c.startnumber != None) .order_by('distance') .limit(1)) @@ -275,7 +275,7 @@ def _base_query() -> SaSelect: inner = sa.select(t, t.c.linegeo.ST_Distance(WKT_PARAM).label('distance'), _locate_interpolation(t))\ - .where(t.c.linegeo.ST_DWithin(WKT_PARAM, 0.001))\ + .where(t.c.linegeo.within_distance(WKT_PARAM, 0.001))\ .where(t.c.parent_place_id == parent_place_id)\ .order_by('distance')\ .limit(1)\ diff --git a/nominatim/api/search/db_searches.py b/nominatim/api/search/db_searches.py index 2b4dfd3c9b..48bd6272c8 100644 --- a/nominatim/api/search/db_searches.py +++ b/nominatim/api/search/db_searches.py @@ -56,7 +56,7 @@ def _details_to_bind_params(details: SearchDetails) -> Dict[str, Any]: COUNTRIES_PARAM: SaBind = sa.bindparam('countries') def _within_near(t: SaFromClause) -> Callable[[], SaExpression]: - return lambda: t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM) + return lambda: t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM) def _exclude_places(t: SaFromClause) -> Callable[[], SaExpression]: return lambda: t.c.place_id.not_in(sa.bindparam('excluded')) @@ -366,7 +366,7 @@ def _base_query() -> SaSelect: .add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM)) .label('importance'))\ .where(t.c.linked_place_id == None) \ - .where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)) \ + .where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)) \ .order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \ .limit(LIMIT_PARAM) @@ -403,8 +403,8 @@ def _base_query() -> SaSelect: if details.near and details.near_radius is not None: sql = sql.order_by(table.c.centroid.ST_Distance(NEAR_PARAM))\ - .where(table.c.centroid.ST_DWithin(NEAR_PARAM, - NEAR_RADIUS_PARAM)) + .where(table.c.centroid.within_distance(NEAR_PARAM, + NEAR_RADIUS_PARAM)) if self.countries: sql = sql.where(t.c.country_code.in_(self.countries.values)) @@ -632,11 +632,11 @@ async def lookup(self, conn: SearchConnection, sql = sql.where(tsearch.c.address_rank > 9) tpc = conn.t.postcode pcs = self.postcodes.values - if self.expected_count > 1000: + if self.expected_count > 5000: # Many results expected. Restrict by postcode. sql = sql.where(sa.select(tpc.c.postcode) .where(tpc.c.postcode.in_(pcs)) - .where(tsearch.c.centroid.ST_DWithin(tpc.c.geometry, 0.12)) + .where(tsearch.c.centroid.within_distance(tpc.c.geometry, 0.12)) .exists()) # Less results, only have a preference for close postcodes @@ -648,27 +648,26 @@ async def lookup(self, conn: SearchConnection, if details.viewbox is not None: if details.bounded_viewbox: - if details.viewbox.area < 0.2: - sql = sql.where(tsearch.c.centroid.intersects(VIEWBOX_PARAM)) - else: - sql = sql.where(tsearch.c.centroid.ST_Intersects_no_index(VIEWBOX_PARAM)) + sql = sql.where(tsearch.c.centroid + .intersects(VIEWBOX_PARAM, + use_index=details.viewbox.area < 0.2)) elif self.expected_count >= 10000: - if details.viewbox.area < 0.5: - sql = sql.where(tsearch.c.centroid.intersects(VIEWBOX2_PARAM)) - else: - sql = sql.where(tsearch.c.centroid.ST_Intersects_no_index(VIEWBOX2_PARAM)) + sql = sql.where(tsearch.c.centroid + .intersects(VIEWBOX2_PARAM, + use_index=details.viewbox.area < 0.5)) else: - penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0), - (t.c.geometry.intersects(VIEWBOX2_PARAM), 0.5), + penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM, use_index=False), 0.0), + (t.c.geometry.intersects(VIEWBOX2_PARAM, use_index=False), 0.5), else_=1.0) if details.near is not None: if details.near_radius is not None: if details.near_radius < 0.1: - sql = sql.where(tsearch.c.centroid.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)) + sql = sql.where(tsearch.c.centroid.within_distance(NEAR_PARAM, + NEAR_RADIUS_PARAM)) else: - sql = sql.where(tsearch.c.centroid.ST_DWithin_no_index(NEAR_PARAM, - NEAR_RADIUS_PARAM)) + sql = sql.where(tsearch.c.centroid + .ST_Distance(NEAR_PARAM) < NEAR_RADIUS_PARAM) sql = sql.add_columns((-tsearch.c.centroid.ST_Distance(NEAR_PARAM)) .label('importance')) sql = sql.order_by(sa.desc(sa.text('importance'))) diff --git a/nominatim/db/sqlalchemy_types/geometry.py b/nominatim/db/sqlalchemy_types/geometry.py index a36e8c462a..4520fc8e53 100644 --- a/nominatim/db/sqlalchemy_types/geometry.py +++ b/nominatim/db/sqlalchemy_types/geometry.py @@ -165,7 +165,6 @@ def spatialite_dwithin_column(element: SaColumn, compiler.process(dist, **kw)) - class Geometry(types.UserDefinedType): # type: ignore[type-arg] """ Simplified type decorator for PostGIS geometry. This type only supports geometries in 4326 projection. @@ -206,7 +205,10 @@ def bind_expression(self, bindvalue: SaBind) -> SaColumn: class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg] - def intersects(self, other: SaColumn) -> 'sa.Operators': + def intersects(self, other: SaColumn, use_index: bool = True) -> 'sa.Operators': + if not use_index: + return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self.expr), other) + if isinstance(self.expr, sa.Column): return Geometry_ColumnIntersectsBbox(self.expr, other) @@ -221,20 +223,11 @@ def is_area(self) -> SaColumn: return Geometry_IsAreaLike(self) - def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn: + def within_distance(self, other: SaColumn, distance: SaColumn) -> SaColumn: if isinstance(self.expr, sa.Column): return Geometry_ColumnDWithin(self.expr, other, distance) - return sa.func.ST_DWithin(self.expr, other, distance) - - - def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn: - return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self), - other, distance) - - - def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators': - return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self), other) + return self.ST_Distance(other) < distance def ST_Distance(self, other: SaColumn) -> SaColumn: @@ -313,18 +306,3 @@ def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any: for alias in SQLITE_FUNCTION_ALIAS: _add_function_alias(*alias) - - -class ST_DWithin(sa.sql.functions.GenericFunction[Any]): - name = 'ST_DWithin' - inherit_cache = True - - -@compiles(ST_DWithin, 'sqlite') # type: ignore[no-untyped-call, misc] -def default_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str: - geom1, geom2, dist = list(element.clauses) - return "(MbrIntersects(%s, ST_Expand(%s, %s)) = 1 AND ST_Distance(%s, %s) <= %s)" % ( - compiler.process(geom1, **kw), compiler.process(geom2, **kw), - compiler.process(dist, **kw), - compiler.process(geom1, **kw), compiler.process(geom2, **kw), - compiler.process(dist, **kw)) From 8791c6cb69e344dac314078100c5c030b04e26f1 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Tue, 5 Dec 2023 21:20:57 +0100 Subject: [PATCH 06/17] correctly close API objects during testing --- nominatim/tools/convert_sqlite.py | 9 ++++++--- test/python/api/conftest.py | 16 +++++++++++----- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/nominatim/tools/convert_sqlite.py b/nominatim/tools/convert_sqlite.py index 0702e5d8c0..16f51b661a 100644 --- a/nominatim/tools/convert_sqlite.py +++ b/nominatim/tools/convert_sqlite.py @@ -29,9 +29,12 @@ async def convert(project_dir: Path, outfile: Path, options: Set[str]) -> None: outapi = napi.NominatimAPIAsync(project_dir, {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={outfile}"}) - async with api.begin() as src, outapi.begin() as dest: - writer = SqliteWriter(src, dest, options) - await writer.write() + try: + async with api.begin() as src, outapi.begin() as dest: + writer = SqliteWriter(src, dest, options) + await writer.write() + finally: + await outapi.close() finally: await api.close() diff --git a/test/python/api/conftest.py b/test/python/api/conftest.py index 8f0604d407..91a3107fbc 100644 --- a/test/python/api/conftest.py +++ b/test/python/api/conftest.py @@ -190,18 +190,24 @@ def apiobj(temp_db_with_extensions, temp_db_conn, monkeypatch): @pytest.fixture(params=['postgres_db', 'sqlite_db']) def frontend(request, event_loop, tmp_path): + testapis = [] if request.param == 'sqlite_db': db = str(tmp_path / 'test_nominatim_python_unittest.sqlite') def mkapi(apiobj, options={'reverse'}): event_loop.run_until_complete(convert_sqlite.convert(Path('/invalid'), db, options)) - return napi.NominatimAPI(Path('/invalid'), - {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={db}", - 'NOMINATIM_USE_US_TIGER_DATA': 'yes', - 'NOMINATIM_API_POOL_SIZE': '0'}) + outapi = napi.NominatimAPI(Path('/invalid'), + {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={db}", + 'NOMINATIM_USE_US_TIGER_DATA': 'yes'}) + testapis.append(outapi) + + return outapi elif request.param == 'postgres_db': def mkapi(apiobj, options=None): return apiobj.api - return mkapi + yield mkapi + + for api in testapis: + api.close() From b06f5fddcbe9c716afddde1e6d02df6f43ec1081 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 6 Dec 2023 10:37:06 +0100 Subject: [PATCH 07/17] simplify handling of SQL lookup code for search_name Use function classes which can be instantiated directly. --- nominatim/api/search/db_search_builder.py | 19 ++--- nominatim/api/search/db_search_fields.py | 29 ++++--- nominatim/api/search/db_search_lookups.py | 78 +++++++++++++++++++ nominatim/api/search/db_searches.py | 1 - .../api/search/test_db_search_builder.py | 16 ++-- test/python/api/search/test_search_near.py | 3 +- test/python/api/search/test_search_places.py | 55 ++++++------- 7 files changed, 139 insertions(+), 62 deletions(-) create mode 100644 nominatim/api/search/db_search_lookups.py diff --git a/nominatim/api/search/db_search_builder.py b/nominatim/api/search/db_search_builder.py index c755f2a74f..fd8cc7af90 100644 --- a/nominatim/api/search/db_search_builder.py +++ b/nominatim/api/search/db_search_builder.py @@ -15,6 +15,7 @@ from nominatim.api.search.token_assignment import TokenAssignment import nominatim.api.search.db_search_fields as dbf import nominatim.api.search.db_searches as dbs +import nominatim.api.search.db_search_lookups as lookups def wrap_near_search(categories: List[Tuple[str, str]], @@ -152,7 +153,7 @@ def build_special_search(self, sdata: dbf.SearchData, sdata.lookups = [dbf.FieldLookup('nameaddress_vector', [t.token for r in address for t in self.query.get_partials_list(r)], - 'restrict')] + lookups.Restrict)] penalty += 0.2 yield dbs.PostcodeSearch(penalty, sdata) @@ -162,7 +163,7 @@ def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token], """ Build a simple address search for special entries where the housenumber is the main name token. """ - sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], 'lookup_any')] + sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], lookups.LookupAny)] expected_count = sum(t.count for t in hnrs) partials = [t for trange in address @@ -170,16 +171,16 @@ def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token], if expected_count < 8000: sdata.lookups.append(dbf.FieldLookup('nameaddress_vector', - [t.token for t in partials], 'restrict')) + [t.token for t in partials], lookups.Restrict)) elif len(partials) != 1 or partials[0].count < 10000: sdata.lookups.append(dbf.FieldLookup('nameaddress_vector', - [t.token for t in partials], 'lookup_all')) + [t.token for t in partials], lookups.LookupAll)) else: sdata.lookups.append( dbf.FieldLookup('nameaddress_vector', [t.token for t in self.query.get_tokens(address[0], TokenType.WORD)], - 'lookup_any')) + lookups.LookupAny)) sdata.housenumbers = dbf.WeightedStrings([], []) yield dbs.PlaceSearch(0.05, sdata, expected_count) @@ -232,16 +233,16 @@ def yield_lookups(self, name: TokenRange, address: List[TokenRange])\ penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed) # Any of the full names applies with all of the partials from the address yield penalty, fulls_count / (2**len(addr_partials)),\ - dbf.lookup_by_any_name([t.token for t in name_fulls], addr_tokens, - 'restrict' if fulls_count < 10000 else 'lookup_all') + dbf.lookup_by_any_name([t.token for t in name_fulls], + addr_tokens, fulls_count > 10000) # To catch remaining results, lookup by name and address # We only do this if there is a reasonable number of results expected. exp_count = exp_count / (2**len(addr_partials)) if addr_partials else exp_count if exp_count < 10000 and all(t.is_indexed for t in name_partials): - lookup = [dbf.FieldLookup('name_vector', name_tokens, 'lookup_all')] + lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)] if addr_tokens: - lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')) + lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)) penalty += 0.35 * max(0, 5 - len(name_partials) - len(addr_tokens)) yield penalty, exp_count, lookup diff --git a/nominatim/api/search/db_search_fields.py b/nominatim/api/search/db_search_fields.py index 324a7acc2c..6947a565f8 100644 --- a/nominatim/api/search/db_search_fields.py +++ b/nominatim/api/search/db_search_fields.py @@ -7,15 +7,17 @@ """ Data structures for more complex fields in abstract search descriptions. """ -from typing import List, Tuple, Iterator, cast, Dict +from typing import List, Tuple, Iterator, Dict, Type import dataclasses import sqlalchemy as sa from nominatim.typing import SaFromClause, SaColumn, SaExpression from nominatim.api.search.query import Token +import nominatim.api.search.db_search_lookups as lookups from nominatim.utils.json_writer import JsonWriter + @dataclasses.dataclass class WeightedStrings: """ A list of strings together with a penalty. @@ -152,18 +154,12 @@ class FieldLookup: """ column: str tokens: List[int] - lookup_type: str + lookup_type: Type[lookups.LookupType] def sql_condition(self, table: SaFromClause) -> SaColumn: """ Create an SQL expression for the given match condition. """ - col = table.c[self.column] - if self.lookup_type == 'lookup_all': - return col.contains(self.tokens) - if self.lookup_type == 'lookup_any': - return cast(SaColumn, col.overlaps(self.tokens)) - - return sa.func.coalesce(sa.null(), col).contains(self.tokens) # pylint: disable=not-callable + return self.lookup_type(table, self.column, self.tokens) class SearchData: @@ -229,22 +225,23 @@ def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[Fiel """ Create a lookup list where name tokens are looked up via index and potential address tokens are used to restrict the search further. """ - lookup = [FieldLookup('name_vector', name_tokens, 'lookup_all')] + lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)] if addr_tokens: - lookup.append(FieldLookup('nameaddress_vector', addr_tokens, 'restrict')) + lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict)) return lookup def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int], - lookup_type: str) -> List[FieldLookup]: + use_index_for_addr: bool) -> List[FieldLookup]: """ Create a lookup list where name tokens are looked up via index and only one of the name tokens must be present. Potential address tokens are used to restrict the search further. """ - lookup = [FieldLookup('name_vector', name_tokens, 'lookup_any')] + lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)] if addr_tokens: - lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookup_type)) + lookup.append(FieldLookup('nameaddress_vector', addr_tokens, + lookups.LookupAll if use_index_for_addr else lookups.Restrict)) return lookup @@ -253,5 +250,5 @@ def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[Field """ Create a lookup list where address tokens are looked up via index and the name tokens are only used to restrict the search further. """ - return [FieldLookup('name_vector', name_tokens, 'restrict'), - FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')] + return [FieldLookup('name_vector', name_tokens, lookups.Restrict), + FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)] diff --git a/nominatim/api/search/db_search_lookups.py b/nominatim/api/search/db_search_lookups.py new file mode 100644 index 0000000000..3e307235b8 --- /dev/null +++ b/nominatim/api/search/db_search_lookups.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Implementation of lookup functions for the search_name table. +""" +from typing import List, Any + +import sqlalchemy as sa +from sqlalchemy.ext.compiler import compiles + +from nominatim.typing import SaFromClause +from nominatim.db.sqlalchemy_types import IntArray + +# pylint: disable=consider-using-f-string + +LookupType = sa.sql.expression.FunctionElement[Any] + +class LookupAll(LookupType): + """ Find all entries in search_name table that contain all of + a given list of tokens using an index for the search. + """ + inherit_cache = True + + def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None: + super().__init__(getattr(table.c, column), + sa.type_coerce(tokens, IntArray)) + + +@compiles(LookupAll) # type: ignore[no-untyped-call, misc] +def _default_lookup_all(element: LookupAll, + compiler: 'sa.Compiled', **kw: Any) -> str: + col, tokens = list(element.clauses) + return "(%s @> %s)" % (compiler.process(col, **kw), + compiler.process(tokens, **kw)) + + + +class LookupAny(LookupType): + """ Find all entries that contain at least one of the given tokens. + Use an index for the search. + """ + inherit_cache = True + + def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None: + super().__init__(getattr(table.c, column), + sa.type_coerce(tokens, IntArray)) + + +@compiles(LookupAny) # type: ignore[no-untyped-call, misc] +def _default_lookup_any(element: LookupAny, + compiler: 'sa.Compiled', **kw: Any) -> str: + col, tokens = list(element.clauses) + return "(%s && %s)" % (compiler.process(col, **kw), + compiler.process(tokens, **kw)) + + + +class Restrict(LookupType): + """ Find all entries that contain all of the given tokens. + Do not use an index for the search. + """ + inherit_cache = True + + def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None: + super().__init__(getattr(table.c, column), + sa.type_coerce(tokens, IntArray)) + + +@compiles(Restrict) # type: ignore[no-untyped-call, misc] +def _default_restrict(element: Restrict, + compiler: 'sa.Compiled', **kw: Any) -> str: + arg1, arg2 = list(element.clauses) + return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw), + compiler.process(arg2, **kw)) diff --git a/nominatim/api/search/db_searches.py b/nominatim/api/search/db_searches.py index 48bd6272c8..35c1274659 100644 --- a/nominatim/api/search/db_searches.py +++ b/nominatim/api/search/db_searches.py @@ -563,7 +563,6 @@ async def lookup(self, conn: SearchConnection, if self.lookups: assert len(self.lookups) == 1 - assert self.lookups[0].lookup_type == 'restrict' tsearch = conn.t.search_name sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\ .where((tsearch.c.name_vector + tsearch.c.nameaddress_vector) diff --git a/test/python/api/search/test_db_search_builder.py b/test/python/api/search/test_db_search_builder.py index 87d7526152..d3aea90002 100644 --- a/test/python/api/search/test_db_search_builder.py +++ b/test/python/api/search/test_db_search_builder.py @@ -420,8 +420,8 @@ def test_infrequent_partials_in_name(): assert len(search.lookups) == 2 assert len(search.rankings) == 2 - assert set((l.column, l.lookup_type) for l in search.lookups) == \ - {('name_vector', 'lookup_all'), ('nameaddress_vector', 'restrict')} + assert set((l.column, l.lookup_type.__name__) for l in search.lookups) == \ + {('name_vector', 'LookupAll'), ('nameaddress_vector', 'Restrict')} def test_frequent_partials_in_name_and_address(): @@ -432,10 +432,10 @@ def test_frequent_partials_in_name_and_address(): assert all(isinstance(s, dbs.PlaceSearch) for s in searches) searches.sort(key=lambda s: s.penalty) - assert set((l.column, l.lookup_type) for l in searches[0].lookups) == \ - {('name_vector', 'lookup_any'), ('nameaddress_vector', 'restrict')} - assert set((l.column, l.lookup_type) for l in searches[1].lookups) == \ - {('nameaddress_vector', 'lookup_all'), ('name_vector', 'lookup_all')} + assert set((l.column, l.lookup_type.__name__) for l in searches[0].lookups) == \ + {('name_vector', 'LookupAny'), ('nameaddress_vector', 'Restrict')} + assert set((l.column, l.lookup_type.__name__) for l in searches[1].lookups) == \ + {('nameaddress_vector', 'LookupAll'), ('name_vector', 'LookupAll')} def test_too_frequent_partials_in_name_and_address(): @@ -446,5 +446,5 @@ def test_too_frequent_partials_in_name_and_address(): assert all(isinstance(s, dbs.PlaceSearch) for s in searches) searches.sort(key=lambda s: s.penalty) - assert set((l.column, l.lookup_type) for l in searches[0].lookups) == \ - {('name_vector', 'lookup_any'), ('nameaddress_vector', 'restrict')} + assert set((l.column, l.lookup_type.__name__) for l in searches[0].lookups) == \ + {('name_vector', 'LookupAny'), ('nameaddress_vector', 'Restrict')} diff --git a/test/python/api/search/test_search_near.py b/test/python/api/search/test_search_near.py index 2a0acb7459..c0caa9ae6a 100644 --- a/test/python/api/search/test_search_near.py +++ b/test/python/api/search/test_search_near.py @@ -14,6 +14,7 @@ from nominatim.api.search.db_searches import NearSearch, PlaceSearch from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories,\ FieldLookup, FieldRanking, RankedTokens +from nominatim.api.search.db_search_lookups import LookupAll def run_search(apiobj, global_penalty, cat, cat_penalty=None, ccodes=[], @@ -25,7 +26,7 @@ class PlaceSearchData: countries = WeightedStrings(ccodes, [0.0] * len(ccodes)) housenumbers = WeightedStrings([], []) qualifiers = WeightedStrings([], []) - lookups = [FieldLookup('name_vector', [56], 'lookup_all')] + lookups = [FieldLookup('name_vector', [56], LookupAll)] rankings = [] if ccodes is not None: diff --git a/test/python/api/search/test_search_places.py b/test/python/api/search/test_search_places.py index 8a363e9773..44e4098dad 100644 --- a/test/python/api/search/test_search_places.py +++ b/test/python/api/search/test_search_places.py @@ -16,6 +16,7 @@ from nominatim.api.search.db_searches import PlaceSearch from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories,\ FieldLookup, FieldRanking, RankedTokens +from nominatim.api.search.db_search_lookups import LookupAll, LookupAny, Restrict def run_search(apiobj, global_penalty, lookup, ranking, count=2, hnrs=[], pcs=[], ccodes=[], quals=[], @@ -55,7 +56,7 @@ def fill_database(self, apiobj): centroid=(-10.3, 56.9)) - @pytest.mark.parametrize('lookup_type', ['lookup_all', 'restrict']) + @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict]) @pytest.mark.parametrize('rank,res', [([10], [100, 101]), ([20], [101, 100])]) def test_lookup_all_match(self, apiobj, lookup_type, rank, res): @@ -67,7 +68,7 @@ def test_lookup_all_match(self, apiobj, lookup_type, rank, res): assert [r.place_id for r in results] == res - @pytest.mark.parametrize('lookup_type', ['lookup_all', 'restrict']) + @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict]) def test_lookup_all_partial_match(self, apiobj, lookup_type): lookup = FieldLookup('name_vector', [1,20], lookup_type) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) @@ -80,7 +81,7 @@ def test_lookup_all_partial_match(self, apiobj, lookup_type): @pytest.mark.parametrize('rank,res', [([10], [100, 101]), ([20], [101, 100])]) def test_lookup_any_match(self, apiobj, rank, res): - lookup = FieldLookup('name_vector', [11,21], 'lookup_any') + lookup = FieldLookup('name_vector', [11,21], LookupAny) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)]) results = run_search(apiobj, 0.1, [lookup], [ranking]) @@ -89,7 +90,7 @@ def test_lookup_any_match(self, apiobj, rank, res): def test_lookup_any_partial_match(self, apiobj): - lookup = FieldLookup('name_vector', [20], 'lookup_all') + lookup = FieldLookup('name_vector', [20], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) results = run_search(apiobj, 0.1, [lookup], [ranking]) @@ -100,7 +101,7 @@ def test_lookup_any_partial_match(self, apiobj): @pytest.mark.parametrize('cc,res', [('us', 100), ('mx', 101)]) def test_lookup_restrict_country(self, apiobj, cc, res): - lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])]) results = run_search(apiobj, 0.1, [lookup], [ranking], ccodes=[cc]) @@ -109,7 +110,7 @@ def test_lookup_restrict_country(self, apiobj, cc, res): def test_lookup_restrict_placeid(self, apiobj): - lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])]) results = run_search(apiobj, 0.1, [lookup], [ranking], @@ -123,7 +124,7 @@ def test_lookup_restrict_placeid(self, apiobj): napi.GeometryFormat.SVG, napi.GeometryFormat.TEXT]) def test_return_geometries(self, apiobj, geom): - lookup = FieldLookup('name_vector', [20], 'lookup_all') + lookup = FieldLookup('name_vector', [20], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) results = run_search(apiobj, 0.1, [lookup], [ranking], @@ -140,7 +141,7 @@ def test_return_simplified_geometry(self, apiobj, factor, npoints): apiobj.add_search_name(333, names=[55], country_code='us', centroid=(5.6, 4.3)) - lookup = FieldLookup('name_vector', [55], 'lookup_all') + lookup = FieldLookup('name_vector', [55], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) results = run_search(apiobj, 0.1, [lookup], [ranking], @@ -158,7 +159,7 @@ def test_return_simplified_geometry(self, apiobj, factor, npoints): @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.7,4.0,6.0,5.0']) @pytest.mark.parametrize('wcount,rids', [(2, [100, 101]), (20000, [100])]) def test_prefer_viewbox(self, apiobj, viewbox, wcount, rids): - lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + lookup = FieldLookup('name_vector', [1, 2], LookupAll) ranking = FieldRanking('name_vector', 0.2, [RankedTokens(0.0, [21])]) results = run_search(apiobj, 0.1, [lookup], [ranking]) @@ -171,7 +172,7 @@ def test_prefer_viewbox(self, apiobj, viewbox, wcount, rids): @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.55,4.27,5.62,4.31']) def test_force_viewbox(self, apiobj, viewbox): - lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + lookup = FieldLookup('name_vector', [1, 2], LookupAll) details=SearchDetails.from_kwargs({'viewbox': viewbox, 'bounded_viewbox': True}) @@ -181,7 +182,7 @@ def test_force_viewbox(self, apiobj, viewbox): def test_prefer_near(self, apiobj): - lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + lookup = FieldLookup('name_vector', [1, 2], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) results = run_search(apiobj, 0.1, [lookup], [ranking]) @@ -195,7 +196,7 @@ def test_prefer_near(self, apiobj): @pytest.mark.parametrize('radius', [0.09, 0.11]) def test_force_near(self, apiobj, radius): - lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + lookup = FieldLookup('name_vector', [1, 2], LookupAll) details=SearchDetails.from_kwargs({'near': '5.6,4.3', 'near_radius': radius}) @@ -242,7 +243,7 @@ def fill_database(self, apiobj): ('21', [2]), ('22', [2, 92]), ('24', [93]), ('25', [])]) def test_lookup_by_single_housenumber(self, apiobj, hnr, res): - lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=[hnr]) @@ -252,7 +253,7 @@ def test_lookup_by_single_housenumber(self, apiobj, hnr, res): @pytest.mark.parametrize('cc,res', [('es', [2, 1000]), ('pt', [92, 2000])]) def test_lookup_with_country_restriction(self, apiobj, cc, res): - lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], @@ -262,7 +263,7 @@ def test_lookup_with_country_restriction(self, apiobj, cc, res): def test_lookup_exclude_housenumber_placeid(self, apiobj): - lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], @@ -272,7 +273,7 @@ def test_lookup_exclude_housenumber_placeid(self, apiobj): def test_lookup_exclude_street_placeid(self, apiobj): - lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], @@ -282,7 +283,7 @@ def test_lookup_exclude_street_placeid(self, apiobj): def test_lookup_only_house_qualifier(self, apiobj): - lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], @@ -292,7 +293,7 @@ def test_lookup_only_house_qualifier(self, apiobj): def test_lookup_only_street_qualifier(self, apiobj): - lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], @@ -303,7 +304,7 @@ def test_lookup_only_street_qualifier(self, apiobj): @pytest.mark.parametrize('rank,found', [(26, True), (27, False), (30, False)]) def test_lookup_min_rank(self, apiobj, rank, found): - lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], @@ -317,7 +318,7 @@ def test_lookup_min_rank(self, apiobj, rank, found): napi.GeometryFormat.SVG, napi.GeometryFormat.TEXT]) def test_return_geometries(self, apiobj, geom): - lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + lookup = FieldLookup('name_vector', [1, 2], LookupAll) results = run_search(apiobj, 0.1, [lookup], [], hnrs=['20', '21', '22'], details=SearchDetails(geometry_output=geom)) @@ -337,7 +338,7 @@ def test_very_large_housenumber(apiobj): search_rank=26, address_rank=26, country_code='pt') - lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + lookup = FieldLookup('name_vector', [1, 2], LookupAll) results = run_search(apiobj, 0.1, [lookup], [], hnrs=['2467463524544'], details=SearchDetails()) @@ -365,7 +366,7 @@ def test_name_and_postcode(apiobj, wcount, rids): apiobj.add_postcode(place_id=100, country_code='ch', postcode='11225', geometry='POINT(10 10)') - lookup = FieldLookup('name_vector', [111], 'lookup_all') + lookup = FieldLookup('name_vector', [111], LookupAll) results = run_search(apiobj, 0.1, [lookup], [], pcs=['11225'], count=wcount, details=SearchDetails()) @@ -398,7 +399,7 @@ def fill_database(self, apiobj): @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])]) def test_lookup_housenumber(self, apiobj, hnr, res): - lookup = FieldLookup('name_vector', [111], 'lookup_all') + lookup = FieldLookup('name_vector', [111], LookupAll) results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr]) @@ -410,7 +411,7 @@ def test_lookup_housenumber(self, apiobj, hnr, res): napi.GeometryFormat.SVG, napi.GeometryFormat.TEXT]) def test_osmline_with_geometries(self, apiobj, geom): - lookup = FieldLookup('name_vector', [111], 'lookup_all') + lookup = FieldLookup('name_vector', [111], LookupAll) results = run_search(apiobj, 0.1, [lookup], [], hnrs=['21'], details=SearchDetails(geometry_output=geom)) @@ -446,7 +447,7 @@ def fill_database(self, apiobj): @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])]) def test_lookup_housenumber(self, apiobj, hnr, res): - lookup = FieldLookup('name_vector', [111], 'lookup_all') + lookup = FieldLookup('name_vector', [111], LookupAll) results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr]) @@ -458,7 +459,7 @@ def test_lookup_housenumber(self, apiobj, hnr, res): napi.GeometryFormat.SVG, napi.GeometryFormat.TEXT]) def test_tiger_with_geometries(self, apiobj, geom): - lookup = FieldLookup('name_vector', [111], 'lookup_all') + lookup = FieldLookup('name_vector', [111], LookupAll) results = run_search(apiobj, 0.1, [lookup], [], hnrs=['21'], details=SearchDetails(geometry_output=geom)) @@ -513,7 +514,7 @@ def fill_database(self, apiobj): (napi.DataLayer.MANMADE | napi.DataLayer.NATURAL, [225, 227]), (napi.DataLayer.MANMADE | napi.DataLayer.RAILWAY, [225, 226])]) def test_layers_rank30(self, apiobj, layer, res): - lookup = FieldLookup('name_vector', [34], 'lookup_any') + lookup = FieldLookup('name_vector', [34], LookupAny) results = run_search(apiobj, 0.1, [lookup], [], details=SearchDetails(layers=layer)) From b6c8c0e72bdd76a8214fc360cfeaf9188c17d8b1 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 6 Dec 2023 10:55:21 +0100 Subject: [PATCH 08/17] factor out SQL for filtering by location Also improves on the decision if an indexed is used or not. --- nominatim/api/search/db_searches.py | 47 +++++++++++++++++------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/nominatim/api/search/db_searches.py b/nominatim/api/search/db_searches.py index 35c1274659..a5461d6ad1 100644 --- a/nominatim/api/search/db_searches.py +++ b/nominatim/api/search/db_searches.py @@ -55,12 +55,29 @@ def _details_to_bind_params(details: SearchDetails) -> Dict[str, Any]: NEAR_RADIUS_PARAM: SaBind = sa.bindparam('near_radius') COUNTRIES_PARAM: SaBind = sa.bindparam('countries') -def _within_near(t: SaFromClause) -> Callable[[], SaExpression]: - return lambda: t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM) + +def filter_by_area(sql: SaSelect, t: SaFromClause, + details: SearchDetails, avoid_index: bool = False) -> SaSelect: + """ Apply SQL statements for filtering by viewbox and near point, + if applicable. + """ + if details.near is not None and details.near_radius is not None: + if details.near_radius < 0.1 and not avoid_index: + sql = sql.where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)) + else: + sql = sql.where(t.c.geometry.ST_Distance(NEAR_PARAM) <= NEAR_RADIUS_PARAM) + if details.viewbox is not None and details.bounded_viewbox: + sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM, + use_index=not avoid_index and + details.viewbox.area < 0.2)) + + return sql + def _exclude_places(t: SaFromClause) -> Callable[[], SaExpression]: return lambda: t.c.place_id.not_in(sa.bindparam('excluded')) + def _select_placex(t: SaFromClause) -> SaSelect: return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name, t.c.class_, t.c.type, @@ -449,11 +466,7 @@ async def lookup(self, conn: SearchConnection, if details.excluded: sql = sql.where(_exclude_places(t)) - if details.viewbox is not None and details.bounded_viewbox: - sql = sql.where(lambda: t.c.geometry.intersects(VIEWBOX_PARAM)) - - if details.near is not None and details.near_radius is not None: - sql = sql.where(_within_near(t)) + sql = filter_by_area(sql, t, details) results = nres.SearchResults() for row in await conn.execute(sql, _details_to_bind_params(details)): @@ -486,10 +499,7 @@ async def lookup_in_country_table(self, conn: SearchConnection, .where(tgrid.c.country_code.in_(self.countries.values))\ .group_by(tgrid.c.country_code) - if details.viewbox is not None and details.bounded_viewbox: - sql = sql.where(tgrid.c.geometry.intersects(VIEWBOX_PARAM)) - if details.near is not None and details.near_radius is not None: - sql = sql.where(_within_near(tgrid)) + sql = filter_by_area(sql, tgrid, details, avoid_index=True) sub = sql.subquery('grid') @@ -542,19 +552,16 @@ async def lookup(self, conn: SearchConnection, penalty: SaExpression = sa.literal(self.penalty) - if details.viewbox is not None: - if details.bounded_viewbox: - sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM)) - else: - penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0), - (t.c.geometry.intersects(VIEWBOX2_PARAM), 0.5), - else_=1.0) + if details.viewbox is not None and not details.bounded_viewbox: + penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0), + (t.c.geometry.intersects(VIEWBOX2_PARAM), 0.5), + else_=1.0) if details.near is not None: - if details.near_radius is not None: - sql = sql.where(_within_near(t)) sql = sql.order_by(t.c.geometry.ST_Distance(NEAR_PARAM)) + sql = filter_by_area(sql, t, details) + if self.countries: sql = sql.where(t.c.country_code.in_(self.countries.values)) From df6eddebcd74982092477357ecf4a457a9acf561 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 6 Dec 2023 11:03:12 +0100 Subject: [PATCH 09/17] void unnecessary aliases --- nominatim/api/search/db_searches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nominatim/api/search/db_searches.py b/nominatim/api/search/db_searches.py index a5461d6ad1..c56554fdc0 100644 --- a/nominatim/api/search/db_searches.py +++ b/nominatim/api/search/db_searches.py @@ -312,7 +312,7 @@ async def lookup_category(self, results: nres.SearchResults, if table is None: # No classtype table available, do a simplified lookup in placex. - table = conn.t.placex.alias('inner') + table = conn.t.placex sql = sa.select(table.c.place_id, sa.func.min(tgeom.c.centroid.ST_Distance(table.c.centroid)) .label('dist'))\ From b5c61e0b5b4e9f955d55c2368cdc904f9390b288 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 6 Dec 2023 11:13:12 +0100 Subject: [PATCH 10/17] improve typing for @compiles constructs The first parameter is in fact the self parameter referring to the function class. --- nominatim/db/sqlalchemy_functions.py | 26 +++++++++++------------ nominatim/db/sqlalchemy_types/geometry.py | 24 ++++++++++----------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/nominatim/db/sqlalchemy_functions.py b/nominatim/db/sqlalchemy_functions.py index cb04f7626f..8b1967381c 100644 --- a/nominatim/db/sqlalchemy_functions.py +++ b/nominatim/db/sqlalchemy_functions.py @@ -29,7 +29,7 @@ class PlacexGeometryReverseLookuppolygon(sa.sql.functions.GenericFunction[Any]): @compiles(PlacexGeometryReverseLookuppolygon) # type: ignore[no-untyped-call, misc] -def _default_intersects(element: SaColumn, +def _default_intersects(element: PlacexGeometryReverseLookuppolygon, compiler: 'sa.Compiled', **kw: Any) -> str: return ("(ST_GeometryType(placex.geometry) in ('ST_Polygon', 'ST_MultiPolygon')" " AND placex.rank_address between 4 and 25" @@ -40,7 +40,7 @@ def _default_intersects(element: SaColumn, @compiles(PlacexGeometryReverseLookuppolygon, 'sqlite') # type: ignore[no-untyped-call, misc] -def _sqlite_intersects(element: SaColumn, +def _sqlite_intersects(element: PlacexGeometryReverseLookuppolygon, compiler: 'sa.Compiled', **kw: Any) -> str: return ("(ST_GeometryType(placex.geometry) in ('POLYGON', 'MULTIPOLYGON')" " AND placex.rank_address between 4 and 25" @@ -61,7 +61,7 @@ def __init__(self, table: sa.Table, geom: SaColumn) -> None: @compiles(IntersectsReverseDistance) # type: ignore[no-untyped-call, misc] -def default_reverse_place_diameter(element: SaColumn, +def default_reverse_place_diameter(element: IntersectsReverseDistance, compiler: 'sa.Compiled', **kw: Any) -> str: table = element.tablename return f"({table}.rank_address between 4 and 25"\ @@ -74,7 +74,7 @@ def default_reverse_place_diameter(element: SaColumn, @compiles(IntersectsReverseDistance, 'sqlite') # type: ignore[no-untyped-call, misc] -def sqlite_reverse_place_diameter(element: SaColumn, +def sqlite_reverse_place_diameter(element: IntersectsReverseDistance, compiler: 'sa.Compiled', **kw: Any) -> str: geom1, rank, geom2 = list(element.clauses) table = element.tablename @@ -102,7 +102,7 @@ class IsBelowReverseDistance(sa.sql.functions.GenericFunction[Any]): @compiles(IsBelowReverseDistance) # type: ignore[no-untyped-call, misc] -def default_is_below_reverse_distance(element: SaColumn, +def default_is_below_reverse_distance(element: IsBelowReverseDistance, compiler: 'sa.Compiled', **kw: Any) -> str: dist, rank = list(element.clauses) return "%s < reverse_place_diameter(%s)" % (compiler.process(dist, **kw), @@ -110,7 +110,7 @@ def default_is_below_reverse_distance(element: SaColumn, @compiles(IsBelowReverseDistance, 'sqlite') # type: ignore[no-untyped-call, misc] -def sqlite_is_below_reverse_distance(element: SaColumn, +def sqlite_is_below_reverse_distance(element: IsBelowReverseDistance, compiler: 'sa.Compiled', **kw: Any) -> str: dist, rank = list(element.clauses) return "%s < 14.0 * exp(-0.2 * %s) - 0.03" % (compiler.process(dist, **kw), @@ -139,7 +139,7 @@ def __init__(self, table: sa.Table) -> None: @compiles(IsAddressPoint) # type: ignore[no-untyped-call, misc] -def default_is_address_point(element: SaColumn, +def default_is_address_point(element: IsAddressPoint, compiler: 'sa.Compiled', **kw: Any) -> str: rank, hnr, name = list(element.clauses) return "(%s = 30 AND (%s IS NOT NULL OR %s ? 'addr:housename'))" % ( @@ -149,7 +149,7 @@ def default_is_address_point(element: SaColumn, @compiles(IsAddressPoint, 'sqlite') # type: ignore[no-untyped-call, misc] -def sqlite_is_address_point(element: SaColumn, +def sqlite_is_address_point(element: IsAddressPoint, compiler: 'sa.Compiled', **kw: Any) -> str: rank, hnr, name = list(element.clauses) return "(%s = 30 AND coalesce(%s, json_extract(%s, '$.addr:housename')) IS NOT NULL)" % ( @@ -166,7 +166,7 @@ class CrosscheckNames(sa.sql.functions.GenericFunction[Any]): inherit_cache = True @compiles(CrosscheckNames) # type: ignore[no-untyped-call, misc] -def compile_crosscheck_names(element: SaColumn, +def compile_crosscheck_names(element: CrosscheckNames, compiler: 'sa.Compiled', **kw: Any) -> str: arg1, arg2 = list(element.clauses) return "coalesce(avals(%s) && ARRAY(SELECT * FROM json_array_elements_text(%s)), false)" % ( @@ -174,7 +174,7 @@ def compile_crosscheck_names(element: SaColumn, @compiles(CrosscheckNames, 'sqlite') # type: ignore[no-untyped-call, misc] -def compile_sqlite_crosscheck_names(element: SaColumn, +def compile_sqlite_crosscheck_names(element: CrosscheckNames, compiler: 'sa.Compiled', **kw: Any) -> str: arg1, arg2 = list(element.clauses) return "EXISTS(SELECT *"\ @@ -191,12 +191,12 @@ class JsonArrayEach(sa.sql.functions.GenericFunction[Any]): @compiles(JsonArrayEach) # type: ignore[no-untyped-call, misc] -def default_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str: +def default_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw: Any) -> str: return "json_array_elements(%s)" % compiler.process(element.clauses, **kw) @compiles(JsonArrayEach, 'sqlite') # type: ignore[no-untyped-call, misc] -def sqlite_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str: +def sqlite_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw: Any) -> str: return "json_each(%s)" % compiler.process(element.clauses, **kw) @@ -208,5 +208,5 @@ class Greatest(sa.sql.functions.GenericFunction[Any]): @compiles(Greatest, 'sqlite') # type: ignore[no-untyped-call, misc] -def sqlite_greatest(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str: +def sqlite_greatest(element: Greatest, compiler: 'sa.Compiled', **kw: Any) -> str: return "max(%s)" % compiler.process(element.clauses, **kw) diff --git a/nominatim/db/sqlalchemy_types/geometry.py b/nominatim/db/sqlalchemy_types/geometry.py index 4520fc8e53..0731b0b796 100644 --- a/nominatim/db/sqlalchemy_types/geometry.py +++ b/nominatim/db/sqlalchemy_types/geometry.py @@ -28,7 +28,7 @@ class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]): @compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc] -def _default_distance_spheroid(element: SaColumn, +def _default_distance_spheroid(element: Geometry_DistanceSpheroid, compiler: 'sa.Compiled', **kw: Any) -> str: return "ST_DistanceSpheroid(%s,"\ " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\ @@ -36,7 +36,7 @@ def _default_distance_spheroid(element: SaColumn, @compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc] -def _spatialite_distance_spheroid(element: SaColumn, +def _spatialite_distance_spheroid(element: Geometry_DistanceSpheroid, compiler: 'sa.Compiled', **kw: Any) -> str: return "COALESCE(Distance(%s, true), 0.0)" % compiler.process(element.clauses, **kw) @@ -49,14 +49,14 @@ class Geometry_IsLineLike(sa.sql.expression.FunctionElement[Any]): @compiles(Geometry_IsLineLike) # type: ignore[no-untyped-call, misc] -def _default_is_line_like(element: SaColumn, +def _default_is_line_like(element: Geometry_IsLineLike, compiler: 'sa.Compiled', **kw: Any) -> str: return "ST_GeometryType(%s) IN ('ST_LineString', 'ST_MultiLineString')" % \ compiler.process(element.clauses, **kw) @compiles(Geometry_IsLineLike, 'sqlite') # type: ignore[no-untyped-call, misc] -def _sqlite_is_line_like(element: SaColumn, +def _sqlite_is_line_like(element: Geometry_IsLineLike, compiler: 'sa.Compiled', **kw: Any) -> str: return "ST_GeometryType(%s) IN ('LINESTRING', 'MULTILINESTRING')" % \ compiler.process(element.clauses, **kw) @@ -70,14 +70,14 @@ class Geometry_IsAreaLike(sa.sql.expression.FunctionElement[Any]): @compiles(Geometry_IsAreaLike) # type: ignore[no-untyped-call, misc] -def _default_is_area_like(element: SaColumn, +def _default_is_area_like(element: Geometry_IsAreaLike, compiler: 'sa.Compiled', **kw: Any) -> str: return "ST_GeometryType(%s) IN ('ST_Polygon', 'ST_MultiPolygon')" % \ compiler.process(element.clauses, **kw) @compiles(Geometry_IsAreaLike, 'sqlite') # type: ignore[no-untyped-call, misc] -def _sqlite_is_area_like(element: SaColumn, +def _sqlite_is_area_like(element: Geometry_IsAreaLike, compiler: 'sa.Compiled', **kw: Any) -> str: return "ST_GeometryType(%s) IN ('POLYGON', 'MULTIPOLYGON')" % \ compiler.process(element.clauses, **kw) @@ -91,14 +91,14 @@ class Geometry_IntersectsBbox(sa.sql.expression.FunctionElement[Any]): @compiles(Geometry_IntersectsBbox) # type: ignore[no-untyped-call, misc] -def _default_intersects(element: SaColumn, +def _default_intersects(element: Geometry_IntersectsBbox, compiler: 'sa.Compiled', **kw: Any) -> str: arg1, arg2 = list(element.clauses) return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw)) @compiles(Geometry_IntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc] -def _sqlite_intersects(element: SaColumn, +def _sqlite_intersects(element: Geometry_IntersectsBbox, compiler: 'sa.Compiled', **kw: Any) -> str: return "MbrIntersects(%s) = 1" % compiler.process(element.clauses, **kw) @@ -114,14 +114,14 @@ class Geometry_ColumnIntersectsBbox(sa.sql.expression.FunctionElement[Any]): @compiles(Geometry_ColumnIntersectsBbox) # type: ignore[no-untyped-call, misc] -def default_intersects_column(element: SaColumn, +def default_intersects_column(element: Geometry_ColumnIntersectsBbox, compiler: 'sa.Compiled', **kw: Any) -> str: arg1, arg2 = list(element.clauses) return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw)) @compiles(Geometry_ColumnIntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc] -def spatialite_intersects_column(element: SaColumn, +def spatialite_intersects_column(element: Geometry_ColumnIntersectsBbox, compiler: 'sa.Compiled', **kw: Any) -> str: arg1, arg2 = list(element.clauses) return "MbrIntersects(%s, %s) = 1 and "\ @@ -145,12 +145,12 @@ class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[Any]): @compiles(Geometry_ColumnDWithin) # type: ignore[no-untyped-call, misc] -def default_dwithin_column(element: SaColumn, +def default_dwithin_column(element: Geometry_ColumnDWithin, compiler: 'sa.Compiled', **kw: Any) -> str: return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw) @compiles(Geometry_ColumnDWithin, 'sqlite') # type: ignore[no-untyped-call, misc] -def spatialite_dwithin_column(element: SaColumn, +def spatialite_dwithin_column(element: Geometry_ColumnDWithin, compiler: 'sa.Compiled', **kw: Any) -> str: geom1, geom2, dist = list(element.clauses) return "ST_Distance(%s, %s) < %s and "\ From 381bd0b5768113f87dc105afaeeaced1fda96069 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 6 Dec 2023 11:14:36 +0100 Subject: [PATCH 11/17] remove unused function --- nominatim/db/sqlalchemy_functions.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/nominatim/db/sqlalchemy_functions.py b/nominatim/db/sqlalchemy_functions.py index 8b1967381c..5872401cca 100644 --- a/nominatim/db/sqlalchemy_functions.py +++ b/nominatim/db/sqlalchemy_functions.py @@ -117,18 +117,6 @@ def sqlite_is_below_reverse_distance(element: IsBelowReverseDistance, compiler.process(rank, **kw)) -def select_index_placex_geometry_reverse_lookupplacenode(table: str) -> 'sa.TextClause': - """ Create an expression with the necessary conditions over a placex - table that the index 'idx_placex_geometry_reverse_lookupPlaceNode' - can be used. - """ - return sa.text(f"{table}.rank_address between 4 and 25" - f" AND {table}.type != 'postcode'" - f" AND {table}.name is not null" - f" AND {table}.linked_place_id is null" - f" AND {table}.osm_type = 'N'") - - class IsAddressPoint(sa.sql.functions.GenericFunction[Any]): name = 'IsAddressPoint' inherit_cache = True From 0d840c8d4ee6ea233ed32350e1d633402c80e46a Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 6 Dec 2023 13:42:58 +0100 Subject: [PATCH 12/17] extend sqlite converter for search tables --- nominatim/db/sqlalchemy_types/int_array.py | 14 ++++ nominatim/tools/convert_sqlite.py | 97 +++++++++++++++++++++- 2 files changed, 110 insertions(+), 1 deletion(-) diff --git a/nominatim/db/sqlalchemy_types/int_array.py b/nominatim/db/sqlalchemy_types/int_array.py index 335d554197..499376cb85 100644 --- a/nominatim/db/sqlalchemy_types/int_array.py +++ b/nominatim/db/sqlalchemy_types/int_array.py @@ -10,6 +10,7 @@ from typing import Any, List, cast, Optional import sqlalchemy as sa +from sqlalchemy.ext.compiler import compiles from sqlalchemy.dialects.postgresql import ARRAY from nominatim.typing import SaDialect, SaColumn @@ -71,3 +72,16 @@ def overlaps(self, other: SaColumn) -> 'sa.Operators': in the array. """ return self.op('&&', is_comparison=True)(other) + + +class ArrayAgg(sa.sql.functions.GenericFunction[Any]): + """ Aggregate function to collect elements in an array. + """ + type = IntArray() + identifier = 'ArrayAgg' + name = 'array_agg' + inherit_cache = True + +@compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc] +def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str: + return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw) diff --git a/nominatim/tools/convert_sqlite.py b/nominatim/tools/convert_sqlite.py index 16f51b661a..d9e39ba374 100644 --- a/nominatim/tools/convert_sqlite.py +++ b/nominatim/tools/convert_sqlite.py @@ -14,7 +14,8 @@ import sqlalchemy as sa from nominatim.typing import SaSelect -from nominatim.db.sqlalchemy_types import Geometry +from nominatim.db.sqlalchemy_types import Geometry, IntArray +from nominatim.api.search.query_analyzer_factory import make_query_analyzer import nominatim.api as napi LOG = logging.getLogger() @@ -54,18 +55,24 @@ async def write(self) -> None: """ Create the database structure and copy the data from the source database to the destination. """ + LOG.warning('Setting up spatialite') await self.dest.execute(sa.select(sa.func.InitSpatialMetaData(True, 'WGS84'))) await self.create_tables() await self.copy_data() + if 'search' in self.options: + await self.create_word_table() await self.create_indexes() async def create_tables(self) -> None: """ Set up the database tables. """ + LOG.warning('Setting up tables') if 'search' not in self.options: self.dest.t.meta.remove(self.dest.t.search_name) + else: + await self.create_class_tables() await self.dest.connection.run_sync(self.dest.t.meta.create_all) @@ -78,6 +85,41 @@ async def create_tables(self) -> None: col.type.subtype.upper(), 'XY'))) + async def create_class_tables(self) -> None: + """ Set up the table that serve class/type-specific geometries. + """ + sql = sa.text("""SELECT tablename FROM pg_tables + WHERE tablename LIKE 'place_classtype_%'""") + for res in await self.src.execute(sql): + for db in (self.src, self.dest): + sa.Table(res[0], db.t.meta, + sa.Column('place_id', sa.BigInteger), + sa.Column('centroid', Geometry)) + + + async def create_word_table(self) -> None: + """ Create the word table. + This table needs the property information to determine the + correct format. Therefore needs to be done after all other + data has been copied. + """ + await make_query_analyzer(self.src) + await make_query_analyzer(self.dest) + src = self.src.t.meta.tables['word'] + dest = self.dest.t.meta.tables['word'] + + await self.dest.connection.run_sync(dest.create) + + LOG.warning("Copying word table") + async_result = await self.src.connection.stream(sa.select(src)) + + async for partition in async_result.partitions(10000): + data = [{k: getattr(r, k) for k in r._fields} for r in partition] + await self.dest.execute(dest.insert(), data) + + await self.dest.connection.run_sync(sa.Index('idx_word_woken', dest.c.word_token).create) + + async def copy_data(self) -> None: """ Copy data for all registered tables. """ @@ -90,6 +132,14 @@ async def copy_data(self) -> None: for r in partition] await self.dest.execute(table.insert(), data) + # Set up a minimal copy of pg_tables used to look up the class tables later. + pg_tables = sa.Table('pg_tables', self.dest.t.meta, + sa.Column('schemaname', sa.Text, default='public'), + sa.Column('tablename', sa.Text)) + await self.dest.connection.run_sync(pg_tables.create) + data = [{'tablename': t} for t in self.dest.t.meta.tables] + await self.dest.execute(pg_tables.insert().values(data)) + async def create_indexes(self) -> None: """ Add indexes necessary for the frontend. @@ -119,6 +169,22 @@ async def create_indexes(self) -> None: await self.create_index('placex', 'parent_place_id') await self.create_index('placex', 'rank_address') await self.create_index('addressline', 'place_id') + await self.create_index('postcode', 'place_id') + await self.create_index('osmline', 'place_id') + await self.create_index('tiger', 'place_id') + + if 'search' in self.options: + await self.create_spatial_index('postcode', 'geometry') + await self.create_spatial_index('search_name', 'centroid') + await self.create_index('search_name', 'place_id') + await self.create_index('osmline', 'parent_place_id') + await self.create_index('tiger', 'parent_place_id') + await self.create_search_index() + + for t in self.dest.t.meta.tables: + if t.startswith('place_classtype_'): + await self.dest.execute(sa.select( + sa.func.CreateSpatialIndex(t, 'centroid'))) async def create_spatial_index(self, table: str, column: str) -> None: @@ -136,6 +202,35 @@ async def create_index(self, table_name: str, column: str) -> None: sa.Index(f"idx_{table}_{column}", getattr(table.c, column)).create) + async def create_search_index(self) -> None: + """ Create the tables and indexes needed for word lookup. + """ + tsrc = self.src.t.search_name + for column in ('name_vector', 'nameaddress_vector'): + table_name = f'reverse_search_{column}' + LOG.warning("Creating reverse search %s", table_name) + rsn = sa.Table(table_name, self.dest.t.meta, + sa.Column('word', sa.Integer()), + sa.Column('places', IntArray)) + await self.dest.connection.run_sync(rsn.create) + + sql = sa.select(sa.func.unnest(getattr(tsrc.c, column)).label('word'), + sa.func.ArrayAgg(tsrc.c.place_id).label('places'))\ + .group_by('word') + + async_result = await self.src.connection.stream(sql) + async for partition in async_result.partitions(100): + data = [] + for row in partition: + row.places.sort() + data.append({'word': row.word, + 'places': row.places}) + await self.dest.execute(rsn.insert(), data) + + await self.dest.connection.run_sync( + sa.Index(f'idx_reverse_search_{column}_word', rsn.c.word).create) + + def select_from(self, table: str) -> SaSelect: """ Create the SQL statement to select the source columns and rows. """ From 6d39563b872b21825a61949e88bc47a0e88c7573 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 6 Dec 2023 20:56:21 +0100 Subject: [PATCH 13/17] enable all API tests for sqlite and port missing features --- nominatim/api/core.py | 2 + nominatim/api/search/db_search_lookups.py | 46 ++++++- nominatim/api/search/db_searches.py | 36 ++--- nominatim/db/sqlalchemy_functions.py | 21 +++ nominatim/db/sqlalchemy_types/int_array.py | 52 +++++-- nominatim/db/sqlalchemy_types/key_value.py | 27 +++- nominatim/db/sqlite_functions.py | 122 +++++++++++++++++ nominatim/tools/convert_sqlite.py | 19 +-- test/python/api/conftest.py | 33 +++++ test/python/api/search/test_search_country.py | 39 +++--- test/python/api/search/test_search_near.py | 42 +++--- test/python/api/search/test_search_places.py | 127 ++++++++++-------- test/python/api/search/test_search_poi.py | 28 ++-- .../python/api/search/test_search_postcode.py | 48 +++---- test/python/api/test_api_search.py | 102 +++++++------- 15 files changed, 514 insertions(+), 230 deletions(-) create mode 100644 nominatim/db/sqlite_functions.py diff --git a/nominatim/api/core.py b/nominatim/api/core.py index b262422758..f975f44aae 100644 --- a/nominatim/api/core.py +++ b/nominatim/api/core.py @@ -19,6 +19,7 @@ from nominatim.errors import UsageError from nominatim.db.sqlalchemy_schema import SearchTables from nominatim.db.async_core_library import PGCORE_LIB, PGCORE_ERROR +import nominatim.db.sqlite_functions from nominatim.config import Configuration from nominatim.api.connection import SearchConnection from nominatim.api.status import get_status, StatusResult @@ -122,6 +123,7 @@ async def setup_database(self) -> None: @sa.event.listens_for(engine.sync_engine, "connect") def _on_sqlite_connect(dbapi_con: Any, _: Any) -> None: dbapi_con.run_async(lambda conn: conn.enable_load_extension(True)) + nominatim.db.sqlite_functions.install_custom_functions(dbapi_con) cursor = dbapi_con.cursor() cursor.execute("SELECT load_extension('mod_spatialite')") cursor.execute('SELECT SetDecimalPrecision(7)') diff --git a/nominatim/api/search/db_search_lookups.py b/nominatim/api/search/db_search_lookups.py index 3e307235b8..aa5cef5f47 100644 --- a/nominatim/api/search/db_search_lookups.py +++ b/nominatim/api/search/db_search_lookups.py @@ -26,18 +26,38 @@ class LookupAll(LookupType): inherit_cache = True def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None: - super().__init__(getattr(table.c, column), + super().__init__(table.c.place_id, getattr(table.c, column), column, sa.type_coerce(tokens, IntArray)) @compiles(LookupAll) # type: ignore[no-untyped-call, misc] def _default_lookup_all(element: LookupAll, compiler: 'sa.Compiled', **kw: Any) -> str: - col, tokens = list(element.clauses) + _, col, _, tokens = list(element.clauses) return "(%s @> %s)" % (compiler.process(col, **kw), compiler.process(tokens, **kw)) +@compiles(LookupAll, 'sqlite') # type: ignore[no-untyped-call, misc] +def _sqlite_lookup_all(element: LookupAll, + compiler: 'sa.Compiled', **kw: Any) -> str: + place, col, colname, tokens = list(element.clauses) + return "(%s IN (SELECT CAST(value as bigint) FROM"\ + " (SELECT array_intersect_fuzzy(places) as p FROM"\ + " (SELECT places FROM reverse_search_name"\ + " WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\ + " AND column = %s"\ + " ORDER BY length(places)) as x) as u,"\ + " json_each('[' || u.p || ']'))"\ + " AND array_contains(%s, %s))"\ + % (compiler.process(place, **kw), + compiler.process(tokens, **kw), + compiler.process(colname, **kw), + compiler.process(col, **kw), + compiler.process(tokens, **kw) + ) + + class LookupAny(LookupType): """ Find all entries that contain at least one of the given tokens. @@ -46,17 +66,28 @@ class LookupAny(LookupType): inherit_cache = True def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None: - super().__init__(getattr(table.c, column), + super().__init__(table.c.place_id, getattr(table.c, column), column, sa.type_coerce(tokens, IntArray)) - @compiles(LookupAny) # type: ignore[no-untyped-call, misc] def _default_lookup_any(element: LookupAny, compiler: 'sa.Compiled', **kw: Any) -> str: - col, tokens = list(element.clauses) + _, col, _, tokens = list(element.clauses) return "(%s && %s)" % (compiler.process(col, **kw), compiler.process(tokens, **kw)) +@compiles(LookupAny, 'sqlite') # type: ignore[no-untyped-call, misc] +def _sqlite_lookup_any(element: LookupAny, + compiler: 'sa.Compiled', **kw: Any) -> str: + place, _, colname, tokens = list(element.clauses) + return "%s IN (SELECT CAST(value as bigint) FROM"\ + " (SELECT array_union(places) as p FROM reverse_search_name"\ + " WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\ + " AND column = %s) as u,"\ + " json_each('[' || u.p || ']'))" % (compiler.process(place, **kw), + compiler.process(tokens, **kw), + compiler.process(colname, **kw)) + class Restrict(LookupType): @@ -76,3 +107,8 @@ def _default_restrict(element: Restrict, arg1, arg2 = list(element.clauses) return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw)) + +@compiles(Restrict, 'sqlite') # type: ignore[no-untyped-call, misc] +def _sqlite_restrict(element: Restrict, + compiler: 'sa.Compiled', **kw: Any) -> str: + return "array_contains(%s)" % compiler.process(element.clauses, **kw) diff --git a/nominatim/api/search/db_searches.py b/nominatim/api/search/db_searches.py index c56554fdc0..ee98100c63 100644 --- a/nominatim/api/search/db_searches.py +++ b/nominatim/api/search/db_searches.py @@ -11,7 +11,6 @@ import abc import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import array_agg from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \ SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind @@ -19,7 +18,7 @@ from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox import nominatim.api.results as nres from nominatim.api.search.db_search_fields import SearchData, WeightedCategories -from nominatim.db.sqlalchemy_types import Geometry +from nominatim.db.sqlalchemy_types import Geometry, IntArray #pylint: disable=singleton-comparison,not-callable #pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements @@ -110,7 +109,7 @@ def _add_geometry_columns(sql: SaLambdaSelect, col: SaColumn, details: SearchDet def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause, numerals: List[int], details: SearchDetails) -> SaScalarSelect: - all_ids = array_agg(table.c.place_id) # type: ignore[no-untyped-call] + all_ids = sa.func.ArrayAgg(table.c.place_id) sql = sa.select(all_ids).where(table.c.parent_place_id == inner.c.place_id) if len(numerals) == 1: @@ -134,9 +133,7 @@ def _filter_by_layer(table: SaFromClause, layers: DataLayer) -> SaColumn: orexpr.append(no_index(table.c.rank_address).between(1, 30)) elif layers & DataLayer.ADDRESS: orexpr.append(no_index(table.c.rank_address).between(1, 29)) - orexpr.append(sa.and_(no_index(table.c.rank_address) == 30, - sa.or_(table.c.housenumber != None, - table.c.address.has_key('addr:housename')))) + orexpr.append(sa.func.IsAddressPoint(table)) elif layers & DataLayer.POI: orexpr.append(sa.and_(no_index(table.c.rank_address) == 30, table.c.class_.not_in(('place', 'building')))) @@ -188,12 +185,21 @@ async def _get_placex_housenumbers(conn: SearchConnection, yield result +def _int_list_to_subquery(inp: List[int]) -> 'sa.Subquery': + """ Create a subselect that returns the given list of integers + as rows in the column 'nr'. + """ + vtab = sa.func.JsonArrayEach(sa.type_coerce(inp, sa.JSON))\ + .table_valued(sa.column('value', type_=sa.JSON)) # type: ignore[no-untyped-call] + return sa.select(sa.cast(sa.cast(vtab.c.value, sa.Text), sa.Integer).label('nr')).subquery() + + async def _get_osmline(conn: SearchConnection, place_ids: List[int], numerals: List[int], details: SearchDetails) -> AsyncIterator[nres.SearchResult]: t = conn.t.osmline - values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\ - .data([(n,) for n in numerals]) + + values = _int_list_to_subquery(numerals) sql = sa.select(t.c.place_id, t.c.osm_id, t.c.parent_place_id, t.c.address, values.c.nr.label('housenumber'), @@ -216,8 +222,7 @@ async def _get_tiger(conn: SearchConnection, place_ids: List[int], numerals: List[int], osm_id: int, details: SearchDetails) -> AsyncIterator[nres.SearchResult]: t = conn.t.tiger - values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\ - .data([(n,) for n in numerals]) + values = _int_list_to_subquery(numerals) sql = sa.select(t.c.place_id, t.c.parent_place_id, sa.literal('W').label('osm_type'), sa.literal(osm_id).label('osm_id'), @@ -573,7 +578,8 @@ async def lookup(self, conn: SearchConnection, tsearch = conn.t.search_name sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\ .where((tsearch.c.name_vector + tsearch.c.nameaddress_vector) - .contains(self.lookups[0].tokens)) + .contains(sa.type_coerce(self.lookups[0].tokens, + IntArray))) for ranking in self.rankings: penalty += ranking.sql_penalty(conn.t.search_name) @@ -692,10 +698,10 @@ async def lookup(self, conn: SearchConnection, sql = sql.order_by(sa.text('accuracy')) if self.housenumbers: - hnr_regexp = f"\\m({'|'.join(self.housenumbers.values)})\\M" + hnr_list = '|'.join(self.housenumbers.values) sql = sql.where(tsearch.c.address_rank.between(16, 30))\ .where(sa.or_(tsearch.c.address_rank < 30, - t.c.housenumber.op('~*')(hnr_regexp))) + sa.func.RegexpWord(hnr_list, t.c.housenumber))) # Cross check for housenumbers, need to do that on a rather large # set. Worst case there are 40.000 main streets in OSM. @@ -703,10 +709,10 @@ async def lookup(self, conn: SearchConnection, # Housenumbers from placex thnr = conn.t.placex.alias('hnr') - pid_list = array_agg(thnr.c.place_id) # type: ignore[no-untyped-call] + pid_list = sa.func.ArrayAgg(thnr.c.place_id) place_sql = sa.select(pid_list)\ .where(thnr.c.parent_place_id == inner.c.place_id)\ - .where(thnr.c.housenumber.op('~*')(hnr_regexp))\ + .where(sa.func.RegexpWord(hnr_list, thnr.c.housenumber))\ .where(thnr.c.linked_place_id == None)\ .where(thnr.c.indexed_status == 0) diff --git a/nominatim/db/sqlalchemy_functions.py b/nominatim/db/sqlalchemy_functions.py index 5872401cca..e2437dd2e3 100644 --- a/nominatim/db/sqlalchemy_functions.py +++ b/nominatim/db/sqlalchemy_functions.py @@ -188,6 +188,7 @@ def sqlite_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw return "json_each(%s)" % compiler.process(element.clauses, **kw) + class Greatest(sa.sql.functions.GenericFunction[Any]): """ Function to compute maximum of all its input parameters. """ @@ -198,3 +199,23 @@ class Greatest(sa.sql.functions.GenericFunction[Any]): @compiles(Greatest, 'sqlite') # type: ignore[no-untyped-call, misc] def sqlite_greatest(element: Greatest, compiler: 'sa.Compiled', **kw: Any) -> str: return "max(%s)" % compiler.process(element.clauses, **kw) + + + +class RegexpWord(sa.sql.functions.GenericFunction[Any]): + """ Check if a full word is in a given string. + """ + name = 'RegexpWord' + inherit_cache = True + + +@compiles(RegexpWord, 'postgresql') # type: ignore[no-untyped-call, misc] +def postgres_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str: + arg1, arg2 = list(element.clauses) + return "%s ~* ('\\m(' || %s || ')\\M')::text" % (compiler.process(arg2, **kw), compiler.process(arg1, **kw)) + + +@compiles(RegexpWord, 'sqlite') # type: ignore[no-untyped-call, misc] +def sqlite_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str: + arg1, arg2 = list(element.clauses) + return "regexp('\\b(' || %s || ')\\b', %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw)) diff --git a/nominatim/db/sqlalchemy_types/int_array.py b/nominatim/db/sqlalchemy_types/int_array.py index 499376cb85..a31793f3f5 100644 --- a/nominatim/db/sqlalchemy_types/int_array.py +++ b/nominatim/db/sqlalchemy_types/int_array.py @@ -57,22 +57,16 @@ def __add__(self, other: SaColumn) -> 'sa.ColumnOperators': """ Concate the array with the given array. If one of the operants is null, the value of the other will be returned. """ - return sa.func.array_cat(self, other, type_=IntArray) + return ArrayCat(self.expr, other) def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators': """ Return true if the array contains all the value of the argument array. """ - return cast('sa.ColumnOperators', self.op('@>', is_comparison=True)(other)) + return ArrayContains(self.expr, other) - def overlaps(self, other: SaColumn) -> 'sa.Operators': - """ Return true if at least one value of the argument is contained - in the array. - """ - return self.op('&&', is_comparison=True)(other) - class ArrayAgg(sa.sql.functions.GenericFunction[Any]): """ Aggregate function to collect elements in an array. @@ -82,6 +76,48 @@ class ArrayAgg(sa.sql.functions.GenericFunction[Any]): name = 'array_agg' inherit_cache = True + @compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc] def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str: return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw) + + + +class ArrayContains(sa.sql.expression.FunctionElement[Any]): + """ Function to check if an array is fully contained in another. + """ + name = 'ArrayContains' + inherit_cache = True + + +@compiles(ArrayContains) # type: ignore[no-untyped-call, misc] +def generic_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str: + arg1, arg2 = list(element.clauses) + return "(%s @> %s)" % (compiler.process(arg1, **kw), + compiler.process(arg2, **kw)) + + +@compiles(ArrayContains, 'sqlite') # type: ignore[no-untyped-call, misc] +def sqlite_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str: + return "array_contains(%s)" % compiler.process(element.clauses, **kw) + + + +class ArrayCat(sa.sql.expression.FunctionElement[Any]): + """ Function to check if an array is fully contained in another. + """ + type = IntArray() + identifier = 'ArrayCat' + inherit_cache = True + + +@compiles(ArrayCat) # type: ignore[no-untyped-call, misc] +def generic_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str: + return "array_cat(%s)" % compiler.process(element.clauses, **kw) + + +@compiles(ArrayCat, 'sqlite') # type: ignore[no-untyped-call, misc] +def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str: + arg1, arg2 = list(element.clauses) + return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw)) + diff --git a/nominatim/db/sqlalchemy_types/key_value.py b/nominatim/db/sqlalchemy_types/key_value.py index 4f2d824aff..937caa021b 100644 --- a/nominatim/db/sqlalchemy_types/key_value.py +++ b/nominatim/db/sqlalchemy_types/key_value.py @@ -10,6 +10,7 @@ from typing import Any import sqlalchemy as sa +from sqlalchemy.ext.compiler import compiles from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.sqlite import JSON as sqlite_json @@ -37,11 +38,25 @@ def merge(self, other: SaColumn) -> 'sa.Operators': one, overwriting values where necessary. When the argument is null, nothing happens. """ - return self.op('||')(sa.func.coalesce(other, - sa.type_coerce('', KeyValueStore))) + return KeyValueConcat(self.expr, other) + + +class KeyValueConcat(sa.sql.expression.FunctionElement[Any]): + """ Return the merged key-value store from the input parameters. + """ + type = KeyValueStore() + name = 'JsonConcat' + inherit_cache = True + +@compiles(KeyValueConcat) # type: ignore[no-untyped-call, misc] +def default_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str: + arg1, arg2 = list(element.clauses) + return "(%s || coalesce(%s, ''::hstore))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw)) + +@compiles(KeyValueConcat, 'sqlite') # type: ignore[no-untyped-call, misc] +def sqlite_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str: + arg1, arg2 = list(element.clauses) + return "json_patch(%s, coalesce(%s, '{}'))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw)) + - def has_key(self, key: SaColumn) -> 'sa.Operators': - """ Return true if the key is cotained in the store. - """ - return self.op('?', is_comparison=True)(key) diff --git a/nominatim/db/sqlite_functions.py b/nominatim/db/sqlite_functions.py new file mode 100644 index 0000000000..2134ae457b --- /dev/null +++ b/nominatim/db/sqlite_functions.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Custom functions for SQLite. +""" +from typing import cast, Optional, Set, Any +import json + +# pylint: disable=protected-access + +def weigh_search(search_vector: Optional[str], rankings: str, default: float) -> float: + """ Custom weight function for search results. + """ + if search_vector is not None: + svec = [int(x) for x in search_vector.split(',')] + for rank in json.loads(rankings): + if all(r in svec for r in rank[1]): + return cast(float, rank[0]) + + return default + + +class ArrayIntersectFuzzy: + """ Compute the array of common elements of all input integer arrays. + Very large input paramenters may be ignored to speed up + computation. Therefore, the result is a superset of common elements. + + Input and output arrays are given as comma-separated lists. + """ + def __init__(self) -> None: + self.first = '' + self.values: Optional[Set[int]] = None + + def step(self, value: Optional[str]) -> None: + """ Add the next array to the intersection. + """ + if value is not None: + if not self.first: + self.first = value + elif len(value) < 10000000: + if self.values is None: + self.values = {int(x) for x in self.first.split(',')} + self.values.intersection_update((int(x) for x in value.split(','))) + + def finalize(self) -> str: + """ Return the final result. + """ + if self.values is not None: + return ','.join(map(str, self.values)) + + return self.first + + +class ArrayUnion: + """ Compute the set of all elements of the input integer arrays. + + Input and output arrays are given as strings of comma-separated lists. + """ + def __init__(self) -> None: + self.values: Optional[Set[str]] = None + + def step(self, value: Optional[str]) -> None: + """ Add the next array to the union. + """ + if value is not None: + if self.values is None: + self.values = set(value.split(',')) + else: + self.values.update(value.split(',')) + + def finalize(self) -> str: + """ Return the final result. + """ + return '' if self.values is None else ','.join(self.values) + + +def array_contains(container: Optional[str], containee: Optional[str]) -> Optional[bool]: + """ Is the array 'containee' completely contained in array 'container'. + """ + if container is None or containee is None: + return None + + vset = container.split(',') + return all(v in vset for v in containee.split(',')) + + +def array_pair_contains(container1: Optional[str], container2: Optional[str], + containee: Optional[str]) -> Optional[bool]: + """ Is the array 'containee' completely contained in the union of + array 'container1' and array 'container2'. + """ + if container1 is None or container2 is None or containee is None: + return None + + vset = container1.split(',') + container2.split(',') + return all(v in vset for v in containee.split(',')) + + +def install_custom_functions(conn: Any) -> None: + """ Install helper functions for Nominatim into the given SQLite + database connection. + """ + conn.create_function('weigh_search', 3, weigh_search, deterministic=True) + conn.create_function('array_contains', 2, array_contains, deterministic=True) + conn.create_function('array_pair_contains', 3, array_pair_contains, deterministic=True) + _create_aggregate(conn, 'array_intersect_fuzzy', 1, ArrayIntersectFuzzy) + _create_aggregate(conn, 'array_union', 1, ArrayUnion) + + +async def _make_aggregate(aioconn: Any, *args: Any) -> None: + await aioconn._execute(aioconn._conn.create_aggregate, *args) + + +def _create_aggregate(conn: Any, name: str, nargs: int, aggregate: Any) -> None: + try: + conn.await_(_make_aggregate(conn._connection, name, nargs, aggregate)) + except Exception as error: # pylint: disable=broad-exception-caught + conn._handle_exception(error) diff --git a/nominatim/tools/convert_sqlite.py b/nominatim/tools/convert_sqlite.py index d9e39ba374..16139c5fbc 100644 --- a/nominatim/tools/convert_sqlite.py +++ b/nominatim/tools/convert_sqlite.py @@ -205,15 +205,15 @@ async def create_index(self, table_name: str, column: str) -> None: async def create_search_index(self) -> None: """ Create the tables and indexes needed for word lookup. """ + LOG.warning("Creating reverse search table") + rsn = sa.Table('reverse_search_name', self.dest.t.meta, + sa.Column('word', sa.Integer()), + sa.Column('column', sa.Text()), + sa.Column('places', IntArray)) + await self.dest.connection.run_sync(rsn.create) + tsrc = self.src.t.search_name for column in ('name_vector', 'nameaddress_vector'): - table_name = f'reverse_search_{column}' - LOG.warning("Creating reverse search %s", table_name) - rsn = sa.Table(table_name, self.dest.t.meta, - sa.Column('word', sa.Integer()), - sa.Column('places', IntArray)) - await self.dest.connection.run_sync(rsn.create) - sql = sa.select(sa.func.unnest(getattr(tsrc.c, column)).label('word'), sa.func.ArrayAgg(tsrc.c.place_id).label('places'))\ .group_by('word') @@ -224,11 +224,12 @@ async def create_search_index(self) -> None: for row in partition: row.places.sort() data.append({'word': row.word, + 'column': column, 'places': row.places}) await self.dest.execute(rsn.insert(), data) - await self.dest.connection.run_sync( - sa.Index(f'idx_reverse_search_{column}_word', rsn.c.word).create) + await self.dest.connection.run_sync( + sa.Index('idx_reverse_search_name_word', rsn.c.word).create) def select_from(self, table: str) -> SaSelect: diff --git a/test/python/api/conftest.py b/test/python/api/conftest.py index 91a3107fbc..05eaddf5fc 100644 --- a/test/python/api/conftest.py +++ b/test/python/api/conftest.py @@ -16,6 +16,7 @@ import nominatim.api as napi from nominatim.db.sql_preprocessor import SQLPreprocessor +from nominatim.api.search.query_analyzer_factory import make_query_analyzer from nominatim.tools import convert_sqlite import nominatim.api.logging as loglib @@ -160,6 +161,22 @@ def add_class_type_table(self, cls, typ): """))) + def add_word_table(self, content): + data = [dict(zip(['word_id', 'word_token', 'type', 'word', 'info'], c)) + for c in content] + + async def _do_sql(): + async with self.api._async_api.begin() as conn: + if 'word' not in conn.t.meta.tables: + await make_query_analyzer(conn) + word_table = conn.t.meta.tables['word'] + await conn.connection.run_sync(word_table.create) + if data: + await conn.execute(conn.t.meta.tables['word'].insert(), data) + + self.async_to_sync(_do_sql()) + + async def exec_async(self, sql, *args, **kwargs): async with self.api._async_api.begin() as conn: return await conn.execute(sql, *args, **kwargs) @@ -195,6 +212,22 @@ def frontend(request, event_loop, tmp_path): db = str(tmp_path / 'test_nominatim_python_unittest.sqlite') def mkapi(apiobj, options={'reverse'}): + apiobj.add_data('properties', + [{'property': 'tokenizer', 'value': 'icu'}, + {'property': 'tokenizer_import_normalisation', 'value': ':: lower();'}, + {'property': 'tokenizer_import_transliteration', 'value': "'1' > '/1/'; 'ä' > 'ä '"}, + ]) + + async def _do_sql(): + async with apiobj.api._async_api.begin() as conn: + if 'word' in conn.t.meta.tables: + return + await make_query_analyzer(conn) + word_table = conn.t.meta.tables['word'] + await conn.connection.run_sync(word_table.create) + + apiobj.async_to_sync(_do_sql()) + event_loop.run_until_complete(convert_sqlite.convert(Path('/invalid'), db, options)) outapi = napi.NominatimAPI(Path('/invalid'), diff --git a/test/python/api/search/test_search_country.py b/test/python/api/search/test_search_country.py index 82b1d37fe3..dc87d313a0 100644 --- a/test/python/api/search/test_search_country.py +++ b/test/python/api/search/test_search_country.py @@ -15,7 +15,7 @@ from nominatim.api.search.db_search_fields import WeightedStrings -def run_search(apiobj, global_penalty, ccodes, +def run_search(apiobj, frontend, global_penalty, ccodes, country_penalties=None, details=SearchDetails()): if country_penalties is None: country_penalties = [0.0] * len(ccodes) @@ -25,15 +25,16 @@ class MySearchData: countries = WeightedStrings(ccodes, country_penalties) search = CountrySearch(MySearchData()) + api = frontend(apiobj, options=['search']) async def run(): - async with apiobj.api._async_api.begin() as conn: + async with api._async_api.begin() as conn: return await search.lookup(conn, details) - return apiobj.async_to_sync(run()) + return api._loop.run_until_complete(run()) -def test_find_from_placex(apiobj): +def test_find_from_placex(apiobj, frontend): apiobj.add_placex(place_id=55, class_='boundary', type='administrative', rank_search=4, rank_address=4, name={'name': 'Lolaland'}, @@ -41,32 +42,32 @@ def test_find_from_placex(apiobj): centroid=(10, 10), geometry='POLYGON((9.5 9.5, 9.5 10.5, 10.5 10.5, 10.5 9.5, 9.5 9.5))') - results = run_search(apiobj, 0.5, ['de', 'yw'], [0.0, 0.3]) + results = run_search(apiobj, frontend, 0.5, ['de', 'yw'], [0.0, 0.3]) assert len(results) == 1 assert results[0].place_id == 55 assert results[0].accuracy == 0.8 -def test_find_from_fallback_countries(apiobj): +def test_find_from_fallback_countries(apiobj, frontend): apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))') apiobj.add_country_name('ro', {'name': 'România'}) - results = run_search(apiobj, 0.0, ['ro']) + results = run_search(apiobj, frontend, 0.0, ['ro']) assert len(results) == 1 assert results[0].names == {'name': 'România'} -def test_find_none(apiobj): - assert len(run_search(apiobj, 0.0, ['xx'])) == 0 +def test_find_none(apiobj, frontend): + assert len(run_search(apiobj, frontend, 0.0, ['xx'])) == 0 @pytest.mark.parametrize('coord,numres', [((0.5, 1), 1), ((10, 10), 0)]) -def test_find_near(apiobj, coord, numres): +def test_find_near(apiobj, frontend, coord, numres): apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))') apiobj.add_country_name('ro', {'name': 'România'}) - results = run_search(apiobj, 0.0, ['ro'], + results = run_search(apiobj, frontend, 0.0, ['ro'], details=SearchDetails(near=napi.Point(*coord), near_radius=0.1)) @@ -92,8 +93,8 @@ def fill_database(self, apiobj): napi.GeometryFormat.SVG, napi.GeometryFormat.TEXT]) @pytest.mark.parametrize('cc', ['yw', 'ro']) - def test_return_geometries(self, apiobj, geom, cc): - results = run_search(apiobj, 0.5, [cc], + def test_return_geometries(self, apiobj, frontend, geom, cc): + results = run_search(apiobj, frontend, 0.5, [cc], details=SearchDetails(geometry_output=geom)) assert len(results) == 1 @@ -101,8 +102,8 @@ def test_return_geometries(self, apiobj, geom, cc): @pytest.mark.parametrize('pid,rids', [(76, [55]), (55, [])]) - def test_exclude_place_id(self, apiobj, pid, rids): - results = run_search(apiobj, 0.5, ['yw', 'ro'], + def test_exclude_place_id(self, apiobj, frontend, pid, rids): + results = run_search(apiobj, frontend, 0.5, ['yw', 'ro'], details=SearchDetails(excluded=[pid])) assert [r.place_id for r in results] == rids @@ -110,8 +111,8 @@ def test_exclude_place_id(self, apiobj, pid, rids): @pytest.mark.parametrize('viewbox,rids', [((9, 9, 11, 11), [55]), ((-10, -10, -3, -3), [])]) - def test_bounded_viewbox_in_placex(self, apiobj, viewbox, rids): - results = run_search(apiobj, 0.5, ['yw'], + def test_bounded_viewbox_in_placex(self, apiobj, frontend, viewbox, rids): + results = run_search(apiobj, frontend, 0.5, ['yw'], details=SearchDetails.from_kwargs({'viewbox': viewbox, 'bounded_viewbox': True})) @@ -120,8 +121,8 @@ def test_bounded_viewbox_in_placex(self, apiobj, viewbox, rids): @pytest.mark.parametrize('viewbox,numres', [((0, 0, 1, 1), 1), ((-10, -10, -3, -3), 0)]) - def test_bounded_viewbox_in_fallback(self, apiobj, viewbox, numres): - results = run_search(apiobj, 0.5, ['ro'], + def test_bounded_viewbox_in_fallback(self, apiobj, frontend, viewbox, numres): + results = run_search(apiobj, frontend, 0.5, ['ro'], details=SearchDetails.from_kwargs({'viewbox': viewbox, 'bounded_viewbox': True})) diff --git a/test/python/api/search/test_search_near.py b/test/python/api/search/test_search_near.py index c0caa9ae6a..5b60dd51d5 100644 --- a/test/python/api/search/test_search_near.py +++ b/test/python/api/search/test_search_near.py @@ -17,7 +17,7 @@ from nominatim.api.search.db_search_lookups import LookupAll -def run_search(apiobj, global_penalty, cat, cat_penalty=None, ccodes=[], +def run_search(apiobj, frontend, global_penalty, cat, cat_penalty=None, ccodes=[], details=SearchDetails()): class PlaceSearchData: @@ -39,21 +39,23 @@ class PlaceSearchData: near_search = NearSearch(0.1, WeightedCategories(cat, cat_penalty), place_search) + api = frontend(apiobj, options=['search']) + async def run(): - async with apiobj.api._async_api.begin() as conn: + async with api._async_api.begin() as conn: return await near_search.lookup(conn, details) - results = apiobj.async_to_sync(run()) + results = api._loop.run_until_complete(run()) results.sort(key=lambda r: r.accuracy) return results -def test_no_results_inner_query(apiobj): - assert not run_search(apiobj, 0.4, [('this', 'that')]) +def test_no_results_inner_query(apiobj, frontend): + assert not run_search(apiobj, frontend, 0.4, [('this', 'that')]) -def test_no_appropriate_results_inner_query(apiobj): +def test_no_appropriate_results_inner_query(apiobj, frontend): apiobj.add_placex(place_id=100, country_code='us', centroid=(5.6, 4.3), geometry='POLYGON((0.0 0.0, 10.0 0.0, 10.0 2.0, 0.0 2.0, 0.0 0.0))') @@ -62,7 +64,7 @@ def test_no_appropriate_results_inner_query(apiobj): apiobj.add_placex(place_id=22, class_='amenity', type='bank', centroid=(5.6001, 4.2994)) - assert not run_search(apiobj, 0.4, [('amenity', 'bank')]) + assert not run_search(apiobj, frontend, 0.4, [('amenity', 'bank')]) class TestNearSearch: @@ -79,18 +81,18 @@ def fill_database(self, apiobj): centroid=(-10.3, 56.9)) - def test_near_in_placex(self, apiobj): + def test_near_in_placex(self, apiobj, frontend): apiobj.add_placex(place_id=22, class_='amenity', type='bank', centroid=(5.6001, 4.2994)) apiobj.add_placex(place_id=23, class_='amenity', type='bench', centroid=(5.6001, 4.2994)) - results = run_search(apiobj, 0.1, [('amenity', 'bank')]) + results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')]) assert [r.place_id for r in results] == [22] - def test_multiple_types_near_in_placex(self, apiobj): + def test_multiple_types_near_in_placex(self, apiobj, frontend): apiobj.add_placex(place_id=22, class_='amenity', type='bank', importance=0.002, centroid=(5.6001, 4.2994)) @@ -98,13 +100,13 @@ def test_multiple_types_near_in_placex(self, apiobj): importance=0.001, centroid=(5.6001, 4.2994)) - results = run_search(apiobj, 0.1, [('amenity', 'bank'), - ('amenity', 'bench')]) + results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank'), + ('amenity', 'bench')]) assert [r.place_id for r in results] == [22, 23] - def test_near_in_classtype(self, apiobj): + def test_near_in_classtype(self, apiobj, frontend): apiobj.add_placex(place_id=22, class_='amenity', type='bank', centroid=(5.6, 4.34)) apiobj.add_placex(place_id=23, class_='amenity', type='bench', @@ -112,13 +114,13 @@ def test_near_in_classtype(self, apiobj): apiobj.add_class_type_table('amenity', 'bank') apiobj.add_class_type_table('amenity', 'bench') - results = run_search(apiobj, 0.1, [('amenity', 'bank')]) + results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')]) assert [r.place_id for r in results] == [22] @pytest.mark.parametrize('cc,rid', [('us', 22), ('mx', 23)]) - def test_restrict_by_country(self, apiobj, cc, rid): + def test_restrict_by_country(self, apiobj, frontend, cc, rid): apiobj.add_placex(place_id=22, class_='amenity', type='bank', centroid=(5.6001, 4.2994), country_code='us') @@ -132,13 +134,13 @@ def test_restrict_by_country(self, apiobj, cc, rid): centroid=(-10.3001, 56.9), country_code='us') - results = run_search(apiobj, 0.1, [('amenity', 'bank')], ccodes=[cc, 'fr']) + results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')], ccodes=[cc, 'fr']) assert [r.place_id for r in results] == [rid] @pytest.mark.parametrize('excluded,rid', [(22, 122), (122, 22)]) - def test_exclude_place_by_id(self, apiobj, excluded, rid): + def test_exclude_place_by_id(self, apiobj, frontend, excluded, rid): apiobj.add_placex(place_id=22, class_='amenity', type='bank', centroid=(5.6001, 4.2994), country_code='us') @@ -147,7 +149,7 @@ def test_exclude_place_by_id(self, apiobj, excluded, rid): country_code='us') - results = run_search(apiobj, 0.1, [('amenity', 'bank')], + results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')], details=SearchDetails(excluded=[excluded])) assert [r.place_id for r in results] == [rid] @@ -155,12 +157,12 @@ def test_exclude_place_by_id(self, apiobj, excluded, rid): @pytest.mark.parametrize('layer,rids', [(napi.DataLayer.POI, [22]), (napi.DataLayer.MANMADE, [])]) - def test_with_layer(self, apiobj, layer, rids): + def test_with_layer(self, apiobj, frontend, layer, rids): apiobj.add_placex(place_id=22, class_='amenity', type='bank', centroid=(5.6001, 4.2994), country_code='us') - results = run_search(apiobj, 0.1, [('amenity', 'bank')], + results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')], details=SearchDetails(layers=layer)) assert [r.place_id for r in results] == rids diff --git a/test/python/api/search/test_search_places.py b/test/python/api/search/test_search_places.py index 44e4098dad..c446a35f88 100644 --- a/test/python/api/search/test_search_places.py +++ b/test/python/api/search/test_search_places.py @@ -18,7 +18,9 @@ FieldLookup, FieldRanking, RankedTokens from nominatim.api.search.db_search_lookups import LookupAll, LookupAny, Restrict -def run_search(apiobj, global_penalty, lookup, ranking, count=2, +APIOPTIONS = ['search'] + +def run_search(apiobj, frontend, global_penalty, lookup, ranking, count=2, hnrs=[], pcs=[], ccodes=[], quals=[], details=SearchDetails()): class MySearchData: @@ -32,11 +34,16 @@ class MySearchData: search = PlaceSearch(0.0, MySearchData(), count) + if frontend is None: + api = apiobj + else: + api = frontend(apiobj, options=APIOPTIONS) + async def run(): - async with apiobj.api._async_api.begin() as conn: + async with api._async_api.begin() as conn: return await search.lookup(conn, details) - results = apiobj.async_to_sync(run()) + results = api._loop.run_until_complete(run()) results.sort(key=lambda r: r.accuracy) return results @@ -59,61 +66,61 @@ def fill_database(self, apiobj): @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict]) @pytest.mark.parametrize('rank,res', [([10], [100, 101]), ([20], [101, 100])]) - def test_lookup_all_match(self, apiobj, lookup_type, rank, res): + def test_lookup_all_match(self, apiobj, frontend, lookup_type, rank, res): lookup = FieldLookup('name_vector', [1,2], lookup_type) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)]) - results = run_search(apiobj, 0.1, [lookup], [ranking]) + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking]) assert [r.place_id for r in results] == res @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict]) - def test_lookup_all_partial_match(self, apiobj, lookup_type): + def test_lookup_all_partial_match(self, apiobj, frontend, lookup_type): lookup = FieldLookup('name_vector', [1,20], lookup_type) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) - results = run_search(apiobj, 0.1, [lookup], [ranking]) + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking]) assert len(results) == 1 assert results[0].place_id == 101 @pytest.mark.parametrize('rank,res', [([10], [100, 101]), ([20], [101, 100])]) - def test_lookup_any_match(self, apiobj, rank, res): + def test_lookup_any_match(self, apiobj, frontend, rank, res): lookup = FieldLookup('name_vector', [11,21], LookupAny) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)]) - results = run_search(apiobj, 0.1, [lookup], [ranking]) + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking]) assert [r.place_id for r in results] == res - def test_lookup_any_partial_match(self, apiobj): + def test_lookup_any_partial_match(self, apiobj, frontend): lookup = FieldLookup('name_vector', [20], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) - results = run_search(apiobj, 0.1, [lookup], [ranking]) + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking]) assert len(results) == 1 assert results[0].place_id == 101 @pytest.mark.parametrize('cc,res', [('us', 100), ('mx', 101)]) - def test_lookup_restrict_country(self, apiobj, cc, res): + def test_lookup_restrict_country(self, apiobj, frontend, cc, res): lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])]) - results = run_search(apiobj, 0.1, [lookup], [ranking], ccodes=[cc]) + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], ccodes=[cc]) assert [r.place_id for r in results] == [res] - def test_lookup_restrict_placeid(self, apiobj): + def test_lookup_restrict_placeid(self, apiobj, frontend): lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])]) - results = run_search(apiobj, 0.1, [lookup], [ranking], + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], details=SearchDetails(excluded=[101])) assert [r.place_id for r in results] == [100] @@ -123,18 +130,18 @@ def test_lookup_restrict_placeid(self, apiobj): napi.GeometryFormat.KML, napi.GeometryFormat.SVG, napi.GeometryFormat.TEXT]) - def test_return_geometries(self, apiobj, geom): + def test_return_geometries(self, apiobj, frontend, geom): lookup = FieldLookup('name_vector', [20], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) - results = run_search(apiobj, 0.1, [lookup], [ranking], + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], details=SearchDetails(geometry_output=geom)) assert geom.name.lower() in results[0].geometry @pytest.mark.parametrize('factor,npoints', [(0.0, 3), (1.0, 2)]) - def test_return_simplified_geometry(self, apiobj, factor, npoints): + def test_return_simplified_geometry(self, apiobj, frontend, factor, npoints): apiobj.add_placex(place_id=333, country_code='us', centroid=(9.0, 9.0), geometry='LINESTRING(8.9 9.0, 9.0 9.0, 9.1 9.0)') @@ -144,7 +151,7 @@ def test_return_simplified_geometry(self, apiobj, factor, npoints): lookup = FieldLookup('name_vector', [55], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) - results = run_search(apiobj, 0.1, [lookup], [ranking], + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], details=SearchDetails(geometry_output=napi.GeometryFormat.GEOJSON, geometry_simplification=factor)) @@ -158,50 +165,52 @@ def test_return_simplified_geometry(self, apiobj, factor, npoints): @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.7,4.0,6.0,5.0']) @pytest.mark.parametrize('wcount,rids', [(2, [100, 101]), (20000, [100])]) - def test_prefer_viewbox(self, apiobj, viewbox, wcount, rids): + def test_prefer_viewbox(self, apiobj, frontend, viewbox, wcount, rids): lookup = FieldLookup('name_vector', [1, 2], LookupAll) ranking = FieldRanking('name_vector', 0.2, [RankedTokens(0.0, [21])]) - results = run_search(apiobj, 0.1, [lookup], [ranking]) + api = frontend(apiobj, options=APIOPTIONS) + results = run_search(api, None, 0.1, [lookup], [ranking]) assert [r.place_id for r in results] == [101, 100] - results = run_search(apiobj, 0.1, [lookup], [ranking], count=wcount, + results = run_search(api, None, 0.1, [lookup], [ranking], count=wcount, details=SearchDetails.from_kwargs({'viewbox': viewbox})) assert [r.place_id for r in results] == rids @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.55,4.27,5.62,4.31']) - def test_force_viewbox(self, apiobj, viewbox): + def test_force_viewbox(self, apiobj, frontend, viewbox): lookup = FieldLookup('name_vector', [1, 2], LookupAll) details=SearchDetails.from_kwargs({'viewbox': viewbox, 'bounded_viewbox': True}) - results = run_search(apiobj, 0.1, [lookup], [], details=details) + results = run_search(apiobj, frontend, 0.1, [lookup], [], details=details) assert [r.place_id for r in results] == [100] - def test_prefer_near(self, apiobj): + def test_prefer_near(self, apiobj, frontend): lookup = FieldLookup('name_vector', [1, 2], LookupAll) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) - results = run_search(apiobj, 0.1, [lookup], [ranking]) + api = frontend(apiobj, options=APIOPTIONS) + results = run_search(api, None, 0.1, [lookup], [ranking]) assert [r.place_id for r in results] == [101, 100] - results = run_search(apiobj, 0.1, [lookup], [ranking], + results = run_search(api, None, 0.1, [lookup], [ranking], details=SearchDetails.from_kwargs({'near': '5.6,4.3'})) results.sort(key=lambda r: -r.importance) assert [r.place_id for r in results] == [100, 101] @pytest.mark.parametrize('radius', [0.09, 0.11]) - def test_force_near(self, apiobj, radius): + def test_force_near(self, apiobj, frontend, radius): lookup = FieldLookup('name_vector', [1, 2], LookupAll) details=SearchDetails.from_kwargs({'near': '5.6,4.3', 'near_radius': radius}) - results = run_search(apiobj, 0.1, [lookup], [], details=details) + results = run_search(apiobj, frontend, 0.1, [lookup], [], details=details) assert [r.place_id for r in results] == [100] @@ -242,72 +251,72 @@ def fill_database(self, apiobj): @pytest.mark.parametrize('hnr,res', [('20', [91, 1]), ('20 a', [1]), ('21', [2]), ('22', [2, 92]), ('24', [93]), ('25', [])]) - def test_lookup_by_single_housenumber(self, apiobj, hnr, res): + def test_lookup_by_single_housenumber(self, apiobj, frontend, hnr, res): lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) - results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=[hnr]) + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=[hnr]) assert [r.place_id for r in results] == res + [1000, 2000] @pytest.mark.parametrize('cc,res', [('es', [2, 1000]), ('pt', [92, 2000])]) - def test_lookup_with_country_restriction(self, apiobj, cc, res): + def test_lookup_with_country_restriction(self, apiobj, frontend, cc, res): lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) - results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'], ccodes=[cc]) assert [r.place_id for r in results] == res - def test_lookup_exclude_housenumber_placeid(self, apiobj): + def test_lookup_exclude_housenumber_placeid(self, apiobj, frontend): lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) - results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'], details=SearchDetails(excluded=[92])) assert [r.place_id for r in results] == [2, 1000, 2000] - def test_lookup_exclude_street_placeid(self, apiobj): + def test_lookup_exclude_street_placeid(self, apiobj, frontend): lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) - results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'], details=SearchDetails(excluded=[1000])) assert [r.place_id for r in results] == [2, 92, 2000] - def test_lookup_only_house_qualifier(self, apiobj): + def test_lookup_only_house_qualifier(self, apiobj, frontend): lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) - results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'], quals=[('place', 'house')]) assert [r.place_id for r in results] == [2, 92] - def test_lookup_only_street_qualifier(self, apiobj): + def test_lookup_only_street_qualifier(self, apiobj, frontend): lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) - results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'], quals=[('highway', 'residential')]) assert [r.place_id for r in results] == [1000, 2000] @pytest.mark.parametrize('rank,found', [(26, True), (27, False), (30, False)]) - def test_lookup_min_rank(self, apiobj, rank, found): + def test_lookup_min_rank(self, apiobj, frontend, rank, found): lookup = FieldLookup('name_vector', [1,2], LookupAll) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) - results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'], details=SearchDetails(min_rank=rank)) assert [r.place_id for r in results] == ([2, 92, 1000, 2000] if found else [2, 92]) @@ -317,17 +326,17 @@ def test_lookup_min_rank(self, apiobj, rank, found): napi.GeometryFormat.KML, napi.GeometryFormat.SVG, napi.GeometryFormat.TEXT]) - def test_return_geometries(self, apiobj, geom): + def test_return_geometries(self, apiobj, frontend, geom): lookup = FieldLookup('name_vector', [1, 2], LookupAll) - results = run_search(apiobj, 0.1, [lookup], [], hnrs=['20', '21', '22'], + results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['20', '21', '22'], details=SearchDetails(geometry_output=geom)) assert results assert all(geom.name.lower() in r.geometry for r in results) -def test_very_large_housenumber(apiobj): +def test_very_large_housenumber(apiobj, frontend): apiobj.add_placex(place_id=93, class_='place', type='house', parent_place_id=2000, housenumber='2467463524544', country_code='pt') @@ -340,7 +349,7 @@ def test_very_large_housenumber(apiobj): lookup = FieldLookup('name_vector', [1, 2], LookupAll) - results = run_search(apiobj, 0.1, [lookup], [], hnrs=['2467463524544'], + results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['2467463524544'], details=SearchDetails()) assert results @@ -348,7 +357,7 @@ def test_very_large_housenumber(apiobj): @pytest.mark.parametrize('wcount,rids', [(2, [990, 991]), (30000, [990])]) -def test_name_and_postcode(apiobj, wcount, rids): +def test_name_and_postcode(apiobj, frontend, wcount, rids): apiobj.add_placex(place_id=990, class_='highway', type='service', rank_search=27, rank_address=27, postcode='11225', @@ -368,7 +377,7 @@ def test_name_and_postcode(apiobj, wcount, rids): lookup = FieldLookup('name_vector', [111], LookupAll) - results = run_search(apiobj, 0.1, [lookup], [], pcs=['11225'], count=wcount, + results = run_search(apiobj, frontend, 0.1, [lookup], [], pcs=['11225'], count=wcount, details=SearchDetails()) assert results @@ -398,10 +407,10 @@ def fill_database(self, apiobj): @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])]) - def test_lookup_housenumber(self, apiobj, hnr, res): + def test_lookup_housenumber(self, apiobj, frontend, hnr, res): lookup = FieldLookup('name_vector', [111], LookupAll) - results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr]) + results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=[hnr]) assert [r.place_id for r in results] == res + [990] @@ -410,10 +419,10 @@ def test_lookup_housenumber(self, apiobj, hnr, res): napi.GeometryFormat.KML, napi.GeometryFormat.SVG, napi.GeometryFormat.TEXT]) - def test_osmline_with_geometries(self, apiobj, geom): + def test_osmline_with_geometries(self, apiobj, frontend, geom): lookup = FieldLookup('name_vector', [111], LookupAll) - results = run_search(apiobj, 0.1, [lookup], [], hnrs=['21'], + results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['21'], details=SearchDetails(geometry_output=geom)) assert results[0].place_id == 992 @@ -446,10 +455,10 @@ def fill_database(self, apiobj): @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])]) - def test_lookup_housenumber(self, apiobj, hnr, res): + def test_lookup_housenumber(self, apiobj, frontend, hnr, res): lookup = FieldLookup('name_vector', [111], LookupAll) - results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr]) + results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=[hnr]) assert [r.place_id for r in results] == res + [990] @@ -458,10 +467,10 @@ def test_lookup_housenumber(self, apiobj, hnr, res): napi.GeometryFormat.KML, napi.GeometryFormat.SVG, napi.GeometryFormat.TEXT]) - def test_tiger_with_geometries(self, apiobj, geom): + def test_tiger_with_geometries(self, apiobj, frontend, geom): lookup = FieldLookup('name_vector', [111], LookupAll) - results = run_search(apiobj, 0.1, [lookup], [], hnrs=['21'], + results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['21'], details=SearchDetails(geometry_output=geom)) assert results[0].place_id == 992 @@ -513,10 +522,10 @@ def fill_database(self, apiobj): (napi.DataLayer.NATURAL, [227]), (napi.DataLayer.MANMADE | napi.DataLayer.NATURAL, [225, 227]), (napi.DataLayer.MANMADE | napi.DataLayer.RAILWAY, [225, 226])]) - def test_layers_rank30(self, apiobj, layer, res): + def test_layers_rank30(self, apiobj, frontend, layer, res): lookup = FieldLookup('name_vector', [34], LookupAny) - results = run_search(apiobj, 0.1, [lookup], [], + results = run_search(apiobj, frontend, 0.1, [lookup], [], details=SearchDetails(layers=layer)) assert [r.place_id for r in results] == res diff --git a/test/python/api/search/test_search_poi.py b/test/python/api/search/test_search_poi.py index b80c075200..a0b578baff 100644 --- a/test/python/api/search/test_search_poi.py +++ b/test/python/api/search/test_search_poi.py @@ -15,7 +15,7 @@ from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories -def run_search(apiobj, global_penalty, poitypes, poi_penalties=None, +def run_search(apiobj, frontend, global_penalty, poitypes, poi_penalties=None, ccodes=[], details=SearchDetails()): if poi_penalties is None: poi_penalties = [0.0] * len(poitypes) @@ -27,16 +27,18 @@ class MySearchData: search = PoiSearch(MySearchData()) + api = frontend(apiobj, options=['search']) + async def run(): - async with apiobj.api._async_api.begin() as conn: + async with api._async_api.begin() as conn: return await search.lookup(conn, details) - return apiobj.async_to_sync(run()) + return api._loop.run_until_complete(run()) @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2), ('5.0, 4.59933', 1)]) -def test_simple_near_search_in_placex(apiobj, coord, pid): +def test_simple_near_search_in_placex(apiobj, frontend, coord, pid): apiobj.add_placex(place_id=1, class_='highway', type='bus_stop', centroid=(5.0, 4.6)) apiobj.add_placex(place_id=2, class_='highway', type='bus_stop', @@ -44,7 +46,7 @@ def test_simple_near_search_in_placex(apiobj, coord, pid): details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.001}) - results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details) + results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], details=details) assert [r.place_id for r in results] == [pid] @@ -52,7 +54,7 @@ def test_simple_near_search_in_placex(apiobj, coord, pid): @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2), ('34.3, 56.4', 2), ('5.0, 4.59933', 1)]) -def test_simple_near_search_in_classtype(apiobj, coord, pid): +def test_simple_near_search_in_classtype(apiobj, frontend, coord, pid): apiobj.add_placex(place_id=1, class_='highway', type='bus_stop', centroid=(5.0, 4.6)) apiobj.add_placex(place_id=2, class_='highway', type='bus_stop', @@ -61,7 +63,7 @@ def test_simple_near_search_in_classtype(apiobj, coord, pid): details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.5}) - results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details) + results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], details=details) assert [r.place_id for r in results] == [pid] @@ -83,25 +85,25 @@ def fill_database(self, apiobj, request): self.args = {'near': '34.3, 56.100021', 'near_radius': 0.001} - def test_unrestricted(self, apiobj): - results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], + def test_unrestricted(self, apiobj, frontend): + results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], details=SearchDetails.from_kwargs(self.args)) assert [r.place_id for r in results] == [1, 2] - def test_restict_country(self, apiobj): - results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], + def test_restict_country(self, apiobj, frontend): + results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], ccodes=['de', 'nz'], details=SearchDetails.from_kwargs(self.args)) assert [r.place_id for r in results] == [2] - def test_restrict_by_viewbox(self, apiobj): + def test_restrict_by_viewbox(self, apiobj, frontend): args = {'bounded_viewbox': True, 'viewbox': '34.299,56.0,34.3001,56.10001'} args.update(self.args) - results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], + results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], ccodes=['de', 'nz'], details=SearchDetails.from_kwargs(args)) diff --git a/test/python/api/search/test_search_postcode.py b/test/python/api/search/test_search_postcode.py index e7153f38bf..6976b6a592 100644 --- a/test/python/api/search/test_search_postcode.py +++ b/test/python/api/search/test_search_postcode.py @@ -15,7 +15,7 @@ from nominatim.api.search.db_search_fields import WeightedStrings, FieldLookup, \ FieldRanking, RankedTokens -def run_search(apiobj, global_penalty, pcs, pc_penalties=None, +def run_search(apiobj, frontend, global_penalty, pcs, pc_penalties=None, ccodes=[], lookup=[], ranking=[], details=SearchDetails()): if pc_penalties is None: pc_penalties = [0.0] * len(pcs) @@ -29,28 +29,30 @@ class MySearchData: search = PostcodeSearch(0.0, MySearchData()) + api = frontend(apiobj, options=['search']) + async def run(): - async with apiobj.api._async_api.begin() as conn: + async with api._async_api.begin() as conn: return await search.lookup(conn, details) - return apiobj.async_to_sync(run()) + return api._loop.run_until_complete(run()) -def test_postcode_only_search(apiobj): +def test_postcode_only_search(apiobj, frontend): apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345') apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345') - results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1]) + results = run_search(apiobj, frontend, 0.3, ['12345', '12 345'], [0.0, 0.1]) assert len(results) == 2 assert [r.place_id for r in results] == [100, 101] -def test_postcode_with_country(apiobj): +def test_postcode_with_country(apiobj, frontend): apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345') apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345') - results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1], + results = run_search(apiobj, frontend, 0.3, ['12345', '12 345'], [0.0, 0.1], ccodes=['de', 'pl']) assert len(results) == 1 @@ -81,30 +83,30 @@ def fill_database(self, apiobj): country_code='pl') - def test_lookup_both(self, apiobj): + def test_lookup_both(self, apiobj, frontend): lookup = FieldLookup('name_vector', [1,2], 'restrict') ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) - results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup], ranking=[ranking]) + results = run_search(apiobj, frontend, 0.1, ['12345'], lookup=[lookup], ranking=[ranking]) assert [r.place_id for r in results] == [100, 101] - def test_restrict_by_name(self, apiobj): + def test_restrict_by_name(self, apiobj, frontend): lookup = FieldLookup('name_vector', [10], 'restrict') - results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup]) + results = run_search(apiobj, frontend, 0.1, ['12345'], lookup=[lookup]) assert [r.place_id for r in results] == [100] @pytest.mark.parametrize('coord,place_id', [((16.5, 5), 100), ((-45.1, 7.004), 101)]) - def test_lookup_near(self, apiobj, coord, place_id): + def test_lookup_near(self, apiobj, frontend, coord, place_id): lookup = FieldLookup('name_vector', [1,2], 'restrict') ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) - results = run_search(apiobj, 0.1, ['12345'], + results = run_search(apiobj, frontend, 0.1, ['12345'], lookup=[lookup], ranking=[ranking], details=SearchDetails(near=napi.Point(*coord), near_radius=0.6)) @@ -116,8 +118,8 @@ def test_lookup_near(self, apiobj, coord, place_id): napi.GeometryFormat.KML, napi.GeometryFormat.SVG, napi.GeometryFormat.TEXT]) - def test_return_geometries(self, apiobj, geom): - results = run_search(apiobj, 0.1, ['12345'], + def test_return_geometries(self, apiobj, frontend, geom): + results = run_search(apiobj, frontend, 0.1, ['12345'], details=SearchDetails(geometry_output=geom)) assert results @@ -126,8 +128,8 @@ def test_return_geometries(self, apiobj, geom): @pytest.mark.parametrize('viewbox, rids', [('-46,6,-44,8', [101,100]), ('16,4,18,6', [100,101])]) - def test_prefer_viewbox(self, apiobj, viewbox, rids): - results = run_search(apiobj, 0.1, ['12345'], + def test_prefer_viewbox(self, apiobj, frontend, viewbox, rids): + results = run_search(apiobj, frontend, 0.1, ['12345'], details=SearchDetails.from_kwargs({'viewbox': viewbox})) assert [r.place_id for r in results] == rids @@ -135,8 +137,8 @@ def test_prefer_viewbox(self, apiobj, viewbox, rids): @pytest.mark.parametrize('viewbox, rid', [('-46,6,-44,8', 101), ('16,4,18,6', 100)]) - def test_restrict_to_viewbox(self, apiobj, viewbox, rid): - results = run_search(apiobj, 0.1, ['12345'], + def test_restrict_to_viewbox(self, apiobj, frontend, viewbox, rid): + results = run_search(apiobj, frontend, 0.1, ['12345'], details=SearchDetails.from_kwargs({'viewbox': viewbox, 'bounded_viewbox': True})) @@ -145,16 +147,16 @@ def test_restrict_to_viewbox(self, apiobj, viewbox, rid): @pytest.mark.parametrize('coord,rids', [((17.05, 5), [100, 101]), ((-45, 7.1), [101, 100])]) - def test_prefer_near(self, apiobj, coord, rids): - results = run_search(apiobj, 0.1, ['12345'], + def test_prefer_near(self, apiobj, frontend, coord, rids): + results = run_search(apiobj, frontend, 0.1, ['12345'], details=SearchDetails(near=napi.Point(*coord))) assert [r.place_id for r in results] == rids @pytest.mark.parametrize('pid,rid', [(100, 101), (101, 100)]) - def test_exclude(self, apiobj, pid, rid): - results = run_search(apiobj, 0.1, ['12345'], + def test_exclude(self, apiobj, frontend, pid, rid): + results = run_search(apiobj, frontend, 0.1, ['12345'], details=SearchDetails(excluded=[pid])) assert [r.place_id for r in results] == [rid] diff --git a/test/python/api/test_api_search.py b/test/python/api/test_api_search.py index aa263d24dd..22dbaa2642 100644 --- a/test/python/api/test_api_search.py +++ b/test/python/api/test_api_search.py @@ -19,6 +19,8 @@ import nominatim.api as napi import nominatim.api.logging as loglib +API_OPTIONS = {'search'} + @pytest.fixture(autouse=True) def setup_icu_tokenizer(apiobj): """ Setup the propoerties needed for using the ICU tokenizer. @@ -30,66 +32,62 @@ def setup_icu_tokenizer(apiobj): ]) -def test_search_no_content(apiobj, table_factory): - table_factory('word', - definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') +def test_search_no_content(apiobj, frontend): + apiobj.add_word_table([]) - assert apiobj.api.search('foo') == [] + api = frontend(apiobj, options=API_OPTIONS) + assert api.search('foo') == [] -def test_search_simple_word(apiobj, table_factory): - table_factory('word', - definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB', - content=[(55, 'test', 'W', 'test', None), +def test_search_simple_word(apiobj, frontend): + apiobj.add_word_table([(55, 'test', 'W', 'test', None), (2, 'test', 'w', 'test', None)]) apiobj.add_placex(place_id=444, class_='place', type='village', centroid=(1.3, 0.7)) apiobj.add_search_name(444, names=[2, 55]) - results = apiobj.api.search('TEST') + api = frontend(apiobj, options=API_OPTIONS) + results = api.search('TEST') assert [r.place_id for r in results] == [444] @pytest.mark.parametrize('logtype', ['text', 'html']) -def test_search_with_debug(apiobj, table_factory, logtype): - table_factory('word', - definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB', - content=[(55, 'test', 'W', 'test', None), +def test_search_with_debug(apiobj, frontend, logtype): + apiobj.add_word_table([(55, 'test', 'W', 'test', None), (2, 'test', 'w', 'test', None)]) apiobj.add_placex(place_id=444, class_='place', type='village', centroid=(1.3, 0.7)) apiobj.add_search_name(444, names=[2, 55]) + api = frontend(apiobj, options=API_OPTIONS) loglib.set_log_output(logtype) - results = apiobj.api.search('TEST') + results = api.search('TEST') assert loglib.get_and_disable() -def test_address_no_content(apiobj, table_factory): - table_factory('word', - definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') +def test_address_no_content(apiobj, frontend): + apiobj.add_word_table([]) - assert apiobj.api.search_address(amenity='hotel', - street='Main St 34', - city='Happyville', - county='Wideland', - state='Praerie', - postalcode='55648', - country='xx') == [] + api = frontend(apiobj, options=API_OPTIONS) + assert api.search_address(amenity='hotel', + street='Main St 34', + city='Happyville', + county='Wideland', + state='Praerie', + postalcode='55648', + country='xx') == [] @pytest.mark.parametrize('atype,address,search', [('street', 26, 26), ('city', 16, 18), ('county', 12, 12), ('state', 8, 8)]) -def test_address_simple_places(apiobj, table_factory, atype, address, search): - table_factory('word', - definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB', - content=[(55, 'test', 'W', 'test', None), +def test_address_simple_places(apiobj, frontend, atype, address, search): + apiobj.add_word_table([(55, 'test', 'W', 'test', None), (2, 'test', 'w', 'test', None)]) apiobj.add_placex(place_id=444, @@ -97,53 +95,51 @@ def test_address_simple_places(apiobj, table_factory, atype, address, search): centroid=(1.3, 0.7)) apiobj.add_search_name(444, names=[2, 55], address_rank=address, search_rank=search) - results = apiobj.api.search_address(**{atype: 'TEST'}) + api = frontend(apiobj, options=API_OPTIONS) + results = api.search_address(**{atype: 'TEST'}) assert [r.place_id for r in results] == [444] -def test_address_country(apiobj, table_factory): - table_factory('word', - definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB', - content=[(None, 'ro', 'C', 'ro', None)]) +def test_address_country(apiobj, frontend): + apiobj.add_word_table([(None, 'ro', 'C', 'ro', None)]) apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))') apiobj.add_country_name('ro', {'name': 'România'}) - assert len(apiobj.api.search_address(country='ro')) == 1 + api = frontend(apiobj, options=API_OPTIONS) + assert len(api.search_address(country='ro')) == 1 -def test_category_no_categories(apiobj, table_factory): - table_factory('word', - definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') +def test_category_no_categories(apiobj, frontend): + apiobj.add_word_table([]) - assert apiobj.api.search_category([], near_query='Berlin') == [] + api = frontend(apiobj, options=API_OPTIONS) + assert api.search_category([], near_query='Berlin') == [] -def test_category_no_content(apiobj, table_factory): - table_factory('word', - definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') +def test_category_no_content(apiobj, frontend): + apiobj.add_word_table([]) - assert apiobj.api.search_category([('amenity', 'restaurant')]) == [] + api = frontend(apiobj, options=API_OPTIONS) + assert api.search_category([('amenity', 'restaurant')]) == [] -def test_category_simple_restaurant(apiobj, table_factory): - table_factory('word', - definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') +def test_category_simple_restaurant(apiobj, frontend): + apiobj.add_word_table([]) apiobj.add_placex(place_id=444, class_='amenity', type='restaurant', centroid=(1.3, 0.7)) apiobj.add_search_name(444, names=[2, 55], address_rank=16, search_rank=18) - results = apiobj.api.search_category([('amenity', 'restaurant')], - near=(1.3, 0.701), near_radius=0.015) + api = frontend(apiobj, options=API_OPTIONS) + results = api.search_category([('amenity', 'restaurant')], + near=(1.3, 0.701), near_radius=0.015) assert [r.place_id for r in results] == [444] -def test_category_with_search_phrase(apiobj, table_factory): - table_factory('word', - definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB', - content=[(55, 'test', 'W', 'test', None), +def test_category_with_search_phrase(apiobj, frontend): + apiobj.add_word_table([(55, 'test', 'W', 'test', None), (2, 'test', 'w', 'test', None)]) apiobj.add_placex(place_id=444, class_='place', type='village', @@ -153,7 +149,7 @@ def test_category_with_search_phrase(apiobj, table_factory): apiobj.add_placex(place_id=95, class_='amenity', type='restaurant', centroid=(1.3, 0.7003)) - results = apiobj.api.search_category([('amenity', 'restaurant')], - near_query='TEST') + api = frontend(apiobj, options=API_OPTIONS) + results = api.search_category([('amenity', 'restaurant')], near_query='TEST') assert [r.place_id for r in results] == [95] From ff06b6432992831c597ef3e1ae340ddff424ce2d Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 6 Dec 2023 20:57:09 +0100 Subject: [PATCH 14/17] enable all BDD API tests for sqlite --- test/bdd/api/search/geocodejson.feature | 1 + test/bdd/api/search/language.feature | 1 + test/bdd/api/search/params.feature | 1 + test/bdd/api/search/postcode.feature | 1 + test/bdd/api/search/queries.feature | 1 + test/bdd/api/search/simple.feature | 1 + test/bdd/api/search/structured.feature | 1 + 7 files changed, 7 insertions(+) diff --git a/test/bdd/api/search/geocodejson.feature b/test/bdd/api/search/geocodejson.feature index b0ef92dacf..271ec10c16 100644 --- a/test/bdd/api/search/geocodejson.feature +++ b/test/bdd/api/search/geocodejson.feature @@ -1,3 +1,4 @@ +@SQLITE @APIDB Feature: Parameters for Search API Testing correctness of geocodejson output. diff --git a/test/bdd/api/search/language.feature b/test/bdd/api/search/language.feature index b76adbef5b..fe14cdbe6c 100644 --- a/test/bdd/api/search/language.feature +++ b/test/bdd/api/search/language.feature @@ -1,3 +1,4 @@ +@SQLITE @APIDB Feature: Localization of search results diff --git a/test/bdd/api/search/params.feature b/test/bdd/api/search/params.feature index d5512f5b66..e667b690b0 100644 --- a/test/bdd/api/search/params.feature +++ b/test/bdd/api/search/params.feature @@ -1,3 +1,4 @@ +@SQLITE @APIDB Feature: Search queries Testing different queries and parameters diff --git a/test/bdd/api/search/postcode.feature b/test/bdd/api/search/postcode.feature index 81836efb57..e372f449a9 100644 --- a/test/bdd/api/search/postcode.feature +++ b/test/bdd/api/search/postcode.feature @@ -1,3 +1,4 @@ +@SQLITE @APIDB Feature: Searches with postcodes Various searches involving postcodes diff --git a/test/bdd/api/search/queries.feature b/test/bdd/api/search/queries.feature index 847f1dbf02..eba903ea30 100644 --- a/test/bdd/api/search/queries.feature +++ b/test/bdd/api/search/queries.feature @@ -1,3 +1,4 @@ +@SQLITE @APIDB Feature: Search queries Generic search result correctness diff --git a/test/bdd/api/search/simple.feature b/test/bdd/api/search/simple.feature index 11cd4801be..121271cdf1 100644 --- a/test/bdd/api/search/simple.feature +++ b/test/bdd/api/search/simple.feature @@ -1,3 +1,4 @@ +@SQLITE @APIDB Feature: Simple Tests Simple tests for internal server errors and response format. diff --git a/test/bdd/api/search/structured.feature b/test/bdd/api/search/structured.feature index 517c0eddd2..a1dd5b83d4 100644 --- a/test/bdd/api/search/structured.feature +++ b/test/bdd/api/search/structured.feature @@ -1,3 +1,4 @@ +@SQLITE @APIDB Feature: Structured search queries Testing correctness of results with From 3f5484f48fb3ab132f877d1d5a3e42cdaf274f07 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Thu, 7 Dec 2023 09:33:42 +0100 Subject: [PATCH 15/17] enable search for sqlite conversion by default --- nominatim/clicmd/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nominatim/clicmd/convert.py b/nominatim/clicmd/convert.py index 26b3fb1ffe..7ba77172bd 100644 --- a/nominatim/clicmd/convert.py +++ b/nominatim/clicmd/convert.py @@ -76,7 +76,7 @@ def add_args(self, parser: argparse.ArgumentParser) -> None: group.add_argument('--reverse', action=WithAction, dest_set=self.options, default=True, help='Enable/disable support for reverse and lookup API' ' (default: enabled)') - group.add_argument('--search', action=WithAction, dest_set=self.options, default=False, + group.add_argument('--search', action=WithAction, dest_set=self.options, default=True, help='Enable/disable support for search API (default: disabled)') group.add_argument('--details', action=WithAction, dest_set=self.options, default=True, help='Enable/disable support for details API (default: enabled)') From 89094cf92e82ebb1004498fad8822975f1af4347 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Thu, 7 Dec 2023 10:24:53 +0100 Subject: [PATCH 16/17] error out when a SQLite database does not exist Requires to mark the databse r/w when it is newly created in the convert function. --- nominatim/api/core.py | 4 ++++ nominatim/tools/convert_sqlite.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/nominatim/api/core.py b/nominatim/api/core.py index f975f44aae..1c0c4423fc 100644 --- a/nominatim/api/core.py +++ b/nominatim/api/core.py @@ -101,6 +101,10 @@ async def setup_database(self) -> None: dburl = sa.engine.URL.create('sqlite+aiosqlite', database=params.get('dbname')) + if not ('NOMINATIM_DATABASE_RW' in self.config.environ + and self.config.get_bool('DATABASE_RW')) \ + and not Path(params.get('dbname', '')).is_file(): + raise UsageError(f"SQlite database '{params.get('dbname')}' does not exist.") else: dsn = self.config.get_database_params() query = {k: v for k, v in dsn.items() diff --git a/nominatim/tools/convert_sqlite.py b/nominatim/tools/convert_sqlite.py index 16139c5fbc..3e5847107e 100644 --- a/nominatim/tools/convert_sqlite.py +++ b/nominatim/tools/convert_sqlite.py @@ -28,7 +28,8 @@ async def convert(project_dir: Path, outfile: Path, options: Set[str]) -> None: try: outapi = napi.NominatimAPIAsync(project_dir, - {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={outfile}"}) + {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={outfile}", + 'NOMINATIM_DATABASE_RW': '1'}) try: async with api.begin() as src, outapi.begin() as dest: From ab45db5360aab95bd2ef1f09a87fa1c105a4fa46 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Sat, 9 Dec 2023 16:30:31 +0100 Subject: [PATCH 17/17] add minimal documentation for the SQLite usage --- docs/customize/SQLite.md | 55 ++++++++++++++++++++++++++++++++++++++++ docs/mkdocs.yml | 1 + 2 files changed, 56 insertions(+) create mode 100644 docs/customize/SQLite.md diff --git a/docs/customize/SQLite.md b/docs/customize/SQLite.md new file mode 100644 index 0000000000..9614feabb8 --- /dev/null +++ b/docs/customize/SQLite.md @@ -0,0 +1,55 @@ +A Nominatim database can be converted into an SQLite database and used as +a read-only source for geocoding queries. This sections describes how to +create and use an SQLite database. + +!!! danger + This feature is in an experimental state at the moment. Use at your own + risk. + +## Installing prerequisites + +To use a SQLite database, you need to install: + +* SQLite (>= 3.30) +* Spatialite (> 5.0.0) + +On Ubuntu/Debian, you can run: + + sudo apt install sqlite3 libsqlite3-mod-spatialite libspatialite7 + +## Creating a new SQLite database + +Nominatim cannot import directly into SQLite database. Instead you have to +first create a geocoding database in PostgreSQL by running a +[regular Nominatim import](../admin/Import.md). + +Once this is done, the database can be converted to SQLite with + + nominatim convert -o mydb.sqlite + +This will create a database where all geocoding functions are available. +Depending on what functions you need, the database can be made smaller: + +* `--without-reverse` omits indexes only needed for reverse geocoding +* `--without-search` omit tables and indexes used for forward search +* `--without-details` leaves out extra information only available in the + details API + +## Using an SQLite database + +Once you have created the database, you can use it by simply pointing the +database DSN to the SQLite file: + + NOMINATIM_DATABASE_DSN=sqlite:dbname=mydb.sqlite + +Please note that SQLite support is only available for the Python frontend. To +use the test server with an SQLite database, you therefore need to switch +the frontend engine: + + nominatim serve --engine falcon + +You need to install falcon or starlette for this, depending on which engine +you choose. + +The CLI query commands and the library interface already use the new Python +frontend and therefore work right out of the box. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 3301356d71..f332640ff9 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -40,6 +40,7 @@ nav: - 'Special Phrases': 'customize/Special-Phrases.md' - 'External data: US housenumbers from TIGER': 'customize/Tiger.md' - 'External data: Postcodes': 'customize/Postcodes.md' + - 'Conversion to SQLite': 'customize/SQLite.md' - 'Library Guide': - 'Getting Started': 'library/Getting-Started.md' - 'Nominatim API class': 'library/NominatimAPI.md'