Skip to content

Commit

Permalink
Merge pull request #12 from aleph-alpha-intelligence-layer/iterate-on…
Browse files Browse the repository at this point in the history
…-mindset-classifier

Remove limit from FilterSearch
  • Loading branch information
NickyHavoc authored Oct 27, 2023
2 parents bf87d87 + b79bf20 commit 6572d90
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def _add_texts_to_memory(self, documents: Sequence[Document]) -> None:
)

def get_filtered_documents_with_scores(
self, query: str, limit: int, filter: models.Filter
self, query: str, filter: models.Filter
) -> Sequence[SearchResult]:
"""Specific method for `InMemoryRetriever` to support filtering search results."""
query_embedding = self._embed(query, self._query_representation)
search_result = self._search_client.search(
collection_name=self._collection_name,
query_vector=query_embedding,
limit=limit,
limit=self._k,
query_filter=filter,
)
return [self._point_to_search_result(point) for point in search_result]
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
RetrieverType,
)
from intelligence_layer.core.logger import DebugLogger
from intelligence_layer.core.task import Chunk, Probability, Task
from intelligence_layer.core.task import Chunk, Probability
from intelligence_layer.use_cases.classify.classify import (
Classify,
ClassifyInput,
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(
retriever = InMemoryRetriever(
client,
documents=documents,
k=len(documents),
k=scoring.value,
retriever_type=RetrieverType.SYMMETRIC,
)
self._filter_search = FilterSearch(retriever)
Expand Down Expand Up @@ -149,7 +149,6 @@ def _label_search(
) -> SearchOutput:
search_input = FilterSearchInput(
query=chunk,
limit=self._scoring.value,
filter=models.Filter(
must=[
models.FieldCondition(
Expand Down
5 changes: 1 addition & 4 deletions src/intelligence_layer/use_cases/search/filter_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ class FilterSearchInput(BaseModel):
Attributes:
query: The text to be searched with.
limit: The maximum number of items to be retrieved.
filter: Conditions to filter by as offered by Qdrant.
"""

query: str
limit: int
filter: Filter


Expand All @@ -44,7 +42,6 @@ class FilterSearch(Task[FilterSearchInput, SearchOutput]):
>>> task = FilterSearch(retriever)
>>> input = FilterSearchInput(
>>> query="When did East and West Germany reunite?"
>>> limit=1,
>>> filter=models.Filter(
>>> must=[
>>> models.FieldCondition(
Expand All @@ -64,6 +61,6 @@ def __init__(self, in_memory_retriever: InMemoryRetriever):

def run(self, input: FilterSearchInput, logger: DebugLogger) -> SearchOutput:
results = self._in_memory_retriever.get_filtered_documents_with_scores(
input.query, input.limit, input.filter
input.query, input.filter
)
return SearchOutput(results=results)
33 changes: 33 additions & 0 deletions tests/use_cases/classify/test_embedding_based_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,39 @@ def test_embedding_based_classify_raises_for_unknown_label(
embedding_based_classify.run(classify_input, NoOpDebugLogger())


def test_embedding_based_classify_works_for_empty_labels_in_request(
embedding_based_classify: EmbeddingBasedClassify,
) -> None:
classify_input = ClassifyInput(
chunk=Chunk("This is good"),
labels=frozenset(),
)
result = embedding_based_classify.run(classify_input, NoOpDebugLogger())
assert result.scores == {}


def test_embedding_based_classify_works_without_examples(
client: Client,
) -> None:
labels_with_examples = [
LabelWithExamples(
name="positive",
examples=[],
),
LabelWithExamples(
name="negative",
examples=[],
),
]
embedding_based_classify = EmbeddingBasedClassify(labels_with_examples, client)
classify_input = ClassifyInput(
chunk=Chunk("This is good"),
labels=frozenset(),
)
result = embedding_based_classify.run(classify_input, NoOpDebugLogger())
assert result.scores == {}


def test_can_evaluate_embedding_based_classify(
embedding_based_classify: EmbeddingBasedClassify,
) -> None:
Expand Down
5 changes: 2 additions & 3 deletions tests/use_cases/search/test_filter_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def in_memory_retriever_documents() -> Sequence[Document]:
metadata={"type": "doc"},
),
Document(
text="Cats are small animals. Well, I do not fit at all but I am of the correct type.",
metadata={"type": "doc"},
text="Cats are small animals. Well, I do not fit at all and I am of the correct type.",
metadata={"type": "no doc"},
),
Document(
text="Germany reunited in 1990. This document fits perfectly but it is of the wrong type.",
Expand All @@ -44,7 +44,6 @@ def test_filter_search(
) -> None:
search_input = FilterSearchInput(
query="When did Germany reunite?",
limit=1,
filter=models.Filter(
must=[
models.FieldCondition(
Expand Down

0 comments on commit 6572d90

Please sign in to comment.