Skip to content

Commit

Permalink
Merge branch 'main' into docemb
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Oct 23, 2023
2 parents 11070ef + 101bd81 commit cd59aad
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 97 deletions.
22 changes: 17 additions & 5 deletions haystack/preview/components/file_converters/azure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import List, Union, Dict, Any
from typing import List, Union, Dict, Any, Optional
import os

from haystack.preview.lazy_imports import LazyImport
from haystack.preview import component, Document, default_to_dict
Expand All @@ -22,22 +23,33 @@ class AzureOCRDocumentConverter:
to set up your resource.
"""

def __init__(self, endpoint: str, api_key: str, model_id: str = "prebuilt-read"):
def __init__(self, endpoint: str, api_key: Optional[str] = None, model_id: str = "prebuilt-read"):
"""
Create an AzureOCRDocumentConverter component.
:param endpoint: The endpoint of your Azure resource.
:param api_key: The key of your Azure resource.
:param api_key: The key of your Azure resource. It can be
explicitly provided or automatically read from the
environment variable AZURE_AI_API_KEY (recommended).
:param model_id: The model ID of the model you want to use. Please refer to [Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/choose-model-feature)
for a list of available models. Default: `"prebuilt-read"`.
"""
azure_import.check()

if api_key is None:
try:
api_key = os.environ["AZURE_AI_API_KEY"]
except KeyError as e:
raise ValueError(
"AzureOCRDocumentConverter expects an Azure Credential key. "
"Set the AZURE_AI_API_KEY environment variable (recommended) or pass it explicitly."
) from e

self.api_key = api_key
self.document_analysis_client = DocumentAnalysisClient(
endpoint=endpoint, credential=AzureKeyCredential(api_key)
)
self.endpoint = endpoint
self.api_key = api_key
self.model_id = model_id

@component.output_types(documents=List[Document], azure=List[Dict])
Expand Down Expand Up @@ -70,7 +82,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, endpoint=self.endpoint, api_key=self.api_key, model_id=self.model_id)
return default_to_dict(self, endpoint=self.endpoint, model_id=self.model_id)

@staticmethod
def _convert_azure_result_to_document(result: "AnalyzeResult", file_suffix: str) -> Document:
Expand Down
20 changes: 13 additions & 7 deletions haystack/preview/components/websearch/serper_dev.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import logging
from typing import Dict, List, Optional, Any

Expand Down Expand Up @@ -26,20 +27,29 @@ class SerperDevWebSearch:

def __init__(
self,
api_key: str,
api_key: Optional[str] = None,
top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None,
search_params: Optional[Dict[str, Any]] = None,
):
"""
:param api_key: API key for the SerperDev API.
:param api_key: API key for the SerperDev API. It can be
explicitly provided or automatically read from the
environment variable SERPERDEV_API_KEY (recommended).
:param top_k: Number of documents to return.
:param allowed_domains: List of domains to limit the search to.
:param search_params: Additional parameters passed to the SerperDev API.
For example, you can set 'num' to 20 to increase the number of search results.
See the [Serper Dev website](https://serper.dev/) for more details.
"""
if api_key is None:
try:
api_key = os.environ["SERPERDEV_API_KEY"]
except KeyError as e:
raise ValueError(
"SerperDevWebSearch expects an API key. "
"Set the SERPERDEV_API_KEY environment variable (recommended) or pass it explicitly."
) from e
raise ValueError("API key for SerperDev API must be set.")
self.api_key = api_key
self.top_k = top_k
Expand All @@ -51,11 +61,7 @@ def to_dict(self) -> Dict[str, Any]:
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
api_key=self.api_key,
top_k=self.top_k,
allowed_domains=self.allowed_domains,
search_params=self.search_params,
self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params
)

@component.output_types(documents=List[Document], links=List[str])
Expand Down
8 changes: 3 additions & 5 deletions haystack/preview/dataclasses/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from dataclasses import asdict, dataclass, field, fields
from pathlib import Path
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, List, Optional, Type

import numpy
import pandas
Expand Down Expand Up @@ -42,8 +42,6 @@ def document_decoder(self, dictionary):
dictionary["array"] = numpy.array(dictionary.get("array"))
if "dataframe" in dictionary and dictionary.get("dataframe"):
dictionary["dataframe"] = pandas.read_json(dictionary.get("dataframe", None))
if "embedding" in dictionary and dictionary.get("embedding"):
dictionary["embedding"] = numpy.array(dictionary.get("embedding"))

return dictionary

Expand Down Expand Up @@ -75,7 +73,7 @@ class Document:
mime_type: str = field(default="text/plain")
metadata: Dict[str, Any] = field(default_factory=dict)
score: Optional[float] = field(default=None)
embedding: Optional[numpy.ndarray] = field(default=None, repr=False)
embedding: Optional[List[float]] = field(default=None, repr=False)

def __str__(self):
fields = [f"mimetype: '{self.mime_type}'"]
Expand Down Expand Up @@ -120,7 +118,7 @@ def _create_id(self):
blob = self.blob or None
mime_type = self.mime_type or None
metadata = self.metadata or {}
embedding = self.embedding.tolist() if self.embedding is not None else None
embedding = self.embedding if self.embedding is not None else None
data = f"{text}{array}{dataframe}{blob}{mime_type}{metadata}{embedding}"
return hashlib.sha256(data.encode("utf-8")).hexdigest()

Expand Down
32 changes: 15 additions & 17 deletions haystack/preview/testing/document_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=too-many-public-methods
from typing import List
import random

import pytest
import numpy as np
Expand All @@ -11,37 +12,41 @@
from haystack.preview.errors import FilterError


def _random_embeddings(n):
return [random.random() for _ in range(n)]


class DocumentStoreBaseTests:
@pytest.fixture
def docstore(self) -> DocumentStore:
raise NotImplementedError()

@pytest.fixture
def filterable_docs(self) -> List[Document]:
embedding_zero = np.zeros(768).astype(np.float32)
embedding_one = np.ones(768).astype(np.float32)
embedding_zero = [0.0] * 768
embedding_one = [1.0] * 768

documents = []
for i in range(3):
documents.append(
Document(
text=f"A Foo Document {i}",
metadata={"name": f"name_{i}", "page": "100", "chapter": "intro", "number": 2},
embedding=np.random.rand(768).astype(np.float32),
embedding=_random_embeddings(768),
)
)
documents.append(
Document(
text=f"A Bar Document {i}",
metadata={"name": f"name_{i}", "page": "123", "chapter": "abstract", "number": -2},
embedding=np.random.rand(768).astype(np.float32),
embedding=_random_embeddings(768),
)
)
documents.append(
Document(
text=f"A Foobar Document {i}",
metadata={"name": f"name_{i}", "page": "90", "chapter": "conclusion", "number": -10},
embedding=np.random.rand(768).astype(np.float32),
embedding=_random_embeddings(768),
)
)
documents.append(
Expand Down Expand Up @@ -209,11 +214,9 @@ def test_eq_filter_table(self, docstore: DocumentStore, filterable_docs: List[Do
@pytest.mark.unit
def test_eq_filter_embedding(self, docstore: DocumentStore, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs)
embedding = np.zeros(768).astype(np.float32)
embedding = [0.0] * 768
result = docstore.filter_documents(filters={"embedding": embedding})
assert self.contains_same_docs(
result, [doc for doc in filterable_docs if np.array_equal(embedding, doc.embedding)] # type: ignore
)
assert self.contains_same_docs(result, [doc for doc in filterable_docs if embedding == doc.embedding])

@pytest.mark.unit
def test_in_filter_explicit(self, docstore: DocumentStore, filterable_docs: List[Document]):
Expand Down Expand Up @@ -248,17 +251,12 @@ def test_in_filter_table(self, docstore: DocumentStore, filterable_docs: List[Do
@pytest.mark.unit
def test_in_filter_embedding(self, docstore: DocumentStore, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs)
embedding_zero = np.zeros(768, np.float32)
embedding_one = np.ones(768, np.float32)
embedding_zero = [0.0] * 768
embedding_one = [1.0] * 768
result = docstore.filter_documents(filters={"embedding": {"$in": [embedding_zero, embedding_one]}})
assert self.contains_same_docs(
result,
[
doc
for doc in filterable_docs
if isinstance(doc.embedding, np.ndarray)
and (np.array_equal(embedding_zero, doc.embedding) or np.array_equal(embedding_one, doc.embedding))
],
[doc for doc in filterable_docs if (embedding_zero == doc.embedding or embedding_one == doc.embedding)],
)

@pytest.mark.unit
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Change `Document`'s `embedding` field type from `numpy.ndarray` to `List[float]`
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Remove "api_key" from serialization of AzureOCRDocumentConverter and SerperDevWebSearch.
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@


class TestAzureOCRDocumentConverter:
@pytest.mark.unit
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("AZURE_AI_API_KEY", raising=False)
with pytest.raises(ValueError, match="AzureOCRDocumentConverter expects an Azure Credential key"):
AzureOCRDocumentConverter(endpoint="test_endpoint")

@pytest.mark.unit
def test_to_dict(self):
component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key")
data = component.to_dict()
assert data == {
"type": "AzureOCRDocumentConverter",
"init_parameters": {
"api_key": "test_credential_key",
"endpoint": "test_endpoint",
"model_id": "prebuilt-read",
},
"init_parameters": {"endpoint": "test_endpoint", "model_id": "prebuilt-read"},
}

@pytest.mark.unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def test_valid_run(self):
top_k = 3
ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
docs = [
Document(text="my document", embedding=np.array([0.1, 0.2, 0.3, 0.4])),
Document(text="another document", embedding=np.array([1.0, 1.0, 1.0, 1.0])),
Document(text="third document", embedding=np.array([0.5, 0.7, 0.5, 0.7])),
Document(text="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(text="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(text="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
]
ds.write_documents(docs)

Expand All @@ -142,17 +142,17 @@ def test_run_with_pipeline(self):
ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
top_k = 2
docs = [
Document(text="my document", embedding=np.array([0.1, 0.2, 0.3, 0.4])),
Document(text="another document", embedding=np.array([1.0, 1.0, 1.0, 1.0])),
Document(text="third document", embedding=np.array([0.5, 0.7, 0.5, 0.7])),
Document(text="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
Document(text="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(text="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
]
ds.write_documents(docs)
retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k)

pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
result: Dict[str, Any] = pipeline.run(
data={"retriever": {"query_embedding": np.array([0.1, 0.1, 0.1, 0.1]), "return_embedding": True}}
data={"retriever": {"query_embedding": [0.1, 0.1, 0.1, 0.1], "return_embedding": True}}
)

assert result
Expand Down
13 changes: 7 additions & 6 deletions test/preview/components/websearch/test_serperdev.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def mock_serper_dev_search_result():


class TestSerperDevSearchAPI:
@pytest.mark.unit
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("SERPERDEV_API_KEY", raising=False)
with pytest.raises(ValueError, match="SerperDevWebSearch expects an API key"):
SerperDevWebSearch()

@pytest.mark.unit
def test_to_dict(self):
component = SerperDevWebSearch(
Expand All @@ -116,12 +122,7 @@ def test_to_dict(self):
data = component.to_dict()
assert data == {
"type": "SerperDevWebSearch",
"init_parameters": {
"api_key": "test_key",
"top_k": 10,
"allowed_domains": ["test.com"],
"search_params": {"param": "test"},
},
"init_parameters": {"top_k": 10, "allowed_domains": ["test.com"], "search_params": {"param": "test"}},
}

@pytest.mark.unit
Expand Down
14 changes: 7 additions & 7 deletions test/preview/dataclasses/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def __eq__(self, other):
return True

foo = TestObject()
doc1 = Document(text="test text", metadata={"value": np.array([0, 1, 2]), "path": Path("."), "obj": foo})
doc2 = Document(text="test text", metadata={"value": np.array([0, 1, 2]), "path": Path("."), "obj": foo})
doc1 = Document(text="test text", metadata={"value": [0, 1, 2], "path": Path("."), "obj": foo})
doc2 = Document(text="test text", metadata={"value": [0, 1, 2], "path": Path("."), "obj": foo})
assert doc1 == doc2


Expand Down Expand Up @@ -107,7 +107,7 @@ def test_full_document_to_dict():
mime_type="application/pdf",
metadata={"some": "values", "test": 10},
score=0.99,
embedding=np.zeros([10, 10]),
embedding=[10, 10],
)
dictionary = doc.to_dict()

Expand All @@ -121,7 +121,7 @@ def test_full_document_to_dict():
assert blob == doc.blob

embedding = dictionary.pop("embedding")
assert (embedding == doc.embedding).all()
assert embedding == doc.embedding

assert dictionary == {
"id": doc.id,
Expand All @@ -134,7 +134,7 @@ def test_full_document_to_dict():

@pytest.mark.unit
def test_document_with_most_attributes_from_dict():
embedding = np.zeros([10, 10])
embedding = [10, 10]
assert Document.from_dict(
{
"text": "test text",
Expand Down Expand Up @@ -194,7 +194,7 @@ def __repr__(self):
mime_type="application/pdf",
metadata={"some object": TestClass(), "a path": tmp_path / "test.txt"},
score=0.5,
embedding=np.array([1, 2, 3, 4]),
embedding=[1, 2, 3, 4],
)
assert doc_1.to_json() == json.dumps(
{
Expand Down Expand Up @@ -241,7 +241,7 @@ def __eq__(self, other):
# Note the object serialization
metadata={"some object": "<the object>", "a path": str((tmp_path / "test.txt").absolute())},
score=0.5,
embedding=np.array([1, 2, 3, 4]),
embedding=[1, 2, 3, 4],
)


Expand Down
Loading

0 comments on commit cd59aad

Please sign in to comment.