Skip to content

Commit

Permalink
Refine DocumentIndexClient
Browse files Browse the repository at this point in the history
  • Loading branch information
NickyHavoc committed Nov 2, 2023
1 parent 69b02ca commit 4e93262
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/intelligence_layer/connectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .document_index.document_index import DocumentIndex
from .document_index.document_index import DocumentIndexClient
from .retrievers.base_retriever import BaseRetriever, Document, SearchResult
from .retrievers.document_index_retriever import DocumentIndexRetriever
from .retrievers.qdrant_in_memory_retriever import (
Expand Down
206 changes: 184 additions & 22 deletions src/intelligence_layer/connectors/document_index/document_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@


class DocumentContents(BaseModel):
"""Actual content of a document.
Note:
Currently only supports text-only documents.
Args:
contents: List of text items.
"""

contents: Sequence[str]

@classmethod
Expand All @@ -33,11 +42,27 @@ def _to_modalities_json(self) -> Sequence[Mapping[str, str]]:


class CollectionPath(BaseModel):
"""Path to a collection.
Args:
namespace: Holds collections.
collection: Holds documents.
Unique within a namespace.
"""

namespace: str
collection: str


class DocumentPath(BaseModel):
"""Path to a document.
Args:
collection_path: Path to a collection.
document_name: Points to a document.
Unique within a collection.
"""

collection_path: CollectionPath
document_name: str

Expand All @@ -53,6 +78,15 @@ def _from_json(cls, document_path_json: Mapping[str, str]) -> "DocumentPath":


class DocumentInfo(BaseModel):
"""Presents an overview of a document.
Args:
document_path: Path to a document.
created: When this version of the document was created.
Equivalent to when it was last updated.
version: How many times the document was updated.
"""

document_path: DocumentPath
created: datetime
version: int
Expand All @@ -71,12 +105,32 @@ def _from_list_documents_response(


class SearchQuery(BaseModel):
"""Query to search through a collection with.
Args:
query: Actual text to be searched with.
max_results: Max number of search results to be retrieved by the query.
Must be larger than 0.
min_score: Min score needed for a search result to be returned.
Must be between 0 and 1.
"""

query: str
max_results: int = Field(..., ge=0)
min_score: float = Field(..., ge=0.0, le=1.0)


class DocumentSearchResult(BaseModel):
"""Result of a search query for one individual section.
Args:
document_path: Path to the document that the section originates from.
section: Actual section of the document that was found as a match to the query.
score: Actual search score of the section found.
Generally, higher scores correspond to better matches.
Will be between 0 and 1.
"""

document_path: DocumentPath
section: str
score: float
Expand All @@ -92,7 +146,20 @@ def _from_search_response(
)


class DocumentIndex:
class DocumentIndexError(Exception):
"""Raised in case of any `DocumentIndexClient`-related errors.
Attributes:
message: The error message as returned by the Document Index.
status_code: The http error code.
"""

def __init__(self, message: str, status_code: int) -> None:
super().__init__(message)
self.status_code = status_code


class DocumentIndexClient:
"""Client for the Document Index allowing handling documents and search.
Document Index is a tool for managing collections of documents, enabling operations such as creation, deletion, listing, and searching.
Expand All @@ -104,21 +171,25 @@ class DocumentIndex:
Example:
>>> document_index = DocumentIndex(os.getenv("AA_TOKEN"))
>>> document_index.create_collection(namespace="my_namespace", collection="germany_facts_collection")
>>> collection_path = CollectionPath(
>>> namespace="my_namespace",
>>> collection="germany_facts_collection"
>>> )
>>> document_index.create_collection(collection_path)
>>> document_index.add_document(
>>> document_path=CollectionPath(
>>> namespace="my_namespace",
>>> collection="germany_facts_collection",
>>> document_name="Fun facts about Germany",
>>> )
>>> content=DocumentContents.from_text("Germany is a country located in ...")
>>> document_path=DocumentPath(
>>> collection_path=collection_path,
>>> document_name="Fun facts about Germany"
>>> ),
>>> contents=DocumentContents.from_text("Germany is a country located in ...")
>>> )
>>> documents = document_index.search(
>>> namespace="my_namespace",
>>> collection="germany_facts_collection",
>>> query: "What is the capital of Germany",
>>> max_results=4,
>>> min_score: 0.5
>>> search_result = document_index.asymmetric_search(
>>> collection_path=collection_path,
>>> search_query=SearchQuery(
>>> query="What is the capital of Germany",
>>> max_results=4,
>>> min_score=0.5
>>> )
>>> )
"""

Expand All @@ -134,20 +205,51 @@ def __init__(
"Authorization": f"Bearer {token}",
}

def _raise_for_status(self, response: requests.Response) -> None:
try:
response.raise_for_status()
except:
raise DocumentIndexError(response.text, response.status_code)

def create_collection(self, collection_path: CollectionPath) -> None:
"""Creates a collection at the path.
Note:
Collection's name must be unique within a namespace.
Args:
collection_path: Path to the collection of interest.
"""

url = f"{self._base_document_index_url}/collections/{collection_path.namespace}/{collection_path.collection}"
response = requests.put(url, headers=self.headers)
response.raise_for_status()
self._raise_for_status(response)

def delete_collection(self, collection_path: CollectionPath) -> None:
"""Deletes the collection at the path.
Args:
collection_path: Path to the collection of interest.
"""

url = f"{self._base_document_index_url}/collections/{collection_path.namespace}/{collection_path.collection}"
response = requests.delete(url, headers=self.headers)
response.raise_for_status()
self._raise_for_status(response)

def list_collections(self, namespace: str) -> Sequence[str]:
"""Lists all collections within a namespace.
Args:
namespace: For a collection of documents.
Typically corresponds to an organization.
Returns:
List of all collections' names.
"""

url = f"{self._base_document_index_url}/collections/{namespace}"
response = requests.get(url, headers=self.headers)
response.raise_for_status()
self._raise_for_status(response)
collections: Sequence[str] = response.json()
return collections

Expand All @@ -156,29 +258,67 @@ def add_document(
document_path: DocumentPath,
contents: DocumentContents,
) -> None:
"""Add a document to a collection.
Note:
If a document with the same `document_path` exists, it will be updated with the new `contents`.
Args:
document_path: Consists of `collection_path` and name of document to be created.
contents: Actual content of the document.
Currently only supports text.
"""

url = f"{self._base_document_index_url}/collections/{document_path.collection_path.namespace}/{document_path.collection_path.collection}/docs/{document_path.document_name}"
data = {
"schema_version": "V1",
"contents": contents._to_modalities_json(),
}
response = requests.put(url, data=json.dumps(data), headers=self.headers)
response.raise_for_status()
self._raise_for_status(response)

def delete_document(self, document_path: DocumentPath) -> None:
"""Delete a document from a collection.
Args:
document_path: Consists of `collection_path` and name of document to be deleted.
"""

url = f"{self._base_document_index_url}/collections/{document_path.collection_path.namespace}/{document_path.collection_path.collection}/docs/{document_path.document_name}"
response = requests.delete(url, headers=self.headers)
response.raise_for_status()
self._raise_for_status(response)

def document(self, document_path: DocumentPath) -> DocumentContents:
"""Retrieve a document from a collection.
Args:
document_path: Consists of `collection_path` and name of document to be retrieved.
Returns:
Content of the retrieved document.
"""

url = f"{self._base_document_index_url}/collections/{document_path.collection_path.namespace}/{document_path.collection_path.collection}/docs/{document_path.document_name}"
response = requests.get(url, headers=self.headers)
response.raise_for_status()
self._raise_for_status(response)
return DocumentContents._from_modalities_json(response.json())

def list_documents(self, collection_path: CollectionPath) -> Sequence[DocumentInfo]:
"""List all documents within a collection.
Note:
Does not return each document's content.
Args:
collection_path: Path to the collection of interest.
Returns:
Overview of all documents within the collection.
"""

url = f"{self._base_document_index_url}/collections/{collection_path.namespace}/{collection_path.collection}/docs"
response = requests.get(url, headers=self.headers)
response.raise_for_status()
self._raise_for_status(response)
return [DocumentInfo._from_list_documents_response(r) for r in response.json()]

def search(
Expand All @@ -187,6 +327,18 @@ def search(
index: str,
search_query: SearchQuery,
) -> Sequence[DocumentSearchResult]:
"""Search through a collection with a `search_query`.
Args:
collection_path: Path to the collection of interest.
index: Name of the search configuration.
Currently only supports "asymmetric".
search_query: The query to search with.
Returns:
Result of the search operation. Will be empty if nothing was retrieved.
"""

url = f"{self._base_document_index_url}/collections/{collection_path.namespace}/{collection_path.collection}/indexes/{index}/search"
data = {
"query": [{"modality": "text", "text": search_query.query}],
Expand All @@ -195,12 +347,22 @@ def search(
"filter": [{"with": [{"modality": "text"}]}],
}
response = requests.post(url, data=json.dumps(data), headers=self.headers)
response.raise_for_status()
self._raise_for_status(response)
return [DocumentSearchResult._from_search_response(r) for r in response.json()]

def asymmetric_search(
self,
collection_path: CollectionPath,
search_query: SearchQuery,
) -> Sequence[DocumentSearchResult]:
"""Search through a collection with a `search_query` using the asymmetric search configuration.
Args:
collection_path: Path to the collection of interest.
search_query: The query to search with.
Returns:
Result of the search operation. Will be empty if nothing was retrieved.
"""

return self.search(collection_path, "asymmetric", search_query)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from intelligence_layer.connectors.document_index.document_index import (
CollectionPath,
DocumentIndex,
DocumentIndexClient,
SearchQuery,
)
from intelligence_layer.connectors.retrievers.base_retriever import (
Expand All @@ -13,14 +13,14 @@


class DocumentIndexRetriever(BaseRetriever):
"""Search through documents within collections in the `DocumentIndex`.
"""Search through documents within collections in the `DocumentIndexClient`.
We initialize this Retriever with a collection & namespace names, and we can find the documents in the collection
most semanticly similar to our query.
Args:
document_index: Client offering functionality for search.
namespace: The namespace within the `DocumentIndex` where all collections are stored.
namespace: The namespace within the `DocumentIndexClient` where all collections are stored.
collection: The collection within the namespace that holds the desired documents.
k: The (top) number of documents to be returned by search.
threshold: The mimumum value of cosine similarity between the query vector and the document vector.
Expand All @@ -34,7 +34,7 @@ class DocumentIndexRetriever(BaseRetriever):

def __init__(
self,
document_index: DocumentIndex,
document_index: DocumentIndexClient,
namespace: str,
collection: str,
k: int,
Expand Down
12 changes: 8 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from intelligence_layer.connectors.retrievers.document_index_retriever import (
DocumentIndexRetriever,
)
from intelligence_layer.connectors.document_index.document_index import DocumentIndex
from intelligence_layer.connectors.document_index.document_index import (
DocumentIndexClient,
)
from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import (
QdrantInMemoryRetriever,
RetrieverType,
Expand Down Expand Up @@ -69,12 +71,14 @@ def symmetric_in_memory_retriever(


@fixture
def document_index(token: str) -> DocumentIndex:
return DocumentIndex(token)
def document_index(token: str) -> DocumentIndexClient:
return DocumentIndexClient(token)


@fixture
def document_index_retriever(document_index: DocumentIndex) -> DocumentIndexRetriever:
def document_index_retriever(
document_index: DocumentIndexClient,
) -> DocumentIndexRetriever:
return DocumentIndexRetriever(
document_index, namespace="aleph-alpha", collection="wikipedia-de", k=2
)
Loading

0 comments on commit 4e93262

Please sign in to comment.