diff --git a/.github/workflows/azure_ai_search.yml b/.github/workflows/azure_ai_search.yml new file mode 100644 index 000000000..1c10edc91 --- /dev/null +++ b/.github/workflows/azure_ai_search.yml @@ -0,0 +1,72 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / azure_ai_search + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/azure_ai_search/**" + - ".github/workflows/azure_ai_search.yml" + +concurrency: + group: azure_ai_search-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + AZURE_SEARCH_API_KEY: ${{ secrets.AZURE_SEARCH_API_KEY }} + AZURE_SEARCH_SERVICE_ENDPOINT: ${{ secrets.AZURE_SEARCH_SERVICE_ENDPOINT }} + +defaults: + run: + working-directory: integrations/azure_ai_search + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + max-parallel: 3 + matrix: + os: [ubuntu-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' + run: hatch run lint:all + + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + + - name: Run tests + run: hatch run cov-retry + + - name: Nightly - run unit tests with Haystack main branch + if: github.event_name == 'schedule' + run: | + hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run cov-retry -m "not integration" + + - name: Send event to Datadog for nightly failures + if: failure() && github.event_name == 'schedule' + uses: ./.github/actions/send_failure + with: + title: | + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/integrations/azure_ai_search/LICENSE b/integrations/azure_ai_search/LICENSE new file mode 100644 index 000000000..de4c7f39f --- /dev/null +++ b/integrations/azure_ai_search/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 deepset GmbH + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/integrations/azure_ai_search/README.md b/integrations/azure_ai_search/README.md new file mode 100644 index 000000000..915a23b63 --- /dev/null +++ b/integrations/azure_ai_search/README.md @@ -0,0 +1,26 @@ +# Azure AI Search Document Store for Haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/azure-ai-search-haystack.svg)](https://pypi.org/project/azure-ai-search-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/azure-ai-search-haystack.svg)](https://pypi.org/project/azure-ai-search-haystack) + +----- + +**Table of Contents** + +- [Azure AI Search Document Store for Haystack](#azure-ai-search-document-store-for-haystack) + - [Installation](#installation) + - [Examples](#examples) + - [License](#license) + +## Installation + +```console +pip install azure-ai-search-haystack +``` + +## Examples +You can find a code example showing how to use the Document Store and the Retriever in the documentation or in [this Colab](https://colab.research.google.com/drive/1YpDetI8BRbObPDEVdfqUcwhEX9UUXP-m?usp=sharing). + +## License + +`azure-ai-search-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/azure_ai_search/example/document_store.py b/integrations/azure_ai_search/example/document_store.py new file mode 100644 index 000000000..779f28935 --- /dev/null +++ b/integrations/azure_ai_search/example/document_store.py @@ -0,0 +1,44 @@ +from haystack import Document +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +""" +This example demonstrates how to use the AzureAISearchDocumentStore to write and filter documents. +To run this example, you'll need an Azure Search service endpoint and API key, which can either be +set as environment variables (AZURE_SEARCH_SERVICE_ENDPOINT and AZURE_SEARCH_API_KEY) or +provided directly to AzureAISearchDocumentStore(as params "api_key", "azure_endpoint"). +Otherwise you can use DefaultAzureCredential to authenticate with Azure services. +See more details at https://learn.microsoft.com/en-us/azure/search/keyless-connections?tabs=python%2Cazure-cli +""" +document_store = AzureAISearchDocumentStore( + metadata_fields={"version": float, "label": str}, + index_name="document-store-example", +) + +documents = [ + Document( + content="This is an introduction to using Python for data analysis.", + meta={"version": 1.0, "label": "chapter_one"}, + ), + Document( + content="Learn how to use Python libraries for machine learning.", + meta={"version": 1.5, "label": "chapter_two"}, + ), + Document( + content="Advanced Python techniques for data visualization.", + meta={"version": 2.0, "label": "chapter_three"}, + ), +] +document_store.write_documents(documents, policy=DuplicatePolicy.SKIP) + +filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.version", "operator": ">", "value": 1.2}, + {"field": "meta.label", "operator": "in", "value": ["chapter_one", "chapter_three"]}, + ], +} + +results = document_store.filter_documents(filters) +print(results) diff --git a/integrations/azure_ai_search/example/embedding_retrieval.py b/integrations/azure_ai_search/example/embedding_retrieval.py new file mode 100644 index 000000000..088b08653 --- /dev/null +++ b/integrations/azure_ai_search/example/embedding_retrieval.py @@ -0,0 +1,58 @@ +from haystack import Document, Pipeline +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.writers import DocumentWriter +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +""" +This example demonstrates how to use the AzureAISearchEmbeddingRetriever to retrieve documents +using embeddings based on a query. To run this example, you'll need an Azure Search service endpoint +and API key, which can either be +set as environment variables (AZURE_SEARCH_SERVICE_ENDPOINT and AZURE_SEARCH_API_KEY) or +provided directly to AzureAISearchDocumentStore(as params "api_key", "azure_endpoint"). +Otherwise you can use DefaultAzureCredential to authenticate with Azure services. +See more details at https://learn.microsoft.com/en-us/azure/search/keyless-connections?tabs=python%2Cazure-cli +""" + +document_store = AzureAISearchDocumentStore(index_name="retrieval-example") + +model = "sentence-transformers/all-mpnet-base-v2" + +documents = [ + Document(content="There are over 7,000 languages spoken around the world today."), + Document( + content="""Elephants have been observed to behave in a way that indicates a + high level of self-awareness, such as recognizing themselves in mirrors.""" + ), + Document( + content="""In certain parts of the world, like the Maldives, Puerto Rico, and + San Diego, you can witness the phenomenon of bioluminescent waves.""" + ), +] + +document_embedder = SentenceTransformersDocumentEmbedder(model=model) +document_embedder.warm_up() + +# Indexing Pipeline +indexing_pipeline = Pipeline() +indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder") +indexing_pipeline.add_component( + instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="doc_writer" +) +indexing_pipeline.connect("doc_embedder", "doc_writer") + +indexing_pipeline.run({"doc_embedder": {"documents": documents}}) + +# Query Pipeline +query_pipeline = Pipeline() +query_pipeline.add_component("text_embedder", SentenceTransformersTextEmbedder(model=model)) +query_pipeline.add_component("retriever", AzureAISearchEmbeddingRetriever(document_store=document_store)) +query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") + +query = "How many languages are there?" + +result = query_pipeline.run({"text_embedder": {"text": query}}) + +print(result["retriever"]["documents"][0]) diff --git a/integrations/azure_ai_search/pydoc/config.yml b/integrations/azure_ai_search/pydoc/config.yml new file mode 100644 index 000000000..ec411af60 --- /dev/null +++ b/integrations/azure_ai_search/pydoc/config.yml @@ -0,0 +1,31 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever", + "haystack_integrations.document_stores.azure_ai_search.document_store", + "haystack_integrations.document_stores.azure_ai_search.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer + excerpt: Azure AI Search integration for Haystack + category_slug: integrations-api + title: Azure AI Search + slug: integrations-azure_ai_search + order: 180 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_azure_ai_search.md diff --git a/integrations/azure_ai_search/pyproject.toml b/integrations/azure_ai_search/pyproject.toml new file mode 100644 index 000000000..49ca623e7 --- /dev/null +++ b/integrations/azure_ai_search/pyproject.toml @@ -0,0 +1,163 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "azure-ai-search-haystack" +dynamic = ["version"] +description = 'Haystack 2.x Document Store for Azure AI Search' +readme = "README.md" +requires-python = ">=3.8,<3.13" +license = "Apache-2.0" +keywords = [] +authors = [{ name = "deepset", email = "info@deepset.ai" }] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["haystack-ai", "azure-search-documents>=11.5", "azure-identity"] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/azure_ai_search#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/azure_ai_search" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/azure-ai-search-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/azure-ai-search-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "pytest-rerunfailures", + "pytest-xdist", + "haystack-pydoc-tools", +] + +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = ["ruff check {args:src/}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] +all = ["style", "typing"] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.black] +target-version = ["py38"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py38" +line-length = 120 + +[tool.ruff.lint] +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] +exclude = ["example"] + +[tool.ruff.lint.isort] +known-first-party = ["src"] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.lint.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252", "S311"] +"example/**/*" = ["T201"] + +[tool.coverage.run] +source = ["haystack_integrations"] +branch = true +parallel = false + + +[tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + + +[tool.pytest.ini_options] +minversion = "6.0" +markers = ["unit: unit tests", "integration: integration tests"] + +[[tool.mypy.overrides]] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "azure.identity.*", "mypy.*", "azure.core.*", "azure.search.documents.*"] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py new file mode 100644 index 000000000..eb75ffa6c --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py @@ -0,0 +1,3 @@ +from .embedding_retriever import AzureAISearchEmbeddingRetriever + +__all__ = ["AzureAISearchEmbeddingRetriever"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py new file mode 100644 index 000000000..ab649f874 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py @@ -0,0 +1,116 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchEmbeddingRetriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using a vector similarity metric. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the AzureAISearchEmbeddingRetriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the approximate kNN search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :filter_policy: Policy to determine how filters are applied. Possible options: + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise Exception(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query_embedding: floats representing the query embedding + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + if filters is not None: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=normalized_filters, + top_k=top_k, + ) + except Exception as e: + raise e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py new file mode 100644 index 000000000..635878a38 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore +from .filters import normalize_filters + +__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "normalize_filters"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py new file mode 100644 index 000000000..0b59b6e37 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -0,0 +1,440 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import logging +import os +from dataclasses import asdict +from datetime import datetime +from typing import Any, Dict, List, Optional + +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.search.documents import SearchClient +from azure.search.documents.indexes import SearchIndexClient +from azure.search.documents.indexes.models import ( + HnswAlgorithmConfiguration, + HnswParameters, + SearchableField, + SearchField, + SearchFieldDataType, + SearchIndex, + SimpleField, + VectorSearch, + VectorSearchAlgorithmMetric, + VectorSearchProfile, +) +from azure.search.documents.models import VectorizedQuery +from haystack import default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils import Secret, deserialize_secrets_inplace + +from .errors import AzureAISearchDocumentStoreConfigError +from .filters import normalize_filters + +type_mapping = { + str: "Edm.String", + bool: "Edm.Boolean", + int: "Edm.Int32", + float: "Edm.Double", + datetime: "Edm.DateTimeOffset", +} + +DEFAULT_VECTOR_SEARCH = VectorSearch( + profiles=[ + VectorSearchProfile(name="default-vector-config", algorithm_configuration_name="cosine-algorithm-config") + ], + algorithms=[ + HnswAlgorithmConfiguration( + name="cosine-algorithm-config", + parameters=HnswParameters( + metric=VectorSearchAlgorithmMetric.COSINE, + ), + ) + ], +) + +logger = logging.getLogger(__name__) +logging.getLogger("azure").setLevel(logging.ERROR) +logging.getLogger("azure.identity").setLevel(logging.DEBUG) + + +class AzureAISearchDocumentStore: + def __init__( + self, + *, + api_key: Secret = Secret.from_env_var("AZURE_SEARCH_API_KEY", strict=False), # noqa: B008 + azure_endpoint: Secret = Secret.from_env_var("AZURE_SEARCH_SERVICE_ENDPOINT", strict=True), # noqa: B008 + index_name: str = "default", + embedding_dimension: int = 768, + metadata_fields: Optional[Dict[str, type]] = None, + vector_search_configuration: VectorSearch = None, + **kwargs, + ): + """ + A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/) + as the backend. + + :param azure_endpoint: The URL endpoint of an Azure AI Search service. + :param api_key: The API key to use for authentication. + :param index_name: Name of index in Azure AI Search, if it doesn't exist it will be created. + :param embedding_dimension: Dimension of the embeddings. + :param metadata_fields: A dictionary of metadata keys and their types to create + additional fields in index schema. As fields in Azure SearchIndex cannot be dynamic, + it is necessary to specify the metadata fields in advance. + (e.g. metadata_fields = {"author": str, "date": datetime}) + :param vector_search_configuration: Configuration option related to vector search. + Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches. + + :param kwargs: Optional keyword parameters for Azure AI Search. + Some of the supported parameters: + - `api_version`: The Search API version to use for requests. + - `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD). + The audience is not considered when using a shared key. If audience is not provided, + the public cloud audience will be assumed. + + For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/) + """ + + azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None + if not azure_endpoint: + msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT." + raise ValueError(msg) + + api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None + + self._client = None + self._index_client = None + self._index_fields = [] # type: List[Any] # stores all fields in the final schema of index + self._api_key = api_key + self._azure_endpoint = azure_endpoint + self._index_name = index_name + self._embedding_dimension = embedding_dimension + self._dummy_vector = [-10.0] * self._embedding_dimension + self._metadata_fields = metadata_fields + self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH + self._kwargs = kwargs + + @property + def client(self) -> SearchClient: + + # resolve secrets for authentication + resolved_endpoint = ( + self._azure_endpoint.resolve_value() if isinstance(self._azure_endpoint, Secret) else self._azure_endpoint + ) + resolved_key = self._api_key.resolve_value() if isinstance(self._api_key, Secret) else self._api_key + + credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential() + try: + if not self._index_client: + self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs) + if not self._index_exists(self._index_name): + # Create a new index if it does not exist + logger.debug( + "The index '%s' does not exist. A new index will be created.", + self._index_name, + ) + self._create_index(self._index_name) + except (HttpResponseError, ClientAuthenticationError) as error: + msg = f"Failed to authenticate with Azure Search: {error}" + raise AzureAISearchDocumentStoreConfigError(msg) from error + + if self._index_client: + # Get the search client, if index client is initialized + index_fields = self._index_client.get_index(self._index_name).fields + self._index_fields = [field.name for field in index_fields] + self._client = self._index_client.get_search_client(self._index_name) + else: + msg = "Search Index Client is not initialized." + raise AzureAISearchDocumentStoreConfigError(msg) + + return self._client + + def _create_index(self, index_name: str, **kwargs) -> None: + """ + Creates a new search index. + :param index_name: Name of the index to create. If None, the index name from the constructor is used. + :param kwargs: Optional keyword parameters. + """ + + # default fields to create index based on Haystack Document (id, content, embedding) + default_fields = [ + SimpleField(name="id", type=SearchFieldDataType.String, key=True, filterable=True), + SearchableField(name="content", type=SearchFieldDataType.String), + SearchField( + name="embedding", + type=SearchFieldDataType.Collection(SearchFieldDataType.Single), + searchable=True, + hidden=False, + vector_search_dimensions=self._embedding_dimension, + vector_search_profile_name="default-vector-config", + ), + ] + + if not index_name: + index_name = self._index_name + if self._metadata_fields: + default_fields.extend(self._create_metadata_index_fields(self._metadata_fields)) + index = SearchIndex( + name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs + ) + if self._index_client: + self._index_client.create_index(index) + + def to_dict(self) -> Dict[str, Any]: + # This is not the best solution to serialise this class but is the fastest to implement. + # Not all kwargs types can be serialised to text so this can fail. We must serialise each + # type explicitly to handle this properly. + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None, + api_key=self._api_key.to_dict() if self._api_key is not None else None, + index_name=self._index_name, + embedding_dimension=self._embedding_dimension, + metadata_fields=self._metadata_fields, + vector_search_configuration=self._vector_search_configuration.as_dict(), + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_endpoint"]) + if (vector_search_configuration := data["init_parameters"].get("vector_search_configuration")) is not None: + data["init_parameters"]["vector_search_configuration"] = VectorSearch.from_dict(vector_search_configuration) + return default_from_dict(cls, data) + + def count_documents(self) -> int: + """ + Returns how many documents are present in the search index. + + :returns: list of retrieved documents. + """ + return self.client.get_document_count() + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: + """ + Writes the provided documents to search index. + + :param documents: documents to write to the index. + :return: the number of documents added to index. + """ + + def _convert_input_document(documents: Document): + document_dict = asdict(documents) + if not isinstance(document_dict["id"], str): + msg = f"Document id {document_dict['id']} is not a string, " + raise Exception(msg) + index_document = self._convert_haystack_documents_to_azure(document_dict) + + return index_document + + if len(documents) > 0: + if not isinstance(documents[0], Document): + msg = "param 'documents' must contain a list of objects of type Document" + raise ValueError(msg) + + if policy not in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]: + logger.warning( + f"AzureAISearchDocumentStore only supports `DuplicatePolicy.OVERWRITE`" + f"but got {policy}. Overwriting duplicates is enabled by default." + ) + client = self.client + documents_to_write = [(_convert_input_document(doc)) for doc in documents] + + if documents_to_write != []: + client.upload_documents(documents_to_write) + return len(documents_to_write) + + def delete_documents(self, document_ids: List[str]) -> None: + """ + Deletes all documents with a matching document_ids from the search index. + + :param document_ids: ids of the documents to be deleted. + """ + if self.count_documents() == 0: + return + documents = self._get_raw_documents_by_id(document_ids) + if documents: + self.client.delete_documents(documents) + + def get_documents_by_id(self, document_ids: List[str]) -> List[Document]: + return self._convert_search_result_to_documents(self._get_raw_documents_by_id(document_ids)) + + def search_documents(self, search_text: str = "*", top_k: int = 10) -> List[Document]: + """ + Returns all documents that match the provided search_text. + If search_text is None, returns all documents. + :param search_text: the text to search for in the Document list. + :param top_k: Maximum number of documents to return. + :returns: A list of Documents that match the given search_text. + """ + result = self.client.search(search_text=search_text, top=top_k) + return self._convert_search_result_to_documents(list(result)) + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the provided filters. + Filters should be given as a dictionary supporting filtering by metadata. For details on + filters, see the [metadata filtering documentation](https://docs.haystack.deepset.ai/docs/metadata-filtering). + + :param filters: the filters to apply to the document list. + :returns: A list of Documents that match the given filters. + """ + if filters: + normalized_filters = normalize_filters(filters) + result = self.client.search(filter=normalized_filters) + return self._convert_search_result_to_documents(result) + else: + return self.search_documents() + + def _convert_search_result_to_documents(self, azure_docs: List[Dict[str, Any]]) -> List[Document]: + """ + Converts Azure search results to Haystack Documents. + """ + documents = [] + + for azure_doc in azure_docs: + embedding = azure_doc.get("embedding") + if embedding == self._dummy_vector: + embedding = None + + # Anything besides default fields (id, content, and embedding) is considered metadata + meta = { + key: value + for key, value in azure_doc.items() + if key not in ["id", "content", "embedding"] and key in self._index_fields and value is not None + } + + # Create the document with meta only if it's non-empty + doc = Document( + id=azure_doc["id"], content=azure_doc["content"], embedding=embedding, meta=meta if meta else {} + ) + + documents.append(doc) + return documents + + def _index_exists(self, index_name: Optional[str]) -> bool: + """ + Check if the index exists in the Azure AI Search service. + + :param index_name: The name of the index to check. + :returns bool: whether the index exists. + """ + + if self._index_client and index_name: + return index_name in self._index_client.list_index_names() + else: + msg = "Index name is required to check if the index exists." + raise ValueError(msg) + + def _get_raw_documents_by_id(self, document_ids: List[str]): + """ + Retrieves all Azure documents with a matching document_ids from the document store. + + :param document_ids: ids of the documents to be retrieved. + :returns: list of retrieved Azure documents. + """ + azure_documents = [] + for doc_id in document_ids: + try: + document = self.client.get_document(doc_id) + azure_documents.append(document) + except ResourceNotFoundError: + logger.warning(f"Document with ID {doc_id} not found.") + return azure_documents + + def _convert_haystack_documents_to_azure(self, document: Dict[str, Any]) -> Dict[str, Any]: + """Map the document keys to fields of search index""" + + # Because Azure Search does not allow dynamic fields, we only include fields that are part of the schema + index_document = {k: v for k, v in {**document, **document.get("meta", {})}.items() if k in self._index_fields} + if index_document["embedding"] is None: + index_document["embedding"] = self._dummy_vector + + return index_document + + def _create_metadata_index_fields(self, metadata: Dict[str, Any]) -> List[SimpleField]: + """Create a list of index fields for storing metadata values.""" + + index_fields = [] + metadata_field_mapping = self._map_metadata_field_types(metadata) + + for key, field_type in metadata_field_mapping.items(): + index_fields.append(SimpleField(name=key, type=field_type, filterable=True)) + + return index_fields + + def _map_metadata_field_types(self, metadata: Dict[str, type]) -> Dict[str, str]: + """Map metadata field types to Azure Search field types.""" + + metadata_field_mapping = {} + + for key, value_type in metadata.items(): + + if not key[0].isalpha(): + msg = ( + f"Azure Search index only allows field names starting with letters. " + f"Invalid key: {key} will be dropped." + ) + logger.warning(msg) + continue + + field_type = type_mapping.get(value_type) + if not field_type: + error_message = f"Unsupported field type for key '{key}': {value_type}" + raise ValueError(error_message) + metadata_field_mapping[key] = field_type + + return metadata_field_mapping + + def _embedding_retrieval( + self, + query_embedding: List[float], + *, + top_k: int = 10, + fields: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query embedding using a vector similarity metric. + It uses the vector configuration of the document store. By default it uses the HNSW algorithm + with cosine similarity. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. + :param top_k: Maximum number of Documents to return, defaults to 10 + + :raises ValueError: If `query_embedding` is an empty list + :returns: List of Document that are most similar to `query_embedding` + """ + + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") + result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py new file mode 100644 index 000000000..0fbc80696 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py @@ -0,0 +1,20 @@ +from haystack.document_stores.errors import DocumentStoreError +from haystack.errors import FilterError + + +class AzureAISearchDocumentStoreError(DocumentStoreError): + """Parent class for all AzureAISearchDocumentStore exceptions.""" + + pass + + +class AzureAISearchDocumentStoreConfigError(AzureAISearchDocumentStoreError): + """Raised when a configuration is not valid for a AzureAISearchDocumentStore.""" + + pass + + +class AzureAISearchDocumentStoreFilterError(FilterError): + """Raised when filter is not valid for AzureAISearchDocumentStore.""" + + pass diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py new file mode 100644 index 000000000..650e3f8be --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py @@ -0,0 +1,112 @@ +from typing import Any, Dict + +from dateutil import parser + +from .errors import AzureAISearchDocumentStoreFilterError + +LOGICAL_OPERATORS = {"AND": "and", "OR": "or", "NOT": "not"} + + +def normalize_filters(filters: Dict[str, Any]) -> str: + """ + Converts Haystack filters in Azure AI Search compatible filters. + """ + if not isinstance(filters, dict): + msg = """Filters must be a dictionary. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + if "field" in filters: + return _parse_comparison_condition(filters) + return _parse_logical_condition(filters) + + +def _parse_logical_condition(condition: Dict[str, Any]) -> str: + missing_keys = [key for key in ("operator", "conditions") if key not in condition] + if missing_keys: + msg = f"""Missing key(s) {missing_keys} in {condition}. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + operator = condition["operator"] + if operator not in LOGICAL_OPERATORS: + msg = f"Unknown operator {operator}" + raise AzureAISearchDocumentStoreFilterError(msg) + conditions = [] + for c in condition["conditions"]: + # Recursively parse if the condition itself is a logical condition + if isinstance(c, dict) and "operator" in c and c["operator"] in LOGICAL_OPERATORS: + conditions.append(_parse_logical_condition(c)) + else: + # Otherwise, parse it as a comparison condition + conditions.append(_parse_comparison_condition(c)) + + # Format the result based on the operator + if operator == "NOT": + return f"not ({' and '.join([f'({c})' for c in conditions])})" + else: + return f" {LOGICAL_OPERATORS[operator]} ".join([f"({c})" for c in conditions]) + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> str: + missing_keys = [key for key in ("field", "operator", "value") if key not in condition] + if missing_keys: + msg = f"""Missing key(s) {missing_keys} in {condition}. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + # Remove the "meta." prefix from the field name if present + field = condition["field"][5:] if condition["field"].startswith("meta.") else condition["field"] + operator = condition["operator"] + value = "null" if condition["value"] is None else condition["value"] + + if operator not in COMPARISON_OPERATORS: + msg = f"Unknown operator {operator}. Valid operators are: {list(COMPARISON_OPERATORS.keys())}" + raise AzureAISearchDocumentStoreFilterError(msg) + + return COMPARISON_OPERATORS[operator](field, value) + + +def _eq(field: str, value: Any) -> str: + return f"{field} eq '{value}'" if isinstance(value, str) and value != "null" else f"{field} eq {value}" + + +def _ne(field: str, value: Any) -> str: + return f"not ({field} eq '{value}')" if isinstance(value, str) and value != "null" else f"not ({field} eq {value})" + + +def _in(field: str, value: Any) -> str: + if not isinstance(value, list) or any(not isinstance(v, str) for v in value): + msg = "Azure AI Search only supports a list of strings for 'in' comparators" + raise AzureAISearchDocumentStoreFilterError(msg) + values = ", ".join(map(str, value)) + return f"search.in({field},'{values}')" + + +def _comparison_operator(field: str, value: Any, operator: str) -> str: + _validate_type(value, operator) + return f"{field} {operator} {value}" + + +def _validate_type(value: Any, operator: str) -> None: + """Validates that the value is either an integer, float, or ISO 8601 string.""" + msg = f"Invalid value type for '{operator}' comparator. Supported types are: int, float, or ISO 8601 string." + + if isinstance(value, str): + try: + parser.isoparse(value) + except ValueError as e: + raise AzureAISearchDocumentStoreFilterError(msg) from e + elif not isinstance(value, (int, float)): + raise AzureAISearchDocumentStoreFilterError(msg) + + +COMPARISON_OPERATORS = { + "==": _eq, + "!=": _ne, + "in": _in, + ">": lambda f, v: _comparison_operator(f, v, "gt"), + ">=": lambda f, v: _comparison_operator(f, v, "ge"), + "<": lambda f, v: _comparison_operator(f, v, "lt"), + "<=": lambda f, v: _comparison_operator(f, v, "le"), +} diff --git a/integrations/azure_ai_search/tests/__init__.py b/integrations/azure_ai_search/tests/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/azure_ai_search/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py new file mode 100644 index 000000000..3017c79c2 --- /dev/null +++ b/integrations/azure_ai_search/tests/conftest.py @@ -0,0 +1,68 @@ +import os +import time +import uuid + +import pytest +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ResourceNotFoundError +from azure.search.documents.indexes import SearchIndexClient +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +# This is the approximate time in seconds it takes for the documents to be available in Azure Search index +SLEEP_TIME_IN_SECONDS = 5 + + +@pytest.fixture() +def sleep_time(): + return SLEEP_TIME_IN_SECONDS + + +@pytest.fixture +def document_store(request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + index_name = f"haystack_test_{uuid.uuid4().hex}" + metadata_fields = getattr(request, "param", {}).get("metadata_fields", None) + + azure_endpoint = os.environ["AZURE_SEARCH_SERVICE_ENDPOINT"] + api_key = os.environ["AZURE_SEARCH_API_KEY"] + + client = SearchIndexClient(azure_endpoint, AzureKeyCredential(api_key)) + if index_name in client.list_index_names(): + client.delete_index(index_name) + + store = AzureAISearchDocumentStore( + api_key=api_key, + azure_endpoint=azure_endpoint, + index_name=index_name, + create_index=True, + embedding_dimension=768, + metadata_fields=metadata_fields, + ) + + # Override some methods to wait for the documents to be available + original_write_documents = store.write_documents + + def write_documents_and_wait(documents, policy=DuplicatePolicy.OVERWRITE): + written_docs = original_write_documents(documents, policy) + time.sleep(SLEEP_TIME_IN_SECONDS) + return written_docs + + original_delete_documents = store.delete_documents + + def delete_documents_and_wait(filters): + original_delete_documents(filters) + time.sleep(SLEEP_TIME_IN_SECONDS) + + store.write_documents = write_documents_and_wait + store.delete_documents = delete_documents_and_wait + + yield store + try: + client.delete_index(index_name) + except ResourceNotFoundError: + pass diff --git a/integrations/azure_ai_search/tests/test_document_store.py b/integrations/azure_ai_search/tests/test_document_store.py new file mode 100644 index 000000000..1bcd967c6 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_document_store.py @@ -0,0 +1,410 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +import random +from datetime import datetime, timezone +from typing import List +from unittest.mock import patch + +import pytest +from haystack.dataclasses.document import Document +from haystack.errors import FilterError +from haystack.testing.document_store import ( + CountDocumentsTest, + DeleteDocumentsTest, + FilterDocumentsTest, + WriteDocumentsTest, +) +from haystack.utils.auth import EnvVarSecret, Secret + +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_to_dict(monkeypatch): + monkeypatch.setenv("AZURE_SEARCH_API_KEY", "test-api-key") + monkeypatch.setenv("AZURE_SEARCH_SERVICE_ENDPOINT", "test-endpoint") + document_store = AzureAISearchDocumentStore() + res = document_store.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", + "init_parameters": { + "azure_endpoint": {"env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False, "type": "env_var"}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + }, + } + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_from_dict(monkeypatch): + monkeypatch.setenv("AZURE_SEARCH_API_KEY", "test-api-key") + monkeypatch.setenv("AZURE_SEARCH_SERVICE_ENDPOINT", "test-endpoint") + + data = { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", + "init_parameters": { + "azure_endpoint": {"env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False, "type": "env_var"}, + "embedding_dimension": 768, + "index_name": "default", + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + }, + } + document_store = AzureAISearchDocumentStore.from_dict(data) + assert isinstance(document_store._api_key, EnvVarSecret) + assert isinstance(document_store._azure_endpoint, EnvVarSecret) + assert document_store._index_name == "default" + assert document_store._embedding_dimension == 768 + assert document_store._metadata_fields is None + assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_init_is_lazy(_mock_azure_search_client): + AzureAISearchDocumentStore(azure_endpoint=Secret.from_token("test_endpoint")) + _mock_azure_search_client.assert_not_called() + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_init(_mock_azure_search_client): + + document_store = AzureAISearchDocumentStore( + api_key=Secret.from_token("fake-api-key"), + azure_endpoint=Secret.from_token("fake_endpoint"), + index_name="my_index", + embedding_dimension=15, + metadata_fields={"Title": str, "Pages": int}, + ) + + assert document_store._index_name == "my_index" + assert document_store._embedding_dimension == 15 + assert document_store._metadata_fields == {"Title": str, "Pages": int} + assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): + + def test_write_documents(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + assert document_store.write_documents(docs) == 1 + + # Parametrize the test with metadata fields + @pytest.mark.parametrize( + "document_store", + [ + {"metadata_fields": {"author": str, "publication_year": int, "rating": float}}, + ], + indirect=True, + ) + def test_write_documents_with_meta(self, document_store: AzureAISearchDocumentStore): + docs = [ + Document( + id="1", + meta={"author": "Tom", "publication_year": 2021, "rating": 4.5}, + content="This is a test document.", + ) + ] + document_store.write_documents(docs) + doc = document_store.get_documents_by_id(["1"]) + assert doc[0] == docs[0] + + @pytest.mark.skip(reason="Azure AI search index overwrites duplicate documents by default") + def test_write_documents_duplicate_fail(self, document_store: AzureAISearchDocumentStore): ... + + @pytest.mark.skip(reason="Azure AI search index overwrites duplicate documents by default") + def test_write_documents_duplicate_skip(self, document_store: AzureAISearchDocumentStore): ... + + +def _random_embeddings(n): + return [round(random.random(), 7) for _ in range(n)] # nosec: S311 + + +TEST_EMBEDDING_1 = _random_embeddings(768) +TEST_EMBEDDING_2 = _random_embeddings(768) + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.parametrize( + "document_store", + [ + {"metadata_fields": {"name": str, "page": str, "chapter": str, "number": int, "date": datetime}}, + ], + indirect=True, +) +class TestFilters(FilterDocumentsTest): + + # Overriding to change "date" to compatible ISO 8601 format + # and remove incompatible fields (dataframes) for Azure search index + @pytest.fixture + def filterable_docs(self) -> List[Document]: + """Fixture that returns a list of Documents that can be used to test filtering.""" + documents = [] + for i in range(3): + documents.append( + Document( + content=f"A Foo Document {i}", + meta={ + "name": f"name_{i}", + "page": "100", + "chapter": "intro", + "number": 2, + "date": "1969-07-21T20:17:40Z", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Bar Document {i}", + meta={ + "name": f"name_{i}", + "page": "123", + "chapter": "abstract", + "number": -2, + "date": "1972-12-11T19:54:58Z", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Foobar Document {i}", + meta={ + "name": f"name_{i}", + "page": "90", + "chapter": "conclusion", + "number": -10, + "date": "1989-11-09T17:53:00Z", + }, + embedding=_random_embeddings(768), + ) + ) + + documents.append( + Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) + ) + documents.append( + Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) + ) + return documents + + # Overriding to compare the documents with the same order + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + Assert that two lists of Documents are equal. + + This is used in every test, if a Document Store implementation has a different behaviour + it should override this method. This can happen for example when the Document Store sets + a score to returned Documents. Since we can't know what the score will be, we can't compare + the Documents reliably. + """ + sorted_recieved = sorted(received, key=lambda doc: doc.id) + sorted_expected = sorted(expected, key=lambda doc: doc.id) + assert sorted_recieved == sorted_expected + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_greater_than_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_less_than_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_greater_than_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_less_than_equal_with_dataframe(self, document_store, filterable_docs): ... + + # Azure search index supports UTC datetime in ISO 8601 format + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with > comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": ">", "value": "1972-12-11T19:54:58Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + > datetime.strptime("1972-12-11T19:54:58Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with >= comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": ">=", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + >= datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with < comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": "<", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + < datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with <= comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": "<=", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + <= datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + # Override as comparison operators with None/null raise errors + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): + """Test filter_documents() with > comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": ">", "value": None}) + + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): + """Test filter_documents() with >= comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": ">=", "value": None}) + + def test_comparison_less_than_with_none(self, document_store, filterable_docs): + """Test filter_documents() with < comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": "<", "value": None}) + + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): + """Test filter_documents() with <= comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": "<=", "value": None}) + + # Override as Azure AI Search supports 'in' operator only for strings + def test_comparison_in(self, document_store, filterable_docs): + """Test filter_documents() with 'in' comparator""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents({"field": "meta.page", "operator": "in", "value": ["100", "123"]}) + assert len(result) + expected = [d for d in filterable_docs if d.meta.get("page") is not None and d.meta["page"] in ["100", "123"]] + self.assert_documents_are_equal(result, expected) + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in_with_with_non_list(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in_with_with_non_list_iterable(self, document_store, filterable_docs): ... + + def test_missing_condition_operator_key(self, document_store, filterable_docs): + """Test filter_documents() with missing operator key""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents( + filters={"conditions": [{"field": "meta.name", "operator": "eq", "value": "test"}]} + ) + + def test_nested_logical_filters(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + {"field": "meta.name", "operator": "==", "value": "name_0"}, + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "!=", "value": 0}, + {"field": "meta.page", "operator": "==", "value": "123"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + {"field": "meta.page", "operator": "==", "value": "90"}, + ], + }, + ], + } + result = document_store.filter_documents(filters=filters) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + # Ensure all required fields are present in doc.meta + ("name" in doc.meta and doc.meta.get("name") == "name_0") + or ( + all(key in doc.meta for key in ["number", "page"]) + and doc.meta.get("number") != 0 + and doc.meta.get("page") == "123" + ) + or ( + all(key in doc.meta for key in ["page", "chapter"]) + and doc.meta.get("chapter") == "conclusion" + and doc.meta.get("page") == "90" + ) + ) + ], + ) diff --git a/integrations/azure_ai_search/tests/test_embedding_retriever.py b/integrations/azure_ai_search/tests/test_embedding_retriever.py new file mode 100644 index 000000000..d4615ec44 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_embedding_retriever.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List +from unittest.mock import Mock + +import pytest +from azure.core.exceptions import HttpResponseError +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from numpy.random import rand # type: ignore + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchEmbeddingRetriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever.AzureAISearchEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever.AzureAISearchEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchEmbeddingRetriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + document_store.write_documents(docs) + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + res = retriever.run(query_embedding=[0.1] * 768) + assert res["documents"] == docs + + def test_embedding_retrieval(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 200 + [0.1] * 300 + [0.2] * 268 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="This is first document", embedding=most_similar_embedding), + Document(content="This is second document", embedding=second_best_embedding), + Document(content="This is thrid document", embedding=another_embedding), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + results = retriever.run(query_embedding=query_embedding) + assert results["documents"][0].content == "This is first document" + + def test_empty_query_embedding(self, document_store: AzureAISearchDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(HttpResponseError): + document_store._embedding_retrieval(query_embedding=query_embedding)