diff --git a/src/intelligence_layer/connectors/retrievers/in_memory_retriever.py b/src/intelligence_layer/connectors/retrievers/qdrant_in_memory_retriever.py similarity index 95% rename from src/intelligence_layer/connectors/retrievers/in_memory_retriever.py rename to src/intelligence_layer/connectors/retrievers/qdrant_in_memory_retriever.py index 8a18d9807..a2d557e23 100644 --- a/src/intelligence_layer/connectors/retrievers/in_memory_retriever.py +++ b/src/intelligence_layer/connectors/retrievers/qdrant_in_memory_retriever.py @@ -31,7 +31,7 @@ class RetrieverType(Enum): SYMMETRIC = (SemanticRepresentation.Symmetric, SemanticRepresentation.Symmetric) -class InMemoryRetriever(BaseRetriever): +class QdrantInMemoryRetriever(BaseRetriever): """Search through documents stored in memory using semantic search. This retriever uses a [Qdrant](https://github.com/qdrant/qdrant)-in-Memory vector store instance to store documents and their asymmetric embeddings. @@ -49,7 +49,7 @@ class InMemoryRetriever(BaseRetriever): Example: >>> client = Client(os.getenv("AA_TOKEN")) >>> documents = [Document(text=t) for t in ["I do not like rain.", "Summer is warm.", "We are so back."]] - >>> retriever = InMemoryRetriever(client, documents) + >>> retriever = QdrantInMemoryRetriever(client, documents) >>> query = "Do you like summer?" >>> documents = retriever.get_relevant_documents_with_scores(query) """ @@ -69,8 +69,7 @@ def __init__( self._collection_name = "in_memory_collection" self._k = k self._threshold = threshold - self._query_representation = retriever_type.value[0] - self._document_representation = retriever_type.value[1] + self._query_representation, self._document_representation = retriever_type.value self._search_client.recreate_collection( collection_name=self._collection_name, diff --git a/src/intelligence_layer/use_cases/classify/classify.py b/src/intelligence_layer/use_cases/classify/classify.py index fee788a4f..66cce39dd 100644 --- a/src/intelligence_layer/use_cases/classify/classify.py +++ b/src/intelligence_layer/use_cases/classify/classify.py @@ -34,12 +34,6 @@ class ClassifyOutput(BaseModel): scores: Mapping[str, Probability] -class Classify(Task[ClassifyInput, ClassifyOutput]): - """Placeholder class for any classifier implementation.""" - - pass - - class ClassifyEvaluation(BaseModel): """The evaluation of a single label classification run. @@ -72,7 +66,7 @@ class ClassifyEvaluator( AggregatedClassifyEvaluation, ] ): - def __init__(self, task: Classify): + def __init__(self, task: Task[ClassifyInput, ClassifyOutput]): self.task = task def evaluate( diff --git a/src/intelligence_layer/use_cases/classify/embedding_based_classify.py b/src/intelligence_layer/use_cases/classify/embedding_based_classify.py index 7116dec16..2d56bc8f9 100644 --- a/src/intelligence_layer/use_cases/classify/embedding_based_classify.py +++ b/src/intelligence_layer/use_cases/classify/embedding_based_classify.py @@ -7,20 +7,19 @@ from qdrant_client.http.models import models from intelligence_layer.connectors.retrievers.base_retriever import Document -from intelligence_layer.connectors.retrievers.in_memory_retriever import ( - InMemoryRetriever, +from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import ( + QdrantInMemoryRetriever, RetrieverType, ) from intelligence_layer.core.logger import DebugLogger -from intelligence_layer.core.task import Chunk, Probability +from intelligence_layer.core.task import Chunk, Probability, Task from intelligence_layer.use_cases.classify.classify import ( - Classify, ClassifyInput, ClassifyOutput, ) -from intelligence_layer.use_cases.search.filter_search import ( - FilterSearch, - FilterSearchInput, +from intelligence_layer.use_cases.search.qdrant_search import ( + QdrantSearch, + QdrantSearchInput, ) from intelligence_layer.use_cases.search.search import SearchOutput @@ -38,19 +37,7 @@ class LabelWithExamples(BaseModel): examples: Sequence[str] -class EmbeddingBasedClassifyScoring(Enum): - """Specify the type of scoring to use. - - Attributes: - MAX: Takes the mean of the top match, i.e., the max. - MEAN_TOP_5: Takes the mean of the top 5 matches. - """ - - MAX = 1 - MEAN_TOP_5 = 5 - - -class EmbeddingBasedClassify(Classify): +class EmbeddingBasedClassify(Task[ClassifyInput, ClassifyOutput]): """Task that classifies a given input text based on examples. The input contains a complete set of all possible labels. The output will return a score @@ -102,35 +89,28 @@ def __init__( self, labels_with_examples: Sequence[LabelWithExamples], client: Client, - scoring: EmbeddingBasedClassifyScoring = EmbeddingBasedClassifyScoring.MEAN_TOP_5, + top_k_per_label: int = 5, ) -> None: super().__init__() self._labels_with_examples = labels_with_examples documents = self._labels_with_examples_to_documents(labels_with_examples) - self._scoring = scoring - retriever = InMemoryRetriever( + self._scoring = top_k_per_label + retriever = QdrantInMemoryRetriever( client, documents=documents, - k=scoring.value, + k=top_k_per_label, retriever_type=RetrieverType.SYMMETRIC, ) - self._filter_search = FilterSearch(retriever) + self._filter_search = QdrantSearch(retriever) def run(self, input: ClassifyInput, logger: DebugLogger) -> ClassifyOutput: - available_labels = set( - class_with_examples.name - for class_with_examples in self._labels_with_examples - ) - unknown_labels = input.labels - available_labels - if unknown_labels: - raise ValueError(f"Got unexpected labels: {', '.join(unknown_labels)}.") - labels = list(input.labels) # converting to list to preserve order + self._validate_input_labels(input) results_per_label = [ - self._label_search(input.chunk, label, logger) for label in labels + self._label_search(input.chunk, label, logger) for label in input.labels ] scores = self._calculate_scores(results_per_label) return ClassifyOutput( - scores={l: Probability(s) for l, s in zip(labels, scores)} + scores={l: Probability(s) for l, s in zip(input.labels, scores)} ) def _labels_with_examples_to_documents( @@ -144,10 +124,19 @@ def _labels_with_examples_to_documents( for e in class_with_examples.examples ] + def _validate_input_labels(self, input: ClassifyInput) -> None: + available_labels = set( + class_with_examples.name + for class_with_examples in self._labels_with_examples + ) + unknown_labels = input.labels - available_labels + if unknown_labels: + raise ValueError(f"Got unexpected labels: {', '.join(unknown_labels)}.") + def _label_search( self, chunk: Chunk, label: str, logger: DebugLogger ) -> SearchOutput: - search_input = FilterSearchInput( + search_input = QdrantSearchInput( query=chunk, filter=models.Filter( must=[ diff --git a/src/intelligence_layer/use_cases/classify/single_label_classify.py b/src/intelligence_layer/use_cases/classify/single_label_classify.py index ba44097dc..ae5b573dd 100644 --- a/src/intelligence_layer/use_cases/classify/single_label_classify.py +++ b/src/intelligence_layer/use_cases/classify/single_label_classify.py @@ -18,9 +18,8 @@ ) from intelligence_layer.core.echo import EchoInput, EchoTask, TokenWithProb from intelligence_layer.core.logger import DebugLogger -from intelligence_layer.core.task import Probability, Token +from intelligence_layer.core.task import Probability, Task, Token from intelligence_layer.use_cases.classify.classify import ( - Classify, ClassifyInput, ClassifyOutput, ) @@ -30,7 +29,7 @@ def to_aa_tokens_prompt(tokens: Sequence[Token]) -> Prompt: return Prompt.from_tokens([token.token_id for token in tokens]) -class SingleLabelClassify(Classify): +class SingleLabelClassify(Task[ClassifyInput, ClassifyOutput]): """Task that classifies a given input text with one of the given classes. The input contains a complete set of all possible labels. The output will return a score for diff --git a/src/intelligence_layer/use_cases/qa/long_context_qa.py b/src/intelligence_layer/use_cases/qa/long_context_qa.py index 42b4c0a16..a31e0d26b 100644 --- a/src/intelligence_layer/use_cases/qa/long_context_qa.py +++ b/src/intelligence_layer/use_cases/qa/long_context_qa.py @@ -10,8 +10,8 @@ MultipleChunkQaInput, MultipleChunkQaOutput, ) -from intelligence_layer.connectors.retrievers.in_memory_retriever import ( - InMemoryRetriever, +from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import ( + QdrantInMemoryRetriever, ) from intelligence_layer.use_cases.search.search import Search, SearchInput from intelligence_layer.core.task import Chunk, Task @@ -76,7 +76,7 @@ def run( ) -> MultipleChunkQaOutput: chunks = self._chunk(input.text) logger.log("chunks", chunks) - retriever = InMemoryRetriever( + retriever = QdrantInMemoryRetriever( self._client, documents=[Document(text=c) for c in chunks], k=self._k, diff --git a/src/intelligence_layer/use_cases/search/filter_search.py b/src/intelligence_layer/use_cases/search/qdrant_search.py similarity index 85% rename from src/intelligence_layer/use_cases/search/filter_search.py rename to src/intelligence_layer/use_cases/search/qdrant_search.py index 0dbcf7e7b..ce78f6a25 100644 --- a/src/intelligence_layer/use_cases/search/filter_search.py +++ b/src/intelligence_layer/use_cases/search/qdrant_search.py @@ -1,15 +1,15 @@ from pydantic import BaseModel from qdrant_client.http.models.models import Filter -from intelligence_layer.connectors.retrievers.in_memory_retriever import ( - InMemoryRetriever, +from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import ( + QdrantInMemoryRetriever, ) from intelligence_layer.core.task import Task from intelligence_layer.core.logger import DebugLogger from intelligence_layer.use_cases.search.search import SearchOutput -class FilterSearchInput(BaseModel): +class QdrantSearchInput(BaseModel): """The input for a `FilterSearch` task. Attributes: @@ -21,7 +21,7 @@ class FilterSearchInput(BaseModel): filter: Filter -class FilterSearch(Task[FilterSearchInput, SearchOutput]): +class QdrantSearch(Task[QdrantSearchInput, SearchOutput]): """Performs search to find documents using QDrant filtering methods. Given a query, this task will utilize a retriever to fetch relevant text search results. @@ -55,11 +55,11 @@ class FilterSearch(Task[FilterSearchInput, SearchOutput]): >>> output = task.run(input, logger) """ - def __init__(self, in_memory_retriever: InMemoryRetriever): + def __init__(self, in_memory_retriever: QdrantInMemoryRetriever): super().__init__() self._in_memory_retriever = in_memory_retriever - def run(self, input: FilterSearchInput, logger: DebugLogger) -> SearchOutput: + def run(self, input: QdrantSearchInput, logger: DebugLogger) -> SearchOutput: results = self._in_memory_retriever.get_filtered_documents_with_scores( input.query, input.filter ) diff --git a/tests/conftest.py b/tests/conftest.py index 543a9bcd4..9ebf20349 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,8 @@ DocumentIndexRetriever, ) from intelligence_layer.connectors.document_index import DocumentIndex -from intelligence_layer.connectors.retrievers.in_memory_retriever import ( - InMemoryRetriever, +from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import ( + QdrantInMemoryRetriever, RetrieverType, ) @@ -45,24 +45,24 @@ def prompt_image() -> Image: @fixture -def asymmetric_in_memory_retriever( - client: Client, in_memory_retriever_documents: Sequence[Document] -) -> InMemoryRetriever: - return InMemoryRetriever( +def asymmetric_qdrant_in_memory_retriever( + client: Client, qdrant_in_memory_retriever_documents: Sequence[Document] +) -> QdrantInMemoryRetriever: + return QdrantInMemoryRetriever( client, - in_memory_retriever_documents, + qdrant_in_memory_retriever_documents, k=2, retriever_type=RetrieverType.ASYMMETRIC, ) @fixture -def symmetric_in_memory_retriever( - client: Client, in_memory_retriever_documents: Sequence[Document] -) -> InMemoryRetriever: - return InMemoryRetriever( +def symmetric_qdrant_in_memory_retriever( + client: Client, qdrant_in_memory_retriever_documents: Sequence[Document] +) -> QdrantInMemoryRetriever: + return QdrantInMemoryRetriever( client, - in_memory_retriever_documents, + qdrant_in_memory_retriever_documents, k=2, retriever_type=RetrieverType.SYMMETRIC, ) diff --git a/tests/retrievers/test_in_memory.py b/tests/retrievers/test_qdrant_in_memory.py similarity index 80% rename from tests/retrievers/test_in_memory.py rename to tests/retrievers/test_qdrant_in_memory.py index f70b09dca..46115fc23 100644 --- a/tests/retrievers/test_in_memory.py +++ b/tests/retrievers/test_qdrant_in_memory.py @@ -3,8 +3,8 @@ from pytest import fixture from intelligence_layer.connectors.retrievers.base_retriever import Document -from intelligence_layer.connectors.retrievers.in_memory_retriever import ( - InMemoryRetriever, +from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import ( + QdrantInMemoryRetriever, ) @@ -18,7 +18,7 @@ def in_memory_retriever_documents() -> Sequence[Document]: def test_asymmetric_in_memory( - asymmetric_in_memory_retriever: InMemoryRetriever, + asymmetric_in_memory_retriever: QdrantInMemoryRetriever, in_memory_retriever_documents: Sequence[Document], ) -> None: query = "Do you like summer?" @@ -28,7 +28,7 @@ def test_asymmetric_in_memory( def test_symmetric_in_memory( - symmetric_in_memory_retriever: InMemoryRetriever, + symmetric_in_memory_retriever: QdrantInMemoryRetriever, in_memory_retriever_documents: Sequence[Document], ) -> None: query = "I hate drizzle" diff --git a/tests/use_cases/qa/test_retriever_based_qa.py b/tests/use_cases/qa/test_retriever_based_qa.py index 1e5656b36..17a14d476 100644 --- a/tests/use_cases/qa/test_retriever_based_qa.py +++ b/tests/use_cases/qa/test_retriever_based_qa.py @@ -7,8 +7,8 @@ from intelligence_layer.connectors.retrievers.document_index_retriever import ( DocumentIndexRetriever, ) -from intelligence_layer.connectors.retrievers.in_memory_retriever import ( - InMemoryRetriever, +from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import ( + QdrantInMemoryRetriever, ) from intelligence_layer.core.logger import NoOpDebugLogger from intelligence_layer.use_cases.qa.retriever_based_qa import ( @@ -37,7 +37,7 @@ def in_memory_retriever_documents() -> Sequence[Document]: @fixture def retriever_based_qa_with_in_memory_retriever( - client: Client, asymmetric_in_memory_retriever: InMemoryRetriever + client: Client, asymmetric_in_memory_retriever: QdrantInMemoryRetriever ) -> RetrieverBasedQa: return RetrieverBasedQa( client, asymmetric_in_memory_retriever, model="luminous-base-control" diff --git a/tests/use_cases/search/test_filter_search.py b/tests/use_cases/search/test_qdrant_search.py similarity index 76% rename from tests/use_cases/search/test_filter_search.py rename to tests/use_cases/search/test_qdrant_search.py index 5959fd2ff..95784026f 100644 --- a/tests/use_cases/search/test_filter_search.py +++ b/tests/use_cases/search/test_qdrant_search.py @@ -4,13 +4,13 @@ from qdrant_client.http.models import models from intelligence_layer.connectors.retrievers.base_retriever import Document -from intelligence_layer.connectors.retrievers.in_memory_retriever import ( - InMemoryRetriever, +from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import ( + QdrantInMemoryRetriever, ) from intelligence_layer.core.logger import NoOpDebugLogger -from intelligence_layer.use_cases.search.filter_search import ( - FilterSearch, - FilterSearchInput, +from intelligence_layer.use_cases.search.qdrant_search import ( + QdrantSearch, + QdrantSearchInput, ) @@ -33,16 +33,16 @@ def in_memory_retriever_documents() -> Sequence[Document]: @fixture -def filter_search(asymmetric_in_memory_retriever: InMemoryRetriever) -> FilterSearch: - return FilterSearch(asymmetric_in_memory_retriever) +def filter_search(asymmetric_in_memory_retriever: QdrantInMemoryRetriever) -> QdrantSearch: + return QdrantSearch(asymmetric_in_memory_retriever) def test_filter_search( - filter_search: FilterSearch, + filter_search: QdrantSearch, no_op_debug_logger: NoOpDebugLogger, in_memory_retriever_documents: Sequence[Document], ) -> None: - search_input = FilterSearchInput( + search_input = QdrantSearchInput( query="When did Germany reunite?", filter=models.Filter( must=[ diff --git a/tests/use_cases/search/test_search.py b/tests/use_cases/search/test_search.py index aaf930cf3..5e6d1aae8 100644 --- a/tests/use_cases/search/test_search.py +++ b/tests/use_cases/search/test_search.py @@ -2,8 +2,8 @@ from typing import Sequence from intelligence_layer.connectors.retrievers.base_retriever import Document -from intelligence_layer.connectors.retrievers.in_memory_retriever import ( - InMemoryRetriever, +from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import ( + QdrantInMemoryRetriever, ) from intelligence_layer.core.logger import NoOpDebugLogger from intelligence_layer.use_cases.search.search import Search, SearchInput @@ -19,7 +19,7 @@ def in_memory_retriever_documents() -> Sequence[Document]: @fixture -def search(asymmetric_in_memory_retriever: InMemoryRetriever) -> Search: +def search(asymmetric_in_memory_retriever: QdrantInMemoryRetriever) -> Search: return Search(asymmetric_in_memory_retriever)