Skip to content

Commit

Permalink
Fix ser/deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Oct 28, 2024
1 parent 54570e3 commit b9563ba
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
api_key: Secret = Secret.from_env_var("AZURE_SEARCH_API_KEY", strict=False), # noqa: B008
azure_endpoint: Secret = Secret.from_env_var("AZURE_SEARCH_SERVICE_ENDPOINT", strict=False), # noqa: B008
index_name: str = "default",
embedding_dimension: int = 768,
embedding_dimension: Optional[int] = 768,
metadata_fields: Optional[Dict[str, type]] = None,
vector_search_configuration: VectorSearch = None,
create_index: bool = True,
Expand Down Expand Up @@ -104,6 +104,7 @@ def __init__(
if not azure_endpoint:
msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT."
raise ValueError(msg)

api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY")

self._client = None
Expand All @@ -122,15 +123,16 @@ def __init__(
@property
def client(self) -> SearchClient:

if isinstance(self._azure_endpoint, Secret):
self._azure_endpoint = self._azure_endpoint.resolve_value()
# resolve secrets for authentication
resolved_endpoint = (
self._azure_endpoint.resolve_value() if isinstance(self._azure_endpoint, Secret) else self._azure_endpoint
)
resolved_key = self._api_key.resolve_value() if isinstance(self._api_key, Secret) else self._api_key

if isinstance(self._api_key, Secret):
self._api_key = self._api_key.resolve_value()
credential = AzureKeyCredential(self._api_key) if self._api_key else DefaultAzureCredential()
credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential()
try:
if not self._index_client:
self._index_client = SearchIndexClient(self._azure_endpoint, credential, **self._kwargs)
self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs)
if not self.index_exists(self._index_name):
# Create a new index if it does not exist
logger.debug(
Expand Down Expand Up @@ -202,7 +204,7 @@ def to_dict(self) -> Dict[str, Any]:
create_index=self._create_index,
embedding_dimension=self._embedding_dimension,
metadata_fields=self._metadata_fields,
vector_search_configuration=self._vector_search_configuration,
vector_search_configuration=self._vector_search_configuration.as_dict(),
**self._kwargs,
)

Expand All @@ -219,6 +221,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore":
"""

deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_endpoint"])
if (vector_search_configuration := data["init_parameters"].get("vector_search_configuration")) is not None:
data["init_parameters"]["vector_search_configuration"] = VectorSearch.from_dict(vector_search_configuration)
return default_from_dict(cls, data)

def count_documents(self, **kwargs: Any) -> int:
Expand Down
13 changes: 12 additions & 1 deletion integrations/azure_ai_search/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,18 @@ def test_to_dict(monkeypatch):
"embedding_dimension": 768,
"metadata_fields": None,
"create_index": True,
"vector_search_configuration": DEFAULT_VECTOR_SEARCH,
"vector_search_configuration": {
"profiles": [
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
],
"algorithms": [
{
"name": "cosine-algorithm-config",
"kind": "hnsw",
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
}
],
},
},
}

Expand Down
13 changes: 12 additions & 1 deletion integrations/azure_ai_search/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,18 @@ def test_to_dict():
"create_index": True,
"embedding_dimension": 768,
"metadata_fields": None,
"vector_search_configuration": DEFAULT_VECTOR_SEARCH,
"vector_search_configuration": {
"profiles": [
{"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"}
],
"algorithms": [
{
"name": "cosine-algorithm-config",
"kind": "hnsw",
"parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"},
}
],
},
"hosts": "some fake host",
},
},
Expand Down

0 comments on commit b9563ba

Please sign in to comment.