Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove limit from FilterSearch #12

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def _add_texts_to_memory(self, documents: Sequence[Document]) -> None:
)

def get_filtered_documents_with_scores(
self, query: str, limit: int, filter: models.Filter
self, query: str, filter: models.Filter
) -> Sequence[SearchResult]:
"""Specific method for `InMemoryRetriever` to support filtering search results."""
query_embedding = self._embed(query, self._query_representation)
search_result = self._search_client.search(
collection_name=self._collection_name,
query_vector=query_embedding,
limit=limit,
limit=self._k,
query_filter=filter,
)
return [self._point_to_search_result(point) for point in search_result]
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
RetrieverType,
)
from intelligence_layer.core.logger import DebugLogger
from intelligence_layer.core.task import Chunk, Probability, Task
from intelligence_layer.core.task import Chunk, Probability
from intelligence_layer.use_cases.classify.classify import (
Classify,
ClassifyInput,
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(
retriever = InMemoryRetriever(
client,
documents=documents,
k=len(documents),
k=scoring.value,
retriever_type=RetrieverType.SYMMETRIC,
)
self._filter_search = FilterSearch(retriever)
Expand Down Expand Up @@ -149,7 +149,6 @@ def _label_search(
) -> SearchOutput:
search_input = FilterSearchInput(
query=chunk,
limit=self._scoring.value,
filter=models.Filter(
must=[
models.FieldCondition(
Expand Down
5 changes: 1 addition & 4 deletions src/intelligence_layer/use_cases/search/filter_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ class FilterSearchInput(BaseModel):

Attributes:
query: The text to be searched with.
limit: The maximum number of items to be retrieved.
filter: Conditions to filter by as offered by Qdrant.
"""

query: str
limit: int
filter: Filter


Expand All @@ -44,7 +42,6 @@ class FilterSearch(Task[FilterSearchInput, SearchOutput]):
>>> task = FilterSearch(retriever)
>>> input = FilterSearchInput(
>>> query="When did East and West Germany reunite?"
>>> limit=1,
>>> filter=models.Filter(
>>> must=[
>>> models.FieldCondition(
Expand All @@ -64,6 +61,6 @@ def __init__(self, in_memory_retriever: InMemoryRetriever):

def run(self, input: FilterSearchInput, logger: DebugLogger) -> SearchOutput:
results = self._in_memory_retriever.get_filtered_documents_with_scores(
input.query, input.limit, input.filter
input.query, input.filter
)
return SearchOutput(results=results)
33 changes: 33 additions & 0 deletions tests/use_cases/classify/test_embedding_based_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,39 @@ def test_embedding_based_classify_raises_for_unknown_label(
embedding_based_classify.run(classify_input, NoOpDebugLogger())


def test_embedding_based_classify_works_for_empty_labels_in_request(
embedding_based_classify: EmbeddingBasedClassify,
) -> None:
classify_input = ClassifyInput(
chunk=Chunk("This is good"),
labels=frozenset(),
)
result = embedding_based_classify.run(classify_input, NoOpDebugLogger())
assert result.scores == {}


def test_embedding_based_classify_works_without_examples(
client: Client,
) -> None:
labels_with_examples = [
LabelWithExamples(
name="positive",
examples=[],
),
LabelWithExamples(
name="negative",
examples=[],
),
]
embedding_based_classify = EmbeddingBasedClassify(labels_with_examples, client)
classify_input = ClassifyInput(
chunk=Chunk("This is good"),
labels=frozenset(),
)
result = embedding_based_classify.run(classify_input, NoOpDebugLogger())
assert result.scores == {}


def test_can_evaluate_embedding_based_classify(
embedding_based_classify: EmbeddingBasedClassify,
) -> None:
Expand Down
5 changes: 2 additions & 3 deletions tests/use_cases/search/test_filter_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def in_memory_retriever_documents() -> Sequence[Document]:
metadata={"type": "doc"},
),
Document(
text="Cats are small animals. Well, I do not fit at all but I am of the correct type.",
metadata={"type": "doc"},
text="Cats are small animals. Well, I do not fit at all and I am of the correct type.",
metadata={"type": "no doc"},
),
Document(
text="Germany reunited in 1990. This document fits perfectly but it is of the wrong type.",
Expand All @@ -44,7 +44,6 @@ def test_filter_search(
) -> None:
search_input = FilterSearchInput(
query="When did Germany reunite?",
limit=1,
filter=models.Filter(
must=[
models.FieldCondition(
Expand Down