Skip to content

Commit

Permalink
finalize EmbeddingBasedClassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
NickyHavoc committed Oct 27, 2023
1 parent 6572d90 commit fb525b3
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class RetrieverType(Enum):
SYMMETRIC = (SemanticRepresentation.Symmetric, SemanticRepresentation.Symmetric)


class InMemoryRetriever(BaseRetriever):
class QdrantInMemoryRetriever(BaseRetriever):
"""Search through documents stored in memory using semantic search.
This retriever uses a [Qdrant](https://github.com/qdrant/qdrant)-in-Memory vector store instance to store documents and their asymmetric embeddings.
Expand All @@ -49,7 +49,7 @@ class InMemoryRetriever(BaseRetriever):
Example:
>>> client = Client(os.getenv("AA_TOKEN"))
>>> documents = [Document(text=t) for t in ["I do not like rain.", "Summer is warm.", "We are so back."]]
>>> retriever = InMemoryRetriever(client, documents)
>>> retriever = QdrantInMemoryRetriever(client, documents)
>>> query = "Do you like summer?"
>>> documents = retriever.get_relevant_documents_with_scores(query)
"""
Expand All @@ -69,8 +69,7 @@ def __init__(
self._collection_name = "in_memory_collection"
self._k = k
self._threshold = threshold
self._query_representation = retriever_type.value[0]
self._document_representation = retriever_type.value[1]
self._query_representation, self._document_representation = retriever_type.value

self._search_client.recreate_collection(
collection_name=self._collection_name,
Expand Down
8 changes: 1 addition & 7 deletions src/intelligence_layer/use_cases/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ class ClassifyOutput(BaseModel):
scores: Mapping[str, Probability]


class Classify(Task[ClassifyInput, ClassifyOutput]):
"""Placeholder class for any classifier implementation."""

pass


class ClassifyEvaluation(BaseModel):
"""The evaluation of a single label classification run.
Expand Down Expand Up @@ -72,7 +66,7 @@ class ClassifyEvaluator(
AggregatedClassifyEvaluation,
]
):
def __init__(self, task: Classify):
def __init__(self, task: Task[ClassifyInput, ClassifyOutput]):
self.task = task

def evaluate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,19 @@
from qdrant_client.http.models import models

from intelligence_layer.connectors.retrievers.base_retriever import Document
from intelligence_layer.connectors.retrievers.in_memory_retriever import (
InMemoryRetriever,
from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import (
QdrantInMemoryRetriever,
RetrieverType,
)
from intelligence_layer.core.logger import DebugLogger
from intelligence_layer.core.task import Chunk, Probability
from intelligence_layer.core.task import Chunk, Probability, Task
from intelligence_layer.use_cases.classify.classify import (
Classify,
ClassifyInput,
ClassifyOutput,
)
from intelligence_layer.use_cases.search.filter_search import (
FilterSearch,
FilterSearchInput,
from intelligence_layer.use_cases.search.qdrant_search import (
QdrantSearch,
QdrantSearchInput,
)
from intelligence_layer.use_cases.search.search import SearchOutput

Expand All @@ -38,19 +37,7 @@ class LabelWithExamples(BaseModel):
examples: Sequence[str]


class EmbeddingBasedClassifyScoring(Enum):
"""Specify the type of scoring to use.
Attributes:
MAX: Takes the mean of the top match, i.e., the max.
MEAN_TOP_5: Takes the mean of the top 5 matches.
"""

MAX = 1
MEAN_TOP_5 = 5


class EmbeddingBasedClassify(Classify):
class EmbeddingBasedClassify(Task[ClassifyInput, ClassifyOutput]):
"""Task that classifies a given input text based on examples.
The input contains a complete set of all possible labels. The output will return a score
Expand Down Expand Up @@ -102,35 +89,28 @@ def __init__(
self,
labels_with_examples: Sequence[LabelWithExamples],
client: Client,
scoring: EmbeddingBasedClassifyScoring = EmbeddingBasedClassifyScoring.MEAN_TOP_5,
top_k_per_label: int = 5,
) -> None:
super().__init__()
self._labels_with_examples = labels_with_examples
documents = self._labels_with_examples_to_documents(labels_with_examples)
self._scoring = scoring
retriever = InMemoryRetriever(
self._scoring = top_k_per_label
retriever = QdrantInMemoryRetriever(
client,
documents=documents,
k=scoring.value,
k=top_k_per_label,
retriever_type=RetrieverType.SYMMETRIC,
)
self._filter_search = FilterSearch(retriever)
self._filter_search = QdrantSearch(retriever)

def run(self, input: ClassifyInput, logger: DebugLogger) -> ClassifyOutput:
available_labels = set(
class_with_examples.name
for class_with_examples in self._labels_with_examples
)
unknown_labels = input.labels - available_labels
if unknown_labels:
raise ValueError(f"Got unexpected labels: {', '.join(unknown_labels)}.")
labels = list(input.labels) # converting to list to preserve order
self._validate_input_labels(input)
results_per_label = [
self._label_search(input.chunk, label, logger) for label in labels
self._label_search(input.chunk, label, logger) for label in input.labels
]
scores = self._calculate_scores(results_per_label)
return ClassifyOutput(
scores={l: Probability(s) for l, s in zip(labels, scores)}
scores={l: Probability(s) for l, s in zip(input.labels, scores)}
)

def _labels_with_examples_to_documents(
Expand All @@ -144,10 +124,19 @@ def _labels_with_examples_to_documents(
for e in class_with_examples.examples
]

def _validate_input_labels(self, input: ClassifyInput) -> None:
available_labels = set(
class_with_examples.name
for class_with_examples in self._labels_with_examples
)
unknown_labels = input.labels - available_labels
if unknown_labels:
raise ValueError(f"Got unexpected labels: {', '.join(unknown_labels)}.")

def _label_search(
self, chunk: Chunk, label: str, logger: DebugLogger
) -> SearchOutput:
search_input = FilterSearchInput(
search_input = QdrantSearchInput(
query=chunk,
filter=models.Filter(
must=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
)
from intelligence_layer.core.echo import EchoInput, EchoTask, TokenWithProb
from intelligence_layer.core.logger import DebugLogger
from intelligence_layer.core.task import Probability, Token
from intelligence_layer.core.task import Probability, Task, Token
from intelligence_layer.use_cases.classify.classify import (
Classify,
ClassifyInput,
ClassifyOutput,
)
Expand All @@ -30,7 +29,7 @@ def to_aa_tokens_prompt(tokens: Sequence[Token]) -> Prompt:
return Prompt.from_tokens([token.token_id for token in tokens])


class SingleLabelClassify(Classify):
class SingleLabelClassify(Task[ClassifyInput, ClassifyOutput]):
"""Task that classifies a given input text with one of the given classes.
The input contains a complete set of all possible labels. The output will return a score for
Expand Down
6 changes: 3 additions & 3 deletions src/intelligence_layer/use_cases/qa/long_context_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
MultipleChunkQaInput,
MultipleChunkQaOutput,
)
from intelligence_layer.connectors.retrievers.in_memory_retriever import (
InMemoryRetriever,
from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import (
QdrantInMemoryRetriever,
)
from intelligence_layer.use_cases.search.search import Search, SearchInput
from intelligence_layer.core.task import Chunk, Task
Expand Down Expand Up @@ -76,7 +76,7 @@ def run(
) -> MultipleChunkQaOutput:
chunks = self._chunk(input.text)
logger.log("chunks", chunks)
retriever = InMemoryRetriever(
retriever = QdrantInMemoryRetriever(
self._client,
documents=[Document(text=c) for c in chunks],
k=self._k,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from pydantic import BaseModel
from qdrant_client.http.models.models import Filter

from intelligence_layer.connectors.retrievers.in_memory_retriever import (
InMemoryRetriever,
from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import (
QdrantInMemoryRetriever,
)
from intelligence_layer.core.task import Task
from intelligence_layer.core.logger import DebugLogger
from intelligence_layer.use_cases.search.search import SearchOutput


class FilterSearchInput(BaseModel):
class QdrantSearchInput(BaseModel):
"""The input for a `FilterSearch` task.
Attributes:
Expand All @@ -21,7 +21,7 @@ class FilterSearchInput(BaseModel):
filter: Filter


class FilterSearch(Task[FilterSearchInput, SearchOutput]):
class QdrantSearch(Task[QdrantSearchInput, SearchOutput]):
"""Performs search to find documents using QDrant filtering methods.
Given a query, this task will utilize a retriever to fetch relevant text search results.
Expand Down Expand Up @@ -55,11 +55,11 @@ class FilterSearch(Task[FilterSearchInput, SearchOutput]):
>>> output = task.run(input, logger)
"""

def __init__(self, in_memory_retriever: InMemoryRetriever):
def __init__(self, in_memory_retriever: QdrantInMemoryRetriever):
super().__init__()
self._in_memory_retriever = in_memory_retriever

def run(self, input: FilterSearchInput, logger: DebugLogger) -> SearchOutput:
def run(self, input: QdrantSearchInput, logger: DebugLogger) -> SearchOutput:
results = self._in_memory_retriever.get_filtered_documents_with_scores(
input.query, input.filter
)
Expand Down
24 changes: 12 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
DocumentIndexRetriever,
)
from intelligence_layer.connectors.document_index import DocumentIndex
from intelligence_layer.connectors.retrievers.in_memory_retriever import (
InMemoryRetriever,
from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import (
QdrantInMemoryRetriever,
RetrieverType,
)

Expand Down Expand Up @@ -45,24 +45,24 @@ def prompt_image() -> Image:


@fixture
def asymmetric_in_memory_retriever(
client: Client, in_memory_retriever_documents: Sequence[Document]
) -> InMemoryRetriever:
return InMemoryRetriever(
def asymmetric_qdrant_in_memory_retriever(
client: Client, qdrant_in_memory_retriever_documents: Sequence[Document]
) -> QdrantInMemoryRetriever:
return QdrantInMemoryRetriever(
client,
in_memory_retriever_documents,
qdrant_in_memory_retriever_documents,
k=2,
retriever_type=RetrieverType.ASYMMETRIC,
)


@fixture
def symmetric_in_memory_retriever(
client: Client, in_memory_retriever_documents: Sequence[Document]
) -> InMemoryRetriever:
return InMemoryRetriever(
def symmetric_qdrant_in_memory_retriever(
client: Client, qdrant_in_memory_retriever_documents: Sequence[Document]
) -> QdrantInMemoryRetriever:
return QdrantInMemoryRetriever(
client,
in_memory_retriever_documents,
qdrant_in_memory_retriever_documents,
k=2,
retriever_type=RetrieverType.SYMMETRIC,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from pytest import fixture
from intelligence_layer.connectors.retrievers.base_retriever import Document

from intelligence_layer.connectors.retrievers.in_memory_retriever import (
InMemoryRetriever,
from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import (
QdrantInMemoryRetriever,
)


Expand All @@ -18,7 +18,7 @@ def in_memory_retriever_documents() -> Sequence[Document]:


def test_asymmetric_in_memory(
asymmetric_in_memory_retriever: InMemoryRetriever,
asymmetric_in_memory_retriever: QdrantInMemoryRetriever,
in_memory_retriever_documents: Sequence[Document],
) -> None:
query = "Do you like summer?"
Expand All @@ -28,7 +28,7 @@ def test_asymmetric_in_memory(


def test_symmetric_in_memory(
symmetric_in_memory_retriever: InMemoryRetriever,
symmetric_in_memory_retriever: QdrantInMemoryRetriever,
in_memory_retriever_documents: Sequence[Document],
) -> None:
query = "I hate drizzle"
Expand Down
6 changes: 3 additions & 3 deletions tests/use_cases/qa/test_retriever_based_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from intelligence_layer.connectors.retrievers.document_index_retriever import (
DocumentIndexRetriever,
)
from intelligence_layer.connectors.retrievers.in_memory_retriever import (
InMemoryRetriever,
from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import (
QdrantInMemoryRetriever,
)
from intelligence_layer.core.logger import NoOpDebugLogger
from intelligence_layer.use_cases.qa.retriever_based_qa import (
Expand Down Expand Up @@ -37,7 +37,7 @@ def in_memory_retriever_documents() -> Sequence[Document]:

@fixture
def retriever_based_qa_with_in_memory_retriever(
client: Client, asymmetric_in_memory_retriever: InMemoryRetriever
client: Client, asymmetric_in_memory_retriever: QdrantInMemoryRetriever
) -> RetrieverBasedQa:
return RetrieverBasedQa(
client, asymmetric_in_memory_retriever, model="luminous-base-control"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from qdrant_client.http.models import models

from intelligence_layer.connectors.retrievers.base_retriever import Document
from intelligence_layer.connectors.retrievers.in_memory_retriever import (
InMemoryRetriever,
from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import (
QdrantInMemoryRetriever,
)
from intelligence_layer.core.logger import NoOpDebugLogger
from intelligence_layer.use_cases.search.filter_search import (
FilterSearch,
FilterSearchInput,
from intelligence_layer.use_cases.search.qdrant_search import (
QdrantSearch,
QdrantSearchInput,
)


Expand All @@ -33,16 +33,16 @@ def in_memory_retriever_documents() -> Sequence[Document]:


@fixture
def filter_search(asymmetric_in_memory_retriever: InMemoryRetriever) -> FilterSearch:
return FilterSearch(asymmetric_in_memory_retriever)
def filter_search(asymmetric_in_memory_retriever: QdrantInMemoryRetriever) -> QdrantSearch:
return QdrantSearch(asymmetric_in_memory_retriever)


def test_filter_search(
filter_search: FilterSearch,
filter_search: QdrantSearch,
no_op_debug_logger: NoOpDebugLogger,
in_memory_retriever_documents: Sequence[Document],
) -> None:
search_input = FilterSearchInput(
search_input = QdrantSearchInput(
query="When did Germany reunite?",
filter=models.Filter(
must=[
Expand Down
6 changes: 3 additions & 3 deletions tests/use_cases/search/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Sequence
from intelligence_layer.connectors.retrievers.base_retriever import Document

from intelligence_layer.connectors.retrievers.in_memory_retriever import (
InMemoryRetriever,
from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import (
QdrantInMemoryRetriever,
)
from intelligence_layer.core.logger import NoOpDebugLogger
from intelligence_layer.use_cases.search.search import Search, SearchInput
Expand All @@ -19,7 +19,7 @@ def in_memory_retriever_documents() -> Sequence[Document]:


@fixture
def search(asymmetric_in_memory_retriever: InMemoryRetriever) -> Search:
def search(asymmetric_in_memory_retriever: QdrantInMemoryRetriever) -> Search:
return Search(asymmetric_in_memory_retriever)


Expand Down

0 comments on commit fb525b3

Please sign in to comment.