Skip to content

Commit

Permalink
Return the search terms as search highlights for SQLite instead of no…
Browse files Browse the repository at this point in the history
…thing
  • Loading branch information
mlaily committed Mar 17, 2024
1 parent 52f456a commit b4dbcc5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
1 change: 1 addition & 0 deletions changelog.d/17000.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed search feature of Element Android on homesevers using SQLite by returning search terms as search highlights.
28 changes: 23 additions & 5 deletions synapse/storage/databases/main/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ async def search_msgs(
count_args = args
count_clauses = clauses

sqlite_highlights = []

if isinstance(self.database_engine, PostgresEngine):
search_query = search_term
sql = """
Expand All @@ -486,7 +488,7 @@ async def search_msgs(
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_for_sqlite(search_term)
search_query, sqlite_highlights = _parse_query_for_sqlite(search_term)

sql = """
SELECT rank(matchinfo(event_search)) as rank, room_id, event_id
Expand Down Expand Up @@ -534,6 +536,8 @@ async def search_msgs(
highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = await self._find_highlights_in_postgres(search_query, events)
else:
highlights = sqlite_highlights

count_sql += " GROUP BY room_id"

Expand Down Expand Up @@ -597,6 +601,8 @@ async def search_rooms(
count_args = list(args)
count_clauses = list(clauses)

sqlite_highlights = []

if pagination_token:
try:
origin_server_ts_str, stream_str = pagination_token.split(",")
Expand Down Expand Up @@ -647,7 +653,7 @@ async def search_rooms(
CROSS JOIN events USING (event_id)
WHERE
"""
search_query = _parse_query_for_sqlite(search_term)
search_query, sqlite_highlights = _parse_query_for_sqlite(search_term)
args = [search_query] + args

count_sql = """
Expand Down Expand Up @@ -697,6 +703,8 @@ async def search_rooms(
highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = await self._find_highlights_in_postgres(search_query, events)
else:
highlights = sqlite_highlights

count_sql += " GROUP BY room_id"

Expand Down Expand Up @@ -892,19 +900,25 @@ def _tokenize_query(query: str) -> TokenList:
return tokens


def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
def _tokens_to_sqlite_match_query(tokens: TokenList) -> Tuple[str, List[str]]:
"""
Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
Assume sqlite was compiled with enhanced query syntax.
Returns the sqlite-formatted query string and the tokenized search terms
that can be used as highlights.
Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
"""
match_query = []
highlights = []
for token in tokens:
if isinstance(token, str):
match_query.append(token)
highlights.append(token)
elif isinstance(token, Phrase):
match_query.append('"' + " ".join(token.phrase) + '"')
highlights.append(" ".join(token.phrase))
elif token == SearchToken.Not:
# TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
# term has already been added before this.
Expand All @@ -916,11 +930,15 @@ def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
else:
raise ValueError(f"unknown token {token}")

return "".join(match_query)
return "".join(match_query), highlights


def _parse_query_for_sqlite(search_term: str) -> str:
def _parse_query_for_sqlite(search_term: str) -> Tuple[str, List[str]]:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to sqllite's matchinfo().
Returns the converted query string and the tokenized search terms
that can be used as highlights.
"""
return _tokens_to_sqlite_match_query(_tokenize_query(search_term))

13 changes: 6 additions & 7 deletions tests/storage/test_room_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,16 @@ def test_null_byte(self) -> None:
store.search_msgs([room_id], "hi bob", ["content.body"])
)
self.assertEqual(result.get("count"), 1)
if isinstance(store.database_engine, PostgresEngine):
self.assertIn("hi", result.get("highlights"))
self.assertIn("bob", result.get("highlights"))
self.assertIn("hi", result.get("highlights"))
self.assertIn("bob", result.get("highlights"))

# Check that search works for an unrelated message
result = self.get_success(
store.search_msgs([room_id], "another", ["content.body"])
)
self.assertEqual(result.get("count"), 1)
if isinstance(store.database_engine, PostgresEngine):
self.assertIn("another", result.get("highlights"))

self.assertIn("another", result.get("highlights"))

# Check that search works for a search term that overlaps with the message
# containing a null byte and an unrelated message.
Expand All @@ -90,8 +89,8 @@ def test_null_byte(self) -> None:
result = self.get_success(
store.search_msgs([room_id], "hi alice", ["content.body"])
)
if isinstance(store.database_engine, PostgresEngine):
self.assertIn("alice", result.get("highlights"))

self.assertIn("alice", result.get("highlights"))

def test_non_string(self) -> None:
"""Test that non-string `value`s are not inserted into `event_search`.
Expand Down

0 comments on commit b4dbcc5

Please sign in to comment.