Skip to content

Commit

Permalink
Merge pull request #3260 from lonvia/improve-catgeory-search
Browse files Browse the repository at this point in the history
Various improvements to search with special phrases for Python frontend
  • Loading branch information
lonvia authored Nov 27, 2023
2 parents d6fe58f + a7f5c6c commit d8ed565
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 63 deletions.
15 changes: 8 additions & 7 deletions nominatim/api/search/db_search_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
Convertion from token assignment to an abstract DB search.
"""
from typing import Optional, List, Tuple, Iterator
from typing import Optional, List, Tuple, Iterator, Dict
import heapq

from nominatim.api.types import SearchDetails, DataLayer
Expand Down Expand Up @@ -339,12 +339,13 @@ def get_search_categories(self,
Returns None if no category search is requested.
"""
if assignment.category:
tokens = [t for t in self.query.get_tokens(assignment.category,
TokenType.CATEGORY)
if not self.details.categories
or t.get_category() in self.details.categories]
return dbf.WeightedCategories([t.get_category() for t in tokens],
[t.penalty for t in tokens])
tokens: Dict[Tuple[str, str], float] = {}
for t in self.query.get_tokens(assignment.category, TokenType.CATEGORY):
cat = t.get_category()
if (not self.details.categories or cat in self.details.categories)\
and t.penalty < tokens.get(cat, 1000.0):
tokens[cat] = t.penalty
return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))

if self.details.categories:
return dbf.WeightedCategories(self.details.categories,
Expand Down
15 changes: 11 additions & 4 deletions nominatim/api/search/db_search_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
Data structures for more complex fields in abstract search descriptions.
"""
from typing import List, Tuple, Iterator, cast
from typing import List, Tuple, Iterator, cast, Dict
import dataclasses

import sqlalchemy as sa
Expand Down Expand Up @@ -195,10 +195,17 @@ def set_qualifiers(self, tokens: List[Token]) -> None:
""" Set the qulaifier field from the given tokens.
"""
if tokens:
min_penalty = min(t.penalty for t in tokens)
categories: Dict[Tuple[str, str], float] = {}
min_penalty = 1000.0
for t in tokens:
if t.penalty < min_penalty:
min_penalty = t.penalty
cat = t.get_category()
if t.penalty < categories.get(cat, 1000.0):
categories[cat] = t.penalty
self.penalty += min_penalty
self.qualifiers = WeightedCategories([t.get_category() for t in tokens],
[t.penalty - min_penalty for t in tokens])
self.qualifiers = WeightedCategories(list(categories.keys()),
list(categories.values()))


def set_ranking(self, rankings: List[FieldRanking]) -> None:
Expand Down
69 changes: 43 additions & 26 deletions nominatim/api/search/db_searches.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _select_placex(t: SaFromClause) -> SaSelect:
t.c.class_, t.c.type,
t.c.address, t.c.extratags,
t.c.housenumber, t.c.postcode, t.c.country_code,
t.c.importance, t.c.wikipedia,
t.c.wikipedia,
t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
t.c.linked_place_id, t.c.admin_level,
t.c.centroid,
Expand Down Expand Up @@ -158,7 +158,8 @@ async def _get_placex_housenumbers(conn: SearchConnection,
place_ids: List[int],
details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
t = conn.t.placex
sql = _select_placex(t).where(t.c.place_id.in_(place_ids))
sql = _select_placex(t).add_columns(t.c.importance)\
.where(t.c.place_id.in_(place_ids))

if details.geometry_output:
sql = _add_geometry_columns(sql, t.c.geometry, details)
Expand Down Expand Up @@ -255,9 +256,20 @@ async def lookup(self, conn: SearchConnection,

base.sort(key=lambda r: (r.accuracy, r.rank_search))
max_accuracy = base[0].accuracy + 0.5
if base[0].rank_address == 0:
min_rank = 0
max_rank = 0
elif base[0].rank_address < 26:
min_rank = 1
max_rank = min(25, base[0].rank_address + 4)
else:
min_rank = 26
max_rank = 30
base = nres.SearchResults(r for r in base if r.source_table == nres.SourceTable.PLACEX
and r.accuracy <= max_accuracy
and r.bbox and r.bbox.area < 20)
and r.bbox and r.bbox.area < 20
and r.rank_address >= min_rank
and r.rank_address <= max_rank)

if base:
baseids = [b.place_id for b in base[:5] if b.place_id]
Expand All @@ -279,28 +291,37 @@ async def lookup_category(self, results: nres.SearchResults,
"""
table = await conn.get_class_table(*category)

t = conn.t.placex
tgeom = conn.t.placex.alias('pgeom')

sql = _select_placex(t).where(tgeom.c.place_id.in_(ids))\
.where(t.c.class_ == category[0])\
.where(t.c.type == category[1])

if table is None:
# No classtype table available, do a simplified lookup in placex.
sql = sql.join(tgeom, t.c.geometry.ST_DWithin(tgeom.c.centroid, 0.01))\
.order_by(tgeom.c.centroid.ST_Distance(t.c.centroid))
table = conn.t.placex.alias('inner')
sql = sa.select(table.c.place_id,
sa.func.min(tgeom.c.centroid.ST_Distance(table.c.centroid))
.label('dist'))\
.join(tgeom, table.c.geometry.intersects(tgeom.c.centroid.ST_Expand(0.01)))\
.where(table.c.class_ == category[0])\
.where(table.c.type == category[1])
else:
# Use classtype table. We can afford to use a larger
# radius for the lookup.
sql = sql.join(table, t.c.place_id == table.c.place_id)\
.join(tgeom,
table.c.centroid.ST_CoveredBy(
sa.case((sa.and_(tgeom.c.rank_address > 9,
sql = sa.select(table.c.place_id,
sa.func.min(tgeom.c.centroid.ST_Distance(table.c.centroid))
.label('dist'))\
.join(tgeom,
table.c.centroid.ST_CoveredBy(
sa.case((sa.and_(tgeom.c.rank_address > 9,
tgeom.c.geometry.is_area()),
tgeom.c.geometry),
else_ = tgeom.c.centroid.ST_Expand(0.05))))\
.order_by(tgeom.c.centroid.ST_Distance(table.c.centroid))
tgeom.c.geometry),
else_ = tgeom.c.centroid.ST_Expand(0.05))))

inner = sql.where(tgeom.c.place_id.in_(ids))\
.group_by(table.c.place_id).subquery()

t = conn.t.placex
sql = _select_placex(t).add_columns((-inner.c.dist).label('importance'))\
.join(inner, inner.c.place_id == t.c.place_id)\
.order_by(inner.c.dist)

sql = sql.where(no_index(t.c.rank_address).between(MIN_RANK_PARAM, MAX_RANK_PARAM))
if details.countries:
Expand Down Expand Up @@ -342,6 +363,8 @@ async def lookup(self, conn: SearchConnection,
# simply search in placex table
def _base_query() -> SaSelect:
return _select_placex(t) \
.add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM))
.label('importance'))\
.where(t.c.linked_place_id == None) \
.where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
.order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \
Expand Down Expand Up @@ -370,6 +393,7 @@ def _base_query() -> SaSelect:
table = await conn.get_class_table(*category)
if table is not None:
sql = _select_placex(t)\
.add_columns(t.c.importance)\
.join(table, t.c.place_id == table.c.place_id)\
.where(t.c.class_ == category[0])\
.where(t.c.type == category[1])
Expand Down Expand Up @@ -415,6 +439,7 @@ async def lookup(self, conn: SearchConnection,

ccodes = self.countries.values
sql = _select_placex(t)\
.add_columns(t.c.importance)\
.where(t.c.country_code.in_(ccodes))\
.where(t.c.rank_address == 4)

Expand Down Expand Up @@ -591,15 +616,7 @@ async def lookup(self, conn: SearchConnection,
tsearch = conn.t.search_name

sql: SaLambdaSelect = sa.lambda_stmt(lambda:
sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
t.c.class_, t.c.type,
t.c.address, t.c.extratags, t.c.admin_level,
t.c.housenumber, t.c.postcode, t.c.country_code,
t.c.wikipedia,
t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
t.c.centroid,
t.c.geometry.ST_Expand(0).label('bbox'))
.where(t.c.place_id == tsearch.c.place_id))
_select_placex(t).where(t.c.place_id == tsearch.c.place_id))


if details.geometry_output:
Expand Down
5 changes: 4 additions & 1 deletion nominatim/api/search/geocoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None:
return

for result in results:
if not result.display_name:
# Negative importance indicates ordering by distance, which is
# more important than word matching.
if not result.display_name\
or (result.importance is not None and result.importance < 0):
continue
distance = 0.0
norm = self.query_analyzer.normalize_text(result.display_name)
Expand Down
14 changes: 9 additions & 5 deletions nominatim/api/search/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,16 @@ class PhraseType(enum.Enum):
COUNTRY = enum.auto()
""" Contains the country name or code. """

def compatible_with(self, ttype: TokenType) -> bool:
def compatible_with(self, ttype: TokenType,
is_full_phrase: bool) -> bool:
""" Check if the given token type can be used with the phrase type.
"""
if self == PhraseType.NONE:
return True
return not is_full_phrase or ttype != TokenType.QUALIFIER
if self == PhraseType.AMENITY:
return ttype in (TokenType.WORD, TokenType.PARTIAL,
TokenType.QUALIFIER, TokenType.CATEGORY)
return ttype in (TokenType.WORD, TokenType.PARTIAL)\
or (is_full_phrase and ttype == TokenType.CATEGORY)\
or (not is_full_phrase and ttype == TokenType.QUALIFIER)
if self == PhraseType.STREET:
return ttype in (TokenType.WORD, TokenType.PARTIAL, TokenType.HOUSENUMBER)
if self == PhraseType.POSTCODE:
Expand Down Expand Up @@ -244,7 +246,9 @@ def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
be added to, then the token is silently dropped.
"""
snode = self.nodes[trange.start]
if snode.ptype.compatible_with(ttype):
full_phrase = snode.btype in (BreakType.START, BreakType.PHRASE)\
and self.nodes[trange.end].btype in (BreakType.PHRASE, BreakType.END)
if snode.ptype.compatible_with(ttype, full_phrase):
tlist = snode.get_tokens(trange.end, ttype)
if tlist is None:
snode.starting.append(TokenList(trange.end, ttype, [token]))
Expand Down
37 changes: 35 additions & 2 deletions test/python/api/search/test_api_search_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def mktoken(tid: int):
('COUNTRY', 'COUNTRY'),
('POSTCODE', 'POSTCODE')])
def test_phrase_compatible(ptype, ttype):
assert query.PhraseType[ptype].compatible_with(query.TokenType[ttype])
assert query.PhraseType[ptype].compatible_with(query.TokenType[ttype], False)


@pytest.mark.parametrize('ptype', ['COUNTRY', 'POSTCODE'])
def test_phrase_incompatible(ptype):
assert not query.PhraseType[ptype].compatible_with(query.TokenType.PARTIAL)
assert not query.PhraseType[ptype].compatible_with(query.TokenType.PARTIAL, True)


def test_query_node_empty():
Expand Down Expand Up @@ -99,3 +99,36 @@ def test_query_struct_incompatible_token():

assert q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL) == []
assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.COUNTRY)) == 1


def test_query_struct_amenity_single_word():
q = query.QueryStruct([query.Phrase(query.PhraseType.AMENITY, 'bar')])
q.add_node(query.BreakType.END, query.PhraseType.NONE)

q.add_token(query.TokenRange(0, 1), query.TokenType.PARTIAL, mktoken(1))
q.add_token(query.TokenRange(0, 1), query.TokenType.CATEGORY, mktoken(2))
q.add_token(query.TokenRange(0, 1), query.TokenType.QUALIFIER, mktoken(3))

assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL)) == 1
assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.CATEGORY)) == 1
assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.QUALIFIER)) == 0


def test_query_struct_amenity_two_words():
q = query.QueryStruct([query.Phrase(query.PhraseType.AMENITY, 'foo bar')])
q.add_node(query.BreakType.WORD, query.PhraseType.AMENITY)
q.add_node(query.BreakType.END, query.PhraseType.NONE)

for trange in [(0, 1), (1, 2)]:
q.add_token(query.TokenRange(*trange), query.TokenType.PARTIAL, mktoken(1))
q.add_token(query.TokenRange(*trange), query.TokenType.CATEGORY, mktoken(2))
q.add_token(query.TokenRange(*trange), query.TokenType.QUALIFIER, mktoken(3))

assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL)) == 1
assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.CATEGORY)) == 0
assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.QUALIFIER)) == 1

assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.PARTIAL)) == 1
assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.CATEGORY)) == 0
assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.QUALIFIER)) == 1

13 changes: 5 additions & 8 deletions test/python/api/search/test_db_search_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,18 @@ def get_category(self):


def make_query(*args):
q = None
q = QueryStruct([Phrase(PhraseType.NONE, '')])

for tlist in args:
if q is None:
q = QueryStruct([Phrase(PhraseType.NONE, '')])
else:
q.add_node(BreakType.WORD, PhraseType.NONE)
for _ in range(max(inner[0] for tlist in args for inner in tlist)):
q.add_node(BreakType.WORD, PhraseType.NONE)
q.add_node(BreakType.END, PhraseType.NONE)

start = len(q.nodes) - 1
for start, tlist in enumerate(args):
for end, ttype, tinfo in tlist:
for tid, word in tinfo:
q.add_token(TokenRange(start, end), ttype,
MyToken(0.5 if ttype == TokenType.PARTIAL else 0.0, tid, 1, word, True))

q.add_node(BreakType.END, PhraseType.NONE)

return q

Expand Down
16 changes: 6 additions & 10 deletions test/python/api/search/test_token_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,17 @@ def get_category(self):


def make_query(*args):
q = None
q = QueryStruct([Phrase(args[0][1], '')])
dummy = MyToken(3.0, 45, 1, 'foo', True)

for btype, ptype, tlist in args:
if q is None:
q = QueryStruct([Phrase(ptype, '')])
else:
q.add_node(btype, ptype)
for btype, ptype, _ in args[1:]:
q.add_node(btype, ptype)
q.add_node(BreakType.END, PhraseType.NONE)

start = len(q.nodes) - 1
for end, ttype in tlist:
for start, t in enumerate(args):
for end, ttype in t[2]:
q.add_token(TokenRange(start, end), ttype, dummy)

q.add_node(BreakType.END, PhraseType.NONE)

return q


Expand Down

0 comments on commit d8ed565

Please sign in to comment.