Skip to content

Commit

Permalink
Use raw SQL query for get_subjects_intersection()
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovsky authored and drew2a committed Nov 21, 2022
1 parent fa5a55c commit 3619ca5
Showing 1 changed file with 22 additions and 30 deletions.
52 changes: 22 additions & 30 deletions src/tribler/core/components/knowledge/db/knowledge_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, Iterator, List, Optional, Set

from pony import orm
from pony.orm import raw_sql
from pony.orm.core import Entity, Query, select
from pony.utils import between

Expand Down Expand Up @@ -358,39 +359,30 @@ def get_suggestions(self, subject_type: Optional[ResourceType] = None, subject:
def get_subjects_intersection(self, subjects_type: Optional[ResourceType], objects: Set[str],
predicate: Optional[ResourceType],
case_sensitive: bool = True) -> Set[str]:
"""Queries the subjects with the given objects and the predicate. Then made an intersection among them.
if not objects:
return set()

In the Tribler, this method is mostly used for searching by tags.
Args:
subjects_type: a type of subjects.
objects: a set of strings that represents the objects.
predicate: the enum that represents a predicate of querying operations.
case_sensitive: if True, then Resources are selected in a case-sensitive manner. if False, then Resources
are selected in a case-insensitive manner.
Returns: a list of the strings representing the subjects.
"""
if case_sensitive:
def name_condition(obj, obj_name):
return obj.name == obj_name
name_condition = '"obj"."name" = $obj_name'
else:
def name_condition(obj, obj_name):
return obj.name.lower() == obj_name.lower()

query = select(subject.name for subject in self.instance.Resource if subject.type == subjects_type.value)
for object_name in objects:
query = query.where(lambda subject: subject in (
s.subject for s in self.instance.Statement
if (s.local_operation == Operation.ADD.value
or not s.local_operation
and s.score >= SHOW_THRESHOLD
and s.object in (
obj for obj in self.instance.Resource
if (obj.type == predicate.value
and name_condition(obj, object_name))
))
))
name_condition = 'py_lower("obj"."name") = py_lower($obj_name)'

query = select(r.name for r in self.instance.Resource)
for obj_name in objects:
query = query.filter(raw_sql(f"""
r.id IN (
SELECT "s"."subject"
FROM "Statement" "s"
WHERE (
"s"."local_operation" = $(Operation.ADD.value)
OR
("s"."local_operation" = 0 OR "s"."local_operation" IS NULL)
AND ("s"."added_count" - "s"."removed_count") >= $SHOW_THRESHOLD
) AND "s"."object" IN (
SELECT "obj"."id" FROM "Resource" "obj"
WHERE "obj"."type" = $(predicate.value) AND {name_condition}
)
)"""))
return set(query)

def get_clock(self, operation: StatementOperation) -> int:
Expand Down

0 comments on commit 3619ca5

Please sign in to comment.