Skip to content

Commit

Permalink
typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Nov 22, 2024
1 parent 710543a commit 6dbd6f9
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion libs/elasticsearch/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def async_es_embeddings_cache_fx(
@pytest.fixture
def es_cache_fx(
es_client_fx: MagicMock,
) -> Generator[AsyncElasticsearchCache, None, None]:
) -> Generator[ElasticsearchCache, None, None]:
with mock.patch(
"langchain_elasticsearch._sync.cache.create_elasticsearch_client",
return_value=es_client_fx,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import uuid
from typing import Iterator
from typing import AsyncIterator

import pytest
from langchain.memory import ConversationBufferMemory
Expand All @@ -24,7 +24,7 @@

class TestElasticsearch:
@pytest.fixture
async def elasticsearch_connection(self) -> Iterator[dict]:
async def elasticsearch_connection(self) -> AsyncIterator[dict]:
params = read_env()
es = create_es_client(params)

Expand Down Expand Up @@ -60,7 +60,7 @@ async def test_memory_with_message_store(
)

# get the message history from the memory store and turn it into a json
messages = await memory.chat_memory.aget_messages()
messages = await memory.chat_memory.aget_messages() # type: ignore[attr-defined]
messages_json = json.dumps([message_to_dict(msg) for msg in messages])

assert "This is me, the AI" in messages_json
Expand All @@ -69,4 +69,4 @@ async def test_memory_with_message_store(
# remove the record from Elasticsearch, so the next test run won't pick it up
await memory.chat_memory.aclear()

assert await memory.chat_memory.aget_messages() == []
assert await memory.chat_memory.aget_messages() == [] # type: ignore[attr-defined]
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_memory_with_message_store(
)

# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.get_messages()
messages = memory.chat_memory.get_messages() # type: ignore[attr-defined]
messages_json = json.dumps([message_to_dict(msg) for msg in messages])

assert "This is me, the AI" in messages_json
Expand All @@ -69,4 +69,4 @@ def test_memory_with_message_store(
# remove the record from Elasticsearch, so the next test run won't pick it up
memory.chat_memory.clear()

assert memory.chat_memory.get_messages() == []
assert memory.chat_memory.get_messages() == [] # type: ignore[attr-defined]
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test Elasticsearch functionality."""

import re
from typing import Any, Dict, Generator, List, Optional
from typing import Any, Dict, AsyncGenerator, List, Optional
from unittest.mock import AsyncMock

import pytest
Expand Down Expand Up @@ -192,7 +192,7 @@ def embeddings(self) -> Embeddings:
return AsyncConsistentFakeEmbeddings()

@pytest.fixture
async def store(self) -> Generator[AsyncElasticsearchStore, None, None]:
async def store(self) -> AsyncGenerator:
client = AsyncElasticsearch(hosts=["http://dummy:9200"]) # never connected to
store = AsyncElasticsearchStore(index_name="test_index", es_connection=client)
try:
Expand All @@ -203,7 +203,7 @@ async def store(self) -> Generator[AsyncElasticsearchStore, None, None]:
@pytest.fixture
async def hybrid_store(
self, embeddings: Embeddings
) -> Generator[AsyncElasticsearchStore, None, None]:
) -> AsyncGenerator:
client = AsyncElasticsearch(hosts=["http://dummy:9200"]) # never connected to
store = AsyncElasticsearchStore(
index_name="test_index",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def embeddings(self) -> Embeddings:
return ConsistentFakeEmbeddings()

@pytest.fixture
def store(self) -> Generator[ElasticsearchStore, None, None]:
def store(self) -> Generator:
client = Elasticsearch(hosts=["http://dummy:9200"]) # never connected to
store = ElasticsearchStore(index_name="test_index", es_connection=client)
try:
Expand All @@ -203,7 +203,7 @@ def store(self) -> Generator[ElasticsearchStore, None, None]:
@pytest.fixture
def hybrid_store(
self, embeddings: Embeddings
) -> Generator[ElasticsearchStore, None, None]:
) -> Generator:
client = Elasticsearch(hosts=["http://dummy:9200"]) # never connected to
store = ElasticsearchStore(
index_name="test_index",
Expand Down

0 comments on commit 6dbd6f9

Please sign in to comment.